diff --git a/server/pkg/repo/embedding/repository.go b/server/pkg/repo/embedding/repository.go index 1288512b3..d61134ecb 100644 --- a/server/pkg/repo/embedding/repository.go +++ b/server/pkg/repo/embedding/repository.go @@ -129,7 +129,14 @@ func (r *Repository) RemoveDatacenter(ctx context.Context, fileID int64, dc stri // AddNewDC adds the dc name to the list of datacenters, if it doesn't exist already, for a given file, model and user. It also updates the size of the embedding func (r *Repository) AddNewDC(ctx context.Context, fileID int64, model ente.Model, userID int64, size int, dc string) error { - res, err := r.DB.ExecContext(ctx, `UPDATE embeddings SET size = $1, datacenters = array_append(COALESCE(datacenters, ARRAY[]::s3region[]), $2::s3region) WHERE file_id = $3 AND model = $4 AND owner_id = $5`, size, dc, fileID, model, userID) + res, err := r.DB.ExecContext(ctx, ` + UPDATE embeddings + SET size = $1, + datacenters = CASE + WHEN $2::s3region = ANY(datacenters) THEN datacenters + ELSE array_append(datacenters, $2::s3region) + END + WHERE file_id = $3 AND model = $4 AND owner_id = $5`, size, dc, fileID, model, userID) if err != nil { return stacktrace.Propagate(err, "") } @@ -138,7 +145,7 @@ func (r *Repository) AddNewDC(ctx context.Context, fileID int64, model ente.Mode return stacktrace.Propagate(err, "") } if rowsAffected == 0 { - return stacktrace.Propagate(errors.New("no row got updated"), "") + return stacktrace.Propagate(errors.New("no row got updated"), "") } return nil }