chore: using http/httptrace to get local/remoteAddr for grpc client

This commit is contained in:
wwqgtxx 2025-04-03 19:47:49 +08:00
parent 7b37fcfc8d
commit 23ffe451f4
3 changed files with 35 additions and 36 deletions

View file

@ -13,6 +13,7 @@ import (
"io"
"net"
"net/http"
"net/http/httptrace"
"net/url"
"sync"
"time"
@ -38,7 +39,7 @@ var defaultHeader = http.Header{
type DialFn = func(network, addr string) (net.Conn, error)
type Conn struct {
initFn func() (io.ReadCloser, error)
initFn func() (io.ReadCloser, netAddr, error)
writer io.Writer
flusher http.Flusher
netAddr
@ -60,7 +61,7 @@ type Config struct {
}
func (g *Conn) initReader() {
reader, err := g.initFn()
reader, addr, err := g.initFn()
if err != nil {
g.err = err
if closer, ok := g.writer.(io.Closer); ok {
@ -68,6 +69,7 @@ func (g *Conn) initReader() {
}
return
}
g.netAddr = addr
if !g.close.Load() {
g.reader = reader
@ -209,15 +211,11 @@ func (g *Conn) SetDeadline(t time.Time) error {
}
func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config, Fingerprint string, realityConfig *tlsC.RealityConfig) *TransportWrap {
wrap := TransportWrap{}
dialFunc := func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
pconn, err := dialFn(network, addr)
if err != nil {
return nil, err
}
wrap.remoteAddr = pconn.RemoteAddr()
wrap.localAddr = pconn.LocalAddr()
if tlsConfig == nil {
return pconn, nil
@ -269,15 +267,17 @@ func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config, Fingerprint string, re
return conn, nil
}
wrap.Transport = &http2.Transport{
transport := &http2.Transport{
DialTLSContext: dialFunc,
TLSClientConfig: tlsConfig,
AllowHTTP: false,
DisableCompression: true,
PingTimeout: 0,
}
return &wrap
wrap := &TransportWrap{
Transport: transport,
}
return wrap
}
func StreamGunWithTransport(transport *TransportWrap, cfg *Config) (net.Conn, error) {
@ -304,15 +304,22 @@ func StreamGunWithTransport(transport *TransportWrap, cfg *Config) (net.Conn, er
}
conn := &Conn{
initFn: func() (io.ReadCloser, error) {
initFn: func() (io.ReadCloser, netAddr, error) {
nAddr := netAddr{}
trace := &httptrace.ClientTrace{
GotConn: func(connInfo httptrace.GotConnInfo) {
nAddr.localAddr = connInfo.Conn.LocalAddr()
nAddr.remoteAddr = connInfo.Conn.RemoteAddr()
},
}
request = request.WithContext(httptrace.WithClientTrace(request.Context(), trace))
response, err := transport.RoundTrip(request)
if err != nil {
return nil, err
return nil, nAddr, err
}
return response.Body, nil
return response.Body, nAddr, nil
},
writer: writer,
netAddr: transport.netAddr,
writer: writer,
}
go conn.Init()

View file

@ -43,21 +43,22 @@ func NewServerHandler(options ServerOption) http.Handler {
writer.WriteHeader(http.StatusOK)
conn := &Conn{
initFn: func() (io.ReadCloser, error) {
return request.Body, nil
initFn: func() (io.ReadCloser, netAddr, error) {
nAddr := netAddr{}
if request.RemoteAddr != "" {
metadata := C.Metadata{}
if err := metadata.SetRemoteAddress(request.RemoteAddr); err == nil {
nAddr.remoteAddr = net.TCPAddrFromAddrPort(metadata.AddrPort())
}
}
if addr, ok := request.Context().Value(http.LocalAddrContextKey).(net.Addr); ok {
nAddr.localAddr = addr
}
return request.Body, nAddr, nil
},
writer: writer,
flusher: writer.(http.Flusher),
}
if request.RemoteAddr != "" {
metadata := C.Metadata{}
if err := metadata.SetRemoteAddress(request.RemoteAddr); err == nil {
conn.remoteAddr = net.TCPAddrFromAddrPort(metadata.AddrPort())
}
}
if addr, ok := request.Context().Value(http.LocalAddrContextKey).(net.Addr); ok {
conn.localAddr = addr
}
wrapper := &h2ConnWrapper{
// gun.Conn can't correct handle ReadDeadline

View file

@ -7,15 +7,6 @@ import (
type TransportWrap struct {
*http2.Transport
netAddr
}
func (tw *TransportWrap) RemoteAddr() net.Addr {
return tw.remoteAddr
}
func (tw *TransportWrap) LocalAddr() net.Addr {
return tw.localAddr
}
type netAddr struct {
@ -23,10 +14,10 @@ type netAddr struct {
localAddr net.Addr
}
func (addr *netAddr) RemoteAddr() net.Addr {
func (addr netAddr) RemoteAddr() net.Addr {
return addr.remoteAddr
}
func (addr *netAddr) LocalAddr() net.Addr {
func (addr netAddr) LocalAddr() net.Addr {
return addr.localAddr
}