Update dc while copying derived file
This commit is contained in:
parent
a522631c2b
commit
b404b77da3
|
@ -5,9 +5,11 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
"github.com/ente-io/museum/pkg/utils/array"
|
"github.com/ente-io/museum/pkg/utils/array"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
gTime "time"
|
gTime "time"
|
||||||
|
|
||||||
|
@ -228,6 +230,15 @@ func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string
|
||||||
return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/"
|
return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get userId, model and fileID from the object key
|
||||||
|
func (c *Controller) getEmbeddingObjectDetails(objectKey string) (userID int64, model string, fileID int64) {
|
||||||
|
split := strings.Split(objectKey, "/")
|
||||||
|
userID, _ = strconv.ParseInt(split[0], 10, 64)
|
||||||
|
fileID, _ = strconv.ParseInt(split[2], 10, 64)
|
||||||
|
model = strings.Split(split[3], ".")[0]
|
||||||
|
return userID, model, fileID
|
||||||
|
}
|
||||||
|
|
||||||
// uploadObject uploads the embedding object to the object store and returns the object size
|
// uploadObject uploads the embedding object to the object store and returns the object size
|
||||||
func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) {
|
func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) {
|
||||||
embeddingObj, _ := json.Marshal(obj)
|
embeddingObj, _ := json.Marshal(obj)
|
||||||
|
@ -352,7 +363,7 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
|
||||||
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
|
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
|
||||||
} else {
|
} else {
|
||||||
// If derived and hot bucket are different, try to copy from hot bucket
|
// If derived and hot bucket are different, try to copy from hot bucket
|
||||||
copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey)
|
copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey, c.S3Config.GetHotBackblazeDC(), c.derivedStorageDataCenter)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
ctxLogger.Info("Got the object from hot bucket object")
|
ctxLogger.Info("Got the object from hot bucket object")
|
||||||
return *copyEmbeddingObject, nil
|
return *copyEmbeddingObject, nil
|
||||||
|
@ -390,21 +401,26 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, dc st
|
||||||
}
|
}
|
||||||
|
|
||||||
// download the embedding object from hot bucket and upload to embeddings bucket
|
// download the embedding object from hot bucket and upload to embeddings bucket
|
||||||
func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
|
func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string, srcDC, destDC string) (*ente.EmbeddingObject, error) {
|
||||||
if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() {
|
if srcDC == destDC {
|
||||||
return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
|
return nil, stacktrace.Propagate(errors.New("src and dest dc can not be same"), "")
|
||||||
}
|
}
|
||||||
obj, err := c.downloadObject(ctx, objectKey, c.S3Config.GetHotBackblazeDC())
|
obj, err := c.downloadObject(ctx, objectKey, srcDC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
|
return nil, stacktrace.Propagate(err, fmt.Sprintf("failed to download object from %s", srcDC))
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
_, err = c.uploadObject(obj, objectKey, c.derivedStorageDataCenter)
|
userID, modelName, fileID := c.getEmbeddingObjectDetails(objectKey)
|
||||||
if err != nil {
|
size, uploadErr := c.uploadObject(obj, objectKey, c.derivedStorageDataCenter)
|
||||||
log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", err)
|
if uploadErr != nil {
|
||||||
|
log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", uploadErr)
|
||||||
|
}
|
||||||
|
updateDcErr := c.Repo.AddNewDC(context.Background(), fileID, ente.Model(modelName), userID, size, destDC)
|
||||||
|
if updateDcErr != nil {
|
||||||
|
log.WithField("object", objectKey).Error("Failed to update dc in db: ", updateDcErr)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return &obj, nil
|
return &obj, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ package embedding
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
|
||||||
|
@ -126,6 +127,22 @@ func (r *Repository) RemoveDatacenter(ctx context.Context, fileID int64, dc stri
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddNewDC adds the dc name to the list of datacenters, if it doesn't exist already, for a given file, model and user. It also updates the size of the embedding
|
||||||
|
func (r *Repository) AddNewDC(ctx context.Context, fileID int64, model ente.Model, userID int64, size int, dc string) error {
|
||||||
|
res, err := r.DB.ExecContext(ctx, `UPDATE embeddings SET size = $1, datacenters = array_append(COALESCE(datacenters, ARRAY[]::s3region[]), $2::s3region) WHERE file_id = $3 AND model = $4 AND owner_id = $5`, size, dc, fileID, model, userID)
|
||||||
|
if err != nil {
|
||||||
|
return stacktrace.Propagate(err, "")
|
||||||
|
}
|
||||||
|
rowsAffected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return stacktrace.Propagate(err, "")
|
||||||
|
}
|
||||||
|
if rowsAffected == 0 {
|
||||||
|
return stacktrace.Propagate(errors.New("no row got updated"), "")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) {
|
func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := rows.Close(); err != nil {
|
if err := rows.Close(); err != nil {
|
||||||
|
|
Loading…
Reference in a new issue