Fix: tunnel UDP race condition (#1043)

This commit is contained in:
Jason Lyu 2020-10-28 21:26:50 +08:00 committed by GitHub
parent ba060bd0ee
commit 2cd1b890ce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 44 deletions

View file

@ -22,9 +22,9 @@ func (t *Table) Get(key string) C.PacketConn {
return item.(C.PacketConn) return item.(C.PacketConn)
} }
func (t *Table) GetOrCreateLock(key string) (*sync.WaitGroup, bool) { func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) {
item, loaded := t.mapping.LoadOrStore(key, &sync.WaitGroup{}) item, loaded := t.mapping.LoadOrStore(key, sync.NewCond(&sync.Mutex{}))
return item.(*sync.WaitGroup), loaded return item.(*sync.Cond), loaded
} }
func (t *Table) Delete(key string) { func (t *Table) Delete(key string) {

View file

@ -164,7 +164,7 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
return return
} }
// make a fAddr if requset ip is fakeip // make a fAddr if request ip is fakeip
var fAddr net.Addr var fAddr net.Addr
if resolver.IsExistFakeIP(metadata.DstIP) { if resolver.IsExistFakeIP(metadata.DstIP) {
fAddr = metadata.UDPAddr() fAddr = metadata.UDPAddr()
@ -176,57 +176,65 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
} }
key := packet.LocalAddr().String() key := packet.LocalAddr().String()
pc := natTable.Get(key)
if pc != nil { handle := func() bool {
handleUDPToRemote(packet, pc, metadata) pc := natTable.Get(key)
if pc != nil {
handleUDPToRemote(packet, pc, metadata)
return true
}
return false
}
if handle() {
return return
} }
lockKey := key + "-lock" lockKey := key + "-lock"
wg, loaded := natTable.GetOrCreateLock(lockKey) cond, loaded := natTable.GetOrCreateLock(lockKey)
go func() { go func() {
if !loaded { if loaded {
wg.Add(1) cond.L.Lock()
proxy, rule, err := resolveMetadata(metadata) cond.Wait()
if err != nil { handle()
log.Warnln("[UDP] Parse metadata failed: %s", err.Error()) cond.L.Unlock()
natTable.Delete(lockKey) return
wg.Done() }
return
}
rawPc, err := proxy.DialUDP(metadata) defer func() {
if err != nil {
log.Warnln("[UDP] dial %s error: %s", proxy.Name(), err.Error())
natTable.Delete(lockKey)
wg.Done()
return
}
pc = newUDPTracker(rawPc, DefaultManager, metadata, rule)
switch true {
case rule != nil:
log.Infoln("[UDP] %s --> %v match %s(%s) using %s", metadata.SourceAddress(), metadata.String(), rule.RuleType().String(), rule.Payload(), rawPc.Chains().String())
case mode == Global:
log.Infoln("[UDP] %s --> %v using GLOBAL", metadata.SourceAddress(), metadata.String())
case mode == Direct:
log.Infoln("[UDP] %s --> %v using DIRECT", metadata.SourceAddress(), metadata.String())
default:
log.Infoln("[UDP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String())
}
natTable.Set(key, pc)
natTable.Delete(lockKey) natTable.Delete(lockKey)
wg.Done() cond.Broadcast()
go handleUDPToLocal(packet.UDPPacket, pc, key, fAddr) }()
proxy, rule, err := resolveMetadata(metadata)
if err != nil {
log.Warnln("[UDP] Parse metadata failed: %s", err.Error())
return
} }
wg.Wait() rawPc, err := proxy.DialUDP(metadata)
pc := natTable.Get(key) if err != nil {
if pc != nil { log.Warnln("[UDP] dial %s error: %s", proxy.Name(), err.Error())
handleUDPToRemote(packet, pc, metadata) return
} }
pc := newUDPTracker(rawPc, DefaultManager, metadata, rule)
switch true {
case rule != nil:
log.Infoln("[UDP] %s --> %v match %s(%s) using %s", metadata.SourceAddress(), metadata.String(), rule.RuleType().String(), rule.Payload(), rawPc.Chains().String())
case mode == Global:
log.Infoln("[UDP] %s --> %v using GLOBAL", metadata.SourceAddress(), metadata.String())
case mode == Direct:
log.Infoln("[UDP] %s --> %v using DIRECT", metadata.SourceAddress(), metadata.String())
default:
log.Infoln("[UDP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String())
}
go handleUDPToLocal(packet.UDPPacket, pc, key, fAddr)
natTable.Set(key, pc)
handle()
}() }()
} }