Refactor
This commit is contained in:
parent
3485b31475
commit
20e9a6a1fc
|
@ -33,6 +33,14 @@ const (
|
||||||
embeddingFetchTimeout = 15 * gTime.Second
|
embeddingFetchTimeout = 15 * gTime.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// _fetchConfig is the configuration for the fetching objects from S3
|
||||||
|
type _fetchConfig struct {
|
||||||
|
RetryCount int
|
||||||
|
FetchTimeOut gTime.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
var _defaultFetchConfig = _fetchConfig{RetryCount: 3, FetchTimeOut: 15 * gTime.Second}
|
||||||
|
|
||||||
type Controller struct {
|
type Controller struct {
|
||||||
Repo *embedding.Repository
|
Repo *embedding.Repository
|
||||||
AccessCtrl access.Controller
|
AccessCtrl access.Controller
|
||||||
|
@ -251,7 +259,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
|
defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
|
||||||
|
|
||||||
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
|
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs = append(errs, err)
|
errs = append(errs, err)
|
||||||
log.Error("error fetching embedding object: "+objectKey, err)
|
log.Error("error fetching embedding object: "+objectKey, err)
|
||||||
|
@ -289,7 +297,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
|
defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
|
||||||
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
|
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
|
||||||
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
|
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("error fetching embedding object: "+objectKey, err)
|
log.Error("error fetching embedding object: "+objectKey, err)
|
||||||
embeddingObjects[i] = embeddingObjectResult{
|
embeddingObjects[i] = embeddingObjectResult{
|
||||||
|
@ -309,18 +317,8 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
|
||||||
return embeddingObjects, nil
|
return embeddingObjects, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type getOptions struct {
|
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
|
||||||
RetryCount int
|
opt := _defaultFetchConfig
|
||||||
FetchTimeOut gTime.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, opt *getOptions) (ente.EmbeddingObject, error) {
|
|
||||||
if opt == nil {
|
|
||||||
opt = &getOptions{
|
|
||||||
RetryCount: 3,
|
|
||||||
FetchTimeOut: embeddingFetchTimeout,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ctxLogger := log.WithField("objectKey", objectKey)
|
ctxLogger := log.WithField("objectKey", objectKey)
|
||||||
totalAttempts := opt.RetryCount + 1
|
totalAttempts := opt.RetryCount + 1
|
||||||
for i := 0; i < totalAttempts; i++ {
|
for i := 0; i < totalAttempts; i++ {
|
||||||
|
@ -346,7 +344,7 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
|
||||||
// check if the error is due to object not found
|
// check if the error is due to object not found
|
||||||
if s3Err, ok := err.(awserr.RequestFailure); ok {
|
if s3Err, ok := err.(awserr.RequestFailure); ok {
|
||||||
if s3Err.Code() == s3.ErrCodeNoSuchKey {
|
if s3Err.Code() == s3.ErrCodeNoSuchKey {
|
||||||
if c.derivedStorageDataCenter == c.S3Config.GetHotDataCenter() {
|
if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() {
|
||||||
ctxLogger.Error("Object not found: ", s3Err)
|
ctxLogger.Error("Object not found: ", s3Err)
|
||||||
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
|
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
|
||||||
} else {
|
} else {
|
||||||
|
@ -389,11 +387,11 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl
|
||||||
|
|
||||||
// download the embedding object from hot bucket and upload to embeddings bucket
|
// download the embedding object from hot bucket and upload to embeddings bucket
|
||||||
func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
|
func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
|
||||||
if c.derivedStorageDataCenter == c.S3Config.GetHotDataCenter() {
|
if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() {
|
||||||
return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
|
return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
|
||||||
}
|
}
|
||||||
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
|
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
|
||||||
obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotDataCenter())
|
obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBackblazeDC())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
|
return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue