Index: Add tests and refactor database record updates #1438

This commit is contained in:
Michael Mayer 2022-04-05 11:40:53 +02:00
parent 9986986f8f
commit 7b508d6ad5
10 changed files with 384 additions and 199 deletions

View file

@ -1,52 +0,0 @@
package entity
import (
"fmt"
"reflect"
"runtime/debug"
"strings"
)
// Save updates an entity in the database, or inserts if it doesn't exist.
func Save(m interface{}, primaryKeys ...string) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("index: save failed (%s)\nstack: %s", r, debug.Stack())
log.Error(err)
}
}()
if err := Update(m, primaryKeys...); err == nil {
return nil
} else if err := UnscopedDb().Save(m).Error; err == nil {
return nil
} else if !strings.Contains(strings.ToLower(err.Error()), "lock") {
return err
} else if err := UnscopedDb().Save(m).Error; err != nil {
return err
}
return nil
}
// Update updates an existing entity in the database.
func Update(m interface{}, primaryKeys ...string) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("index: update failed (%s)\nstack: %s", r, debug.Stack())
log.Error(err)
}
}()
// Return with error if a primary key is empty.
v := reflect.ValueOf(m).Elem()
for _, k := range primaryKeys {
if field := v.FieldByName(k); !field.CanSet() || field.IsZero() {
return fmt.Errorf("empty primary key '%s'", k)
}
}
err = UnscopedDb().FirstOrCreate(m, GetValues(m)).Error
return err
}

View file

@ -1,64 +0,0 @@
package entity
import (
"testing"
"time"
"github.com/photoprism/photoprism/internal/face"
"github.com/photoprism/photoprism/pkg/rnd"
"github.com/stretchr/testify/assert"
)
func TestUpdate(t *testing.T) {
t.Run("HasCreatedUpdatedAt", func(t *testing.T) {
m := NewFace(rnd.PPID('j'), SrcAuto, face.RandomEmbeddings(1, face.RegularFace))
id := m.ID
m.CreatedAt = time.Now()
m.UpdatedAt = time.Now()
if err := m.Save(); err != nil {
t.Fatal(err)
return
}
found := FindFace(id)
assert.NotNil(t, found)
assert.Equal(t, id, found.ID)
assert.Greater(t, time.Now(), m.UpdatedAt)
assert.Equal(t, found.CreatedAt.UTC(), m.CreatedAt.UTC())
})
t.Run("HasCreatedAt", func(t *testing.T) {
m := NewFace(rnd.PPID('j'), SrcAuto, face.RandomEmbeddings(1, face.RegularFace))
id := m.ID
m.CreatedAt = time.Now()
if err := m.Save(); err != nil {
t.Fatal(err)
return
}
found := FindFace(id)
assert.NotNil(t, found)
assert.Equal(t, id, found.ID)
assert.Greater(t, time.Now().UTC(), m.UpdatedAt.UTC())
assert.Equal(t, found.CreatedAt.UTC(), m.CreatedAt.UTC())
})
t.Run("NoCreatedAt", func(t *testing.T) {
m := NewFace(rnd.PPID('j'), SrcAuto, face.RandomEmbeddings(1, face.RegularFace))
id := m.ID
if err := m.Save(); err != nil {
t.Fatal(err)
return
}
found := FindFace(id)
assert.NotNil(t, found)
assert.Equal(t, id, found.ID)
assert.Greater(t, time.Now(), m.UpdatedAt.UTC())
assert.Equal(t, found.CreatedAt.UTC(), m.CreatedAt.UTC())
})
}

View file

