Backend: Refactor gorm.DB connection provider in entity package

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer 2022-10-01 15:17:04 +02:00
parent 24fc54e326
commit 65e9a58979
4 changed files with 50 additions and 43 deletions

14
internal/entity/db.go Normal file
View file

@ -0,0 +1,14 @@
package entity
import "github.com/jinzhu/gorm"
// Db returns the default *gorm.DB connection.
func Db() *gorm.DB {
return dbConn.Db()
}
// UnscopedDb returns an unscoped *gorm.DB connection
// that returns all records including deleted records.
func UnscopedDb() *gorm.DB {
return Db().Unscoped()
}

View file

@ -10,7 +10,7 @@ import (
_ "github.com/jinzhu/gorm/dialects/sqlite" _ "github.com/jinzhu/gorm/dialects/sqlite"
) )
// SQL Databases. // Supported test databases.
const ( const (
MySQL = "mysql" MySQL = "mysql"
SQLite3 = "sqlite3" SQLite3 = "sqlite3"
@ -18,43 +18,16 @@ const (
SQLiteMemoryDSN = ":memory:" SQLiteMemoryDSN = ":memory:"
) )
var dbProvider DbProvider // dbConn is the global gorm.DB connection provider.
var dbConn Gorm
type DbProvider interface { // Gorm is a gorm.DB connection provider interface.
type Gorm interface {
Db() *gorm.DB Db() *gorm.DB
} }
// IsDialect returns true if the given sql dialect is used. // DbConn is a gorm.DB connection provider.
func IsDialect(name string) bool { type DbConn struct {
return name == Db().Dialect().GetName()
}
// DbDialect returns the sql dialect name.
func DbDialect() string {
return Db().Dialect().GetName()
}
// SetDbProvider sets the Gorm database connection provider.
func SetDbProvider(provider DbProvider) {
dbProvider = provider
}
// HasDbProvider returns true if a db provider exists.
func HasDbProvider() bool {
return dbProvider != nil
}
// Db returns the database connection.
func Db() *gorm.DB {
return dbProvider.Db()
}
// UnscopedDb returns an unscoped database connection.
func UnscopedDb() *gorm.DB {
return Db().Unscoped()
}
type Gorm struct {
Driver string Driver string
Dsn string Dsn string
@ -63,8 +36,8 @@ type Gorm struct {
} }
// Db returns the gorm db connection. // Db returns the gorm db connection.
func (g *Gorm) Db() *gorm.DB { func (g *DbConn) Db() *gorm.DB {
g.once.Do(g.Connect) g.once.Do(g.Open)
if g.db == nil { if g.db == nil {
log.Fatal("migrate: database not connected") log.Fatal("migrate: database not connected")
@ -73,8 +46,8 @@ func (g *Gorm) Db() *gorm.DB {
return g.db return g.db
} }
// Connect creates a new gorm db connection. // Open creates a new gorm db connection.
func (g *Gorm) Connect() { func (g *DbConn) Open() {
db, err := gorm.Open(g.Driver, g.Dsn) db, err := gorm.Open(g.Driver, g.Dsn)
if err != nil || db == nil { if err != nil || db == nil {
@ -104,7 +77,7 @@ func (g *Gorm) Connect() {
} }
// Close closes the gorm db connection. // Close closes the gorm db connection.
func (g *Gorm) Close() { func (g *DbConn) Close() {
if g.db != nil { if g.db != nil {
if err := g.db.Close(); err != nil { if err := g.db.Close(); err != nil {
log.Fatal(err) log.Fatal(err)
@ -113,3 +86,23 @@ func (g *Gorm) Close() {
g.db = nil g.db = nil
} }
} }
// IsDialect returns true if the given sql dialect is used.
func IsDialect(name string) bool {
return name == Db().Dialect().GetName()
}
// DbDialect returns the sql dialect name.
func DbDialect() string {
return Db().Dialect().GetName()
}
// SetDbProvider sets the Gorm database connection provider.
func SetDbProvider(conn Gorm) {
dbConn = conn
}
// HasDbProvider returns true if a db provider exists.
func HasDbProvider() bool {
return dbConn != nil
}

View file

@ -43,7 +43,7 @@ func InitDb(dropDeprecated, runFailed bool, ids []string) {
} }
// InitTestDb connects to and completely initializes the test database incl fixtures. // InitTestDb connects to and completely initializes the test database incl fixtures.
func InitTestDb(driver, dsn string) *Gorm { func InitTestDb(driver, dsn string) *DbConn {
if HasDbProvider() { if HasDbProvider() {
return nil return nil
} }
@ -68,13 +68,13 @@ func InitTestDb(driver, dsn string) *Gorm {
log.Infof("initializing %s test db in %s", driver, dsn) log.Infof("initializing %s test db in %s", driver, dsn)
// Create ORM instance. // Create gorm.DB connection provider.
db := &Gorm{ db := &DbConn{
Driver: driver, Driver: driver,
Dsn: dsn, Dsn: dsn,
} }
// Insert test fixtures. // Insert test fixtures into the database.
SetDbProvider(db) SetDbProvider(db)
ResetTestFixtures() ResetTestFixtures()
File{}.RegenerateIndex() File{}.RegenerateIndex()