package storage import ( "database/sql" "fmt" "log" "os" "time" _ "github.com/go-sql-driver/mysql" ) // DBConfig 数据库配置 type DBConfig struct { Host string Port int User string Password string Database string } var db *sql.DB // InitDB 初始化数据库连接 func InitDB(config DBConfig) error { dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", config.User, config.Password, config.Host, config.Port, config.Database) var err error db, err = sql.Open("mysql", dsn) if err != nil { return fmt.Errorf("failed to open database: %v", err) } // 配置连接池 db.SetMaxOpenConns(25) db.SetMaxIdleConns(5) db.SetConnMaxLifetime(5 * time.Minute) // 测试连接 if err := db.Ping(); err != nil { return fmt.Errorf("failed to ping database: %v", err) } log.Println("✅ Database connection established") // 自动迁移表结构 if err := autoMigrate(); err != nil { return fmt.Errorf("failed to auto migrate: %v", err) } return nil } // autoMigrate 自动创建或更新表结构 func autoMigrate() error { // 创建 customers 表 createCustomersTable := ` CREATE TABLE IF NOT EXISTS customers ( id VARCHAR(64) PRIMARY KEY, created_at DATETIME NOT NULL, customer_name VARCHAR(255) NOT NULL, intended_product VARCHAR(255), version VARCHAR(100), description TEXT, solution TEXT, type VARCHAR(100), module VARCHAR(100), status_progress VARCHAR(100), reporter VARCHAR(255), screenshots LONGTEXT, INDEX idx_customer_name (customer_name), INDEX idx_created_at (created_at) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; ` if _, err := db.Exec(createCustomersTable); err != nil { return fmt.Errorf("failed to create customers table: %v", err) } // 创建 followups 表 createFollowupsTable := ` CREATE TABLE IF NOT EXISTS followups ( id VARCHAR(64) PRIMARY KEY, created_at DATETIME NOT NULL, customer_name VARCHAR(255) NOT NULL, deal_status VARCHAR(50), customer_level VARCHAR(50), industry VARCHAR(100), follow_up_time DATETIME, notification_sent BOOLEAN DEFAULT FALSE, INDEX idx_customer_name (customer_name), INDEX idx_follow_up_time (follow_up_time) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; ` if _, err := db.Exec(createFollowupsTable); err != nil { return fmt.Errorf("failed to create followups table: %v", err) } // 创建 trial_periods 表 createTrialPeriodsTable := ` CREATE TABLE IF NOT EXISTS trial_periods ( id VARCHAR(64) PRIMARY KEY, customer_name VARCHAR(255) NOT NULL, source VARCHAR(100), intended_product VARCHAR(255), start_time DATETIME, end_time DATETIME, is_trial BOOLEAN DEFAULT TRUE, created_at DATETIME NOT NULL, INDEX idx_customer_name (customer_name), INDEX idx_end_time (end_time) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; ` if _, err := db.Exec(createTrialPeriodsTable); err != nil { return fmt.Errorf("failed to create trial_periods table: %v", err) } // 检查并添加/更新 screenshots 列 // 检查列是否存在以及数据类型 var columnType sql.NullString err := db.QueryRow(` SELECT DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'customers' AND COLUMN_NAME = 'screenshots' `).Scan(&columnType) if err == sql.ErrNoRows { // 列不存在,添加新列 _, err = db.Exec("ALTER TABLE customers ADD COLUMN screenshots LONGTEXT") if err != nil { return fmt.Errorf("failed to add screenshots column: %v", err) } log.Println("✅ Added screenshots column (LONGTEXT) to customers table") } else if err != nil { return fmt.Errorf("failed to check screenshots column: %v", err) } else if columnType.Valid && columnType.String != "longtext" { // 列存在但类型不是 LONGTEXT,修改列类型 _, err = db.Exec("ALTER TABLE customers MODIFY COLUMN screenshots LONGTEXT") if err != nil { return fmt.Errorf("failed to modify screenshots column type: %v", err) } log.Printf("✅ Modified screenshots column type from %s to LONGTEXT\n", columnType.String) } // 检查并添加 deal_status 列到 trial_periods 表 var dealStatusColumn sql.NullString err = db.QueryRow(` SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'trial_periods' AND COLUMN_NAME = 'deal_status' `).Scan(&dealStatusColumn) if err == sql.ErrNoRows { // deal_status 列不存在,添加新列 _, err = db.Exec("ALTER TABLE trial_periods ADD COLUMN deal_status VARCHAR(50) DEFAULT '初步接触'") if err != nil { return fmt.Errorf("failed to add deal_status column: %v", err) } log.Println("✅ Added deal_status column to trial_periods table") } else if err != nil { return fmt.Errorf("failed to check deal_status column: %v", err) } // 数据清理:将旧的状态名称统一更新为新名称 _, _ = db.Exec("UPDATE trial_periods SET deal_status = '初步接触' WHERE deal_status = '潜在客户' OR deal_status = '' OR deal_status IS NULL") _, _ = db.Exec("UPDATE trial_periods SET deal_status = '需求确认' WHERE deal_status = '试用中'") log.Println("✅ Database tables migrated successfully") return nil } // GetDB 获取数据库连接 func GetDB() *sql.DB { return db } // CloseDB 关闭数据库连接 func CloseDB() error { if db != nil { return db.Close() } return nil } // GetDBConfigFromEnv 从环境变量获取数据库配置 func GetDBConfigFromEnv() DBConfig { config := DBConfig{ Host: "localhost", Port: 3306, User: "root", Password: "", Database: "crm_db", } if host := os.Getenv("DB_HOST"); host != "" { config.Host = host } if user := os.Getenv("DB_USER"); user != "" { config.User = user } if pwd := os.Getenv("DB_PASSWORD"); pwd != "" { config.Password = pwd } if dbName := os.Getenv("DB_NAME"); dbName != "" { config.Database = dbName } if port := os.Getenv("DB_PORT"); port != "" { fmt.Sscanf(port, "%d", &config.Port) } return config }