ente/server/pkg/repo/passkey/passkey.go
2024-03-01 13:37:01 +05:30

484 lines
11 KiB
Go

package passkey
import (
"database/sql"
"encoding/base64"
"encoding/json"
"net/http"
"strings"
"time"
ente_time "github.com/ente-io/museum/pkg/utils/time"
"github.com/ente-io/stacktrace"
"github.com/go-webauthn/webauthn/protocol"
"github.com/google/uuid"
"github.com/spf13/viper"
"github.com/ente-io/museum/ente"
"github.com/ente-io/museum/pkg/utils/byteMarshaller"
"github.com/go-webauthn/webauthn/webauthn"
)
type Repository struct {
DB *sql.DB
webAuthnInstance *webauthn.WebAuthn
}
type PasskeyUser struct {
*ente.User
repo *Repository
}
func (u *PasskeyUser) WebAuthnID() []byte {
b, _ := byteMarshaller.ConvertInt64ToByte(u.ID)
return b
}
func (u *PasskeyUser) WebAuthnName() string {
return u.Email
}
func (u *PasskeyUser) WebAuthnDisplayName() string {
return u.Name
}
func (u *PasskeyUser) WebAuthnCredentials() []webauthn.Credential {
creds, err := u.repo.GetUserPasskeyCredentials(u.ID)
if err != nil {
return []webauthn.Credential{}
}
return creds
}
func (u *PasskeyUser) WebAuthnIcon() string {
// this specification is deprecated but the interface requires it
return ""
}
func NewRepository(
db *sql.DB,
) (repo *Repository, err error) {
rpId := viper.GetString("webauthn.rpid")
if rpId == "" {
rpId = "accounts.ente.io"
}
rpOrigins := viper.GetStringSlice("webauthn.rporigins")
wconfig := &webauthn.Config{
RPDisplayName: "Ente",
RPID: rpId,
RPOrigins: rpOrigins,
Timeouts: webauthn.TimeoutsConfig{
Login: webauthn.TimeoutConfig{
Enforce: true,
Timeout: time.Duration(5) * time.Minute,
},
Registration: webauthn.TimeoutConfig{
Enforce: true,
Timeout: time.Duration(5) * time.Minute,
},
},
}
webAuthnInstance, err := webauthn.New(wconfig)
if err != nil {
return
}
repo = &Repository{
DB: db,
webAuthnInstance: webAuthnInstance,
}
return
}
func (r *Repository) GetUserPasskeys(userID int64) (passkeys []ente.Passkey, err error) {
rows, err := r.DB.Query(`
SELECT id, user_id, friendly_name, created_at
FROM passkeys
WHERE user_id = $1 AND deleted_at IS NULL
`, userID)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
defer rows.Close()
for rows.Next() {
var passkey ente.Passkey
if err = rows.Scan(
&passkey.ID,
&passkey.UserID,
&passkey.FriendlyName,
&passkey.CreatedAt,
); err != nil {
err = stacktrace.Propagate(err, "")
return
}
passkeys = append(passkeys, passkey)
}
return
}
func (r *Repository) CreateBeginRegistrationData(user *ente.User) (options *protocol.CredentialCreation, session *webauthn.SessionData, id uuid.UUID, err error) {
passkeyUser := &PasskeyUser{
User: user,
repo: r,
}
if len(passkeyUser.WebAuthnCredentials()) >= ente.MaxPasskeys {
err = stacktrace.NewError(ente.ErrMaxPasskeysReached.Error())
return
}
options, session, err = r.webAuthnInstance.BeginRegistration(passkeyUser)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
// save session data
marshalledSessionData, err := r.marshalSessionDataToWebAuthnSession(session)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
id = uuid.New()
err = r.saveSessionData(id, marshalledSessionData)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
return
}
func (r *Repository) AddPasskeyTwoFactorSession(userID int64, sessionID string, expirationTime int64) error {
_, err := r.DB.Exec(`INSERT INTO passkey_login_sessions(user_id, session_id, creation_time, expiration_time) VALUES($1, $2, $3, $4)`,
userID, sessionID, ente_time.Microseconds(), expirationTime)
return stacktrace.Propagate(err, "")
}
func (r *Repository) GetUserIDWithPasskeyTwoFactorSession(sessionID string) (userID int64, err error) {
err = r.DB.QueryRow(`SELECT user_id FROM passkey_login_sessions WHERE session_id = $1`, sessionID).Scan(&userID)
return
}
func (r *Repository) CreateBeginAuthenticationData(user *ente.User) (options *protocol.CredentialAssertion, session *webauthn.SessionData, id uuid.UUID, err error) {
passkeyUser := &PasskeyUser{
User: user,
repo: r,
}
options, session, err = r.webAuthnInstance.BeginLogin(passkeyUser)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
// save session data
marshalledSessionData, err := r.marshalSessionDataToWebAuthnSession(session)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
id = uuid.New()
err = r.saveSessionData(id, marshalledSessionData)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
return
}
func (r *Repository) FinishRegistration(user *ente.User, friendlyName string, req *http.Request, sessionID uuid.UUID) (err error) {
passkeyUser := &PasskeyUser{
User: user,
repo: r,
}
session, err := r.getWebAuthnSessionByID(sessionID)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
if session.UserID != user.ID {
err = stacktrace.NewError("session does not belong to user")
return
}
sessionData, err := session.SessionData()
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
if time.Now().After(sessionData.Expires) {
err = stacktrace.NewError("session expired")
return
}
credential, err := r.webAuthnInstance.FinishRegistration(passkeyUser, *sessionData, req)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
newPasskey, err := r.createPasskey(user.ID, friendlyName)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
passkeyCredential, err := r.marshalCredentialToPasskeyCredential(credential, newPasskey.ID)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
err = r.createPasskeyCredential(passkeyCredential)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
return
}
func (r *Repository) FinishAuthentication(user *ente.User, req *http.Request, sessionID uuid.UUID) (err error) {
passkeyUser := &PasskeyUser{
User: user,
repo: r,
}
session, err := r.getWebAuthnSessionByID(sessionID)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
if session.UserID != user.ID {
err = stacktrace.NewError("session does not belong to user")
return
}
sessionData, err := session.SessionData()
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
if time.Now().After(sessionData.Expires) {
err = stacktrace.NewError("session expired")
return
}
_, err = r.webAuthnInstance.FinishLogin(passkeyUser, *sessionData, req)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
return
}
func (r *Repository) DeletePasskey(user *ente.User, passkeyID uuid.UUID) (err error) {
_, err = r.DB.Exec(`
UPDATE passkeys
SET friendly_name = $1,
deleted_at = $2
WHERE id = $3 AND user_id = $4 AND deleted_at IS NULL
`, passkeyID, ente_time.Microseconds(), passkeyID, user.ID)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
return
}
func (r *Repository) RenamePasskey(user *ente.User, passkeyID uuid.UUID, newName string) (err error) {
_, err = r.DB.Exec(`
UPDATE passkeys
SET friendly_name = $1
WHERE id = $2 AND user_id = $3 AND deleted_at IS NULL
`, newName, passkeyID, user.ID)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
return
}
func (r *Repository) saveSessionData(id uuid.UUID, session *ente.WebAuthnSession) (err error) {
_, err = r.DB.Exec(`
INSERT INTO webauthn_sessions (
id,
challenge,
user_id,
allowed_credential_ids,
expires_at,
user_verification_requirement,
extensions,
created_at
) VALUES (
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8
)
`,
id,
session.Challenge,
session.UserID,
session.AllowedCredentialIDs,
session.ExpiresAt,
session.UserVerificationRequirement,
session.Extensions,
session.CreatedAt,
)
return
}
func (r *Repository) marshalCredentialToPasskeyCredential(cred *webauthn.Credential, passkeyID uuid.UUID) (*ente.PasskeyCredential, error) {
// Convert the PublicKey to base64
publicKeyB64 := base64.StdEncoding.EncodeToString(cred.PublicKey)
// Convert the Transports slice to a comma-separated string
var transports []string
for _, t := range cred.Transport {
transports = append(transports, string(t))
}
authenticatorTransports := strings.Join(transports, ",")
// Marshal the Flags to JSON
credentialFlags, err := json.Marshal(cred.Flags)
if err != nil {
return nil, err
}
// Marshal the Authenticator to JSON and encode AAGUID to base64
authenticatorMap := map[string]interface{}{
"AAGUID": base64.StdEncoding.EncodeToString(cred.Authenticator.AAGUID),
"SignCount": cred.Authenticator.SignCount,
"CloneWarning": cred.Authenticator.CloneWarning,
"Attachment": cred.Authenticator.Attachment,
}
authenticatorJSON, err := json.Marshal(authenticatorMap)
if err != nil {
return nil, err
}
// convert cred.ID into base64
credID := base64.StdEncoding.EncodeToString(cred.ID)
// Create the PasskeyCredential
passkeyCred := &ente.PasskeyCredential{
CredentialID: credID,
PasskeyID: passkeyID,
PublicKey: publicKeyB64,
AttestationType: cred.AttestationType,
AuthenticatorTransports: authenticatorTransports,
CredentialFlags: string(credentialFlags),
Authenticator: string(authenticatorJSON),
CreatedAt: time.Now().UnixMicro(),
}
return passkeyCred, nil
}
func (r *Repository) marshalSessionDataToWebAuthnSession(session *webauthn.SessionData) (webAuthnSession *ente.WebAuthnSession, err error) {
userID, err := byteMarshaller.ConvertBytesToInt64(session.UserID)
if err != nil {
return
}
extensionsJson, err := json.Marshal(session.Extensions)
if err != nil {
return
}
newWebAuthnSession := &ente.WebAuthnSession{
Challenge: session.Challenge,
UserID: userID,
AllowedCredentialIDs: byteMarshaller.EncodeSlices(session.AllowedCredentialIDs),
ExpiresAt: session.Expires.UnixMicro(),
UserVerificationRequirement: string(session.UserVerification),
Extensions: string(extensionsJson),
CreatedAt: time.Now().UnixMicro(),
}
return newWebAuthnSession, nil
}
func (r *Repository) GetUserPasskeyCredentials(userID int64) (credentials []webauthn.Credential, err error) {
rows, err := r.DB.Query(`
SELECT pc.*
FROM passkey_credentials pc
JOIN passkeys p ON pc.passkey_id = p.id
WHERE p.user_id = $1 AND p.deleted_at IS NULL
`, userID)
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
defer rows.Close()
for rows.Next() {
var pc ente.PasskeyCredential
if err = rows.Scan(
&pc.PasskeyID,
&pc.CredentialID,
&pc.PublicKey,
&pc.AttestationType,
&pc.AuthenticatorTransports,
&pc.CredentialFlags,
&pc.Authenticator,
&pc.CreatedAt,
); err != nil {
err = stacktrace.Propagate(err, "")
return
}
var cred *webauthn.Credential
cred, err = pc.WebAuthnCredential()
if err != nil {
err = stacktrace.Propagate(err, "")
return
}
credentials = append(credentials, *cred)
}
return
}
func (repo *Repository) RemoveExpiredPasskeySessions() error {
_, err := repo.DB.Exec(`DELETE FROM webauthn_sessions WHERE expires_at <= $1`,
ente_time.Microseconds())
if err != nil {
return stacktrace.Propagate(err, "")
}
_, err = repo.DB.Exec(`DELETE FROM passkey_login_sessions WHERE expiration_time <= $1`,
ente_time.Microseconds())
return stacktrace.Propagate(err, "")
}