103 lines
2.2 KiB
Go
103 lines
2.2 KiB
Go
package settings
|
|
|
|
import (
|
|
"strings"
|
|
"sync"
|
|
|
|
"0451meishiditu/backend/internal/config"
|
|
"0451meishiditu/backend/internal/models"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const keyCORSAllowOrigins = "cors_allow_origins"
|
|
|
|
type Store struct {
|
|
mu sync.RWMutex
|
|
origins map[string]struct{}
|
|
}
|
|
|
|
func New(db *gorm.DB, cfg config.Config) (*Store, error) {
|
|
s := &Store{origins: map[string]struct{}{}}
|
|
|
|
// init from DB if exists; otherwise seed from env
|
|
var row models.SystemSetting
|
|
err := db.Where("`key` = ?", keyCORSAllowOrigins).First(&row).Error
|
|
if err == nil {
|
|
s.SetCORSAllowOrigins(parseOrigins(row.Value))
|
|
return s, nil
|
|
}
|
|
if err != nil && err != gorm.ErrRecordNotFound {
|
|
return nil, err
|
|
}
|
|
|
|
s.SetCORSAllowOrigins(cfg.CORSAllowOrigins)
|
|
_ = db.Create(&models.SystemSetting{Key: keyCORSAllowOrigins, Value: strings.Join(cfg.CORSAllowOrigins, ",")}).Error
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Store) CORSAllowOrigins() []string {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
out := make([]string, 0, len(s.origins))
|
|
for o := range s.origins {
|
|
out = append(out, o)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (s *Store) CORSAllowOrigin(origin string) bool {
|
|
origin = strings.TrimSpace(origin)
|
|
if origin == "" {
|
|
return false
|
|
}
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
if _, ok := s.origins["*"]; ok {
|
|
return true
|
|
}
|
|
_, ok := s.origins[origin]
|
|
return ok
|
|
}
|
|
|
|
func (s *Store) SetCORSAllowOrigins(origins []string) {
|
|
m := map[string]struct{}{}
|
|
for _, o := range origins {
|
|
v := strings.TrimSpace(o)
|
|
if v == "" {
|
|
continue
|
|
}
|
|
m[v] = struct{}{}
|
|
}
|
|
s.mu.Lock()
|
|
s.origins = m
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func parseOrigins(raw string) []string {
|
|
var out []string
|
|
for _, part := range strings.Split(raw, ",") {
|
|
v := strings.TrimSpace(part)
|
|
if v != "" {
|
|
out = append(out, v)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func UpsertCORSAllowOrigins(db *gorm.DB, origins []string) error {
|
|
val := strings.Join(origins, ",")
|
|
return db.Transaction(func(tx *gorm.DB) error {
|
|
var row models.SystemSetting
|
|
err := tx.Where("`key` = ?", keyCORSAllowOrigins).First(&row).Error
|
|
if err == nil {
|
|
return tx.Model(&models.SystemSetting{}).Where("id = ?", row.ID).Update("value", val).Error
|
|
}
|
|
if err != nil && err != gorm.ErrRecordNotFound {
|
|
return err
|
|
}
|
|
return tx.Create(&models.SystemSetting{Key: keyCORSAllowOrigins, Value: val}).Error
|
|
})
|
|
}
|
|
|