diff --git a/config/config.go b/config/config.go index ac78af24..355b3f7c 100644 --- a/config/config.go +++ b/config/config.go @@ -8,7 +8,7 @@ import ( "sync" "time" - "github.com/Dreamacro/clash/adapters/outbound" + adapters "github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/common/observable" "github.com/Dreamacro/clash/common/structure" C "github.com/Dreamacro/clash/constant" @@ -50,6 +50,7 @@ type RawConfig struct { Mode string `yaml:"mode"` LogLevel string `yaml:"log-level"` ExternalController string `yaml:"external-controller"` + Secret string `yaml:"secret"` Proxy []map[string]interface{} `yaml:"Proxy"` ProxyGroup []map[string]interface{} `yaml:"Proxy Group"` @@ -190,6 +191,7 @@ func (c *Config) parseGeneral(cfg *RawConfig) error { if restAddr := cfg.ExternalController; restAddr != "" { c.event <- &Event{Type: "external-controller", Payload: restAddr} + c.event <- &Event{Type: "secret", Payload: cfg.Secret} } c.UpdateGeneral(*c.general) diff --git a/hub/server.go b/hub/server.go index cb150067..b85b9c2f 100644 --- a/hub/server.go +++ b/hub/server.go @@ -3,6 +3,7 @@ package hub import ( "encoding/json" "net/http" + "strings" "time" "github.com/Dreamacro/clash/config" @@ -15,6 +16,8 @@ import ( log "github.com/sirupsen/logrus" ) +var secret = "" + type Traffic struct { Up int64 `json:"up"` Down int64 `json:"down"` @@ -24,11 +27,19 @@ func newHub(signal chan struct{}) { var addr string ch := config.Instance().Subscribe() signal <- struct{}{} + count := 0 for { elm := <-ch event := elm.(*config.Event) - if event.Type == "external-controller" { + switch event.Type { + case "external-controller": addr = event.Payload.(string) + count++ + case "secret": + secret = event.Payload.(string) + count++ + } + if count == 2 { break } } @@ -38,11 +49,11 @@ func newHub(signal chan struct{}) { cors := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, - AllowedHeaders: []string{"Content-Type"}, + AllowedHeaders: []string{"Content-Type", "Authorization"}, MaxAge: 300, }) - r.Use(cors.Handler) + r.Use(cors.Handler, authentication) r.With(jsonContentType).Get("/traffic", traffic) r.With(jsonContentType).Get("/logs", getLogs) @@ -65,6 +76,30 @@ func jsonContentType(next http.Handler) http.Handler { return http.HandlerFunc(fn) } +func authentication(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + header := r.Header.Get("Authorization") + text := strings.SplitN(header, " ", 2) + + if secret == "" { + next.ServeHTTP(w, r) + return + } + + hasUnvalidHeader := text[0] != "Bearer" + hasUnvalidSecret := len(text) == 2 && text[1] != secret + if hasUnvalidHeader || hasUnvalidSecret { + w.WriteHeader(http.StatusUnauthorized) + render.JSON(w, r, Error{ + Error: "Authentication failed", + }) + return + } + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} + func traffic(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK)