From f4b9f2965ffc1e2a9d72b52ef1a400adfbd92cc9 Mon Sep 17 00:00:00 2001 From: Skyxim Date: Sun, 26 Jun 2022 21:52:22 +0800 Subject: [PATCH] fix: hysteria dial use external context --- adapter/outbound/hysteria.go | 42 ++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/adapter/outbound/hysteria.go b/adapter/outbound/hysteria.go index 60b2efa0..7b4bcd78 100644 --- a/adapter/outbound/hysteria.go +++ b/adapter/outbound/hysteria.go @@ -41,24 +41,32 @@ var rateStringRegexp = regexp.MustCompile(`^(\d+)\s*([KMGT]?)([Bb])ps$`) type Hysteria struct { *Base - client *core.Client - clientTransport *transport.ClientTransport + client *core.Client } func (h *Hysteria) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { - tcpConn, err := h.client.DialTCP(metadata.RemoteAddress(), hyDialer(func() (net.PacketConn, error) { - return dialer.ListenPacket(ctx, "udp", "", h.Base.DialOptions(opts...)...) - })) + hdc := hyDialerWithContext{ + ctx: ctx, + hyDialer: func() (net.PacketConn, error) { + return dialer.ListenPacket(ctx, "udp", "", h.Base.DialOptions(opts...)...) + }, + } + tcpConn, err := h.client.DialTCP(metadata.RemoteAddress(), &hdc) if err != nil { return nil, err } + return NewConn(tcpConn, h), nil } func (h *Hysteria) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { - udpConn, err := h.client.DialUDP(hyDialer(func() (net.PacketConn, error) { - return dialer.ListenPacket(ctx, "udp", "", h.Base.DialOptions(opts...)...) - })) + hdc := hyDialerWithContext{ + ctx: ctx, + hyDialer: func() (net.PacketConn, error) { + return dialer.ListenPacket(ctx, "udp", "", h.Base.DialOptions(opts...)...) + }, + } + udpConn, err := h.client.DialUDP(&hdc) if err != nil { return nil, err } @@ -191,8 +199,7 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) { iface: option.Interface, rmark: option.RoutingMark, }, - client: client, - clientTransport: clientTransport, + client: client, }, nil } @@ -255,8 +262,15 @@ func (c *hyPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { return } -type hyDialer func() (net.PacketConn, error) - -func (h hyDialer) ListenPacket() (net.PacketConn, error) { - return h() +type hyDialerWithContext struct { + hyDialer func() (net.PacketConn, error) + ctx context.Context +} + +func (h *hyDialerWithContext) ListenPacket() (net.PacketConn, error) { + return h.hyDialer() +} + +func (h *hyDialerWithContext) Context() context.Context { + return h.ctx }