This commit is contained in:
Neeraj Gupta 2024-05-16 13:39:47 +05:30
parent 3485b31475
commit 20e9a6a1fc

View file

@ -33,6 +33,14 @@ const (
embeddingFetchTimeout = 15 * gTime.Second embeddingFetchTimeout = 15 * gTime.Second
) )
// _fetchConfig is the configuration for the fetching objects from S3
type _fetchConfig struct {
RetryCount int
FetchTimeOut gTime.Duration
}
var _defaultFetchConfig = _fetchConfig{RetryCount: 3, FetchTimeOut: 15 * gTime.Second}
type Controller struct { type Controller struct {
Repo *embedding.Repository Repo *embedding.Repository
AccessCtrl access.Controller AccessCtrl access.Controller
@ -251,7 +259,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
defer wg.Done() defer wg.Done()
defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil) obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
if err != nil { if err != nil {
errs = append(errs, err) errs = append(errs, err)
log.Error("error fetching embedding object: "+objectKey, err) log.Error("error fetching embedding object: "+objectKey, err)
@ -289,7 +297,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
defer wg.Done() defer wg.Done()
defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model) objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil) obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
if err != nil { if err != nil {
log.Error("error fetching embedding object: "+objectKey, err) log.Error("error fetching embedding object: "+objectKey, err)
embeddingObjects[i] = embeddingObjectResult{ embeddingObjects[i] = embeddingObjectResult{
@ -309,18 +317,8 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
return embeddingObjects, nil return embeddingObjects, nil
} }
type getOptions struct { func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
RetryCount int opt := _defaultFetchConfig
FetchTimeOut gTime.Duration
}
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, opt *getOptions) (ente.EmbeddingObject, error) {
if opt == nil {
opt = &getOptions{
RetryCount: 3,
FetchTimeOut: embeddingFetchTimeout,
}
}
ctxLogger := log.WithField("objectKey", objectKey) ctxLogger := log.WithField("objectKey", objectKey)
totalAttempts := opt.RetryCount + 1 totalAttempts := opt.RetryCount + 1
for i := 0; i < totalAttempts; i++ { for i := 0; i < totalAttempts; i++ {
@ -346,7 +344,7 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
// check if the error is due to object not found // check if the error is due to object not found
if s3Err, ok := err.(awserr.RequestFailure); ok { if s3Err, ok := err.(awserr.RequestFailure); ok {
if s3Err.Code() == s3.ErrCodeNoSuchKey { if s3Err.Code() == s3.ErrCodeNoSuchKey {
if c.derivedStorageDataCenter == c.S3Config.GetHotDataCenter() { if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() {
ctxLogger.Error("Object not found: ", s3Err) ctxLogger.Error("Object not found: ", s3Err)
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "") return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
} else { } else {
@ -389,11 +387,11 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl
// download the embedding object from hot bucket and upload to embeddings bucket // download the embedding object from hot bucket and upload to embeddings bucket
func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) { func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
if c.derivedStorageDataCenter == c.S3Config.GetHotDataCenter() { if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() {
return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "") return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
} }
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client()) downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotDataCenter()) obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBackblazeDC())
if err != nil { if err != nil {
return nil, stacktrace.Propagate(err, "failed to download from hot bucket") return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
} }