chore: simplify fast open code

This commit is contained in:
wwqgtxx 2023-11-30 20:00:24 +08:00
parent db973de7bd
commit 599ce784d2
4 changed files with 98 additions and 134 deletions

67
common/net/earlyconn.go Normal file
View file

@ -0,0 +1,67 @@
package net
import (
"net"
"sync"
"sync/atomic"
"unsafe"
"github.com/metacubex/mihomo/common/buf"
)
type earlyConn struct {
ExtendedConn // only expose standard N.ExtendedConn function to outside
resFunc func() error
resOnce sync.Once
resErr error
}
func (conn *earlyConn) Response() error {
conn.resOnce.Do(func() {
conn.resErr = conn.resFunc()
})
return conn.resErr
}
func (conn *earlyConn) Read(b []byte) (n int, err error) {
err = conn.Response()
if err != nil {
return 0, err
}
return conn.ExtendedConn.Read(b)
}
func (conn *earlyConn) ReadBuffer(buffer *buf.Buffer) (err error) {
err = conn.Response()
if err != nil {
return err
}
return conn.ExtendedConn.ReadBuffer(buffer)
}
func (conn *earlyConn) Upstream() any {
return conn.ExtendedConn
}
func (conn *earlyConn) Success() bool {
// atomic visit sync.Once.done
return atomic.LoadUint32((*uint32)(unsafe.Pointer(&conn.resOnce))) == 1 && conn.resErr == nil
}
func (conn *earlyConn) ReaderReplaceable() bool {
return conn.Success()
}
func (conn *earlyConn) ReaderPossiblyReplaceable() bool {
return !conn.Success()
}
func (conn *earlyConn) WriterReplaceable() bool {
return true
}
var _ ExtendedConn = (*earlyConn)(nil)
func NewEarlyConn(c net.Conn, f func() error) net.Conn {
return &earlyConn{ExtendedConn: NewExtendedConn(c), resFunc: f}
}

View file

@ -11,10 +11,8 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe"
atomic2 "github.com/metacubex/mihomo/common/atomic" atomic2 "github.com/metacubex/mihomo/common/atomic"
"github.com/metacubex/mihomo/common/buf"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/pool" "github.com/metacubex/mihomo/common/pool"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
@ -329,75 +327,30 @@ func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Meta
} }
bufConn := N.NewBufferedConn(stream) bufConn := N.NewBufferedConn(stream)
conn := &earlyConn{ExtendedConn: bufConn, bufConn: bufConn, RequestTimeout: t.RequestTimeout} response := func() error {
if !t.FastOpen { if t.RequestTimeout > 0 {
err = conn.Response() _ = bufConn.SetReadDeadline(time.Now().Add(t.RequestTimeout))
}
response, err := ReadResponse(bufConn)
if err != nil { if err != nil {
return nil, err _ = bufConn.Close()
}
}
return conn, nil
}
type earlyConn struct {
N.ExtendedConn // only expose standard N.ExtendedConn function to outside
bufConn *N.BufferedConn
resOnce sync.Once
resErr error
RequestTimeout time.Duration
}
func (conn *earlyConn) response() error {
if conn.RequestTimeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(conn.RequestTimeout))
}
response, err := ReadResponse(conn.bufConn)
if err != nil {
_ = conn.Close()
return err return err
} }
if response.IsFailed() { if response.IsFailed() {
_ = conn.Close() _ = bufConn.Close()
return errors.New("connect failed") return errors.New("connect failed")
} }
_ = conn.SetReadDeadline(time.Time{}) _ = bufConn.SetReadDeadline(time.Time{})
return nil return nil
}
func (conn *earlyConn) Response() error {
conn.resOnce.Do(func() {
conn.resErr = conn.response()
})
return conn.resErr
}
func (conn *earlyConn) Read(b []byte) (n int, err error) {
err = conn.Response()
if err != nil {
return 0, err
} }
return conn.bufConn.Read(b) if t.FastOpen {
} return N.NewEarlyConn(bufConn, response), nil
func (conn *earlyConn) ReadBuffer(buffer *buf.Buffer) (err error) {
err = conn.Response()
if err != nil {
return err
} }
return conn.bufConn.ReadBuffer(buffer) err = response()
} if err != nil {
return nil, err
func (conn *earlyConn) Upstream() any { }
return conn.bufConn return bufConn, nil
}
func (conn *earlyConn) ReaderReplaceable() bool {
return atomic.LoadUint32((*uint32)(unsafe.Pointer(&conn.resOnce))) == 1 && conn.resErr == nil
}
func (conn *earlyConn) WriterReplaceable() bool {
return true
} }
func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) { func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) {

View file

@ -1,65 +0,0 @@
package vmess
import (
"fmt"
"net/http"
"strings"
"sync"
"github.com/metacubex/mihomo/common/buf"
"github.com/metacubex/mihomo/common/net"
)
type httpUpgradeEarlyConn struct {
*net.BufferedConn
create sync.Once
done bool
err error
}
func (c *httpUpgradeEarlyConn) readResponse() {
var request http.Request
response, err := http.ReadResponse(c.Reader(), &request)
c.done = true
if err != nil {
c.err = err
return
}
if response.StatusCode != http.StatusSwitchingProtocols ||
!strings.EqualFold(response.Header.Get("Connection"), "upgrade") ||
!strings.EqualFold(response.Header.Get("Upgrade"), "websocket") {
c.err = fmt.Errorf("unexpected status: %s", response.Status)
return
}
}
func (c *httpUpgradeEarlyConn) Read(p []byte) (int, error) {
c.create.Do(c.readResponse)
if c.err != nil {
return 0, c.err
}
return c.BufferedConn.Read(p)
}
func (c *httpUpgradeEarlyConn) ReadBuffer(buffer *buf.Buffer) error {
c.create.Do(c.readResponse)
if c.err != nil {
return c.err
}
return c.BufferedConn.ReadBuffer(buffer)
}
func (c *httpUpgradeEarlyConn) ReaderReplaceable() bool {
return c.done
}
func (c *httpUpgradeEarlyConn) ReaderPossiblyReplaceable() bool {
return !c.done
}
func (c *httpUpgradeEarlyConn) ReadCached() *buf.Buffer {
if c.done {
return c.BufferedConn.ReadCached()
}
return nil
}

View file

@ -418,9 +418,18 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig,
bufferedConn := N.NewBufferedConn(conn) bufferedConn := N.NewBufferedConn(conn)
if c.V2rayHttpUpgrade && c.V2rayHttpUpgradeFastOpen { if c.V2rayHttpUpgrade && c.V2rayHttpUpgradeFastOpen {
return &httpUpgradeEarlyConn{ return N.NewEarlyConn(bufferedConn, func() error {
BufferedConn: bufferedConn, response, err := http.ReadResponse(bufferedConn.Reader(), request)
}, nil if err != nil {
return err
}
if response.StatusCode != http.StatusSwitchingProtocols ||
!strings.EqualFold(response.Header.Get("Connection"), "upgrade") ||
!strings.EqualFold(response.Header.Get("Upgrade"), "websocket") {
return fmt.Errorf("unexpected status: %s", response.Status)
}
return nil
}), nil
} }
response, err := http.ReadResponse(bufferedConn.Reader(), request) response, err := http.ReadResponse(bufferedConn.Reader(), request)