refactor: Implement extended IO

This commit is contained in:
H1JK 2023-01-16 09:42:03 +08:00
parent 8fa66c13a9
commit d1565bb46f
7 changed files with 219 additions and 39 deletions

View file

@ -4,12 +4,15 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"github.com/gofrs/uuid"
"net" "net"
"strings" "strings"
"github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/network"
) )
type Base struct { type Base struct {
@ -166,7 +169,7 @@ func NewBase(opt BaseOption) *Base {
} }
type conn struct { type conn struct {
net.Conn network.ExtendedConn
chain C.Chain chain C.Chain
actualRemoteDestination string actualRemoteDestination string
} }
@ -185,8 +188,15 @@ func (c *conn) AppendToChains(a C.ProxyAdapter) {
c.chain = append(c.chain, a.Name()) c.chain = append(c.chain, a.Name())
} }
func (c *conn) Upstream() any {
if wrapper, ok := c.ExtendedConn.(*bufio.ExtendedConnWrapper); ok {
return wrapper.Conn
}
return c.ExtendedConn
}
func NewConn(c net.Conn, a C.ProxyAdapter) C.Conn { func NewConn(c net.Conn, a C.ProxyAdapter) C.Conn {
return &conn{c, []string{a.Name()}, parseRemoteDestination(a.Addr())} return &conn{bufio.NewExtendedConn(c), []string{a.Name()}, parseRemoteDestination(a.Addr())}
} }
type packetConn struct { type packetConn struct {

View file

@ -14,6 +14,8 @@ import (
"github.com/Dreamacro/clash/transport/gun" "github.com/Dreamacro/clash/transport/gun"
"github.com/Dreamacro/clash/transport/trojan" "github.com/Dreamacro/clash/transport/trojan"
"github.com/Dreamacro/clash/transport/vless" "github.com/Dreamacro/clash/transport/vless"
"github.com/sagernet/sing/common/bufio"
) )
type Trojan struct { type Trojan struct {
@ -95,7 +97,7 @@ func (t *Trojan) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error)
return c, err return c, err
} }
err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata)) err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata))
return c, err return bufio.NewExtendedConn(c), err
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter

View file

