refactor: passing dial in constructor

This commit is contained in:
anytls 2025-02-17 14:22:19 +09:00
parent e4cf7e9133
commit b012a3cbba
2 changed files with 12 additions and 9 deletions

View file

@ -50,12 +50,12 @@ func NewClient(ctx context.Context, config ClientConfig) *Client {
}
// Initialize the padding state of this client
padding.UpdatePaddingScheme(padding.DefaultPaddingScheme, &c.padding)
c.sessionClient = session.NewClient(ctx, &c.padding, config.IdleSessionCheckInterval, config.IdleSessionTimeout)
c.sessionClient = session.NewClient(ctx, c.CreateOutboundTLSConnection, &c.padding, config.IdleSessionCheckInterval, config.IdleSessionTimeout)
return c
}
func (c *Client) CreateProxy(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
conn, err := c.sessionClient.CreateStream(ctx, c.CreateOutboundTLSConnection)
conn, err := c.sessionClient.CreateStream(ctx)
if err != nil {
return nil, err
}

View file

@ -22,6 +22,8 @@ type Client struct {
die context.Context
dieCancel context.CancelFunc
dialOut func(ctx context.Context) (net.Conn, error)
sessionCounter atomic.Uint64
idleSession *skiplist.SkipList[uint64, *Session]
idleSessionLock sync.Mutex
@ -31,8 +33,9 @@ type Client struct {
idleSessionTimeout time.Duration
}
func NewClient(ctx context.Context, _padding *singAtomic.TypedValue[*padding.PaddingFactory], idleSessionCheckInterval, idleSessionTimeout time.Duration) *Client {
func NewClient(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error), _padding *singAtomic.TypedValue[*padding.PaddingFactory], idleSessionCheckInterval, idleSessionTimeout time.Duration) *Client {
c := &Client{
dialOut: dialOut,
padding: _padding,
idleSessionTimeout: idleSessionTimeout,
}
@ -48,7 +51,7 @@ func NewClient(ctx context.Context, _padding *singAtomic.TypedValue[*padding.Pad
return c
}
func (c *Client) CreateStream(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error)) (net.Conn, error) {
func (c *Client) CreateStream(ctx context.Context) (net.Conn, error) {
select {
case <-c.die.Done():
return nil, io.ErrClosedPipe
@ -60,7 +63,7 @@ func (c *Client) CreateStream(ctx context.Context, dialOut func(ctx context.Cont
var err error
for i := 0; i < 3; i++ {
session, err = c.findSession(ctx, dialOut)
session, err = c.findSession(ctx)
if session == nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
@ -91,7 +94,7 @@ func (c *Client) CreateStream(ctx context.Context, dialOut func(ctx context.Cont
return streamC, nil
}
func (c *Client) findSession(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error)) (*Session, error) {
func (c *Client) findSession(ctx context.Context) (*Session, error) {
var idle *Session
c.idleSessionLock.Lock()
@ -103,14 +106,14 @@ func (c *Client) findSession(ctx context.Context, dialOut func(ctx context.Conte
c.idleSessionLock.Unlock()
if idle == nil {
s, err := c.createSession(ctx, dialOut)
s, err := c.createSession(ctx)
return s, err
}
return idle, nil
}
func (c *Client) createSession(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error)) (*Session, error) {
underlying, err := dialOut(ctx)
func (c *Client) createSession(ctx context.Context) (*Session, error) {
underlying, err := c.dialOut(ctx)
if err != nil {
return nil, err
}