chore: better tuic conn close
This commit is contained in:
parent
b2939ad863
commit
25540e6c96
2 changed files with 26 additions and 11 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -199,6 +200,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
|
|||
RequestTimeout: option.RequestTimeout,
|
||||
}
|
||||
clientMap[o] = client
|
||||
runtime.SetFinalizer(client, closeTuicClient)
|
||||
return client
|
||||
}
|
||||
|
||||
|
@ -214,3 +216,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
|
|||
getClient: getClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func closeTuicClient(client *tuic.Client) {
|
||||
client.Close(nil)
|
||||
}
|
||||
|
|
|
@ -197,18 +197,26 @@ func (t *Client) deferQuicConn(quicConn quic.Connection, err error) {
|
|||
t.connMutex.Lock()
|
||||
defer t.connMutex.Unlock()
|
||||
if t.quicConn == quicConn {
|
||||
t.udpInputMap.Range(func(key, value any) bool {
|
||||
if conn, ok := value.(net.Conn); ok {
|
||||
_ = conn.Close()
|
||||
}
|
||||
return true
|
||||
})
|
||||
t.udpInputMap = sync.Map{} // new one
|
||||
t.quicConn = nil
|
||||
t.Close(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Client) Close(err error) {
|
||||
quicConn := t.quicConn
|
||||
if quicConn != nil {
|
||||
_ = t.quicConn.CloseWithError(ProtocolError, err.Error())
|
||||
t.udpInputMap.Range(func(key, value any) bool {
|
||||
if conn, ok := value.(net.Conn); ok {
|
||||
_ = conn.Close()
|
||||
}
|
||||
return true
|
||||
})
|
||||
t.udpInputMap = sync.Map{} // new one
|
||||
t.quicConn = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)) (net.Conn, error) {
|
||||
quicConn, err := t.getQuicConn(ctx, dialFn)
|
||||
if err != nil {
|
||||
|
@ -237,7 +245,7 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f
|
|||
if t.RequestTimeout > 0 {
|
||||
_ = stream.SetReadDeadline(time.Now().Add(time.Duration(t.RequestTimeout) * time.Millisecond))
|
||||
}
|
||||
conn := N.NewBufferedConn(&quicStreamConn{stream, quicConn.LocalAddr(), quicConn.RemoteAddr()})
|
||||
conn := N.NewBufferedConn(&quicStreamConn{stream, quicConn.LocalAddr(), quicConn.RemoteAddr(), t})
|
||||
response, err := ReadResponse(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -252,8 +260,9 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f
|
|||
|
||||
type quicStreamConn struct {
|
||||
quic.Stream
|
||||
lAddr net.Addr
|
||||
rAddr net.Addr
|
||||
lAddr net.Addr
|
||||
rAddr net.Addr
|
||||
client *Client
|
||||
}
|
||||
|
||||
func (q *quicStreamConn) LocalAddr() net.Addr {
|
||||
|
|
Loading…
Reference in a new issue