crm/cmd/server/main.go
2026-01-13 18:02:43 +08:00

309 lines
8.4 KiB
Go

package main
import (
"crm-go/internal/handlers"
"crm-go/internal/middleware"
"crm-go/internal/storage"
"crm-go/services"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
func main() {
// Initialize storage
customerStorage := storage.NewCustomerStorage("./data/customers.json")
followUpStorage := storage.NewFollowUpStorage("./data/followups.json")
trialPeriodStorage := storage.NewTrialPeriodStorage("./data/trial_periods.json")
// Get Feishu webhook URL from environment variable
feishuWebhook := "https://open.feishu.cn/open-apis/bot/v2/hook/d75c14ad-d782-489e-8a99-81b511ee4abd"
// Initialize handlers
customerHandler := handlers.NewCustomerHandler(customerStorage, feishuWebhook)
followUpHandler := handlers.NewFollowUpHandler(followUpStorage, customerStorage, feishuWebhook)
trialPeriodHandler := handlers.NewTrialPeriodHandler(trialPeriodStorage, customerStorage, feishuWebhook)
authHandler := handlers.NewAuthHandler()
// Start notification checker in background
go func() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for range ticker.C {
if err := followUpHandler.CheckAndSendNotifications(); err != nil {
log.Printf("Error checking notifications: %v", err)
}
}
}()
// Start trial expiry checker in background
trialChecker := services.NewTrialExpiryChecker(feishuWebhook)
go func() {
// Check immediately on startup
trialPeriods, err := trialPeriodStorage.GetAllTrialPeriods()
if err == nil {
customers, err := customerStorage.GetAllCustomers()
if err == nil {
// Create customer name map
customersMap := make(map[string]string)
for _, c := range customers {
customersMap[c.ID] = c.CustomerName
}
// Convert to services.TrialPeriod type
serviceTrialPeriods := make([]services.TrialPeriod, len(trialPeriods))
for i, tp := range trialPeriods {
serviceTrialPeriods[i] = services.TrialPeriod{
ID: tp.ID,
CustomerID: tp.CustomerID,
StartTime: tp.StartTime,
EndTime: tp.EndTime,
CreatedAt: tp.CreatedAt,
}
}
if err := trialChecker.CheckTrialPeriodsAndNotify(serviceTrialPeriods, customersMap); err != nil {
log.Printf("Error checking trial expiry: %v", err)
}
}
}
// Then check once per day at 10:00 AM
ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop()
for range ticker.C {
trialPeriods, err := trialPeriodStorage.GetAllTrialPeriods()
if err != nil {
log.Printf("Error loading trial periods for expiry check: %v", err)
continue
}
customers, err := customerStorage.GetAllCustomers()
if err != nil {
log.Printf("Error loading customers for trial check: %v", err)
continue
}
// Create customer name map
customersMap := make(map[string]string)
for _, c := range customers {
customersMap[c.ID] = c.CustomerName
}
// Convert to services.TrialPeriod type
serviceTrialPeriods := make([]services.TrialPeriod, len(trialPeriods))
for i, tp := range trialPeriods {
serviceTrialPeriods[i] = services.TrialPeriod{
ID: tp.ID,
CustomerID: tp.CustomerID,
StartTime: tp.StartTime,
EndTime: tp.EndTime,
CreatedAt: tp.CreatedAt,
}
}
if err := trialChecker.CheckTrialPeriodsAndNotify(serviceTrialPeriods, customersMap); err != nil {
log.Printf("Error checking trial expiry: %v", err)
}
}
}()
// Enable CORS manually
corsHandler := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
h.ServeHTTP(w, r)
})
}
// Auth routes
http.HandleFunc("/api/login", authHandler.Login)
// Set up routes using standard http
http.HandleFunc("/api/customers", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
customerHandler.GetCustomers(w, r)
case "POST":
customerHandler.CreateCustomer(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
http.HandleFunc("/api/customers/", func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
// Handle import endpoint
if path == "/api/customers/import" {
if r.Method == "POST" {
customerHandler.ImportCustomers(w, r)
return
}
}
// Handle customer ID endpoints
if strings.HasPrefix(path, "/api/customers/") && path != "/api/customers/" {
// Extract customer ID from URL
id := strings.TrimPrefix(path, "/api/customers/")
// Remove query parameters if any
if idx := strings.Index(id, "?"); idx != -1 {
id = id[:idx]
}
if id != "" {
if r.Method == "GET" {
customerHandler.GetCustomerByID(w, r)
return
}
if r.Method == "PUT" {
customerHandler.UpdateCustomer(w, r)
return
}
if r.Method == "DELETE" {
customerHandler.DeleteCustomer(w, r)
return
}
}
}
http.NotFound(w, r)
})
// Follow-up routes
http.HandleFunc("/api/followups", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
followUpHandler.GetFollowUps(w, r)
case "POST":
followUpHandler.CreateFollowUp(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
http.HandleFunc("/api/followups/", func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
// Handle follow-up ID endpoints
if strings.HasPrefix(path, "/api/followups/") && path != "/api/followups/" {
// Extract follow-up ID from URL
id := strings.TrimPrefix(path, "/api/followups/")
// Remove query parameters if any
if idx := strings.Index(id, "?"); idx != -1 {
id = id[:idx]
}
if id != "" {
if r.Method == "GET" {
followUpHandler.GetFollowUpByID(w, r)
return
}
if r.Method == "PUT" {
followUpHandler.UpdateFollowUp(w, r)
return
}
if r.Method == "DELETE" {
followUpHandler.DeleteFollowUp(w, r)
return
}
}
}
http.NotFound(w, r)
})
// Customer list endpoint for follow-up form
http.HandleFunc("/api/customers/list", func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
followUpHandler.GetCustomerList(w, r)
} else {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
// Trial period routes
http.HandleFunc("/api/trial-periods", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
trialPeriodHandler.GetTrialPeriodsByCustomer(w, r)
case "POST":
trialPeriodHandler.CreateTrialPeriod(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
// Get all trial periods
http.HandleFunc("/api/trial-periods/all", func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
trialPeriodHandler.GetAllTrialPeriods(w, r)
} else {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
http.HandleFunc("/api/trial-periods/", func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
if strings.HasPrefix(path, "/api/trial-periods/") && path != "/api/trial-periods/" {
id := strings.TrimPrefix(path, "/api/trial-periods/")
if idx := strings.Index(id, "?"); idx != -1 {
id = id[:idx]
}
if id != "" {
if r.Method == "PUT" {
trialPeriodHandler.UpdateTrialPeriod(w, r)
return
}
if r.Method == "DELETE" {
trialPeriodHandler.DeleteTrialPeriod(w, r)
return
}
}
}
http.NotFound(w, r)
})
// Serve static files for the frontend
staticDir := "./frontend"
if _, err := os.Stat(staticDir); os.IsNotExist(err) {
// Create basic frontend directory if it doesn't exist
os.MkdirAll(staticDir, 0755)
}
// Serve static files
http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("./frontend"))))
// Serve index page
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, filepath.Join("./frontend", "index.html"))
})
// Assemble final handler chain: DefaultServeMux -> AuthMiddleware -> CorsHandler
finalHandler := middleware.AuthMiddleware(http.DefaultServeMux)
port := os.Getenv("PORT")
if port == "" {
port = "8081"
}
addr := ":" + port
log.Println("Server starting on " + addr)
log.Fatal(http.ListenAndServe(addr, corsHandler(finalHandler)))
}