chore: proxyDialer first using old function to let mux work

This commit is contained in:
wwqgtxx 2023-04-12 11:09:31 +08:00
parent fda8857ec8
commit 17922dc857
2 changed files with 20 additions and 40 deletions

View file

@ -45,24 +45,14 @@ func (p proxyDialer) DialContext(ctx context.Context, network, address string) (
return N.NewBindPacketConn(pc, currentMeta.UDPAddr()), nil return N.NewBindPacketConn(pc, currentMeta.UDPAddr()), nil
} }
var conn C.Conn var conn C.Conn
switch p.proxy.SupportWithDialer() { if d, ok := p.dialer.(dialer.Dialer); ok { // first using old function to let mux work
case C.ALLNet:
fallthrough
case C.TCP:
conn, err = p.proxy.DialContextWithDialer(ctx, p.dialer, currentMeta)
if err != nil {
return nil, err
}
default: // fallback to old function
if d, ok := p.dialer.(dialer.Dialer); ok { // fallback to old function
conn, err = p.proxy.DialContext(ctx, currentMeta, dialer.WithOption(d.Opt)) conn, err = p.proxy.DialContext(ctx, currentMeta, dialer.WithOption(d.Opt))
} else {
conn, err = p.proxy.DialContextWithDialer(ctx, p.dialer, currentMeta)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else {
return nil, C.ErrNotSupport
}
}
if p.statistic { if p.statistic {
conn = statistic.NewTCPTracker(conn, statistic.DefaultManager, currentMeta, nil, 0, 0, false) conn = statistic.NewTCPTracker(conn, statistic.DefaultManager, currentMeta, nil, 0, 0, false)
} }
@ -81,24 +71,14 @@ func (p proxyDialer) listenPacket(ctx context.Context, currentMeta *C.Metadata)
var pc C.PacketConn var pc C.PacketConn
var err error var err error
currentMeta.NetWork = C.UDP currentMeta.NetWork = C.UDP
switch p.proxy.SupportWithDialer() { if d, ok := p.dialer.(dialer.Dialer); ok { // first using old function to let mux work
case C.ALLNet:
fallthrough
case C.UDP:
pc, err = p.proxy.ListenPacketWithDialer(ctx, p.dialer, currentMeta)
if err != nil {
return nil, err
}
default: // fallback to old function
if d, ok := p.dialer.(dialer.Dialer); ok { // fallback to old function
pc, err = p.proxy.ListenPacketContext(ctx, currentMeta, dialer.WithOption(d.Opt)) pc, err = p.proxy.ListenPacketContext(ctx, currentMeta, dialer.WithOption(d.Opt))
} else {
pc, err = p.proxy.ListenPacketWithDialer(ctx, p.dialer, currentMeta)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else {
return nil, C.ErrNotSupport
}
}
if p.statistic { if p.statistic {
pc = statistic.NewUDPTracker(pc, statistic.DefaultManager, currentMeta, nil, 0, 0, false) pc = statistic.NewUDPTracker(pc, statistic.DefaultManager, currentMeta, nil, 0, 0, false)
} }

View file

@ -170,7 +170,8 @@ func getDialHandler(r *Resolver, proxyAdapter C.ProxyAdapter, proxyName string,
Host: host, Host: host,
DstPort: port, DstPort: port,
} }
if proxyAdapter.IsL3Protocol(metadata) { if proxyAdapter != nil {
if proxyAdapter.IsL3Protocol(metadata) { // L3 proxy should resolve domain before to avoid loopback
dstIP, err := resolver.ResolveIPWithResolver(ctx, host, r) dstIP, err := resolver.ResolveIPWithResolver(ctx, host, r)
if err != nil { if err != nil {
return nil, err return nil, err
@ -178,7 +179,6 @@ func getDialHandler(r *Resolver, proxyAdapter C.ProxyAdapter, proxyName string,
metadata.Host = "" metadata.Host = ""
metadata.DstIP = dstIP metadata.DstIP = dstIP
} }
if proxyAdapter != nil {
return proxyAdapter.DialContext(ctx, metadata, opts...) return proxyAdapter.DialContext(ctx, metadata, opts...)
} }
opts = append(opts, dialer.WithResolver(r)) opts = append(opts, dialer.WithResolver(r))