diff --git a/transport/tuic/client.go b/transport/tuic/client.go index 583eaba7..c18a9049 100644 --- a/transport/tuic/client.go +++ b/transport/tuic/client.go @@ -11,6 +11,7 @@ import ( "sync" "sync/atomic" "time" + "unsafe" "github.com/Dreamacro/clash/common/buf" N "github.com/Dreamacro/clash/common/net" @@ -289,7 +290,8 @@ func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Meta return nil, err } - conn := &earlyConn{BufferedConn: N.NewBufferedConn(stream), RequestTimeout: t.RequestTimeout} + bufConn := N.NewBufferedConn(stream) + conn := &earlyConn{ExtendedConn: bufConn, bufConn: bufConn, RequestTimeout: t.RequestTimeout} if !t.FastOpen { err = conn.Response() if err != nil { @@ -300,22 +302,19 @@ func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Meta } type earlyConn struct { - *N.BufferedConn - resOnce sync.Once - resErr error + N.ExtendedConn // only expose standard N.ExtendedConn function to outside + bufConn *N.BufferedConn + resOnce sync.Once + resErr error RequestTimeout time.Duration } -func (conn *earlyConn) ReaderReplaceable() bool { - return false -} - func (conn *earlyConn) response() error { if conn.RequestTimeout > 0 { _ = conn.SetReadDeadline(time.Now().Add(conn.RequestTimeout)) } - response, err := ReadResponse(conn) + response, err := ReadResponse(conn.bufConn) if err != nil { _ = conn.Close() return err @@ -340,7 +339,7 @@ func (conn *earlyConn) Read(b []byte) (n int, err error) { if err != nil { return 0, err } - return conn.BufferedConn.Read(b) + return conn.bufConn.Read(b) } func (conn *earlyConn) ReadBuffer(buffer *buf.Buffer) (err error) { @@ -348,7 +347,19 @@ func (conn *earlyConn) ReadBuffer(buffer *buf.Buffer) (err error) { if err != nil { return err } - return conn.BufferedConn.ReadBuffer(buffer) + return conn.bufConn.ReadBuffer(buffer) +} + +func (conn *earlyConn) Upstream() any { + return conn.bufConn +} + +func (conn *earlyConn) ReaderReplaceable() bool { + return atomic.LoadUint32((*uint32)(unsafe.Pointer(&conn.resOnce))) == 1 && conn.resErr == nil +} + +func (conn *earlyConn) WriterReplaceable() bool { + return true } func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.PacketConn, error) {