diff --git a/go.mod b/go.mod index 1101def1..c8eb6e10 100644 --- a/go.mod +++ b/go.mod @@ -35,7 +35,7 @@ require ( github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c github.com/samber/lo v1.37.0 github.com/sirupsen/logrus v1.9.0 - github.com/stretchr/testify v1.8.1 + github.com/stretchr/testify v1.8.2 github.com/xtls/go v0.0.0-20220914232946-0441cf4cf837 github.com/zhangyunhao116/fastrand v0.3.0 go.etcd.io/bbolt v1.3.6 diff --git a/go.sum b/go.sum index 22cc7c26..f56aed7a 100644 --- a/go.sum +++ b/go.sum @@ -153,8 +153,8 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/u-root/uio v0.0.0-20221213070652-c3537552635f h1:dpx1PHxYqAnXzbryJrWP1NQLzEjwcVgFLhkknuFQ7ww= github.com/u-root/uio v0.0.0-20221213070652-c3537552635f/go.mod h1:IogEAUBXDEwX7oR/BMmCctShYs80ql4hF0ySdzGxf7E= github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg= diff --git a/hub/route/server.go b/hub/route/server.go index 054b1ad1..848face9 100644 --- a/hub/route/server.go +++ b/hub/route/server.go @@ -87,6 +87,8 @@ func Start(addr string, tlsAddr string, secret string, r.Mount("/cache", cacheRouter()) r.Mount("/dns", dnsRouter()) r.Mount("/restart", restartRouter()) + r.Mount("/upgrade", upgradeRouter()) + }) if uiPath != "" { diff --git a/hub/route/upgrade.go b/hub/route/upgrade.go new file mode 100644 index 00000000..0d772c85 --- /dev/null +++ b/hub/route/upgrade.go @@ -0,0 +1,69 @@ +package route + +import ( + "fmt" + "net/http" + "os" + "os/exec" + "runtime" + "syscall" + + "github.com/Dreamacro/clash/hub/updater" + "github.com/Dreamacro/clash/log" + "github.com/go-chi/render" + + "github.com/go-chi/chi/v5" +) + +func upgradeRouter() http.Handler { + r := chi.NewRouter() + r.Post("/", upgrade) + return r +} + +func upgrade(w http.ResponseWriter, r *http.Request) { + // modify from https://github.com/AdguardTeam/AdGuardHome/blob/595484e0b3fb4c457f9bb727a6b94faa78a66c5f/internal/home/controlupdate.go#L108 + log.Infoln("start update") + err := updater.Update() + if err != nil { + log.Errorln("err:%s", err) + } + + execPath, err := os.Executable() + if err != nil { + render.Status(r, http.StatusInternalServerError) + render.JSON(w, r, newError(fmt.Sprintf("getting path: %s", err))) + return + } + + render.JSON(w, r, render.M{"status": "ok"}) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + // modify from https://github.com/AdguardTeam/AdGuardHome/blob/595484e0b3fb4c457f9bb727a6b94faa78a66c5f/internal/home/controlupdate.go#L180 + // The background context is used because the underlying functions wrap it + // with timeout and shut down the server, which handles current request. It + // also should be done in a separate goroutine for the same reason. + go func() { + if runtime.GOOS == "windows" { + cmd := exec.Command(execPath, os.Args[1:]...) + log.Infoln("restarting: %q %q", execPath, os.Args[1:]) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err = cmd.Start() + if err != nil { + log.Fatalln("restarting: %s", err) + } + + os.Exit(0) + } + + log.Infoln("restarting: %q %q", execPath, os.Args[1:]) + err = syscall.Exec(execPath, os.Args, os.Environ()) + if err != nil { + log.Fatalln("restarting: %s", err) + } + }() +} diff --git a/hub/updater/limitedreader.go b/hub/updater/limitedreader.go new file mode 100644 index 00000000..c31db601 --- /dev/null +++ b/hub/updater/limitedreader.go @@ -0,0 +1,67 @@ +package updater + +import ( + "fmt" + "io" + + "golang.org/x/exp/constraints" +) + +// LimitReachedError records the limit and the operation that caused it. +type LimitReachedError struct { + Limit int64 +} + +// Error implements the [error] interface for *LimitReachedError. +// +// TODO(a.garipov): Think about error string format. +func (lre *LimitReachedError) Error() string { + return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit) +} + +// limitedReader is a wrapper for [io.Reader] limiting the input and dealing +// with errors package. +type limitedReader struct { + r io.Reader + limit int64 + n int64 +} + +// Read implements the [io.Reader] interface. +func (lr *limitedReader) Read(p []byte) (n int, err error) { + if lr.n == 0 { + return 0, &LimitReachedError{ + Limit: lr.limit, + } + } + + p = p[:Min(lr.n, int64(len(p)))] + + n, err = lr.r.Read(p) + lr.n -= int64(n) + + return n, err +} + +// LimitReader wraps Reader to make it's Reader stop with ErrLimitReached after +// n bytes read. +func LimitReader(r io.Reader, n int64) (limited io.Reader, err error) { + if n < 0 { + return nil, &updateError{Message: "limit must be non-negative"} + } + + return &limitedReader{ + r: r, + limit: n, + n: n, + }, nil +} + +// Min returns the smaller of x or y. +func Min[T constraints.Integer | ~string](x, y T) (res T) { + if x < y { + return x + } + + return y +} diff --git a/hub/updater/updater.go b/hub/updater/updater.go new file mode 100644 index 00000000..e4b55421 --- /dev/null +++ b/hub/updater/updater.go @@ -0,0 +1,460 @@ +package updater + +import ( + "archive/zip" + "compress/gzip" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + + "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" +) + +// Updater is the AdGuard Home updater. +var ( + client http.Client + + goarch string + goos string + goarm string + gomips string + + workDir string + versionCheckURL string + + // mu protects all fields below. + mu sync.RWMutex + + // TODO(a.garipov): See if all of these fields actually have to be in + // this struct. + currentExeName string // 当前可执行文件 + updateDir string // 更新目录 + packageName string // 更新压缩文件 + backupDir string // 备份目录 + backupExeName string // 备份文件名 + updateExeName string // 更新后的可执行文件 + unpackedFile string + + baseURL = "https://github.com/MetaCubeX/Clash.Meta/releases/download/Prerelease-Alpha/clash.meta" + versionURL = "https://github.com/MetaCubeX/Clash.Meta/releases/download/Prerelease-Alpha/version.txt" + packageURL string + latestVersion string +) + +type updateError struct { + Message string +} + +func (e *updateError) Error() string { + return fmt.Sprintf("error: %s", e.Message) +} + +// Update performs the auto-updater. It returns an error if the updater failed. +// If firstRun is true, it assumes the configuration file doesn't exist. +func Update() (err error) { + goos = runtime.GOOS + goarch = runtime.GOARCH + latestVersion = getLatestVersion() + + if latestVersion == constant.Version { + err := &updateError{Message: "Already using latest version"} + return err + } + + updateDownloadURL() + mu.Lock() + defer mu.Unlock() + + log.Infoln("current version alpha-%s", constant.Version) + + defer func() { + if err != nil { + log.Errorln("updater: failed: %v", err) + } else { + log.Infoln("updater: finished") + } + }() + + execPath, err := os.Executable() + if err != nil { + return fmt.Errorf("getting executable path: %w", err) + } + + workDir = filepath.Dir(execPath) + //log.Infoln("workDir %s", execPath) + + err = prepare(execPath) + if err != nil { + return fmt.Errorf("preparing: %w", err) + } + + defer clean() + + err = downloadPackageFile() + if err != nil { + return fmt.Errorf("downloading package file: %w", err) + } + + err = unpack() + if err != nil { + return fmt.Errorf("unpacking: %w", err) + } + + err = replace() + if err != nil { + return fmt.Errorf("replacing: %w", err) + } + + return nil +} + +// VersionCheckURL returns the version check URL. +func VersionCheckURL() (vcu string) { + mu.RLock() + defer mu.RUnlock() + + return versionCheckURL +} + +// prepare fills all necessary fields in Updater object. +func prepare(exePath string) (err error) { + updateDir = filepath.Join(workDir, "meta-update") + currentExeName = exePath + _, pkgNameOnly := filepath.Split(packageURL) + if pkgNameOnly == "" { + return fmt.Errorf("invalid PackageURL: %q", packageURL) + } + + packageName = filepath.Join(updateDir, pkgNameOnly) + //log.Infoln(packageName) + backupDir = filepath.Join(workDir, "meta-backup") + + if goos == "windows" { + updateExeName = "clash.meta" + "-" + goos + "-" + goarch + ".exe" + } else { + updateExeName = "clash.meta" + "-" + goos + "-" + goarch + } + + log.Infoln("updateExeName: %s ,currentExeName: %s", updateExeName, currentExeName) + + backupExeName = filepath.Join(backupDir, filepath.Base(exePath)) + updateExeName = filepath.Join(updateDir, updateExeName) + + log.Infoln( + "updater: updating using url: %s", + packageURL, + ) + + currentExeName = exePath + _, err = os.Stat(currentExeName) + if err != nil { + return fmt.Errorf("checking %q: %w", currentExeName, err) + } + + return nil +} + +// unpack extracts the files from the downloaded archive. +func unpack() error { + var err error + _, pkgNameOnly := filepath.Split(packageURL) + + log.Debugln("updater: unpacking package") + if strings.HasSuffix(pkgNameOnly, ".zip") { + unpackedFile, err = zipFileUnpack(packageName, updateDir) + if err != nil { + return fmt.Errorf(".zip unpack failed: %w", err) + } + + } else if strings.HasSuffix(pkgNameOnly, ".gz") { + unpackedFile, err = gzFileUnpack(packageName, updateDir) + if err != nil { + return fmt.Errorf(".gz unpack failed: %w", err) + } + + } else { + return fmt.Errorf("unknown package extension") + } + + return nil +} + +// replace moves the current executable with the updated one and also copies the +// supporting files. +func replace() error { + //err := copySupportingFiles(unpackedFiles, updateDir, workDir) + //if err != nil { + // return fmt.Errorf("copySupportingFiles(%s, %s) failed: %w", updateDir, workDir, err) + //} + + log.Infoln("updater: renaming: %s to %s", currentExeName, backupExeName) + err := os.Rename(currentExeName, backupExeName) + if err != nil { + return err + } + + if goos == "windows" { + // rename fails with "File in use" error + log.Infoln("copying:%s to %s", updateExeName, currentExeName) + err = copyFile(updateExeName, currentExeName) + } else { + err = os.Rename(updateExeName, currentExeName) + } + if err != nil { + return err + } + + return nil +} + +// clean removes the temporary directory itself and all it's contents. +func clean() { + _ = os.RemoveAll(updateDir) +} + +// MaxPackageFileSize is a maximum package file length in bytes. The largest +// package whose size is limited by this constant currently has the size of +// approximately 9 MiB. +const MaxPackageFileSize = 32 * 1024 * 1024 + +// Download package file and save it to disk +func downloadPackageFile() (err error) { + var resp *http.Response + resp, err = client.Get(packageURL) + if err != nil { + return fmt.Errorf("http request failed: %w", err) + } + + defer func() { + closeErr := resp.Body.Close() + if closeErr != nil && err == nil { + err = closeErr + } + }() + + var r io.Reader + r, err = LimitReader(resp.Body, MaxPackageFileSize) + if err != nil { + return fmt.Errorf("http request failed: %w", err) + } + + log.Debugln("updater: reading http body") + // This use of ReadAll is now safe, because we limited body's Reader. + body, err := io.ReadAll(r) + if err != nil { + return fmt.Errorf("io.ReadAll() failed: %w", err) + } + + log.Debugln("updateDir %s", updateDir) + err = os.Mkdir(updateDir, 0o755) + if err != nil { + fmt.Errorf("mkdir error: %w", err) + } + + log.Debugln("updater: saving package to file %s", packageName) + err = os.WriteFile(packageName, body, 0o755) + if err != nil { + return fmt.Errorf("os.WriteFile() failed: %w", err) + } + return nil +} + +// Unpack a single .gz file to the specified directory +// Existing files are overwritten +// All files are created inside outDir, subdirectories are not created +// Return the output file name +func gzFileUnpack(gzfile, outDir string) (string, error) { + f, err := os.Open(gzfile) + if err != nil { + return "", fmt.Errorf("os.Open(): %w", err) + } + + defer func() { + closeErr := f.Close() + if closeErr != nil && err == nil { + err = closeErr + } + }() + + gzReader, err := gzip.NewReader(f) + if err != nil { + return "", fmt.Errorf("gzip.NewReader(): %w", err) + } + + defer func() { + closeErr := gzReader.Close() + if closeErr != nil && err == nil { + err = closeErr + } + }() + // Get the original file name from the .gz file header + originalName := gzReader.Header.Name + if originalName == "" { + // Fallback: remove the .gz extension from the input file name if the header doesn't provide the original name + originalName = filepath.Base(gzfile) + originalName = strings.TrimSuffix(originalName, ".gz") + } + + outputName := filepath.Join(outDir, originalName) + + // Create the output file + wc, err := os.OpenFile( + outputName, + os.O_WRONLY|os.O_CREATE|os.O_TRUNC, + 0o755, + ) + if err != nil { + return "", fmt.Errorf("os.OpenFile(%s): %w", outputName, err) + } + + defer func() { + closeErr := wc.Close() + if closeErr != nil && err == nil { + err = closeErr + } + }() + + // Copy the contents of the gzReader to the output file + _, err = io.Copy(wc, gzReader) + if err != nil { + return "", fmt.Errorf("io.Copy(): %w", err) + } + + return outputName, nil +} + +// Unpack a single file from .zip file to the specified directory +// Existing files are overwritten +// All files are created inside 'outDir', subdirectories are not created +// Return the output file name +func zipFileUnpack(zipfile, outDir string) (string, error) { + zrc, err := zip.OpenReader(zipfile) + if err != nil { + return "", fmt.Errorf("zip.OpenReader(): %w", err) + } + + defer func() { + closeErr := zrc.Close() + if closeErr != nil && err == nil { + err = closeErr + } + }() + if len(zrc.File) == 0 { + return "", fmt.Errorf("no files in the zip archive") + } + + // Assuming the first file in the zip archive is the target file + zf := zrc.File[0] + var rc io.ReadCloser + rc, err = zf.Open() + if err != nil { + return "", fmt.Errorf("zip file Open(): %w", err) + } + + defer func() { + closeErr := rc.Close() + if closeErr != nil && err == nil { + err = closeErr + } + }() + fi := zf.FileInfo() + name := fi.Name() + outputName := filepath.Join(outDir, name) + + if fi.IsDir() { + return "", fmt.Errorf("the target file is a directory") + } + + var wc io.WriteCloser + wc, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) + if err != nil { + return "", fmt.Errorf("os.OpenFile(): %w", err) + } + + defer func() { + closeErr := wc.Close() + if closeErr != nil && err == nil { + err = closeErr + } + }() + _, err = io.Copy(wc, rc) + if err != nil { + return "", fmt.Errorf("io.Copy(): %w", err) + } + + return outputName, nil +} + +// Copy file on disk +func copyFile(src, dst string) error { + d, e := os.ReadFile(src) + if e != nil { + return e + } + e = os.WriteFile(dst, d, 0o644) + if e != nil { + return e + } + return nil +} + +func getLatestVersion() string { + resp, err := http.Get(versionURL) + if err != nil { + return "" + } + defer func() { + closeErr := resp.Body.Close() + if closeErr != nil && err == nil { + err = closeErr + } + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "" + } + content := strings.TrimRight(string(body), "\n") + log.Infoln("latest:%s", content) + return content +} + +func updateDownloadURL() { + var middle string + + if goarch == "arm" && goarm != "" { + middle = fmt.Sprintf("-%s-%sv%s-%s", goos, goarch, goarm, latestVersion) + } else if isMIPS(goarch) && gomips != "" { + middle = fmt.Sprintf("-%s-%s-%s-%s", goos, goarch, gomips, latestVersion) + } else { + middle = fmt.Sprintf("-%s-%s-%s", goos, goarch, latestVersion) + } + + if goos == "windows" { + middle += ".zip" + } else { + middle += ".gz" + } + packageURL = baseURL + middle + //log.Infoln(packageURL) +} + +// isMIPS returns true if arch is any MIPS architecture. +func isMIPS(arch string) (ok bool) { + switch arch { + case + "mips", + "mips64", + "mips64le", + "mipsle": + return true + default: + return false + } +}