Implemented migrate-db command
This commit is contained in:
parent
cfad5ecb35
commit
595c32b856
75
Gopkg.lock
generated
75
Gopkg.lock
generated
|
@ -1,6 +1,12 @@
|
||||||
# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'.
|
# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'.
|
||||||
|
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
name = "cloud.google.com/go"
|
||||||
|
packages = ["civil"]
|
||||||
|
revision = "777200caa7fb8936aed0f12b1fd79af64cc83ec9"
|
||||||
|
version = "v0.24.0"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
branch = "master"
|
branch = "master"
|
||||||
name = "github.com/araddon/dateparse"
|
name = "github.com/araddon/dateparse"
|
||||||
|
@ -25,6 +31,15 @@
|
||||||
revision = "346938d642f2ec3594ed81d874461961cd0faa76"
|
revision = "346938d642f2ec3594ed81d874461961cd0faa76"
|
||||||
version = "v1.1.0"
|
version = "v1.1.0"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
branch = "master"
|
||||||
|
name = "github.com/denisenkom/go-mssqldb"
|
||||||
|
packages = [
|
||||||
|
".",
|
||||||
|
"internal/cp"
|
||||||
|
]
|
||||||
|
revision = "94c9c97e8c9f9844d15c846854a7a6031ae2132c"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
name = "github.com/disintegration/imaging"
|
name = "github.com/disintegration/imaging"
|
||||||
packages = ["."]
|
packages = ["."]
|
||||||
|
@ -38,10 +53,28 @@
|
||||||
version = "v1.0.1"
|
version = "v1.0.1"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
name = "github.com/julienschmidt/httprouter"
|
name = "github.com/go-sql-driver/mysql"
|
||||||
packages = ["."]
|
packages = ["."]
|
||||||
revision = "8c199fb6259ffc1af525cc3ad52ee60ba8359669"
|
revision = "d523deb1b23d913de5bdada721a6071e71283618"
|
||||||
version = "v1.1"
|
version = "v1.4.0"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
name = "github.com/jinzhu/gorm"
|
||||||
|
packages = [
|
||||||
|
".",
|
||||||
|
"dialects/mssql",
|
||||||
|
"dialects/mysql",
|
||||||
|
"dialects/postgres",
|
||||||
|
"dialects/sqlite"
|
||||||
|
]
|
||||||
|
revision = "6ed508ec6a4ecb3531899a69cbc746ccf65a4166"
|
||||||
|
version = "v1.9.1"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
branch = "master"
|
||||||
|
name = "github.com/jinzhu/inflection"
|
||||||
|
packages = ["."]
|
||||||
|
revision = "04140366298a54a039076d798123ffa108fff46c"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
branch = "master"
|
branch = "master"
|
||||||
|
@ -49,6 +82,28 @@
|
||||||
packages = ["yaml"]
|
packages = ["yaml"]
|
||||||
revision = "08cad365cd28a7fba23bb1e57aa43c5e18ad8bb8"
|
revision = "08cad365cd28a7fba23bb1e57aa43c5e18ad8bb8"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
branch = "master"
|
||||||
|
name = "github.com/lib/pq"
|
||||||
|
packages = [
|
||||||
|
".",
|
||||||
|
"hstore",
|
||||||
|
"oid"
|
||||||
|
]
|
||||||
|
revision = "90697d60dd844d5ef6ff15135d0203f65d2f53b8"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
name = "github.com/mattn/go-sqlite3"
|
||||||
|
packages = ["."]
|
||||||
|
revision = "25ecb14adfc7543176f7d85291ec7dba82c6f7e4"
|
||||||
|
version = "v1.9.0"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
branch = "master"
|
||||||
|
name = "github.com/photoprism/photoprism"
|
||||||
|
packages = ["."]
|
||||||
|
revision = "b2659ba5ce48b223490b8f51db065d93ae8f0cf5"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
name = "github.com/pkg/errors"
|
name = "github.com/pkg/errors"
|
||||||
packages = ["."]
|
packages = ["."]
|
||||||
|
@ -98,6 +153,12 @@
|
||||||
revision = "cfb38830724cc34fedffe9a2a29fb54fa9169cd1"
|
revision = "cfb38830724cc34fedffe9a2a29fb54fa9169cd1"
|
||||||
version = "v1.20.0"
|
version = "v1.20.0"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
branch = "master"
|
||||||
|
name = "golang.org/x/crypto"
|
||||||
|
packages = ["md4"]
|
||||||
|
revision = "a49355c7e3f8fe157a85be2f77e6e269a0f89602"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
branch = "master"
|
branch = "master"
|
||||||
name = "golang.org/x/image"
|
name = "golang.org/x/image"
|
||||||
|
@ -112,9 +173,15 @@
|
||||||
]
|
]
|
||||||
revision = "12117c17ca67ffa1ce22e9409f3b0b0a93ac08c7"
|
revision = "12117c17ca67ffa1ce22e9409f3b0b0a93ac08c7"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
name = "google.golang.org/appengine"
|
||||||
|
packages = ["cloudsql"]
|
||||||
|
revision = "b1f26356af11148e710935ed1ac8a7f5702c7612"
|
||||||
|
version = "v1.1.0"
|
||||||
|
|
||||||
[solve-meta]
|
[solve-meta]
|
||||||
analyzer-name = "dep"
|
analyzer-name = "dep"
|
||||||
analyzer-version = 1
|
analyzer-version = 1
|
||||||
inputs-digest = "1ffc35574da6350baeabae12e2aed7a151f4db332e9e73656bd10b8db55de7d2"
|
inputs-digest = "8aa59b793f2c56ca48723acc6a2b517cbd2dd323af673b71c1bcc74d491951bf"
|
||||||
solver-name = "gps-cdcl"
|
solver-name = "gps-cdcl"
|
||||||
solver-version = 1
|
solver-version = 1
|
||||||
|
|
11
album.go
Normal file
11
album.go
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
package photoprism
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Album struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Photos []Photo `gorm:"many2many:album_photos;"`
|
||||||
|
}
|
|
@ -36,6 +36,23 @@ func main() {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "migrate-db",
|
||||||
|
Usage: "Automatically migrates / initializes database",
|
||||||
|
Action: func(context *cli.Context) error {
|
||||||
|
conf.SetValuesFromFile(photoprism.GetExpandedFilename(context.GlobalString("config-file")))
|
||||||
|
|
||||||
|
conf.SetValuesFromCliContext(context)
|
||||||
|
|
||||||
|
fmt.Println("Migrating database...")
|
||||||
|
|
||||||
|
conf.MigrateDb()
|
||||||
|
|
||||||
|
fmt.Println("Done.")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "import",
|
Name: "import",
|
||||||
Usage: "Imports photos",
|
Usage: "Imports photos",
|
||||||
|
@ -48,7 +65,7 @@ func main() {
|
||||||
|
|
||||||
fmt.Printf("Importing photos from %s...\n", conf.ImportPath)
|
fmt.Printf("Importing photos from %s...\n", conf.ImportPath)
|
||||||
|
|
||||||
importer := photoprism.NewImporter(conf.OriginalsPath)
|
importer := photoprism.NewImporter(conf.OriginalsPath, conf.GetDb())
|
||||||
|
|
||||||
importer.ImportPhotosFromDirectory(conf.ImportPath)
|
importer.ImportPhotosFromDirectory(conf.ImportPath)
|
||||||
|
|
||||||
|
@ -226,4 +243,14 @@ var globalCliFlags = []cli.Flag{
|
||||||
Usage: "thumbnails path",
|
Usage: "thumbnails path",
|
||||||
Value: "~/Photos/Thumbnails",
|
Value: "~/Photos/Thumbnails",
|
||||||
},
|
},
|
||||||
|
cli.StringFlag{
|
||||||
|
Name: "database-driver",
|
||||||
|
Usage: "database driver (mysql, mssql, postgres or sqlite)",
|
||||||
|
Value: "mysql",
|
||||||
|
},
|
||||||
|
cli.StringFlag{
|
||||||
|
Name: "database-dsn",
|
||||||
|
Usage: "database data source name (DSN)",
|
||||||
|
Value: "photoprism:photoprism@tcp(database:3306)/photoprism",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,17 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
func responseError(w http.ResponseWriter, message string, code int) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(code)
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"error": message})
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseJSON(w http.ResponseWriter, data interface{}) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(data)
|
|
||||||
}
|
|
|
@ -3,3 +3,5 @@ originals-path: photos/originals
|
||||||
thumbnails-path: photos/thumbnails
|
thumbnails-path: photos/thumbnails
|
||||||
import-path: photos/import
|
import-path: photos/import
|
||||||
export-path: photos/export
|
export-path: photos/export
|
||||||
|
database-driver: mysql
|
||||||
|
database-dsn: photoprism:photoprism@tcp(database:3306)/photoprism
|
53
config.go
53
config.go
|
@ -1,8 +1,14 @@
|
||||||
package photoprism
|
package photoprism
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
_ "github.com/jinzhu/gorm/dialects/mssql"
|
||||||
|
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||||
|
_ "github.com/jinzhu/gorm/dialects/postgres"
|
||||||
|
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||||
"github.com/kylelemons/go-gypsy/yaml"
|
"github.com/kylelemons/go-gypsy/yaml"
|
||||||
"github.com/urfave/cli"
|
"github.com/urfave/cli"
|
||||||
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
)
|
)
|
||||||
|
@ -14,6 +20,9 @@ type Config struct {
|
||||||
ThumbnailsPath string
|
ThumbnailsPath string
|
||||||
ImportPath string
|
ImportPath string
|
||||||
ExportPath string
|
ExportPath string
|
||||||
|
DatabaseDriver string
|
||||||
|
DatabaseDsn string
|
||||||
|
db *gorm.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConfig() *Config {
|
func NewConfig() *Config {
|
||||||
|
@ -49,6 +58,14 @@ func (c *Config) SetValuesFromFile(fileName string) error {
|
||||||
c.DarktableCli = GetExpandedFilename(DarktableCli)
|
c.DarktableCli = GetExpandedFilename(DarktableCli)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if DatabaseDriver, err := yamlConfig.Get("database-driver"); err == nil {
|
||||||
|
c.DatabaseDriver = DatabaseDriver
|
||||||
|
}
|
||||||
|
|
||||||
|
if DatabaseDsn, err := yamlConfig.Get("database-dsn"); err == nil {
|
||||||
|
c.DatabaseDsn = DatabaseDsn
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,6 +90,14 @@ func (c *Config) SetValuesFromCliContext(context *cli.Context) error {
|
||||||
c.DarktableCli = GetExpandedFilename(context.String("darktable-cli"))
|
c.DarktableCli = GetExpandedFilename(context.String("darktable-cli"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if context.IsSet("database-driver") {
|
||||||
|
c.DatabaseDriver = context.String("database-driver")
|
||||||
|
}
|
||||||
|
|
||||||
|
if context.IsSet("database-dsn") {
|
||||||
|
c.DatabaseDsn = context.String("database-dsn")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,3 +107,31 @@ func (c *Config) CreateDirectories() {
|
||||||
os.MkdirAll(path.Dir(c.ImportPath), os.ModePerm)
|
os.MkdirAll(path.Dir(c.ImportPath), os.ModePerm)
|
||||||
os.MkdirAll(path.Dir(c.ExportPath), os.ModePerm)
|
os.MkdirAll(path.Dir(c.ExportPath), os.ModePerm)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Config) ConnectToDatabase() error {
|
||||||
|
db, err := gorm.Open(c.DatabaseDriver, c.DatabaseDsn)
|
||||||
|
|
||||||
|
if err != nil || db == nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.db = db
|
||||||
|
|
||||||
|
c.MigrateDb()
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) GetDb() *gorm.DB {
|
||||||
|
if c.db == nil {
|
||||||
|
c.ConnectToDatabase()
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) MigrateDb() {
|
||||||
|
db := c.GetDb()
|
||||||
|
|
||||||
|
db.AutoMigrate(&File{}, &Photo{}, &Tag{}, &Album{})
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package photoprism
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -17,6 +18,8 @@ var originalsPath = GetExpandedFilename(testDataPath + "/originals")
|
||||||
var thumbnailsPath = GetExpandedFilename(testDataPath + "/thumbnails")
|
var thumbnailsPath = GetExpandedFilename(testDataPath + "/thumbnails")
|
||||||
var importPath = GetExpandedFilename(testDataPath + "/import")
|
var importPath = GetExpandedFilename(testDataPath + "/import")
|
||||||
var exportPath = GetExpandedFilename(testDataPath + "/export")
|
var exportPath = GetExpandedFilename(testDataPath + "/export")
|
||||||
|
var databaseDriver = "mysql"
|
||||||
|
var databaseDsn = "photoprism:photoprism@tcp(database:3306)/photoprism"
|
||||||
|
|
||||||
func (c *Config) RemoveTestData(t *testing.T) {
|
func (c *Config) RemoveTestData(t *testing.T) {
|
||||||
os.RemoveAll(c.ImportPath)
|
os.RemoveAll(c.ImportPath)
|
||||||
|
@ -67,6 +70,8 @@ func NewTestConfig() *Config {
|
||||||
ThumbnailsPath: thumbnailsPath,
|
ThumbnailsPath: thumbnailsPath,
|
||||||
ImportPath: importPath,
|
ImportPath: importPath,
|
||||||
ExportPath: exportPath,
|
ExportPath: exportPath,
|
||||||
|
DatabaseDriver: databaseDriver,
|
||||||
|
DatabaseDsn: databaseDsn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,4 +90,16 @@ func TestConfig_SetValuesFromFile(t *testing.T) {
|
||||||
assert.Equal(t, GetExpandedFilename("photos/thumbnails"), c.ThumbnailsPath)
|
assert.Equal(t, GetExpandedFilename("photos/thumbnails"), c.ThumbnailsPath)
|
||||||
assert.Equal(t, GetExpandedFilename("photos/import"), c.ImportPath)
|
assert.Equal(t, GetExpandedFilename("photos/import"), c.ImportPath)
|
||||||
assert.Equal(t, GetExpandedFilename("photos/export"), c.ExportPath)
|
assert.Equal(t, GetExpandedFilename("photos/export"), c.ExportPath)
|
||||||
|
assert.Equal(t, databaseDriver, c.DatabaseDriver)
|
||||||
|
assert.Equal(t, databaseDsn, c.DatabaseDsn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ConnectToDatabase(t *testing.T) {
|
||||||
|
c := NewTestConfig()
|
||||||
|
|
||||||
|
c.ConnectToDatabase()
|
||||||
|
|
||||||
|
db := c.GetDb()
|
||||||
|
|
||||||
|
assert.IsType(t, &gorm.DB{}, db)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ services:
|
||||||
- 8080:8080
|
- 8080:8080
|
||||||
volumes:
|
volumes:
|
||||||
- .:/go/src/photoprism
|
- .:/go/src/photoprism
|
||||||
|
|
||||||
database:
|
database:
|
||||||
image: mysql:latest
|
image: mysql:latest
|
||||||
command: mysqld --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci --max-connections=1024
|
command: mysqld --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci --max-connections=1024
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
package photoprism
|
package photoprism
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
|
||||||
"log"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
14
file.go
Normal file
14
file.go
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
package photoprism
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type File struct {
|
||||||
|
gorm.Model
|
||||||
|
PhotoID uint
|
||||||
|
Filename string
|
||||||
|
Hash string
|
||||||
|
FileType string
|
||||||
|
MimeType string
|
||||||
|
}
|
|
@ -2,6 +2,7 @@ package photoprism
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
@ -13,14 +14,16 @@ import (
|
||||||
|
|
||||||
type Importer struct {
|
type Importer struct {
|
||||||
originalsPath string
|
originalsPath string
|
||||||
|
db *gorm.DB
|
||||||
removeDotFiles bool
|
removeDotFiles bool
|
||||||
removeExistingFiles bool
|
removeExistingFiles bool
|
||||||
removeEmptyDirectories bool
|
removeEmptyDirectories bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewImporter(originalsPath string) *Importer {
|
func NewImporter(originalsPath string, db *gorm.DB) *Importer {
|
||||||
instance := &Importer{
|
instance := &Importer{
|
||||||
originalsPath: originalsPath,
|
originalsPath: originalsPath,
|
||||||
|
db: db,
|
||||||
removeDotFiles: true,
|
removeDotFiles: true,
|
||||||
removeExistingFiles: true,
|
removeExistingFiles: true,
|
||||||
removeEmptyDirectories: true,
|
removeEmptyDirectories: true,
|
||||||
|
|
|
@ -5,13 +5,13 @@ import (
|
||||||
"github.com/brett-lempereur/ish"
|
"github.com/brett-lempereur/ish"
|
||||||
"github.com/djherbis/times"
|
"github.com/djherbis/times"
|
||||||
"github.com/steakknife/hamming"
|
"github.com/steakknife/hamming"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"io"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
22
photo.go
Normal file
22
photo.go
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
package photoprism
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Photo struct {
|
||||||
|
gorm.Model
|
||||||
|
CanonicalName string
|
||||||
|
PerceptualHash string
|
||||||
|
Tags []Tag `gorm:"many2many:photo_tags;"`
|
||||||
|
Files []File
|
||||||
|
Albums []Album `gorm:"many2many:album_photos;"`
|
||||||
|
Author string
|
||||||
|
CameraModel string
|
||||||
|
LocationName string
|
||||||
|
Lat float64
|
||||||
|
Long float64
|
||||||
|
Liked bool
|
||||||
|
Private bool
|
||||||
|
Deleted bool
|
||||||
|
}
|
10
tag.go
Normal file
10
tag.go
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
package photoprism
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Tag struct {
|
||||||
|
gorm.Model
|
||||||
|
Label string
|
||||||
|
}
|
BIN
tensorflow/cat.jpg
Normal file
BIN
tensorflow/cat.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 130 KiB |
|
@ -1,14 +1,12 @@
|
||||||
package main
|
package tensorflow
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
|
|
||||||
tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
||||||
"github.com/tensorflow/tensorflow/tensorflow/go/op"
|
"github.com/tensorflow/tensorflow/tensorflow/go/op"
|
||||||
)
|
)
|
||||||
|
|
||||||
func makeTensorFromImage(imageBuffer *bytes.Buffer, imageFormat string) (*tf.Tensor, error) {
|
func makeTensorFromImage(image string, imageFormat string) (*tf.Tensor, error) {
|
||||||
tensor, err := tf.NewTensor(imageBuffer.String())
|
tensor, err := tf.NewTensor(image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
|
@ -1,18 +1,13 @@
|
||||||
package main
|
package tensorflow
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/julienschmidt/httprouter"
|
|
||||||
tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
||||||
|
"log"
|
||||||
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ClassifyResult struct {
|
type ClassifyResult struct {
|
||||||
|
@ -30,15 +25,41 @@ var (
|
||||||
labels []string
|
labels []string
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func RecognizeImage(image string) (result []LabelResult, err error) {
|
||||||
if err := loadModel(); err != nil {
|
if err := loadModel(); err != nil {
|
||||||
log.Fatal(err)
|
return nil, err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r := httprouter.New()
|
// Make tensor
|
||||||
r.POST("/recognize", recognizeHandler)
|
tensor, err := makeTensorFromImage(image, "jpeg")
|
||||||
log.Fatal(http.ListenAndServe(":8080", r))
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("invalid image")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run inference
|
||||||
|
session, err := tf.NewSession(graph, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
output, err := session.Run(
|
||||||
|
map[tf.Output]*tf.Tensor{
|
||||||
|
graph.Operation("input").Output(0): tensor,
|
||||||
|
},
|
||||||
|
[]tf.Output{
|
||||||
|
graph.Operation("output").Output(0),
|
||||||
|
},
|
||||||
|
nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("could not run inference")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return best labels
|
||||||
|
return findBestLabels(output[0].Value().([][]float32)[0]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadModel() error {
|
func loadModel() error {
|
||||||
|
@ -51,6 +72,7 @@ func loadModel() error {
|
||||||
if err := graph.Import(model, ""); err != nil {
|
if err := graph.Import(model, ""); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load labels
|
// Load labels
|
||||||
labelsFile, err := os.Open("/model/imagenet_comp_graph_label_strings.txt")
|
labelsFile, err := os.Open("/model/imagenet_comp_graph_label_strings.txt")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -58,6 +80,7 @@ func loadModel() error {
|
||||||
}
|
}
|
||||||
defer labelsFile.Close()
|
defer labelsFile.Close()
|
||||||
scanner := bufio.NewScanner(labelsFile)
|
scanner := bufio.NewScanner(labelsFile)
|
||||||
|
|
||||||
// Labels are separated by newlines
|
// Labels are separated by newlines
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
labels = append(labels, scanner.Text())
|
labels = append(labels, scanner.Text())
|
||||||
|
@ -68,54 +91,6 @@ func loadModel() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func recognizeHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
|
||||||
// Read image
|
|
||||||
imageFile, header, err := r.FormFile("image")
|
|
||||||
// Will contain filename and extension
|
|
||||||
imageName := strings.Split(header.Filename, ".")
|
|
||||||
if err != nil {
|
|
||||||
responseError(w, "Could not read image", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer imageFile.Close()
|
|
||||||
var imageBuffer bytes.Buffer
|
|
||||||
// Copy image data to a buffer
|
|
||||||
io.Copy(&imageBuffer, imageFile)
|
|
||||||
|
|
||||||
// ...
|
|
||||||
// Make tensor
|
|
||||||
tensor, err := makeTensorFromImage(&imageBuffer, imageName[:1][0])
|
|
||||||
if err != nil {
|
|
||||||
responseError(w, "Invalid image", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run inference
|
|
||||||
session, err := tf.NewSession(graph, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
output, err := session.Run(
|
|
||||||
map[tf.Output]*tf.Tensor{
|
|
||||||
graph.Operation("input").Output(0): tensor,
|
|
||||||
},
|
|
||||||
[]tf.Output{
|
|
||||||
graph.Operation("output").Output(0),
|
|
||||||
},
|
|
||||||
nil)
|
|
||||||
if err != nil {
|
|
||||||
responseError(w, "Could not run inference", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return best labels
|
|
||||||
responseJSON(w, ClassifyResult{
|
|
||||||
Filename: header.Filename,
|
|
||||||
Labels: findBestLabels(output[0].Value().([][]float32)[0]),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type ByProbability []LabelResult
|
type ByProbability []LabelResult
|
||||||
|
|
||||||
func (a ByProbability) Len() int { return len(a) }
|
func (a ByProbability) Len() int { return len(a) }
|
||||||
|
@ -131,8 +106,10 @@ func findBestLabels(probabilities []float32) []LabelResult {
|
||||||
}
|
}
|
||||||
resultLabels = append(resultLabels, LabelResult{Label: labels[i], Probability: p})
|
resultLabels = append(resultLabels, LabelResult{Label: labels[i], Probability: p})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort by probability
|
// Sort by probability
|
||||||
sort.Sort(ByProbability(resultLabels))
|
sort.Sort(ByProbability(resultLabels))
|
||||||
|
|
||||||
// Return top 5 labels
|
// Return top 5 labels
|
||||||
return resultLabels[:5]
|
return resultLabels[:5]
|
||||||
}
|
}
|
25
tensorflow/recognize_test.go
Normal file
25
tensorflow/recognize_test.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package tensorflow
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"io/ioutil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRecognizeImage(t *testing.T) {
|
||||||
|
if imageBuffer, err := ioutil.ReadFile("cat.jpg"); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else {
|
||||||
|
result, err := RecognizeImage(string(imageBuffer))
|
||||||
|
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.IsType(t, []LabelResult{}, result)
|
||||||
|
assert.Equal(t, 5, len(result))
|
||||||
|
|
||||||
|
assert.Equal(t, "tabby", result[0].Label)
|
||||||
|
assert.Equal(t, "tiger cat", result[1].Label)
|
||||||
|
|
||||||
|
assert.Equal(t, float32(0.23251747), result[1].Probability)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,10 +1,10 @@
|
||||||
package photoprism
|
package photoprism
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/disintegration/imaging"
|
"github.com/disintegration/imaging"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"fmt"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue