From 4e2d000a0bc1fbf8d744379665c02d336b7a6d3c Mon Sep 17 00:00:00 2001 From: NyaMisty Date: Thu, 20 Jun 2024 22:25:30 +0000 Subject: [PATCH] feat: improve authentication parsing logic in http listener - credentials are always parsed - fix IN-USER matching for non CONNECT http proxy requests - fix authentication caching by including username in cache --- listener/http/proxy.go | 90 +++++++++++++++++++++++++++-------------- listener/http/server.go | 4 +- listener/mixed/mixed.go | 6 +-- 3 files changed, 65 insertions(+), 35 deletions(-) diff --git a/listener/http/proxy.go b/listener/http/proxy.go index c77f9230..a2161ac1 100644 --- a/listener/http/proxy.go +++ b/listener/http/proxy.go @@ -31,9 +31,13 @@ func (b *bodyWrapper) Read(p []byte) (n int, err error) { return n, err } -func HandleConn(c net.Conn, tunnel C.Tunnel, cache *lru.LruCache[string, bool], additions ...inbound.Addition) { - client := newClient(c, tunnel, additions...) - defer client.CloseIdleConnections() +func HandleConn(c net.Conn, tunnel C.Tunnel, cache *lru.LruCache[string, AuthResult], additions ...inbound.Addition) { + var client *http.Client // create the outbound client on-demand + defer func() { + if client != nil { + client.CloseIdleConnections() + } + }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() peekMutex := sync.Mutex{} @@ -42,6 +46,7 @@ func HandleConn(c net.Conn, tunnel C.Tunnel, cache *lru.LruCache[string, bool], keepAlive := true trusted := cache == nil // disable authenticate if lru is nil + lastUser := "" for keepAlive { peekMutex.Lock() @@ -57,12 +62,10 @@ func HandleConn(c net.Conn, tunnel C.Tunnel, cache *lru.LruCache[string, bool], var resp *http.Response - if !trusted { - var user string - resp, user = authenticate(request, cache) - additions = append(additions, inbound.WithInUser(user)) - trusted = resp == nil - } + var user string + resp, user = authenticate(request, cache) // always call authenticate function to get user + trusted = trusted || resp == nil + additions = append(additions, inbound.WithInUser(user)) if trusted { if request.Method == http.MethodConnect { @@ -89,6 +92,15 @@ func HandleConn(c net.Conn, tunnel C.Tunnel, cache *lru.LruCache[string, bool], return // hijack connection } + // ensure there is a client with correct additions + // when the authenticated user changed, outbound client should also get rebuilt + if client == nil || user != lastUser { + if client != nil { + client.CloseIdleConnections() + } + client = newClient(c, tunnel, additions...) + } + removeHopByHopHeaders(request.Header) removeExtraHTTPHostPort(request) @@ -138,33 +150,51 @@ func HandleConn(c net.Conn, tunnel C.Tunnel, cache *lru.LruCache[string, bool], _ = conn.Close() } -func authenticate(request *http.Request, cache *lru.LruCache[string, bool]) (resp *http.Response, u string) { +type AuthResult struct { + user string + authed bool +} + +func authenticate(request *http.Request, cache *lru.LruCache[string, AuthResult]) (resp *http.Response, u string) { authenticator := authStore.Authenticator() if inbound.SkipAuthRemoteAddress(request.RemoteAddr) { authenticator = nil } - if authenticator != nil { - credential := parseBasicProxyAuthorization(request) - if credential == "" { - resp := responseWith(request, http.StatusProxyAuthRequired) - resp.Header.Set("Proxy-Authenticate", "Basic") - return resp, "" - } - - authed, exist := cache.Get(credential) - if !exist { - user, pass, err := decodeBasicProxyAuthorization(credential) - authed = err == nil && authenticator.Verify(user, pass) - u = user - cache.Set(credential, authed) - } - if !authed { - log.Infoln("Auth failed from %s", request.RemoteAddr) - - return responseWith(request, http.StatusForbidden), u - } + credential := parseBasicProxyAuthorization(request) + if credential == "" && authenticator != nil { + resp := responseWith(request, http.StatusProxyAuthRequired) + resp.Header.Set("Proxy-Authenticate", "Basic") + return resp, "" } + var authret AuthResult + exist := false + if cache != nil { + authret, exist = cache.Get(credential) + } + if !exist { + user, pass, err := decodeBasicProxyAuthorization(credential) + authed := false + if authenticator == nil { + // skipped authentication + authed = true + } else if err == nil && authenticator.Verify(user, pass) { + authed = true + } + authret = AuthResult{ + user: user, + authed: authed, + } + cache.Set(credential, authret) + } + u = authret.user + if !authret.authed { + log.Infoln("Auth failed from %s", request.RemoteAddr) + + return responseWith(request, http.StatusForbidden), u + } + log.Infoln("Auth success from %s -> %s", request.RemoteAddr, u) + return nil, u } diff --git a/listener/http/server.go b/listener/http/server.go index 8fc9da59..c1c50dff 100644 --- a/listener/http/server.go +++ b/listener/http/server.go @@ -50,9 +50,9 @@ func NewWithAuthenticate(addr string, tunnel C.Tunnel, authenticate bool, additi return nil, err } - var c *lru.LruCache[string, bool] + var c *lru.LruCache[string, AuthResult] if authenticate { - c = lru.New[string, bool](lru.WithAge[string, bool](30)) + c = lru.New[string, AuthResult](lru.WithAge[string, AuthResult](30)) } hl := &Listener{ diff --git a/listener/mixed/mixed.go b/listener/mixed/mixed.go index 367b7a36..2bdac6cf 100644 --- a/listener/mixed/mixed.go +++ b/listener/mixed/mixed.go @@ -16,7 +16,7 @@ import ( type Listener struct { listener net.Listener addr string - cache *lru.LruCache[string, bool] + cache *lru.LruCache[string, http.AuthResult] closed bool } @@ -53,7 +53,7 @@ func New(addr string, tunnel C.Tunnel, additions ...inbound.Addition) (*Listener ml := &Listener{ listener: l, addr: addr, - cache: lru.New[string, bool](lru.WithAge[string, bool](30)), + cache: lru.New[string, http.AuthResult](lru.WithAge[string, http.AuthResult](30)), } go func() { for { @@ -77,7 +77,7 @@ func New(addr string, tunnel C.Tunnel, additions ...inbound.Addition) (*Listener return ml, nil } -func handleConn(conn net.Conn, tunnel C.Tunnel, cache *lru.LruCache[string, bool], additions ...inbound.Addition) { +func handleConn(conn net.Conn, tunnel C.Tunnel, cache *lru.LruCache[string, http.AuthResult], additions ...inbound.Addition) { N.TCPKeepAlive(conn) bufConn := N.NewBufferedConn(conn)