Refactor
This commit is contained in:
parent
20e9a6a1fc
commit
a522631c2b
|
@ -53,9 +53,16 @@ type Controller struct {
|
||||||
HostName string
|
HostName string
|
||||||
cleanupCronRunning bool
|
cleanupCronRunning bool
|
||||||
derivedStorageDataCenter string
|
derivedStorageDataCenter string
|
||||||
|
downloadManagerCache map[string]*s3manager.Downloader
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
|
func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
|
||||||
|
embeddingDcs := []string{s3Config.GetHotBackblazeDC(), s3Config.GetHotWasabiDC(), s3Config.GetDerivedStorageDataCenter()}
|
||||||
|
cache := make(map[string]*s3manager.Downloader, len(embeddingDcs))
|
||||||
|
for i := range embeddingDcs {
|
||||||
|
s3Client := s3Config.GetS3Client(embeddingDcs[i])
|
||||||
|
cache[embeddingDcs[i]] = s3manager.NewDownloaderWithClient(&s3Client)
|
||||||
|
}
|
||||||
return &Controller{
|
return &Controller{
|
||||||
Repo: repo,
|
Repo: repo,
|
||||||
AccessCtrl: accessCtrl,
|
AccessCtrl: accessCtrl,
|
||||||
|
@ -67,6 +74,7 @@ func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanup
|
||||||
CollectionRepo: collectionRepo,
|
CollectionRepo: collectionRepo,
|
||||||
HostName: hostName,
|
HostName: hostName,
|
||||||
derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(),
|
derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(),
|
||||||
|
downloadManagerCache: cache,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -136,7 +144,7 @@ func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest)
|
||||||
|
|
||||||
// Fetch missing embeddings in parallel
|
// Fetch missing embeddings in parallel
|
||||||
if len(objectKeys) > 0 {
|
if len(objectKeys) > 0 {
|
||||||
embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys)
|
embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys, c.derivedStorageDataCenter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, stacktrace.Propagate(err, "")
|
return nil, stacktrace.Propagate(err, "")
|
||||||
}
|
}
|
||||||
|
@ -182,7 +190,7 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd
|
||||||
errFileIds := make([]int64, 0)
|
errFileIds := make([]int64, 0)
|
||||||
|
|
||||||
// Fetch missing userFileEmbeddings in parallel
|
// Fetch missing userFileEmbeddings in parallel
|
||||||
embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData)
|
embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData, c.derivedStorageDataCenter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, stacktrace.Propagate(err, "")
|
return nil, stacktrace.Propagate(err, "")
|
||||||
}
|
}
|
||||||
|
@ -245,13 +253,10 @@ var globalDiffFetchSemaphore = make(chan struct{}, 300)
|
||||||
|
|
||||||
var globalFileFetchSemaphore = make(chan struct{}, 400)
|
var globalFileFetchSemaphore = make(chan struct{}, 400)
|
||||||
|
|
||||||
func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.EmbeddingObject, error) {
|
func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string, dc string) ([]ente.EmbeddingObject, error) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
var errs []error
|
var errs []error
|
||||||
embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys))
|
embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys))
|
||||||
s3Client := c.S3Config.GetS3Client(c.derivedStorageDataCenter)
|
|
||||||
downloader := s3manager.NewDownloaderWithClient(&s3Client)
|
|
||||||
|
|
||||||
for i, objectKey := range objectKeys {
|
for i, objectKey := range objectKeys {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore
|
globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore
|
||||||
|
@ -259,7 +264,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
|
defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
|
||||||
|
|
||||||
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
|
obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs = append(errs, err)
|
errs = append(errs, err)
|
||||||
log.Error("error fetching embedding object: "+objectKey, err)
|
log.Error("error fetching embedding object: "+objectKey, err)
|
||||||
|
@ -284,11 +289,9 @@ type embeddingObjectResult struct {
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) {
|
func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding, dc string) ([]embeddingObjectResult, error) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
|
embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
|
||||||
s3Client := c.S3Config.GetS3Client(c.derivedStorageDataCenter)
|
|
||||||
downloader := s3manager.NewDownloaderWithClient(&s3Client)
|
|
||||||
|
|
||||||
for i, dbEmbeddingRow := range dbEmbeddingRows {
|
for i, dbEmbeddingRow := range dbEmbeddingRows {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
@ -297,7 +300,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
|
defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
|
||||||
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
|
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
|
||||||
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
|
obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("error fetching embedding object: "+objectKey, err)
|
log.Error("error fetching embedding object: "+objectKey, err)
|
||||||
embeddingObjects[i] = embeddingObjectResult{
|
embeddingObjects[i] = embeddingObjectResult{
|
||||||
|
@ -317,7 +320,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
|
||||||
return embeddingObjects, nil
|
return embeddingObjects, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
|
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) {
|
||||||
opt := _defaultFetchConfig
|
opt := _defaultFetchConfig
|
||||||
ctxLogger := log.WithField("objectKey", objectKey)
|
ctxLogger := log.WithField("objectKey", objectKey)
|
||||||
totalAttempts := opt.RetryCount + 1
|
totalAttempts := opt.RetryCount + 1
|
||||||
|
@ -329,7 +332,7 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
|
||||||
cancel()
|
cancel()
|
||||||
return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
|
return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
|
||||||
default:
|
default:
|
||||||
obj, err := c.downloadObject(fetchCtx, objectKey, downloader, c.derivedStorageDataCenter)
|
obj, err := c.downloadObject(fetchCtx, objectKey, dc)
|
||||||
cancel() // Ensure cancel is called to release resources
|
cancel() // Ensure cancel is called to release resources
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
|
@ -367,10 +370,11 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
|
||||||
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
|
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, dc string) (ente.EmbeddingObject, error) {
|
func (c *Controller) downloadObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) {
|
||||||
var obj ente.EmbeddingObject
|
var obj ente.EmbeddingObject
|
||||||
buff := &aws.WriteAtBuffer{}
|
buff := &aws.WriteAtBuffer{}
|
||||||
bucket := c.S3Config.GetBucket(dc)
|
bucket := c.S3Config.GetBucket(dc)
|
||||||
|
downloader := c.downloadManagerCache[dc]
|
||||||
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
|
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
|
||||||
Bucket: bucket,
|
Bucket: bucket,
|
||||||
Key: &objectKey,
|
Key: &objectKey,
|
||||||
|
@ -390,8 +394,7 @@ func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string)
|
||||||
if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() {
|
if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() {
|
||||||
return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
|
return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
|
||||||
}
|
}
|
||||||
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
|
obj, err := c.downloadObject(ctx, objectKey, c.S3Config.GetHotBackblazeDC())
|
||||||
obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBackblazeDC())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
|
return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue