Fix: tunnel UDP race condition (#1043)

This commit is contained in:
gVisor bot 2020-10-28 21:26:50 +08:00
parent 42d5c8d1d8
commit f066687f93
2 changed files with 52 additions and 44 deletions

View file

@ -22,9 +22,9 @@ func (t *Table) Get(key string) C.PacketConn {
return item.(C.PacketConn) return item.(C.PacketConn)
} }
func (t *Table) GetOrCreateLock(key string) (*sync.WaitGroup, bool) { func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) {
item, loaded := t.mapping.LoadOrStore(key, &sync.WaitGroup{}) item, loaded := t.mapping.LoadOrStore(key, sync.NewCond(&sync.Mutex{}))
return item.(*sync.WaitGroup), loaded return item.(*sync.Cond), loaded
} }
func (t *Table) Delete(key string) { func (t *Table) Delete(key string) {

View file

@ -164,7 +164,7 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
return return
} }
// make a fAddr if requset ip is fakeip // make a fAddr if request ip is fakeip
var fAddr net.Addr var fAddr net.Addr
if resolver.IsExistFakeIP(metadata.DstIP) { if resolver.IsExistFakeIP(metadata.DstIP) {
fAddr = metadata.UDPAddr() fAddr = metadata.UDPAddr()
@ -176,34 +176,49 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
} }
key := packet.LocalAddr().String() key := packet.LocalAddr().String()
handle := func() bool {
pc := natTable.Get(key) pc := natTable.Get(key)
if pc != nil { if pc != nil {
handleUDPToRemote(packet, pc, metadata) handleUDPToRemote(packet, pc, metadata)
return true
}
return false
}
if handle() {
return return
} }
lockKey := key + "-lock" lockKey := key + "-lock"
wg, loaded := natTable.GetOrCreateLock(lockKey) cond, loaded := natTable.GetOrCreateLock(lockKey)
go func() { go func() {
if !loaded { if loaded {
wg.Add(1) cond.L.Lock()
cond.Wait()
handle()
cond.L.Unlock()
return
}
defer func() {
natTable.Delete(lockKey)
cond.Broadcast()
}()
proxy, rule, err := resolveMetadata(metadata) proxy, rule, err := resolveMetadata(metadata)
if err != nil { if err != nil {
log.Warnln("[UDP] Parse metadata failed: %s", err.Error()) log.Warnln("[UDP] Parse metadata failed: %s", err.Error())
natTable.Delete(lockKey)
wg.Done()
return return
} }
rawPc, err := proxy.DialUDP(metadata) rawPc, err := proxy.DialUDP(metadata)
if err != nil { if err != nil {
log.Warnln("[UDP] dial %s error: %s", proxy.Name(), err.Error()) log.Warnln("[UDP] dial %s error: %s", proxy.Name(), err.Error())
natTable.Delete(lockKey)
wg.Done()
return return
} }
pc = newUDPTracker(rawPc, DefaultManager, metadata, rule) pc := newUDPTracker(rawPc, DefaultManager, metadata, rule)
switch true { switch true {
case rule != nil: case rule != nil:
@ -216,17 +231,10 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
log.Infoln("[UDP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String()) log.Infoln("[UDP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String())
} }
natTable.Set(key, pc)
natTable.Delete(lockKey)
wg.Done()
go handleUDPToLocal(packet.UDPPacket, pc, key, fAddr) go handleUDPToLocal(packet.UDPPacket, pc, key, fAddr)
}
wg.Wait() natTable.Set(key, pc)
pc := natTable.Get(key) handle()
if pc != nil {
handleUDPToRemote(packet, pc, metadata)
}
}() }()
} }