diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index 84c34189d..8ccb43cc0 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -678,7 +678,7 @@ func main() { pushHandler := &api.PushHandler{PushController: pushController} privateAPI.POST("/push/token", pushHandler.AddToken) - embeddingController := &embeddingCtrl.Controller{Repo: embeddingRepo, AccessCtrl: accessCtrl, ObjectCleanupController: objectCleanupController, S3Config: s3Config, FileRepo: fileRepo, CollectionRepo: collectionRepo, QueueRepo: queueRepo, TaskLockingRepo: taskLockingRepo, HostName: hostName} + embeddingController := embeddingCtrl.New(embeddingRepo, accessCtrl, objectCleanupController, s3Config, queueRepo, taskLockingRepo, fileRepo, collectionRepo, hostName) embeddingHandler := &api.EmbeddingHandler{Controller: embeddingController} privateAPI.PUT("/embeddings", embeddingHandler.InsertOrUpdate) diff --git a/server/configurations/local.yaml b/server/configurations/local.yaml index 196c56f1f..87502c271 100644 --- a/server/configurations/local.yaml +++ b/server/configurations/local.yaml @@ -125,6 +125,16 @@ s3: endpoint: region: bucket: + wasabi-eu-central-2-derived: + key: + secret: + endpoint: + region: + bucket: + # Derived storage bucket is used for storing derived data like embeddings, preview etc. + # By default, it is the same as the hot storage bucket. + # derived-storage: wasabi-eu-central-2-derived + # If true, enable some workarounds to allow us to use a local minio instance # for object storage. # diff --git a/server/migrations/86_add_dc_embedding.down.sql b/server/migrations/86_add_dc_embedding.down.sql new file mode 100644 index 000000000..b705b29b6 --- /dev/null +++ b/server/migrations/86_add_dc_embedding.down.sql @@ -0,0 +1,18 @@ +-- Add types for the new dcs that are introduced for the derived data +ALTER TABLE embeddings DROP COLUMN IF EXISTS datacenters; + +DO +$$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_embeddings_updated_at') THEN + CREATE TRIGGER update_embeddings_updated_at + BEFORE UPDATE + ON embeddings + FOR EACH ROW + EXECUTE PROCEDURE + trigger_updated_at_microseconds_column(); + ELSE + RAISE NOTICE 'Trigger update_embeddings_updated_at already exists.'; + END IF; + END +$$; \ No newline at end of file diff --git a/server/migrations/86_add_dc_embedding.up.sql b/server/migrations/86_add_dc_embedding.up.sql new file mode 100644 index 000000000..9d8e28ba7 --- /dev/null +++ b/server/migrations/86_add_dc_embedding.up.sql @@ -0,0 +1,4 @@ +-- Add types for the new dcs that are introduced for the derived data +ALTER TYPE s3region ADD VALUE 'wasabi-eu-central-2-derived'; +DROP TRIGGER IF EXISTS update_embeddings_updated_at ON embeddings; +ALTER TABLE embeddings ADD COLUMN IF NOT EXISTS datacenters s3region[] default '{b2-eu-cen}'; diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index bf317ccfe..6f3de3ca7 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -6,8 +6,10 @@ import ( "encoding/json" "errors" "fmt" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/ente-io/museum/pkg/utils/array" "strconv" + "strings" "sync" gTime "time" @@ -22,7 +24,6 @@ import ( "github.com/ente-io/museum/pkg/utils/auth" "github.com/ente-io/museum/pkg/utils/network" "github.com/ente-io/museum/pkg/utils/s3config" - "github.com/ente-io/museum/pkg/utils/time" "github.com/ente-io/stacktrace" "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" @@ -31,20 +32,54 @@ import ( const ( // maxEmbeddingDataSize is the min size of an embedding object in bytes minEmbeddingDataSize = 2048 - embeddingFetchTimeout = 15 * gTime.Second + embeddingFetchTimeout = 10 * gTime.Second ) +// _fetchConfig is the configuration for the fetching objects from S3 +type _fetchConfig struct { + RetryCount int + InitialTimeout gTime.Duration + MaxTimeout gTime.Duration +} + +var _defaultFetchConfig = _fetchConfig{RetryCount: 3, InitialTimeout: 10 * gTime.Second, MaxTimeout: 30 * gTime.Second} +var _b2FetchConfig = _fetchConfig{RetryCount: 3, InitialTimeout: 15 * gTime.Second, MaxTimeout: 30 * gTime.Second} + type Controller struct { - Repo *embedding.Repository - AccessCtrl access.Controller - ObjectCleanupController *controller.ObjectCleanupController - S3Config *s3config.S3Config - QueueRepo *repo.QueueRepository - TaskLockingRepo *repo.TaskLockRepository - FileRepo *repo.FileRepository - CollectionRepo *repo.CollectionRepository - HostName string - cleanupCronRunning bool + Repo *embedding.Repository + AccessCtrl access.Controller + ObjectCleanupController *controller.ObjectCleanupController + S3Config *s3config.S3Config + QueueRepo *repo.QueueRepository + TaskLockingRepo *repo.TaskLockRepository + FileRepo *repo.FileRepository + CollectionRepo *repo.CollectionRepository + HostName string + cleanupCronRunning bool + derivedStorageDataCenter string + downloadManagerCache map[string]*s3manager.Downloader +} + +func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller { + embeddingDcs := []string{s3Config.GetHotBackblazeDC(), s3Config.GetHotWasabiDC(), s3Config.GetWasabiDerivedDC(), s3Config.GetDerivedStorageDataCenter()} + cache := make(map[string]*s3manager.Downloader, len(embeddingDcs)) + for i := range embeddingDcs { + s3Client := s3Config.GetS3Client(embeddingDcs[i]) + cache[embeddingDcs[i]] = s3manager.NewDownloaderWithClient(&s3Client) + } + return &Controller{ + Repo: repo, + AccessCtrl: accessCtrl, + ObjectCleanupController: objectCleanupController, + S3Config: s3Config, + QueueRepo: queueRepo, + TaskLockingRepo: taskLockingRepo, + FileRepo: fileRepo, + CollectionRepo: collectionRepo, + HostName: hostName, + derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(), + downloadManagerCache: cache, + } } func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) { @@ -77,12 +112,12 @@ func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmb DecryptionHeader: req.DecryptionHeader, Client: network.GetPrettyUA(ctx.GetHeader("User-Agent")) + "/" + ctx.GetHeader("X-Client-Version"), } - size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model)) + size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model), c.derivedStorageDataCenter) if uploadErr != nil { log.Error(uploadErr) return nil, stacktrace.Propagate(uploadErr, "") } - embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version) + embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.derivedStorageDataCenter) embedding.Version = &version if err != nil { return nil, stacktrace.Propagate(err, "") @@ -113,7 +148,7 @@ func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest) // Fetch missing embeddings in parallel if len(objectKeys) > 0 { - embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys) + embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys, c.derivedStorageDataCenter) if err != nil { return nil, stacktrace.Propagate(err, "") } @@ -146,7 +181,7 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd embeddingsWithData := make([]ente.Embedding, 0) noEmbeddingFileIds := make([]int64, 0) dbFileIds := make([]int64, 0) - // fileIDs that were indexed but they don't contain any embedding information + // fileIDs that were indexed, but they don't contain any embedding information for i := range userFileEmbeddings { dbFileIds = append(dbFileIds, userFileEmbeddings[i].FileID) if userFileEmbeddings[i].Size != nil && *userFileEmbeddings[i].Size < minEmbeddingDataSize { @@ -159,7 +194,7 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd errFileIds := make([]int64, 0) // Fetch missing userFileEmbeddings in parallel - embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData) + embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData, c.derivedStorageDataCenter) if err != nil { return nil, stacktrace.Propagate(err, "") } @@ -189,82 +224,6 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd }, nil } -func (c *Controller) DeleteAll(ctx *gin.Context) error { - userID := auth.GetUserID(ctx.Request.Header) - - err := c.Repo.DeleteAll(ctx, userID) - if err != nil { - return stacktrace.Propagate(err, "") - } - return nil -} - -// CleanupDeletedEmbeddings clears all embeddings for deleted files from the object store -func (c *Controller) CleanupDeletedEmbeddings() { - log.Info("Cleaning up deleted embeddings") - if c.cleanupCronRunning { - log.Info("Skipping CleanupDeletedEmbeddings cron run as another instance is still running") - return - } - c.cleanupCronRunning = true - defer func() { - c.cleanupCronRunning = false - }() - items, err := c.QueueRepo.GetItemsReadyForDeletion(repo.DeleteEmbeddingsQueue, 200) - if err != nil { - log.WithError(err).Error("Failed to fetch items from queue") - return - } - for _, i := range items { - c.deleteEmbedding(i) - } -} - -func (c *Controller) deleteEmbedding(qItem repo.QueueItem) { - lockName := fmt.Sprintf("Embedding:%s", qItem.Item) - lockStatus, err := c.TaskLockingRepo.AcquireLock(lockName, time.MicrosecondsAfterHours(1), c.HostName) - ctxLogger := log.WithField("item", qItem.Item).WithField("queue_id", qItem.Id) - if err != nil || !lockStatus { - ctxLogger.Warn("unable to acquire lock") - return - } - defer func() { - err = c.TaskLockingRepo.ReleaseLock(lockName) - if err != nil { - ctxLogger.Errorf("Error while releasing lock %s", err) - } - }() - ctxLogger.Info("Deleting all embeddings") - - fileID, _ := strconv.ParseInt(qItem.Item, 10, 64) - ownerID, err := c.FileRepo.GetOwnerID(fileID) - if err != nil { - ctxLogger.WithError(err).Error("Failed to fetch ownerID") - return - } - prefix := c.getEmbeddingObjectPrefix(ownerID, fileID) - - err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter()) - if err != nil { - ctxLogger.WithError(err).Error("Failed to delete all objects") - return - } - - err = c.Repo.Delete(fileID) - if err != nil { - ctxLogger.WithError(err).Error("Failed to remove from db") - return - } - - err = c.QueueRepo.DeleteItem(repo.DeleteEmbeddingsQueue, qItem.Item) - if err != nil { - ctxLogger.WithError(err).Error("Failed to remove item from the queue") - return - } - - ctxLogger.Info("Successfully deleted all embeddings") -} - func (c *Controller) getObjectKey(userID int64, fileID int64, model string) string { return c.getEmbeddingObjectPrefix(userID, fileID) + model + ".json" } @@ -273,12 +232,23 @@ func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/" } +// Get userId, model and fileID from the object key +func (c *Controller) getEmbeddingObjectDetails(objectKey string) (userID int64, model string, fileID int64) { + split := strings.Split(objectKey, "/") + userID, _ = strconv.ParseInt(split[0], 10, 64) + fileID, _ = strconv.ParseInt(split[2], 10, 64) + model = strings.Split(split[3], ".")[0] + return userID, model, fileID +} + // uploadObject uploads the embedding object to the object store and returns the object size -func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) { +func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) { embeddingObj, _ := json.Marshal(obj) - uploader := s3manager.NewUploaderWithClient(c.S3Config.GetHotS3Client()) + s3Client := c.S3Config.GetS3Client(dc) + s3Bucket := c.S3Config.GetBucket(dc) + uploader := s3manager.NewUploaderWithClient(&s3Client) up := s3manager.UploadInput{ - Bucket: c.S3Config.GetHotBucket(), + Bucket: s3Bucket, Key: &key, Body: bytes.NewReader(embeddingObj), } @@ -296,12 +266,10 @@ var globalDiffFetchSemaphore = make(chan struct{}, 300) var globalFileFetchSemaphore = make(chan struct{}, 400) -func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.EmbeddingObject, error) { +func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string, dc string) ([]ente.EmbeddingObject, error) { var wg sync.WaitGroup var errs []error embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys)) - downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client()) - for i, objectKey := range objectKeys { wg.Add(1) globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore @@ -309,7 +277,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em defer wg.Done() defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore - obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader) + obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc) if err != nil { errs = append(errs, err) log.Error("error fetching embedding object: "+objectKey, err) @@ -334,10 +302,9 @@ type embeddingObjectResult struct { err error } -func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) { +func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding, dc string) ([]embeddingObjectResult, error) { var wg sync.WaitGroup embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows)) - downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client()) for i, dbEmbeddingRow := range dbEmbeddingRows { wg.Add(1) @@ -346,9 +313,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows defer wg.Done() defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model) - ctx, cancel := context.WithTimeout(context.Background(), embeddingFetchTimeout) - defer cancel() - obj, err := c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 0) + obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc) if err != nil { log.Error("error fetching embedding object: "+objectKey, err) embeddingObjects[i] = embeddingObjectResult{ @@ -368,32 +333,125 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows return embeddingObjects, nil } -func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) { - return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 3) +func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) { + opt := _defaultFetchConfig + if dc == c.S3Config.GetHotBackblazeDC() { + opt = _b2FetchConfig + } + ctxLogger := log.WithField("objectKey", objectKey).WithField("dc", dc) + totalAttempts := opt.RetryCount + 1 + timeout := opt.InitialTimeout + for i := 0; i < totalAttempts; i++ { + if i > 0 { + timeout = timeout * 2 + if timeout > opt.MaxTimeout { + timeout = opt.MaxTimeout + } + } + fetchCtx, cancel := context.WithTimeout(ctx, timeout) + select { + case <-ctx.Done(): + cancel() + return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "") + default: + obj, err := c.downloadObject(fetchCtx, objectKey, dc) + cancel() // Ensure cancel is called to release resources + if err == nil { + if i > 0 { + ctxLogger.Infof("Fetched object after %d attempts", i) + } + return obj, nil + } + // Check if the error is due to context timeout or cancellation + if err == nil && fetchCtx.Err() != nil { + ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err()) + } else { + // check if the error is due to object not found + if s3Err, ok := err.(awserr.RequestFailure); ok { + if s3Err.Code() == s3.ErrCodeNoSuchKey { + var srcDc, destDc string + destDc = c.S3Config.GetDerivedStorageDataCenter() + // todo:(neeraj) Refactor this later to get available the DC from the DB instead of + // querying the DB. This will help in case of multiple DCs and avoid querying the DB + // for each object. + // For initial migration, as we know that original DC was b2, and if the embedding is not found + // in the new derived DC, we can try to fetch it from the B2 DC. + if c.derivedStorageDataCenter != c.S3Config.GetHotBackblazeDC() { + // embeddings ideally should ideally be in the default hot bucket b2 + srcDc = c.S3Config.GetHotBackblazeDC() + } else { + _, modelName, fileID := c.getEmbeddingObjectDetails(objectKey) + activeDcs, err := c.Repo.GetOtherDCsForFileAndModel(context.Background(), fileID, modelName, c.derivedStorageDataCenter) + if err != nil { + return ente.EmbeddingObject{}, stacktrace.Propagate(err, "failed to get other dc") + } + if len(activeDcs) > 0 { + srcDc = activeDcs[0] + } else { + ctxLogger.Error("Object not found in any dc ", s3Err) + return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "") + } + } + copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey, srcDc, destDc) + if err == nil { + ctxLogger.Infof("Got object from dc %s", srcDc) + return *copyEmbeddingObject, nil + } else { + ctxLogger.WithError(err).Errorf("Failed to get object from fallback dc %s", srcDc) + } + return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "") + } + } + ctxLogger.Error("Failed to fetch object: ", err) + } + } + } + return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "") } -func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) { +func (c *Controller) downloadObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) { var obj ente.EmbeddingObject buff := &aws.WriteAtBuffer{} + bucket := c.S3Config.GetBucket(dc) + downloader := c.downloadManagerCache[dc] _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ - Bucket: c.S3Config.GetHotBucket(), + Bucket: bucket, Key: &objectKey, }) if err != nil { - log.Error(err) - if retryCount > 0 { - return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, retryCount-1) - } - return obj, stacktrace.Propagate(err, "") + return obj, err } err = json.Unmarshal(buff.Bytes(), &obj) if err != nil { - log.Error(err) - return obj, stacktrace.Propagate(err, "") + return obj, stacktrace.Propagate(err, "unmarshal failed") } return obj, nil } +// download the embedding object from hot bucket and upload to embeddings bucket +func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string, srcDC, destDC string) (*ente.EmbeddingObject, error) { + if srcDC == destDC { + return nil, stacktrace.Propagate(errors.New("src and dest dc can not be same"), "") + } + obj, err := c.downloadObject(ctx, objectKey, srcDC) + if err != nil { + return nil, stacktrace.Propagate(err, fmt.Sprintf("failed to download object from %s", srcDC)) + } + go func() { + userID, modelName, fileID := c.getEmbeddingObjectDetails(objectKey) + size, uploadErr := c.uploadObject(obj, objectKey, c.derivedStorageDataCenter) + if uploadErr != nil { + log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", uploadErr) + } + updateDcErr := c.Repo.AddNewDC(context.Background(), fileID, ente.Model(modelName), userID, size, destDC) + if updateDcErr != nil { + log.WithField("object", objectKey).Error("Failed to update dc in db: ", updateDcErr) + return + } + }() + return &obj, nil +} + func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error { if req.Model == "" { return ente.NewBadRequestWithMessage("model is required") diff --git a/server/pkg/controller/embedding/delete.go b/server/pkg/controller/embedding/delete.go new file mode 100644 index 000000000..dd2027e42 --- /dev/null +++ b/server/pkg/controller/embedding/delete.go @@ -0,0 +1,126 @@ +package embedding + +import ( + "context" + "fmt" + "github.com/ente-io/museum/pkg/repo" + "github.com/ente-io/museum/pkg/utils/auth" + "github.com/ente-io/museum/pkg/utils/time" + "github.com/ente-io/stacktrace" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + "strconv" +) + +func (c *Controller) DeleteAll(ctx *gin.Context) error { + userID := auth.GetUserID(ctx.Request.Header) + + err := c.Repo.DeleteAll(ctx, userID) + if err != nil { + return stacktrace.Propagate(err, "") + } + return nil +} + +// CleanupDeletedEmbeddings clears all embeddings for deleted files from the object store +func (c *Controller) CleanupDeletedEmbeddings() { + log.Info("Cleaning up deleted embeddings") + if c.cleanupCronRunning { + log.Info("Skipping CleanupDeletedEmbeddings cron run as another instance is still running") + return + } + c.cleanupCronRunning = true + defer func() { + c.cleanupCronRunning = false + }() + items, err := c.QueueRepo.GetItemsReadyForDeletion(repo.DeleteEmbeddingsQueue, 200) + if err != nil { + log.WithError(err).Error("Failed to fetch items from queue") + return + } + for _, i := range items { + c.deleteEmbedding(i) + } +} + +func (c *Controller) deleteEmbedding(qItem repo.QueueItem) { + lockName := fmt.Sprintf("Embedding:%s", qItem.Item) + lockStatus, err := c.TaskLockingRepo.AcquireLock(lockName, time.MicrosecondsAfterHours(1), c.HostName) + ctxLogger := log.WithField("item", qItem.Item).WithField("queue_id", qItem.Id) + if err != nil || !lockStatus { + ctxLogger.Warn("unable to acquire lock") + return + } + defer func() { + err = c.TaskLockingRepo.ReleaseLock(lockName) + if err != nil { + ctxLogger.Errorf("Error while releasing lock %s", err) + } + }() + ctxLogger.Info("Deleting all embeddings") + + fileID, _ := strconv.ParseInt(qItem.Item, 10, 64) + ownerID, err := c.FileRepo.GetOwnerID(fileID) + if err != nil { + ctxLogger.WithError(err).Error("Failed to fetch ownerID") + return + } + prefix := c.getEmbeddingObjectPrefix(ownerID, fileID) + datacenters, err := c.Repo.GetDatacenters(context.Background(), fileID) + if err != nil { + ctxLogger.WithError(err).Error("Failed to fetch datacenters") + return + } + // Ensure that the object are deleted from active derived storage dc. Ideally, this section should never be executed + // unless there's a bug in storing the DC or the service restarts before removing the rows from the table + // todo:(neeraj): remove this section after a few weeks of deployment + if len(datacenters) == 0 { + ctxLogger.Warn("No datacenters found for file, ensuring deletion from derived storage and hot DC") + err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetDerivedStorageDataCenter()) + if err != nil { + ctxLogger.WithError(err).Error("Failed to delete all objects") + return + } + // if Derived DC is different from hot DC, delete from hot DC as well + if c.derivedStorageDataCenter != c.S3Config.GetHotDataCenter() { + err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter()) + if err != nil { + ctxLogger.WithError(err).Error("Failed to delete all objects from hot DC") + return + } + } + } else { + ctxLogger.Infof("Deleting from all datacenters %v", datacenters) + } + + for i := range datacenters { + err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, datacenters[i]) + if err != nil { + ctxLogger.WithError(err).Errorf("Failed to delete all objects from %s", datacenters[i]) + return + } else { + removeErr := c.Repo.RemoveDatacenter(context.Background(), fileID, datacenters[i]) + if removeErr != nil { + ctxLogger.WithError(removeErr).Error("Failed to remove datacenter from db") + return + } + } + } + + noDcs, noDcErr := c.Repo.GetDatacenters(context.Background(), fileID) + if len(noDcs) > 0 || noDcErr != nil { + ctxLogger.Errorf("Failed to delete from all datacenters %s", noDcs) + return + } + err = c.Repo.Delete(fileID) + if err != nil { + ctxLogger.WithError(err).Error("Failed to remove from db") + return + } + err = c.QueueRepo.DeleteItem(repo.DeleteEmbeddingsQueue, qItem.Item) + if err != nil { + ctxLogger.WithError(err).Error("Failed to remove item from the queue") + return + } + ctxLogger.Info("Successfully deleted all embeddings") +} diff --git a/server/pkg/repo/embedding/repository.go b/server/pkg/repo/embedding/repository.go index 86915fde5..5cfbd35c5 100644 --- a/server/pkg/repo/embedding/repository.go +++ b/server/pkg/repo/embedding/repository.go @@ -3,11 +3,11 @@ package embedding import ( "context" "database/sql" + "errors" "fmt" - "github.com/lib/pq" - "github.com/ente-io/museum/ente" "github.com/ente-io/stacktrace" + "github.com/lib/pq" "github.com/sirupsen/logrus" ) @@ -18,15 +18,26 @@ type Repository struct { } // Create inserts a new embedding - -func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry ente.InsertOrUpdateEmbeddingRequest, size int, version int) (ente.Embedding, error) { +func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry ente.InsertOrUpdateEmbeddingRequest, size int, version int, dc string) (ente.Embedding, error) { var updatedAt int64 - err := r.DB.QueryRowContext(ctx, `INSERT INTO embeddings - (file_id, owner_id, model, size, version) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT ON CONSTRAINT unique_embeddings_file_id_model - DO UPDATE SET updated_at = now_utc_micro_seconds(), size = $4, version = $5 - RETURNING updated_at`, entry.FileID, ownerID, entry.Model, size, version).Scan(&updatedAt) + err := r.DB.QueryRowContext(ctx, ` + INSERT INTO embeddings + (file_id, owner_id, model, size, version, datacenters) + VALUES + ($1, $2, $3, $4, $5, ARRAY[$6]::s3region[]) + ON CONFLICT ON CONSTRAINT unique_embeddings_file_id_model + DO UPDATE + SET + updated_at = now_utc_micro_seconds(), + size = $4, + version = $5, + datacenters = CASE + WHEN $6 = ANY(COALESCE(embeddings.datacenters, ARRAY['b2-eu-cen']::s3region[])) THEN embeddings.datacenters + ELSE array_append(COALESCE(embeddings.datacenters, ARRAY['b2-eu-cen']::s3region[]), $6::s3region) + END + RETURNING updated_at`, + entry.FileID, ownerID, entry.Model, size, version, dc).Scan(&updatedAt) + if err != nil { // check if error is due to model enum invalid value if err.Error() == fmt.Sprintf("pq: invalid input value for enum model: \"%s\"", entry.Model) { @@ -82,6 +93,89 @@ func (r *Repository) Delete(fileID int64) error { return nil } +// GetDatacenters returns unique list of datacenters where derived embeddings are stored +func (r *Repository) GetDatacenters(ctx context.Context, fileID int64) ([]string, error) { + rows, err := r.DB.QueryContext(ctx, `SELECT datacenters FROM embeddings WHERE file_id = $1`, fileID) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + uniqueDatacenters := make(map[string]struct{}) + for rows.Next() { + var datacenters []string + err = rows.Scan(pq.Array(&datacenters)) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + for _, dc := range datacenters { + uniqueDatacenters[dc] = struct{}{} + } + } + datacenters := make([]string, 0, len(uniqueDatacenters)) + for dc := range uniqueDatacenters { + datacenters = append(datacenters, dc) + } + return datacenters, nil +} + +// GetOtherDCsForFileAndModel returns the list of datacenters where the embeddings are stored for a given file and model, excluding the ignoredDC +func (r *Repository) GetOtherDCsForFileAndModel(ctx context.Context, fileID int64, model string, ignoredDC string) ([]string, error) { + rows, err := r.DB.QueryContext(ctx, `SELECT datacenters FROM embeddings WHERE file_id = $1 AND model = $2`, fileID, model) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + uniqueDatacenters := make(map[string]bool) + for rows.Next() { + var datacenters []string + err = rows.Scan(pq.Array(&datacenters)) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + for _, dc := range datacenters { + // add to uniqueDatacenters if it is not the ignoredDC + if dc != ignoredDC { + uniqueDatacenters[dc] = true + } + } + } + datacenters := make([]string, 0, len(uniqueDatacenters)) + for dc := range uniqueDatacenters { + datacenters = append(datacenters, dc) + } + return datacenters, nil +} + +// RemoveDatacenter removes the given datacenter from the list of datacenters +func (r *Repository) RemoveDatacenter(ctx context.Context, fileID int64, dc string) error { + _, err := r.DB.ExecContext(ctx, `UPDATE embeddings SET datacenters = array_remove(datacenters, $1) WHERE file_id = $2`, dc, fileID) + if err != nil { + return stacktrace.Propagate(err, "") + } + return nil +} + +// AddNewDC adds the dc name to the list of datacenters, if it doesn't exist already, for a given file, model and user. It also updates the size of the embedding +func (r *Repository) AddNewDC(ctx context.Context, fileID int64, model ente.Model, userID int64, size int, dc string) error { + res, err := r.DB.ExecContext(ctx, ` + UPDATE embeddings + SET size = $1, + datacenters = CASE + WHEN $2::s3region = ANY(datacenters) THEN datacenters + ELSE array_append(datacenters, $2::s3region) + END + WHERE file_id = $3 AND model = $4 AND owner_id = $5`, size, dc, fileID, model, userID) + if err != nil { + return stacktrace.Propagate(err, "") + } + rowsAffected, err := res.RowsAffected() + if err != nil { + return stacktrace.Propagate(err, "") + } + if rowsAffected == 0 { + return stacktrace.Propagate(errors.New("no row got updated"), "") + } + return nil +} + func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) { defer func() { if err := rows.Close(); err != nil { diff --git a/server/pkg/utils/s3config/s3config.go b/server/pkg/utils/s3config/s3config.go index 9b273bd61..a562e5181 100644 --- a/server/pkg/utils/s3config/s3config.go +++ b/server/pkg/utils/s3config/s3config.go @@ -28,6 +28,8 @@ type S3Config struct { hotDC string // Secondary (hot) data center secondaryHotDC string + //Derived data data center for derived files like ml embeddings & preview files + derivedStorageDC string // A map from data centers to S3 configurations s3Configs map[string]*aws.Config // A map from data centers to pre-created S3 clients @@ -71,6 +73,7 @@ var ( dcWasabiEuropeCentralDeprecated string = "wasabi-eu-central-2" dcWasabiEuropeCentral_v3 string = "wasabi-eu-central-2-v3" dcSCWEuropeFrance_v3 string = "scw-eu-fr-v3" + dcWasabiEuropeCentralDerived string = "wasabi-eu-central-2-derived" ) // Number of days that the wasabi bucket is configured to retain objects. @@ -86,9 +89,9 @@ func NewS3Config() *S3Config { } func (config *S3Config) initialize() { - dcs := [5]string{ + dcs := [6]string{ dcB2EuropeCentral, dcSCWEuropeFranceLockedDeprecated, dcWasabiEuropeCentralDeprecated, - dcWasabiEuropeCentral_v3, dcSCWEuropeFrance_v3} + dcWasabiEuropeCentral_v3, dcSCWEuropeFrance_v3, dcWasabiEuropeCentralDerived} config.hotDC = dcB2EuropeCentral config.secondaryHotDC = dcWasabiEuropeCentral_v3 @@ -99,6 +102,12 @@ func (config *S3Config) initialize() { config.secondaryHotDC = hs2 log.Infof("Hot storage: %s (secondary: %s)", hs1, hs2) } + config.derivedStorageDC = config.hotDC + embeddingsDC := viper.GetString("s3.derived-storage") + if embeddingsDC != "" && array.StringInList(embeddingsDC, dcs[:]) { + config.derivedStorageDC = embeddingsDC + log.Infof("Embeddings bucket: %s", embeddingsDC) + } config.buckets = make(map[string]string) config.s3Configs = make(map[string]*aws.Config) @@ -171,6 +180,18 @@ func (config *S3Config) GetHotS3Client() *s3.S3 { return &s3Client } +func (config *S3Config) GetDerivedStorageDataCenter() string { + return config.derivedStorageDC +} +func (config *S3Config) GetDerivedStorageBucket() *string { + return config.GetBucket(config.derivedStorageDC) +} + +func (config *S3Config) GetDerivedStorageS3Client() *s3.S3 { + s3Client := config.GetS3Client(config.derivedStorageDC) + return &s3Client +} + // Return the name of the hot Backblaze data center func (config *S3Config) GetHotBackblazeDC() string { return dcB2EuropeCentral @@ -181,6 +202,10 @@ func (config *S3Config) GetHotWasabiDC() string { return dcWasabiEuropeCentral_v3 } +func (config *S3Config) GetWasabiDerivedDC() string { + return dcWasabiEuropeCentralDerived +} + // Return the name of the cold Scaleway data center func (config *S3Config) GetColdScalewayDC() string { return dcSCWEuropeFrance_v3