diff --git a/internal/entity/face_test.go b/internal/entity/face_test.go index 90195118a..b41cfbf21 100644 --- a/internal/entity/face_test.go +++ b/internal/entity/face_test.go @@ -52,36 +52,56 @@ func TestFace_Match(t *testing.T) { } func TestFace_ReportCollision(t *testing.T) { - m := FaceFixtures.Get("joe-biden") + t.Run("collision", func(t *testing.T) { + m := FaceFixtures.Get("joe-biden") - assert.Zero(t, m.Collisions) - assert.Zero(t, m.CollisionRadius) + assert.Zero(t, m.Collisions) + assert.Zero(t, m.CollisionRadius) - if reported, err := m.ReportCollision(MarkerFixtures.Pointer("1000003-4").Embeddings()); err != nil { - t.Fatal(err) - } else { - assert.True(t, reported) - } + if reported, err := m.ReportCollision(MarkerFixtures.Pointer("1000003-4").Embeddings()); err != nil { + t.Fatal(err) + } else { + assert.True(t, reported) + } - // Number of collisions must have increased by one. - assert.Equal(t, 1, m.Collisions) + // Number of collisions must have increased by one. + assert.Equal(t, 1, m.Collisions) - // Actual distance is ~1.314040 - assert.Greater(t, m.CollisionRadius, 1.2) - assert.Less(t, m.CollisionRadius, 1.314) + // Actual distance is ~1.314040 + assert.Greater(t, m.CollisionRadius, 1.2) + assert.Less(t, m.CollisionRadius, 1.314) - if reported, err := m.ReportCollision(MarkerFixtures.Pointer("1000003-6").Embeddings()); err != nil { - t.Fatal(err) - } else { - assert.False(t, reported) - } + if reported, err := m.ReportCollision(MarkerFixtures.Pointer("1000003-6").Embeddings()); err != nil { + t.Fatal(err) + } else { + assert.False(t, reported) + } - // Number of collisions must not have increased. - assert.Equal(t, 1, m.Collisions) + // Number of collisions must not have increased. + assert.Equal(t, 1, m.Collisions) - // Actual distance is ~1.272604 - assert.Greater(t, m.CollisionRadius, 1.1) - assert.Less(t, m.CollisionRadius, 1.272) + // Actual distance is ~1.272604 + assert.Greater(t, m.CollisionRadius, 1.1) + assert.Less(t, m.CollisionRadius, 1.272) + }) + t.Run("subject id empty", func(t *testing.T) { + m := NewFace("", SrcAuto, Embeddings{}) + if reported, err := m.ReportCollision(MarkerFixtures.Pointer("1000003-4").Embeddings()); err != nil { + t.Fatal(err) + } else { + assert.False(t, reported) + } + }) + t.Run("invalid face id", func(t *testing.T) { + m := NewFace("123", SrcAuto, Embeddings{}) + m.ID = "" + if reported, err := m.ReportCollision(MarkerFixtures.Pointer("1000003-4").Embeddings()); err == nil { + t.Fatal(err) + } else { + assert.False(t, reported) + assert.Equal(t, "invalid face id", err.Error()) + } + }) } func TestFace_ReviseMatches(t *testing.T) { @@ -110,40 +130,84 @@ func TestFace_SetEmbeddings(t *testing.T) { t.Run("success", func(t *testing.T) { marker := MarkerFixtures.Get("1000003-4") e := marker.Embeddings() - f := FaceFixtures.Get("joe-biden") - assert.NotEqual(t, e[0][0], f.Embedding()[0]) + m := FaceFixtures.Get("joe-biden") + assert.NotEqual(t, e[0][0], m.Embedding()[0]) - err := f.SetEmbeddings(e) + err := m.SetEmbeddings(e) if err != nil { t.Fatal(err) } - assert.Equal(t, e[0][0], f.Embedding()[0]) + assert.Equal(t, e[0][0], m.Embedding()[0]) }) } func TestFace_Embedding(t *testing.T) { t.Run("success", func(t *testing.T) { - f := FaceFixtures.Get("joe-biden") + m := FaceFixtures.Get("joe-biden") - assert.Equal(t, 0.10730543085474682, f.Embedding()[0]) + assert.Equal(t, 0.10730543085474682, m.Embedding()[0]) }) t.Run("empty embedding", func(t *testing.T) { - f := NewFace("12345", SrcAuto, Embeddings{}) + m := NewFace("12345", SrcAuto, Embeddings{}) + m.EmbeddingJSON = []byte("") - assert.Empty(t, f.Embedding()) + assert.Empty(t, m.Embedding()) }) t.Run("invalid embedding json", func(t *testing.T) { - f := NewFace("12345", SrcAuto, Embeddings{}) - f.EmbeddingJSON = []byte("[false]") + m := NewFace("12345", SrcAuto, Embeddings{}) + m.EmbeddingJSON = []byte("[false]") - assert.Equal(t, float64(0), f.Embedding()[0]) + assert.Equal(t, float64(0), m.Embedding()[0]) }) } func TestFace_UpdateMatchTime(t *testing.T) { - f := NewFace("12345", SrcAuto, Embeddings{}) - initialMatchTime := f.MatchedAt - assert.Equal(t, initialMatchTime, f.MatchedAt) - f.UpdateMatchTime() - assert.NotEqual(t, initialMatchTime, f.MatchedAt) + m := NewFace("12345", SrcAuto, Embeddings{}) + initialMatchTime := m.MatchedAt + assert.Equal(t, initialMatchTime, m.MatchedAt) + m.UpdateMatchTime() + assert.NotEqual(t, initialMatchTime, m.MatchedAt) +} + +func TestFace_Save(t *testing.T) { + m := NewFace("12345fde", SrcAuto, Embeddings{Embedding{1}, Embedding{2}}) + assert.Nil(t, FindFace(m.ID)) + m.Save() + assert.NotNil(t, FindFace(m.ID)) + assert.Equal(t, "12345fde", FindFace(m.ID).SubjectUID) +} + +func TestFace_Update(t *testing.T) { + m := NewFace("12345fdef", SrcAuto, Embeddings{Embedding{8}, Embedding{16}}) + assert.Nil(t, FindFace(m.ID)) + m.Save() + assert.NotNil(t, FindFace(m.ID)) + assert.Equal(t, "12345fdef", FindFace(m.ID).SubjectUID) + + m2 := FindFace(m.ID) + m2.Update("SubjectUID", "new") + assert.Equal(t, "new", FindFace(m.ID).SubjectUID) +} + +func TestFirstOrCreateFace(t *testing.T) { + t.Run("create new face", func(t *testing.T) { + m := NewFace("12345unique", SrcAuto, Embeddings{Embedding{99}, Embedding{2}}) + r := FirstOrCreateFace(m) + assert.Equal(t, "12345unique", r.SubjectUID) + }) + t.Run("return existing entity", func(t *testing.T) { + m := FaceFixtures.Pointer("joe-biden") + r := FirstOrCreateFace(m) + assert.Equal(t, "jqy3y652h8njw0sx", r.SubjectUID) + assert.Equal(t, 33, r.Samples) + }) +} + +func TestFindFace(t *testing.T) { + t.Run("existing face", func(t *testing.T) { + assert.Equal(t, 3, FindFace("VF7ANLDET2BKZNT4VQWJMMC6HBEFDOG7").Samples) + }) + t.Run("empty id", func(t *testing.T) { + assert.Nil(t, FindFace("")) + }) }