refactor: Implement extended IO
This commit is contained in:
parent
8fa66c13a9
commit
d1565bb46f
7 changed files with 219 additions and 39 deletions
|
@ -4,12 +4,15 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/gofrs/uuid"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/Dreamacro/clash/component/dialer"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type Base struct {
|
||||
|
@ -166,7 +169,7 @@ func NewBase(opt BaseOption) *Base {
|
|||
}
|
||||
|
||||
type conn struct {
|
||||
net.Conn
|
||||
network.ExtendedConn
|
||||
chain C.Chain
|
||||
actualRemoteDestination string
|
||||
}
|
||||
|
@ -185,8 +188,15 @@ func (c *conn) AppendToChains(a C.ProxyAdapter) {
|
|||
c.chain = append(c.chain, a.Name())
|
||||
}
|
||||
|
||||
func (c *conn) Upstream() any {
|
||||
if wrapper, ok := c.ExtendedConn.(*bufio.ExtendedConnWrapper); ok {
|
||||
return wrapper.Conn
|
||||
}
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
||||
func NewConn(c net.Conn, a C.ProxyAdapter) C.Conn {
|
||||
return &conn{c, []string{a.Name()}, parseRemoteDestination(a.Addr())}
|
||||
return &conn{bufio.NewExtendedConn(c), []string{a.Name()}, parseRemoteDestination(a.Addr())}
|
||||
}
|
||||
|
||||
type packetConn struct {
|
||||
|
|
|
@ -14,6 +14,8 @@ import (
|
|||
"github.com/Dreamacro/clash/transport/gun"
|
||||
"github.com/Dreamacro/clash/transport/trojan"
|
||||
"github.com/Dreamacro/clash/transport/vless"
|
||||
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
)
|
||||
|
||||
type Trojan struct {
|
||||
|
@ -95,7 +97,7 @@ func (t *Trojan) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error)
|
|||
return c, err
|
||||
}
|
||||
err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata))
|
||||
return c, err
|
||||
return bufio.NewExtendedConn(c), err
|
||||
}
|
||||
|
||||
// DialContext implements C.ProxyAdapter
|
||||
|
|
|
@ -3,18 +3,24 @@ package net
|
|||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
sing_bufio "github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var _ network.ExtendedConn = (*BufferedConn)(nil)
|
||||
|
||||
type BufferedConn struct {
|
||||
r *bufio.Reader
|
||||
net.Conn
|
||||
network.ExtendedConn
|
||||
}
|
||||
|
||||
func NewBufferedConn(c net.Conn) *BufferedConn {
|
||||
if bc, ok := c.(*BufferedConn); ok {
|
||||
return bc
|
||||
}
|
||||
return &BufferedConn{bufio.NewReader(c), c}
|
||||
return &BufferedConn{bufio.NewReader(c), sing_bufio.NewExtendedConn(c)}
|
||||
}
|
||||
|
||||
// Reader returns the internal bufio.Reader.
|
||||
|
@ -42,3 +48,18 @@ func (c *BufferedConn) UnreadByte() error {
|
|||
func (c *BufferedConn) Buffered() int {
|
||||
return c.r.Buffered()
|
||||
}
|
||||
|
||||
func (c *BufferedConn) ReadBuffer(buffer *buf.Buffer) (err error) {
|
||||
if c.r.Buffered() > 0 {
|
||||
_, err = buffer.ReadOnceFrom(c.r)
|
||||
return
|
||||
}
|
||||
return c.ExtendedConn.ReadBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *BufferedConn) Upstream() any {
|
||||
if wrapper, ok := c.ExtendedConn.(*sing_bufio.ExtendedConnWrapper); ok {
|
||||
return wrapper.Conn
|
||||
}
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package vless
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -9,12 +8,16 @@ import (
|
|||
"net"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
xtls "github.com/xtls/go"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
network.ExtendedConn
|
||||
dst *DstAddr
|
||||
id *uuid.UUID
|
||||
addons *Addons
|
||||
|
@ -23,57 +26,82 @@ type Conn struct {
|
|||
|
||||
func (vc *Conn) Read(b []byte) (int, error) {
|
||||
if vc.received {
|
||||
return vc.Conn.Read(b)
|
||||
return vc.ExtendedConn.Read(b)
|
||||
}
|
||||
|
||||
if err := vc.recvResponse(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
vc.received = true
|
||||
return vc.Conn.Read(b)
|
||||
return vc.ExtendedConn.Read(b)
|
||||
}
|
||||
|
||||
func (vc *Conn) sendRequest() error {
|
||||
buf := &bytes.Buffer{}
|
||||
func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
|
||||
if vc.received {
|
||||
return vc.ExtendedConn.ReadBuffer(buffer)
|
||||
}
|
||||
|
||||
buf.WriteByte(Version) // protocol version
|
||||
buf.Write(vc.id.Bytes()) // 16 bytes of uuid
|
||||
if err := vc.recvResponse(); err != nil {
|
||||
return err
|
||||
}
|
||||
vc.received = true
|
||||
return vc.ExtendedConn.ReadBuffer(buffer)
|
||||
}
|
||||
|
||||
func (vc *Conn) sendRequest() (err error) {
|
||||
requestLen := 1 // protocol version
|
||||
requestLen += 16 // UUID
|
||||
requestLen += 1 // addons length
|
||||
var addonsBytes []byte
|
||||
if vc.addons != nil {
|
||||
bytes, err := proto.Marshal(vc.addons)
|
||||
addonsBytes, err = proto.Marshal(vc.addons)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf.WriteByte(byte(len(bytes)))
|
||||
buf.Write(bytes)
|
||||
} else {
|
||||
buf.WriteByte(0) // addon data length. 0 means no addon data
|
||||
}
|
||||
requestLen += len(addonsBytes)
|
||||
requestLen += 1 // command
|
||||
if !vc.dst.Mux {
|
||||
requestLen += 2 // port
|
||||
requestLen += 1 // addr type
|
||||
requestLen += len(vc.dst.Addr)
|
||||
}
|
||||
_buffer := buf.StackNewSize(requestLen)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
|
||||
common.Must(
|
||||
buffer.WriteByte(Version), // protocol version
|
||||
common.Error(buffer.Write(vc.id.Bytes())), // 16 bytes of uuid
|
||||
buffer.WriteByte(byte(len(addonsBytes))),
|
||||
common.Error(buffer.Write(addonsBytes)),
|
||||
)
|
||||
|
||||
if vc.dst.Mux {
|
||||
buf.WriteByte(CommandMux)
|
||||
common.Must(buffer.WriteByte(CommandMux))
|
||||
} else {
|
||||
if vc.dst.UDP {
|
||||
buf.WriteByte(CommandUDP)
|
||||
common.Must(buffer.WriteByte(CommandUDP))
|
||||
} else {
|
||||
buf.WriteByte(CommandTCP)
|
||||
common.Must(buffer.WriteByte(CommandTCP))
|
||||
}
|
||||
|
||||
// Port AddrType Addr
|
||||
binary.Write(buf, binary.BigEndian, vc.dst.Port)
|
||||
buf.WriteByte(vc.dst.AddrType)
|
||||
buf.Write(vc.dst.Addr)
|
||||
binary.BigEndian.PutUint16(buffer.Extend(2), vc.dst.Port)
|
||||
common.Must(
|
||||
buffer.WriteByte(vc.dst.AddrType),
|
||||
common.Error(buffer.Write(vc.dst.Addr)),
|
||||
)
|
||||
}
|
||||
|
||||
_, err := vc.Conn.Write(buf.Bytes())
|
||||
return err
|
||||
_, err = vc.ExtendedConn.Write(buffer.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
func (vc *Conn) recvResponse() error {
|
||||
var err error
|
||||
buf := make([]byte, 1)
|
||||
_, err = io.ReadFull(vc.Conn, buf)
|
||||
var buf [1]byte
|
||||
_, err = io.ReadFull(vc.ExtendedConn, buf[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -82,23 +110,30 @@ func (vc *Conn) recvResponse() error {
|
|||
return errors.New("unexpected response version")
|
||||
}
|
||||
|
||||
_, err = io.ReadFull(vc.Conn, buf)
|
||||
_, err = io.ReadFull(vc.ExtendedConn, buf[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
length := int64(buf[0])
|
||||
if length != 0 { // addon data length > 0
|
||||
io.CopyN(io.Discard, vc.Conn, length) // just discard
|
||||
io.CopyN(io.Discard, vc.ExtendedConn, length) // just discard
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vc *Conn) Upstream() any {
|
||||
if wrapper, ok := vc.ExtendedConn.(*bufio.ExtendedConnWrapper); ok {
|
||||
return wrapper.Conn
|
||||
}
|
||||
return vc.ExtendedConn
|
||||
}
|
||||
|
||||
// newConn return a Conn instance
|
||||
func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
|
||||
c := &Conn{
|
||||
Conn: conn,
|
||||
ExtendedConn: bufio.NewExtendedConn(conn),
|
||||
id: client.uuid,
|
||||
dst: dst,
|
||||
}
|
||||
|
|
|
@ -5,9 +5,11 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -15,15 +17,24 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
_ "unsafe"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
//go:linkname maskBytes github.com/gorilla/websocket.maskBytes
|
||||
func maskBytes(key [4]byte, pos int, b []byte) int
|
||||
|
||||
type websocketConn struct {
|
||||
conn *websocket.Conn
|
||||
reader io.Reader
|
||||
remoteAddr net.Addr
|
||||
|
||||
rawWriter network.ExtendedWriter
|
||||
|
||||
// https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency
|
||||
rMux sync.Mutex
|
||||
wMux sync.Mutex
|
||||
|
@ -31,6 +42,7 @@ type websocketConn struct {
|
|||
|
||||
type websocketWithEarlyDataConn struct {
|
||||
net.Conn
|
||||
wsWriter network.ExtendedWriter
|
||||
underlay net.Conn
|
||||
closed bool
|
||||
dialed chan bool
|
||||
|
@ -79,6 +91,54 @@ func (wsc *websocketConn) Write(b []byte) (int, error) {
|
|||
return len(b), nil
|
||||
}
|
||||
|
||||
func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
var payloadBitLength int
|
||||
dataLen := buffer.Len()
|
||||
data := buffer.Bytes()
|
||||
if dataLen < 126 {
|
||||
payloadBitLength = 1
|
||||
} else if dataLen < 65536 {
|
||||
payloadBitLength = 3
|
||||
} else {
|
||||
payloadBitLength = 9
|
||||
}
|
||||
|
||||
var headerLen int
|
||||
headerLen += 1 // FIN / RSV / OPCODE
|
||||
headerLen += payloadBitLength
|
||||
headerLen += 4 // MASK KEY
|
||||
|
||||
header := buffer.ExtendHeader(headerLen)
|
||||
header[0] = websocket.BinaryMessage | 1<<7
|
||||
header[1] = 1 << 7
|
||||
|
||||
if dataLen < 126 {
|
||||
header[1] |= byte(dataLen)
|
||||
} else if dataLen < 65536 {
|
||||
header[1] |= 126
|
||||
binary.BigEndian.PutUint16(header[2:], uint16(dataLen))
|
||||
} else {
|
||||
header[1] |= 127
|
||||
binary.BigEndian.PutUint64(header[2:], uint64(dataLen))
|
||||
}
|
||||
|
||||
maskKey := rand.Uint32()
|
||||
binary.BigEndian.PutUint32(header[1+payloadBitLength:], maskKey)
|
||||
maskBytes(*(*[4]byte)(header[1+payloadBitLength:]), 0, data)
|
||||
|
||||
wsc.wMux.Lock()
|
||||
defer wsc.wMux.Unlock()
|
||||
return wsc.rawWriter.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (wsc *websocketConn) FrontHeadroom() int {
|
||||
return 14
|
||||
}
|
||||
|
||||
func (wsc *websocketConn) Upstream() any {
|
||||
return wsc.conn.UnderlyingConn()
|
||||
}
|
||||
|
||||
func (wsc *websocketConn) Close() error {
|
||||
var errors []string
|
||||
if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil {
|
||||
|
@ -149,6 +209,7 @@ func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
|
|||
}
|
||||
|
||||
wsedc.dialed <- true
|
||||
wsedc.wsWriter = bufio.NewExtendedWriter(wsedc.Conn)
|
||||
if earlyDataBuf.Len() != 0 {
|
||||
_, err = wsedc.Conn.Write(earlyDataBuf.Bytes())
|
||||
}
|
||||
|
@ -170,6 +231,20 @@ func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) {
|
|||
return wsedc.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
if wsedc.closed {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
if wsedc.Conn == nil {
|
||||
if err := wsedc.Dial(buffer.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return wsedc.wsWriter.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) {
|
||||
if wsedc.closed {
|
||||
return 0, io.ErrClosedPipe
|
||||
|
@ -228,6 +303,10 @@ func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error {
|
|||
return wsedc.Conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) Upstream() any {
|
||||
return wsedc.Conn
|
||||
}
|
||||
|
||||
func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
conn = &websocketWithEarlyDataConn{
|
||||
|
@ -294,6 +373,7 @@ func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buf
|
|||
|
||||
return &websocketConn{
|
||||
conn: wsConn,
|
||||
rawWriter: bufio.NewExtendedWriter(wsConn.UnderlyingConn()),
|
||||
remoteAddr: conn.RemoteAddr(),
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
"github.com/Dreamacro/clash/common/pool"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
)
|
||||
|
||||
func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata) error {
|
||||
|
@ -60,5 +62,5 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, oAddr,
|
|||
}
|
||||
|
||||
func handleSocket(ctx C.ConnContext, outbound net.Conn) {
|
||||
N.Relay(ctx.Conn(), outbound)
|
||||
bufio.CopyConn(context.TODO(), ctx.Conn(), outbound)
|
||||
}
|
||||
|
|
|
@ -7,6 +7,9 @@ import (
|
|||
C "github.com/Dreamacro/clash/constant"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
|
@ -30,6 +33,8 @@ type tcpTracker struct {
|
|||
C.Conn `json:"-"`
|
||||
*trackerInfo
|
||||
manager *Manager
|
||||
extendedReader network.ExtendedReader
|
||||
extendedWriter network.ExtendedWriter
|
||||
}
|
||||
|
||||
func (tt *tcpTracker) ID() string {
|
||||
|
@ -44,6 +49,14 @@ func (tt *tcpTracker) Read(b []byte) (int, error) {
|
|||
return n, err
|
||||
}
|
||||
|
||||
func (tt *tcpTracker) ReadBuffer(buffer *buf.Buffer) (err error) {
|
||||
err = tt.extendedReader.ReadBuffer(buffer)
|
||||
download := int64(buffer.Len())
|
||||
tt.manager.PushDownloaded(download)
|
||||
tt.DownloadTotal.Add(download)
|
||||
return
|
||||
}
|
||||
|
||||
func (tt *tcpTracker) Write(b []byte) (int, error) {
|
||||
n, err := tt.Conn.Write(b)
|
||||
upload := int64(n)
|
||||
|
@ -52,11 +65,26 @@ func (tt *tcpTracker) Write(b []byte) (int, error) {
|
|||
return n, err
|
||||
}
|
||||
|
||||
func (tt *tcpTracker) WriteBuffer(buffer *buf.Buffer) (err error) {
|
||||
err = tt.extendedWriter.WriteBuffer(buffer)
|
||||
var upload int64
|
||||
if err != nil {
|
||||
upload = int64(buffer.Len())
|
||||
}
|
||||
tt.manager.PushUploaded(upload)
|
||||
tt.UploadTotal.Add(upload)
|
||||
return
|
||||
}
|
||||
|
||||
func (tt *tcpTracker) Close() error {
|
||||
tt.manager.Leave(tt)
|
||||
return tt.Conn.Close()
|
||||
}
|
||||
|
||||
func (tt *tcpTracker) Upstream() any {
|
||||
return tt.Conn
|
||||
}
|
||||
|
||||
func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker {
|
||||
uuid, _ := uuid.NewV4()
|
||||
if conn != nil {
|
||||
|
@ -79,6 +107,8 @@ func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R
|
|||
UploadTotal: atomic.NewInt64(0),
|
||||
DownloadTotal: atomic.NewInt64(0),
|
||||
},
|
||||
extendedReader: bufio.NewExtendedReader(conn),
|
||||
extendedWriter: bufio.NewExtendedWriter(conn),
|
||||
}
|
||||
|
||||
if rule != nil {
|
||||
|
|
Loading…
Reference in a new issue