@ -3,18 +3,24 @@ package net
import ( import (
"bufio" "bufio"
"net" "net"
"github.com/sagernet/sing/common/buf"
sing_bufio "github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/network"
) )
var _ network.ExtendedConn = (*BufferedConn)(nil)
type BufferedConn struct { type BufferedConn struct {
r *bufio.Reader r *bufio.Reader
net.Conn network.ExtendedConn
} }
func NewBufferedConn(c net.Conn) *BufferedConn { func NewBufferedConn(c net.Conn) *BufferedConn {
if bc, ok := c.(*BufferedConn); ok { if bc, ok := c.(*BufferedConn); ok {
return bc return bc
} }
return &BufferedConn{bufio.NewReader(c), c} return &BufferedConn{bufio.NewReader(c), sing_bufio.NewExtendedConn(c)}
} }
// Reader returns the internal bufio.Reader. // Reader returns the internal bufio.Reader.
@ -42,3 +48,18 @@ func (c *BufferedConn) UnreadByte() error {
func (c *BufferedConn) Buffered() int { func (c *BufferedConn) Buffered() int {
return c.r.Buffered() return c.r.Buffered()
} }
func (c *BufferedConn) ReadBuffer(buffer *buf.Buffer) (err error) {
if c.r.Buffered() > 0 {
_, err = buffer.ReadOnceFrom(c.r)
return
}
return c.ExtendedConn.ReadBuffer(buffer)
}
func (c *BufferedConn) Upstream() any {
if wrapper, ok := c.ExtendedConn.(*sing_bufio.ExtendedConnWrapper); ok {
return wrapper.Conn
}
return c.ExtendedConn
}

View file

@ -1,7 +1,6 @@
package vless package vless
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -9,12 +8,16 @@ import (
"net" "net"
"github.com/gofrs/uuid" "github.com/gofrs/uuid"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/network"
xtls "github.com/xtls/go" xtls "github.com/xtls/go"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )
type Conn struct { type Conn struct {
net.Conn network.ExtendedConn
dst *DstAddr dst *DstAddr
id *uuid.UUID id *uuid.UUID
addons *Addons addons *Addons
@ -23,57 +26,82 @@ type Conn struct {
func (vc *Conn) Read(b []byte) (int, error) { func (vc *Conn) Read(b []byte) (int, error) {
if vc.received { if vc.received {
return vc.Conn.Read(b) return vc.ExtendedConn.Read(b)
} }
if err := vc.recvResponse(); err != nil { if err := vc.recvResponse(); err != nil {
return 0, err return 0, err
} }
vc.received = true vc.received = true
return vc.Conn.Read(b) return vc.ExtendedConn.Read(b)
} }
func (vc *Conn) sendRequest() error { func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
buf := &bytes.Buffer{} if vc.received {
return vc.ExtendedConn.ReadBuffer(buffer)
}
buf.WriteByte(Version) // protocol version if err := vc.recvResponse(); err != nil {
buf.Write(vc.id.Bytes()) // 16 bytes of uuid return err
}
vc.received = true
return vc.ExtendedConn.ReadBuffer(buffer)
}
func (vc *Conn) sendRequest() (err error) {
requestLen := 1 // protocol version
requestLen += 16 // UUID
requestLen += 1 // addons length
var addonsBytes []byte
if vc.addons != nil { if vc.addons != nil {
bytes, err := proto.Marshal(vc.addons) addonsBytes, err = proto.Marshal(vc.addons)
if err != nil { if err != nil {
return err return err
} }
buf.WriteByte(byte(len(bytes)))
buf.Write(bytes)
} else {
buf.WriteByte(0) // addon data length. 0 means no addon data
} }
requestLen += len(addonsBytes)
requestLen += 1 // command
if !vc.dst.Mux {
requestLen += 2 // port
requestLen += 1 // addr type
requestLen += len(vc.dst.Addr)
}
_buffer := buf.StackNewSize(requestLen)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
common.Must(
buffer.WriteByte(Version), // protocol version
common.Error(buffer.Write(vc.id.Bytes())), // 16 bytes of uuid
buffer.WriteByte(byte(len(addonsBytes))),
common.Error(buffer.Write(addonsBytes)),
)
if vc.dst.Mux { if vc.dst.Mux {
buf.WriteByte(CommandMux) common.Must(buffer.WriteByte(CommandMux))
} else { } else {
if vc.dst.UDP { if vc.dst.UDP {
buf.WriteByte(CommandUDP) common.Must(buffer.WriteByte(CommandUDP))
} else { } else {
buf.WriteByte(CommandTCP) common.Must(buffer.WriteByte(CommandTCP))
} }
// Port AddrType Addr binary.BigEndian.PutUint16(buffer.Extend(2), vc.dst.Port)
binary.Write(buf, binary.BigEndian, vc.dst.Port) common.Must(
buf.WriteByte(vc.dst.AddrType) buffer.WriteByte(vc.dst.AddrType),
buf.Write(vc.dst.Addr) common.Error(buffer.Write(vc.dst.Addr)),
)
} }
_, err := vc.Conn.Write(buf.Bytes()) _, err = vc.ExtendedConn.Write(buffer.Bytes())
return err return
} }
func (vc *Conn) recvResponse() error { func (vc *Conn) recvResponse() error {
var err error var err error
buf := make([]byte, 1) var buf [1]byte
_, err = io.ReadFull(vc.Conn, buf) _, err = io.ReadFull(vc.ExtendedConn, buf[:])
if err != nil { if err != nil {
return err return err
} }
@ -82,23 +110,30 @@ func (vc *Conn) recvResponse() error {
return errors.New("unexpected response version") return errors.New("unexpected response version")
} }
_, err = io.ReadFull(vc.Conn, buf) _, err = io.ReadFull(vc.ExtendedConn, buf[:])
if err != nil { if err != nil {
return err return err
} }
length := int64(buf[0]) length := int64(buf[0])
if length != 0 { // addon data length > 0 if length != 0 { // addon data length > 0
io.CopyN(io.Discard, vc.Conn, length) // just discard io.CopyN(io.Discard, vc.ExtendedConn, length) // just discard
} }
return nil return nil
} }
func (vc *Conn) Upstream() any {
if wrapper, ok := vc.ExtendedConn.(*bufio.ExtendedConnWrapper); ok {
return wrapper.Conn
}
return vc.ExtendedConn
}
// newConn return a Conn instance // newConn return a Conn instance
func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) { func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
c := &Conn{ c := &Conn{
Conn: conn, ExtendedConn: bufio.NewExtendedConn(conn),
id: client.uuid, id: client.uuid,
dst: dst, dst: dst,
} }

View file

@ -5,9 +5,11 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -15,15 +17,24 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
_ "unsafe"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/network"
) )
//go:linkname maskBytes github.com/gorilla/websocket.maskBytes
func maskBytes(key [4]byte, pos int, b []byte) int
type websocketConn struct { type websocketConn struct {
conn *websocket.Conn conn *websocket.Conn
reader io.Reader reader io.Reader
remoteAddr net.Addr remoteAddr net.Addr
rawWriter network.ExtendedWriter
// https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency // https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency
rMux sync.Mutex rMux sync.Mutex
wMux sync.Mutex wMux sync.Mutex
@ -31,6 +42,7 @@ type websocketConn struct {
type websocketWithEarlyDataConn struct { type websocketWithEarlyDataConn struct {
net.Conn net.Conn
wsWriter network.ExtendedWriter
underlay net.Conn underlay net.Conn
closed bool closed bool
dialed chan bool dialed chan bool
@ -79,6 +91,54 @@ func (wsc *websocketConn) Write(b []byte) (int, error) {
return len(b), nil return len(b), nil
} }
func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error {
var payloadBitLength int
dataLen := buffer.Len()
data := buffer.Bytes()
if dataLen < 126 {
payloadBitLength = 1
} else if dataLen < 65536 {
payloadBitLength = 3
} else {
payloadBitLength = 9
}
var headerLen int
headerLen += 1 // FIN / RSV / OPCODE
headerLen += payloadBitLength
headerLen += 4 // MASK KEY
header := buffer.ExtendHeader(headerLen)
header[0] = websocket.BinaryMessage | 1<<7
header[1] = 1 << 7
if dataLen < 126 {
header[1] |= byte(dataLen)
} else if dataLen < 65536 {
header[1] |= 126
binary.BigEndian.PutUint16(header[2:], uint16(dataLen))
} else {
header[1] |= 127
binary.BigEndian.PutUint64(header[2:], uint64(dataLen))
}
maskKey := rand.Uint32()
binary.BigEndian.PutUint32(header[1+payloadBitLength:], maskKey)
maskBytes(*(*[4]byte)(header[1+payloadBitLength:]), 0, data)
wsc.wMux.Lock()
defer wsc.wMux.Unlock()
return wsc.rawWriter.WriteBuffer(buffer)
}
func (wsc *websocketConn) FrontHeadroom() int {
return 14
}
func (wsc *websocketConn) Upstream() any {
return wsc.conn.UnderlyingConn()
}
func (wsc *websocketConn) Close() error { func (wsc *websocketConn) Close() error {
var errors []string var errors []string
if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil { if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil {
@ -149,6 +209,7 @@ func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
} }
wsedc.dialed <- true wsedc.dialed <- true
wsedc.wsWriter = bufio.NewExtendedWriter(wsedc.Conn)
if earlyDataBuf.Len() != 0 { if earlyDataBuf.Len() != 0 {
_, err = wsedc.Conn.Write(earlyDataBuf.Bytes()) _, err = wsedc.Conn.Write(earlyDataBuf.Bytes())
} }
@ -170,6 +231,20 @@ func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) {
return wsedc.Conn.Write(b) return wsedc.Conn.Write(b)
} }
func (wsedc *websocketWithEarlyDataConn) WriteBuffer(buffer *buf.Buffer) error {
if wsedc.closed {
return io.ErrClosedPipe
}
if wsedc.Conn == nil {
if err := wsedc.Dial(buffer.Bytes()); err != nil {
return err
}
return nil
}
return wsedc.wsWriter.WriteBuffer(buffer)
}
func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) { func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) {
if wsedc.closed { if wsedc.closed {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
@ -228,6 +303,10 @@ func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error {
return wsedc.Conn.SetWriteDeadline(t) return wsedc.Conn.SetWriteDeadline(t)
} }
func (wsedc *websocketWithEarlyDataConn) Upstream() any {
return wsedc.Conn
}
func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
conn = &websocketWithEarlyDataConn{ conn = &websocketWithEarlyDataConn{
@ -294,6 +373,7 @@ func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buf
return &websocketConn{ return &websocketConn{
conn: wsConn, conn: wsConn,
rawWriter: bufio.NewExtendedWriter(wsConn.UnderlyingConn()),
remoteAddr: conn.RemoteAddr(), remoteAddr: conn.RemoteAddr(),
}, nil }, nil
} }

View file

@ -1,14 +1,16 @@
package tunnel package tunnel
import ( import (
"context"
"errors" "errors"
"net" "net"
"net/netip" "net/netip"
"time" "time"
N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/sagernet/sing/common/bufio"
) )
func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata) error { func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata) error {
@ -60,5 +62,5 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, oAddr,
} }
func handleSocket(ctx C.ConnContext, outbound net.Conn) { func handleSocket(ctx C.ConnContext, outbound net.Conn) {
N.Relay(ctx.Conn(), outbound) bufio.CopyConn(context.TODO(), ctx.Conn(), outbound)
} }

View file

@ -7,6 +7,9 @@ import (
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid" "github.com/gofrs/uuid"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/network"
"go.uber.org/atomic" "go.uber.org/atomic"
) )
@ -30,6 +33,8 @@ type tcpTracker struct {
C.Conn `json:"-"` C.Conn `json:"-"`
*trackerInfo *trackerInfo
manager *Manager manager *Manager
extendedReader network.ExtendedReader
extendedWriter network.ExtendedWriter
} }
func (tt *tcpTracker) ID() string { func (tt *tcpTracker) ID() string {
@ -44,6 +49,14 @@ func (tt *tcpTracker) Read(b []byte) (int, error) {
return n, err return n, err
} }
func (tt *tcpTracker) ReadBuffer(buffer *buf.Buffer) (err error) {
err = tt.extendedReader.ReadBuffer(buffer)
download := int64(buffer.Len())
tt.manager.PushDownloaded(download)
tt.DownloadTotal.Add(download)
return
}
func (tt *tcpTracker) Write(b []byte) (int, error) { func (tt *tcpTracker) Write(b []byte) (int, error) {
n, err := tt.Conn.Write(b) n, err := tt.Conn.Write(b)
upload := int64(n) upload := int64(n)
@ -52,11 +65,26 @@ func (tt *tcpTracker) Write(b []byte) (int, error) {
return n, err return n, err
} }
func (tt *tcpTracker) WriteBuffer(buffer *buf.Buffer) (err error) {
err = tt.extendedWriter.WriteBuffer(buffer)
var upload int64
if err != nil {
upload = int64(buffer.Len())
}
tt.manager.PushUploaded(upload)
tt.UploadTotal.Add(upload)
return
}
func (tt *tcpTracker) Close() error { func (tt *tcpTracker) Close() error {
tt.manager.Leave(tt) tt.manager.Leave(tt)
return tt.Conn.Close() return tt.Conn.Close()
} }
func (tt *tcpTracker) Upstream() any {
return tt.Conn
}
func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker { func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker {
uuid, _ := uuid.NewV4() uuid, _ := uuid.NewV4()
if conn != nil { if conn != nil {
@ -79,6 +107,8 @@ func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R
UploadTotal: atomic.NewInt64(0), UploadTotal: atomic.NewInt64(0),
DownloadTotal: atomic.NewInt64(0), DownloadTotal: atomic.NewInt64(0),
}, },
extendedReader: bufio.NewExtendedReader(conn),
extendedWriter: bufio.NewExtendedWriter(conn),
} }
if rule != nil { if rule != nil {