package storage import ( "database/sql" "fmt" "log" "os" "time" _ "github.com/go-sql-driver/mysql" ) // DBConfig 数据库配置 type DBConfig struct { Host string Port int User string Password string Database string } var db *sql.DB // InitDB 初始化数据库连接 func InitDB(config DBConfig) error { dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", config.User, config.Password, config.Host, config.Port, config.Database) var err error db, err = sql.Open("mysql", dsn) if err != nil { return fmt.Errorf("failed to open database: %v", err) } // 配置连接池 db.SetMaxOpenConns(25) db.SetMaxIdleConns(5) db.SetConnMaxLifetime(5 * time.Minute) // 测试连接 if err := db.Ping(); err != nil { return fmt.Errorf("failed to ping database: %v", err) } log.Println("✅ Database connection established") // 自动迁移表结构 if err := autoMigrate(); err != nil { return fmt.Errorf("failed to auto migrate: %v", err) } return nil } // autoMigrate 自动创建或更新表结构 func autoMigrate() error { // 创建 customers 表 createCustomersTable := ` CREATE TABLE IF NOT EXISTS customers ( id VARCHAR(64) PRIMARY KEY, created_at DATETIME NOT NULL, customer_name VARCHAR(255) NOT NULL, intended_product VARCHAR(255), version VARCHAR(100), description TEXT, solution TEXT, type VARCHAR(100), module VARCHAR(100), status_progress VARCHAR(100), reporter VARCHAR(255), screenshots TEXT, INDEX idx_customer_name (customer_name), INDEX idx_created_at (created_at) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; ` if _, err := db.Exec(createCustomersTable); err != nil { return fmt.Errorf("failed to create customers table: %v", err) } // 创建 followups 表 createFollowupsTable := ` CREATE TABLE IF NOT EXISTS followups ( id VARCHAR(64) PRIMARY KEY, created_at DATETIME NOT NULL, customer_name VARCHAR(255) NOT NULL, deal_status VARCHAR(50), customer_level VARCHAR(50), industry VARCHAR(100), follow_up_time DATETIME, notification_sent BOOLEAN DEFAULT FALSE, INDEX idx_customer_name (customer_name), INDEX idx_follow_up_time (follow_up_time) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; ` if _, err := db.Exec(createFollowupsTable); err != nil { return fmt.Errorf("failed to create followups table: %v", err) } // 创建 trial_periods 表 createTrialPeriodsTable := ` CREATE TABLE IF NOT EXISTS trial_periods ( id VARCHAR(64) PRIMARY KEY, customer_name VARCHAR(255) NOT NULL, source VARCHAR(100), intended_product VARCHAR(255), start_time DATETIME, end_time DATETIME, is_trial BOOLEAN DEFAULT TRUE, created_at DATETIME NOT NULL, INDEX idx_customer_name (customer_name), INDEX idx_end_time (end_time) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; ` if _, err := db.Exec(createTrialPeriodsTable); err != nil { return fmt.Errorf("failed to create trial_periods table: %v", err) } // 检查并添加 screenshots 列(如果不存在) // MySQL 不支持 IF NOT EXISTS,所以我们需要先检查列是否存在 var columnExists int err := db.QueryRow(` SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'customers' AND COLUMN_NAME = 'screenshots' `).Scan(&columnExists) if err != nil { return fmt.Errorf("failed to check screenshots column: %v", err) } if columnExists == 0 { _, err = db.Exec("ALTER TABLE customers ADD COLUMN screenshots TEXT") if err != nil { return fmt.Errorf("failed to add screenshots column: %v", err) } log.Println("✅ Added screenshots column to customers table") } log.Println("✅ Database tables migrated successfully") return nil } // GetDB 获取数据库连接 func GetDB() *sql.DB { return db } // CloseDB 关闭数据库连接 func CloseDB() error { if db != nil { return db.Close() } return nil } // GetDBConfigFromEnv 从环境变量获取数据库配置 func GetDBConfigFromEnv() DBConfig { config := DBConfig{ Host: "localhost", Port: 3306, User: "root", Password: "", Database: "crm_db", } if host := os.Getenv("DB_HOST"); host != "" { config.Host = host } if user := os.Getenv("DB_USER"); user != "" { config.User = user } if pwd := os.Getenv("DB_PASSWORD"); pwd != "" { config.Password = pwd } if dbName := os.Getenv("DB_NAME"); dbName != "" { config.Database = dbName } if port := os.Getenv("DB_PORT"); port != "" { fmt.Sscanf(port, "%d", &config.Port) } return config }