photoprism/internal/face/embeddings_test.go

100 lines
27 KiB
Go
Raw Normal View History

package face
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestEmbeddings_Contains(t *testing.T) {
t.Run("Blacklist", func(t *testing.T) {
e1 := Embedding{-0.00067741907, -0.0473858, -0.055704225, 0.066581585, 0.005290037, 0.03822874, -0.05082791, 0.049087, 0.018308854, -0.01478969, -0.0017964382, 0.057903986, -0.02245441, -0.033446193, -0.0071987803, -0.03150515, 0.017724376, 0.019948162, 0.056362633, -0.11701946, -0.09725581, 0.013104323, 0.07357148, -0.020104313, -0.014956488, -0.017257847, -0.020743687, -0.046340823, -0.021527996, 0.057039287, -0.051751148, -0.00595411, 0.011733857, -0.04644382, -0.028774152, 0.038308877, -0.015685631, 0.0038190065, -0.094677314, 0.07005245, -0.003068884, 0.050274804, -0.018143218, 0.012056174, 0.019058047, -0.00077958487, 0.030408988, 0.0885491, -0.11106459, -0.019910613, 0.01465938, 0.032499526, 0.04720623, 0.03163276, -0.079508044, 0.07914293, 0.03724754, 0.05636608, -0.016979787, -0.04298359, -0.03599786, 0.058342796, -0.024069536, -0.03200436, 0.015390179, 0.07293075, -0.0013783299, -0.01605161, -0.0045498423, -0.016551308, 0.060823392, 0.073548146, -0.021239666, -0.036749583, 0.00028047478, -0.025989566, -0.05310693, -0.009845653, -0.012862101, 0.058108855, 0.06915682, 0.021099294, 0.001225616, 0.033261415, 0.042154703, 0.05271609, -0.030613305, 0.041209485, 0.0019353209, -0.0019135632, 0.080911145, 0.011060149, 0.05281461, -0.10160427, 0.053647887, 0.01762615, -0.018661015, 0.02777136, -0.056023918, 0.057687696, 0.06513923, -0.069180146, -0.015231091, 0.03774808, 0.012945786, 0.090585396, 0.019942591, -0.015280875, -0.0011742035, -0.042135417, 0.05812486, 0.05519544, -0.06453465, -0.04014237, -0.03199121, 0.009241696, 0.043912023, 0.015467731, 0.0020673021, 0.08763925, -0.08974939, -0.05304043, 0.016309343, -0.04766073, -0.0046682972, -0.040477246, 0.012679025, 0.005960806, -0.03256039, -0.07089416, -0.02648642, -0.062463485, 0.08295411, -0.037147924, -0.074104264, 0.077417135, -0.042383663, 0.002088579, 0.07709948, 0.06521331, -0.07541816, 0.057679284, -0.0038482754, 0.055191133, 0.058614884, -0.018541405, -0.012277692, 0.057926673, -0.01724738, 0.020869015, 0.046103075, -0.0319926, -0.0671411, 0.02629761, 0.044356663, 0.036788594, 0.028035736, 0.06419986, 0.045972086, -0.044160895, 0.09784713, 0.00953585, -0.06252615, 0.025597766, 0.029688764, 0.027506752, -0.055384982, -0.028262418, -0.057812758, 0.042470966, -0.05840525, -0.03801629, 0.0043816785, -0.015086851, 0.022297874, 0.054110903, 0.07420415, -0.040949743, -0.08912868, -0.060081407, -0.046966024, -0.04826231, -0.031198893, 0.06643161, 0.01347482, 0.029003717, 0.047974125, 0.07580259, 0.0364837, 0.012980756, 0.020622224, -0.022555852, 0.10519882, -0.03425076, 0.03889103, -0.007944982, 0.060116176, 0.038143042, -0.043681793, 0.05316621, 0.0016697546, -0.0033659195, 0.020053076, -0.07204762, 0.009732797, 0.03862544, 0.03913928, -0.009832126, 0.06401315, 0.044209804, -0.051490918, 0.014780334, -0.025438532, -0.01395564, -0.038089562, -0.009803314, -0.04146325, 0.03357636, -0.009009651, 0.04373462, -0.05627207, -0.072068065, -0.007331119, -0.04238925, -0.021102922, -0.021610938, -0.063095644, -0.05075978, -0.07491732, 0.0113026835, 0.04940704, 0.084163934, -0.01636119, -0.06292533, -0.014526007, -0.007826649, 0.07212185, 0.004734863, -0.062791124, -0.016170232, -0.016590146, -0.024280416, -0.019522576, 0.01579864, 0.002749153, 0.053687476, -0.016495354, 0.03662209, -0.018870195, 0.039446633, 0.04715818, 0.013310755, 0.028033575, 0.067203104, -0.005826697, 0.042239364, -0.020745477, 0.0015561507, 0.065384425, 0.09041862, 0.01788624, -0.010026493, 0.0639752, -0.0041079777, -0.058741663, 0.034099422, 0.016963305, -0.0052649197, -0.07403116, -0.0029912698, -0.07583906, 0.001012409, -0.029492352, -0.04166984, -0.048218843, -0.06306306, -0.0033243645, -0.02919451, 0.0039520795, -0.037993237, 0.044352133, 0.03976749, -0.030565975, -0.076377496, -0.06569797, 0.01630497, 0.017201763, 0.028935125, 0.004058029, -0.09946857, 0.036038112, -0.06792709, 0.04493338, 0.035801385, -0.048764315, -0.06326099, 0.037968885, 0.038012274, 0.038041856, 0.04288072, 0.01531972, 0.032477308, 0.043926656, -0.0077148937, 0.056322563, -0.03451439, -0.0083661955, 0.0015379508, -0.007879
e2 := Embedding{-0.037731018, -0.005501065, 0.04339579, 0.050818004, -0.059338734, 0.033849984, -0.006599584, -0.0017643301, 0.049746443, -0.103716515, 0.037138782, -0.0064612515, 0.071909964, 0.013218528, -0.065359734, 0.11057091, 0.031195551, 0.025612833, 0.0075477255, -0.034973715, -0.013490629, -0.08104751, -0.022038054, -0.05304818, 0.008366317, -0.056096837, -0.008484318, 0.049539477, 0.019540254, 0.067417614, -0.027856546, -0.008532138, -0.017063588, -0.00016265438, -0.106199585, 0.03904082, 0.030587498, 0.043008707, -0.015111545, -0.022849092, 0.0025588698, -0.012814152, 0.037556626, -0.0086288145, 0.05265788, 0.011832273, 0.00015048613, -0.0081366515, 0.0013409692, 0.028389124, 0.022627315, -0.015008434, -0.0007749727, 0.013927345, -0.012275729, -0.0090859635, 0.019502806, -0.011900984, 0.016286656, 0.08094661, 0.000306613, -0.06327904, 0.018552454, 0.08885108, -0.07583091, 0.09275318, -0.018484656, 0.074180886, -0.039385945, -0.08063905, -0.05360434, -0.037074074, 0.09909196, 0.025063906, -0.009406389, -0.029612983, -0.018644262, 0.08433939, -0.04466277, -0.07118042, -0.0053266245, -0.07471344, 0.06739151, -0.05399609, 0.03125197, -0.00007781149, -0.04214992, -0.044316035, 0.025013437, 0.031466946, 0.0023496088, 0.042693187, -0.046198968, 0.026152546, -0.017578958, 0.023763098, 0.027511515, -0.05229892, -0.005204117, 0.035853546, -0.031919815, -0.027175877, -0.033706605, 0.018576957, -0.0010251165, -0.006808904, 0.009910016, -0.046926413, -0.02833718, 0.0132687995, -0.033933964, 0.06434295, 0.046245363, 0.044698197, 0.041076522, 0.04224362, -0.050834127, 0.0037004466, 0.061506275, -0.018232772, 0.067569405, 0.048701495, 0.042266034, -0.11045008, 0.03627151, -0.07259142, -0.0027725939, 0.040572345, 0.010365194, -0.018683784, 0.004533848, 0.037213936, -0.050944775, 0.07134523, 0.004012727, 0.036228556, -0.013853831, -0.06910639, -0.011394227, -0.012075533, 0.036311198, -0.02587341, -0.04086224, -0.024498813, -0.019423751, -0.022674281, 0.052483488, 0.026303312, -0.051178075, 0.008410645, 0.039851066, -0.028721321, -0.027934253, -0.029567441, 0.054549955, 0.07423011, -0.07211806, 0.015979288, 0.002092099, 0.049062036, 0.025120452, 0.045975365, 0.025024865, -0.04019101, 0.0013054911, -0.0049644294, 0.0065203104, -0.03237452, -0.020704443, 0.028736785, -0.027353559, 0.07551169, 0.0842262, 0.019992182, 0.11138123, -0.028617613, 0.06700691, -0.048681036, -0.008201593, -0.058066163, 0.027867565, 0.07693089, -0.033642102, -0.05855467, -0.08575646, -0.019721355, 0.018443357, -0.0037373751, 0.032450553, -0.0074002664, -0.028135147, 0.046631414, 0.0192969, -0.0071076434, -0.004898368, 0.011896125, -0.026020564, 0.074016415, -0.033884488, -0.07919758, 0.021606326, -0.0142197255, 0.0807476, 0.03722956, -0.0015949347, 0.008076167, -0.009640628, 0.02341143, 0.015375526, -0.059428506, 0.051759534, 0.028049389, 0.07790443, 0.0478649, 0.09191913, -0.055882096, -0.026637457, 0.01236174, -0.0033003334, 0.008522798, 0.027216703, -0.033221588, -0.028086975, -0.11505473, -0.044336796, 0.013873659, 0.03982099, 0.060988583, -0.07439005, 0.01333661, -0.004818605, 0.02561305, -0.059055943, 0.0081638545, -0.032278564, 0.046092775, 0.025316834, -0.046857174, -0.0341012, 0.04379944, -0.029710777, 0.09238533, 0.009769442, 0.018552538, -0.02632421, 0.033739865, -0.022547472, 0.016400741, -0.05336998, -0.012623122, -0.08303054, -0.010368709, 0.01690871, 0.0014627968, -0.050720602, 0.038742293, -0.065664165, -0.10676187, -0.013403239, 0.075702645, -0.055623896, -0.03871971, -0.042371742, 0.03794916, -0.0590573, -0.002583715, -0.029995736, -0.08144537, -0.043295015, -0.034286328, -0.026538746, 0.01953962, 0.08203153, 0.036415525, 0.045531306, 0.004713152, 0.026550433, -0.0055336948, -0.031087596, -0.01923592, -0.1047651, 0.051826596, -0.009522955, 0.0023846119, -0.030824797, 0.0011774554, 0.03384506, 0.010090165, -0.033521466, -0.052155476, 0.0032979914, -0.004305921, -0.08622774, 0.03262125, 0.06332183, 0.00067599304, 0.01989574, 0.04406689, 0.019945903, -0.003796719, 0.00025200442, -0.010055775, 0.04070448, -0.004082432, -0.026942603, 0.1101
assert.True(t, Blacklist.Contains(e1, BlacklistRadius))
assert.False(t, Blacklist.Contains(e2, BlacklistRadius))
assert.False(t, Blacklist.Contains(e1, 0.1))
})
}
func TestEmbeddings_First(t *testing.T) {
t.Run("Blacklist", func(t *testing.T) {
e := Blacklist.First()
assert.Equal(t,
Embedding{0.0001326936762779951, 0.010595318133709952, -0.025556722866895143, 0.0469118170440197, -0.006627591326832771, 0.05271399952471256, -0.04542037146165967, 0.027480189339257777, -0.004917271726299077, -0.07468410208821297, 0.02064464334398508, 0.027222666889429092, 0.07686506863683462, -0.03543879697099328, -0.06587888672947884, 0.00656710215844214, 0.0468103364109993, 0.026114298962056637, 0.09671456180512905, -0.10109077207744122, -0.00781591737177223, -0.03552762418985367, 0.06885470915585756, -0.012004591058939695, 0.08148624002933502, -0.028154293075203896, 0.04417960252612829, -0.04447614587843418, -0.02476067957468331, 0.08517011441290379, 0.002213838277384639, -0.04421043721958995, -0.030782480724155903, -0.022004681872203946, 0.0009286352433264256, 0.010065821232274175, -0.02192891761660576, -0.012186611769720912, -0.08409351948648691, 0.03493268555030227, -0.044278232380747795, 0.028376419097185135, 0.0037802220904268324, 0.04365090653300285, 0.020489776856265962, 0.0062931065913289785, 0.012816649512387812, 0.08537860121577978, -0.08391124662011862, -0.0074469012033659965, 0.007379985763691366, 0.010717783472500741, 0.05138006154447794, 0.03530120011419058, -0.0252615287899971, 0.048870823346078396, 0.051952130161225796, 0.05130995064973831, -0.07164964452385902, 0.008358421036973596, 0.07212860928848386, 0.023144953418523073, 0.016010764054954052, -0.0020557132083922625, 0.0006047732058505062, 0.05894300062209368, 0.018385239876806736, 0.013549384311772883, 0.013471916783601046, -0.026798027334734797, 0.055714783258736134, 0.056760589592158794, -0.03564274637028575, 0.006639213301241398, -0.010688409092836082, -0.014488881919533014, -0.041894917376339436, 0.027378612896427512, 0.04120381874963641, 0.026110727107152343, 0.014362117013661191, -0.015710897743701935, 0.01369841955602169, -0.013755762134678662, -0.025898753898218274, 0.03539643401745707, -0.02275478537194431, 0.044916168320924044, 0.06104293651878834, 0.020197266509057954, 0.03360230568796396, 0.0005524924490600824, -0.03635450592264533, -0.04944554786197841, 0.07170013058930635, -0.0030192392878234386, 0.0010193748748861253, -0.010943924076855183, -0.06179766729474068, 0.056694529950618744, 0.004520991584286094, -0.09191464446485043, -0.02631867417949252, 0.03776269545778632, -0.00876810192130506, 0.10884877666831017, 0.005137963918969035, 0.014990123057032179, -0.012330824043601751, -0.007994036190211773, 0.04028965998440981, 0.030694966204464436, 0.013997081900015473, -0.027155806310474873, 0.008769600055529736, 0.0038997385127004236, -0.003813096962403506, 0.06714510265737772, 0.006777862668968737, 0.061618989799171686, -0.016695198486559093, -0.014862070740491617, 0.047256719786673784, -0.030920962191885337, -0.010902314214035869, 0.029649023665115237, -0.03882072772830725, 0.037115989718586206, -0.00044789700768887997, -0.050685918889939785, 0.023225090000778437, 0.005942594725638628, 0.03058242436964065, -0.05067136138677597, -0.013880123384296894, 0.0521021569147706, -0.03878975426778197, 0.0057146906128764385, 0.01622107159346342, 0.0286275401012972, -0.07171816006302834, -0.024164889007806778, -0.036472175968810916, -0.06874032877385616, -0.013022568658925593, -0.02637113118544221, 0.061983236111700535, 0.0328070695977658, 0.004039793537231162, -0.022473737597465515, 0.017253196332603693, -0.04853619821369648, -0.053010642528533936, 0.04317591618746519, 0.00835538915998768, 0.015444972610566765, 0.04270970821380615, 0.028640874341363087, 0.06504016369581223, 0.03173913527280092, 0.05225925333797932, 0.03651800798252225, -0.022531148977577686, -0.0032108076702570543, 0.030786060029640794, 0.01946059288457036, -0.020198959624394774, -0.010514225577935576, -0.04853967670351267, 0.02134396170731634, -0.05255058314651251, 0.016728751827031374, 0.0456616897135973, -0.021495807450264692, 0.03207355597987771, 0.010148180415853858, 0.03585670003667474, -0.02635771268978715, -0.04969983361661434, -0.034095218405127525, -0.02873873570933938, 0.020237101707607508, -0.026095976354554296, 0.06887094769626856, 0.010905503178946674, 0.054851
e)
})
}
func TestEmbeddingsMidpoint(t *testing.T) {
t.Run("2 embeddings, 1 dimension", func(t *testing.T) {
e := Embeddings{Embedding{1}, Embedding{3}}
result, r, c := EmbeddingsMidpoint(e)
assert.Equal(t, Embedding{2}, result)
assert.Equal(t, 1.01, r)
assert.Equal(t, 2, c)
})
t.Run("3 embeddings, 1 dimension", func(t *testing.T) {
e := Embeddings{Embedding{1}, Embedding{3}, Embedding{4}}
result, r, c := EmbeddingsMidpoint(e)
assert.Equal(t, Embedding{2.6666666666666665}, result)
assert.Equal(t, 1.6766666666666665, r)
assert.Equal(t, 3, c)
})
t.Run("4 embeddings, 1 dimension", func(t *testing.T) {
e := Embeddings{Embedding{1}, Embedding{3}, Embedding{4}, Embedding{8}}
result, r, c := EmbeddingsMidpoint(e)
assert.Equal(t, Embedding{4}, result)
assert.Equal(t, 4.01, r)
assert.Equal(t, 4, c)
})
t.Run("empty embedding", func(t *testing.T) {
e := Embeddings{}
result, r, c := EmbeddingsMidpoint(e)
assert.Len(t, result, 0)
assert.Equal(t, float64(0), r)
assert.Equal(t, 0, c)
})
t.Run("embedding with different length", func(t *testing.T) {
e := Embeddings{Embedding{1}, Embedding{3, 5}}
result, r, c := EmbeddingsMidpoint(e)
assert.Equal(t, Embedding{2}, result)
assert.Equal(t, 1.01, r)
assert.Equal(t, 2, c)
})
t.Run("vectors", func(t *testing.T) {
e := Embeddings{Embedding{1, 0, 0, 0}, Embedding{0, 1, 0, 0}, Embedding{0, 0, 1, 0}, Embedding{0, 0, 0, 1}}
result, radius, count := EmbeddingsMidpoint(e)
assert.Equal(t, Embedding{0.25, 0.25, 0.25, 0.25}, result)
assert.Greater(t, 0.87612, radius)
assert.Less(t, 0.8760, radius)
assert.Equal(t, 4, count)
})
}
func TestUnmarshalEmbeddings(t *testing.T) {
t.Run("success", func(t *testing.T) {
r := UnmarshalEmbeddings("[[-0.013,-0.031]]")
assert.Equal(t, Embeddings{{-0.013, -0.031}}, r)
})
t.Run("no prefix", func(t *testing.T) {
r := UnmarshalEmbeddings("-0.013,-0.031]")
assert.Nil(t, r)
})
t.Run("invalid json", func(t *testing.T) {
r := UnmarshalEmbeddings("[[true, false]]")
assert.Equal(t, Embeddings{{0, 0}}, r)
})
}