diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 9e9db82b..d80893c9 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -406,29 +406,36 @@ func handleTCPConn(connCtx C.ConnContext) { ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout) defer cancel() - remoteConn, err := retry(ctx, func(ctx context.Context) (C.Conn, error) { - remoteConn, err := proxy.DialContext(ctx, dialMetadata) + remoteConn, err := retry(ctx, func(ctx context.Context) (remoteConn C.Conn, err error) { + remoteConn, err = proxy.DialContext(ctx, dialMetadata) if err != nil { - return nil, err - } - for _, chain := range remoteConn.Chains() { - if chain == "REJECT" { - return remoteConn, nil - } + return } + if N.NeedHandshake(remoteConn) { + defer func() { + for _, chain := range remoteConn.Chains() { + if chain == "REJECT" { + err = nil + return + } + } + if err != nil { + remoteConn = nil + } + }() peekMutex.Lock() defer peekMutex.Unlock() peekBytes, _ = conn.Peek(conn.Buffered()) _, err = remoteConn.Write(peekBytes) if err != nil { - return nil, err + return } if peekLen = len(peekBytes); peekLen > 0 { _, _ = conn.Discard(peekLen) } } - return remoteConn, err + return }, func(err error) { if rule == nil { log.Warnln(