chore: better tuic conn close
This commit is contained in:
parent
b2939ad863
commit
25540e6c96
2 changed files with 26 additions and 11 deletions
|
@ -9,6 +9,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -199,6 +200,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
|
||||||
RequestTimeout: option.RequestTimeout,
|
RequestTimeout: option.RequestTimeout,
|
||||||
}
|
}
|
||||||
clientMap[o] = client
|
clientMap[o] = client
|
||||||
|
runtime.SetFinalizer(client, closeTuicClient)
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -214,3 +216,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
|
||||||
getClient: getClient,
|
getClient: getClient,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func closeTuicClient(client *tuic.Client) {
|
||||||
|
client.Close(nil)
|
||||||
|
}
|
||||||
|
|
|
@ -197,18 +197,26 @@ func (t *Client) deferQuicConn(quicConn quic.Connection, err error) {
|
||||||
t.connMutex.Lock()
|
t.connMutex.Lock()
|
||||||
defer t.connMutex.Unlock()
|
defer t.connMutex.Unlock()
|
||||||
if t.quicConn == quicConn {
|
if t.quicConn == quicConn {
|
||||||
t.udpInputMap.Range(func(key, value any) bool {
|
t.Close(err)
|
||||||
if conn, ok := value.(net.Conn); ok {
|
|
||||||
_ = conn.Close()
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
t.udpInputMap = sync.Map{} // new one
|
|
||||||
t.quicConn = nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Client) Close(err error) {
|
||||||
|
quicConn := t.quicConn
|
||||||
|
if quicConn != nil {
|
||||||
|
_ = t.quicConn.CloseWithError(ProtocolError, err.Error())
|
||||||
|
t.udpInputMap.Range(func(key, value any) bool {
|
||||||
|
if conn, ok := value.(net.Conn); ok {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
t.udpInputMap = sync.Map{} // new one
|
||||||
|
t.quicConn = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)) (net.Conn, error) {
|
func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)) (net.Conn, error) {
|
||||||
quicConn, err := t.getQuicConn(ctx, dialFn)
|
quicConn, err := t.getQuicConn(ctx, dialFn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -237,7 +245,7 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f
|
||||||
if t.RequestTimeout > 0 {
|
if t.RequestTimeout > 0 {
|
||||||
_ = stream.SetReadDeadline(time.Now().Add(time.Duration(t.RequestTimeout) * time.Millisecond))
|
_ = stream.SetReadDeadline(time.Now().Add(time.Duration(t.RequestTimeout) * time.Millisecond))
|
||||||
}
|
}
|
||||||
conn := N.NewBufferedConn(&quicStreamConn{stream, quicConn.LocalAddr(), quicConn.RemoteAddr()})
|
conn := N.NewBufferedConn(&quicStreamConn{stream, quicConn.LocalAddr(), quicConn.RemoteAddr(), t})
|
||||||
response, err := ReadResponse(conn)
|
response, err := ReadResponse(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -252,8 +260,9 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f
|
||||||
|
|
||||||
type quicStreamConn struct {
|
type quicStreamConn struct {
|
||||||
quic.Stream
|
quic.Stream
|
||||||
lAddr net.Addr
|
lAddr net.Addr
|
||||||
rAddr net.Addr
|
rAddr net.Addr
|
||||||
|
client *Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *quicStreamConn) LocalAddr() net.Addr {
|
func (q *quicStreamConn) LocalAddr() net.Addr {
|
||||||
|
|
Loading…
Reference in a new issue