chore: use early conn to support real ws 0-rtt
This commit is contained in:
parent
a1d008e6f0
commit
75680c5866
13 changed files with 132 additions and 40 deletions
|
@ -53,6 +53,9 @@ func (rw *nopConn) Read(b []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *nopConn) Write(b []byte) (int, error) {
|
func (rw *nopConn) Write(b []byte) (int, error) {
|
||||||
|
if len(b) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -103,9 +103,9 @@ func (ss *ShadowSocks) streamConn(c net.Conn, metadata *C.Metadata) (net.Conn, e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if metadata.NetWork == C.UDP && ss.option.UDPOverTCP {
|
if metadata.NetWork == C.UDP && ss.option.UDPOverTCP {
|
||||||
return ss.method.DialConn(c, M.ParseSocksaddr(uot.UOTMagicAddress+":443"))
|
return ss.method.DialEarlyConn(c, M.ParseSocksaddr(uot.UOTMagicAddress+":443")), nil
|
||||||
}
|
}
|
||||||
return ss.method.DialConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
return ss.method.DialEarlyConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialContext implements C.ProxyAdapter
|
// DialContext implements C.ProxyAdapter
|
||||||
|
|
|
@ -213,12 +213,12 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
|
||||||
}
|
}
|
||||||
if metadata.NetWork == C.UDP {
|
if metadata.NetWork == C.UDP {
|
||||||
if v.option.XUDP {
|
if v.option.XUDP {
|
||||||
return v.client.DialXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
return v.client.DialEarlyXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
|
||||||
} else {
|
} else {
|
||||||
return v.client.DialPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
return v.client.DialEarlyPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return v.client.DialConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
return v.client.DialEarlyConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -289,9 +289,9 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o
|
||||||
}(c)
|
}(c)
|
||||||
|
|
||||||
if v.option.XUDP {
|
if v.option.XUDP {
|
||||||
c, err = v.client.DialXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
c = v.client.DialEarlyXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
||||||
} else {
|
} else {
|
||||||
c, err = v.client.DialPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
c = v.client.DialEarlyPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Dreamacro/clash/adapter/outbound"
|
"github.com/Dreamacro/clash/adapter/outbound"
|
||||||
|
"github.com/Dreamacro/clash/common/callback"
|
||||||
"github.com/Dreamacro/clash/component/dialer"
|
"github.com/Dreamacro/clash/component/dialer"
|
||||||
C "github.com/Dreamacro/clash/constant"
|
C "github.com/Dreamacro/clash/constant"
|
||||||
"github.com/Dreamacro/clash/constant/provider"
|
"github.com/Dreamacro/clash/constant/provider"
|
||||||
|
@ -30,10 +31,20 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata, opts .
|
||||||
c, err := proxy.DialContext(ctx, metadata, f.Base.DialOptions(opts...)...)
|
c, err := proxy.DialContext(ctx, metadata, f.Base.DialOptions(opts...)...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
c.AppendToChains(f)
|
c.AppendToChains(f)
|
||||||
|
} else {
|
||||||
|
f.onDialFailed(proxy.Type(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c = &callback.FirstWriteCallBackConn{
|
||||||
|
Conn: c,
|
||||||
|
Callback: func(err error) {
|
||||||
|
if err == nil {
|
||||||
f.onDialSuccess()
|
f.onDialSuccess()
|
||||||
} else {
|
} else {
|
||||||
f.onDialFailed(proxy.Type(), err)
|
f.onDialFailed(proxy.Type(), err)
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
return c, err
|
return c, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/Dreamacro/clash/adapter/outbound"
|
"github.com/Dreamacro/clash/adapter/outbound"
|
||||||
"github.com/Dreamacro/clash/common/cache"
|
"github.com/Dreamacro/clash/common/cache"
|
||||||
|
"github.com/Dreamacro/clash/common/callback"
|
||||||
"github.com/Dreamacro/clash/common/murmur3"
|
"github.com/Dreamacro/clash/common/murmur3"
|
||||||
"github.com/Dreamacro/clash/component/dialer"
|
"github.com/Dreamacro/clash/component/dialer"
|
||||||
C "github.com/Dreamacro/clash/constant"
|
C "github.com/Dreamacro/clash/constant"
|
||||||
|
@ -83,17 +84,24 @@ func jumpHash(key uint64, buckets int32) int32 {
|
||||||
// DialContext implements C.ProxyAdapter
|
// DialContext implements C.ProxyAdapter
|
||||||
func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (c C.Conn, err error) {
|
func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (c C.Conn, err error) {
|
||||||
proxy := lb.Unwrap(metadata, true)
|
proxy := lb.Unwrap(metadata, true)
|
||||||
|
c, err = proxy.DialContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
c.AppendToChains(lb)
|
c.AppendToChains(lb)
|
||||||
|
} else {
|
||||||
|
lb.onDialFailed(proxy.Type(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c = &callback.FirstWriteCallBackConn{
|
||||||
|
Conn: c,
|
||||||
|
Callback: func(err error) {
|
||||||
|
if err == nil {
|
||||||
lb.onDialSuccess()
|
lb.onDialSuccess()
|
||||||
} else {
|
} else {
|
||||||
lb.onDialFailed(proxy.Type(), err)
|
lb.onDialFailed(proxy.Type(), err)
|
||||||
}
|
}
|
||||||
}()
|
},
|
||||||
|
}
|
||||||
c, err = proxy.DialContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Dreamacro/clash/adapter/outbound"
|
"github.com/Dreamacro/clash/adapter/outbound"
|
||||||
|
"github.com/Dreamacro/clash/common/callback"
|
||||||
"github.com/Dreamacro/clash/common/singledo"
|
"github.com/Dreamacro/clash/common/singledo"
|
||||||
"github.com/Dreamacro/clash/component/dialer"
|
"github.com/Dreamacro/clash/component/dialer"
|
||||||
C "github.com/Dreamacro/clash/constant"
|
C "github.com/Dreamacro/clash/constant"
|
||||||
|
@ -38,10 +39,20 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata, opts ..
|
||||||
c, err = proxy.DialContext(ctx, metadata, u.Base.DialOptions(opts...)...)
|
c, err = proxy.DialContext(ctx, metadata, u.Base.DialOptions(opts...)...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
c.AppendToChains(u)
|
c.AppendToChains(u)
|
||||||
|
} else {
|
||||||
|
u.onDialFailed(proxy.Type(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c = &callback.FirstWriteCallBackConn{
|
||||||
|
Conn: c,
|
||||||
|
Callback: func(err error) {
|
||||||
|
if err == nil {
|
||||||
u.onDialSuccess()
|
u.onDialSuccess()
|
||||||
} else {
|
} else {
|
||||||
u.onDialFailed(proxy.Type(), err)
|
u.onDialFailed(proxy.Type(), err)
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
return c, err
|
return c, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
25
common/callback/callback.go
Normal file
25
common/callback/callback.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package callback
|
||||||
|
|
||||||
|
import (
|
||||||
|
C "github.com/Dreamacro/clash/constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FirstWriteCallBackConn struct {
|
||||||
|
C.Conn
|
||||||
|
Callback func(error)
|
||||||
|
written bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *FirstWriteCallBackConn) Write(b []byte) (n int, err error) {
|
||||||
|
defer func() {
|
||||||
|
if !c.written {
|
||||||
|
c.written = true
|
||||||
|
c.Callback(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return c.Conn.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *FirstWriteCallBackConn) Upstream() any {
|
||||||
|
return c.Conn
|
||||||
|
}
|
|
@ -12,13 +12,14 @@ var _ ExtendedConn = (*BufferedConn)(nil)
|
||||||
type BufferedConn struct {
|
type BufferedConn struct {
|
||||||
r *bufio.Reader
|
r *bufio.Reader
|
||||||
ExtendedConn
|
ExtendedConn
|
||||||
|
peeked bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBufferedConn(c net.Conn) *BufferedConn {
|
func NewBufferedConn(c net.Conn) *BufferedConn {
|
||||||
if bc, ok := c.(*BufferedConn); ok {
|
if bc, ok := c.(*BufferedConn); ok {
|
||||||
return bc
|
return bc
|
||||||
}
|
}
|
||||||
return &BufferedConn{bufio.NewReader(c), NewExtendedConn(c)}
|
return &BufferedConn{bufio.NewReader(c), NewExtendedConn(c), false}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reader returns the internal bufio.Reader.
|
// Reader returns the internal bufio.Reader.
|
||||||
|
@ -26,11 +27,20 @@ func (c *BufferedConn) Reader() *bufio.Reader {
|
||||||
return c.r
|
return c.r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *BufferedConn) Peeked() bool {
|
||||||
|
return c.peeked
|
||||||
|
}
|
||||||
|
|
||||||
// Peek returns the next n bytes without advancing the reader.
|
// Peek returns the next n bytes without advancing the reader.
|
||||||
func (c *BufferedConn) Peek(n int) ([]byte, error) {
|
func (c *BufferedConn) Peek(n int) ([]byte, error) {
|
||||||
|
c.peeked = true
|
||||||
return c.r.Peek(n)
|
return c.r.Peek(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *BufferedConn) Discard(n int) (discarded int, err error) {
|
||||||
|
return c.r.Discard(n)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *BufferedConn) Read(p []byte) (int, error) {
|
func (c *BufferedConn) Read(p []byte) (int, error) {
|
||||||
return c.r.Read(p)
|
return c.r.Read(p)
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,12 +36,7 @@ type SnifferDispatcher struct {
|
||||||
parsePureIp bool
|
parsePureIp bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
|
func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata) {
|
||||||
bufConn, ok := conn.(*N.BufferedConn)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if (metadata.Host == "" && sd.parsePureIp) || sd.forceDomain.Search(metadata.Host) != nil || (metadata.DNSMode == C.DNSMapping && sd.forceDnsMapping) {
|
if (metadata.Host == "" && sd.parsePureIp) || sd.forceDomain.Search(metadata.Host) != nil || (metadata.DNSMode == C.DNSMapping && sd.forceDnsMapping) {
|
||||||
port, err := strconv.ParseUint(metadata.DstPort, 10, 16)
|
port, err := strconv.ParseUint(metadata.DstPort, 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -74,7 +69,7 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
|
||||||
}
|
}
|
||||||
sd.rwMux.RUnlock()
|
sd.rwMux.RUnlock()
|
||||||
|
|
||||||
if host, err := sd.sniffDomain(bufConn, metadata); err != nil {
|
if host, err := sd.sniffDomain(conn, metadata); err != nil {
|
||||||
sd.cacheSniffFailed(metadata)
|
sd.cacheSniffFailed(metadata)
|
||||||
log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%s] to [%s:%s]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort)
|
log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%s] to [%s:%s]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort)
|
||||||
return
|
return
|
||||||
|
|
|
@ -3,6 +3,8 @@ package constant
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
N "github.com/Dreamacro/clash/common/net"
|
||||||
|
|
||||||
"github.com/gofrs/uuid"
|
"github.com/gofrs/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,7 +15,7 @@ type PlainContext interface {
|
||||||
type ConnContext interface {
|
type ConnContext interface {
|
||||||
PlainContext
|
PlainContext
|
||||||
Metadata() *Metadata
|
Metadata() *Metadata
|
||||||
Conn() net.Conn
|
Conn() *N.BufferedConn
|
||||||
}
|
}
|
||||||
|
|
||||||
type PacketConnContext interface {
|
type PacketConnContext interface {
|
||||||
|
|
|
@ -12,7 +12,7 @@ import (
|
||||||
type ConnContext struct {
|
type ConnContext struct {
|
||||||
id uuid.UUID
|
id uuid.UUID
|
||||||
metadata *C.Metadata
|
metadata *C.Metadata
|
||||||
conn net.Conn
|
conn *N.BufferedConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnContext(conn net.Conn, metadata *C.Metadata) *ConnContext {
|
func NewConnContext(conn net.Conn, metadata *C.Metadata) *ConnContext {
|
||||||
|
@ -36,6 +36,6 @@ func (c *ConnContext) Metadata() *C.Metadata {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conn implement C.ConnContext Conn
|
// Conn implement C.ConnContext Conn
|
||||||
func (c *ConnContext) Conn() net.Conn {
|
func (c *ConnContext) Conn() *N.BufferedConn {
|
||||||
return c.conn
|
return c.conn
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Dreamacro/clash/common/buf"
|
"github.com/Dreamacro/clash/common/buf"
|
||||||
N "github.com/Dreamacro/clash/common/net"
|
N "github.com/Dreamacro/clash/common/net"
|
||||||
|
@ -208,12 +207,12 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
//go func() {
|
||||||
select {
|
// select {
|
||||||
case <-c.handshake:
|
// case <-c.handshake:
|
||||||
case <-time.After(200 * time.Millisecond):
|
// case <-time.After(200 * time.Millisecond):
|
||||||
c.sendRequest(nil)
|
// c.sendRequest(nil)
|
||||||
}
|
// }
|
||||||
}()
|
//}()
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -366,8 +366,20 @@ func handleTCPConn(connCtx C.ConnContext) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn := connCtx.Conn()
|
||||||
if sniffer.Dispatcher.Enable() && sniffingEnable {
|
if sniffer.Dispatcher.Enable() && sniffingEnable {
|
||||||
sniffer.Dispatcher.TCPSniff(connCtx.Conn(), metadata)
|
sniffer.Dispatcher.TCPSniff(conn, metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
peekMutex := sync.Mutex{}
|
||||||
|
if !conn.Peeked() {
|
||||||
|
peekMutex.Lock()
|
||||||
|
go func() {
|
||||||
|
defer peekMutex.Unlock()
|
||||||
|
_ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
||||||
|
_, _ = conn.Peek(1)
|
||||||
|
_ = conn.SetReadDeadline(time.Time{})
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy, rule, err := resolveMetadata(connCtx, metadata)
|
proxy, rule, err := resolveMetadata(connCtx, metadata)
|
||||||
|
@ -387,10 +399,26 @@ func handleTCPConn(connCtx C.ConnContext) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var peekBytes []byte
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
remoteConn, err := retry(ctx, func(ctx context.Context) (C.Conn, error) {
|
remoteConn, err := retry(ctx, func(ctx context.Context) (C.Conn, error) {
|
||||||
return proxy.DialContext(ctx, dialMetadata)
|
remoteConn, err := proxy.DialContext(ctx, dialMetadata)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
peekMutex.Lock()
|
||||||
|
defer peekMutex.Unlock()
|
||||||
|
peekBytes, _ = conn.Peek(conn.Buffered())
|
||||||
|
_, err = remoteConn.Write(peekBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if peekLen := len(peekBytes); peekLen > 0 {
|
||||||
|
_, _ = conn.Discard(peekLen)
|
||||||
|
}
|
||||||
|
return remoteConn, err
|
||||||
}, func(err error) {
|
}, func(err error) {
|
||||||
if rule == nil {
|
if rule == nil {
|
||||||
log.Warnln(
|
log.Warnln(
|
||||||
|
|
Loading…
Reference in a new issue