Backend: Refactor gorm.DB connection provider in entity package
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
parent
24fc54e326
commit
65e9a58979
14
internal/entity/db.go
Normal file
14
internal/entity/db.go
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
package entity
|
||||||
|
|
||||||
|
import "github.com/jinzhu/gorm"
|
||||||
|
|
||||||
|
// Db returns the default *gorm.DB connection.
|
||||||
|
func Db() *gorm.DB {
|
||||||
|
return dbConn.Db()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnscopedDb returns an unscoped *gorm.DB connection
|
||||||
|
// that returns all records including deleted records.
|
||||||
|
func UnscopedDb() *gorm.DB {
|
||||||
|
return Db().Unscoped()
|
||||||
|
}
|
|
@ -10,7 +10,7 @@ import (
|
||||||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SQL Databases.
|
// Supported test databases.
|
||||||
const (
|
const (
|
||||||
MySQL = "mysql"
|
MySQL = "mysql"
|
||||||
SQLite3 = "sqlite3"
|
SQLite3 = "sqlite3"
|
||||||
|
@ -18,43 +18,16 @@ const (
|
||||||
SQLiteMemoryDSN = ":memory:"
|
SQLiteMemoryDSN = ":memory:"
|
||||||
)
|
)
|
||||||
|
|
||||||
var dbProvider DbProvider
|
// dbConn is the global gorm.DB connection provider.
|
||||||
|
var dbConn Gorm
|
||||||
|
|
||||||
type DbProvider interface {
|
// Gorm is a gorm.DB connection provider interface.
|
||||||
|
type Gorm interface {
|
||||||
Db() *gorm.DB
|
Db() *gorm.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsDialect returns true if the given sql dialect is used.
|
// DbConn is a gorm.DB connection provider.
|
||||||
func IsDialect(name string) bool {
|
type DbConn struct {
|
||||||
return name == Db().Dialect().GetName()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DbDialect returns the sql dialect name.
|
|
||||||
func DbDialect() string {
|
|
||||||
return Db().Dialect().GetName()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDbProvider sets the Gorm database connection provider.
|
|
||||||
func SetDbProvider(provider DbProvider) {
|
|
||||||
dbProvider = provider
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasDbProvider returns true if a db provider exists.
|
|
||||||
func HasDbProvider() bool {
|
|
||||||
return dbProvider != nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Db returns the database connection.
|
|
||||||
func Db() *gorm.DB {
|
|
||||||
return dbProvider.Db()
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnscopedDb returns an unscoped database connection.
|
|
||||||
func UnscopedDb() *gorm.DB {
|
|
||||||
return Db().Unscoped()
|
|
||||||
}
|
|
||||||
|
|
||||||
type Gorm struct {
|
|
||||||
Driver string
|
Driver string
|
||||||
Dsn string
|
Dsn string
|
||||||
|
|
||||||
|
@ -63,8 +36,8 @@ type Gorm struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Db returns the gorm db connection.
|
// Db returns the gorm db connection.
|
||||||
func (g *Gorm) Db() *gorm.DB {
|
func (g *DbConn) Db() *gorm.DB {
|
||||||
g.once.Do(g.Connect)
|
g.once.Do(g.Open)
|
||||||
|
|
||||||
if g.db == nil {
|
if g.db == nil {
|
||||||
log.Fatal("migrate: database not connected")
|
log.Fatal("migrate: database not connected")
|
||||||
|
@ -73,8 +46,8 @@ func (g *Gorm) Db() *gorm.DB {
|
||||||
return g.db
|
return g.db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect creates a new gorm db connection.
|
// Open creates a new gorm db connection.
|
||||||
func (g *Gorm) Connect() {
|
func (g *DbConn) Open() {
|
||||||
db, err := gorm.Open(g.Driver, g.Dsn)
|
db, err := gorm.Open(g.Driver, g.Dsn)
|
||||||
|
|
||||||
if err != nil || db == nil {
|
if err != nil || db == nil {
|
||||||
|
@ -104,7 +77,7 @@ func (g *Gorm) Connect() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the gorm db connection.
|
// Close closes the gorm db connection.
|
||||||
func (g *Gorm) Close() {
|
func (g *DbConn) Close() {
|
||||||
if g.db != nil {
|
if g.db != nil {
|
||||||
if err := g.db.Close(); err != nil {
|
if err := g.db.Close(); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
|
@ -113,3 +86,23 @@ func (g *Gorm) Close() {
|
||||||
g.db = nil
|
g.db = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsDialect returns true if the given sql dialect is used.
|
||||||
|
func IsDialect(name string) bool {
|
||||||
|
return name == Db().Dialect().GetName()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DbDialect returns the sql dialect name.
|
||||||
|
func DbDialect() string {
|
||||||
|
return Db().Dialect().GetName()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDbProvider sets the Gorm database connection provider.
|
||||||
|
func SetDbProvider(conn Gorm) {
|
||||||
|
dbConn = conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasDbProvider returns true if a db provider exists.
|
||||||
|
func HasDbProvider() bool {
|
||||||
|
return dbConn != nil
|
||||||
|
}
|
|
@ -43,7 +43,7 @@ func InitDb(dropDeprecated, runFailed bool, ids []string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitTestDb connects to and completely initializes the test database incl fixtures.
|
// InitTestDb connects to and completely initializes the test database incl fixtures.
|
||||||
func InitTestDb(driver, dsn string) *Gorm {
|
func InitTestDb(driver, dsn string) *DbConn {
|
||||||
if HasDbProvider() {
|
if HasDbProvider() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -68,13 +68,13 @@ func InitTestDb(driver, dsn string) *Gorm {
|
||||||
|
|
||||||
log.Infof("initializing %s test db in %s", driver, dsn)
|
log.Infof("initializing %s test db in %s", driver, dsn)
|
||||||
|
|
||||||
// Create ORM instance.
|
// Create gorm.DB connection provider.
|
||||||
db := &Gorm{
|
db := &DbConn{
|
||||||
Driver: driver,
|
Driver: driver,
|
||||||
Dsn: dsn,
|
Dsn: dsn,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert test fixtures.
|
// Insert test fixtures into the database.
|
||||||
SetDbProvider(db)
|
SetDbProvider(db)
|
||||||
ResetTestFixtures()
|
ResetTestFixtures()
|
||||||
File{}.RegenerateIndex()
|
File{}.RegenerateIndex()
|
||||||
|
|
Loading…
Reference in a new issue