@ -0,0 +1,89 @@
package entity
import (
"fmt"
"runtime/debug"
"strings"
"github.com/jinzhu/gorm"
)
// Save updates a record in the database, or inserts if it doesn't exist.
func Save(m interface{}, keyNames ...string) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("index: save failed (%s)\nstack: %s", r, debug.Stack())
log.Error(err)
}
}()
// Try updating first.
if err = Update(m, keyNames...); err == nil {
return nil
} else if err = UnscopedDb().Save(m).Error; err == nil {
return nil
} else if !strings.Contains(strings.ToLower(err.Error()), "lock") {
return err
} else if err = UnscopedDb().Save(m).Error; err != nil {
return err
}
return nil
}
// Update updates an existing record in the database.
func Update(m interface{}, keyNames ...string) (err error) {
// New entity?
if Db().NewRecord(m) {
return fmt.Errorf("new record")
}
values, keys, err := ModelValues(m, keyNames...)
// Has keys and values?
if err != nil {
return err
} else if len(keys) != len(keyNames) {
return fmt.Errorf("record keys missing")
}
// Perform update.
res := Db().Model(m).Updates(values)
// Successful?
if res.Error != nil {
return err
} else if res.RowsAffected > 1 {
log.Debugf("entity: updated statement affected more than one record - bug?")
return nil
} else if res.RowsAffected == 1 {
return nil
} else if Count(m, keyNames, keys) != 1 {
return fmt.Errorf("record not found")
}
return err
}
// Count returns the number of records for a given a model and key values.
func Count(m interface{}, keys []string, values []interface{}) int {
if m == nil || len(keys) != len(values) {
log.Debugf("entity: invalid parameters (count records)")
return -1
}
var count int
stmt := Db().Model(m)
for k := range keys {
stmt.Where("? = ?", gorm.Expr(keys[k]), values[k])
}
if err := stmt.Count(&count).Error; err != nil {
log.Debugf("entity: %s (count records)", err)
return -1
}
return count
}

View file

@ -0,0 +1,144 @@
package entity
import (
"math/rand"
"testing"
"time"
"github.com/photoprism/photoprism/pkg/rnd"
"github.com/stretchr/testify/assert"
)
func TestSave(t *testing.T) {
var r = rand.New(rand.NewSource(time.Now().UnixNano()))
t.Run("HasCreatedUpdatedAt", func(t *testing.T) {
id := 99999 + r.Intn(10000)
m := Photo{ID: uint(id), PhotoUID: rnd.PPID('p'), UpdatedAt: TimeStamp(), CreatedAt: TimeStamp()}
if err := m.Save(); err != nil {
t.Fatal(err)
return
}
if err := m.Find(); err != nil {
t.Fatal(err)
return
}
})
t.Run("HasCreatedAt", func(t *testing.T) {
id := 99999 + r.Intn(10000)
m := Photo{ID: uint(id), PhotoUID: rnd.PPID('p'), CreatedAt: TimeStamp()}
if err := m.Save(); err != nil {
t.Fatal(err)
return
}
if err := m.Find(); err != nil {
t.Fatal(err)
return
}
})
t.Run("NoCreatedAt", func(t *testing.T) {
id := 99999 + r.Intn(10000)
m := Photo{ID: uint(id), PhotoUID: rnd.PPID('p'), CreatedAt: TimeStamp()}
if err := m.Save(); err != nil {
t.Fatal(err)
return
}
if err := m.Find(); err != nil {
t.Fatal(err)
return
}
})
}
func TestUpdate(t *testing.T) {
var r = rand.New(rand.NewSource(time.Now().UnixNano()))
t.Run("IDMissing", func(t *testing.T) {
uid := rnd.PPID('p')
m := &Photo{ID: 0, PhotoUID: uid, UpdatedAt: TimeStamp(), CreatedAt: TimeStamp(), PhotoTitle: "Foo"}
updatedAt := m.UpdatedAt
err := Update(m, "ID", "PhotoUID")
if err == nil {
t.Fatal("error expected")
}
assert.ErrorContains(t, err, "new record")
assert.Equal(t, m.UpdatedAt.UTC(), updatedAt.UTC())
})
t.Run("UIDMissing", func(t *testing.T) {
id := 99999 + r.Intn(10000)
m := &Photo{ID: uint(id), PhotoUID: "", UpdatedAt: TimeStamp(), CreatedAt: TimeStamp(), PhotoTitle: "Foo"}
updatedAt := m.UpdatedAt
err := Update(m, "ID", "PhotoUID")
if err == nil {
t.Fatal("error expected")
}
assert.ErrorContains(t, err, "record keys missing")
assert.Equal(t, m.UpdatedAt.UTC(), updatedAt.UTC())
})
t.Run("NotUpdated", func(t *testing.T) {
id := 99999 + r.Intn(10000)
uid := rnd.PPID('p')
m := &Photo{ID: uint(id), PhotoUID: uid, UpdatedAt: time.Now(), CreatedAt: TimeStamp(), PhotoTitle: "Foo"}
updatedAt := m.UpdatedAt
err := Update(m, "ID", "PhotoUID")
if err == nil {
t.Fatal("error expected")
}
assert.ErrorContains(t, err, "record not found")
assert.Greater(t, m.UpdatedAt.UTC(), updatedAt.UTC())
})
t.Run("Photo01", func(t *testing.T) {
m := PhotoFixtures.Pointer("Photo01")
updatedAt := m.UpdatedAt
// Should be updated without any issues.
if err := Update(m, "ID", "PhotoUID"); err != nil {
assert.Greater(t, m.UpdatedAt.UTC(), updatedAt.UTC())
t.Fatal(err)
return
} else {
assert.Greater(t, m.UpdatedAt.UTC(), updatedAt.UTC())
t.Logf("(1) UpdatedAt: %s -> %s", updatedAt.UTC(), m.UpdatedAt.UTC())
t.Logf("(1) Successfully updated values")
}
// Tests that no error is returned on MySQL/MariaDB although
// the number of affected rows is 0.
if err := Update(m, "ID", "PhotoUID"); err != nil {
assert.Greater(t, m.UpdatedAt.UTC(), updatedAt.UTC())
t.Fatal(err)
return
} else {
assert.Greater(t, m.UpdatedAt.UTC(), updatedAt.UTC())
t.Logf("(2) UpdatedAt: %s -> %s", updatedAt.UTC(), m.UpdatedAt.UTC())
t.Logf("(2) Successfully updated values")
}
})
t.Run("NonExistentKeys", func(t *testing.T) {
m := PhotoFixtures.Pointer("Photo01")
m.ID = uint(99999 + r.Intn(10000))
m.PhotoUID = rnd.PPID('p')
updatedAt := m.UpdatedAt
if err := Update(m, "ID", "PhotoUID"); err == nil {
t.Fatal("error expected")
return
} else {
assert.ErrorContains(t, err, "record not found")
assert.Greater(t, m.UpdatedAt.UTC(), updatedAt.UTC())
}
})
}

View file

@ -0,0 +1,81 @@
package entity
import (
"fmt"
"reflect"
)
// Values is a shortcut for map[string]interface{}
type Values map[string]interface{}
// ModelValues extracts Values from an entity model.
func ModelValues(m interface{}, keyNames ...string) (result Values, keys []interface{}, err error) {
isKey := func(name string) bool {
for _, s := range keyNames {
if name == s {
return true
}
}
return false
}
r := reflect.ValueOf(m)
if r.Kind() != reflect.Pointer {
return result, keys, fmt.Errorf("model interface expected")
}
values := r.Elem()
if kind := values.Kind(); kind != reflect.Struct {
return result, keys, fmt.Errorf("model expected")
}
t := values.Type()
num := t.NumField()
keys = make([]interface{}, 0, len(keyNames))
result = make(map[string]interface{}, num)
// Add exported fields to result.
for i := 0; i < num; i++ {
field := t.Field(i)
// Skip non-exported fields.
if !field.IsExported() {
continue
}
name := field.Name
// Skip timestamps.
if name == "" || name == "UpdatedAt" || name == "CreatedAt" {
continue
}
v := values.Field(i)
// Skip read-only fields.
if !v.CanSet() {
continue
}
// Skip keys.
if isKey(name) {
if !v.IsZero() {
keys = append(keys, v.Interface())
}
continue
}
// Add value to result.
result[name] = v.Interface()
}
if len(result) == 0 {
return result, keys, fmt.Errorf("no values")
}
return result, keys, nil
}

View file

@ -0,0 +1,66 @@
package entity
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestModelValues(t *testing.T) {
t.Run("NoInterface", func(t *testing.T) {
m := Photo{}
values, keys, err := ModelValues(m, "ID", "PhotoUID")
assert.Error(t, err)
assert.IsType(t, Values{}, values)
assert.Len(t, keys, 0)
})
t.Run("NewPhoto", func(t *testing.T) {
m := &Photo{}
values, keys, err := ModelValues(m, "ID", "PhotoUID")
if err != nil {
t.Fatal(err)
}
assert.Len(t, keys, 0)
assert.NotNil(t, values)
assert.IsType(t, Values{}, values)
})
t.Run("ExistingPhoto", func(t *testing.T) {
m := PhotoFixtures.Pointer("Photo01")
values, keys, err := ModelValues(m, "ID", "PhotoUID")
if err != nil {
t.Fatal(err)
}
assert.Len(t, keys, 2)
assert.NotNil(t, values)
assert.IsType(t, Values{}, values)
})
t.Run("NewFace", func(t *testing.T) {
m := &Face{}
values, keys, err := ModelValues(m, "ID")
if err != nil {
t.Fatal(err)
}
assert.Len(t, keys, 0)
assert.NotNil(t, values)
assert.IsType(t, Values{}, values)
})
t.Run("ExistingFace", func(t *testing.T) {
m := FaceFixtures.Pointer("john-doe")
values, keys, err := ModelValues(m, "ID")
if err != nil {
t.Fatal(err)
}
assert.Len(t, keys, 1)
assert.NotNil(t, values)
assert.IsType(t, Values{}, values)
})
}

