diff --git a/listener/sing_tun/dns.go b/listener/sing_tun/dns.go index 4fd38e1d..57dcb1a5 100644 --- a/listener/sing_tun/dns.go +++ b/listener/sing_tun/dns.go @@ -73,7 +73,7 @@ func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, meta ctx, cancel := context.WithTimeout(ctx, DefaultDnsRelayTimeout) defer cancel() inData := buff[:n] - msg, err := RelayDnsPacket(ctx, inData) + msg, err := RelayDnsPacket(ctx, inData, buff) if err != nil { return err } @@ -110,7 +110,9 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network. conn2 = nil }() rwOptions := network.ReadWaitOptions{ - MTU: 2 * 1024, // safe size which is 1232 from https://dnsflagday.net/2020/, so 2048 is enough + FrontHeadroom: network.CalculateFrontHeadroom(conn), + RearHeadroom: network.CalculateRearHeadroom(conn), + MTU: 2 * 1024, // safe size which is 1232 from https://dnsflagday.net/2020/, so 2048 is enough } readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn) if isReadWaiter { @@ -118,24 +120,24 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network. } for { var ( - buff *buf.Buffer - dest M.Socksaddr - err error + readBuff *buf.Buffer + dest M.Socksaddr + err error ) _ = conn.SetReadDeadline(time.Now().Add(DefaultDnsReadTimeout)) - buff = nil // clear last loop status, avoid repeat release + readBuff = nil // clear last loop status, avoid repeat release if isReadWaiter { - buff, dest, err = readWaiter.WaitReadPacket() + readBuff, dest, err = readWaiter.WaitReadPacket() } else { - buff = rwOptions.NewPacketBuffer() - dest, err = conn.ReadPacket(buff) - if buff != nil { - rwOptions.PostReturn(buff) + readBuff = rwOptions.NewPacketBuffer() + dest, err = conn.ReadPacket(readBuff) + if readBuff != nil { + rwOptions.PostReturn(readBuff) } } if err != nil { - if buff != nil { - buff.Release() + if readBuff != nil { + readBuff.Release() } if sing.ShouldIgnorePacketError(err) { break @@ -145,26 +147,30 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network. go func() { ctx, cancel := context.WithTimeout(ctx, DefaultDnsRelayTimeout) defer cancel() - inData := buff.Bytes() - msg, err := RelayDnsPacket(ctx, inData) + inData := readBuff.Bytes() + writeBuff := readBuff + if writeBuff.Cap() < rwOptions.MTU { // only create a new buffer when space don't enough + writeBuff = rwOptions.NewPacketBuffer() + } + msg, err := RelayDnsPacket(ctx, inData, writeBuff.FreeBytes()) + if writeBuff != readBuff { + readBuff.Release() + } if err != nil { - buff.Release() - return - } - buff.Reset() - _, err = buff.Write(msg) - if err != nil { - buff.Release() + writeBuff.Release() return } + writeBuff.Truncate(len(msg)) mutex.Lock() defer mutex.Unlock() conn := conn2 if conn == nil { + writeBuff.Release() return } - err = conn.WritePacket(buff, dest) // WritePacket will release buff + err = conn.WritePacket(writeBuff, dest) // WritePacket will release writeBuff if err != nil { + writeBuff.Release() return } }() @@ -174,7 +180,7 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network. return h.ListenerHandler.NewPacketConnection(ctx, conn, metadata) } -func RelayDnsPacket(ctx context.Context, payload []byte) ([]byte, error) { +func RelayDnsPacket(ctx context.Context, payload []byte, target []byte) ([]byte, error) { msg := &D.Msg{} if err := msg.Unpack(payload); err != nil { return nil, err @@ -184,10 +190,10 @@ func RelayDnsPacket(ctx context.Context, payload []byte) ([]byte, error) { if err != nil { m := new(D.Msg) m.SetRcode(msg, D.RcodeServerFailure) - return m.Pack() + return m.PackBuffer(target) } r.SetRcode(msg, r.Rcode) r.Compress = true - return r.Pack() + return r.PackBuffer(target) }