diff --git a/listener/http/proxy.go b/listener/http/proxy.go index 720a5655..b57ff4f3 100644 --- a/listener/http/proxy.go +++ b/listener/http/proxy.go @@ -62,9 +62,9 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, request.RequestURI = "" if isUpgradeRequest(request) { - handleUpgrade(conn, request, in) - - return // hijack connection + if resp = handleUpgrade(conn, conn.RemoteAddr(), request, in); resp == nil { + return // hijack connection + } } removeHopByHopHeaders(request.Header) @@ -96,7 +96,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, } } - conn.Close() + _ = conn.Close() } func authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http.Response { diff --git a/listener/http/upgrade.go b/listener/http/upgrade.go index 643de541..c12fc33d 100644 --- a/listener/http/upgrade.go +++ b/listener/http/upgrade.go @@ -1,9 +1,12 @@ package http import ( + "context" + "crypto/tls" "net" "net/http" "strings" + "time" "github.com/Dreamacro/clash/adapter/inbound" N "github.com/Dreamacro/clash/common/net" @@ -15,15 +18,17 @@ func isUpgradeRequest(req *http.Request) bool { return strings.EqualFold(req.Header.Get("Connection"), "Upgrade") } -func handleUpgrade(conn net.Conn, request *http.Request, in chan<- C.ConnContext) { - defer conn.Close() - +func handleUpgrade(conn net.Conn, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) { removeProxyHeaders(request.Header) removeExtraHTTPHostPort(request) address := request.Host if _, _, err := net.SplitHostPort(address); err != nil { - address = net.JoinHostPort(address, "80") + port := "80" + if request.TLS != nil { + port = "443" + } + address = net.JoinHostPort(address, port) } dstAddr := socks5.ParseAddr(address) @@ -35,27 +40,56 @@ func handleUpgrade(conn net.Conn, request *http.Request, in chan<- C.ConnContext in <- inbound.NewHTTP(dstAddr, conn.RemoteAddr(), right) - bufferedLeft := N.NewBufferedConn(left) - defer bufferedLeft.Close() + var remoteServer *N.BufferedConn + if request.TLS != nil { + tlsConn := tls.Client(left, &tls.Config{ + ServerName: request.URL.Hostname(), + }) - err := request.Write(bufferedLeft) + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) + defer cancel() + if tlsConn.HandshakeContext(ctx) != nil { + _ = localConn.Close() + _ = left.Close() + return + } + + remoteServer = N.NewBufferedConn(tlsConn) + } else { + remoteServer = N.NewBufferedConn(left) + } + defer func() { + _ = remoteServer.Close() + }() + + err := request.Write(remoteServer) if err != nil { + _ = localConn.Close() return } - resp, err := http.ReadResponse(bufferedLeft.Reader(), request) - if err != nil { - return - } - - removeProxyHeaders(resp.Header) - - err = resp.Write(conn) + resp, err = http.ReadResponse(remoteServer.Reader(), request) if err != nil { + _ = localConn.Close() return } if resp.StatusCode == http.StatusSwitchingProtocols { - N.Relay(bufferedLeft, conn) + removeProxyHeaders(resp.Header) + + err = localConn.SetReadDeadline(time.Time{}) // set to not time out + if err != nil { + return + } + + err = resp.Write(localConn) + if err != nil { + return + } + + N.Relay(remoteServer, localConn) // blocking here + _ = localConn.Close() + resp = nil } + return } diff --git a/listener/tun/ipstack/system/mars/nat/table.go b/listener/tun/ipstack/system/mars/nat/table.go index 9c1b32cd..38b7d6c6 100644 --- a/listener/tun/ipstack/system/mars/nat/table.go +++ b/listener/tun/ipstack/system/mars/nat/table.go @@ -2,7 +2,6 @@ package nat import ( "net/netip" - "sync" "github.com/Dreamacro/clash/common/generics/list" ) @@ -25,7 +24,6 @@ type binding struct { } type table struct { - mu sync.Mutex tuples map[tuple]*list.Element[*binding] ports [portLength]*list.Element[*binding] available *list.List[*binding] @@ -39,13 +37,13 @@ func (t *table) tupleOf(port uint16) tuple { elm := t.ports[offset] + t.available.MoveToFront(elm) + return elm.Value.tuple } func (t *table) portOf(tuple tuple) uint16 { - t.mu.Lock() elm := t.tuples[tuple] - t.mu.Unlock() if elm == nil { return 0 } @@ -59,11 +57,8 @@ func (t *table) newConn(tuple tuple) uint16 { elm := t.available.Back() b := elm.Value - t.mu.Lock() delete(t.tuples, b.tuple) t.tuples[tuple] = elm - t.mu.Unlock() - b.tuple = tuple t.available.MoveToFront(elm) @@ -71,19 +66,6 @@ func (t *table) newConn(tuple tuple) uint16 { return portBegin + b.offset } -func (t *table) delete(tup tuple) { - t.mu.Lock() - elm := t.tuples[tup] - if elm == nil { - t.mu.Unlock() - return - } - delete(t.tuples, tup) - t.mu.Unlock() - - t.available.MoveToBack(elm) -} - func newTable() *table { result := &table{ tuples: make(map[tuple]*list.Element[*binding], portLength), diff --git a/listener/tun/ipstack/system/mars/nat/tcp.go b/listener/tun/ipstack/system/mars/nat/tcp.go index 48ad3e43..cc0abe7d 100644 --- a/listener/tun/ipstack/system/mars/nat/tcp.go +++ b/listener/tun/ipstack/system/mars/nat/tcp.go @@ -16,8 +16,6 @@ type conn struct { net.Conn tuple tuple - - close func(tuple tuple) } func (t *TCP) Accept() (net.Conn, error) { @@ -39,9 +37,6 @@ func (t *TCP) Accept() (net.Conn, error) { return &conn{ Conn: c, tuple: tup, - close: func(tuple tuple) { - t.table.delete(tuple) - }, }, nil } @@ -57,11 +52,6 @@ func (t *TCP) SetDeadline(time time.Time) error { return t.listener.SetDeadline(time) } -func (c *conn) Close() error { - c.close(c.tuple) - return c.Conn.Close() -} - func (c *conn) LocalAddr() net.Addr { return &net.TCPAddr{ IP: c.tuple.SourceAddr.Addr().AsSlice(),