diff --git a/config/config.go b/config/config.go index 355b3f7c..4d6b87fa 100644 --- a/config/config.go +++ b/config/config.go @@ -87,13 +87,18 @@ func (c *Config) Report() chan<- interface{} { } func (c *Config) readConfig() (*RawConfig, error) { - if _, err := os.Stat(C.ConfigPath); os.IsNotExist(err) { + if _, err := os.Stat(C.Path.Config()); os.IsNotExist(err) { return nil, err } - data, err := ioutil.ReadFile(C.ConfigPath) + data, err := ioutil.ReadFile(C.Path.Config()) if err != nil { return nil, err } + + if len(data) == 0 { + return nil, fmt.Errorf("Configuration file %s is empty", C.Path.Config()) + } + // config with some default value rawConfig := &RawConfig{ AllowLan: false, diff --git a/config/initial.go b/config/initial.go index b13eefd0..62f63570 100644 --- a/config/initial.go +++ b/config/initial.go @@ -55,16 +55,23 @@ func downloadMMDB(path string) (err error) { // Init prepare necessary files func Init() { + // initial homedir + if _, err := os.Stat(C.Path.HomeDir()); os.IsNotExist(err) { + if err := os.MkdirAll(C.Path.HomeDir(), 0777); err != nil { + log.Fatalf("Can't create config directory %s: %s", C.Path.HomeDir(), err.Error()) + } + } + // initial config.ini - if _, err := os.Stat(C.ConfigPath); os.IsNotExist(err) { + if _, err := os.Stat(C.Path.Config()); os.IsNotExist(err) { log.Info("Can't find config, create a empty file") - os.OpenFile(C.ConfigPath, os.O_CREATE|os.O_WRONLY, 0644) + os.OpenFile(C.Path.Config(), os.O_CREATE|os.O_WRONLY, 0644) } // initial mmdb - if _, err := os.Stat(C.MMDBPath); os.IsNotExist(err) { + if _, err := os.Stat(C.Path.MMDB()); os.IsNotExist(err) { log.Info("Can't find MMDB, start download") - err := downloadMMDB(C.MMDBPath) + err := downloadMMDB(C.Path.MMDB()) if err != nil { log.Fatalf("Can't download MMDB: %s", err.Error()) } diff --git a/constant/config.go b/constant/config.go index 0c760919..f0b93fe2 100644 --- a/constant/config.go +++ b/constant/config.go @@ -1,23 +1,5 @@ package constant -import ( - "os" - "os/user" - "path" - - log "github.com/sirupsen/logrus" -) - -const ( - Name = "clash" -) - -var ( - HomeDir string - ConfigPath string - MMDBPath string -) - type General struct { Mode *string `json:"mode,omitempty"` AllowLan *bool `json:"allow-lan,omitempty"` @@ -26,26 +8,3 @@ type General struct { RedirPort *int `json:"redir-port,omitempty"` LogLevel *string `json:"log-level,omitempty"` } - -func init() { - currentUser, err := user.Current() - if err != nil { - dir := os.Getenv("HOME") - if dir == "" { - log.Fatalf("Can't get current user: %s", err.Error()) - } - HomeDir = dir - } else { - HomeDir = currentUser.HomeDir - } - - dirPath := path.Join(HomeDir, ".config", Name) - if _, err := os.Stat(dirPath); os.IsNotExist(err) { - if err := os.MkdirAll(dirPath, 0777); err != nil { - log.Fatalf("Can't create config directory %s: %s", dirPath, err.Error()) - } - } - - ConfigPath = path.Join(dirPath, "config.yml") - MMDBPath = path.Join(dirPath, "Country.mmdb") -} diff --git a/constant/path.go b/constant/path.go new file mode 100644 index 00000000..ad7b847d --- /dev/null +++ b/constant/path.go @@ -0,0 +1,49 @@ +package constant + +import ( + "os" + "os/user" + P "path" +) + +const Name = "clash" + +// Path is used to get the configuration path +var Path *path + +type path struct { + homedir string +} + +func init() { + currentUser, err := user.Current() + var homedir string + if err != nil { + dir := os.Getenv("HOME") + if dir == "" { + dir, _ = os.Getwd() + } + homedir = dir + } else { + homedir = currentUser.HomeDir + } + homedir = P.Join(homedir, ".config", Name) + Path = &path{homedir: homedir} +} + +// SetHomeDir is used to set the configuration path +func SetHomeDir(root string) { + Path = &path{homedir: root} +} + +func (p *path) HomeDir() string { + return p.homedir +} + +func (p *path) Config() string { + return P.Join(p.homedir, "config.yml") +} + +func (p *path) MMDB() string { + return P.Join(p.homedir, "Country.mmdb") +} diff --git a/rules/geoip.go b/rules/geoip.go index ff52ff36..5fb9a1f7 100644 --- a/rules/geoip.go +++ b/rules/geoip.go @@ -42,7 +42,7 @@ func (g *GEOIP) Payload() string { func NewGEOIP(country string, adapter string) *GEOIP { once.Do(func() { var err error - mmdb, err = geoip2.Open(C.MMDBPath) + mmdb, err = geoip2.Open(C.Path.MMDB()) if err != nil { log.Fatalf("Can't load mmdb: %s", err.Error()) }