2018-09-14 10:44:15 +00:00
|
|
|
package photoprism
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bufio"
|
2019-04-30 11:17:01 +00:00
|
|
|
"bytes"
|
2018-09-14 10:44:15 +00:00
|
|
|
"errors"
|
2019-05-16 06:41:16 +00:00
|
|
|
"fmt"
|
2019-04-30 11:17:01 +00:00
|
|
|
"image"
|
2018-09-14 10:44:15 +00:00
|
|
|
"io/ioutil"
|
2019-04-30 11:17:01 +00:00
|
|
|
"math"
|
2018-09-14 10:44:15 +00:00
|
|
|
"os"
|
2019-12-13 15:25:47 +00:00
|
|
|
"path"
|
2019-12-11 03:12:54 +00:00
|
|
|
"path/filepath"
|
2018-09-14 10:44:15 +00:00
|
|
|
"sort"
|
2019-05-16 06:41:16 +00:00
|
|
|
"strings"
|
2018-10-31 06:14:33 +00:00
|
|
|
|
2019-04-30 11:17:01 +00:00
|
|
|
"github.com/disintegration/imaging"
|
2019-06-05 16:25:20 +00:00
|
|
|
"github.com/photoprism/photoprism/internal/config"
|
2019-06-09 03:22:53 +00:00
|
|
|
"github.com/photoprism/photoprism/internal/util"
|
2018-10-31 06:14:33 +00:00
|
|
|
tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
2019-05-16 06:41:16 +00:00
|
|
|
"gopkg.in/yaml.v2"
|
2018-09-14 10:44:15 +00:00
|
|
|
)
|
|
|
|
|
2019-06-05 16:25:20 +00:00
|
|
|
// TensorFlow if a wrapper for their low-level API.
|
2018-09-14 10:44:15 +00:00
|
|
|
type TensorFlow struct {
|
2019-06-05 16:25:20 +00:00
|
|
|
conf *config.Config
|
2019-05-16 06:41:16 +00:00
|
|
|
model *tf.SavedModel
|
2019-12-13 15:25:47 +00:00
|
|
|
modelName string
|
|
|
|
modelTags []string
|
2019-05-16 06:41:16 +00:00
|
|
|
labels []string
|
|
|
|
labelRules LabelRules
|
2018-09-14 10:44:15 +00:00
|
|
|
}
|
|
|
|
|
2019-05-16 06:41:16 +00:00
|
|
|
type LabelRule struct {
|
2019-06-09 03:22:53 +00:00
|
|
|
Label string
|
|
|
|
See string
|
|
|
|
Threshold float32
|
|
|
|
Categories []string
|
|
|
|
Priority int
|
2019-05-16 06:41:16 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
type LabelRules map[string]LabelRule
|
|
|
|
|
2019-12-13 15:25:47 +00:00
|
|
|
// NewTensorFlow returns new TensorFlow instance with Nasnet model.
|
2019-06-05 16:25:20 +00:00
|
|
|
func NewTensorFlow(conf *config.Config) *TensorFlow {
|
2019-12-13 15:25:47 +00:00
|
|
|
return &TensorFlow{conf: conf, modelName: "nasnet", modelTags: []string{"photoprism"}}
|
2018-09-14 10:44:15 +00:00
|
|
|
}
|
|
|
|
|
2019-05-16 06:41:16 +00:00
|
|
|
func (t *TensorFlow) loadLabelRules() (err error) {
|
|
|
|
if len(t.labelRules) > 0 {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
t.labelRules = make(LabelRules)
|
|
|
|
|
2019-06-05 16:25:20 +00:00
|
|
|
fileName := t.conf.ConfigPath() + "/labels.yml"
|
2019-05-16 06:41:16 +00:00
|
|
|
|
2019-12-11 03:12:54 +00:00
|
|
|
log.Debugf("tensorflow: loading label rules from \"%s\"", filepath.Base(fileName))
|
2019-05-16 06:41:16 +00:00
|
|
|
|
|
|
|
if !util.Exists(fileName) {
|
2019-12-11 03:12:54 +00:00
|
|
|
e := fmt.Errorf("tensorflow: label rules file not found in \"%s\"", filepath.Base(fileName))
|
|
|
|
log.Error(e.Error())
|
|
|
|
return e
|
2019-05-16 06:41:16 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
yamlConfig, err := ioutil.ReadFile(fileName)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
log.Error(err)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
err = yaml.Unmarshal(yamlConfig, t.labelRules)
|
|
|
|
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2019-06-05 16:25:20 +00:00
|
|
|
// LabelsFromFile returns matching labels for a jpeg media file.
|
2019-06-04 16:26:35 +00:00
|
|
|
func (t *TensorFlow) LabelsFromFile(filename string) (result Labels, err error) {
|
2018-09-14 10:44:15 +00:00
|
|
|
imageBuffer, err := ioutil.ReadFile(filename)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2019-06-04 16:26:35 +00:00
|
|
|
return t.Labels(imageBuffer)
|
2018-09-14 10:44:15 +00:00
|
|
|
}
|
|
|
|
|
2019-06-05 16:25:20 +00:00
|
|
|
// Labels returns matching labels for a jpeg media string.
|
2019-06-04 16:26:35 +00:00
|
|
|
func (t *TensorFlow) Labels(img []byte) (result Labels, err error) {
|
2018-09-14 10:44:15 +00:00
|
|
|
if err := t.loadModel(); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Make tensor
|
2019-06-04 16:26:35 +00:00
|
|
|
tensor, err := t.makeTensor(img, "jpeg")
|
2018-09-14 10:44:15 +00:00
|
|
|
|
|
|
|
if err != nil {
|
2019-12-14 13:38:43 +00:00
|
|
|
log.Error(err)
|
2018-09-14 10:44:15 +00:00
|
|
|
return nil, errors.New("invalid image")
|
|
|
|
}
|
|
|
|
|
|
|
|
// Run inference
|
2019-04-30 11:17:01 +00:00
|
|
|
output, err := t.model.Session.Run(
|
2018-09-14 10:44:15 +00:00
|
|
|
map[tf.Output]*tf.Tensor{
|
2019-04-30 11:17:01 +00:00
|
|
|
t.model.Graph.Operation("input_1").Output(0): tensor,
|
2018-09-14 10:44:15 +00:00
|
|
|
},
|
|
|
|
[]tf.Output{
|
2019-04-30 11:17:01 +00:00
|
|
|
t.model.Graph.Operation("predictions/Softmax").Output(0),
|
2018-09-14 10:44:15 +00:00
|
|
|
},
|
|
|
|
nil)
|
|
|
|
|
|
|
|
if err != nil {
|
2019-12-14 13:38:43 +00:00
|
|
|
log.Error(err)
|
2019-04-30 11:17:01 +00:00
|
|
|
return result, errors.New("could not run inference")
|
2018-09-14 10:44:15 +00:00
|
|
|
}
|
|
|
|
|
2019-04-30 11:17:01 +00:00
|
|
|
if len(output) < 1 {
|
|
|
|
return result, errors.New("result is empty")
|
|
|
|
}
|
|
|
|
|
2018-09-14 10:44:15 +00:00
|
|
|
// Return best labels
|
2019-06-04 16:26:35 +00:00
|
|
|
result = t.bestLabels(output[0].Value().([][]float32)[0])
|
2019-05-04 15:34:51 +00:00
|
|
|
|
2019-12-11 03:12:54 +00:00
|
|
|
if len(result) > 0 {
|
|
|
|
log.Debugf("tensorflow: image classified as %+v", result)
|
|
|
|
}
|
2019-05-04 15:34:51 +00:00
|
|
|
|
|
|
|
return result, nil
|
2018-09-14 10:44:15 +00:00
|
|
|
}
|
|
|
|
|
2019-07-17 09:53:33 +00:00
|
|
|
func (t *TensorFlow) loadLabels(path string) error {
|
|
|
|
modelLabels := path + "/labels.txt"
|
2018-09-14 10:44:15 +00:00
|
|
|
|
2019-12-11 03:12:54 +00:00
|
|
|
log.Infof("tensorflow: loading classification labels from labels.txt")
|
2019-05-04 15:34:51 +00:00
|
|
|
|
2018-09-14 10:44:15 +00:00
|
|
|
// Load labels
|
2019-05-04 15:34:51 +00:00
|
|
|
f, err := os.Open(modelLabels)
|
|
|
|
|
2018-09-14 10:44:15 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2019-04-30 11:17:01 +00:00
|
|
|
|
2019-05-04 15:34:51 +00:00
|
|
|
defer f.Close()
|
|
|
|
|
|
|
|
scanner := bufio.NewScanner(f)
|
2018-09-14 10:44:15 +00:00
|
|
|
|
|
|
|
// Labels are separated by newlines
|
|
|
|
for scanner.Scan() {
|
|
|
|
t.labels = append(t.labels, scanner.Text())
|
|
|
|
}
|
2019-05-04 15:34:51 +00:00
|
|
|
|
2018-09-14 10:44:15 +00:00
|
|
|
if err := scanner.Err(); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2019-05-04 15:34:51 +00:00
|
|
|
|
2018-09-14 10:44:15 +00:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2019-07-17 09:53:33 +00:00
|
|
|
func (t *TensorFlow) loadModel() error {
|
|
|
|
if t.model != nil {
|
|
|
|
// Already loaded
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2019-12-13 15:25:47 +00:00
|
|
|
modelPath := path.Join(t.conf.ResourcesPath(), t.modelName)
|
2019-07-17 09:53:33 +00:00
|
|
|
|
2019-12-13 15:25:47 +00:00
|
|
|
log.Infof("tensorflow: loading image classification model from \"%s\"", filepath.Base(modelPath))
|
2019-07-17 09:53:33 +00:00
|
|
|
|
|
|
|
// Load model
|
2019-12-13 15:25:47 +00:00
|
|
|
model, err := tf.LoadSavedModel(modelPath, t.modelTags, nil)
|
2019-07-17 09:53:33 +00:00
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
t.model = model
|
|
|
|
|
2019-12-13 15:25:47 +00:00
|
|
|
return t.loadLabels(modelPath)
|
2019-07-17 09:53:33 +00:00
|
|
|
}
|
|
|
|
|
2019-05-16 06:41:16 +00:00
|
|
|
func (t *TensorFlow) labelRule(label string) LabelRule {
|
2019-06-05 08:18:03 +00:00
|
|
|
label = strings.ToLower(label)
|
|
|
|
|
2019-05-16 06:41:16 +00:00
|
|
|
if err := t.loadLabelRules(); err != nil {
|
|
|
|
log.Error(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if rule, ok := t.labelRules[label]; ok {
|
|
|
|
if rule.See != "" {
|
|
|
|
return t.labelRule(rule.See)
|
|
|
|
}
|
|
|
|
|
|
|
|
return t.labelRules[label]
|
|
|
|
}
|
|
|
|
|
|
|
|
return LabelRule{Threshold: 0.08}
|
|
|
|
}
|
|
|
|
|
2019-06-04 16:26:35 +00:00
|
|
|
func (t *TensorFlow) bestLabels(probabilities []float32) Labels {
|
2019-05-16 06:41:16 +00:00
|
|
|
if err := t.loadLabelRules(); err != nil {
|
|
|
|
log.Error(err)
|
|
|
|
}
|
|
|
|
|
2018-09-14 10:44:15 +00:00
|
|
|
// Make a list of label/probability pairs
|
2019-06-04 16:26:35 +00:00
|
|
|
var result Labels
|
2019-05-16 06:41:16 +00:00
|
|
|
|
2018-09-14 10:44:15 +00:00
|
|
|
for i, p := range probabilities {
|
|
|
|
if i >= len(t.labels) {
|
|
|
|
break
|
|
|
|
}
|
2019-05-04 15:34:51 +00:00
|
|
|
|
2019-05-06 21:18:10 +00:00
|
|
|
if p < 0.08 {
|
|
|
|
continue
|
|
|
|
}
|
2019-05-04 15:34:51 +00:00
|
|
|
|
2019-05-16 06:41:16 +00:00
|
|
|
labelText := strings.ToLower(t.labels[i])
|
|
|
|
|
|
|
|
rule := t.labelRule(labelText)
|
|
|
|
|
|
|
|
if p < rule.Threshold {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
2019-06-05 08:18:03 +00:00
|
|
|
if rule.Label != "" {
|
|
|
|
labelText = rule.Label
|
2019-05-16 06:41:16 +00:00
|
|
|
}
|
|
|
|
|
2019-06-05 08:18:03 +00:00
|
|
|
labelText = strings.TrimSpace(labelText)
|
|
|
|
|
2019-06-09 03:22:53 +00:00
|
|
|
uncertainty := 100 - int(math.Round(float64(p*100)))
|
2019-06-04 16:26:35 +00:00
|
|
|
|
2019-06-05 12:05:21 +00:00
|
|
|
result = append(result, Label{Name: labelText, Source: "image", Uncertainty: uncertainty, Priority: rule.Priority, Categories: rule.Categories})
|
2018-09-14 10:44:15 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// Sort by probability
|
2019-12-14 13:38:43 +00:00
|
|
|
sort.Sort(result)
|
2019-05-04 15:34:51 +00:00
|
|
|
|
2019-05-16 06:41:16 +00:00
|
|
|
if l := len(result); l < 5 {
|
2019-05-04 15:34:51 +00:00
|
|
|
return result[:l]
|
2019-05-16 06:41:16 +00:00
|
|
|
} else {
|
|
|
|
return result[:5]
|
2019-05-04 15:34:51 +00:00
|
|
|
}
|
2018-09-14 10:44:15 +00:00
|
|
|
}
|
|
|
|
|
2019-06-04 16:26:35 +00:00
|
|
|
func (t *TensorFlow) makeTensor(image []byte, imageFormat string) (*tf.Tensor, error) {
|
2019-05-01 12:54:11 +00:00
|
|
|
img, err := imaging.Decode(bytes.NewReader(image), imaging.AutoOrientation(true))
|
2019-04-30 11:17:01 +00:00
|
|
|
|
2018-09-14 10:44:15 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2019-04-30 11:17:01 +00:00
|
|
|
|
|
|
|
width, height := 224, 224
|
|
|
|
|
2019-05-13 16:01:50 +00:00
|
|
|
img = imaging.Fill(img, width, height, imaging.Center, imaging.Lanczos)
|
2019-04-30 11:17:01 +00:00
|
|
|
|
|
|
|
return imageToTensorTF(img, width, height)
|
|
|
|
}
|
|
|
|
|
|
|
|
func imageToTensorTF(img image.Image, imageHeight, imageWidth int) (*tf.Tensor, error) {
|
|
|
|
var tfImage [1][][][3]float32
|
|
|
|
|
|
|
|
for j := 0; j < imageHeight; j++ {
|
|
|
|
tfImage[0] = append(tfImage[0], make([][3]float32, imageWidth))
|
2018-09-14 10:44:15 +00:00
|
|
|
}
|
2019-04-30 11:17:01 +00:00
|
|
|
|
|
|
|
for i := 0; i < imageWidth; i++ {
|
|
|
|
for j := 0; j < imageHeight; j++ {
|
|
|
|
r, g, b, _ := img.At(i, j).RGBA()
|
|
|
|
tfImage[0][j][i][0] = convertTF(r)
|
|
|
|
tfImage[0][j][i][1] = convertTF(g)
|
|
|
|
tfImage[0][j][i][2] = convertTF(b)
|
|
|
|
}
|
2018-09-14 10:44:15 +00:00
|
|
|
}
|
2019-04-30 11:17:01 +00:00
|
|
|
|
|
|
|
return tf.NewTensor(tfImage)
|
2018-09-14 10:44:15 +00:00
|
|
|
}
|
|
|
|
|
2019-04-30 11:17:01 +00:00
|
|
|
func convertTF(value uint32) float32 {
|
|
|
|
return (float32(value>>8) - float32(127.5)) / float32(127.5)
|
2018-09-14 10:44:15 +00:00
|
|
|
}
|