diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index a30043e7f..349ab9d9d 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -2,12 +2,14 @@ package embedding import ( "bytes" + "context" "encoding/json" "errors" "fmt" "github.com/ente-io/museum/pkg/utils/array" "strconv" "sync" + gTime "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/s3" @@ -306,7 +308,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em defer wg.Done() defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore - obj, err := c.getEmbeddingObject(objectKey, downloader) + obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader) if err != nil { errs = append(errs, err) log.Error("error fetching embedding object: "+objectKey, err) @@ -343,7 +345,9 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows defer wg.Done() defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model) - obj, err := c.getEmbeddingObject(objectKey, downloader) + ctx, cancel := context.WithTimeout(context.Background(), gTime.Second*10) // 10 seconds timeout + defer cancel() + obj, err := c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 0) if err != nil { log.Error("error fetching embedding object: "+objectKey, err) embeddingObjects[i] = embeddingObjectResult{ @@ -363,21 +367,21 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows return embeddingObjects, nil } -func (c *Controller) getEmbeddingObject(objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) { - return c.getEmbeddingObjectWithRetries(objectKey, downloader, 3) +func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) { + return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 3) } -func (c *Controller) getEmbeddingObjectWithRetries(objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) { +func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) { var obj ente.EmbeddingObject buff := &aws.WriteAtBuffer{} - _, err := downloader.Download(buff, &s3.GetObjectInput{ + _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ Bucket: c.S3Config.GetHotBucket(), Key: &objectKey, }) if err != nil { log.Error(err) if retryCount > 0 { - return c.getEmbeddingObjectWithRetries(objectKey, downloader, retryCount-1) + return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, retryCount-1) } return obj, stacktrace.Propagate(err, "") }