ente/server/pkg/repo/srp.go
2024-03-01 13:37:01 +05:30

193 lines
8.3 KiB
Go

package repo
import (
"context"
"database/sql"
"errors"
"github.com/ente-io/museum/ente"
"github.com/ente-io/stacktrace"
"github.com/google/uuid"
)
// AddSRPSession inserts a SRPSession and returns the session id
func (repo *UserAuthRepository) AddSRPSession(srpUserID uuid.UUID, serverKey string, srpA string) (uuid.UUID, error) {
id := uuid.New()
_, err := repo.DB.Exec(`
INSERT INTO srp_sessions(id, srp_user_id, server_key, srp_a)
VALUES($1, $2 , $3, $4)`, id, srpUserID, serverKey, srpA)
return id, stacktrace.Propagate(err, "")
}
func (repo *UserAuthRepository) GetSRPAuthEntity(ctx context.Context, userID int64) (*ente.SRPAuthEntity, error) {
result := ente.SRPAuthEntity{}
row := repo.DB.QueryRowContext(ctx, `SELECT user_id, srp_user_id, salt, verifier FROM srp_auth WHERE user_id = $1`, userID)
err := row.Scan(&result.UserID, &result.SRPUserID, &result.Salt, &result.Verifier)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return &result, nil
}
func (repo *UserAuthRepository) GetSRPAuthEntityBySRPUserID(ctx context.Context, srpUserID uuid.UUID) (*ente.SRPAuthEntity, error) {
result := ente.SRPAuthEntity{}
row := repo.DB.QueryRowContext(ctx, `SELECT user_id, srp_user_id, salt, verifier FROM srp_auth WHERE srp_user_id = $1`, srpUserID)
err := row.Scan(&result.UserID, &result.SRPUserID, &result.Salt, &result.Verifier)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return &result, nil
}
// IsSRPSetupDone returns true if the user has already set SRP attributes
func (repo *UserAuthRepository) IsSRPSetupDone(ctx context.Context, userID int64) (bool, error) {
_, err := repo.GetSRPAuthEntity(ctx, userID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
return false, stacktrace.Propagate(err, "failed to read srp attributes")
}
return true, nil
}
// UpdateEmailMFA updates the email MFA status of a user
func (repo *UserAuthRepository) UpdateEmailMFA(ctx context.Context, userID int64, isEnabled bool) error {
_, err := repo.DB.ExecContext(ctx, `UPDATE users SET email_mfa = $1 WHERE user_id = $2`, isEnabled, userID)
if err != nil {
return stacktrace.Propagate(err, "failed to update email MFA status")
}
return nil
}
func (repo *UserAuthRepository) IsEmailMFAEnabled(ctx context.Context, userID int64) (*bool, error) {
row := repo.DB.QueryRowContext(ctx, `SELECT email_mfa FROM users WHERE user_id = $1`, userID)
var isEnabled bool
err := row.Scan(&isEnabled)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return &isEnabled, nil
}
// InsertTempSRPSetup inserts an entry into the temp_srp_setup table. It also returns the ID of the inserted row
func (repo *UserAuthRepository) InsertTempSRPSetup(ctx context.Context, req ente.SetupSRPRequest, userID int64, sessionID *uuid.UUID) (*uuid.UUID, error) {
id := uuid.New()
_, err := repo.DB.ExecContext(ctx, `
INSERT INTO temp_srp_setup(id, session_id, user_id, srp_user_id, salt, verifier) VALUES($1, $2 , $3, $4, $5, $6)`,
id, sessionID, userID, req.SrpUserID, req.SRPSalt, req.SRPVerifier)
return &id, stacktrace.Propagate(err, "")
}
func (repo *UserAuthRepository) GetTempSRPSetupEntity(ctx context.Context, setUpID uuid.UUID) (*ente.SRPSetupEntity, error) {
result := ente.SRPSetupEntity{}
row := repo.DB.QueryRowContext(ctx, `SELECT id, session_id, user_id, srp_user_id, salt, verifier FROM temp_srp_setup WHERE id = $1`, setUpID)
err := row.Scan(&result.ID, &result.SessionID, &result.UserID, &result.SRPUserID, &result.Salt, &result.Verifier)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return &result, nil
}
func (repo *UserAuthRepository) InsertSRPAuth(ctx context.Context, userID int64, srpUserID uuid.UUID, verifier string, salt string) error {
isSRPSetupDone, err := repo.IsSRPSetupDone(ctx, userID)
if err != nil {
return stacktrace.Propagate(err, "")
}
if isSRPSetupDone {
return stacktrace.Propagate(ente.NewBadRequestWithMessage("SRP setup already complete"), "")
}
_, err = repo.DB.ExecContext(ctx, `
INSERT INTO srp_auth(user_id, srp_user_id, salt, verifier) VALUES($1, $2 , $3, $4)`,
userID, srpUserID, salt, verifier)
return stacktrace.Propagate(err, "")
}
func (repo *UserAuthRepository) InsertOrUpdateSRPAuthAndKeyAttr(ctx context.Context, userID int64, req ente.UpdateSRPAndKeysRequest, setup *ente.SRPSetupEntity) error {
isSRPSetupDone, err := repo.IsSRPSetupDone(ctx, userID)
if err != nil {
return stacktrace.Propagate(err, "")
}
tx, err := repo.DB.BeginTx(ctx, nil)
if err != nil {
return stacktrace.Propagate(err, "")
}
if !isSRPSetupDone {
_, err = tx.ExecContext(ctx, `
INSERT INTO srp_auth(user_id, srp_user_id, salt, verifier) VALUES($1, $2 , $3, $4)`,
userID, setup.SRPUserID, setup.Salt, setup.Verifier)
} else {
_, err = tx.ExecContext(ctx, `UPDATE srp_auth SET srp_user_id = $1, salt = $2, verifier = $3 WHERE user_id = $4`,
setup.SRPUserID, setup.Salt, setup.Verifier, userID)
}
if err != nil {
rollBackErr := tx.Rollback()
if rollBackErr != nil {
return rollBackErr
}
return stacktrace.Propagate(err, "")
}
updateKeyAttr := *req.UpdateAttributes
_, err = tx.ExecContext(ctx, `UPDATE key_attributes SET kek_salt = $1, encrypted_key = $2, key_decryption_nonce = $3, mem_limit = $4, ops_limit = $5 WHERE user_id = $6`,
updateKeyAttr.KEKSalt, updateKeyAttr.EncryptedKey, updateKeyAttr.KeyDecryptionNonce, updateKeyAttr.MemLimit, updateKeyAttr.OpsLimit, userID)
if err != nil {
rollBackErr := tx.Rollback()
if rollBackErr != nil {
return rollBackErr
}
return stacktrace.Propagate(err, "")
}
return tx.Commit()
}
// GetSrpSessionEntity ...
func (repo *UserAuthRepository) GetSrpSessionEntity(ctx context.Context, sessionID uuid.UUID) (*ente.SRPSessionEntity, error) {
result := ente.SRPSessionEntity{}
row := repo.DB.QueryRowContext(ctx, `SELECT id, srp_user_id, server_key, srp_a, has_verified, attempt_count FROM srp_sessions WHERE id = $1`, sessionID)
err := row.Scan(&result.ID, &result.SRPUserID, &result.ServerKey, &result.SRP_A, &result.IsVerified, &result.AttemptCount)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return &result, nil
}
// IncrementSrpSessionAttemptCount increments the verification attempt count of a session
func (repo *UserAuthRepository) IncrementSrpSessionAttemptCount(ctx context.Context, sessionID uuid.UUID) error {
_, err := repo.DB.ExecContext(ctx, `UPDATE srp_sessions SET attempt_count = attempt_count + 1 WHERE id = $1`, sessionID)
return stacktrace.Propagate(err, "")
}
// SetSrpSessionVerified ..
func (repo *UserAuthRepository) SetSrpSessionVerified(ctx context.Context, sessionID uuid.UUID) error {
_, err := repo.DB.ExecContext(ctx, `UPDATE srp_sessions SET has_verified = true WHERE id = $1`, sessionID)
return stacktrace.Propagate(err, "")
}
// GetSRPAttributes returns the srp attributes of a user
func (repo *UserAuthRepository) GetSRPAttributes(userID int64) (*ente.GetSRPAttributesResponse, error) {
row := repo.DB.QueryRow(`SELECT srp_user_id, salt, mem_limit, ops_limit, kek_salt, email_mfa FROM srp_auth left join key_attributes on srp_auth.user_id = key_attributes.user_id
left join users on users.user_id = srp_auth.user_id WHERE srp_auth.user_id = $1`, userID)
var srpAttributes ente.GetSRPAttributesResponse
err := row.Scan(&srpAttributes.SRPUserID, &srpAttributes.SRPSalt, &srpAttributes.MemLimit, &srpAttributes.OpsLimit, &srpAttributes.KekSalt, &srpAttributes.IsEmailMFAEnabled)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, stacktrace.Propagate(&ente.ErrNotFoundError, "srp attributes are not present")
}
if err.Error() == "sql: Scan error on column index 2, name \"mem_limit\": converting NULL to int is unsupported" {
/* user doesn't have key attributes, deleting the srp auth entry,
so that the user can setup srp again fresh along with key attributes
Can happen if the key attributes setup API is fails, but the srp setup API succeeds
TODO: create a single API for both key attributes and srp setup
*/
_, err := repo.DB.Exec(`DELETE FROM srp_auth WHERE user_id = $1`, userID)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return nil, stacktrace.Propagate(&ente.ErrNotFoundError, "key attributes are not present")
}
return nil, stacktrace.Propagate(err, "failed to read srp attributes")
}
return &srpAttributes, stacktrace.Propagate(err, "")
}