Improve image classification performance and logging in debug mode

This commit is contained in:
Michael Mayer 2019-05-04 17:34:51 +02:00
parent d83e81b49b
commit 1e6f41b417
10 changed files with 103 additions and 33 deletions

View file

@ -33,5 +33,7 @@ func indexAction(ctx *cli.Context) error {
log.Infof("indexed %d files", len(files))
app.Shutdown()
return nil
}

View file

@ -22,5 +22,7 @@ func migrateAction(ctx *cli.Context) error {
log.Infoln("database migration complete")
app.Shutdown()
return nil
}

View file

@ -3,6 +3,7 @@ package context
import (
"errors"
"os"
"syscall"
"time"
"github.com/jinzhu/gorm"
@ -394,3 +395,15 @@ func (c *Context) ClientConfig() ClientConfig {
return result
}
func (c *Context) Shutdown() {
if err := c.CloseDb(); err != nil {
log.Errorf("could not close database connection: %s", err)
} else {
log.Info("closed database connection")
}
if err := syscall.Kill(syscall.Getpid(), syscall.SIGINT); err != nil {
log.Error(err)
}
}

View file

@ -1,6 +1,8 @@
package models
import (
"fmt"
"github.com/gosimple/slug"
"github.com/jinzhu/gorm"
)
@ -44,3 +46,13 @@ func (c *Camera) FirstOrCreate(db *gorm.DB) *Camera {
return c
}
func (c *Camera) String() string {
if c.CameraMake != "" && c.CameraModel != "" {
return fmt.Sprintf("%s %s", c.CameraMake, c.CameraModel)
} else if c.CameraModel != "" {
return c.CameraModel
}
return ""
}

View file

@ -122,29 +122,20 @@ func (i *Indexer) indexMediaFile(mediaFile *MediaFile) string {
photo.PhotoTitle = fmt.Sprintf("%s / %s / %s", location.LocCounty, location.LocCountry, mediaFile.DateCreated().Format("2006"))
}
} else {
log.Infof("no location: %s", err)
log.Debugf("location cannot be determined precisely: %s", err)
var recentPhoto models.Photo
if result := i.db.Order(gorm.Expr("ABS(DATEDIFF(taken_at, ?)) ASC", mediaFile.DateCreated())).Preload("Country").First(&recentPhoto); result.Error == nil {
if recentPhoto.Country != nil {
photo.Country = recentPhoto.Country
log.Debugf("approximate location: %s", recentPhoto.Country.CountryName)
}
}
}
photo.Tags = tags
if photo.PhotoTitle == "" {
if len(photo.Tags) > 0 { // TODO: User defined title format
photo.PhotoTitle = fmt.Sprintf("%s / %s", strings.Title(photo.Tags[0].TagLabel), mediaFile.DateCreated().Format("2006"))
} else if photo.Country != nil && photo.Country.CountryName != "" {
photo.PhotoTitle = fmt.Sprintf("%s / %s", strings.Title(photo.Country.CountryName), mediaFile.DateCreated().Format("2006"))
} else {
photo.PhotoTitle = fmt.Sprintf("Unknown / %s", mediaFile.DateCreated().Format("2006"))
}
}
photo.Camera = models.NewCamera(mediaFile.CameraModel(), mediaFile.CameraMake()).FirstOrCreate(i.db)
photo.Lens = models.NewLens(mediaFile.LensModel(), mediaFile.LensMake()).FirstOrCreate(i.db)
photo.PhotoFocalLength = mediaFile.FocalLength()
@ -154,6 +145,36 @@ func (i *Indexer) indexMediaFile(mediaFile *MediaFile) string {
photo.PhotoCanonicalName = canonicalName
photo.PhotoFavorite = false
if photo.PhotoTitle == "" {
if len(photo.Tags) > 0 { // TODO: User defined title format
photo.PhotoTitle = fmt.Sprintf("%s / %s", strings.Title(photo.Tags[0].TagLabel), mediaFile.DateCreated().Format("2006"))
} else if photo.Country != nil && photo.Country.CountryName != "" {
photo.PhotoTitle = fmt.Sprintf("%s / %s", strings.Title(photo.Country.CountryName), mediaFile.DateCreated().Format("2006"))
} else if photo.Camera.String() != "" && photo.Camera.String() != "Unknown" {
photo.PhotoTitle = fmt.Sprintf("%s / %s", photo.Camera, mediaFile.DateCreated().Format("January 2006"))
} else {
var daytimeString string
hour := mediaFile.DateCreated().Hour()
switch {
case hour < 8:
daytimeString = "Early Bird"
case hour < 12:
daytimeString = "Morning Mood"
case hour < 17:
daytimeString = "Carpe Diem"
case hour < 20:
daytimeString = "Sunset"
default:
daytimeString = "Late Night"
}
photo.PhotoTitle = fmt.Sprintf("%s / %s", daytimeString, mediaFile.DateCreated().Format("January 2006"))
}
}
log.Debugf("title: \"%s\"", photo.PhotoTitle)
i.db.Create(&photo)
} else if time.Now().Sub(photo.UpdatedAt).Minutes() > 10 { // If updated more than 10 minutes ago
if jpeg, err := mediaFile.Jpeg(); err == nil {

View file

@ -49,7 +49,7 @@ func (m *MediaFile) Location() (*models.Location, error) {
if exifData, err := m.ExifData(); err == nil {
if exifData.Lat == 0 && exifData.Long == 0 {
return nil, errors.New("lat and long are missing in metadata")
return nil, errors.New("no latitude and longitude in image metadata")
}
url := fmt.Sprintf("https://nominatim.openstreetmap.org/reverse?lat=%f&lon=%f&format=jsonv2", exifData.Lat, exifData.Long)

View file

@ -11,13 +11,13 @@ import (
"sort"
"github.com/disintegration/imaging"
log "github.com/sirupsen/logrus"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
// TensorFlow if a tensorflow wrapper given a graph, labels and a modelPath.
type TensorFlow struct {
modelPath string
graph *tf.Graph
model *tf.SavedModel
labels []string
}
@ -87,17 +87,26 @@ func (t *TensorFlow) GetImageTags(img []byte) (result []TensorFlowLabel, err err
}
// Return best labels
return t.findBestLabels(output[0].Value().([][]float32)[0]), nil
result = t.findBestLabels(output[0].Value().([][]float32)[0])
log.Debugf("labels: %v", result)
return result, nil
}
func (t *TensorFlow) loadModel() error {
if t.graph != nil {
if t.model != nil {
// Already loaded
return nil
}
savedModel := t.modelPath + "/nasnet"
modelLabels := savedModel + "/labels.txt"
log.Infof("loading image classification model from \"%s\"", savedModel)
// Load model
model, err := tf.LoadSavedModel(t.modelPath+"/nasnet", []string{"photoprism"}, nil)
model, err := tf.LoadSavedModel(savedModel, []string{"photoprism"}, nil)
if err != nil {
return err
@ -105,40 +114,54 @@ func (t *TensorFlow) loadModel() error {
t.model = model
log.Infof("loading classification labels from \"%s\"", modelLabels)
// Load labels
labelsFile, err := os.Open(t.modelPath + "/nasnet/labels.txt")
f, err := os.Open(modelLabels)
if err != nil {
return err
}
defer labelsFile.Close()
scanner := bufio.NewScanner(labelsFile)
defer f.Close()
scanner := bufio.NewScanner(f)
// Labels are separated by newlines
for scanner.Scan() {
t.labels = append(t.labels, scanner.Text())
}
if err := scanner.Err(); err != nil {
return err
}
return nil
}
func (t *TensorFlow) findBestLabels(probabilities []float32) []TensorFlowLabel {
// Make a list of label/probability pairs
var resultLabels []TensorFlowLabel
var result []TensorFlowLabel
for i, p := range probabilities {
if i >= len(t.labels) {
break
}
resultLabels = append(resultLabels, TensorFlowLabel{Label: t.labels[i], Probability: p})
if p < 0.08 { continue }
result = append(result, TensorFlowLabel{Label: t.labels[i], Probability: p})
}
// Sort by probability
sort.Sort(TensorFlowLabels(resultLabels))
sort.Sort(TensorFlowLabels(result))
// Return top 5 labels
return resultLabels[:5]
l := len(result)
if l >= 5 {
return result[:5]
} else {
return result[:l]
}
}
func (t *TensorFlow) makeTensorFromImage(image []byte, imageFormat string) (*tf.Tensor, error) {

View file

@ -26,7 +26,7 @@ func TestTensorFlow_GetImageTagsFromFile(t *testing.T) {
assert.NotNil(t, result)
assert.IsType(t, []TensorFlowLabel{}, result)
assert.Equal(t, 5, len(result))
assert.Equal(t, 2, len(result))
t.Log(result)
@ -59,7 +59,7 @@ func TestTensorFlow_GetImageTags(t *testing.T) {
assert.Nil(t, err)
assert.IsType(t, []TensorFlowLabel{}, result)
assert.Equal(t, 5, len(result))
assert.Equal(t, 2, len(result))
assert.Equal(t, "tabby cat", result[0].Label)
assert.Equal(t, "tiger cat", result[1].Label)
@ -91,7 +91,7 @@ func TestTensorFlow_GetImageTags_Dog(t *testing.T) {
assert.Nil(t, err)
assert.IsType(t, []TensorFlowLabel{}, result)
assert.Equal(t, 5, len(result))
assert.Equal(t, 3, len(result))
assert.Equal(t, "belt", result[0].Label)
assert.Equal(t, "beagle dog", result[1].Label)

View file

@ -50,11 +50,7 @@ func Start(ctx *context.Context) {
<-quit
log.Info("received interrupt signal - shutting down")
if err := ctx.CloseDb(); err != nil {
log.Errorf("could not close database connection: %s", err)
} else {
log.Info("closed database connection")
}
ctx.Shutdown()
if err := server.Close(); err != nil {
log.Errorf("server close: %s", err)
@ -71,5 +67,5 @@ func Start(ctx *context.Context) {
log.Info("please come back another time")
time.Sleep(3 * time.Second)
time.Sleep(2 * time.Second)
}

View file

@ -117,6 +117,7 @@ func Start(path string, port uint, host string, debug bool) {
runServer()
cleanup()
os.Exit(0)
}