ente/server/pkg/controller/user/srp.go
2024-03-01 13:37:01 +05:30

230 lines
6.9 KiB
Go

package user
import (
"context"
"github.com/ente-io/museum/ente"
"github.com/ente-io/museum/pkg/utils/auth"
"github.com/ente-io/stacktrace"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/kong/go-srp"
"github.com/sirupsen/logrus"
"net/http"
)
const Srp4096Params = 4096
func (c *UserController) SetupSRP(context *gin.Context, userID int64, req ente.SetupSRPRequest) (*ente.SetupSRPResponse, error) {
srpB, sessionID, err := c.createAndInsertSRPSession(context, req.SrpUserID, req.SRPVerifier, req.SRPA)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
setupID, err := c.UserAuthRepo.InsertTempSRPSetup(context, req, userID, sessionID)
if err != nil {
return nil, stacktrace.Propagate(err, "failed to add entry in setup table")
}
return &ente.SetupSRPResponse{
SetupID: *setupID,
SRPB: *srpB,
}, nil
}
func (c *UserController) CompleteSRPSetup(context *gin.Context, req ente.CompleteSRPSetupRequest) (*ente.CompleteSRPSetupResponse, error) {
userID := auth.GetUserID(context.Request.Header)
setup, err := c.UserAuthRepo.GetTempSRPSetupEntity(context, req.SetupID)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
srpM2, err := c.verifySRPSession(context, setup.Verifier, setup.SessionID, req.SRPM1)
if err != nil {
return nil, err
}
err = c.UserAuthRepo.InsertSRPAuth(context, userID, setup.SRPUserID, setup.Verifier, setup.Salt)
if err != nil {
return nil, stacktrace.Propagate(err, "failed to add entry in srp auth")
}
return &ente.CompleteSRPSetupResponse{
SetupID: req.SetupID,
SRPM2: *srpM2,
}, nil
}
// UpdateSrpAndKeyAttributes updates the SRP and keys attributes if the SRP setup is successfully done
func (c *UserController) UpdateSrpAndKeyAttributes(context *gin.Context,
userID int64,
req ente.UpdateSRPAndKeysRequest,
shouldClearTokens bool,
) (*ente.UpdateSRPSetupResponse, error) {
setup, err := c.UserAuthRepo.GetTempSRPSetupEntity(context, req.SetupID)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
srpM2, err := c.verifySRPSession(context, setup.Verifier, setup.SessionID, req.SRPM1)
if err != nil {
return nil, err
}
err = c.UserAuthRepo.InsertOrUpdateSRPAuthAndKeyAttr(context, userID, req, setup)
if err != nil {
return nil, stacktrace.Propagate(err, "failed to add entry in srp auth")
}
if shouldClearTokens {
token := auth.GetToken(context)
err = c.UserAuthRepo.RemoveAllOtherTokens(userID, token)
if err != nil {
return nil, err
}
} else {
logrus.WithField("user_id", userID).Info("not clearing tokens")
}
return &ente.UpdateSRPSetupResponse{
SetupID: req.SetupID,
SRPM2: *srpM2,
}, nil
}
func (c *UserController) GetSRPAttributes(context *gin.Context, email string) (*ente.GetSRPAttributesResponse, error) {
userID, err := c.UserRepo.GetUserIDWithEmail(email)
if err != nil {
return nil, stacktrace.Propagate(err, "user does not exist")
}
srpAttributes, err := c.UserAuthRepo.GetSRPAttributes(userID)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return srpAttributes, nil
}
func (c *UserController) CreateSrpSession(context *gin.Context, req ente.CreateSRPSessionRequest) (*ente.CreateSRPSessionResponse, error) {
srpAuthEntity, err := c.UserAuthRepo.GetSRPAuthEntityBySRPUserID(context, req.SRPUserID)
if err != nil {
return nil, err
}
isEmailMFAEnabled, err := c.UserAuthRepo.IsEmailMFAEnabled(context, srpAuthEntity.UserID)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
if *isEmailMFAEnabled {
return nil, stacktrace.Propagate(&ente.ApiError{
Code: "EMAIL_MFA_ENABLED",
Message: "Email MFA is enabled",
HttpStatusCode: http.StatusConflict,
}, "email mfa is enabled")
}
srpBBase64, sessionID, err := c.createAndInsertSRPSession(context, req.SRPUserID, srpAuthEntity.Verifier, req.SRPA)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return &ente.CreateSRPSessionResponse{
SRPB: *srpBBase64,
SessionID: *sessionID,
}, nil
}
func (c *UserController) VerifySRPSession(context *gin.Context, req ente.VerifySRPSessionRequest) (*ente.EmailAuthorizationResponse, error) {
srpAuthEntity, err := c.UserAuthRepo.GetSRPAuthEntityBySRPUserID(context, req.SRPUserID)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
srpM2, err := c.verifySRPSession(context, srpAuthEntity.Verifier, req.SessionID, req.SRPM1)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
user, err := c.UserRepo.Get(srpAuthEntity.UserID)
if err != nil {
return nil, err
}
verResponse, err := c.onVerificationSuccess(context, user.Email, nil)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
verResponse.SrpM2 = srpM2
return &verResponse, nil
}
func (c *UserController) createAndInsertSRPSession(
gContext *gin.Context,
srpUserID uuid.UUID,
srpVerifier string,
srpA string,
) (*string, *uuid.UUID, error) {
serverSecret := srp.GenKey()
srpParams := srp.GetParams(Srp4096Params)
srpServer := srp.NewServer(srpParams, convertStringToBytes(srpVerifier), serverSecret)
if srpServer == nil {
return nil, nil, stacktrace.NewError("server is nil")
}
srpServer.SetA(convertStringToBytes(srpA))
srpB := srpServer.ComputeB()
if srpB == nil {
return nil, nil, stacktrace.NewError("srpB is nil")
}
sessionID, err := c.UserAuthRepo.AddSRPSession(srpUserID, convertBytesToString(serverSecret), srpA)
if err != nil {
return nil, nil, stacktrace.Propagate(err, "")
}
srpBBase64 := convertBytesToString(srpB)
return &srpBBase64, &sessionID, nil
}
func (c *UserController) verifySRPSession(ctx context.Context,
srpVerifier string,
sessionID uuid.UUID,
srpM1 string,
) (*string, error) {
srpSession, err := c.UserAuthRepo.GetSrpSessionEntity(ctx, sessionID)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
if srpSession.IsVerified {
return nil, stacktrace.Propagate(&ente.ApiError{
Code: "SESSION_ALREADY_VERIFIED",
HttpStatusCode: http.StatusGone,
}, "")
} else if srpSession.AttemptCount >= 5 {
return nil, stacktrace.Propagate(&ente.ApiError{
Code: "TOO_MANY_WRONG_ATTEMPTS",
HttpStatusCode: http.StatusGone,
}, "")
}
srpParams := srp.GetParams(Srp4096Params)
srpServer := srp.NewServer(srpParams, convertStringToBytes(srpVerifier), convertStringToBytes(srpSession.ServerKey))
if srpServer == nil {
return nil, stacktrace.NewError("server is nil")
}
srpServer.SetA(convertStringToBytes(srpSession.SRP_A))
srpM2Bytes, err := srpServer.CheckM1(convertStringToBytes(srpM1))
if err != nil {
err2 := c.UserAuthRepo.IncrementSrpSessionAttemptCount(ctx, sessionID)
if err2 != nil {
return nil, stacktrace.Propagate(err2, "")
}
return nil, stacktrace.Propagate(ente.ErrInvalidPassword, "failed to verify srp session")
} else {
err2 := c.UserAuthRepo.SetSrpSessionVerified(ctx, sessionID)
if err2 != nil {
return nil, stacktrace.Propagate(err2, "")
}
}
srpM2 := convertBytesToString(srpM2Bytes)
return &srpM2, nil
}