From 6236cb1cf0f589eabf158346511f973346502a44 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Sat, 19 Apr 2025 01:32:55 +0800 Subject: [PATCH] chore: cleanup trojan code --- adapter/outbound/trojan.go | 101 +++++++++++++++++-------------------- adapter/outbound/vless.go | 3 +- adapter/outbound/vmess.go | 11 ++-- 3 files changed, 55 insertions(+), 60 deletions(-) diff --git a/adapter/outbound/trojan.go b/adapter/outbound/trojan.go index 49bc6cd2..257b58d6 100644 --- a/adapter/outbound/trojan.go +++ b/adapter/outbound/trojan.go @@ -62,10 +62,16 @@ type TrojanSSOption struct { Password string `proxy:"password,omitempty"` } -func (t *Trojan) plainStream(ctx context.Context, c net.Conn) (net.Conn, error) { +// StreamConnContext implements C.ProxyAdapter +func (t *Trojan) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { var err error - if t.option.Network == "ws" { + if tlsC.HaveGlobalFingerprint() && len(t.option.ClientFingerprint) == 0 { + t.option.ClientFingerprint = tlsC.GetGlobalFingerprint() + } + + switch t.option.Network { + case "ws": host, port, _ := net.SplitHostPort(t.addr) wsOpts := &vmess.WebsocketConfig{ @@ -108,46 +114,46 @@ func (t *Trojan) plainStream(ctx context.Context, c net.Conn) (net.Conn, error) return nil, err } - return vmess.StreamWebsocketConn(ctx, c, wsOpts) - } - - alpn := trojan.DefaultALPN - if len(t.option.ALPN) != 0 { - alpn = t.option.ALPN - } - return vmess.StreamTLSConn(ctx, c, &vmess.TLSConfig{ - Host: t.option.SNI, - SkipCertVerify: t.option.SkipCertVerify, - FingerPrint: t.option.Fingerprint, - ClientFingerprint: t.option.ClientFingerprint, - NextProtos: alpn, - Reality: t.realityConfig, - }) -} - -// StreamConnContext implements C.ProxyAdapter -func (t *Trojan) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { - var err error - - if tlsC.HaveGlobalFingerprint() && len(t.option.ClientFingerprint) == 0 { - t.option.ClientFingerprint = tlsC.GetGlobalFingerprint() - } - - if t.transport != nil { + c, err = vmess.StreamWebsocketConn(ctx, c, wsOpts) + case "grpc": c, err = gun.StreamGunWithConn(c, t.gunTLSConfig, t.gunConfig, t.realityConfig) - } else { - c, err = t.plainStream(ctx, c) + default: + // default tcp network + // handle TLS + alpn := trojan.DefaultALPN + if len(t.option.ALPN) != 0 { + alpn = t.option.ALPN + } + c, err = vmess.StreamTLSConn(ctx, c, &vmess.TLSConfig{ + Host: t.option.SNI, + SkipCertVerify: t.option.SkipCertVerify, + FingerPrint: t.option.Fingerprint, + ClientFingerprint: t.option.ClientFingerprint, + NextProtos: alpn, + Reality: t.realityConfig, + }) } - if err != nil { return nil, fmt.Errorf("%s connect error: %w", t.addr, err) } + return t.streamConnContext(ctx, c, metadata) +} + +func (t *Trojan) streamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ net.Conn, err error) { if t.ssCipher != nil { c = t.ssCipher.StreamConn(c) } - err = t.writeHeaderContext(ctx, c, metadata) + if ctx.Done() != nil { + done := N.SetupContextForConn(ctx, c) + defer done(&err) + } + command := trojan.CommandTCP + if metadata.NetWork == C.UDP { + command = trojan.CommandUDP + } + err = trojan.WriteHeader(c, t.hexPassword, command, serializesSocksAddr(metadata)) return c, err } @@ -166,19 +172,19 @@ func (t *Trojan) writeHeaderContext(ctx context.Context, c net.Conn, metadata *C // DialContext implements C.ProxyAdapter func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { + var c net.Conn // gun transport if t.transport != nil && dialer.IsZeroOptions(opts) { - c, err := gun.StreamGunWithTransport(t.transport, t.gunConfig) + c, err = gun.StreamGunWithTransport(t.transport, t.gunConfig) if err != nil { return nil, err } + defer func(c net.Conn) { + safeConnClose(c, err) + }(c) - if t.ssCipher != nil { - c = t.ssCipher.StreamConn(c) - } - - if err = t.writeHeaderContext(ctx, c, metadata); err != nil { - c.Close() + c, err = t.streamConnContext(ctx, c, metadata) + if err != nil { return nil, err } @@ -226,11 +232,7 @@ func (t *Trojan) ListenPacketContext(ctx context.Context, metadata *C.Metadata, safeConnClose(c, err) }(c) - if t.ssCipher != nil { - c = t.ssCipher.StreamConn(c) - } - - err = t.writeHeaderContext(ctx, c, metadata) + c, err = t.streamConnContext(ctx, c, metadata) if err != nil { return nil, err } @@ -256,16 +258,7 @@ func (t *Trojan) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, me defer func(c net.Conn) { safeConnClose(c, err) }(c) - c, err = t.plainStream(ctx, c) - if err != nil { - return nil, fmt.Errorf("%s connect error: %w", t.addr, err) - } - - if t.ssCipher != nil { - c = t.ssCipher.StreamConn(c) - } - - err = t.writeHeaderContext(ctx, c, metadata) + c, err = t.StreamConnContext(ctx, c, metadata) if err != nil { return nil, err } diff --git a/adapter/outbound/vless.go b/adapter/outbound/vless.go index 079d7bc2..9cdccef9 100644 --- a/adapter/outbound/vless.go +++ b/adapter/outbound/vless.go @@ -232,9 +232,10 @@ func (v *Vless) streamTLSConn(ctx context.Context, conn net.Conn, isH2 bool) (ne // DialContext implements C.ProxyAdapter func (v *Vless) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { + var c net.Conn // gun transport if v.transport != nil && dialer.IsZeroOptions(opts) { - c, err := gun.StreamGunWithTransport(v.transport, v.gunConfig) + c, err = gun.StreamGunWithTransport(v.transport, v.gunConfig) if err != nil { return nil, err } diff --git a/adapter/outbound/vmess.go b/adapter/outbound/vmess.go index 4db0cedd..42b8a434 100644 --- a/adapter/outbound/vmess.go +++ b/adapter/outbound/vmess.go @@ -226,10 +226,10 @@ func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M if err != nil { return nil, err } - return v.streamConnConntext(ctx, c, metadata) + return v.streamConnContext(ctx, c, metadata) } -func (v *Vmess) streamConnConntext(ctx context.Context, c net.Conn, metadata *C.Metadata) (conn net.Conn, err error) { +func (v *Vmess) streamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (conn net.Conn, err error) { useEarly := N.NeedHandshake(c) if !useEarly { if ctx.Done() != nil { @@ -287,9 +287,10 @@ func (v *Vmess) streamConnConntext(ctx context.Context, c net.Conn, metadata *C. // DialContext implements C.ProxyAdapter func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { + var c net.Conn // gun transport if v.transport != nil && dialer.IsZeroOptions(opts) { - c, err := gun.StreamGunWithTransport(v.transport, v.gunConfig) + c, err = gun.StreamGunWithTransport(v.transport, v.gunConfig) if err != nil { return nil, err } @@ -297,7 +298,7 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata, opts ...d safeConnClose(c, err) }(c) - c, err = v.streamConnConntext(ctx, c, metadata) + c, err = v.streamConnContext(ctx, c, metadata) if err != nil { return nil, err } @@ -348,7 +349,7 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o safeConnClose(c, err) }(c) - c, err = v.streamConnConntext(ctx, c, metadata) + c, err = v.streamConnContext(ctx, c, metadata) if err != nil { return nil, fmt.Errorf("new vmess client error: %v", err) }