Backend: Refactor gorm.DB connection provider in entity package

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer 2022-10-01 15:17:04 +02:00
parent 24fc54e326
commit 65e9a58979
4 changed files with 50 additions and 43 deletions

14
internal/entity/db.go Normal file
View file

@ -0,0 +1,14 @@
package entity
import "github.com/jinzhu/gorm"
// Db returns the default *gorm.DB connection.
func Db() *gorm.DB {
return dbConn.Db()
}
// UnscopedDb returns an unscoped *gorm.DB connection
// that returns all records including deleted records.
func UnscopedDb() *gorm.DB {
return Db().Unscoped()
}

View file

@ -10,7 +10,7 @@ import (
_ "github.com/jinzhu/gorm/dialects/sqlite"
)
// SQL Databases.
// Supported test databases.
const (
MySQL = "mysql"
SQLite3 = "sqlite3"
@ -18,43 +18,16 @@ const (
SQLiteMemoryDSN = ":memory:"
)
var dbProvider DbProvider
// dbConn is the global gorm.DB connection provider.
var dbConn Gorm
type DbProvider interface {
// Gorm is a gorm.DB connection provider interface.
type Gorm interface {
Db() *gorm.DB
}
// IsDialect returns true if the given sql dialect is used.
func IsDialect(name string) bool {
return name == Db().Dialect().GetName()
}
// DbDialect returns the sql dialect name.
func DbDialect() string {
return Db().Dialect().GetName()
}
// SetDbProvider sets the Gorm database connection provider.
func SetDbProvider(provider DbProvider) {
dbProvider = provider
}
// HasDbProvider returns true if a db provider exists.
func HasDbProvider() bool {
return dbProvider != nil
}
// Db returns the database connection.
func Db() *gorm.DB {
return dbProvider.Db()
}
// UnscopedDb returns an unscoped database connection.
func UnscopedDb() *gorm.DB {
return Db().Unscoped()
}
type Gorm struct {
// DbConn is a gorm.DB connection provider.
type DbConn struct {
Driver string
Dsn string
@ -63,8 +36,8 @@ type Gorm struct {
}
// Db returns the gorm db connection.
func (g *Gorm) Db() *gorm.DB {
g.once.Do(g.Connect)
func (g *DbConn) Db() *gorm.DB {
g.once.Do(g.Open)
if g.db == nil {
log.Fatal("migrate: database not connected")
@ -73,8 +46,8 @@ func (g *Gorm) Db() *gorm.DB {
return g.db
}
// Connect creates a new gorm db connection.
func (g *Gorm) Connect() {
// Open creates a new gorm db connection.
func (g *DbConn) Open() {
db, err := gorm.Open(g.Driver, g.Dsn)
if err != nil || db == nil {
@ -104,7 +77,7 @@ func (g *Gorm) Connect() {
}
// Close closes the gorm db connection.
func (g *Gorm) Close() {
func (g *DbConn) Close() {
if g.db != nil {
if err := g.db.Close(); err != nil {
log.Fatal(err)
@ -113,3 +86,23 @@ func (g *Gorm) Close() {
g.db = nil
}
}
// IsDialect returns true if the given sql dialect is used.
func IsDialect(name string) bool {
return name == Db().Dialect().GetName()
}
// DbDialect returns the sql dialect name.
func DbDialect() string {
return Db().Dialect().GetName()
}
// SetDbProvider sets the Gorm database connection provider.
func SetDbProvider(conn Gorm) {
dbConn = conn
}
// HasDbProvider returns true if a db provider exists.
func HasDbProvider() bool {
return dbConn != nil
}

View file

@ -43,7 +43,7 @@ func InitDb(dropDeprecated, runFailed bool, ids []string) {
}
// InitTestDb connects to and completely initializes the test database incl fixtures.
func InitTestDb(driver, dsn string) *Gorm {
func InitTestDb(driver, dsn string) *DbConn {
if HasDbProvider() {
return nil
}
@ -68,13 +68,13 @@ func InitTestDb(driver, dsn string) *Gorm {
log.Infof("initializing %s test db in %s", driver, dsn)
// Create ORM instance.
db := &Gorm{
// Create gorm.DB connection provider.
db := &DbConn{
Driver: driver,
Dsn: dsn,
}
// Insert test fixtures.
// Insert test fixtures into the database.
SetDbProvider(db)
ResetTestFixtures()
File{}.RegenerateIndex()