220 lines
5.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package storage
import (
"database/sql"
"fmt"
"log"
"os"
"time"
_ "github.com/go-sql-driver/mysql"
)
// DBConfig 数据库配置
type DBConfig struct {
Host string
Port int
User string
Password string
Database string
}
var db *sql.DB
// InitDB 初始化数据库连接
func InitDB(config DBConfig) error {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
config.User, config.Password, config.Host, config.Port, config.Database)
var err error
db, err = sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("failed to open database: %v", err)
}
// 配置连接池
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
// 测试连接
if err := db.Ping(); err != nil {
return fmt.Errorf("failed to ping database: %v", err)
}
log.Println("✅ Database connection established")
// 自动迁移表结构
if err := autoMigrate(); err != nil {
return fmt.Errorf("failed to auto migrate: %v", err)
}
return nil
}
// autoMigrate 自动创建或更新表结构
func autoMigrate() error {
// 创建 customers 表
createCustomersTable := `
CREATE TABLE IF NOT EXISTS customers (
id VARCHAR(64) PRIMARY KEY,
created_at DATETIME NOT NULL,
customer_name VARCHAR(255) NOT NULL,
intended_product VARCHAR(255),
version VARCHAR(100),
description TEXT,
solution TEXT,
type VARCHAR(100),
module VARCHAR(100),
status_progress VARCHAR(100),
reporter VARCHAR(255),
screenshots LONGTEXT,
INDEX idx_customer_name (customer_name),
INDEX idx_created_at (created_at)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
`
if _, err := db.Exec(createCustomersTable); err != nil {
return fmt.Errorf("failed to create customers table: %v", err)
}
// 创建 followups 表
createFollowupsTable := `
CREATE TABLE IF NOT EXISTS followups (
id VARCHAR(64) PRIMARY KEY,
created_at DATETIME NOT NULL,
customer_name VARCHAR(255) NOT NULL,
deal_status VARCHAR(50),
customer_level VARCHAR(50),
industry VARCHAR(100),
follow_up_time DATETIME,
notification_sent BOOLEAN DEFAULT FALSE,
INDEX idx_customer_name (customer_name),
INDEX idx_follow_up_time (follow_up_time)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
`
if _, err := db.Exec(createFollowupsTable); err != nil {
return fmt.Errorf("failed to create followups table: %v", err)
}
// 创建 trial_periods 表
createTrialPeriodsTable := `
CREATE TABLE IF NOT EXISTS trial_periods (
id VARCHAR(64) PRIMARY KEY,
customer_name VARCHAR(255) NOT NULL,
source VARCHAR(100),
intended_product VARCHAR(255),
start_time DATETIME,
end_time DATETIME,
is_trial BOOLEAN DEFAULT TRUE,
created_at DATETIME NOT NULL,
INDEX idx_customer_name (customer_name),
INDEX idx_end_time (end_time)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
`
if _, err := db.Exec(createTrialPeriodsTable); err != nil {
return fmt.Errorf("failed to create trial_periods table: %v", err)
}
// 检查并添加/更新 screenshots 列
// 检查列是否存在以及数据类型
var columnType sql.NullString
err := db.QueryRow(`
SELECT DATA_TYPE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = 'customers'
AND COLUMN_NAME = 'screenshots'
`).Scan(&columnType)
if err == sql.ErrNoRows {
// 列不存在,添加新列
_, err = db.Exec("ALTER TABLE customers ADD COLUMN screenshots LONGTEXT")
if err != nil {
return fmt.Errorf("failed to add screenshots column: %v", err)
}
log.Println("✅ Added screenshots column (LONGTEXT) to customers table")
} else if err != nil {
return fmt.Errorf("failed to check screenshots column: %v", err)
} else if columnType.Valid && columnType.String != "longtext" {
// 列存在但类型不是 LONGTEXT修改列类型
_, err = db.Exec("ALTER TABLE customers MODIFY COLUMN screenshots LONGTEXT")
if err != nil {
return fmt.Errorf("failed to modify screenshots column type: %v", err)
}
log.Printf("✅ Modified screenshots column type from %s to LONGTEXT\n", columnType.String)
}
// 检查并添加 deal_status 列到 trial_periods 表
var dealStatusColumn sql.NullString
err = db.QueryRow(`
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = 'trial_periods'
AND COLUMN_NAME = 'deal_status'
`).Scan(&dealStatusColumn)
if err == sql.ErrNoRows {
// deal_status 列不存在,添加新列
_, err = db.Exec("ALTER TABLE trial_periods ADD COLUMN deal_status VARCHAR(50) DEFAULT '初步接触'")
if err != nil {
return fmt.Errorf("failed to add deal_status column: %v", err)
}
log.Println("✅ Added deal_status column to trial_periods table")
} else if err != nil {
return fmt.Errorf("failed to check deal_status column: %v", err)
}
// 数据清理:将旧的状态名称统一更新为新名称
_, _ = db.Exec("UPDATE trial_periods SET deal_status = '初步接触' WHERE deal_status = '潜在客户' OR deal_status = '' OR deal_status IS NULL")
_, _ = db.Exec("UPDATE trial_periods SET deal_status = '需求确认' WHERE deal_status = '试用中'")
log.Println("✅ Database tables migrated successfully")
return nil
}
// GetDB 获取数据库连接
func GetDB() *sql.DB {
return db
}
// CloseDB 关闭数据库连接
func CloseDB() error {
if db != nil {
return db.Close()
}
return nil
}
// GetDBConfigFromEnv 从环境变量获取数据库配置
func GetDBConfigFromEnv() DBConfig {
config := DBConfig{
Host: "localhost",
Port: 3306,
User: "root",
Password: "",
Database: "crm_db",
}
if host := os.Getenv("DB_HOST"); host != "" {
config.Host = host
}
if user := os.Getenv("DB_USER"); user != "" {
config.User = user
}
if pwd := os.Getenv("DB_PASSWORD"); pwd != "" {
config.Password = pwd
}
if dbName := os.Getenv("DB_NAME"); dbName != "" {
config.Database = dbName
}
if port := os.Getenv("DB_PORT"); port != "" {
fmt.Sscanf(port, "%d", &config.Port)
}
return config
}