diff --git a/adapter/outbound/vmess.go b/adapter/outbound/vmess.go index 52c89020..12bcfaba 100644 --- a/adapter/outbound/vmess.go +++ b/adapter/outbound/vmess.go @@ -9,6 +9,7 @@ import ( "net/http" "strconv" "strings" + "sync" "github.com/Dreamacro/clash/common/convert" "github.com/Dreamacro/clash/component/dialer" @@ -267,9 +268,9 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o } if v.option.PacketAddr { - return newPacketConn(packetaddr.NewBindClient(c), v), nil + return newPacketConn(&threadSafePacketConn{PacketConn: packetaddr.NewBindClient(c)}, v), nil } else if pc, ok := c.(net.PacketConn); ok { - return newPacketConn(pc, v), nil + return newPacketConn(&threadSafePacketConn{PacketConn: pc}, v), nil } return newPacketConn(&vmessPacketConn{Conn: c, rAddr: metadata.UDPAddr()}, v), nil } @@ -277,9 +278,9 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o // ListenPacketOnStreamConn implements C.ProxyAdapter func (v *Vmess) ListenPacketOnStreamConn(c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) { if v.option.PacketAddr { - return newPacketConn(packetaddr.NewBindClient(c), v), nil + return newPacketConn(&threadSafePacketConn{PacketConn: packetaddr.NewBindClient(c)}, v), nil } else if pc, ok := c.(net.PacketConn); ok { - return newPacketConn(pc, v), nil + return newPacketConn(&threadSafePacketConn{PacketConn: pc}, v), nil } return newPacketConn(&vmessPacketConn{Conn: c, rAddr: metadata.UDPAddr()}, v), nil } @@ -357,12 +358,26 @@ func NewVmess(option VmessOption) (*Vmess, error) { return v, nil } +type threadSafePacketConn struct { + net.PacketConn + access sync.Mutex +} + +func (c *threadSafePacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { + c.access.Lock() + defer c.access.Unlock() + return c.PacketConn.WriteTo(b, addr) +} + type vmessPacketConn struct { net.Conn - rAddr net.Addr + rAddr net.Addr + access sync.Mutex } func (uc *vmessPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { + uc.access.Lock() + defer uc.access.Unlock() return uc.Conn.Write(b) }