Improve image classification performance and logging in debug mode

This commit is contained in:
Michael Mayer 2019-05-04 17:34:51 +02:00
parent d83e81b49b
commit 1e6f41b417
10 changed files with 103 additions and 33 deletions

View file

@ -33,5 +33,7 @@ func indexAction(ctx *cli.Context) error {
log.Infof("indexed %d files", len(files)) log.Infof("indexed %d files", len(files))
app.Shutdown()
return nil return nil
} }

View file

@ -22,5 +22,7 @@ func migrateAction(ctx *cli.Context) error {
log.Infoln("database migration complete") log.Infoln("database migration complete")
app.Shutdown()
return nil return nil
} }

View file

@ -3,6 +3,7 @@ package context
import ( import (
"errors" "errors"
"os" "os"
"syscall"
"time" "time"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -394,3 +395,15 @@ func (c *Context) ClientConfig() ClientConfig {
return result return result
} }
func (c *Context) Shutdown() {
if err := c.CloseDb(); err != nil {
log.Errorf("could not close database connection: %s", err)
} else {
log.Info("closed database connection")
}
if err := syscall.Kill(syscall.Getpid(), syscall.SIGINT); err != nil {
log.Error(err)
}
}

View file

@ -1,6 +1,8 @@
package models package models
import ( import (
"fmt"
"github.com/gosimple/slug" "github.com/gosimple/slug"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
@ -44,3 +46,13 @@ func (c *Camera) FirstOrCreate(db *gorm.DB) *Camera {
return c return c
} }
func (c *Camera) String() string {
if c.CameraMake != "" && c.CameraModel != "" {
return fmt.Sprintf("%s %s", c.CameraMake, c.CameraModel)
} else if c.CameraModel != "" {
return c.CameraModel
}
return ""
}

View file

@ -122,29 +122,20 @@ func (i *Indexer) indexMediaFile(mediaFile *MediaFile) string {
photo.PhotoTitle = fmt.Sprintf("%s / %s / %s", location.LocCounty, location.LocCountry, mediaFile.DateCreated().Format("2006")) photo.PhotoTitle = fmt.Sprintf("%s / %s / %s", location.LocCounty, location.LocCountry, mediaFile.DateCreated().Format("2006"))
} }
} else { } else {
log.Infof("no location: %s", err) log.Debugf("location cannot be determined precisely: %s", err)
var recentPhoto models.Photo var recentPhoto models.Photo
if result := i.db.Order(gorm.Expr("ABS(DATEDIFF(taken_at, ?)) ASC", mediaFile.DateCreated())).Preload("Country").First(&recentPhoto); result.Error == nil { if result := i.db.Order(gorm.Expr("ABS(DATEDIFF(taken_at, ?)) ASC", mediaFile.DateCreated())).Preload("Country").First(&recentPhoto); result.Error == nil {
if recentPhoto.Country != nil { if recentPhoto.Country != nil {
photo.Country = recentPhoto.Country photo.Country = recentPhoto.Country
log.Debugf("approximate location: %s", recentPhoto.Country.CountryName)
} }
} }
} }
photo.Tags = tags photo.Tags = tags
if photo.PhotoTitle == "" {
if len(photo.Tags) > 0 { // TODO: User defined title format
photo.PhotoTitle = fmt.Sprintf("%s / %s", strings.Title(photo.Tags[0].TagLabel), mediaFile.DateCreated().Format("2006"))
} else if photo.Country != nil && photo.Country.CountryName != "" {
photo.PhotoTitle = fmt.Sprintf("%s / %s", strings.Title(photo.Country.CountryName), mediaFile.DateCreated().Format("2006"))
} else {
photo.PhotoTitle = fmt.Sprintf("Unknown / %s", mediaFile.DateCreated().Format("2006"))
}
}
photo.Camera = models.NewCamera(mediaFile.CameraModel(), mediaFile.CameraMake()).FirstOrCreate(i.db) photo.Camera = models.NewCamera(mediaFile.CameraModel(), mediaFile.CameraMake()).FirstOrCreate(i.db)
photo.Lens = models.NewLens(mediaFile.LensModel(), mediaFile.LensMake()).FirstOrCreate(i.db) photo.Lens = models.NewLens(mediaFile.LensModel(), mediaFile.LensMake()).FirstOrCreate(i.db)
photo.PhotoFocalLength = mediaFile.FocalLength() photo.PhotoFocalLength = mediaFile.FocalLength()
@ -154,6 +145,36 @@ func (i *Indexer) indexMediaFile(mediaFile *MediaFile) string {
photo.PhotoCanonicalName = canonicalName photo.PhotoCanonicalName = canonicalName
photo.PhotoFavorite = false photo.PhotoFavorite = false
if photo.PhotoTitle == "" {
if len(photo.Tags) > 0 { // TODO: User defined title format
photo.PhotoTitle = fmt.Sprintf("%s / %s", strings.Title(photo.Tags[0].TagLabel), mediaFile.DateCreated().Format("2006"))
} else if photo.Country != nil && photo.Country.CountryName != "" {
photo.PhotoTitle = fmt.Sprintf("%s / %s", strings.Title(photo.Country.CountryName), mediaFile.DateCreated().Format("2006"))
} else if photo.Camera.String() != "" && photo.Camera.String() != "Unknown" {
photo.PhotoTitle = fmt.Sprintf("%s / %s", photo.Camera, mediaFile.DateCreated().Format("January 2006"))
} else {
var daytimeString string
hour := mediaFile.DateCreated().Hour()
switch {
case hour < 8:
daytimeString = "Early Bird"
case hour < 12:
daytimeString = "Morning Mood"
case hour < 17:
daytimeString = "Carpe Diem"
case hour < 20:
daytimeString = "Sunset"
default:
daytimeString = "Late Night"
}
photo.PhotoTitle = fmt.Sprintf("%s / %s", daytimeString, mediaFile.DateCreated().Format("January 2006"))
}
}
log.Debugf("title: \"%s\"", photo.PhotoTitle)
i.db.Create(&photo) i.db.Create(&photo)
} else if time.Now().Sub(photo.UpdatedAt).Minutes() > 10 { // If updated more than 10 minutes ago } else if time.Now().Sub(photo.UpdatedAt).Minutes() > 10 { // If updated more than 10 minutes ago
if jpeg, err := mediaFile.Jpeg(); err == nil { if jpeg, err := mediaFile.Jpeg(); err == nil {

View file

@ -49,7 +49,7 @@ func (m *MediaFile) Location() (*models.Location, error) {
if exifData, err := m.ExifData(); err == nil { if exifData, err := m.ExifData(); err == nil {
if exifData.Lat == 0 && exifData.Long == 0 { if exifData.Lat == 0 && exifData.Long == 0 {
return nil, errors.New("lat and long are missing in metadata") return nil, errors.New("no latitude and longitude in image metadata")
} }
url := fmt.Sprintf("https://nominatim.openstreetmap.org/reverse?lat=%f&lon=%f&format=jsonv2", exifData.Lat, exifData.Long) url := fmt.Sprintf("https://nominatim.openstreetmap.org/reverse?lat=%f&lon=%f&format=jsonv2", exifData.Lat, exifData.Long)

View file

@ -11,13 +11,13 @@ import (
"sort" "sort"
"github.com/disintegration/imaging" "github.com/disintegration/imaging"
log "github.com/sirupsen/logrus"
tf "github.com/tensorflow/tensorflow/tensorflow/go" tf "github.com/tensorflow/tensorflow/tensorflow/go"
) )
// TensorFlow if a tensorflow wrapper given a graph, labels and a modelPath. // TensorFlow if a tensorflow wrapper given a graph, labels and a modelPath.
type TensorFlow struct { type TensorFlow struct {
modelPath string modelPath string
graph *tf.Graph
model *tf.SavedModel model *tf.SavedModel
labels []string labels []string
} }
@ -87,17 +87,26 @@ func (t *TensorFlow) GetImageTags(img []byte) (result []TensorFlowLabel, err err
} }
// Return best labels // Return best labels
return t.findBestLabels(output[0].Value().([][]float32)[0]), nil result = t.findBestLabels(output[0].Value().([][]float32)[0])
log.Debugf("labels: %v", result)
return result, nil
} }
func (t *TensorFlow) loadModel() error { func (t *TensorFlow) loadModel() error {
if t.graph != nil { if t.model != nil {
// Already loaded // Already loaded
return nil return nil
} }
savedModel := t.modelPath + "/nasnet"
modelLabels := savedModel + "/labels.txt"
log.Infof("loading image classification model from \"%s\"", savedModel)
// Load model // Load model
model, err := tf.LoadSavedModel(t.modelPath+"/nasnet", []string{"photoprism"}, nil) model, err := tf.LoadSavedModel(savedModel, []string{"photoprism"}, nil)
if err != nil { if err != nil {
return err return err
@ -105,40 +114,54 @@ func (t *TensorFlow) loadModel() error {
t.model = model t.model = model
log.Infof("loading classification labels from \"%s\"", modelLabels)
// Load labels // Load labels
labelsFile, err := os.Open(t.modelPath + "/nasnet/labels.txt") f, err := os.Open(modelLabels)
if err != nil { if err != nil {
return err return err
} }
defer labelsFile.Close()
scanner := bufio.NewScanner(labelsFile) defer f.Close()
scanner := bufio.NewScanner(f)
// Labels are separated by newlines // Labels are separated by newlines
for scanner.Scan() { for scanner.Scan() {
t.labels = append(t.labels, scanner.Text()) t.labels = append(t.labels, scanner.Text())
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
return err return err
} }
return nil return nil
} }
func (t *TensorFlow) findBestLabels(probabilities []float32) []TensorFlowLabel { func (t *TensorFlow) findBestLabels(probabilities []float32) []TensorFlowLabel {
// Make a list of label/probability pairs // Make a list of label/probability pairs
var resultLabels []TensorFlowLabel var result []TensorFlowLabel
for i, p := range probabilities { for i, p := range probabilities {
if i >= len(t.labels) { if i >= len(t.labels) {
break break
} }
resultLabels = append(resultLabels, TensorFlowLabel{Label: t.labels[i], Probability: p})
if p < 0.08 { continue }
result = append(result, TensorFlowLabel{Label: t.labels[i], Probability: p})
} }
// Sort by probability // Sort by probability
sort.Sort(TensorFlowLabels(resultLabels)) sort.Sort(TensorFlowLabels(result))
// Return top 5 labels l := len(result)
return resultLabels[:5]
if l >= 5 {
return result[:5]
} else {
return result[:l]
}
} }
func (t *TensorFlow) makeTensorFromImage(image []byte, imageFormat string) (*tf.Tensor, error) { func (t *TensorFlow) makeTensorFromImage(image []byte, imageFormat string) (*tf.Tensor, error) {

View file

@ -26,7 +26,7 @@ func TestTensorFlow_GetImageTagsFromFile(t *testing.T) {
assert.NotNil(t, result) assert.NotNil(t, result)
assert.IsType(t, []TensorFlowLabel{}, result) assert.IsType(t, []TensorFlowLabel{}, result)
assert.Equal(t, 5, len(result)) assert.Equal(t, 2, len(result))
t.Log(result) t.Log(result)
@ -59,7 +59,7 @@ func TestTensorFlow_GetImageTags(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.IsType(t, []TensorFlowLabel{}, result) assert.IsType(t, []TensorFlowLabel{}, result)
assert.Equal(t, 5, len(result)) assert.Equal(t, 2, len(result))
assert.Equal(t, "tabby cat", result[0].Label) assert.Equal(t, "tabby cat", result[0].Label)
assert.Equal(t, "tiger cat", result[1].Label) assert.Equal(t, "tiger cat", result[1].Label)
@ -91,7 +91,7 @@ func TestTensorFlow_GetImageTags_Dog(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.IsType(t, []TensorFlowLabel{}, result) assert.IsType(t, []TensorFlowLabel{}, result)
assert.Equal(t, 5, len(result)) assert.Equal(t, 3, len(result))
assert.Equal(t, "belt", result[0].Label) assert.Equal(t, "belt", result[0].Label)
assert.Equal(t, "beagle dog", result[1].Label) assert.Equal(t, "beagle dog", result[1].Label)

View file

@ -50,11 +50,7 @@ func Start(ctx *context.Context) {
<-quit <-quit
log.Info("received interrupt signal - shutting down") log.Info("received interrupt signal - shutting down")
if err := ctx.CloseDb(); err != nil { ctx.Shutdown()
log.Errorf("could not close database connection: %s", err)
} else {
log.Info("closed database connection")
}
if err := server.Close(); err != nil { if err := server.Close(); err != nil {
log.Errorf("server close: %s", err) log.Errorf("server close: %s", err)
@ -71,5 +67,5 @@ func Start(ctx *context.Context) {
log.Info("please come back another time") log.Info("please come back another time")
time.Sleep(3 * time.Second) time.Sleep(2 * time.Second)
} }

View file

@ -117,6 +117,7 @@ func Start(path string, port uint, host string, debug bool) {
runServer() runServer()
cleanup() cleanup()
os.Exit(0) os.Exit(0)
} }