219 lines
5.5 KiB
Go
219 lines
5.5 KiB
Go
package storage
|
||
|
||
import (
|
||
"crm-go/models"
|
||
"crypto/rand"
|
||
"database/sql"
|
||
"encoding/hex"
|
||
"fmt"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
type mysqlCustomerStorage struct {
|
||
db *sql.DB
|
||
}
|
||
|
||
// NewMySQLCustomerStorage 创建MySQL客户存储
|
||
func NewMySQLCustomerStorage() CustomerStorage {
|
||
return &mysqlCustomerStorage{
|
||
db: GetDB(),
|
||
}
|
||
}
|
||
|
||
func (cs *mysqlCustomerStorage) GetAllCustomers() ([]models.Customer, error) {
|
||
query := `
|
||
SELECT id, created_at, customer_name, intended_product, version,
|
||
description, solution, type, module, status_progress, reporter
|
||
FROM customers
|
||
ORDER BY created_at DESC
|
||
`
|
||
|
||
rows, err := cs.db.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var customers []models.Customer
|
||
for rows.Next() {
|
||
var c models.Customer
|
||
var intendedProduct, version, description, solution, typ, module, statusProgress, reporter sql.NullString
|
||
|
||
err := rows.Scan(
|
||
&c.ID, &c.CreatedAt, &c.CustomerName,
|
||
&intendedProduct, &version, &description,
|
||
&solution, &typ, &module, &statusProgress, &reporter,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
c.IntendedProduct = intendedProduct.String
|
||
c.Version = version.String
|
||
c.Description = description.String
|
||
c.Solution = solution.String
|
||
c.Type = typ.String
|
||
c.Module = module.String
|
||
c.StatusProgress = statusProgress.String
|
||
c.Reporter = reporter.String
|
||
|
||
customers = append(customers, c)
|
||
}
|
||
|
||
return customers, rows.Err()
|
||
}
|
||
|
||
func (cs *mysqlCustomerStorage) GetCustomerByID(id string) (*models.Customer, error) {
|
||
query := `
|
||
SELECT id, created_at, customer_name, intended_product, version,
|
||
description, solution, type, module, status_progress, reporter
|
||
FROM customers
|
||
WHERE id = ?
|
||
`
|
||
|
||
var c models.Customer
|
||
var intendedProduct, version, description, solution, typ, module, statusProgress, reporter sql.NullString
|
||
|
||
err := cs.db.QueryRow(query, id).Scan(
|
||
&c.ID, &c.CreatedAt, &c.CustomerName,
|
||
&intendedProduct, &version, &description,
|
||
&solution, &typ, &module, &statusProgress, &reporter,
|
||
)
|
||
|
||
if err == sql.ErrNoRows {
|
||
return nil, nil
|
||
}
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
c.IntendedProduct = intendedProduct.String
|
||
c.Version = version.String
|
||
c.Description = description.String
|
||
c.Solution = solution.String
|
||
c.Type = typ.String
|
||
c.Module = module.String
|
||
c.StatusProgress = statusProgress.String
|
||
c.Reporter = reporter.String
|
||
|
||
return &c, nil
|
||
}
|
||
|
||
func (cs *mysqlCustomerStorage) CreateCustomer(customer models.Customer) error {
|
||
if customer.ID == "" {
|
||
customer.ID = generateMySQLUUID()
|
||
}
|
||
if customer.CreatedAt.IsZero() {
|
||
customer.CreatedAt = time.Now()
|
||
}
|
||
|
||
query := `
|
||
INSERT INTO customers (id, created_at, customer_name, intended_product, version,
|
||
description, solution, type, module, status_progress, reporter)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
`
|
||
|
||
_, err := cs.db.Exec(query,
|
||
customer.ID, customer.CreatedAt, customer.CustomerName,
|
||
customer.IntendedProduct, customer.Version, customer.Description,
|
||
customer.Solution, customer.Type, customer.Module,
|
||
customer.StatusProgress, customer.Reporter,
|
||
)
|
||
|
||
return err
|
||
}
|
||
|
||
func (cs *mysqlCustomerStorage) UpdateCustomer(id string, updates models.UpdateCustomerRequest) error {
|
||
// 首先获取现有客户
|
||
existing, err := cs.GetCustomerByID(id)
|
||
if err != nil || existing == nil {
|
||
return err
|
||
}
|
||
|
||
// 应用更新
|
||
if updates.CustomerName != nil {
|
||
existing.CustomerName = *updates.CustomerName
|
||
}
|
||
if updates.IntendedProduct != nil {
|
||
existing.IntendedProduct = *updates.IntendedProduct
|
||
}
|
||
if updates.Version != nil {
|
||
existing.Version = *updates.Version
|
||
}
|
||
if updates.Description != nil {
|
||
existing.Description = *updates.Description
|
||
}
|
||
if updates.Solution != nil {
|
||
existing.Solution = *updates.Solution
|
||
}
|
||
if updates.Type != nil {
|
||
existing.Type = *updates.Type
|
||
}
|
||
if updates.Module != nil {
|
||
existing.Module = *updates.Module
|
||
}
|
||
if updates.StatusProgress != nil {
|
||
existing.StatusProgress = *updates.StatusProgress
|
||
}
|
||
if updates.Reporter != nil {
|
||
existing.Reporter = *updates.Reporter
|
||
}
|
||
|
||
query := `
|
||
UPDATE customers
|
||
SET customer_name = ?, intended_product = ?, version = ?,
|
||
description = ?, solution = ?, type = ?,
|
||
module = ?, status_progress = ?, reporter = ?
|
||
WHERE id = ?
|
||
`
|
||
|
||
_, err = cs.db.Exec(query,
|
||
existing.CustomerName, existing.IntendedProduct, existing.Version,
|
||
existing.Description, existing.Solution, existing.Type,
|
||
existing.Module, existing.StatusProgress, existing.Reporter,
|
||
id,
|
||
)
|
||
|
||
return err
|
||
}
|
||
|
||
func (cs *mysqlCustomerStorage) DeleteCustomer(id string) error {
|
||
query := `DELETE FROM customers WHERE id = ?`
|
||
_, err := cs.db.Exec(query, id)
|
||
if err != nil {
|
||
if strings.Contains(err.Error(), "command denied") {
|
||
return fmt.Errorf("数据库权限不足:无法执行删除操作,请联系管理员")
|
||
}
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (cs *mysqlCustomerStorage) SaveCustomers(customers []models.Customer) error {
|
||
// MySQL版本不需要使用此方法,保留接口兼容
|
||
return nil
|
||
}
|
||
|
||
func (cs *mysqlCustomerStorage) LoadCustomers() ([]models.Customer, error) {
|
||
return cs.GetAllCustomers()
|
||
}
|
||
|
||
func (cs *mysqlCustomerStorage) CustomerExists(customer models.Customer) (bool, error) {
|
||
query := `SELECT COUNT(*) FROM customers WHERE description = ?`
|
||
var count int
|
||
err := cs.db.QueryRow(query, customer.Description).Scan(&count)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return count > 0, nil
|
||
}
|
||
|
||
func generateMySQLUUID() string {
|
||
bytes := make([]byte, 16)
|
||
rand.Read(bytes)
|
||
bytes[6] = (bytes[6] & 0x0f) | 0x40 // Version 4
|
||
bytes[8] = (bytes[8] & 0x3f) | 0x80 // Variant
|
||
return hex.EncodeToString(bytes)
|
||
}
|