crm/cmd/server/main.go

367 lines
10 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"crm-go/internal/handlers"
"crm-go/internal/middleware"
"crm-go/internal/storage"
"crm-go/services"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
func main() {
// 获取存储模式默认使用MySQL
storageMode := os.Getenv("STORAGE_MODE")
if storageMode == "" {
storageMode = "mysql" // 默认使用MySQL
}
var customerStorage storage.CustomerStorage
var followUpStorage storage.FollowUpStorage
var trialPeriodStorage storage.TrialPeriodStorage
if storageMode == "mysql" {
// 初始化MySQL数据库连接
dbConfig := storage.GetDBConfigFromEnv()
if err := storage.InitDB(dbConfig); err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
defer storage.CloseDB()
// 使用MySQL存储
customerStorage = storage.NewMySQLCustomerStorage()
followUpStorage = storage.NewMySQLFollowUpStorage()
trialPeriodStorage = storage.NewMySQLTrialPeriodStorage()
log.Println("✅ Using MySQL storage")
} else {
// 使用JSON文件存储向后兼容
customerStorage = storage.NewCustomerStorage("./data/customers.json")
followUpStorage = storage.NewFollowUpStorage("./data/followups.json")
trialPeriodStorage = storage.NewTrialPeriodStorage("./data/trial_periods.json")
log.Println("✅ Using JSON file storage")
}
// Get Feishu webhook URL from environment variable
feishuWebhook := "https://open.feishu.cn/open-apis/bot/v2/hook/d75c14ad-d782-489e-8a99-81b511ee4abd"
// Initialize handlers
customerHandler := handlers.NewCustomerHandler(customerStorage, feishuWebhook)
followUpHandler := handlers.NewFollowUpHandler(followUpStorage, customerStorage, trialPeriodStorage, feishuWebhook)
trialPeriodHandler := handlers.NewTrialPeriodHandler(trialPeriodStorage, customerStorage, feishuWebhook)
authHandler := handlers.NewAuthHandler()
// Start notification checker in background
go func() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for range ticker.C {
if err := followUpHandler.CheckAndSendNotifications(); err != nil {
log.Printf("Error checking notifications: %v", err)
}
}
}()
// Start trial expiry checker in background
// Workday schedule: Monday-Friday, 11:00 AM and 5:00 PM
trialChecker := services.NewTrialExpiryChecker(feishuWebhook)
go func() {
// Helper function to check and send trial expiry notifications
checkTrialExpiry := func() {
trialPeriods, err := trialPeriodStorage.GetAllTrialPeriods()
if err != nil {
log.Printf("Error loading trial periods for expiry check: %v", err)
return
}
customers, err := customerStorage.GetAllCustomers()
if err != nil {
log.Printf("Error loading customers for trial check: %v", err)
return
}
// Create customer name map
customersMap := make(map[string]string)
for _, c := range customers {
customersMap[c.ID] = c.CustomerName
}
// Convert to services.TrialPeriod type
serviceTrialPeriods := make([]services.TrialPeriod, len(trialPeriods))
for i, tp := range trialPeriods {
serviceTrialPeriods[i] = services.TrialPeriod{
ID: tp.ID,
CustomerName: tp.CustomerName,
StartTime: tp.StartTime,
EndTime: tp.EndTime,
CreatedAt: tp.CreatedAt,
}
}
if err := trialChecker.CheckTrialPeriodsAndNotify(serviceTrialPeriods, customersMap); err != nil {
log.Printf("Error checking trial expiry: %v", err)
}
}
// Helper function to check if it's a workday (Monday-Friday)
isWorkday := func(t time.Time) bool {
weekday := t.Weekday()
return weekday >= time.Monday && weekday <= time.Friday
}
// Helper function to check if current time matches notification time (11:00 or 17:00)
isNotificationTime := func(t time.Time) bool {
hour := t.Hour()
minute := t.Minute()
// Check for 11:00 AM or 5:00 PM (17:00)
return (hour == 11 || hour == 17) && minute == 0
}
// Track last notification time to prevent duplicate sends
var lastNotificationTime time.Time
// Check immediately on startup (only on workdays)
now := time.Now()
if isWorkday(now) {
log.Println("Trial expiry checker: Running initial check on startup...")
checkTrialExpiry()
lastNotificationTime = now
}
// Check every minute for scheduled notification times
// Workdays: Monday-Friday, 11:00 AM and 5:00 PM
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
// Only send on workdays at 11:00 AM or 5:00 PM
if isWorkday(now) && isNotificationTime(now) {
// Ensure we don't send duplicate notifications for the same time slot
if lastNotificationTime.Hour() != now.Hour() ||
lastNotificationTime.Day() != now.Day() ||
lastNotificationTime.Month() != now.Month() {
log.Printf("Trial expiry checker: Sending scheduled notification at %s", now.Format("2006-01-02 15:04"))
checkTrialExpiry()
lastNotificationTime = now
}
}
}
}()
// Enable CORS manually
corsHandler := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
h.ServeHTTP(w, r)
})
}
// Auth routes
http.HandleFunc("/api/login", authHandler.Login)
// Set up routes using standard http
http.HandleFunc("/api/customers", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
customerHandler.GetCustomers(w, r)
case "POST":
customerHandler.CreateCustomer(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
http.HandleFunc("/api/upload", func(w http.ResponseWriter, r *http.Request) {
if r.Method == "POST" {
customerHandler.UploadScreenshots(w, r)
} else {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
http.HandleFunc("/api/customers/", func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
// Handle import endpoint
if path == "/api/customers/import" {
if r.Method == "POST" {
customerHandler.ImportCustomers(w, r)
return
}
}
// Handle customer ID endpoints
if strings.HasPrefix(path, "/api/customers/") && path != "/api/customers/" {
// Extract customer ID from URL
id := strings.TrimPrefix(path, "/api/customers/")
// Remove query parameters if any
if idx := strings.Index(id, "?"); idx != -1 {
id = id[:idx]
}
if id != "" {
if r.Method == "GET" {
customerHandler.GetCustomerByID(w, r)
return
}
if r.Method == "PUT" {
customerHandler.UpdateCustomer(w, r)
return
}
if r.Method == "DELETE" {
customerHandler.DeleteCustomer(w, r)
return
}
}
}
http.NotFound(w, r)
})
// Follow-up routes
http.HandleFunc("/api/followups", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
followUpHandler.GetFollowUps(w, r)
case "POST":
followUpHandler.CreateFollowUp(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
http.HandleFunc("/api/followups/", func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
// Handle follow-up ID endpoints
if strings.HasPrefix(path, "/api/followups/") && path != "/api/followups/" {
// Extract follow-up ID from URL
id := strings.TrimPrefix(path, "/api/followups/")
// Remove query parameters if any
if idx := strings.Index(id, "?"); idx != -1 {
id = id[:idx]
}
if id != "" {
if r.Method == "GET" {
followUpHandler.GetFollowUpByID(w, r)
return
}
if r.Method == "PUT" {
followUpHandler.UpdateFollowUp(w, r)
return
}
if r.Method == "DELETE" {
followUpHandler.DeleteFollowUp(w, r)
return
}
}
}
http.NotFound(w, r)
})
// Customer list endpoint for follow-up form
http.HandleFunc("/api/customers/list", func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
followUpHandler.GetCustomerList(w, r)
} else {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
// Trial period routes
http.HandleFunc("/api/trial-periods", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
trialPeriodHandler.GetTrialPeriodsByCustomer(w, r)
case "POST":
trialPeriodHandler.CreateTrialPeriod(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
// Get all trial periods
http.HandleFunc("/api/trial-periods/all", func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
trialPeriodHandler.GetAllTrialPeriods(w, r)
} else {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
// Get unique customer list from trial periods
http.HandleFunc("/api/trial-customers/list", func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
trialPeriodHandler.GetTrialCustomerList(w, r)
} else {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
http.HandleFunc("/api/trial-periods/", func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
if strings.HasPrefix(path, "/api/trial-periods/") && path != "/api/trial-periods/" {
id := strings.TrimPrefix(path, "/api/trial-periods/")
if idx := strings.Index(id, "?"); idx != -1 {
id = id[:idx]
}
if id != "" {
if r.Method == "PUT" {
trialPeriodHandler.UpdateTrialPeriod(w, r)
return
}
if r.Method == "DELETE" {
trialPeriodHandler.DeleteTrialPeriod(w, r)
return
}
}
}
http.NotFound(w, r)
})
// Serve static files for the frontend
staticDir := "./frontend"
if _, err := os.Stat(staticDir); os.IsNotExist(err) {
// Create basic frontend directory if it doesn't exist
os.MkdirAll(staticDir, 0755)
}
// Serve static files (包括上传的文件)
http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("./frontend"))))
// Serve index page
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, filepath.Join("./frontend", "index.html"))
})
// Assemble final handler chain: DefaultServeMux -> AuthMiddleware -> CorsHandler
finalHandler := middleware.AuthMiddleware(http.DefaultServeMux)
port := os.Getenv("PORT")
if port == "" {
port = "8081"
}
addr := ":" + port
log.Println("Server starting on " + addr)
log.Fatal(http.ListenAndServe(addr, corsHandler(finalHandler)))
}