diff --git a/server/ente/embedding.go b/server/ente/embedding.go index efde0aae8..2990a779a 100644 --- a/server/ente/embedding.go +++ b/server/ente/embedding.go @@ -1,12 +1,12 @@ package ente type Embedding struct { - FileID int64 `json:"fileID"` - Model string `json:"model"` - EncryptedEmbedding string `json:"encryptedEmbedding"` - DecryptionHeader string `json:"decryptionHeader"` - UpdatedAt int64 `json:"updatedAt"` - Client *string `json:"client,omitempty"` + FileID int64 `json:"fileID"` + Model string `json:"model"` + EncryptedEmbedding string `json:"encryptedEmbedding"` + DecryptionHeader string `json:"decryptionHeader"` + UpdatedAt int64 `json:"updatedAt"` + Version *int `json:"version,omitempty"` } type InsertOrUpdateEmbeddingRequest struct { diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index 905a29ece..7f2f5dd80 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -74,7 +74,8 @@ func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmb log.Error(uploadErr) return nil, stacktrace.Propagate(uploadErr, "") } - embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size) + embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version) + embedding.Version = &version if err != nil { return nil, stacktrace.Propagate(err, "") } @@ -159,7 +160,7 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd EncryptedEmbedding: obj.embeddingObject.EncryptedEmbedding, DecryptionHeader: obj.embeddingObject.DecryptionHeader, UpdatedAt: obj.dbEmbeddingRow.UpdatedAt, - Client: obj.dbEmbeddingRow.Client, + Version: obj.dbEmbeddingRow.Version, }) } } diff --git a/server/pkg/repo/embedding/repository.go b/server/pkg/repo/embedding/repository.go index f3fcb16cb..80b5ae9cc 100644 --- a/server/pkg/repo/embedding/repository.go +++ b/server/pkg/repo/embedding/repository.go @@ -19,12 +19,8 @@ type Repository struct { // Create inserts a new embedding -func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry ente.InsertOrUpdateEmbeddingRequest, size int) (ente.Embedding, error) { +func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry ente.InsertOrUpdateEmbeddingRequest, size int, version int) (ente.Embedding, error) { var updatedAt int64 - version := 1 - if entry.Version != nil { - version = *entry.Version - } err := r.DB.QueryRowContext(ctx, `INSERT INTO embeddings (file_id, owner_id, model, size, version) VALUES ($1, $2, $3, $4, $5) @@ -49,7 +45,7 @@ func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry en // GetDiff returns the embeddings that have been updated since the given time func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Model, sinceTime int64, limit int16) ([]ente.Embedding, error) { - rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at + rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version, FROM embeddings WHERE owner_id = $1 AND model = $2 AND updated_at > $3 ORDER BY updated_at ASC @@ -61,7 +57,7 @@ func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Mode } func (r *Repository) GetFilesEmbedding(ctx context.Context, ownerID int64, model ente.Model, fileIDs []int64) ([]ente.Embedding, error) { - rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at + rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version FROM embeddings WHERE owner_id = $1 AND model = $2 AND file_id = ANY($3)`, ownerID, model, pq.Array(fileIDs)) if err != nil { @@ -97,13 +93,19 @@ func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) { for rows.Next() { embedding := ente.Embedding{} var encryptedEmbedding, decryptionHeader sql.NullString - err := rows.Scan(&embedding.FileID, &embedding.Model, &encryptedEmbedding, &decryptionHeader, &embedding.UpdatedAt) + var version sql.NullInt32 + err := rows.Scan(&embedding.FileID, &embedding.Model, &encryptedEmbedding, &decryptionHeader, &embedding.UpdatedAt, &version) if encryptedEmbedding.Valid && len(encryptedEmbedding.String) > 0 { embedding.EncryptedEmbedding = encryptedEmbedding.String } if decryptionHeader.Valid && len(decryptionHeader.String) > 0 { embedding.DecryptionHeader = decryptionHeader.String } + v := 1 + if version.Valid { + v = int(version.Int32) + } + embedding.Version = &v if err != nil { return nil, stacktrace.Propagate(err, "") } diff --git a/server/pkg/utils/array/array.go b/server/pkg/utils/array/array.go index 0c2d25d90..42b6b8fa9 100644 --- a/server/pkg/utils/array/array.go +++ b/server/pkg/utils/array/array.go @@ -49,15 +49,6 @@ func Int64InList(a int64, list []int64) bool { } // FindMissingElementsInSecondList identifies elements in 'sourceList' that are not present in 'targetList'. -// -// This function creates a set from 'targetList' for efficient lookup, then iterates through 'sourceList' -// to identify which elements are missing in 'targetList'. This method is particularly efficient for large -// lists, as it avoids the quadratic complexity of nested iterations by utilizing a hash set for O(1) lookups. -// -// Parameters: -// - sourceList: An array of int64 elements to check against 'targetList'. -// - targetList: An array of int64 elements used as the reference set to identify missing elements from 'sourceList'. -// // Returns: // - A slice of int64 representing the elements found in 'sourceList' but not in 'targetList'. // If all elements of 'sourceList' are present in 'targetList', an empty slice is returned.