diff --git a/listener/tun/ipstack/system/mars/nat/nat.go b/listener/tun/ipstack/system/mars/nat/nat.go index 16b94de0..6feebbef 100644 --- a/listener/tun/ipstack/system/mars/nat/nat.go +++ b/listener/tun/ipstack/system/mars/nat/nat.go @@ -25,7 +25,6 @@ func Start( tab := newTable() udp := &UDP{ - calls: map[*call]struct{}{}, device: device, buf: [pool.UDPBufferSize]byte{}, } diff --git a/listener/tun/ipstack/system/mars/nat/udp.go b/listener/tun/ipstack/system/mars/nat/udp.go index 173aa792..b28614a5 100644 --- a/listener/tun/ipstack/system/mars/nat/udp.go +++ b/listener/tun/ipstack/system/mars/nat/udp.go @@ -20,28 +20,28 @@ type call struct { } type UDP struct { - closed bool - lock sync.Mutex - calls map[*call]struct{} - device io.Writer - bufLock sync.Mutex - buf [pool.UDPBufferSize]byte + closed bool + device io.Writer + queueLock sync.Mutex + queue []*call + bufLock sync.Mutex + buf [pool.UDPBufferSize]byte } func (u *UDP) ReadFrom(buf []byte) (int, net.Addr, net.Addr, error) { - u.lock.Lock() - defer u.lock.Unlock() + u.queueLock.Lock() + defer u.queueLock.Unlock() for !u.closed { c := &call{ - cond: sync.NewCond(&u.lock), + cond: sync.NewCond(&u.queueLock), buf: buf, n: -1, source: nil, destination: nil, } - u.calls[c] = struct{}{} + u.queue = append(u.queue, c) c.cond.Wait() @@ -54,6 +54,10 @@ func (u *UDP) ReadFrom(buf []byte) (int, net.Addr, net.Addr, error) { } func (u *UDP) WriteTo(buf []byte, local net.Addr, remote net.Addr) (int, error) { + if u.closed { + return 0, net.ErrClosed + } + u.bufLock.Lock() defer u.bufLock.Unlock() @@ -77,8 +81,9 @@ func (u *UDP) WriteTo(buf []byte, local net.Addr, remote net.Addr) (int, error) return 0, net.InvalidAddrError("invalid ip version") } + tcpip.SetIPv4(u.buf[:]) + ip := tcpip.IPv4Packet(u.buf[:]) - tcpip.SetIPv4(ip) ip.SetHeaderLen(tcpip.IPv4HeaderSize) ip.SetTotalLength(tcpip.IPv4HeaderSize + tcpip.UDPHeaderSize + uint16(len(buf))) ip.SetTypeOfService(0) @@ -102,12 +107,12 @@ func (u *UDP) WriteTo(buf []byte, local net.Addr, remote net.Addr) (int, error) } func (u *UDP) Close() error { - u.lock.Lock() - defer u.lock.Unlock() + u.queueLock.Lock() + defer u.queueLock.Unlock() u.closed = true - for c := range u.calls { + for _, c := range u.queue { c.cond.Signal() } @@ -117,14 +122,15 @@ func (u *UDP) Close() error { func (u *UDP) handleUDPPacket(ip tcpip.IP, pkt tcpip.UDPPacket) { var c *call - u.lock.Lock() + u.queueLock.Lock() - for c = range u.calls { - delete(u.calls, c) - break + if len(u.queue) > 0 { + idx := len(u.queue) - 1 + c = u.queue[idx] + u.queue = u.queue[:idx] } - u.lock.Unlock() + u.queueLock.Unlock() if c != nil { c.source = &net.UDPAddr{