ente/server/pkg/repo/userauth.go
2024-03-01 13:37:01 +05:30

175 lines
6.8 KiB
Go

package repo
import (
"database/sql"
"github.com/ente-io/museum/ente"
"github.com/ente-io/museum/pkg/utils/network"
"github.com/ente-io/museum/pkg/utils/time"
"github.com/ente-io/stacktrace"
)
// UserAuthRepository defines the methods for inserting, updating and retrieving
// one time tokens (currently) used for email verification.
type UserAuthRepository struct {
DB *sql.DB
}
// AddOTT saves the provided one time token for the specified user
func (repo *UserAuthRepository) AddOTT(emailHash string, app ente.App, ott string, expirationTime int64) error {
_, err := repo.DB.Exec(`INSERT INTO otts(email_hash, ott, creation_time, expiration_time, app)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT ON CONSTRAINT unique_otts_emailhash_ott DO UPDATE SET creation_time = $3, expiration_time = $4`,
emailHash, ott, time.Microseconds(), expirationTime, app)
return stacktrace.Propagate(err, "")
}
// RemoveOTT removes the specified OTT (to be used when an OTT has been consumed)
func (repo *UserAuthRepository) RemoveOTT(emailHash string, ott string) error {
_, err := repo.DB.Exec(`DELETE FROM otts WHERE email_hash = $1 AND ott = $2`, emailHash, ott)
return stacktrace.Propagate(err, "")
}
// RemoveExpiredOTTs removes all OTTs that have expired
func (repo *UserAuthRepository) RemoveExpiredOTTs() error {
_, err := repo.DB.Exec(`DELETE FROM otts WHERE expiration_time <= $1`,
time.Microseconds())
return stacktrace.Propagate(err, "")
}
// GetTokenCreationTime return the creation_time for the given token
func (repo *UserAuthRepository) GetTokenCreationTime(token string) (int64, error) {
row := repo.DB.QueryRow(`SELECT creation_time from tokens where token = $1`, token)
var result int64
if err := row.Scan(&result); err != nil {
return 0, stacktrace.Propagate(err, "Failed to scan row")
}
return result, nil
}
// GetValidOTTs returns the list of OTTs that haven't expired for a given user
func (repo *UserAuthRepository) GetValidOTTs(emailHash string, app ente.App) ([]string, error) {
rows, err := repo.DB.Query(`SELECT ott FROM otts WHERE email_hash = $1 AND app = $2 AND expiration_time > $3`,
emailHash, app, time.Microseconds())
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
defer rows.Close()
otts := make([]string, 0)
for rows.Next() {
var ott string
err := rows.Scan(&ott)
if err != nil {
return otts, stacktrace.Propagate(err, "")
}
otts = append(otts, ott)
}
return otts, nil
}
func (repo *UserAuthRepository) GetMaxWrongAttempts(emailHash string) (int, error) {
row := repo.DB.QueryRow(`SELECT COALESCE(MAX(wrong_attempt),0) FROM otts WHERE email_hash = $1 AND expiration_time > $2`,
emailHash, time.Microseconds())
var wrongAttempt int
if err := row.Scan(&wrongAttempt); err != nil {
return 0, stacktrace.Propagate(err, "Failed to scan row")
}
return wrongAttempt, nil
}
// RecordWrongAttemptForActiveOtt increases the wrong_attempt count for given emailHash and active ott.
// Assuming tha we keep deleting expired OTT, max(wrong_attempt) can be used to track brute-force attack
func (repo *UserAuthRepository) RecordWrongAttemptForActiveOtt(emailHash string) error {
_, err := repo.DB.Exec(`UPDATE otts SET wrong_attempt = otts.wrong_attempt + 1
WHERE email_hash = $1 AND expiration_time > $2`, emailHash, time.Microseconds())
if err != nil {
return stacktrace.Propagate(err, "Failed to update wrong attempt count")
}
return nil
}
// AddToken saves the provided long lived token for the specified user
func (repo *UserAuthRepository) AddToken(userID int64, app ente.App, token string, ip string, userAgent string) error {
_, err := repo.DB.Exec(`INSERT INTO tokens(user_id, app, token, creation_time, ip, user_agent) VALUES($1, $2, $3, $4, $5, $6)`,
userID, app, token, time.Microseconds(), ip, userAgent)
return stacktrace.Propagate(err, "")
}
// GetUserIDWithToken returns the userID associated with a given token
func (repo *UserAuthRepository) GetUserIDWithToken(token string, app ente.App) (int64, error) {
row := repo.DB.QueryRow(`SELECT user_id FROM tokens WHERE token = $1 AND app = $2 AND is_deleted = false`, token, app)
var id int64
err := row.Scan(&id)
if err != nil {
return -1, stacktrace.Propagate(err, "")
}
return id, nil
}
// RemoveToken marks the specified token (to be used when a user logs out) as deleted
func (repo *UserAuthRepository) RemoveToken(userID int64, token string) error {
_, err := repo.DB.Exec(`UPDATE tokens SET is_deleted = true WHERE user_id = $1 AND token = $2`,
userID, token)
return stacktrace.Propagate(err, "")
}
// UpdateLastUsedAt updates the last used at timestamp for the particular token
func (repo *UserAuthRepository) UpdateLastUsedAt(userID int64, token string, ip string, userAgent string) error {
_, err := repo.DB.Exec(`UPDATE tokens SET ip = $1, user_agent = $2, last_used_at = $3 WHERE user_id = $4 AND token = $5`,
ip, userAgent, time.Microseconds(), userID, token)
return stacktrace.Propagate(err, "")
}
// RemoveAllOtherTokens marks the all tokens apart from the specified one for a user as deleted
func (repo *UserAuthRepository) RemoveAllOtherTokens(userID int64, token string) error {
_, err := repo.DB.Exec(`UPDATE tokens SET is_deleted = true WHERE user_id = $1 AND token <> $2`,
userID, token)
return stacktrace.Propagate(err, "")
}
func (repo *UserAuthRepository) RemoveDeletedTokens(expiryTime int64) error {
_, err := repo.DB.Exec(`DELETE FROM tokens WHERE is_deleted = true AND last_used_at < $1`, expiryTime)
return stacktrace.Propagate(err, "")
}
// RemoveAllTokens marks the all tokens for a user as deleted
func (repo *UserAuthRepository) RemoveAllTokens(userID int64) error {
_, err := repo.DB.Exec(`UPDATE tokens SET is_deleted = true WHERE user_id = $1`, userID)
return stacktrace.Propagate(err, "")
}
// GetActiveSessions returns the list of tokens that are valid for a given user
func (repo *UserAuthRepository) GetActiveSessions(userID int64, app ente.App) ([]ente.Session, error) {
rows, err := repo.DB.Query(`SELECT token, creation_time, ip, user_agent, last_used_at FROM tokens WHERE user_id = $1 AND app = $2 AND is_deleted = false`, userID, app)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
defer rows.Close()
sessions := make([]ente.Session, 0)
for rows.Next() {
var ip sql.NullString
var userAgent sql.NullString
var session ente.Session
err := rows.Scan(&session.Token, &session.CreationTime, &ip, &userAgent, &session.LastUsedTime)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
if ip.Valid {
session.IP = ip.String
} else {
session.IP = "Unknown IP"
}
if userAgent.Valid {
session.UA = userAgent.String
session.PrettyUA = network.GetPrettyUA(userAgent.String)
} else {
session.UA = "Unknown Device"
session.PrettyUA = "Unknown Device"
}
sessions = append(sessions, session)
}
return sessions, nil
}