View file

@ -303,18 +303,6 @@ func (m *Face) Show() (err error) {
return m.Update("FaceHidden", false)
}
// Save updates the existing or inserts a new face.
func (m *Face) Save() error {
if m.ID == "" {
return fmt.Errorf("empty id")
}
faceMutex.Lock()
defer faceMutex.Unlock()
return Save(m, "ID")
}
// Create inserts the face to the database.
func (m *Face) Create() error {
if m.ID == "" {

View file

@ -203,7 +203,7 @@ func TestFace_Save(t *testing.T) {
assert.Nil(t, FindFace(m.ID))
if err := m.Save(); err != nil {
if err := m.Create(); err != nil {
t.Fatal(err)
}
@ -213,7 +213,7 @@ func TestFace_Save(t *testing.T) {
t.Run("Error", func(t *testing.T) {
m := NewFace("12345fde", SrcAuto, face.Embeddings{face.Embedding{1}, face.Embedding{2}})
assert.Nil(t, FindFace(m.ID))
assert.Error(t, m.Save())
assert.Error(t, m.Create())
assert.Nil(t, FindFace(m.ID))
})
}
@ -227,7 +227,7 @@ func TestFace_Update(t *testing.T) {
assert.Nil(t, FindFace(id))
if err := m.Save(); err != nil {
if err := m.Create(); err != nil {
t.Fatal(err)
return
}

View file

@ -697,24 +697,6 @@ func FindMarker(markerUid string) *Marker {
return &result
}
// FindFaceMarker finds the best marker for a given face
func FindFaceMarker(faceId string) *Marker {
if faceId == "" {
return nil
}
var result Marker
if err := Db().Where("face_id = ?", faceId).
Where("thumb <> '' AND marker_invalid = 0").
Order("face_dist ASC, q DESC").First(&result).Error; err != nil {
log.Warnf("markers: found no marker for face %s", sanitize.Log(faceId))
return nil
}
return &result
}
// CreateMarkerIfNotExists updates a marker in the database or creates a new one if needed.
func CreateMarkerIfNotExists(m *Marker) (*Marker, error) {
result := Marker{}
@ -727,7 +709,7 @@ func CreateMarkerIfNotExists(m *Marker) (*Marker, error) {
} else if err := m.Create(); err != nil {
return m, err
} else {
log.Debugf("markers: added %s marker %s for %s", TypeString(m.MarkerType), sanitize.Log(m.MarkerUID), sanitize.Log(m.FileUID))
log.Debugf("markers: added %s %s for file %s", TypeString(m.MarkerType), sanitize.Log(m.MarkerUID), sanitize.Log(m.FileUID))
}
return m, nil

View file

@ -1,49 +0,0 @@
package entity
import (
"reflect"
)
// Values is a shortcut for map[string]interface{}
type Values map[string]interface{}
// GetValues extracts entity Values.
func GetValues(m interface{}, omit ...string) (result Values) {
skip := func(name string) bool {
if name == "" || name == "UpdatedAt" || name == "CreatedAt" {
return true
}
for _, s := range omit {
if name == s {
return true
}
}
return false
}
result = make(map[string]interface{})
elem := reflect.ValueOf(m).Elem()
relType := elem.Type()
num := relType.NumField()
result = make(map[string]interface{}, num)
// Add exported fields to result.
for i := 0; i < num; i++ {
n := relType.Field(i).Name
v := elem.Field(i)
if !v.CanSet() {
continue
} else if skip(n) {
continue
}
result[n] = elem.Field(i).Interface()
}
return result
}