From dbb5b7db1cc1faddfc70dfcbe8360a37e624a538 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Thu, 10 Apr 2025 23:32:26 +0800 Subject: [PATCH] fix: SetupContextForConn should return context error to user --- common/net/context.go | 15 ++++- common/net/context_test.go | 103 +++++++++++++++++++++++++++++++++++ transport/vmess/websocket.go | 2 +- 3 files changed, 117 insertions(+), 3 deletions(-) create mode 100644 common/net/context_test.go diff --git a/common/net/context.go b/common/net/context.go index ef0e9faf..b170516e 100644 --- a/common/net/context.go +++ b/common/net/context.go @@ -7,7 +7,18 @@ import ( "github.com/metacubex/mihomo/common/contextutils" ) -// SetupContextForConn is a helper function that starts connection I/O interrupter goroutine. +// SetupContextForConn is a helper function that starts connection I/O interrupter. +// if ctx be canceled before done called, it will close the connection. +// should use like this: +// +// func streamConn(ctx context.Context, conn net.Conn) (_ net.Conn, err error) { +// if ctx.Done() != nil { +// done := N.SetupContextForConn(ctx, conn) +// defer done(&err) +// } +// conn, err := xxx +// return conn, err +// } func SetupContextForConn(ctx context.Context, conn net.Conn) (done func(*error)) { stopc := make(chan struct{}) stop := contextutils.AfterFunc(ctx, func() { @@ -21,7 +32,7 @@ func SetupContextForConn(ctx context.Context, conn net.Conn) (done func(*error)) <-stopc if ctxErr := ctx.Err(); ctxErr != nil && inputErr != nil { // Return context error to user. - inputErr = &ctxErr + *inputErr = ctxErr } } } diff --git a/common/net/context_test.go b/common/net/context_test.go new file mode 100644 index 00000000..8e4c4ad1 --- /dev/null +++ b/common/net/context_test.go @@ -0,0 +1,103 @@ +package net_test + +import ( + "context" + "errors" + "net" + "testing" + "time" + + N "github.com/metacubex/mihomo/common/net" + + "github.com/stretchr/testify/assert" +) + +func testRead(ctx context.Context, conn net.Conn) (err error) { + if ctx.Done() != nil { + done := N.SetupContextForConn(ctx, conn) + defer done(&err) + } + _, err = conn.Read(make([]byte, 1)) + return err +} + +func TestSetupContextForConnWithCancel(t *testing.T) { + t.Parallel() + c1, c2 := N.Pipe() + defer c1.Close() + defer c2.Close() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errc := make(chan error) + go func() { + errc <- testRead(ctx, c1) + }() + + select { + case <-errc: + t.Fatal("conn closed before cancel") + case <-time.After(100 * time.Millisecond): + cancel() + } + + select { + case err := <-errc: + assert.ErrorIs(t, err, context.Canceled) + case <-time.After(100 * time.Millisecond): + t.Fatal("conn not be canceled") + } +} + +func TestSetupContextForConnWithTimeout1(t *testing.T) { + t.Parallel() + c1, c2 := N.Pipe() + defer c1.Close() + defer c2.Close() + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + errc := make(chan error) + go func() { + errc <- testRead(ctx, c1) + }() + + select { + case err := <-errc: + if !errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Fatal("conn closed before timeout") + } + assert.ErrorIs(t, err, context.DeadlineExceeded) + case <-time.After(200 * time.Millisecond): + t.Fatal("conn not be canceled") + } +} + +func TestSetupContextForConnWithTimeout2(t *testing.T) { + t.Parallel() + c1, c2 := N.Pipe() + defer c1.Close() + defer c2.Close() + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + errc := make(chan error) + go func() { + errc <- testRead(ctx, c1) + }() + + select { + case <-errc: + t.Fatal("conn closed before cancel") + case <-time.After(100 * time.Millisecond): + c2.Write(make([]byte, 1)) + } + + select { + case err := <-errc: + assert.Nil(t, ctx.Err()) + assert.Nil(t, err) + case <-time.After(200 * time.Millisecond): + t.Fatal("conn not be canceled") + } +} diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index 586f77e0..772c6fef 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -326,7 +326,7 @@ func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Co return N.NewDeadlineConn(conn), nil } -func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) { +func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (_ net.Conn, err error) { u, err := url.Parse(c.Path) if err != nil { return nil, fmt.Errorf("parse url %s error: %w", c.Path, err)