Fix: tunnel UDP race condition (#1043)
This commit is contained in:
parent
ba060bd0ee
commit
2cd1b890ce
2 changed files with 52 additions and 44 deletions
|
@ -22,9 +22,9 @@ func (t *Table) Get(key string) C.PacketConn {
|
|||
return item.(C.PacketConn)
|
||||
}
|
||||
|
||||
func (t *Table) GetOrCreateLock(key string) (*sync.WaitGroup, bool) {
|
||||
item, loaded := t.mapping.LoadOrStore(key, &sync.WaitGroup{})
|
||||
return item.(*sync.WaitGroup), loaded
|
||||
func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) {
|
||||
item, loaded := t.mapping.LoadOrStore(key, sync.NewCond(&sync.Mutex{}))
|
||||
return item.(*sync.Cond), loaded
|
||||
}
|
||||
|
||||
func (t *Table) Delete(key string) {
|
||||
|
|
|
@ -164,7 +164,7 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
|
|||
return
|
||||
}
|
||||
|
||||
// make a fAddr if requset ip is fakeip
|
||||
// make a fAddr if request ip is fakeip
|
||||
var fAddr net.Addr
|
||||
if resolver.IsExistFakeIP(metadata.DstIP) {
|
||||
fAddr = metadata.UDPAddr()
|
||||
|
@ -176,34 +176,49 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
|
|||
}
|
||||
|
||||
key := packet.LocalAddr().String()
|
||||
|
||||
handle := func() bool {
|
||||
pc := natTable.Get(key)
|
||||
if pc != nil {
|
||||
handleUDPToRemote(packet, pc, metadata)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if handle() {
|
||||
return
|
||||
}
|
||||
|
||||
lockKey := key + "-lock"
|
||||
wg, loaded := natTable.GetOrCreateLock(lockKey)
|
||||
cond, loaded := natTable.GetOrCreateLock(lockKey)
|
||||
|
||||
go func() {
|
||||
if !loaded {
|
||||
wg.Add(1)
|
||||
if loaded {
|
||||
cond.L.Lock()
|
||||
cond.Wait()
|
||||
handle()
|
||||
cond.L.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
natTable.Delete(lockKey)
|
||||
cond.Broadcast()
|
||||
}()
|
||||
|
||||
proxy, rule, err := resolveMetadata(metadata)
|
||||
if err != nil {
|
||||
log.Warnln("[UDP] Parse metadata failed: %s", err.Error())
|
||||
natTable.Delete(lockKey)
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
|
||||
rawPc, err := proxy.DialUDP(metadata)
|
||||
if err != nil {
|
||||
log.Warnln("[UDP] dial %s error: %s", proxy.Name(), err.Error())
|
||||
natTable.Delete(lockKey)
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
pc = newUDPTracker(rawPc, DefaultManager, metadata, rule)
|
||||
pc := newUDPTracker(rawPc, DefaultManager, metadata, rule)
|
||||
|
||||
switch true {
|
||||
case rule != nil:
|
||||
|
@ -216,17 +231,10 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
|
|||
log.Infoln("[UDP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String())
|
||||
}
|
||||
|
||||
natTable.Set(key, pc)
|
||||
natTable.Delete(lockKey)
|
||||
wg.Done()
|
||||
go handleUDPToLocal(packet.UDPPacket, pc, key, fAddr)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
pc := natTable.Get(key)
|
||||
if pc != nil {
|
||||
handleUDPToRemote(packet, pc, metadata)
|
||||
}
|
||||
natTable.Set(key, pc)
|
||||
handle()
|
||||
}()
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue