[server] Return embedding version in API response
This commit is contained in:
parent
e8756a8cf7
commit
2b3494e61c
|
@ -1,12 +1,12 @@
|
||||||
package ente
|
package ente
|
||||||
|
|
||||||
type Embedding struct {
|
type Embedding struct {
|
||||||
FileID int64 `json:"fileID"`
|
FileID int64 `json:"fileID"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
EncryptedEmbedding string `json:"encryptedEmbedding"`
|
EncryptedEmbedding string `json:"encryptedEmbedding"`
|
||||||
DecryptionHeader string `json:"decryptionHeader"`
|
DecryptionHeader string `json:"decryptionHeader"`
|
||||||
UpdatedAt int64 `json:"updatedAt"`
|
UpdatedAt int64 `json:"updatedAt"`
|
||||||
Client *string `json:"client,omitempty"`
|
Version *int `json:"version,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type InsertOrUpdateEmbeddingRequest struct {
|
type InsertOrUpdateEmbeddingRequest struct {
|
||||||
|
|
|
@ -74,7 +74,8 @@ func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmb
|
||||||
log.Error(uploadErr)
|
log.Error(uploadErr)
|
||||||
return nil, stacktrace.Propagate(uploadErr, "")
|
return nil, stacktrace.Propagate(uploadErr, "")
|
||||||
}
|
}
|
||||||
embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size)
|
embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version)
|
||||||
|
embedding.Version = &version
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, stacktrace.Propagate(err, "")
|
return nil, stacktrace.Propagate(err, "")
|
||||||
}
|
}
|
||||||
|
@ -159,7 +160,7 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd
|
||||||
EncryptedEmbedding: obj.embeddingObject.EncryptedEmbedding,
|
EncryptedEmbedding: obj.embeddingObject.EncryptedEmbedding,
|
||||||
DecryptionHeader: obj.embeddingObject.DecryptionHeader,
|
DecryptionHeader: obj.embeddingObject.DecryptionHeader,
|
||||||
UpdatedAt: obj.dbEmbeddingRow.UpdatedAt,
|
UpdatedAt: obj.dbEmbeddingRow.UpdatedAt,
|
||||||
Client: obj.dbEmbeddingRow.Client,
|
Version: obj.dbEmbeddingRow.Version,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,12 +19,8 @@ type Repository struct {
|
||||||
|
|
||||||
// Create inserts a new embedding
|
// Create inserts a new embedding
|
||||||
|
|
||||||
func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry ente.InsertOrUpdateEmbeddingRequest, size int) (ente.Embedding, error) {
|
func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry ente.InsertOrUpdateEmbeddingRequest, size int, version int) (ente.Embedding, error) {
|
||||||
var updatedAt int64
|
var updatedAt int64
|
||||||
version := 1
|
|
||||||
if entry.Version != nil {
|
|
||||||
version = *entry.Version
|
|
||||||
}
|
|
||||||
err := r.DB.QueryRowContext(ctx, `INSERT INTO embeddings
|
err := r.DB.QueryRowContext(ctx, `INSERT INTO embeddings
|
||||||
(file_id, owner_id, model, size, version)
|
(file_id, owner_id, model, size, version)
|
||||||
VALUES ($1, $2, $3, $4, $5)
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
|
@ -49,7 +45,7 @@ func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry en
|
||||||
|
|
||||||
// GetDiff returns the embeddings that have been updated since the given time
|
// GetDiff returns the embeddings that have been updated since the given time
|
||||||
func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Model, sinceTime int64, limit int16) ([]ente.Embedding, error) {
|
func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Model, sinceTime int64, limit int16) ([]ente.Embedding, error) {
|
||||||
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at
|
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version,
|
||||||
FROM embeddings
|
FROM embeddings
|
||||||
WHERE owner_id = $1 AND model = $2 AND updated_at > $3
|
WHERE owner_id = $1 AND model = $2 AND updated_at > $3
|
||||||
ORDER BY updated_at ASC
|
ORDER BY updated_at ASC
|
||||||
|
@ -61,7 +57,7 @@ func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Mode
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Repository) GetFilesEmbedding(ctx context.Context, ownerID int64, model ente.Model, fileIDs []int64) ([]ente.Embedding, error) {
|
func (r *Repository) GetFilesEmbedding(ctx context.Context, ownerID int64, model ente.Model, fileIDs []int64) ([]ente.Embedding, error) {
|
||||||
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at
|
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version
|
||||||
FROM embeddings
|
FROM embeddings
|
||||||
WHERE owner_id = $1 AND model = $2 AND file_id = ANY($3)`, ownerID, model, pq.Array(fileIDs))
|
WHERE owner_id = $1 AND model = $2 AND file_id = ANY($3)`, ownerID, model, pq.Array(fileIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -97,13 +93,19 @@ func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) {
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
embedding := ente.Embedding{}
|
embedding := ente.Embedding{}
|
||||||
var encryptedEmbedding, decryptionHeader sql.NullString
|
var encryptedEmbedding, decryptionHeader sql.NullString
|
||||||
err := rows.Scan(&embedding.FileID, &embedding.Model, &encryptedEmbedding, &decryptionHeader, &embedding.UpdatedAt)
|
var version sql.NullInt32
|
||||||
|
err := rows.Scan(&embedding.FileID, &embedding.Model, &encryptedEmbedding, &decryptionHeader, &embedding.UpdatedAt, &version)
|
||||||
if encryptedEmbedding.Valid && len(encryptedEmbedding.String) > 0 {
|
if encryptedEmbedding.Valid && len(encryptedEmbedding.String) > 0 {
|
||||||
embedding.EncryptedEmbedding = encryptedEmbedding.String
|
embedding.EncryptedEmbedding = encryptedEmbedding.String
|
||||||
}
|
}
|
||||||
if decryptionHeader.Valid && len(decryptionHeader.String) > 0 {
|
if decryptionHeader.Valid && len(decryptionHeader.String) > 0 {
|
||||||
embedding.DecryptionHeader = decryptionHeader.String
|
embedding.DecryptionHeader = decryptionHeader.String
|
||||||
}
|
}
|
||||||
|
v := 1
|
||||||
|
if version.Valid {
|
||||||
|
v = int(version.Int32)
|
||||||
|
}
|
||||||
|
embedding.Version = &v
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, stacktrace.Propagate(err, "")
|
return nil, stacktrace.Propagate(err, "")
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,15 +49,6 @@ func Int64InList(a int64, list []int64) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindMissingElementsInSecondList identifies elements in 'sourceList' that are not present in 'targetList'.
|
// FindMissingElementsInSecondList identifies elements in 'sourceList' that are not present in 'targetList'.
|
||||||
//
|
|
||||||
// This function creates a set from 'targetList' for efficient lookup, then iterates through 'sourceList'
|
|
||||||
// to identify which elements are missing in 'targetList'. This method is particularly efficient for large
|
|
||||||
// lists, as it avoids the quadratic complexity of nested iterations by utilizing a hash set for O(1) lookups.
|
|
||||||
//
|
|
||||||
// Parameters:
|
|
||||||
// - sourceList: An array of int64 elements to check against 'targetList'.
|
|
||||||
// - targetList: An array of int64 elements used as the reference set to identify missing elements from 'sourceList'.
|
|
||||||
//
|
|
||||||
// Returns:
|
// Returns:
|
||||||
// - A slice of int64 representing the elements found in 'sourceList' but not in 'targetList'.
|
// - A slice of int64 representing the elements found in 'sourceList' but not in 'targetList'.
|
||||||
// If all elements of 'sourceList' are present in 'targetList', an empty slice is returned.
|
// If all elements of 'sourceList' are present in 'targetList', an empty slice is returned.
|
||||||
|
|
Loading…
Reference in a new issue