Refactor computation of collecitonKey

This commit is contained in:
Neeraj Gupta 2023-09-23 09:34:21 +05:30
parent 564696e0ef
commit 42a6217d44
3 changed files with 35 additions and 29 deletions

View file

@ -1,11 +1,5 @@
package api
import (
enteCrypto "cli-go/internal/crypto"
"cli-go/utils"
"log"
)
// Collection represents a collection
type Collection struct {
ID int64 `json:"id"`
@ -25,20 +19,6 @@ type Collection struct {
collectionKey []byte
}
func (c *Collection) GetCollectionKey(masterKey []byte) []byte {
if c.collectionKey == nil || len(c.collectionKey) == 0 {
collKey, err := enteCrypto.SecretBoxOpen(
utils.DecodeBase64(c.EncryptedKey),
utils.DecodeBase64(c.KeyDecryptionNonce),
masterKey)
if err != nil {
log.Fatalf("failed to decrypt collection key %s", err)
}
c.collectionKey = collKey
}
return c.collectionKey
}
// CollectionUser represents the owner of a collection
type CollectionUser struct {
ID int64 `json:"id"`

View file

@ -1,6 +1,9 @@
package pkg
import "cli-go/pkg/model"
import (
"cli-go/pkg/model"
"context"
)
type KeyHolder struct {
AccountSecrets map[string]*accSecretInfo
@ -25,3 +28,8 @@ func (k *KeyHolder) LoadSecrets(account model.Account, cliKey []byte) (*accSecre
}
return k.AccountSecrets[account.AccountKey()], nil
}
func (k *KeyHolder) GetAccountSecretInfo(ctx context.Context) *accSecretInfo {
accountKey := ctx.Value("account_id").(string)
return k.AccountSecrets[accountKey]
}

View file

@ -1,8 +1,10 @@
package pkg
import (
"cli-go/internal/api"
enteCrypto "cli-go/internal/crypto"
"cli-go/pkg/model"
"cli-go/utils"
"context"
"encoding/base64"
"fmt"
@ -10,27 +12,25 @@ import (
"log"
)
var accountMasterKey = map[string][]byte{}
func (c *ClICtrl) SyncAccount(account model.Account) error {
log.SetPrefix(fmt.Sprintf("[%s] ", account.Email))
secretInfo, err := c.KeyHolder.LoadSecrets(account, c.CliKey)
if err != nil {
return err
}
ctx := c.buildRequestContext(context.Background(), account)
err = createDataBuckets(c.DB, account)
if err != nil {
return err
}
accountMasterKey[account.AccountKey()] = secretInfo.MasterKey
c.Client.AddToken(account.AccountKey(), base64.URLEncoding.EncodeToString(secretInfo.Token))
ctx := c.buildRequestContext(context.Background(), account)
return c.syncRemoteCollections(ctx, account)
}
func (c *ClICtrl) buildRequestContext(ctx context.Context, account model.Account) context.Context {
ctx = context.WithValue(ctx, "app", string(account.App))
ctx = context.WithValue(ctx, "account_id", account.AccountKey())
ctx = context.WithValue(ctx, "user_id", account.UserID)
return ctx
}
@ -55,16 +55,17 @@ func createDataBuckets(db *bolt.DB, account model.Account) error {
func (c *ClICtrl) syncRemoteCollections(ctx context.Context, info model.Account) error {
collections, err := c.Client.GetCollections(ctx, 0)
if err != nil {
log.Printf("failed to get collections: %s\n", err)
return err
return fmt.Errorf("failed to get collections: %s", err)
}
masterKey := accountMasterKey[info.AccountKey()]
for _, collection := range collections {
if collection.Owner.ID != info.UserID {
fmt.Printf("Skipping collection %d\n", collection.ID)
continue
}
collectionKey := collection.GetCollectionKey(masterKey)
collectionKey, err := c.getCollectionKey(ctx, collection)
if err != nil {
return err
}
name, nameErr := enteCrypto.SecretBoxOpenBase64(collection.EncryptedName, collection.NameDecryptionNonce, collectionKey)
if nameErr != nil {
log.Fatalf("failed to decrypt collection name: %v", nameErr)
@ -73,3 +74,20 @@ func (c *ClICtrl) syncRemoteCollections(ctx context.Context, info model.Account)
}
return nil
}
func (c *ClICtrl) getCollectionKey(ctx context.Context, collection api.Collection) ([]byte, error) {
accSecretInfo := c.KeyHolder.GetAccountSecretInfo(ctx)
userID := ctx.Value("user_id").(int64)
if collection.Owner.ID == userID {
collKey, err := enteCrypto.SecretBoxOpen(
utils.DecodeBase64(collection.EncryptedKey),
utils.DecodeBase64(collection.KeyDecryptionNonce),
accSecretInfo.MasterKey)
if err != nil {
log.Fatalf("failed to decrypt collection key %s", err)
}
return collKey, nil
} else {
panic("not implemented")
}
}