484 lines
11 KiB
Go
484 lines
11 KiB
Go
package passkey
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
ente_time "github.com/ente-io/museum/pkg/utils/time"
|
|
"github.com/ente-io/stacktrace"
|
|
"github.com/go-webauthn/webauthn/protocol"
|
|
"github.com/google/uuid"
|
|
"github.com/spf13/viper"
|
|
|
|
"github.com/ente-io/museum/ente"
|
|
"github.com/ente-io/museum/pkg/utils/byteMarshaller"
|
|
"github.com/go-webauthn/webauthn/webauthn"
|
|
)
|
|
|
|
type Repository struct {
|
|
DB *sql.DB
|
|
webAuthnInstance *webauthn.WebAuthn
|
|
}
|
|
|
|
type PasskeyUser struct {
|
|
*ente.User
|
|
repo *Repository
|
|
}
|
|
|
|
func (u *PasskeyUser) WebAuthnID() []byte {
|
|
b, _ := byteMarshaller.ConvertInt64ToByte(u.ID)
|
|
return b
|
|
}
|
|
|
|
func (u *PasskeyUser) WebAuthnName() string {
|
|
return u.Email
|
|
}
|
|
|
|
func (u *PasskeyUser) WebAuthnDisplayName() string {
|
|
return u.Name
|
|
}
|
|
|
|
func (u *PasskeyUser) WebAuthnCredentials() []webauthn.Credential {
|
|
creds, err := u.repo.GetUserPasskeyCredentials(u.ID)
|
|
if err != nil {
|
|
return []webauthn.Credential{}
|
|
}
|
|
|
|
return creds
|
|
}
|
|
|
|
func (u *PasskeyUser) WebAuthnIcon() string {
|
|
// this specification is deprecated but the interface requires it
|
|
return ""
|
|
}
|
|
|
|
func NewRepository(
|
|
db *sql.DB,
|
|
) (repo *Repository, err error) {
|
|
rpId := viper.GetString("webauthn.rpid")
|
|
if rpId == "" {
|
|
rpId = "accounts.ente.io"
|
|
}
|
|
rpOrigins := viper.GetStringSlice("webauthn.rporigins")
|
|
|
|
wconfig := &webauthn.Config{
|
|
RPDisplayName: "Ente",
|
|
RPID: rpId,
|
|
RPOrigins: rpOrigins,
|
|
Timeouts: webauthn.TimeoutsConfig{
|
|
Login: webauthn.TimeoutConfig{
|
|
Enforce: true,
|
|
Timeout: time.Duration(5) * time.Minute,
|
|
},
|
|
Registration: webauthn.TimeoutConfig{
|
|
Enforce: true,
|
|
Timeout: time.Duration(5) * time.Minute,
|
|
},
|
|
},
|
|
}
|
|
|
|
webAuthnInstance, err := webauthn.New(wconfig)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
repo = &Repository{
|
|
DB: db,
|
|
webAuthnInstance: webAuthnInstance,
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (r *Repository) GetUserPasskeys(userID int64) (passkeys []ente.Passkey, err error) {
|
|
rows, err := r.DB.Query(`
|
|
SELECT id, user_id, friendly_name, created_at
|
|
FROM passkeys
|
|
WHERE user_id = $1 AND deleted_at IS NULL
|
|
`, userID)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var passkey ente.Passkey
|
|
if err = rows.Scan(
|
|
&passkey.ID,
|
|
&passkey.UserID,
|
|
&passkey.FriendlyName,
|
|
&passkey.CreatedAt,
|
|
); err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
passkeys = append(passkeys, passkey)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (r *Repository) CreateBeginRegistrationData(user *ente.User) (options *protocol.CredentialCreation, session *webauthn.SessionData, id uuid.UUID, err error) {
|
|
passkeyUser := &PasskeyUser{
|
|
User: user,
|
|
repo: r,
|
|
}
|
|
|
|
if len(passkeyUser.WebAuthnCredentials()) >= ente.MaxPasskeys {
|
|
err = stacktrace.NewError(ente.ErrMaxPasskeysReached.Error())
|
|
return
|
|
}
|
|
|
|
options, session, err = r.webAuthnInstance.BeginRegistration(passkeyUser)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
// save session data
|
|
marshalledSessionData, err := r.marshalSessionDataToWebAuthnSession(session)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
id = uuid.New()
|
|
|
|
err = r.saveSessionData(id, marshalledSessionData)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (r *Repository) AddPasskeyTwoFactorSession(userID int64, sessionID string, expirationTime int64) error {
|
|
_, err := r.DB.Exec(`INSERT INTO passkey_login_sessions(user_id, session_id, creation_time, expiration_time) VALUES($1, $2, $3, $4)`,
|
|
userID, sessionID, ente_time.Microseconds(), expirationTime)
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
|
|
func (r *Repository) GetUserIDWithPasskeyTwoFactorSession(sessionID string) (userID int64, err error) {
|
|
err = r.DB.QueryRow(`SELECT user_id FROM passkey_login_sessions WHERE session_id = $1`, sessionID).Scan(&userID)
|
|
return
|
|
}
|
|
|
|
func (r *Repository) CreateBeginAuthenticationData(user *ente.User) (options *protocol.CredentialAssertion, session *webauthn.SessionData, id uuid.UUID, err error) {
|
|
passkeyUser := &PasskeyUser{
|
|
User: user,
|
|
repo: r,
|
|
}
|
|
|
|
options, session, err = r.webAuthnInstance.BeginLogin(passkeyUser)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
// save session data
|
|
marshalledSessionData, err := r.marshalSessionDataToWebAuthnSession(session)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
id = uuid.New()
|
|
|
|
err = r.saveSessionData(id, marshalledSessionData)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (r *Repository) FinishRegistration(user *ente.User, friendlyName string, req *http.Request, sessionID uuid.UUID) (err error) {
|
|
passkeyUser := &PasskeyUser{
|
|
User: user,
|
|
repo: r,
|
|
}
|
|
|
|
session, err := r.getWebAuthnSessionByID(sessionID)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
if session.UserID != user.ID {
|
|
err = stacktrace.NewError("session does not belong to user")
|
|
return
|
|
}
|
|
|
|
sessionData, err := session.SessionData()
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
if time.Now().After(sessionData.Expires) {
|
|
err = stacktrace.NewError("session expired")
|
|
return
|
|
}
|
|
|
|
credential, err := r.webAuthnInstance.FinishRegistration(passkeyUser, *sessionData, req)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
newPasskey, err := r.createPasskey(user.ID, friendlyName)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
passkeyCredential, err := r.marshalCredentialToPasskeyCredential(credential, newPasskey.ID)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
err = r.createPasskeyCredential(passkeyCredential)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (r *Repository) FinishAuthentication(user *ente.User, req *http.Request, sessionID uuid.UUID) (err error) {
|
|
passkeyUser := &PasskeyUser{
|
|
User: user,
|
|
repo: r,
|
|
}
|
|
|
|
session, err := r.getWebAuthnSessionByID(sessionID)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
if session.UserID != user.ID {
|
|
err = stacktrace.NewError("session does not belong to user")
|
|
return
|
|
}
|
|
|
|
sessionData, err := session.SessionData()
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
if time.Now().After(sessionData.Expires) {
|
|
err = stacktrace.NewError("session expired")
|
|
return
|
|
}
|
|
|
|
_, err = r.webAuthnInstance.FinishLogin(passkeyUser, *sessionData, req)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (r *Repository) DeletePasskey(user *ente.User, passkeyID uuid.UUID) (err error) {
|
|
_, err = r.DB.Exec(`
|
|
UPDATE passkeys
|
|
SET friendly_name = $1,
|
|
deleted_at = $2
|
|
WHERE id = $3 AND user_id = $4 AND deleted_at IS NULL
|
|
`, passkeyID, ente_time.Microseconds(), passkeyID, user.ID)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (r *Repository) RenamePasskey(user *ente.User, passkeyID uuid.UUID, newName string) (err error) {
|
|
_, err = r.DB.Exec(`
|
|
UPDATE passkeys
|
|
SET friendly_name = $1
|
|
WHERE id = $2 AND user_id = $3 AND deleted_at IS NULL
|
|
`, newName, passkeyID, user.ID)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (r *Repository) saveSessionData(id uuid.UUID, session *ente.WebAuthnSession) (err error) {
|
|
_, err = r.DB.Exec(`
|
|
INSERT INTO webauthn_sessions (
|
|
id,
|
|
challenge,
|
|
user_id,
|
|
allowed_credential_ids,
|
|
expires_at,
|
|
user_verification_requirement,
|
|
extensions,
|
|
created_at
|
|
) VALUES (
|
|
$1,
|
|
$2,
|
|
$3,
|
|
$4,
|
|
$5,
|
|
$6,
|
|
$7,
|
|
$8
|
|
)
|
|
`,
|
|
id,
|
|
session.Challenge,
|
|
session.UserID,
|
|
session.AllowedCredentialIDs,
|
|
session.ExpiresAt,
|
|
session.UserVerificationRequirement,
|
|
session.Extensions,
|
|
session.CreatedAt,
|
|
)
|
|
return
|
|
}
|
|
|
|
func (r *Repository) marshalCredentialToPasskeyCredential(cred *webauthn.Credential, passkeyID uuid.UUID) (*ente.PasskeyCredential, error) {
|
|
// Convert the PublicKey to base64
|
|
publicKeyB64 := base64.StdEncoding.EncodeToString(cred.PublicKey)
|
|
|
|
// Convert the Transports slice to a comma-separated string
|
|
var transports []string
|
|
for _, t := range cred.Transport {
|
|
transports = append(transports, string(t))
|
|
}
|
|
authenticatorTransports := strings.Join(transports, ",")
|
|
|
|
// Marshal the Flags to JSON
|
|
credentialFlags, err := json.Marshal(cred.Flags)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Marshal the Authenticator to JSON and encode AAGUID to base64
|
|
authenticatorMap := map[string]interface{}{
|
|
"AAGUID": base64.StdEncoding.EncodeToString(cred.Authenticator.AAGUID),
|
|
"SignCount": cred.Authenticator.SignCount,
|
|
"CloneWarning": cred.Authenticator.CloneWarning,
|
|
"Attachment": cred.Authenticator.Attachment,
|
|
}
|
|
authenticatorJSON, err := json.Marshal(authenticatorMap)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// convert cred.ID into base64
|
|
credID := base64.StdEncoding.EncodeToString(cred.ID)
|
|
|
|
// Create the PasskeyCredential
|
|
passkeyCred := &ente.PasskeyCredential{
|
|
CredentialID: credID,
|
|
PasskeyID: passkeyID,
|
|
PublicKey: publicKeyB64,
|
|
AttestationType: cred.AttestationType,
|
|
AuthenticatorTransports: authenticatorTransports,
|
|
CredentialFlags: string(credentialFlags),
|
|
Authenticator: string(authenticatorJSON),
|
|
CreatedAt: time.Now().UnixMicro(),
|
|
}
|
|
|
|
return passkeyCred, nil
|
|
}
|
|
|
|
func (r *Repository) marshalSessionDataToWebAuthnSession(session *webauthn.SessionData) (webAuthnSession *ente.WebAuthnSession, err error) {
|
|
|
|
userID, err := byteMarshaller.ConvertBytesToInt64(session.UserID)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
extensionsJson, err := json.Marshal(session.Extensions)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
newWebAuthnSession := &ente.WebAuthnSession{
|
|
Challenge: session.Challenge,
|
|
UserID: userID,
|
|
AllowedCredentialIDs: byteMarshaller.EncodeSlices(session.AllowedCredentialIDs),
|
|
ExpiresAt: session.Expires.UnixMicro(),
|
|
UserVerificationRequirement: string(session.UserVerification),
|
|
Extensions: string(extensionsJson),
|
|
CreatedAt: time.Now().UnixMicro(),
|
|
}
|
|
|
|
return newWebAuthnSession, nil
|
|
}
|
|
|
|
func (r *Repository) GetUserPasskeyCredentials(userID int64) (credentials []webauthn.Credential, err error) {
|
|
rows, err := r.DB.Query(`
|
|
SELECT pc.*
|
|
FROM passkey_credentials pc
|
|
JOIN passkeys p ON pc.passkey_id = p.id
|
|
WHERE p.user_id = $1 AND p.deleted_at IS NULL
|
|
`, userID)
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var pc ente.PasskeyCredential
|
|
if err = rows.Scan(
|
|
&pc.PasskeyID,
|
|
&pc.CredentialID,
|
|
&pc.PublicKey,
|
|
&pc.AttestationType,
|
|
&pc.AuthenticatorTransports,
|
|
&pc.CredentialFlags,
|
|
&pc.Authenticator,
|
|
&pc.CreatedAt,
|
|
); err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
var cred *webauthn.Credential
|
|
cred, err = pc.WebAuthnCredential()
|
|
if err != nil {
|
|
err = stacktrace.Propagate(err, "")
|
|
return
|
|
}
|
|
|
|
credentials = append(credentials, *cred)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (repo *Repository) RemoveExpiredPasskeySessions() error {
|
|
_, err := repo.DB.Exec(`DELETE FROM webauthn_sessions WHERE expires_at <= $1`,
|
|
ente_time.Microseconds())
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
|
|
_, err = repo.DB.Exec(`DELETE FROM passkey_login_sessions WHERE expiration_time <= $1`,
|
|
ente_time.Microseconds())
|
|
|
|
return stacktrace.Propagate(err, "")
|
|
}
|