mirror of
https://github.com/DullJZ/s3-balance.git
synced 2026-06-26 21:41:21 +08:00
238 lines
5.4 KiB
Go
238 lines
5.4 KiB
Go
package database
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"log"
|
||
"os"
|
||
"path/filepath"
|
||
"time"
|
||
|
||
"github.com/DullJZ/s3-balance/internal/config"
|
||
"github.com/DullJZ/s3-balance/internal/storage"
|
||
"gorm.io/driver/mysql"
|
||
"gorm.io/driver/postgres"
|
||
"gorm.io/driver/sqlite"
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/logger"
|
||
_ "modernc.org/sqlite"
|
||
)
|
||
|
||
// DB 全局数据库连接
|
||
var DB *gorm.DB
|
||
|
||
// Initialize 初始化数据库连接
|
||
func Initialize(cfg *config.DatabaseConfig) error {
|
||
var err error
|
||
|
||
// 设置日志级别
|
||
logLevel := getLogLevel(cfg.LogLevel)
|
||
|
||
// GORM配置
|
||
gormConfig := &gorm.Config{
|
||
Logger: logger.Default.LogMode(logLevel),
|
||
NowFunc: func() time.Time {
|
||
return time.Now().Local()
|
||
},
|
||
QueryFields: true,
|
||
}
|
||
|
||
// 根据数据库类型创建连接
|
||
switch cfg.Type {
|
||
case "sqlite":
|
||
DB, err = connectSQLite(cfg.DSN, gormConfig)
|
||
case "mysql":
|
||
DB, err = connectMySQL(cfg.DSN, gormConfig)
|
||
case "postgres", "postgresql":
|
||
DB, err = connectPostgreSQL(cfg.DSN, gormConfig)
|
||
default:
|
||
return fmt.Errorf("unsupported database type: %s", cfg.Type)
|
||
}
|
||
|
||
if err != nil {
|
||
return fmt.Errorf("failed to connect to database: %w", err)
|
||
}
|
||
|
||
// 获取底层SQL数据库连接
|
||
sqlDB, err := DB.DB()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get sql.DB: %w", err)
|
||
}
|
||
|
||
// 设置连接池参数
|
||
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
|
||
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
|
||
sqlDB.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetime) * time.Second)
|
||
|
||
// 测试连接
|
||
if err := sqlDB.Ping(); err != nil {
|
||
return fmt.Errorf("failed to ping database: %w", err)
|
||
}
|
||
|
||
// 自动迁移
|
||
if cfg.AutoMigrate {
|
||
if err := AutoMigrate(); err != nil {
|
||
return fmt.Errorf("failed to auto migrate: %w", err)
|
||
}
|
||
}
|
||
|
||
log.Printf("Successfully connected to %s database", cfg.Type)
|
||
return nil
|
||
}
|
||
|
||
// connectSQLite 连接SQLite数据库(使用modernc.org/sqlite,支持非CGO)
|
||
func connectSQLite(dsn string, gormConfig *gorm.Config) (*gorm.DB, error) {
|
||
// 创建数据目录(如果不存在)
|
||
dir := filepath.Dir(dsn)
|
||
if dir != "" && dir != "." {
|
||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
||
}
|
||
}
|
||
|
||
// 添加SQLite特定参数
|
||
if dsn != ":memory:" {
|
||
dsn = fmt.Sprintf("%s?_journal_mode=WAL&_timeout=5000&_synchronous=NORMAL&_cache_size=10000", dsn)
|
||
}
|
||
|
||
// 使用modernc.org/sqlite驱动(纯Go实现,无需CGO)
|
||
dialector := sqlite.Dialector{
|
||
DriverName: "sqlite",
|
||
DSN: dsn,
|
||
}
|
||
|
||
return gorm.Open(dialector, gormConfig)
|
||
} // connectMySQL 连接MySQL数据库
|
||
func connectMySQL(dsn string, gormConfig *gorm.Config) (*gorm.DB, error) {
|
||
// MySQL DSN示例: user:password@tcp(localhost:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local
|
||
|
||
// 如果DSN中没有指定字符集,添加默认字符集
|
||
if dsn != "" {
|
||
dsn = ensureMySQLParams(dsn)
|
||
}
|
||
|
||
return gorm.Open(mysql.Open(dsn), gormConfig)
|
||
}
|
||
|
||
// connectPostgreSQL 连接PostgreSQL数据库
|
||
func connectPostgreSQL(dsn string, gormConfig *gorm.Config) (*gorm.DB, error) {
|
||
// PostgreSQL DSN示例: host=localhost user=user password=password dbname=mydb port=5432 sslmode=disable TimeZone=Asia/Shanghai
|
||
|
||
return gorm.Open(postgres.Open(dsn), gormConfig)
|
||
}
|
||
|
||
// ensureMySQLParams 确保MySQL DSN包含必要的参数
|
||
func ensureMySQLParams(dsn string) string {
|
||
params := map[string]string{
|
||
"charset": "utf8mb4",
|
||
"parseTime": "True",
|
||
"loc": "Local",
|
||
}
|
||
|
||
separator := "?"
|
||
if len(dsn) > 0 && dsn[len(dsn)-1] == '?' {
|
||
separator = ""
|
||
} else if contains(dsn, "?") {
|
||
separator = "&"
|
||
}
|
||
|
||
for key, value := range params {
|
||
if !contains(dsn, key+"=") {
|
||
dsn = fmt.Sprintf("%s%s%s=%s", dsn, separator, key, value)
|
||
separator = "&"
|
||
}
|
||
}
|
||
|
||
return dsn
|
||
}
|
||
|
||
// contains 检查字符串是否包含子串
|
||
func contains(s, substr string) bool {
|
||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||
}
|
||
|
||
func containsHelper(s, substr string) bool {
|
||
for i := 0; i <= len(s)-len(substr); i++ {
|
||
if s[i:i+len(substr)] == substr {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// getLogLevel 获取GORM日志级别
|
||
func getLogLevel(level string) logger.LogLevel {
|
||
switch level {
|
||
case "silent":
|
||
return logger.Silent
|
||
case "error":
|
||
return logger.Error
|
||
case "warn", "warning":
|
||
return logger.Warn
|
||
case "info":
|
||
return logger.Info
|
||
default:
|
||
return logger.Warn
|
||
}
|
||
}
|
||
|
||
// AutoMigrate 自动迁移数据库表
|
||
func AutoMigrate() error {
|
||
models := []interface{}{
|
||
&storage.Object{},
|
||
&storage.BucketStats{},
|
||
&storage.BucketMonthlyStats{},
|
||
&storage.UploadSession{},
|
||
&storage.AccessLog{},
|
||
&storage.VirtualBucketMapping{},
|
||
}
|
||
|
||
for _, model := range models {
|
||
if err := DB.AutoMigrate(model); err != nil {
|
||
return fmt.Errorf("failed to migrate %T: %w", model, err)
|
||
}
|
||
}
|
||
|
||
log.Println("Database migration completed successfully")
|
||
return nil
|
||
}
|
||
|
||
// Close 关闭数据库连接
|
||
func Close() error {
|
||
if DB != nil {
|
||
sqlDB, err := DB.DB()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return sqlDB.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// HealthCheck 健康检查
|
||
func HealthCheck() error {
|
||
if DB == nil {
|
||
return fmt.Errorf("database not initialized")
|
||
}
|
||
|
||
sqlDB, err := DB.DB()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
return sqlDB.PingContext(ctx)
|
||
}
|
||
|
||
// Transaction 执行事务
|
||
func Transaction(fn func(*gorm.DB) error) error {
|
||
return DB.Transaction(fn)
|
||
}
|
||
|
||
// GetDB 获取数据库连接
|
||
func GetDB() *gorm.DB {
|
||
return DB
|
||
}
|