Feature: add experimental connections API

This commit is contained in:
gVisor bot 2019-10-27 21:44:07 +08:00
parent b34604523c
commit 93f13c627c
16 changed files with 365 additions and 130 deletions

View file

@ -23,6 +23,7 @@ func (h *HTTPAdapter) Metadata() *C.Metadata {
// NewHTTP is HTTPAdapter generator // NewHTTP is HTTPAdapter generator
func NewHTTP(request *http.Request, conn net.Conn) *HTTPAdapter { func NewHTTP(request *http.Request, conn net.Conn) *HTTPAdapter {
metadata := parseHTTPAddr(request) metadata := parseHTTPAddr(request)
metadata.Type = C.HTTP
if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil { if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil {
metadata.SrcIP = ip metadata.SrcIP = ip
metadata.SrcPort = port metadata.SrcPort = port

View file

@ -3,11 +3,14 @@ package adapters
import ( import (
"net" "net"
"net/http" "net/http"
C "github.com/Dreamacro/clash/constant"
) )
// NewHTTPS is HTTPAdapter generator // NewHTTPS is HTTPAdapter generator
func NewHTTPS(request *http.Request, conn net.Conn) *SocketAdapter { func NewHTTPS(request *http.Request, conn net.Conn) *SocketAdapter {
metadata := parseHTTPAddr(request) metadata := parseHTTPAddr(request)
metadata.Type = C.HTTPCONNECT
if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil { if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil {
metadata.SrcIP = ip metadata.SrcIP = ip
metadata.SrcPort = port metadata.SrcPort = port

View file

@ -20,11 +20,11 @@ func parseSocksAddr(target socks5.Addr) *C.Metadata {
metadata.DstPort = strconv.Itoa((int(target[2+target[1]]) << 8) | int(target[2+target[1]+1])) metadata.DstPort = strconv.Itoa((int(target[2+target[1]]) << 8) | int(target[2+target[1]+1]))
case socks5.AtypIPv4: case socks5.AtypIPv4:
ip := net.IP(target[1 : 1+net.IPv4len]) ip := net.IP(target[1 : 1+net.IPv4len])
metadata.DstIP = &ip metadata.DstIP = ip
metadata.DstPort = strconv.Itoa((int(target[1+net.IPv4len]) << 8) | int(target[1+net.IPv4len+1])) metadata.DstPort = strconv.Itoa((int(target[1+net.IPv4len]) << 8) | int(target[1+net.IPv4len+1]))
case socks5.AtypIPv6: case socks5.AtypIPv6:
ip := net.IP(target[1 : 1+net.IPv6len]) ip := net.IP(target[1 : 1+net.IPv6len])
metadata.DstIP = &ip metadata.DstIP = ip
metadata.DstPort = strconv.Itoa((int(target[1+net.IPv6len]) << 8) | int(target[1+net.IPv6len+1])) metadata.DstPort = strconv.Itoa((int(target[1+net.IPv6len]) << 8) | int(target[1+net.IPv6len+1]))
} }
@ -40,7 +40,6 @@ func parseHTTPAddr(request *http.Request) *C.Metadata {
metadata := &C.Metadata{ metadata := &C.Metadata{
NetWork: C.TCP, NetWork: C.TCP,
Type: C.HTTP,
AddrType: C.AtypDomainName, AddrType: C.AtypDomainName,
Host: host, Host: host,
DstIP: nil, DstIP: nil,
@ -55,18 +54,18 @@ func parseHTTPAddr(request *http.Request) *C.Metadata {
default: default:
metadata.AddrType = C.AtypIPv4 metadata.AddrType = C.AtypIPv4
} }
metadata.DstIP = &ip metadata.DstIP = ip
} }
return metadata return metadata
} }
func parseAddr(addr string) (*net.IP, string, error) { func parseAddr(addr string) (net.IP, string, error) {
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
ip := net.ParseIP(host) ip := net.ParseIP(host)
return &ip, port, nil return ip, port, nil
} }

View file

@ -14,7 +14,7 @@ type Pool struct {
min uint32 min uint32
gateway uint32 gateway uint32
offset uint32 offset uint32
mux *sync.Mutex mux sync.Mutex
cache *cache.LruCache cache *cache.LruCache
} }
@ -111,7 +111,6 @@ func New(ipnet *net.IPNet, size int) (*Pool, error) {
min: min, min: min,
max: max, max: max,
gateway: min - 1, gateway: min - 1,
mux: &sync.Mutex{},
cache: cache.NewLRUCache(cache.WithSize(size * 2)), cache: cache.NewLRUCache(cache.WithSize(size * 2)),
}, nil }, nil
} }

View file

@ -1,6 +1,7 @@
package constant package constant
import ( import (
"encoding/json"
"net" "net"
) )
@ -14,6 +15,7 @@ const (
UDP UDP
HTTP Type = iota HTTP Type = iota
HTTPCONNECT
SOCKS SOCKS
REDIR REDIR
) )
@ -27,18 +29,41 @@ func (n *NetWork) String() string {
return "udp" return "udp"
} }
func (n NetWork) MarshalJSON() ([]byte, error) {
return json.Marshal(n.String())
}
type Type int type Type int
func (t Type) String() string {
switch t {
case HTTP:
return "HTTP"
case HTTPCONNECT:
return "HTTP Connect"
case SOCKS:
return "Socks5"
case REDIR:
return "Redir"
default:
return "Unknown"
}
}
func (t Type) MarshalJSON() ([]byte, error) {
return json.Marshal(t.String())
}
// Metadata is used to store connection address // Metadata is used to store connection address
type Metadata struct { type Metadata struct {
NetWork NetWork NetWork NetWork `json:"network"`
Type Type Type Type `json:"type"`
SrcIP *net.IP SrcIP net.IP `json:"sourceIP"`
DstIP *net.IP DstIP net.IP `json:"destinationIP"`
SrcPort string SrcPort string `json:"sourcePort"`
DstPort string DstPort string `json:"destinationPort"`
AddrType int AddrType int `json:"-"`
Host string Host string `json:"host"`
} }
func (m *Metadata) RemoteAddress() string { func (m *Metadata) RemoteAddress() string {

View file

@ -24,7 +24,7 @@ func (rt RuleType) String() string {
case DomainKeyword: case DomainKeyword:
return "DomainKeyword" return "DomainKeyword"
case GEOIP: case GEOIP:
return "GEOIP" return "GeoIP"
case IPCIDR: case IPCIDR:
return "IPCIDR" return "IPCIDR"
case SrcIPCIDR: case SrcIPCIDR:
@ -34,7 +34,7 @@ func (rt RuleType) String() string {
case DstPort: case DstPort:
return "DstPort" return "DstPort"
case MATCH: case MATCH:
return "MATCH" return "Match"
default: default:
return "Unknown" return "Unknown"
} }

View file

@ -1,55 +0,0 @@
package constant
import (
"time"
)
type Traffic struct {
up chan int64
down chan int64
upCount int64
downCount int64
upTotal int64
downTotal int64
interval time.Duration
}
func (t *Traffic) Up() chan<- int64 {
return t.up
}
func (t *Traffic) Down() chan<- int64 {
return t.down
}
func (t *Traffic) Now() (up int64, down int64) {
return t.upTotal, t.downTotal
}
func (t *Traffic) handle() {
go t.handleCh(t.up, &t.upCount, &t.upTotal)
go t.handleCh(t.down, &t.downCount, &t.downTotal)
}
func (t *Traffic) handleCh(ch <-chan int64, count *int64, total *int64) {
ticker := time.NewTicker(t.interval)
for {
select {
case n := <-ch:
*count += n
case <-ticker.C:
*total = *count
*count = 0
}
}
}
func NewTraffic(interval time.Duration) *Traffic {
t := &Traffic{
up: make(chan int64),
down: make(chan int64),
interval: interval,
}
go t.handle()
return t
}

91
hub/route/connections.go Normal file
View file

@ -0,0 +1,91 @@
package route
import (
"bytes"
"encoding/json"
"net/http"
"strconv"
"time"
T "github.com/Dreamacro/clash/tunnel"
"github.com/gorilla/websocket"
"github.com/go-chi/chi"
"github.com/go-chi/render"
)
func connectionRouter() http.Handler {
r := chi.NewRouter()
r.Get("/", getConnections)
r.Delete("/", closeAllConnections)
r.Delete("/{id}", closeConnection)
return r
}
func getConnections(w http.ResponseWriter, r *http.Request) {
if !websocket.IsWebSocketUpgrade(r) {
snapshot := T.DefaultManager.Snapshot()
render.JSON(w, r, snapshot)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
intervalStr := r.URL.Query().Get("interval")
interval := 1000
if intervalStr != "" {
t, err := strconv.Atoi(intervalStr)
if err != nil {
render.Status(r, http.StatusBadRequest)
render.JSON(w, r, ErrBadRequest)
return
}
interval = t
}
buf := &bytes.Buffer{}
sendSnapshot := func() error {
buf.Reset()
snapshot := T.DefaultManager.Snapshot()
if err := json.NewEncoder(buf).Encode(snapshot); err != nil {
return err
}
return conn.WriteMessage(websocket.TextMessage, buf.Bytes())
}
if err := sendSnapshot(); err != nil {
return
}
tick := time.NewTicker(time.Millisecond * time.Duration(interval))
for range tick.C {
if err := sendSnapshot(); err != nil {
break
}
}
}
func closeConnection(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
snapshot := T.DefaultManager.Snapshot()
for _, c := range snapshot.Connections {
if id == c.ID() {
c.Close()
break
}
}
render.NoContent(w, r)
}
func closeAllConnections(w http.ResponseWriter, r *http.Request) {
snapshot := T.DefaultManager.Snapshot()
for _, c := range snapshot.Connections {
c.Close()
}
render.NoContent(w, r)
}

View file

@ -67,6 +67,7 @@ func Start(addr string, secret string) {
r.Mount("/configs", configRouter()) r.Mount("/configs", configRouter())
r.Mount("/proxies", proxyRouter()) r.Mount("/proxies", proxyRouter())
r.Mount("/rules", ruleRouter()) r.Mount("/rules", ruleRouter())
r.Mount("/connections", connectionRouter())
}) })
if uiPath != "" { if uiPath != "" {
@ -140,7 +141,7 @@ func traffic(w http.ResponseWriter, r *http.Request) {
} }
tick := time.NewTicker(time.Second) tick := time.NewTicker(time.Second)
t := T.Instance().Traffic() t := T.DefaultManager
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
var err error var err error
for range tick.C { for range tick.C {

View file

@ -27,7 +27,7 @@ func (g *GEOIP) IsMatch(metadata *C.Metadata) bool {
if metadata.DstIP == nil { if metadata.DstIP == nil {
return false return false
} }
record, _ := mmdb.Country(*metadata.DstIP) record, _ := mmdb.Country(metadata.DstIP)
return record.Country.IsoCode == g.country return record.Country.IsoCode == g.country
} }

View file

@ -24,7 +24,7 @@ func (i *IPCIDR) IsMatch(metadata *C.Metadata) bool {
if i.isSourceIP { if i.isSourceIP {
ip = metadata.SrcIP ip = metadata.SrcIP
} }
return ip != nil && i.ipnet.Contains(*ip) return ip != nil && i.ipnet.Contains(ip)
} }
func (i *IPCIDR) Adapter() string { func (i *IPCIDR) Adapter() string {

View file

@ -13,12 +13,11 @@ import (
) )
func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) { func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) {
conn := newTrafficTrack(outbound, t.traffic)
req := request.R req := request.R
host := req.Host host := req.Host
inboundReeder := bufio.NewReader(request) inboundReeder := bufio.NewReader(request)
outboundReeder := bufio.NewReader(conn) outboundReeder := bufio.NewReader(outbound)
for { for {
keepAlive := strings.TrimSpace(strings.ToLower(req.Header.Get("Proxy-Connection"))) == "keep-alive" keepAlive := strings.TrimSpace(strings.ToLower(req.Header.Get("Proxy-Connection"))) == "keep-alive"
@ -26,7 +25,7 @@ func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) {
req.Header.Set("Connection", "close") req.Header.Set("Connection", "close")
req.RequestURI = "" req.RequestURI = ""
adapters.RemoveHopByHopHeaders(req.Header) adapters.RemoveHopByHopHeaders(req.Header)
err := req.Write(conn) err := req.Write(outbound)
if err != nil { if err != nil {
break break
} }
@ -91,7 +90,7 @@ func (t *Tunnel) handleUDPToRemote(conn net.Conn, pc net.PacketConn, addr net.Ad
if _, err = pc.WriteTo(buf[:n], addr); err != nil { if _, err = pc.WriteTo(buf[:n], addr); err != nil {
return return
} }
t.traffic.Up() <- int64(n) DefaultManager.Upload() <- int64(n)
} }
func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn, key string, timeout time.Duration) { func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn, key string, timeout time.Duration) {
@ -111,13 +110,12 @@ func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn, key string,
if err != nil { if err != nil {
return return
} }
t.traffic.Down() <- int64(n) DefaultManager.Download() <- int64(n)
} }
} }
func (t *Tunnel) handleSocket(request *adapters.SocketAdapter, outbound net.Conn) { func (t *Tunnel) handleSocket(request *adapters.SocketAdapter, outbound net.Conn) {
conn := newTrafficTrack(outbound, t.traffic) relay(request, outbound)
relay(request, conn)
} }
// relay copies between left and right bidirectionally. // relay copies between left and right bidirectionally.

