diff --git a/server/migrations/83_fix_ott_index.down.sql b/server/migrations/83_fix_ott_index.down.sql index 571df81a5..5eab1e6a2 100644 --- a/server/migrations/83_fix_ott_index.down.sql +++ b/server/migrations/83_fix_ott_index.down.sql @@ -1,11 +1,9 @@ - --- Add unique index on otts for email_hash,app,ott BEGIN; ALTER TABLE - otts DROP CONSTRAINT unique_otts_emailhash_ott; + otts DROP CONSTRAINT IF EXISTS unique_otts_emailhash_app_ott; ALTER TABLE otts ADD - CONSTRAINT unique_otts_emailhash_app_ott UNIQUE (ott,app, email_hash); -COMMIT; + CONSTRAINT unique_otts_emailhash_ott UNIQUE (ott, email_hash); +COMMIT; \ No newline at end of file diff --git a/server/migrations/83_fix_ott_index.up.sql b/server/migrations/83_fix_ott_index.up.sql index 096ded0d7..4c9015f85 100644 --- a/server/migrations/83_fix_ott_index.up.sql +++ b/server/migrations/83_fix_ott_index.up.sql @@ -1,2 +1,9 @@ -DROP TRIGGER IF EXISTS update_location_tag_updated_at ON location_tag; -DROP TABLE location_tag; +BEGIN; +ALTER TABLE + otts DROP CONSTRAINT IF EXISTS unique_otts_emailhash_ott; + +ALTER TABLE + otts + ADD + CONSTRAINT unique_otts_emailhash_app_ott UNIQUE (ott,app, email_hash); +COMMIT; \ No newline at end of file diff --git a/server/pkg/controller/user/userauth.go b/server/pkg/controller/user/userauth.go index bbc9942de..7247bb18f 100644 --- a/server/pkg/controller/user/userauth.go +++ b/server/pkg/controller/user/userauth.go @@ -140,7 +140,7 @@ func (c *UserController) verifyEmailOtt(context *gin.Context, email string, ott if err != nil { return stacktrace.Propagate(err, "") } - wrongAttempt, err := c.UserAuthRepo.GetMaxWrongAttempts(emailHash) + wrongAttempt, err := c.UserAuthRepo.GetMaxWrongAttempts(emailHash, auth.GetApp(context)) if err != nil { return stacktrace.Propagate(err, "") } @@ -166,12 +166,12 @@ func (c *UserController) verifyEmailOtt(context *gin.Context, email string, ott } } if !isValidOTT { - if err = c.UserAuthRepo.RecordWrongAttemptForActiveOtt(emailHash); err != nil { + if err = c.UserAuthRepo.RecordWrongAttemptForActiveOtt(emailHash, auth.GetApp(context)); err != nil { log.WithError(err).Warn("Failed to track wrong attempt") } return stacktrace.Propagate(ente.ErrIncorrectOTT, "") } - err = c.UserAuthRepo.RemoveOTT(emailHash, ott) + err = c.UserAuthRepo.RemoveOTT(emailHash, ott, auth.GetApp(context)) if err != nil { return stacktrace.Propagate(err, "") } diff --git a/server/pkg/repo/userauth.go b/server/pkg/repo/userauth.go index c5f86e8ec..c182e9e87 100644 --- a/server/pkg/repo/userauth.go +++ b/server/pkg/repo/userauth.go @@ -20,14 +20,14 @@ type UserAuthRepository struct { func (repo *UserAuthRepository) AddOTT(emailHash string, app ente.App, ott string, expirationTime int64) error { _, err := repo.DB.Exec(`INSERT INTO otts(email_hash, ott, creation_time, expiration_time, app) VALUES($1, $2, $3, $4, $5) - ON CONFLICT ON CONSTRAINT unique_otts_emailhash_ott DO UPDATE SET creation_time = $3, expiration_time = $4`, + ON CONFLICT ON CONSTRAINT unique_otts_emailhash_app_ott DO UPDATE SET creation_time = $3, expiration_time = $4`, emailHash, ott, time.Microseconds(), expirationTime, app) return stacktrace.Propagate(err, "") } // RemoveOTT removes the specified OTT (to be used when an OTT has been consumed) -func (repo *UserAuthRepository) RemoveOTT(emailHash string, ott string) error { - _, err := repo.DB.Exec(`DELETE FROM otts WHERE email_hash = $1 AND ott = $2`, emailHash, ott) +func (repo *UserAuthRepository) RemoveOTT(emailHash string, ott string, app ente.App) error { + _, err := repo.DB.Exec(`DELETE FROM otts WHERE email_hash = $1 AND ott = $2 AND app = $3`, emailHash, ott, app) return stacktrace.Propagate(err, "") } @@ -69,9 +69,9 @@ func (repo *UserAuthRepository) GetValidOTTs(emailHash string, app ente.App) ([] return otts, nil } -func (repo *UserAuthRepository) GetMaxWrongAttempts(emailHash string) (int, error) { - row := repo.DB.QueryRow(`SELECT COALESCE(MAX(wrong_attempt),0) FROM otts WHERE email_hash = $1 AND expiration_time > $2`, - emailHash, time.Microseconds()) +func (repo *UserAuthRepository) GetMaxWrongAttempts(emailHash string, app ente.App) (int, error) { + row := repo.DB.QueryRow(`SELECT COALESCE(MAX(wrong_attempt),0) FROM otts WHERE email_hash = $1 AND expiration_time > $2 AND app = $3`, + emailHash, time.Microseconds(), app) var wrongAttempt int if err := row.Scan(&wrongAttempt); err != nil { return 0, stacktrace.Propagate(err, "Failed to scan row") @@ -81,9 +81,9 @@ func (repo *UserAuthRepository) GetMaxWrongAttempts(emailHash string) (int, erro // RecordWrongAttemptForActiveOtt increases the wrong_attempt count for given emailHash and active ott. // Assuming tha we keep deleting expired OTT, max(wrong_attempt) can be used to track brute-force attack -func (repo *UserAuthRepository) RecordWrongAttemptForActiveOtt(emailHash string) error { +func (repo *UserAuthRepository) RecordWrongAttemptForActiveOtt(emailHash string, app ente.App) error { _, err := repo.DB.Exec(`UPDATE otts SET wrong_attempt = otts.wrong_attempt + 1 - WHERE email_hash = $1 AND expiration_time > $2`, emailHash, time.Microseconds()) + WHERE email_hash = $1 AND expiration_time > $2 AND app=$3`, emailHash, time.Microseconds(), app) if err != nil { return stacktrace.Propagate(err, "Failed to update wrong attempt count") }