220 lines
5.9 KiB
Go
220 lines
5.9 KiB
Go
package storage
|
||
|
||
import (
|
||
"database/sql"
|
||
"fmt"
|
||
"log"
|
||
"os"
|
||
"time"
|
||
|
||
_ "github.com/go-sql-driver/mysql"
|
||
)
|
||
|
||
// DBConfig 数据库配置
|
||
type DBConfig struct {
|
||
Host string
|
||
Port int
|
||
User string
|
||
Password string
|
||
Database string
|
||
}
|
||
|
||
var db *sql.DB
|
||
|
||
// InitDB 初始化数据库连接
|
||
func InitDB(config DBConfig) error {
|
||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||
config.User, config.Password, config.Host, config.Port, config.Database)
|
||
|
||
var err error
|
||
db, err = sql.Open("mysql", dsn)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to open database: %v", err)
|
||
}
|
||
|
||
// 配置连接池
|
||
db.SetMaxOpenConns(25)
|
||
db.SetMaxIdleConns(5)
|
||
db.SetConnMaxLifetime(5 * time.Minute)
|
||
|
||
// 测试连接
|
||
if err := db.Ping(); err != nil {
|
||
return fmt.Errorf("failed to ping database: %v", err)
|
||
}
|
||
|
||
log.Println("✅ Database connection established")
|
||
|
||
// 自动迁移表结构
|
||
if err := autoMigrate(); err != nil {
|
||
return fmt.Errorf("failed to auto migrate: %v", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// autoMigrate 自动创建或更新表结构
|
||
func autoMigrate() error {
|
||
// 创建 customers 表
|
||
createCustomersTable := `
|
||
CREATE TABLE IF NOT EXISTS customers (
|
||
id VARCHAR(64) PRIMARY KEY,
|
||
created_at DATETIME NOT NULL,
|
||
customer_name VARCHAR(255) NOT NULL,
|
||
intended_product VARCHAR(255),
|
||
version VARCHAR(100),
|
||
description TEXT,
|
||
solution TEXT,
|
||
type VARCHAR(100),
|
||
module VARCHAR(100),
|
||
status_progress VARCHAR(100),
|
||
reporter VARCHAR(255),
|
||
screenshots LONGTEXT,
|
||
INDEX idx_customer_name (customer_name),
|
||
INDEX idx_created_at (created_at)
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||
`
|
||
|
||
if _, err := db.Exec(createCustomersTable); err != nil {
|
||
return fmt.Errorf("failed to create customers table: %v", err)
|
||
}
|
||
|
||
// 创建 followups 表
|
||
createFollowupsTable := `
|
||
CREATE TABLE IF NOT EXISTS followups (
|
||
id VARCHAR(64) PRIMARY KEY,
|
||
created_at DATETIME NOT NULL,
|
||
customer_name VARCHAR(255) NOT NULL,
|
||
deal_status VARCHAR(50),
|
||
customer_level VARCHAR(50),
|
||
industry VARCHAR(100),
|
||
follow_up_time DATETIME,
|
||
notification_sent BOOLEAN DEFAULT FALSE,
|
||
INDEX idx_customer_name (customer_name),
|
||
INDEX idx_follow_up_time (follow_up_time)
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||
`
|
||
|
||
if _, err := db.Exec(createFollowupsTable); err != nil {
|
||
return fmt.Errorf("failed to create followups table: %v", err)
|
||
}
|
||
|
||
// 创建 trial_periods 表
|
||
createTrialPeriodsTable := `
|
||
CREATE TABLE IF NOT EXISTS trial_periods (
|
||
id VARCHAR(64) PRIMARY KEY,
|
||
customer_name VARCHAR(255) NOT NULL,
|
||
source VARCHAR(100),
|
||
intended_product VARCHAR(255),
|
||
start_time DATETIME,
|
||
end_time DATETIME,
|
||
is_trial BOOLEAN DEFAULT TRUE,
|
||
created_at DATETIME NOT NULL,
|
||
INDEX idx_customer_name (customer_name),
|
||
INDEX idx_end_time (end_time)
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||
`
|
||
|
||
if _, err := db.Exec(createTrialPeriodsTable); err != nil {
|
||
return fmt.Errorf("failed to create trial_periods table: %v", err)
|
||
}
|
||
|
||
// 检查并添加/更新 screenshots 列
|
||
// 检查列是否存在以及数据类型
|
||
var columnType sql.NullString
|
||
err := db.QueryRow(`
|
||
SELECT DATA_TYPE
|
||
FROM INFORMATION_SCHEMA.COLUMNS
|
||
WHERE TABLE_SCHEMA = DATABASE()
|
||
AND TABLE_NAME = 'customers'
|
||
AND COLUMN_NAME = 'screenshots'
|
||
`).Scan(&columnType)
|
||
|
||
if err == sql.ErrNoRows {
|
||
// 列不存在,添加新列
|
||
_, err = db.Exec("ALTER TABLE customers ADD COLUMN screenshots LONGTEXT")
|
||
if err != nil {
|
||
return fmt.Errorf("failed to add screenshots column: %v", err)
|
||
}
|
||
log.Println("✅ Added screenshots column (LONGTEXT) to customers table")
|
||
} else if err != nil {
|
||
return fmt.Errorf("failed to check screenshots column: %v", err)
|
||
} else if columnType.Valid && columnType.String != "longtext" {
|
||
// 列存在但类型不是 LONGTEXT,修改列类型
|
||
_, err = db.Exec("ALTER TABLE customers MODIFY COLUMN screenshots LONGTEXT")
|
||
if err != nil {
|
||
return fmt.Errorf("failed to modify screenshots column type: %v", err)
|
||
}
|
||
log.Printf("✅ Modified screenshots column type from %s to LONGTEXT\n", columnType.String)
|
||
}
|
||
|
||
// 检查并添加 deal_status 列到 trial_periods 表
|
||
var dealStatusColumn sql.NullString
|
||
err = db.QueryRow(`
|
||
SELECT COLUMN_NAME
|
||
FROM INFORMATION_SCHEMA.COLUMNS
|
||
WHERE TABLE_SCHEMA = DATABASE()
|
||
AND TABLE_NAME = 'trial_periods'
|
||
AND COLUMN_NAME = 'deal_status'
|
||
`).Scan(&dealStatusColumn)
|
||
|
||
if err == sql.ErrNoRows {
|
||
// deal_status 列不存在,添加新列
|
||
_, err = db.Exec("ALTER TABLE trial_periods ADD COLUMN deal_status VARCHAR(50) DEFAULT '初步接触'")
|
||
if err != nil {
|
||
return fmt.Errorf("failed to add deal_status column: %v", err)
|
||
}
|
||
log.Println("✅ Added deal_status column to trial_periods table")
|
||
} else if err != nil {
|
||
return fmt.Errorf("failed to check deal_status column: %v", err)
|
||
}
|
||
|
||
// 数据清理:将旧的状态名称统一更新为新名称
|
||
_, _ = db.Exec("UPDATE trial_periods SET deal_status = '初步接触' WHERE deal_status = '潜在客户' OR deal_status = '' OR deal_status IS NULL")
|
||
_, _ = db.Exec("UPDATE trial_periods SET deal_status = '需求确认' WHERE deal_status = '试用中'")
|
||
|
||
log.Println("✅ Database tables migrated successfully")
|
||
return nil
|
||
}
|
||
|
||
// GetDB 获取数据库连接
|
||
func GetDB() *sql.DB {
|
||
return db
|
||
}
|
||
|
||
// CloseDB 关闭数据库连接
|
||
func CloseDB() error {
|
||
if db != nil {
|
||
return db.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GetDBConfigFromEnv 从环境变量获取数据库配置
|
||
func GetDBConfigFromEnv() DBConfig {
|
||
config := DBConfig{
|
||
Host: "localhost",
|
||
Port: 3306,
|
||
User: "root",
|
||
Password: "",
|
||
Database: "crm_db",
|
||
}
|
||
|
||
if host := os.Getenv("DB_HOST"); host != "" {
|
||
config.Host = host
|
||
}
|
||
if user := os.Getenv("DB_USER"); user != "" {
|
||
config.User = user
|
||
}
|
||
if pwd := os.Getenv("DB_PASSWORD"); pwd != "" {
|
||
config.Password = pwd
|
||
}
|
||
if dbName := os.Getenv("DB_NAME"); dbName != "" {
|
||
config.Database = dbName
|
||
}
|
||
if port := os.Getenv("DB_PORT"); port != "" {
|
||
fmt.Sscanf(port, "%d", &config.Port)
|
||
}
|
||
|
||
return config
|
||
}
|