87
tunnel/manager.go Normal file
View file

@ -0,0 +1,87 @@
package tunnel
import (
"sync"
"time"
)
var DefaultManager *Manager
func init() {
DefaultManager = &Manager{
upload: make(chan int64),
download: make(chan int64),
}
DefaultManager.handle()
}
type Manager struct {
connections sync.Map
upload chan int64
download chan int64
uploadTemp int64
downloadTemp int64
uploadBlip int64
downloadBlip int64
uploadTotal int64
downloadTotal int64
}
func (m *Manager) Join(c tracker) {
m.connections.Store(c.ID(), c)
}
func (m *Manager) Leave(c tracker) {
m.connections.Delete(c.ID())
}
func (m *Manager) Upload() chan<- int64 {
return m.upload
}
func (m *Manager) Download() chan<- int64 {
return m.download
}
func (m *Manager) Now() (up int64, down int64) {
return m.uploadBlip, m.downloadBlip
}
func (m *Manager) Snapshot() *Snapshot {
connections := []tracker{}
m.connections.Range(func(key, value interface{}) bool {
connections = append(connections, value.(tracker))
return true
})
return &Snapshot{
UploadTotal: m.uploadTotal,
DownloadTotal: m.downloadTotal,
Connections: connections,
}
}
func (m *Manager) handle() {
go m.handleCh(m.upload, &m.uploadTemp, &m.uploadBlip, &m.uploadTotal)
go m.handleCh(m.download, &m.downloadTemp, &m.downloadBlip, &m.downloadTotal)
}
func (m *Manager) handleCh(ch <-chan int64, temp *int64, blip *int64, total *int64) {
ticker := time.NewTicker(time.Second)
for {
select {
case n := <-ch:
*temp += n
*total += n
case <-ticker.C:
*blip = *temp
*temp = 0
}
}
}
type Snapshot struct {
DownloadTotal int64 `json:"downloadTotal"`
UploadTotal int64 `json:"uploadTotal"`
Connections []tracker `json:"connections"`
}

