diff --git a/adapter/outbound/base.go b/adapter/outbound/base.go index d7ffec5a..03e9e6ca 100644 --- a/adapter/outbound/base.go +++ b/adapter/outbound/base.go @@ -4,12 +4,15 @@ import ( "context" "encoding/json" "errors" - "github.com/gofrs/uuid" "net" "strings" "github.com/Dreamacro/clash/component/dialer" C "github.com/Dreamacro/clash/constant" + + "github.com/gofrs/uuid" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/network" ) type Base struct { @@ -166,7 +169,7 @@ func NewBase(opt BaseOption) *Base { } type conn struct { - net.Conn + network.ExtendedConn chain C.Chain actualRemoteDestination string } @@ -185,8 +188,15 @@ func (c *conn) AppendToChains(a C.ProxyAdapter) { c.chain = append(c.chain, a.Name()) } +func (c *conn) Upstream() any { + if wrapper, ok := c.ExtendedConn.(*bufio.ExtendedConnWrapper); ok { + return wrapper.Conn + } + return c.ExtendedConn +} + func NewConn(c net.Conn, a C.ProxyAdapter) C.Conn { - return &conn{c, []string{a.Name()}, parseRemoteDestination(a.Addr())} + return &conn{bufio.NewExtendedConn(c), []string{a.Name()}, parseRemoteDestination(a.Addr())} } type packetConn struct { diff --git a/adapter/outbound/trojan.go b/adapter/outbound/trojan.go index c401999f..a36f5f57 100644 --- a/adapter/outbound/trojan.go +++ b/adapter/outbound/trojan.go @@ -14,6 +14,8 @@ import ( "github.com/Dreamacro/clash/transport/gun" "github.com/Dreamacro/clash/transport/trojan" "github.com/Dreamacro/clash/transport/vless" + + "github.com/sagernet/sing/common/bufio" ) type Trojan struct { @@ -95,7 +97,7 @@ func (t *Trojan) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) return c, err } err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata)) - return c, err + return bufio.NewExtendedConn(c), err } // DialContext implements C.ProxyAdapter diff --git a/common/net/bufconn.go b/common/net/bufconn.go index a50c7f03..bcac11d3 100644 --- a/common/net/bufconn.go +++ b/common/net/bufconn.go @@ -3,18 +3,24 @@ package net import ( "bufio" "net" + + "github.com/sagernet/sing/common/buf" + sing_bufio "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/network" ) +var _ network.ExtendedConn = (*BufferedConn)(nil) + type BufferedConn struct { r *bufio.Reader - net.Conn + network.ExtendedConn } func NewBufferedConn(c net.Conn) *BufferedConn { if bc, ok := c.(*BufferedConn); ok { return bc } - return &BufferedConn{bufio.NewReader(c), c} + return &BufferedConn{bufio.NewReader(c), sing_bufio.NewExtendedConn(c)} } // Reader returns the internal bufio.Reader. @@ -42,3 +48,18 @@ func (c *BufferedConn) UnreadByte() error { func (c *BufferedConn) Buffered() int { return c.r.Buffered() } + +func (c *BufferedConn) ReadBuffer(buffer *buf.Buffer) (err error) { + if c.r.Buffered() > 0 { + _, err = buffer.ReadOnceFrom(c.r) + return + } + return c.ExtendedConn.ReadBuffer(buffer) +} + +func (c *BufferedConn) Upstream() any { + if wrapper, ok := c.ExtendedConn.(*sing_bufio.ExtendedConnWrapper); ok { + return wrapper.Conn + } + return c.ExtendedConn +} diff --git a/transport/vless/conn.go b/transport/vless/conn.go index 5ee69611..72d14d0c 100644 --- a/transport/vless/conn.go +++ b/transport/vless/conn.go @@ -1,7 +1,6 @@ package vless import ( - "bytes" "encoding/binary" "errors" "fmt" @@ -9,12 +8,16 @@ import ( "net" "github.com/gofrs/uuid" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/network" xtls "github.com/xtls/go" "google.golang.org/protobuf/proto" ) type Conn struct { - net.Conn + network.ExtendedConn dst *DstAddr id *uuid.UUID addons *Addons @@ -23,57 +26,82 @@ type Conn struct { func (vc *Conn) Read(b []byte) (int, error) { if vc.received { - return vc.Conn.Read(b) + return vc.ExtendedConn.Read(b) } if err := vc.recvResponse(); err != nil { return 0, err } vc.received = true - return vc.Conn.Read(b) + return vc.ExtendedConn.Read(b) } -func (vc *Conn) sendRequest() error { - buf := &bytes.Buffer{} +func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error { + if vc.received { + return vc.ExtendedConn.ReadBuffer(buffer) + } - buf.WriteByte(Version) // protocol version - buf.Write(vc.id.Bytes()) // 16 bytes of uuid + if err := vc.recvResponse(); err != nil { + return err + } + vc.received = true + return vc.ExtendedConn.ReadBuffer(buffer) +} +func (vc *Conn) sendRequest() (err error) { + requestLen := 1 // protocol version + requestLen += 16 // UUID + requestLen += 1 // addons length + var addonsBytes []byte if vc.addons != nil { - bytes, err := proto.Marshal(vc.addons) + addonsBytes, err = proto.Marshal(vc.addons) if err != nil { return err } - - buf.WriteByte(byte(len(bytes))) - buf.Write(bytes) - } else { - buf.WriteByte(0) // addon data length. 0 means no addon data } + requestLen += len(addonsBytes) + requestLen += 1 // command + if !vc.dst.Mux { + requestLen += 2 // port + requestLen += 1 // addr type + requestLen += len(vc.dst.Addr) + } + _buffer := buf.StackNewSize(requestLen) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + + common.Must( + buffer.WriteByte(Version), // protocol version + common.Error(buffer.Write(vc.id.Bytes())), // 16 bytes of uuid + buffer.WriteByte(byte(len(addonsBytes))), + common.Error(buffer.Write(addonsBytes)), + ) if vc.dst.Mux { - buf.WriteByte(CommandMux) + common.Must(buffer.WriteByte(CommandMux)) } else { if vc.dst.UDP { - buf.WriteByte(CommandUDP) + common.Must(buffer.WriteByte(CommandUDP)) } else { - buf.WriteByte(CommandTCP) + common.Must(buffer.WriteByte(CommandTCP)) } - // Port AddrType Addr - binary.Write(buf, binary.BigEndian, vc.dst.Port) - buf.WriteByte(vc.dst.AddrType) - buf.Write(vc.dst.Addr) + binary.BigEndian.PutUint16(buffer.Extend(2), vc.dst.Port) + common.Must( + buffer.WriteByte(vc.dst.AddrType), + common.Error(buffer.Write(vc.dst.Addr)), + ) } - _, err := vc.Conn.Write(buf.Bytes()) - return err + _, err = vc.ExtendedConn.Write(buffer.Bytes()) + return } func (vc *Conn) recvResponse() error { var err error - buf := make([]byte, 1) - _, err = io.ReadFull(vc.Conn, buf) + var buf [1]byte + _, err = io.ReadFull(vc.ExtendedConn, buf[:]) if err != nil { return err } @@ -82,25 +110,32 @@ func (vc *Conn) recvResponse() error { return errors.New("unexpected response version") } - _, err = io.ReadFull(vc.Conn, buf) + _, err = io.ReadFull(vc.ExtendedConn, buf[:]) if err != nil { return err } length := int64(buf[0]) if length != 0 { // addon data length > 0 - io.CopyN(io.Discard, vc.Conn, length) // just discard + io.CopyN(io.Discard, vc.ExtendedConn, length) // just discard } return nil } +func (vc *Conn) Upstream() any { + if wrapper, ok := vc.ExtendedConn.(*bufio.ExtendedConnWrapper); ok { + return wrapper.Conn + } + return vc.ExtendedConn +} + // newConn return a Conn instance func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) { c := &Conn{ - Conn: conn, - id: client.uuid, - dst: dst, + ExtendedConn: bufio.NewExtendedConn(conn), + id: client.uuid, + dst: dst, } if !dst.UDP && client.Addons != nil { diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index b7b369fd..735ea7f2 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -5,9 +5,11 @@ import ( "context" "crypto/tls" "encoding/base64" + "encoding/binary" "errors" "fmt" "io" + "math/rand" "net" "net/http" "net/url" @@ -15,15 +17,24 @@ import ( "strings" "sync" "time" + _ "unsafe" "github.com/gorilla/websocket" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/network" ) +//go:linkname maskBytes github.com/gorilla/websocket.maskBytes +func maskBytes(key [4]byte, pos int, b []byte) int + type websocketConn struct { conn *websocket.Conn reader io.Reader remoteAddr net.Addr + rawWriter network.ExtendedWriter + // https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency rMux sync.Mutex wMux sync.Mutex @@ -31,6 +42,7 @@ type websocketConn struct { type websocketWithEarlyDataConn struct { net.Conn + wsWriter network.ExtendedWriter underlay net.Conn closed bool dialed chan bool @@ -79,6 +91,54 @@ func (wsc *websocketConn) Write(b []byte) (int, error) { return len(b), nil } +func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error { + var payloadBitLength int + dataLen := buffer.Len() + data := buffer.Bytes() + if dataLen < 126 { + payloadBitLength = 1 + } else if dataLen < 65536 { + payloadBitLength = 3 + } else { + payloadBitLength = 9 + } + + var headerLen int + headerLen += 1 // FIN / RSV / OPCODE + headerLen += payloadBitLength + headerLen += 4 // MASK KEY + + header := buffer.ExtendHeader(headerLen) + header[0] = websocket.BinaryMessage | 1<<7 + header[1] = 1 << 7 + + if dataLen < 126 { + header[1] |= byte(dataLen) + } else if dataLen < 65536 { + header[1] |= 126 + binary.BigEndian.PutUint16(header[2:], uint16(dataLen)) + } else { + header[1] |= 127 + binary.BigEndian.PutUint64(header[2:], uint64(dataLen)) + } + + maskKey := rand.Uint32() + binary.BigEndian.PutUint32(header[1+payloadBitLength:], maskKey) + maskBytes(*(*[4]byte)(header[1+payloadBitLength:]), 0, data) + + wsc.wMux.Lock() + defer wsc.wMux.Unlock() + return wsc.rawWriter.WriteBuffer(buffer) +} + +func (wsc *websocketConn) FrontHeadroom() int { + return 14 +} + +func (wsc *websocketConn) Upstream() any { + return wsc.conn.UnderlyingConn() +} + func (wsc *websocketConn) Close() error { var errors []string if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil { @@ -149,6 +209,7 @@ func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error { } wsedc.dialed <- true + wsedc.wsWriter = bufio.NewExtendedWriter(wsedc.Conn) if earlyDataBuf.Len() != 0 { _, err = wsedc.Conn.Write(earlyDataBuf.Bytes()) } @@ -170,6 +231,20 @@ func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) { return wsedc.Conn.Write(b) } +func (wsedc *websocketWithEarlyDataConn) WriteBuffer(buffer *buf.Buffer) error { + if wsedc.closed { + return io.ErrClosedPipe + } + if wsedc.Conn == nil { + if err := wsedc.Dial(buffer.Bytes()); err != nil { + return err + } + return nil + } + + return wsedc.wsWriter.WriteBuffer(buffer) +} + func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) { if wsedc.closed { return 0, io.ErrClosedPipe @@ -228,6 +303,10 @@ func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error { return wsedc.Conn.SetWriteDeadline(t) } +func (wsedc *websocketWithEarlyDataConn) Upstream() any { + return wsedc.Conn +} + func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { ctx, cancel := context.WithCancel(context.Background()) conn = &websocketWithEarlyDataConn{ @@ -294,6 +373,7 @@ func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buf return &websocketConn{ conn: wsConn, + rawWriter: bufio.NewExtendedWriter(wsConn.UnderlyingConn()), remoteAddr: conn.RemoteAddr(), }, nil } diff --git a/tunnel/connection.go b/tunnel/connection.go index c63bab78..1747597e 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -1,14 +1,16 @@ package tunnel import ( + "context" "errors" "net" "net/netip" "time" - N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" C "github.com/Dreamacro/clash/constant" + + "github.com/sagernet/sing/common/bufio" ) func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata) error { @@ -60,5 +62,5 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, oAddr, } func handleSocket(ctx C.ConnContext, outbound net.Conn) { - N.Relay(ctx.Conn(), outbound) + bufio.CopyConn(context.TODO(), ctx.Conn(), outbound) } diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index fc627297..32d44f0c 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -7,6 +7,9 @@ import ( C "github.com/Dreamacro/clash/constant" "github.com/gofrs/uuid" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/network" "go.uber.org/atomic" ) @@ -29,7 +32,9 @@ type trackerInfo struct { type tcpTracker struct { C.Conn `json:"-"` *trackerInfo - manager *Manager + manager *Manager + extendedReader network.ExtendedReader + extendedWriter network.ExtendedWriter } func (tt *tcpTracker) ID() string { @@ -44,6 +49,14 @@ func (tt *tcpTracker) Read(b []byte) (int, error) { return n, err } +func (tt *tcpTracker) ReadBuffer(buffer *buf.Buffer) (err error) { + err = tt.extendedReader.ReadBuffer(buffer) + download := int64(buffer.Len()) + tt.manager.PushDownloaded(download) + tt.DownloadTotal.Add(download) + return +} + func (tt *tcpTracker) Write(b []byte) (int, error) { n, err := tt.Conn.Write(b) upload := int64(n) @@ -52,11 +65,26 @@ func (tt *tcpTracker) Write(b []byte) (int, error) { return n, err } +func (tt *tcpTracker) WriteBuffer(buffer *buf.Buffer) (err error) { + err = tt.extendedWriter.WriteBuffer(buffer) + var upload int64 + if err != nil { + upload = int64(buffer.Len()) + } + tt.manager.PushUploaded(upload) + tt.UploadTotal.Add(upload) + return +} + func (tt *tcpTracker) Close() error { tt.manager.Leave(tt) return tt.Conn.Close() } +func (tt *tcpTracker) Upstream() any { + return tt.Conn +} + func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker { uuid, _ := uuid.NewV4() if conn != nil { @@ -79,6 +107,8 @@ func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R UploadTotal: atomic.NewInt64(0), DownloadTotal: atomic.NewInt64(0), }, + extendedReader: bufio.NewExtendedReader(conn), + extendedWriter: bufio.NewExtendedWriter(conn), } if rule != nil {