[server] Add timeout while fetching embedding

This commit is contained in:
Neeraj Gupta 2024-05-10 17:17:55 +05:30
parent 5caa9c5f61
commit 3a70dcd930

View file

@ -2,12 +2,14 @@ package embedding
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/ente-io/museum/pkg/utils/array"
"strconv"
"sync"
gTime "time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
@ -306,7 +308,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
defer wg.Done()
defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
obj, err := c.getEmbeddingObject(objectKey, downloader)
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
if err != nil {
errs = append(errs, err)
log.Error("error fetching embedding object: "+objectKey, err)
@ -343,7 +345,9 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
defer wg.Done()
defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
obj, err := c.getEmbeddingObject(objectKey, downloader)
ctx, cancel := context.WithTimeout(context.Background(), gTime.Second*10) // 10 seconds timeout
defer cancel()
obj, err := c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 0)
if err != nil {
log.Error("error fetching embedding object: "+objectKey, err)
embeddingObjects[i] = embeddingObjectResult{
@ -363,21 +367,21 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
return embeddingObjects, nil
}
func (c *Controller) getEmbeddingObject(objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
return c.getEmbeddingObjectWithRetries(objectKey, downloader, 3)
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 3)
}
func (c *Controller) getEmbeddingObjectWithRetries(objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) {
func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) {
var obj ente.EmbeddingObject
buff := &aws.WriteAtBuffer{}
_, err := downloader.Download(buff, &s3.GetObjectInput{
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
Bucket: c.S3Config.GetHotBucket(),
Key: &objectKey,
})
if err != nil {
log.Error(err)
if retryCount > 0 {
return c.getEmbeddingObjectWithRetries(objectKey, downloader, retryCount-1)
return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, retryCount-1)
}
return obj, stacktrace.Propagate(err, "")
}