chore: proxyDialer first using old function to let mux work

This commit is contained in:
wwqgtxx 2023-04-12 11:09:31 +08:00
parent fda8857ec8
commit 17922dc857
2 changed files with 20 additions and 40 deletions

View file

@ -45,23 +45,13 @@ func (p proxyDialer) DialContext(ctx context.Context, network, address string) (
return N.NewBindPacketConn(pc, currentMeta.UDPAddr()), nil return N.NewBindPacketConn(pc, currentMeta.UDPAddr()), nil
} }
var conn C.Conn var conn C.Conn
switch p.proxy.SupportWithDialer() { if d, ok := p.dialer.(dialer.Dialer); ok { // first using old function to let mux work
case C.ALLNet: conn, err = p.proxy.DialContext(ctx, currentMeta, dialer.WithOption(d.Opt))
fallthrough } else {
case C.TCP:
conn, err = p.proxy.DialContextWithDialer(ctx, p.dialer, currentMeta) conn, err = p.proxy.DialContextWithDialer(ctx, p.dialer, currentMeta)
if err != nil { }
return nil, err if err != nil {
} return nil, err
default: // fallback to old function
if d, ok := p.dialer.(dialer.Dialer); ok { // fallback to old function
conn, err = p.proxy.DialContext(ctx, currentMeta, dialer.WithOption(d.Opt))
if err != nil {
return nil, err
}
} else {
return nil, C.ErrNotSupport
}
} }
if p.statistic { if p.statistic {
conn = statistic.NewTCPTracker(conn, statistic.DefaultManager, currentMeta, nil, 0, 0, false) conn = statistic.NewTCPTracker(conn, statistic.DefaultManager, currentMeta, nil, 0, 0, false)
@ -81,23 +71,13 @@ func (p proxyDialer) listenPacket(ctx context.Context, currentMeta *C.Metadata)
var pc C.PacketConn var pc C.PacketConn
var err error var err error
currentMeta.NetWork = C.UDP currentMeta.NetWork = C.UDP
switch p.proxy.SupportWithDialer() { if d, ok := p.dialer.(dialer.Dialer); ok { // first using old function to let mux work
case C.ALLNet: pc, err = p.proxy.ListenPacketContext(ctx, currentMeta, dialer.WithOption(d.Opt))
fallthrough } else {
case C.UDP:
pc, err = p.proxy.ListenPacketWithDialer(ctx, p.dialer, currentMeta) pc, err = p.proxy.ListenPacketWithDialer(ctx, p.dialer, currentMeta)
if err != nil { }
return nil, err if err != nil {
} return nil, err
default: // fallback to old function
if d, ok := p.dialer.(dialer.Dialer); ok { // fallback to old function
pc, err = p.proxy.ListenPacketContext(ctx, currentMeta, dialer.WithOption(d.Opt))
if err != nil {
return nil, err
}
} else {
return nil, C.ErrNotSupport
}
} }
if p.statistic { if p.statistic {
pc = statistic.NewUDPTracker(pc, statistic.DefaultManager, currentMeta, nil, 0, 0, false) pc = statistic.NewUDPTracker(pc, statistic.DefaultManager, currentMeta, nil, 0, 0, false)

View file

@ -170,15 +170,15 @@ func getDialHandler(r *Resolver, proxyAdapter C.ProxyAdapter, proxyName string,
Host: host, Host: host,
DstPort: port, DstPort: port,
} }
if proxyAdapter.IsL3Protocol(metadata) {
dstIP, err := resolver.ResolveIPWithResolver(ctx, host, r)
if err != nil {
return nil, err
}
metadata.Host = ""
metadata.DstIP = dstIP
}
if proxyAdapter != nil { if proxyAdapter != nil {
if proxyAdapter.IsL3Protocol(metadata) { // L3 proxy should resolve domain before to avoid loopback
dstIP, err := resolver.ResolveIPWithResolver(ctx, host, r)
if err != nil {
return nil, err
}
metadata.Host = ""
metadata.DstIP = dstIP
}
return proxyAdapter.DialContext(ctx, metadata, opts...) return proxyAdapter.DialContext(ctx, metadata, opts...)
} }
opts = append(opts, dialer.WithResolver(r)) opts = append(opts, dialer.WithResolver(r))