[server] Increase embedding fetch limit (#1300)

## Description

Also use different semaphore than existing diff API

## Tests
This commit is contained in:
Neeraj Gupta 2024-04-03 12:38:34 +05:30 committed by GitHub
parent ca688d0d46
commit 2fe703df92
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -275,7 +275,9 @@ func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, er
return len(embeddingObj), nil
}
var globalFetchSemaphore = make(chan struct{}, 300)
var globalDiffFetchSemaphore = make(chan struct{}, 300)
var globalFileFetchSemaphore = make(chan struct{}, 400)
func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.EmbeddingObject, error) {
var wg sync.WaitGroup
@ -285,10 +287,10 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
for i, objectKey := range objectKeys {
wg.Add(1)
globalFetchSemaphore <- struct{}{} // Acquire from global semaphore
globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore
go func(i int, objectKey string) {
defer wg.Done()
defer func() { <-globalFetchSemaphore }() // Release back to global semaphore
defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
obj, err := c.getEmbeddingObject(objectKey, downloader)
if err != nil {
@ -322,10 +324,10 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
for i, dbEmbeddingRow := range dbEmbeddingRows {
wg.Add(1)
globalFetchSemaphore <- struct{}{} // Acquire from global semaphore
globalFileFetchSemaphore <- struct{}{} // Acquire from global semaphore
go func(i int, dbEmbeddingRow ente.Embedding) {
defer wg.Done()
defer func() { <-globalFetchSemaphore }() // Release back to global semaphore
defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
obj, err := c.getEmbeddingObject(objectKey, downloader)
if err != nil {
@ -373,8 +375,8 @@ func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID
if len(req.FileIDs) == 0 {
return ente.NewBadRequestWithMessage("fileIDs are required")
}
if len(req.FileIDs) > 100 {
return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 100")
if len(req.FileIDs) > 200 {
return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 200")
}
if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
ActorUserId: userID,