122
tunnel/tracker.go Normal file
View file

@ -0,0 +1,122 @@
package tunnel
import (
"net"
"time"
C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
)
type tracker interface {
ID() string
Close() error
}
type trackerInfo struct {
UUID uuid.UUID `json:"id"`
Metadata *C.Metadata `json:"metadata"`
UploadTotal int64 `json:"upload"`
DownloadTotal int64 `json:"download"`
Start time.Time `json:"start"`
Chain C.Chain `json:"chains"`
Rule string `json:"rule"`
}
type tcpTracker struct {
C.Conn `json:"-"`
*trackerInfo
manager *Manager
}
func (tt *tcpTracker) ID() string {
return tt.UUID.String()
}
func (tt *tcpTracker) Read(b []byte) (int, error) {
n, err := tt.Conn.Read(b)
download := int64(n)
tt.manager.Download() <- download
tt.DownloadTotal += download
return n, err
}
func (tt *tcpTracker) Write(b []byte) (int, error) {
n, err := tt.Conn.Write(b)
upload := int64(n)
tt.manager.Upload() <- upload
tt.UploadTotal += upload
return n, err
}
func (tt *tcpTracker) Close() error {
tt.manager.Leave(tt)
return tt.Conn.Close()
}
func newTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker {
uuid, _ := uuid.NewV4()
t := &tcpTracker{
Conn: conn,
manager: manager,
trackerInfo: &trackerInfo{
UUID: uuid,
Start: time.Now(),
Metadata: metadata,
Chain: conn.Chains(),
Rule: rule.RuleType().String(),
},
}
manager.Join(t)
return t
}
type udpTracker struct {
C.PacketConn `json:"-"`
*trackerInfo
manager *Manager
}
func (ut *udpTracker) ID() string {
return ut.UUID.String()
}
func (ut *udpTracker) ReadFrom(b []byte) (int, net.Addr, error) {
n, addr, err := ut.PacketConn.ReadFrom(b)
download := int64(n)
ut.manager.Download() <- download
ut.DownloadTotal += download
return n, addr, err
}
func (ut *udpTracker) WriteTo(b []byte, addr net.Addr) (int, error) {
n, err := ut.PacketConn.WriteTo(b, addr)
upload := int64(n)
ut.manager.Upload() <- upload
ut.UploadTotal += upload
return n, err
}
func (ut *udpTracker) Close() error {
ut.manager.Leave(ut)
return ut.PacketConn.Close()
}
func newUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule) *udpTracker {
uuid, _ := uuid.NewV4()
ut := &udpTracker{
PacketConn: conn,
manager: manager,
trackerInfo: &trackerInfo{
UUID: uuid,
Start: time.Now(),
Metadata: metadata,
Chain: conn.Chains(),
Rule: rule.RuleType().String(),
},
}
manager.Join(ut)
return ut
}

