Backend: Refactor gorm.DB connection provider in entity package
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
parent
24fc54e326
commit
65e9a58979
14
internal/entity/db.go
Normal file
14
internal/entity/db.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package entity
|
||||
|
||||
import "github.com/jinzhu/gorm"
|
||||
|
||||
// Db returns the default *gorm.DB connection.
|
||||
func Db() *gorm.DB {
|
||||
return dbConn.Db()
|
||||
}
|
||||
|
||||
// UnscopedDb returns an unscoped *gorm.DB connection
|
||||
// that returns all records including deleted records.
|
||||
func UnscopedDb() *gorm.DB {
|
||||
return Db().Unscoped()
|
||||
}
|
|
@ -10,7 +10,7 @@ import (
|
|||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||
)
|
||||
|
||||
// SQL Databases.
|
||||
// Supported test databases.
|
||||
const (
|
||||
MySQL = "mysql"
|
||||
SQLite3 = "sqlite3"
|
||||
|
@ -18,43 +18,16 @@ const (
|
|||
SQLiteMemoryDSN = ":memory:"
|
||||
)
|
||||
|
||||
var dbProvider DbProvider
|
||||
// dbConn is the global gorm.DB connection provider.
|
||||
var dbConn Gorm
|
||||
|
||||
type DbProvider interface {
|
||||
// Gorm is a gorm.DB connection provider interface.
|
||||
type Gorm interface {
|
||||
Db() *gorm.DB
|
||||
}
|
||||
|
||||
// IsDialect returns true if the given sql dialect is used.
|
||||
func IsDialect(name string) bool {
|
||||
return name == Db().Dialect().GetName()
|
||||
}
|
||||
|
||||
// DbDialect returns the sql dialect name.
|
||||
func DbDialect() string {
|
||||
return Db().Dialect().GetName()
|
||||
}
|
||||
|
||||
// SetDbProvider sets the Gorm database connection provider.
|
||||
func SetDbProvider(provider DbProvider) {
|
||||
dbProvider = provider
|
||||
}
|
||||
|
||||
// HasDbProvider returns true if a db provider exists.
|
||||
func HasDbProvider() bool {
|
||||
return dbProvider != nil
|
||||
}
|
||||
|
||||
// Db returns the database connection.
|
||||
func Db() *gorm.DB {
|
||||
return dbProvider.Db()
|
||||
}
|
||||
|
||||
// UnscopedDb returns an unscoped database connection.
|
||||
func UnscopedDb() *gorm.DB {
|
||||
return Db().Unscoped()
|
||||
}
|
||||
|
||||
type Gorm struct {
|
||||
// DbConn is a gorm.DB connection provider.
|
||||
type DbConn struct {
|
||||
Driver string
|
||||
Dsn string
|
||||
|
||||
|
@ -63,8 +36,8 @@ type Gorm struct {
|
|||
}
|
||||
|
||||
// Db returns the gorm db connection.
|
||||
func (g *Gorm) Db() *gorm.DB {
|
||||
g.once.Do(g.Connect)
|
||||
func (g *DbConn) Db() *gorm.DB {
|
||||
g.once.Do(g.Open)
|
||||
|
||||
if g.db == nil {
|
||||
log.Fatal("migrate: database not connected")
|
||||
|
@ -73,8 +46,8 @@ func (g *Gorm) Db() *gorm.DB {
|
|||
return g.db
|
||||
}
|
||||
|
||||
// Connect creates a new gorm db connection.
|
||||
func (g *Gorm) Connect() {
|
||||
// Open creates a new gorm db connection.
|
||||
func (g *DbConn) Open() {
|
||||
db, err := gorm.Open(g.Driver, g.Dsn)
|
||||
|
||||
if err != nil || db == nil {
|
||||
|
@ -104,7 +77,7 @@ func (g *Gorm) Connect() {
|
|||
}
|
||||
|
||||
// Close closes the gorm db connection.
|
||||
func (g *Gorm) Close() {
|
||||
func (g *DbConn) Close() {
|
||||
if g.db != nil {
|
||||
if err := g.db.Close(); err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -113,3 +86,23 @@ func (g *Gorm) Close() {
|
|||
g.db = nil
|
||||
}
|
||||
}
|
||||
|
||||
// IsDialect returns true if the given sql dialect is used.
|
||||
func IsDialect(name string) bool {
|
||||
return name == Db().Dialect().GetName()
|
||||
}
|
||||
|
||||
// DbDialect returns the sql dialect name.
|
||||
func DbDialect() string {
|
||||
return Db().Dialect().GetName()
|
||||
}
|
||||
|
||||
// SetDbProvider sets the Gorm database connection provider.
|
||||
func SetDbProvider(conn Gorm) {
|
||||
dbConn = conn
|
||||
}
|
||||
|
||||
// HasDbProvider returns true if a db provider exists.
|
||||
func HasDbProvider() bool {
|
||||
return dbConn != nil
|
||||
}
|
|
@ -43,7 +43,7 @@ func InitDb(dropDeprecated, runFailed bool, ids []string) {
|
|||
}
|
||||
|
||||
// InitTestDb connects to and completely initializes the test database incl fixtures.
|
||||
func InitTestDb(driver, dsn string) *Gorm {
|
||||
func InitTestDb(driver, dsn string) *DbConn {
|
||||
if HasDbProvider() {
|
||||
return nil
|
||||
}
|
||||
|
@ -68,13 +68,13 @@ func InitTestDb(driver, dsn string) *Gorm {
|
|||
|
||||
log.Infof("initializing %s test db in %s", driver, dsn)
|
||||
|
||||
// Create ORM instance.
|
||||
db := &Gorm{
|
||||
// Create gorm.DB connection provider.
|
||||
db := &DbConn{
|
||||
Driver: driver,
|
||||
Dsn: dsn,
|
||||
}
|
||||
|
||||
// Insert test fixtures.
|
||||
// Insert test fixtures into the database.
|
||||
SetDbProvider(db)
|
||||
ResetTestFixtures()
|
||||
File{}.RegenerateIndex()
|
||||
|
|
Loading…
Reference in a new issue