chore: better tuic earlyConn impl

This commit is contained in:
wwqgtxx 2023-05-10 09:36:06 +08:00
parent 67b9314693
commit 15a8d7c473

View file

@ -11,6 +11,7 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe"
"github.com/Dreamacro/clash/common/buf" "github.com/Dreamacro/clash/common/buf"
N "github.com/Dreamacro/clash/common/net" N "github.com/Dreamacro/clash/common/net"
@ -289,7 +290,8 @@ func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Meta
return nil, err return nil, err
} }
conn := &earlyConn{BufferedConn: N.NewBufferedConn(stream), RequestTimeout: t.RequestTimeout} bufConn := N.NewBufferedConn(stream)
conn := &earlyConn{ExtendedConn: bufConn, bufConn: bufConn, RequestTimeout: t.RequestTimeout}
if !t.FastOpen { if !t.FastOpen {
err = conn.Response() err = conn.Response()
if err != nil { if err != nil {
@ -300,22 +302,19 @@ func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Meta
} }
type earlyConn struct { type earlyConn struct {
*N.BufferedConn N.ExtendedConn // only expose standard N.ExtendedConn function to outside
bufConn *N.BufferedConn
resOnce sync.Once resOnce sync.Once
resErr error resErr error
RequestTimeout time.Duration RequestTimeout time.Duration
} }
func (conn *earlyConn) ReaderReplaceable() bool {
return false
}
func (conn *earlyConn) response() error { func (conn *earlyConn) response() error {
if conn.RequestTimeout > 0 { if conn.RequestTimeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(conn.RequestTimeout)) _ = conn.SetReadDeadline(time.Now().Add(conn.RequestTimeout))
} }
response, err := ReadResponse(conn) response, err := ReadResponse(conn.bufConn)
if err != nil { if err != nil {
_ = conn.Close() _ = conn.Close()
return err return err
@ -340,7 +339,7 @@ func (conn *earlyConn) Read(b []byte) (n int, err error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
return conn.BufferedConn.Read(b) return conn.bufConn.Read(b)
} }
func (conn *earlyConn) ReadBuffer(buffer *buf.Buffer) (err error) { func (conn *earlyConn) ReadBuffer(buffer *buf.Buffer) (err error) {
@ -348,7 +347,19 @@ func (conn *earlyConn) ReadBuffer(buffer *buf.Buffer) (err error) {
if err != nil { if err != nil {
return err return err
} }
return conn.BufferedConn.ReadBuffer(buffer) return conn.bufConn.ReadBuffer(buffer)
}
func (conn *earlyConn) Upstream() any {
return conn.bufConn
}
func (conn *earlyConn) ReaderReplaceable() bool {
return atomic.LoadUint32((*uint32)(unsafe.Pointer(&conn.resOnce))) == 1 && conn.resErr == nil
}
func (conn *earlyConn) WriterReplaceable() bool {
return true
} }
func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.PacketConn, error) { func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.PacketConn, error) {