View file

@ -30,8 +30,7 @@ type Tunnel struct {
natTable *nat.Table natTable *nat.Table
rules []C.Rule rules []C.Rule
proxies map[string]C.Proxy proxies map[string]C.Proxy
configMux *sync.RWMutex configMux sync.RWMutex
traffic *C.Traffic
// experimental features // experimental features
ignoreResolveFail bool ignoreResolveFail bool
@ -50,11 +49,6 @@ func (t *Tunnel) Add(req C.ServerAdapter) {
} }
} }
// Traffic return traffic of all connections
func (t *Tunnel) Traffic() *C.Traffic {
return t.traffic
}
// Rules return all rules // Rules return all rules
func (t *Tunnel) Rules() []C.Rule { func (t *Tunnel) Rules() []C.Rule {
return t.rules return t.rules
@ -123,7 +117,7 @@ func (t *Tunnel) needLookupIP(metadata *C.Metadata) bool {
func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) { func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
// preprocess enhanced-mode metadata // preprocess enhanced-mode metadata
if t.needLookupIP(metadata) { if t.needLookupIP(metadata) {
host, exist := dns.DefaultResolver.IPToHost(*metadata.DstIP) host, exist := dns.DefaultResolver.IPToHost(metadata.DstIP)
if exist { if exist {
metadata.Host = host metadata.Host = host
metadata.AddrType = C.AtypDomainName metadata.AddrType = C.AtypDomainName
@ -188,8 +182,8 @@ func (t *Tunnel) handleUDPConn(localConn C.ServerAdapter) {
wg.Done() wg.Done()
return return
} }
pc = rawPc
addr = nAddr addr = nAddr
pc = newUDPTracker(rawPc, DefaultManager, metadata, rule)
if rule != nil { if rule != nil {
log.Infoln("%s --> %v match %s using %s", metadata.SrcIP.String(), metadata.String(), rule.RuleType().String(), rawPc.Chains().String()) log.Infoln("%s --> %v match %s using %s", metadata.SrcIP.String(), metadata.String(), rule.RuleType().String(), rawPc.Chains().String())
@ -231,6 +225,7 @@ func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter) {
log.Warnln("dial %s error: %s", proxy.Name(), err.Error()) log.Warnln("dial %s error: %s", proxy.Name(), err.Error())
return return
} }
remoteConn = newTCPTracker(remoteConn, DefaultManager, metadata, rule)
defer remoteConn.Close() defer remoteConn.Close()
if rule != nil { if rule != nil {
@ -259,7 +254,7 @@ func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
if node := dns.DefaultHosts.Search(metadata.Host); node != nil { if node := dns.DefaultHosts.Search(metadata.Host); node != nil {
ip := node.Data.(net.IP) ip := node.Data.(net.IP)
metadata.DstIP = &ip metadata.DstIP = ip
resolved = true resolved = true
} }
@ -273,7 +268,7 @@ func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
log.Debugln("[DNS] resolve %s error: %s", metadata.Host, err.Error()) log.Debugln("[DNS] resolve %s error: %s", metadata.Host, err.Error())
} else { } else {
log.Debugln("[DNS] %s --> %s", metadata.Host, ip.String()) log.Debugln("[DNS] %s --> %s", metadata.Host, ip.String())
metadata.DstIP = &ip metadata.DstIP = ip
} }
resolved = true resolved = true
} }
@ -300,8 +295,6 @@ func newTunnel() *Tunnel {
udpQueue: channels.NewInfiniteChannel(), udpQueue: channels.NewInfiniteChannel(),
natTable: nat.New(), natTable: nat.New(),
proxies: make(map[string]C.Proxy), proxies: make(map[string]C.Proxy),
configMux: &sync.RWMutex{},
traffic: C.NewTraffic(time.Second),
mode: Rule, mode: Rule,
} }
} }

View file

@ -1,29 +0,0 @@
package tunnel
import (
"net"
C "github.com/Dreamacro/clash/constant"
)
// TrafficTrack record traffic of net.Conn
type TrafficTrack struct {
net.Conn
traffic *C.Traffic
}
func (tt *TrafficTrack) Read(b []byte) (int, error) {
n, err := tt.Conn.Read(b)
tt.traffic.Down() <- int64(n)
return n, err
}
func (tt *TrafficTrack) Write(b []byte) (int, error) {
n, err := tt.Conn.Write(b)
tt.traffic.Up() <- int64(n)
return n, err
}
func newTrafficTrack(conn net.Conn, traffic *C.Traffic) *TrafficTrack {
return &TrafficTrack{traffic: traffic, Conn: conn}
}