diff --git a/transport/anytls/client.go b/transport/anytls/client.go index 8fbb314e..2076019e 100644 --- a/transport/anytls/client.go +++ b/transport/anytls/client.go @@ -50,12 +50,12 @@ func NewClient(ctx context.Context, config ClientConfig) *Client { } // Initialize the padding state of this client padding.UpdatePaddingScheme(padding.DefaultPaddingScheme, &c.padding) - c.sessionClient = session.NewClient(ctx, &c.padding, config.IdleSessionCheckInterval, config.IdleSessionTimeout) + c.sessionClient = session.NewClient(ctx, c.CreateOutboundTLSConnection, &c.padding, config.IdleSessionCheckInterval, config.IdleSessionTimeout) return c } func (c *Client) CreateProxy(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { - conn, err := c.sessionClient.CreateStream(ctx, c.CreateOutboundTLSConnection) + conn, err := c.sessionClient.CreateStream(ctx) if err != nil { return nil, err } diff --git a/transport/anytls/session/client.go b/transport/anytls/session/client.go index 1ab0a18b..0ce7acc0 100644 --- a/transport/anytls/session/client.go +++ b/transport/anytls/session/client.go @@ -22,6 +22,8 @@ type Client struct { die context.Context dieCancel context.CancelFunc + dialOut func(ctx context.Context) (net.Conn, error) + sessionCounter atomic.Uint64 idleSession *skiplist.SkipList[uint64, *Session] idleSessionLock sync.Mutex @@ -31,8 +33,9 @@ type Client struct { idleSessionTimeout time.Duration } -func NewClient(ctx context.Context, _padding *singAtomic.TypedValue[*padding.PaddingFactory], idleSessionCheckInterval, idleSessionTimeout time.Duration) *Client { +func NewClient(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error), _padding *singAtomic.TypedValue[*padding.PaddingFactory], idleSessionCheckInterval, idleSessionTimeout time.Duration) *Client { c := &Client{ + dialOut: dialOut, padding: _padding, idleSessionTimeout: idleSessionTimeout, } @@ -48,7 +51,7 @@ func NewClient(ctx context.Context, _padding *singAtomic.TypedValue[*padding.Pad return c } -func (c *Client) CreateStream(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error)) (net.Conn, error) { +func (c *Client) CreateStream(ctx context.Context) (net.Conn, error) { select { case <-c.die.Done(): return nil, io.ErrClosedPipe @@ -60,7 +63,7 @@ func (c *Client) CreateStream(ctx context.Context, dialOut func(ctx context.Cont var err error for i := 0; i < 3; i++ { - session, err = c.findSession(ctx, dialOut) + session, err = c.findSession(ctx) if session == nil { return nil, fmt.Errorf("failed to create session: %w", err) } @@ -91,7 +94,7 @@ func (c *Client) CreateStream(ctx context.Context, dialOut func(ctx context.Cont return streamC, nil } -func (c *Client) findSession(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error)) (*Session, error) { +func (c *Client) findSession(ctx context.Context) (*Session, error) { var idle *Session c.idleSessionLock.Lock() @@ -103,14 +106,14 @@ func (c *Client) findSession(ctx context.Context, dialOut func(ctx context.Conte c.idleSessionLock.Unlock() if idle == nil { - s, err := c.createSession(ctx, dialOut) + s, err := c.createSession(ctx) return s, err } return idle, nil } -func (c *Client) createSession(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error)) (*Session, error) { - underlying, err := dialOut(ctx) +func (c *Client) createSession(ctx context.Context) (*Session, error) { + underlying, err := c.dialOut(ctx) if err != nil { return nil, err }