fix: peek not work with some inbound

This commit is contained in:
wwqgtxx 2023-02-25 19:41:01 +08:00
parent de92bc0234
commit a3b8c9c233
3 changed files with 15 additions and 9 deletions

View file

@ -27,6 +27,10 @@ func (c *BufferedConn) Reader() *bufio.Reader {
return c.r return c.r
} }
func (c *BufferedConn) ResetPeeked() {
c.peeked = false
}
func (c *BufferedConn) Peeked() bool { func (c *BufferedConn) Peeked() bool {
return c.peeked return c.peeked
} }

View file

@ -81,7 +81,7 @@ func (tt *tcpTracker) Upstream() any {
return tt.Conn return tt.Conn
} }
func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker { func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule, uploadTotal int64, downloadTotal int64) *tcpTracker {
uuid, _ := uuid.NewV4() uuid, _ := uuid.NewV4()
if conn != nil { if conn != nil {
if tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr); ok { if tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr); ok {
@ -100,8 +100,8 @@ func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R
Metadata: metadata, Metadata: metadata,
Chain: conn.Chains(), Chain: conn.Chains(),
Rule: "", Rule: "",
UploadTotal: atomic.NewInt64(0), UploadTotal: atomic.NewInt64(uploadTotal),
DownloadTotal: atomic.NewInt64(0), DownloadTotal: atomic.NewInt64(downloadTotal),
}, },
extendedReader: N.NewExtendedReader(conn), extendedReader: N.NewExtendedReader(conn),
extendedWriter: N.NewExtendedWriter(conn), extendedWriter: N.NewExtendedWriter(conn),
@ -147,7 +147,7 @@ func (ut *udpTracker) Close() error {
return ut.PacketConn.Close() return ut.PacketConn.Close()
} }
func NewUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule) *udpTracker { func NewUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule, uploadTotal int64, downloadTotal int64) *udpTracker {
uuid, _ := uuid.NewV4() uuid, _ := uuid.NewV4()
metadata.RemoteDst = conn.RemoteDestination() metadata.RemoteDst = conn.RemoteDestination()
@ -160,8 +160,8 @@ func NewUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, ru
Metadata: metadata, Metadata: metadata,
Chain: conn.Chains(), Chain: conn.Chains(),
Rule: "", Rule: "",
UploadTotal: atomic.NewInt64(0), UploadTotal: atomic.NewInt64(uploadTotal),
DownloadTotal: atomic.NewInt64(0), DownloadTotal: atomic.NewInt64(downloadTotal),
}, },
} }

View file

@ -322,7 +322,7 @@ func handleUDPConn(packet C.PacketAdapter) {
} }
pCtx.InjectPacketConn(rawPc) pCtx.InjectPacketConn(rawPc)
pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule) pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule, 0, 0)
switch true { switch true {
case metadata.SpecialProxy != "": case metadata.SpecialProxy != "":
@ -367,6 +367,7 @@ func handleTCPConn(connCtx C.ConnContext) {
} }
conn := connCtx.Conn() conn := connCtx.Conn()
conn.ResetPeeked()
if sniffer.Dispatcher.Enable() && sniffingEnable { if sniffer.Dispatcher.Enable() && sniffingEnable {
sniffer.Dispatcher.TCPSniff(conn, metadata) sniffer.Dispatcher.TCPSniff(conn, metadata)
} }
@ -400,6 +401,7 @@ func handleTCPConn(connCtx C.ConnContext) {
} }
var peekBytes []byte var peekBytes []byte
var peekLen int
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout) ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout)
defer cancel() defer cancel()
@ -415,7 +417,7 @@ func handleTCPConn(connCtx C.ConnContext) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if peekLen := len(peekBytes); peekLen > 0 { if peekLen = len(peekBytes); peekLen > 0 {
_, _ = conn.Discard(peekLen) _, _ = conn.Discard(peekLen)
} }
return remoteConn, err return remoteConn, err
@ -436,7 +438,7 @@ func handleTCPConn(connCtx C.ConnContext) {
return return
} }
remoteConn = statistic.NewTCPTracker(remoteConn, statistic.DefaultManager, metadata, rule) remoteConn = statistic.NewTCPTracker(remoteConn, statistic.DefaultManager, metadata, rule, 0, int64(peekLen))
defer func(remoteConn C.Conn) { defer func(remoteConn C.Conn) {
_ = remoteConn.Close() _ = remoteConn.Close()
}(remoteConn) }(remoteConn)