diff --git a/internal/entity/entity_save.go b/internal/entity/entity_save.go deleted file mode 100644 index f63923e5d..000000000 --- a/internal/entity/entity_save.go +++ /dev/null @@ -1,52 +0,0 @@ -package entity - -import ( - "fmt" - "reflect" - "runtime/debug" - "strings" -) - -// Save updates an entity in the database, or inserts if it doesn't exist. -func Save(m interface{}, primaryKeys ...string) (err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("index: save failed (%s)\nstack: %s", r, debug.Stack()) - log.Error(err) - } - }() - - if err := Update(m, primaryKeys...); err == nil { - return nil - } else if err := UnscopedDb().Save(m).Error; err == nil { - return nil - } else if !strings.Contains(strings.ToLower(err.Error()), "lock") { - return err - } else if err := UnscopedDb().Save(m).Error; err != nil { - return err - } - - return nil -} - -// Update updates an existing entity in the database. -func Update(m interface{}, primaryKeys ...string) (err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("index: update failed (%s)\nstack: %s", r, debug.Stack()) - log.Error(err) - } - }() - - // Return with error if a primary key is empty. - v := reflect.ValueOf(m).Elem() - for _, k := range primaryKeys { - if field := v.FieldByName(k); !field.CanSet() || field.IsZero() { - return fmt.Errorf("empty primary key '%s'", k) - } - } - - err = UnscopedDb().FirstOrCreate(m, GetValues(m)).Error - - return err -} diff --git a/internal/entity/entity_save_test.go b/internal/entity/entity_save_test.go deleted file mode 100644 index 05a1cd785..000000000 --- a/internal/entity/entity_save_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package entity - -import ( - "testing" - "time" - - "github.com/photoprism/photoprism/internal/face" - "github.com/photoprism/photoprism/pkg/rnd" - "github.com/stretchr/testify/assert" -) - -func TestUpdate(t *testing.T) { - t.Run("HasCreatedUpdatedAt", func(t *testing.T) { - m := NewFace(rnd.PPID('j'), SrcAuto, face.RandomEmbeddings(1, face.RegularFace)) - id := m.ID - - m.CreatedAt = time.Now() - m.UpdatedAt = time.Now() - - if err := m.Save(); err != nil { - t.Fatal(err) - return - } - - found := FindFace(id) - - assert.NotNil(t, found) - assert.Equal(t, id, found.ID) - assert.Greater(t, time.Now(), m.UpdatedAt) - assert.Equal(t, found.CreatedAt.UTC(), m.CreatedAt.UTC()) - }) - t.Run("HasCreatedAt", func(t *testing.T) { - m := NewFace(rnd.PPID('j'), SrcAuto, face.RandomEmbeddings(1, face.RegularFace)) - id := m.ID - - m.CreatedAt = time.Now() - - if err := m.Save(); err != nil { - t.Fatal(err) - return - } - - found := FindFace(id) - assert.NotNil(t, found) - assert.Equal(t, id, found.ID) - assert.Greater(t, time.Now().UTC(), m.UpdatedAt.UTC()) - assert.Equal(t, found.CreatedAt.UTC(), m.CreatedAt.UTC()) - }) - t.Run("NoCreatedAt", func(t *testing.T) { - m := NewFace(rnd.PPID('j'), SrcAuto, face.RandomEmbeddings(1, face.RegularFace)) - id := m.ID - - if err := m.Save(); err != nil { - t.Fatal(err) - return - } - - found := FindFace(id) - assert.NotNil(t, found) - assert.Equal(t, id, found.ID) - assert.Greater(t, time.Now(), m.UpdatedAt.UTC()) - assert.Equal(t, found.CreatedAt.UTC(), m.CreatedAt.UTC()) - }) -} diff --git a/internal/entity/entity_update.go b/internal/entity/entity_update.go new file mode 100644 index 000000000..b35dbb7b7 --- /dev/null +++ b/internal/entity/entity_update.go @@ -0,0 +1,89 @@ +package entity + +import ( + "fmt" + "runtime/debug" + "strings" + + "github.com/jinzhu/gorm" +) + +// Save updates a record in the database, or inserts if it doesn't exist. +func Save(m interface{}, keyNames ...string) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("index: save failed (%s)\nstack: %s", r, debug.Stack()) + log.Error(err) + } + }() + + // Try updating first. + if err = Update(m, keyNames...); err == nil { + return nil + } else if err = UnscopedDb().Save(m).Error; err == nil { + return nil + } else if !strings.Contains(strings.ToLower(err.Error()), "lock") { + return err + } else if err = UnscopedDb().Save(m).Error; err != nil { + return err + } + + return nil +} + +// Update updates an existing record in the database. +func Update(m interface{}, keyNames ...string) (err error) { + // New entity? + if Db().NewRecord(m) { + return fmt.Errorf("new record") + } + + values, keys, err := ModelValues(m, keyNames...) + + // Has keys and values? + if err != nil { + return err + } else if len(keys) != len(keyNames) { + return fmt.Errorf("record keys missing") + } + + // Perform update. + res := Db().Model(m).Updates(values) + + // Successful? + if res.Error != nil { + return err + } else if res.RowsAffected > 1 { + log.Debugf("entity: updated statement affected more than one record - bug?") + return nil + } else if res.RowsAffected == 1 { + return nil + } else if Count(m, keyNames, keys) != 1 { + return fmt.Errorf("record not found") + } + + return err +} + +// Count returns the number of records for a given a model and key values. +func Count(m interface{}, keys []string, values []interface{}) int { + if m == nil || len(keys) != len(values) { + log.Debugf("entity: invalid parameters (count records)") + return -1 + } + + var count int + + stmt := Db().Model(m) + + for k := range keys { + stmt.Where("? = ?", gorm.Expr(keys[k]), values[k]) + } + + if err := stmt.Count(&count).Error; err != nil { + log.Debugf("entity: %s (count records)", err) + return -1 + } + + return count +} diff --git a/internal/entity/entity_update_test.go b/internal/entity/entity_update_test.go new file mode 100644 index 000000000..4a0544ae6 --- /dev/null +++ b/internal/entity/entity_update_test.go @@ -0,0 +1,144 @@ +package entity + +import ( + "math/rand" + "testing" + "time" + + "github.com/photoprism/photoprism/pkg/rnd" + "github.com/stretchr/testify/assert" +) + +func TestSave(t *testing.T) { + var r = rand.New(rand.NewSource(time.Now().UnixNano())) + + t.Run("HasCreatedUpdatedAt", func(t *testing.T) { + id := 99999 + r.Intn(10000) + m := Photo{ID: uint(id), PhotoUID: rnd.PPID('p'), UpdatedAt: TimeStamp(), CreatedAt: TimeStamp()} + + if err := m.Save(); err != nil { + t.Fatal(err) + return + } + + if err := m.Find(); err != nil { + t.Fatal(err) + return + } + }) + t.Run("HasCreatedAt", func(t *testing.T) { + id := 99999 + r.Intn(10000) + m := Photo{ID: uint(id), PhotoUID: rnd.PPID('p'), CreatedAt: TimeStamp()} + + if err := m.Save(); err != nil { + t.Fatal(err) + return + } + + if err := m.Find(); err != nil { + t.Fatal(err) + return + } + }) + t.Run("NoCreatedAt", func(t *testing.T) { + id := 99999 + r.Intn(10000) + m := Photo{ID: uint(id), PhotoUID: rnd.PPID('p'), CreatedAt: TimeStamp()} + + if err := m.Save(); err != nil { + t.Fatal(err) + return + } + + if err := m.Find(); err != nil { + t.Fatal(err) + return + } + }) +} + +func TestUpdate(t *testing.T) { + var r = rand.New(rand.NewSource(time.Now().UnixNano())) + t.Run("IDMissing", func(t *testing.T) { + uid := rnd.PPID('p') + m := &Photo{ID: 0, PhotoUID: uid, UpdatedAt: TimeStamp(), CreatedAt: TimeStamp(), PhotoTitle: "Foo"} + updatedAt := m.UpdatedAt + + err := Update(m, "ID", "PhotoUID") + + if err == nil { + t.Fatal("error expected") + } + + assert.ErrorContains(t, err, "new record") + assert.Equal(t, m.UpdatedAt.UTC(), updatedAt.UTC()) + }) + t.Run("UIDMissing", func(t *testing.T) { + id := 99999 + r.Intn(10000) + m := &Photo{ID: uint(id), PhotoUID: "", UpdatedAt: TimeStamp(), CreatedAt: TimeStamp(), PhotoTitle: "Foo"} + updatedAt := m.UpdatedAt + + err := Update(m, "ID", "PhotoUID") + + if err == nil { + t.Fatal("error expected") + } + + assert.ErrorContains(t, err, "record keys missing") + assert.Equal(t, m.UpdatedAt.UTC(), updatedAt.UTC()) + }) + t.Run("NotUpdated", func(t *testing.T) { + id := 99999 + r.Intn(10000) + uid := rnd.PPID('p') + m := &Photo{ID: uint(id), PhotoUID: uid, UpdatedAt: time.Now(), CreatedAt: TimeStamp(), PhotoTitle: "Foo"} + updatedAt := m.UpdatedAt + + err := Update(m, "ID", "PhotoUID") + + if err == nil { + t.Fatal("error expected") + } + + assert.ErrorContains(t, err, "record not found") + assert.Greater(t, m.UpdatedAt.UTC(), updatedAt.UTC()) + }) + t.Run("Photo01", func(t *testing.T) { + m := PhotoFixtures.Pointer("Photo01") + updatedAt := m.UpdatedAt + + // Should be updated without any issues. + if err := Update(m, "ID", "PhotoUID"); err != nil { + assert.Greater(t, m.UpdatedAt.UTC(), updatedAt.UTC()) + t.Fatal(err) + return + } else { + assert.Greater(t, m.UpdatedAt.UTC(), updatedAt.UTC()) + t.Logf("(1) UpdatedAt: %s -> %s", updatedAt.UTC(), m.UpdatedAt.UTC()) + t.Logf("(1) Successfully updated values") + } + + // Tests that no error is returned on MySQL/MariaDB although + // the number of affected rows is 0. + if err := Update(m, "ID", "PhotoUID"); err != nil { + assert.Greater(t, m.UpdatedAt.UTC(), updatedAt.UTC()) + t.Fatal(err) + return + } else { + assert.Greater(t, m.UpdatedAt.UTC(), updatedAt.UTC()) + t.Logf("(2) UpdatedAt: %s -> %s", updatedAt.UTC(), m.UpdatedAt.UTC()) + t.Logf("(2) Successfully updated values") + } + }) + t.Run("NonExistentKeys", func(t *testing.T) { + m := PhotoFixtures.Pointer("Photo01") + m.ID = uint(99999 + r.Intn(10000)) + m.PhotoUID = rnd.PPID('p') + updatedAt := m.UpdatedAt + if err := Update(m, "ID", "PhotoUID"); err == nil { + t.Fatal("error expected") + return + } else { + assert.ErrorContains(t, err, "record not found") + assert.Greater(t, m.UpdatedAt.UTC(), updatedAt.UTC()) + } + }) +} diff --git a/internal/entity/entity_values.go b/internal/entity/entity_values.go new file mode 100644 index 000000000..df3f01bc7 --- /dev/null +++ b/internal/entity/entity_values.go @@ -0,0 +1,81 @@ +package entity + +import ( + "fmt" + "reflect" +) + +// Values is a shortcut for map[string]interface{} +type Values map[string]interface{} + +// ModelValues extracts Values from an entity model. +func ModelValues(m interface{}, keyNames ...string) (result Values, keys []interface{}, err error) { + isKey := func(name string) bool { + for _, s := range keyNames { + if name == s { + return true + } + } + + return false + } + + r := reflect.ValueOf(m) + + if r.Kind() != reflect.Pointer { + return result, keys, fmt.Errorf("model interface expected") + } + + values := r.Elem() + + if kind := values.Kind(); kind != reflect.Struct { + return result, keys, fmt.Errorf("model expected") + } + + t := values.Type() + num := t.NumField() + + keys = make([]interface{}, 0, len(keyNames)) + result = make(map[string]interface{}, num) + + // Add exported fields to result. + for i := 0; i < num; i++ { + field := t.Field(i) + + // Skip non-exported fields. + if !field.IsExported() { + continue + } + + name := field.Name + + // Skip timestamps. + if name == "" || name == "UpdatedAt" || name == "CreatedAt" { + continue + } + + v := values.Field(i) + + // Skip read-only fields. + if !v.CanSet() { + continue + } + + // Skip keys. + if isKey(name) { + if !v.IsZero() { + keys = append(keys, v.Interface()) + } + continue + } + + // Add value to result. + result[name] = v.Interface() + } + + if len(result) == 0 { + return result, keys, fmt.Errorf("no values") + } + + return result, keys, nil +} diff --git a/internal/entity/entity_values_test.go b/internal/entity/entity_values_test.go new file mode 100644 index 000000000..3b61ab2ba --- /dev/null +++ b/internal/entity/entity_values_test.go @@ -0,0 +1,66 @@ +package entity + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestModelValues(t *testing.T) { + t.Run("NoInterface", func(t *testing.T) { + m := Photo{} + values, keys, err := ModelValues(m, "ID", "PhotoUID") + + assert.Error(t, err) + assert.IsType(t, Values{}, values) + assert.Len(t, keys, 0) + }) + t.Run("NewPhoto", func(t *testing.T) { + m := &Photo{} + values, keys, err := ModelValues(m, "ID", "PhotoUID") + + if err != nil { + t.Fatal(err) + } + + assert.Len(t, keys, 0) + assert.NotNil(t, values) + assert.IsType(t, Values{}, values) + }) + t.Run("ExistingPhoto", func(t *testing.T) { + m := PhotoFixtures.Pointer("Photo01") + values, keys, err := ModelValues(m, "ID", "PhotoUID") + + if err != nil { + t.Fatal(err) + } + + assert.Len(t, keys, 2) + assert.NotNil(t, values) + assert.IsType(t, Values{}, values) + }) + t.Run("NewFace", func(t *testing.T) { + m := &Face{} + values, keys, err := ModelValues(m, "ID") + + if err != nil { + t.Fatal(err) + } + + assert.Len(t, keys, 0) + assert.NotNil(t, values) + assert.IsType(t, Values{}, values) + }) + t.Run("ExistingFace", func(t *testing.T) { + m := FaceFixtures.Pointer("john-doe") + values, keys, err := ModelValues(m, "ID") + + if err != nil { + t.Fatal(err) + } + + assert.Len(t, keys, 1) + assert.NotNil(t, values) + assert.IsType(t, Values{}, values) + }) +} diff --git a/internal/entity/face.go b/internal/entity/face.go index 40912d91c..ed67592a0 100644 --- a/internal/entity/face.go +++ b/internal/entity/face.go @@ -303,18 +303,6 @@ func (m *Face) Show() (err error) { return m.Update("FaceHidden", false) } -// Save updates the existing or inserts a new face. -func (m *Face) Save() error { - if m.ID == "" { - return fmt.Errorf("empty id") - } - - faceMutex.Lock() - defer faceMutex.Unlock() - - return Save(m, "ID") -} - // Create inserts the face to the database. func (m *Face) Create() error { if m.ID == "" { diff --git a/internal/entity/face_test.go b/internal/entity/face_test.go index 0310df6ad..77b24a015 100644 --- a/internal/entity/face_test.go +++ b/internal/entity/face_test.go @@ -203,7 +203,7 @@ func TestFace_Save(t *testing.T) { assert.Nil(t, FindFace(m.ID)) - if err := m.Save(); err != nil { + if err := m.Create(); err != nil { t.Fatal(err) } @@ -213,7 +213,7 @@ func TestFace_Save(t *testing.T) { t.Run("Error", func(t *testing.T) { m := NewFace("12345fde", SrcAuto, face.Embeddings{face.Embedding{1}, face.Embedding{2}}) assert.Nil(t, FindFace(m.ID)) - assert.Error(t, m.Save()) + assert.Error(t, m.Create()) assert.Nil(t, FindFace(m.ID)) }) } @@ -227,7 +227,7 @@ func TestFace_Update(t *testing.T) { assert.Nil(t, FindFace(id)) - if err := m.Save(); err != nil { + if err := m.Create(); err != nil { t.Fatal(err) return } diff --git a/internal/entity/marker.go b/internal/entity/marker.go index 5697633e9..d50fb9f5a 100644 --- a/internal/entity/marker.go +++ b/internal/entity/marker.go @@ -697,24 +697,6 @@ func FindMarker(markerUid string) *Marker { return &result } -// FindFaceMarker finds the best marker for a given face -func FindFaceMarker(faceId string) *Marker { - if faceId == "" { - return nil - } - - var result Marker - - if err := Db().Where("face_id = ?", faceId). - Where("thumb <> '' AND marker_invalid = 0"). - Order("face_dist ASC, q DESC").First(&result).Error; err != nil { - log.Warnf("markers: found no marker for face %s", sanitize.Log(faceId)) - return nil - } - - return &result -} - // CreateMarkerIfNotExists updates a marker in the database or creates a new one if needed. func CreateMarkerIfNotExists(m *Marker) (*Marker, error) { result := Marker{} @@ -727,7 +709,7 @@ func CreateMarkerIfNotExists(m *Marker) (*Marker, error) { } else if err := m.Create(); err != nil { return m, err } else { - log.Debugf("markers: added %s marker %s for %s", TypeString(m.MarkerType), sanitize.Log(m.MarkerUID), sanitize.Log(m.FileUID)) + log.Debugf("markers: added %s %s for file %s", TypeString(m.MarkerType), sanitize.Log(m.MarkerUID), sanitize.Log(m.FileUID)) } return m, nil diff --git a/internal/entity/string_values.go b/internal/entity/string_values.go deleted file mode 100644 index 8771e65c0..000000000 --- a/internal/entity/string_values.go +++ /dev/null @@ -1,49 +0,0 @@ -package entity - -import ( - "reflect" -) - -// Values is a shortcut for map[string]interface{} -type Values map[string]interface{} - -// GetValues extracts entity Values. -func GetValues(m interface{}, omit ...string) (result Values) { - skip := func(name string) bool { - if name == "" || name == "UpdatedAt" || name == "CreatedAt" { - return true - } - - for _, s := range omit { - if name == s { - return true - } - } - - return false - } - - result = make(map[string]interface{}) - - elem := reflect.ValueOf(m).Elem() - relType := elem.Type() - num := relType.NumField() - - result = make(map[string]interface{}, num) - - // Add exported fields to result. - for i := 0; i < num; i++ { - n := relType.Field(i).Name - v := elem.Field(i) - - if !v.CanSet() { - continue - } else if skip(n) { - continue - } - - result[n] = elem.Field(i).Interface() - } - - return result -}