chore: use early conn to support real ws 0-rtt
This commit is contained in:
parent
a1d008e6f0
commit
75680c5866
13 changed files with 132 additions and 40 deletions
|
@ -53,6 +53,9 @@ func (rw *nopConn) Read(b []byte) (int, error) {
|
|||
}
|
||||
|
||||
func (rw *nopConn) Write(b []byte) (int, error) {
|
||||
if len(b) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
|
|
|
@ -103,9 +103,9 @@ func (ss *ShadowSocks) streamConn(c net.Conn, metadata *C.Metadata) (net.Conn, e
|
|||
}
|
||||
}
|
||||
if metadata.NetWork == C.UDP && ss.option.UDPOverTCP {
|
||||
return ss.method.DialConn(c, M.ParseSocksaddr(uot.UOTMagicAddress+":443"))
|
||||
return ss.method.DialEarlyConn(c, M.ParseSocksaddr(uot.UOTMagicAddress+":443")), nil
|
||||
}
|
||||
return ss.method.DialConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
||||
return ss.method.DialEarlyConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
|
||||
}
|
||||
|
||||
// DialContext implements C.ProxyAdapter
|
||||
|
|
|
@ -213,12 +213,12 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
|
|||
}
|
||||
if metadata.NetWork == C.UDP {
|
||||
if v.option.XUDP {
|
||||
return v.client.DialXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
||||
return v.client.DialEarlyXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
|
||||
} else {
|
||||
return v.client.DialPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
||||
return v.client.DialEarlyPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
|
||||
}
|
||||
} else {
|
||||
return v.client.DialConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
||||
return v.client.DialEarlyConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -289,9 +289,9 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o
|
|||
}(c)
|
||||
|
||||
if v.option.XUDP {
|
||||
c, err = v.client.DialXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
||||
c = v.client.DialEarlyXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
||||
} else {
|
||||
c, err = v.client.DialPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
||||
c = v.client.DialEarlyPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/Dreamacro/clash/adapter/outbound"
|
||||
"github.com/Dreamacro/clash/common/callback"
|
||||
"github.com/Dreamacro/clash/component/dialer"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
"github.com/Dreamacro/clash/constant/provider"
|
||||
|
@ -30,10 +31,20 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata, opts .
|
|||
c, err := proxy.DialContext(ctx, metadata, f.Base.DialOptions(opts...)...)
|
||||
if err == nil {
|
||||
c.AppendToChains(f)
|
||||
} else {
|
||||
f.onDialFailed(proxy.Type(), err)
|
||||
}
|
||||
|
||||
c = &callback.FirstWriteCallBackConn{
|
||||
Conn: c,
|
||||
Callback: func(err error) {
|
||||
if err == nil {
|
||||
f.onDialSuccess()
|
||||
} else {
|
||||
f.onDialFailed(proxy.Type(), err)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
return c, err
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/Dreamacro/clash/adapter/outbound"
|
||||
"github.com/Dreamacro/clash/common/cache"
|
||||
"github.com/Dreamacro/clash/common/callback"
|
||||
"github.com/Dreamacro/clash/common/murmur3"
|
||||
"github.com/Dreamacro/clash/component/dialer"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
|
@ -83,17 +84,24 @@ func jumpHash(key uint64, buckets int32) int32 {
|
|||
// DialContext implements C.ProxyAdapter
|
||||
func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (c C.Conn, err error) {
|
||||
proxy := lb.Unwrap(metadata, true)
|
||||
c, err = proxy.DialContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
c.AppendToChains(lb)
|
||||
} else {
|
||||
lb.onDialFailed(proxy.Type(), err)
|
||||
}
|
||||
|
||||
c = &callback.FirstWriteCallBackConn{
|
||||
Conn: c,
|
||||
Callback: func(err error) {
|
||||
if err == nil {
|
||||
lb.onDialSuccess()
|
||||
} else {
|
||||
lb.onDialFailed(proxy.Type(), err)
|
||||
}
|
||||
}()
|
||||
|
||||
c, err = proxy.DialContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
|
||||
},
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/Dreamacro/clash/adapter/outbound"
|
||||
"github.com/Dreamacro/clash/common/callback"
|
||||
"github.com/Dreamacro/clash/common/singledo"
|
||||
"github.com/Dreamacro/clash/component/dialer"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
|
@ -38,10 +39,20 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata, opts ..
|
|||
c, err = proxy.DialContext(ctx, metadata, u.Base.DialOptions(opts...)...)
|
||||
if err == nil {
|
||||
c.AppendToChains(u)
|
||||
} else {
|
||||
u.onDialFailed(proxy.Type(), err)
|
||||
}
|
||||
|
||||
c = &callback.FirstWriteCallBackConn{
|
||||
Conn: c,
|
||||
Callback: func(err error) {
|
||||
if err == nil {
|
||||
u.onDialSuccess()
|
||||
} else {
|
||||
u.onDialFailed(proxy.Type(), err)
|
||||
}
|
||||
},
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
|
|
25
common/callback/callback.go
Normal file
25
common/callback/callback.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package callback
|
||||
|
||||
import (
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
)
|
||||
|
||||
type FirstWriteCallBackConn struct {
|
||||
C.Conn
|
||||
Callback func(error)
|
||||
written bool
|
||||
}
|
||||
|
||||
func (c *FirstWriteCallBackConn) Write(b []byte) (n int, err error) {
|
||||
defer func() {
|
||||
if !c.written {
|
||||
c.written = true
|
||||
c.Callback(err)
|
||||
}
|
||||
}()
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (c *FirstWriteCallBackConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
|
@ -12,13 +12,14 @@ var _ ExtendedConn = (*BufferedConn)(nil)
|
|||
type BufferedConn struct {
|
||||
r *bufio.Reader
|
||||
ExtendedConn
|
||||
peeked bool
|
||||
}
|
||||
|
||||
func NewBufferedConn(c net.Conn) *BufferedConn {
|
||||
if bc, ok := c.(*BufferedConn); ok {
|
||||
return bc
|
||||
}
|
||||
return &BufferedConn{bufio.NewReader(c), NewExtendedConn(c)}
|
||||
return &BufferedConn{bufio.NewReader(c), NewExtendedConn(c), false}
|
||||
}
|
||||
|
||||
// Reader returns the internal bufio.Reader.
|
||||
|
@ -26,11 +27,20 @@ func (c *BufferedConn) Reader() *bufio.Reader {
|
|||
return c.r
|
||||
}
|
||||
|
||||
func (c *BufferedConn) Peeked() bool {
|
||||
return c.peeked
|
||||
}
|
||||
|
||||
// Peek returns the next n bytes without advancing the reader.
|
||||
func (c *BufferedConn) Peek(n int) ([]byte, error) {
|
||||
c.peeked = true
|
||||
return c.r.Peek(n)
|
||||
}
|
||||
|
||||
func (c *BufferedConn) Discard(n int) (discarded int, err error) {
|
||||
return c.r.Discard(n)
|
||||
}
|
||||
|
||||
func (c *BufferedConn) Read(p []byte) (int, error) {
|
||||
return c.r.Read(p)
|
||||
}
|
||||
|
|
|
@ -36,12 +36,7 @@ type SnifferDispatcher struct {
|
|||
parsePureIp bool
|
||||
}
|
||||
|
||||
func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
|
||||
bufConn, ok := conn.(*N.BufferedConn)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata) {
|
||||
if (metadata.Host == "" && sd.parsePureIp) || sd.forceDomain.Search(metadata.Host) != nil || (metadata.DNSMode == C.DNSMapping && sd.forceDnsMapping) {
|
||||
port, err := strconv.ParseUint(metadata.DstPort, 10, 16)
|
||||
if err != nil {
|
||||
|
@ -74,7 +69,7 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
|
|||
}
|
||||
sd.rwMux.RUnlock()
|
||||
|
||||
if host, err := sd.sniffDomain(bufConn, metadata); err != nil {
|
||||
if host, err := sd.sniffDomain(conn, metadata); err != nil {
|
||||
sd.cacheSniffFailed(metadata)
|
||||
log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%s] to [%s:%s]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort)
|
||||
return
|
||||
|
|
|
@ -3,6 +3,8 @@ package constant
|
|||
import (
|
||||
"net"
|
||||
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
)
|
||||
|
||||
|
@ -13,7 +15,7 @@ type PlainContext interface {
|
|||
type ConnContext interface {
|
||||
PlainContext
|
||||
Metadata() *Metadata
|
||||
Conn() net.Conn
|
||||
Conn() *N.BufferedConn
|
||||
}
|
||||
|
||||
type PacketConnContext interface {
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
type ConnContext struct {
|
||||
id uuid.UUID
|
||||
metadata *C.Metadata
|
||||
conn net.Conn
|
||||
conn *N.BufferedConn
|
||||
}
|
||||
|
||||
func NewConnContext(conn net.Conn, metadata *C.Metadata) *ConnContext {
|
||||
|
@ -36,6 +36,6 @@ func (c *ConnContext) Metadata() *C.Metadata {
|
|||
}
|
||||
|
||||
// Conn implement C.ConnContext Conn
|
||||
func (c *ConnContext) Conn() net.Conn {
|
||||
func (c *ConnContext) Conn() *N.BufferedConn {
|
||||
return c.conn
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Dreamacro/clash/common/buf"
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
|
@ -208,12 +207,12 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
|
|||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-c.handshake:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
c.sendRequest(nil)
|
||||
}
|
||||
}()
|
||||
//go func() {
|
||||
// select {
|
||||
// case <-c.handshake:
|
||||
// case <-time.After(200 * time.Millisecond):
|
||||
// c.sendRequest(nil)
|
||||
// }
|
||||
//}()
|
||||
return c, nil
|
||||
}
|
||||
|
|
|
@ -366,8 +366,20 @@ func handleTCPConn(connCtx C.ConnContext) {
|
|||
return
|
||||
}
|
||||
|
||||
conn := connCtx.Conn()
|
||||
if sniffer.Dispatcher.Enable() && sniffingEnable {
|
||||
sniffer.Dispatcher.TCPSniff(connCtx.Conn(), metadata)
|
||||
sniffer.Dispatcher.TCPSniff(conn, metadata)
|
||||
}
|
||||
|
||||
peekMutex := sync.Mutex{}
|
||||
if !conn.Peeked() {
|
||||
peekMutex.Lock()
|
||||
go func() {
|
||||
defer peekMutex.Unlock()
|
||||
_ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
||||
_, _ = conn.Peek(1)
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
}
|
||||
|
||||
proxy, rule, err := resolveMetadata(connCtx, metadata)
|
||||
|
@ -387,10 +399,26 @@ func handleTCPConn(connCtx C.ConnContext) {
|
|||
}
|
||||
}
|
||||
|
||||
var peekBytes []byte
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout)
|
||||
defer cancel()
|
||||
remoteConn, err := retry(ctx, func(ctx context.Context) (C.Conn, error) {
|
||||
return proxy.DialContext(ctx, dialMetadata)
|
||||
remoteConn, err := proxy.DialContext(ctx, dialMetadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peekMutex.Lock()
|
||||
defer peekMutex.Unlock()
|
||||
peekBytes, _ = conn.Peek(conn.Buffered())
|
||||
_, err = remoteConn.Write(peekBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if peekLen := len(peekBytes); peekLen > 0 {
|
||||
_, _ = conn.Discard(peekLen)
|
||||
}
|
||||
return remoteConn, err
|
||||
}, func(err error) {
|
||||
if rule == nil {
|
||||
log.Warnln(
|
||||
|
|
Loading…
Reference in a new issue