Compare commits

..

No commits in common. "Alpha" and "v1.18.8" have entirely different histories.

290 changed files with 3907 additions and 13068 deletions

View file

@ -1,17 +1,17 @@
[Unit] [Unit]
Description=mihomo Daemon, Another Clash Kernel. Description=mihomo Daemon, Another Clash Kernel.
Documentation=https://wiki.metacubex.one After=network.target NetworkManager.service systemd-networkd.service iwd.service
After=network.target nss-lookup.target network-online.target
[Service] [Service]
Type=simple Type=simple
LimitNPROC=500
LimitNOFILE=1000000
CapabilityBoundingSet=CAP_NET_ADMIN CAP_NET_RAW CAP_NET_BIND_SERVICE CAP_SYS_TIME CAP_SYS_PTRACE CAP_DAC_READ_SEARCH CAP_DAC_OVERRIDE CapabilityBoundingSet=CAP_NET_ADMIN CAP_NET_RAW CAP_NET_BIND_SERVICE CAP_SYS_TIME CAP_SYS_PTRACE CAP_DAC_READ_SEARCH CAP_DAC_OVERRIDE
AmbientCapabilities=CAP_NET_ADMIN CAP_NET_RAW CAP_NET_BIND_SERVICE CAP_SYS_TIME CAP_SYS_PTRACE CAP_DAC_READ_SEARCH CAP_DAC_OVERRIDE AmbientCapabilities=CAP_NET_ADMIN CAP_NET_RAW CAP_NET_BIND_SERVICE CAP_SYS_TIME CAP_SYS_PTRACE CAP_DAC_READ_SEARCH CAP_DAC_OVERRIDE
Restart=always
ExecStartPre=/usr/bin/sleep 2s
ExecStart=/usr/bin/mihomo -d /etc/mihomo ExecStart=/usr/bin/mihomo -d /etc/mihomo
ExecReload=/bin/kill -HUP $MAINPID ExecReload=/bin/kill -HUP $MAINPID
Restart=on-failure
RestartSec=10
LimitNOFILE=infinity
[Install] [Install]
WantedBy=multi-user.target WantedBy=multi-user.target

View file

@ -1,17 +0,0 @@
[Unit]
Description=mihomo Daemon, Another Clash Kernel.
Documentation=https://wiki.metacubex.one
After=network.target nss-lookup.target network-online.target
[Service]
Type=simple
CapabilityBoundingSet=CAP_NET_ADMIN CAP_NET_RAW CAP_NET_BIND_SERVICE CAP_SYS_TIME CAP_SYS_PTRACE CAP_DAC_READ_SEARCH CAP_DAC_OVERRIDE
AmbientCapabilities=CAP_NET_ADMIN CAP_NET_RAW CAP_NET_BIND_SERVICE CAP_SYS_TIME CAP_SYS_PTRACE CAP_DAC_READ_SEARCH CAP_DAC_OVERRIDE
ExecStart=/usr/bin/mihomo -d /etc/mihomo
ExecReload=/bin/kill -HUP $MAINPID
Restart=on-failure
RestartSec=10
LimitNOFILE=infinity
[Install]
WantedBy=multi-user.target

View file

@ -14,7 +14,7 @@ on:
- Alpha - Alpha
tags: tags:
- "v*" - "v*"
pull_request: pull_request_target:
branches: branches:
- Alpha - Alpha
concurrency: concurrency:
@ -54,6 +54,7 @@ jobs:
- { goos: windows, goarch: '386', output: '386' } - { goos: windows, goarch: '386', output: '386' }
- { goos: windows, goarch: amd64, goamd64: v1, output: amd64-compatible } - { goos: windows, goarch: amd64, goamd64: v1, output: amd64-compatible }
- { goos: windows, goarch: amd64, goamd64: v3, output: amd64 } - { goos: windows, goarch: amd64, goamd64: v3, output: amd64 }
- { goos: windows, goarch: arm, goarm: '7', output: armv7 }
- { goos: windows, goarch: arm64, output: arm64 } - { goos: windows, goarch: arm64, output: arm64 }
- { goos: freebsd, goarch: '386', output: '386' } - { goos: freebsd, goarch: '386', output: '386' }
@ -66,12 +67,6 @@ jobs:
- { goos: android, goarch: arm, ndk: armv7a-linux-androideabi34, output: armv7 } - { goos: android, goarch: arm, ndk: armv7a-linux-androideabi34, output: armv7 }
- { goos: android, goarch: arm64, ndk: aarch64-linux-android34, output: arm64-v8 } - { goos: android, goarch: arm64, ndk: aarch64-linux-android34, output: arm64-v8 }
# Go 1.23 with special patch can work on Windows 7
# https://github.com/MetaCubeX/go/commits/release-branch.go1.23/
- { goos: windows, goarch: '386', output: '386-go123', goversion: '1.23' }
- { goos: windows, goarch: amd64, goamd64: v1, output: amd64-compatible-go123, goversion: '1.23' }
- { goos: windows, goarch: amd64, goamd64: v3, output: amd64-go123, goversion: '1.23' }
# Go 1.22 with special patch can work on Windows 7 # Go 1.22 with special patch can work on Windows 7
# https://github.com/MetaCubeX/go/commits/release-branch.go1.22/ # https://github.com/MetaCubeX/go/commits/release-branch.go1.22/
- { goos: windows, goarch: '386', output: '386-go122', goversion: '1.22' } - { goos: windows, goarch: '386', output: '386-go122', goversion: '1.22' }
@ -100,11 +95,6 @@ jobs:
- { goos: darwin, goarch: amd64, goamd64: v1, output: amd64-compatible-go120, goversion: '1.20' } - { goos: darwin, goarch: amd64, goamd64: v1, output: amd64-compatible-go120, goversion: '1.20' }
- { goos: darwin, goarch: amd64, goamd64: v3, output: amd64-go120, goversion: '1.20' } - { goos: darwin, goarch: amd64, goamd64: v3, output: amd64-go120, goversion: '1.20' }
# Go 1.23 is the last release that requires Linux kernel version 2.6.32 or later. Go 1.24 will require Linux kernel version 3.2 or later.
- { goos: linux, goarch: '386', output: '386-go123', goversion: '1.23' }
- { goos: linux, goarch: amd64, goamd64: v1, output: amd64-compatible-go123, goversion: '1.23', test: test }
- { goos: linux, goarch: amd64, goamd64: v3, output: amd64-go123, goversion: '1.23' }
# only for test # only for test
- { goos: linux, goarch: '386', output: '386-go120', goversion: '1.20' } - { goos: linux, goarch: '386', output: '386-go120', goversion: '1.20' }
- { goos: linux, goarch: amd64, goamd64: v1, output: amd64-compatible-go120, goversion: '1.20', test: test } - { goos: linux, goarch: amd64, goamd64: v1, output: amd64-compatible-go120, goversion: '1.20', test: test }
@ -114,41 +104,30 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Go - name: Set up Go
if: ${{ matrix.jobs.goversion == '' && matrix.jobs.abi != '1' }} if: ${{ matrix.jobs.goversion == '' && matrix.jobs.goarch != 'loong64' }}
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: '1.24' go-version: '1.23'
- name: Set up Go - name: Set up Go
if: ${{ matrix.jobs.goversion != '' && matrix.jobs.abi != '1' }} if: ${{ matrix.jobs.goversion != '' && matrix.jobs.goarch != 'loong64' }}
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: ${{ matrix.jobs.goversion }} go-version: ${{ matrix.jobs.goversion }}
- name: Set up Go1.23 loongarch abi1 - name: Set up Go1.22 loongarch abi1
if: ${{ matrix.jobs.goarch == 'loong64' && matrix.jobs.abi == '1' }} if: ${{ matrix.jobs.goarch == 'loong64' && matrix.jobs.abi == '1' }}
run: | run: |
wget -q https://github.com/MetaCubeX/loongarch64-golang/releases/download/1.23.0/go1.23.0.linux-amd64-abi1.tar.gz wget -q https://github.com/MetaCubeX/loongarch64-golang/releases/download/1.22.4/go1.22.4.linux-amd64-abi1.tar.gz
sudo tar zxf go1.23.0.linux-amd64-abi1.tar.gz -C /usr/local sudo tar zxf go1.22.4.linux-amd64-abi1.tar.gz -C /usr/local
echo "/usr/local/go/bin" >> $GITHUB_PATH echo "/usr/local/go/bin" >> $GITHUB_PATH
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557 - name: Set up Go1.22 loongarch abi2
# this patch file only works on golang1.24.x if: ${{ matrix.jobs.goarch == 'loong64' && matrix.jobs.abi == '2' }}
# that means after golang1.25 release it must be changed
# see: https://github.com/MetaCubeX/go/commits/release-branch.go1.24/
# revert:
# 693def151adff1af707d82d28f55dba81ceb08e1: "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
# 7c1157f9544922e96945196b47b95664b1e39108: "net: remove sysSocket fallback for Windows 7"
# 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround"
# a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries"
- name: Revert Golang1.24 commit for Windows7/8
if: ${{ matrix.jobs.goos == 'windows' && matrix.jobs.goversion == '' }}
run: | run: |
cd $(go env GOROOT) wget -q https://github.com/MetaCubeX/loongarch64-golang/releases/download/1.22.4/go1.22.4.linux-amd64-abi2.tar.gz
curl https://github.com/MetaCubeX/go/commit/2a406dc9f1ea7323d6ca9fccb2fe9ddebb6b1cc8.diff | patch --verbose -p 1 sudo tar zxf go1.22.4.linux-amd64-abi2.tar.gz -C /usr/local
curl https://github.com/MetaCubeX/go/commit/7b1fd7d39c6be0185fbe1d929578ab372ac5c632.diff | patch --verbose -p 1 echo "/usr/local/go/bin" >> $GITHUB_PATH
curl https://github.com/MetaCubeX/go/commit/979d6d8bab3823ff572ace26767fd2ce3cf351ae.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/ac3e93c061779dfefc0dd13a5b6e6f764a25621e.diff | patch --verbose -p 1
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557 # modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557
# this patch file only works on golang1.23.x # this patch file only works on golang1.23.x
@ -160,7 +139,7 @@ jobs:
# 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround" # 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround"
# a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries" # a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries"
- name: Revert Golang1.23 commit for Windows7/8 - name: Revert Golang1.23 commit for Windows7/8
if: ${{ matrix.jobs.goos == 'windows' && matrix.jobs.goversion == '1.23' }} if: ${{ matrix.jobs.goos == 'windows' && matrix.jobs.goversion == '' }}
run: | run: |
cd $(go env GOROOT) cd $(go env GOROOT)
curl https://github.com/MetaCubeX/go/commit/9ac42137ef6730e8b7daca016ece831297a1d75b.diff | patch --verbose -p 1 curl https://github.com/MetaCubeX/go/commit/9ac42137ef6730e8b7daca016ece831297a1d75b.diff | patch --verbose -p 1
@ -215,7 +194,7 @@ jobs:
uses: nttld/setup-ndk@v1 uses: nttld/setup-ndk@v1
id: setup-ndk id: setup-ndk
with: with:
ndk-version: r28-beta1 ndk-version: r27
- name: Set NDK path - name: Set NDK path
if: ${{ matrix.jobs.goos == 'android' }} if: ${{ matrix.jobs.goos == 'android' }}
@ -276,20 +255,17 @@ jobs:
mkdir -p mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/DEBIAN mkdir -p mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/DEBIAN
mkdir -p mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/usr/bin mkdir -p mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/usr/bin
mkdir -p mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/usr/share/licenses/mihomo
mkdir -p mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/etc/mihomo mkdir -p mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/etc/mihomo
mkdir -p mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/lib/systemd/system mkdir -p mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/etc/systemd/system/
mkdir -p mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/usr/share/licenses/mihomo
cp mihomo mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/usr/bin/ cp mihomo mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/usr/bin/mihomo
cp LICENSE mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/usr/share/licenses/mihomo/ cp LICENSE mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/usr/share/licenses/mihomo/
cp .github/{mihomo.service,mihomo@.service} mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/lib/systemd/system/ cp .github/mihomo.service mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/etc/systemd/system/
cat > mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/etc/mihomo/config.yaml <<EOF cat > mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/etc/mihomo/config.yaml <<EOF
mixed-port: 7890 mixed-port: 7890
EOF external-controller: 127.0.0.1:9090
cat > mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/DEBIAN/conffiles <<EOF
/etc/mihomo/config.yaml
EOF EOF
cat > mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/DEBIAN/control <<EOF cat > mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/DEBIAN/control <<EOF
@ -329,6 +305,7 @@ jobs:
run: | run: |
echo ${VERSION} > version.txt echo ${VERSION} > version.txt
shell: bash shell: bash
- name: Archive production artifacts - name: Archive production artifacts
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
@ -339,7 +316,6 @@ jobs:
mihomo*.rpm mihomo*.rpm
mihomo*.zip mihomo*.zip
version.txt version.txt
checksums.txt
Upload-Prerelease: Upload-Prerelease:
permissions: write-all permissions: write-all
@ -353,13 +329,6 @@ jobs:
path: bin/ path: bin/
merge-multiple: true merge-multiple: true
- name: Calculate checksums
run: |
cd bin/
find . -type f -not -name "checksums.*" -not -name "version.txt" | sort | xargs sha256sum > checksums.txt
cat checksums.txt
shell: bash
- name: Delete current release assets - name: Delete current release assets
uses: 8Mi-Tech/delete-release-assets-action@main uses: 8Mi-Tech/delete-release-assets-action@main
with: with:

View file

@ -1,115 +0,0 @@
name: Test
on:
push:
paths-ignore:
- "docs/**"
- "README.md"
- ".github/ISSUE_TEMPLATE/**"
branches:
- Alpha
tags:
- "v*"
pull_request:
branches:
- Alpha
jobs:
test:
strategy:
matrix:
os:
- 'ubuntu-latest' # amd64 linux
- 'windows-latest' # amd64 windows
- 'macos-latest' # arm64 macos
- 'ubuntu-24.04-arm' # arm64 linux
- 'macos-13' # amd64 macos
go-version:
- '1.24'
- '1.23'
- '1.22'
- '1.21'
- '1.20'
fail-fast: false
runs-on: ${{ matrix.os }}
defaults:
run:
shell: bash
env:
CGO_ENABLED: 0
GOTOOLCHAIN: local
# Fix mingw trying to be smart and converting paths https://github.com/moby/moby/issues/24029#issuecomment-250412919
MSYS_NO_PATHCONV: true
steps:
- uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557
# this patch file only works on golang1.24.x
# that means after golang1.25 release it must be changed
# see: https://github.com/MetaCubeX/go/commits/release-branch.go1.24/
# revert:
# 693def151adff1af707d82d28f55dba81ceb08e1: "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
# 7c1157f9544922e96945196b47b95664b1e39108: "net: remove sysSocket fallback for Windows 7"
# 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround"
# a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries"
- name: Revert Golang1.24 commit for Windows7/8
if: ${{ runner.os == 'Windows' && matrix.go-version == '1.24' }}
run: |
cd $(go env GOROOT)
curl https://github.com/MetaCubeX/go/commit/2a406dc9f1ea7323d6ca9fccb2fe9ddebb6b1cc8.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/7b1fd7d39c6be0185fbe1d929578ab372ac5c632.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/979d6d8bab3823ff572ace26767fd2ce3cf351ae.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/ac3e93c061779dfefc0dd13a5b6e6f764a25621e.diff | patch --verbose -p 1
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557
# this patch file only works on golang1.23.x
# that means after golang1.24 release it must be changed
# see: https://github.com/MetaCubeX/go/commits/release-branch.go1.23/
# revert:
# 693def151adff1af707d82d28f55dba81ceb08e1: "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
# 7c1157f9544922e96945196b47b95664b1e39108: "net: remove sysSocket fallback for Windows 7"
# 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround"
# a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries"
- name: Revert Golang1.23 commit for Windows7/8
if: ${{ runner.os == 'Windows' && matrix.go-version == '1.23' }}
run: |
cd $(go env GOROOT)
curl https://github.com/MetaCubeX/go/commit/9ac42137ef6730e8b7daca016ece831297a1d75b.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/21290de8a4c91408de7c2b5b68757b1e90af49dd.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/6a31d3fa8e47ddabc10bd97bff10d9a85f4cfb76.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/69e2eed6dd0f6d815ebf15797761c13f31213dd6.diff | patch --verbose -p 1
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557
# this patch file only works on golang1.22.x
# that means after golang1.23 release it must be changed
# see: https://github.com/MetaCubeX/go/commits/release-branch.go1.22/
# revert:
# 693def151adff1af707d82d28f55dba81ceb08e1: "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
# 7c1157f9544922e96945196b47b95664b1e39108: "net: remove sysSocket fallback for Windows 7"
# 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround"
# a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries"
- name: Revert Golang1.22 commit for Windows7/8
if: ${{ runner.os == 'Windows' && matrix.go-version == '1.22' }}
run: |
cd $(go env GOROOT)
curl https://github.com/MetaCubeX/go/commit/9779155f18b6556a034f7bb79fb7fb2aad1e26a9.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/ef0606261340e608017860b423ffae5c1ce78239.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/7f83badcb925a7e743188041cb6e561fc9b5b642.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/83ff9782e024cb328b690cbf0da4e7848a327f4f.diff | patch --verbose -p 1
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557
- name: Revert Golang1.21 commit for Windows7/8
if: ${{ runner.os == 'Windows' && matrix.go-version == '1.21' }}
run: |
cd $(go env GOROOT)
curl https://github.com/golang/go/commit/9e43850a3298a9b8b1162ba0033d4c53f8637571.diff | patch --verbose -R -p 1
- name: Test
run: go test ./... -v -count=1
- name: Test with tag with_gvisor
run: go test ./... -v -count=1 -tags "with_gvisor"

View file

@ -98,4 +98,3 @@ API.
This software is released under the GPL-3.0 license. This software is released under the GPL-3.0 license.
**In addition, any downstream projects not affiliated with `MetaCubeX` shall not contain the word `mihomo` in their names.**

View file

@ -10,7 +10,6 @@ import (
"net/netip" "net/netip"
"net/url" "net/url"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/common/atomic"
@ -19,7 +18,6 @@ import (
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/dialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
) )
@ -41,11 +39,6 @@ type Proxy struct {
extra *xsync.MapOf[string, *internalProxyState] extra *xsync.MapOf[string, *internalProxyState]
} }
// Adapter implements C.Proxy
func (p *Proxy) Adapter() C.ProxyAdapter {
return p.ProxyAdapter
}
// AliveForTestUrl implements C.Proxy // AliveForTestUrl implements C.Proxy
func (p *Proxy) AliveForTestUrl(url string) bool { func (p *Proxy) AliveForTestUrl(url string) bool {
if state, ok := p.extra.Load(url); ok { if state, ok := p.extra.Load(url); ok {
@ -163,17 +156,8 @@ func (p *Proxy) MarshalJSON() ([]byte, error) {
mapping["alive"] = p.alive.Load() mapping["alive"] = p.alive.Load()
mapping["name"] = p.Name() mapping["name"] = p.Name()
mapping["udp"] = p.SupportUDP() mapping["udp"] = p.SupportUDP()
mapping["uot"] = p.SupportUOT() mapping["xudp"] = p.SupportXUDP()
mapping["tfo"] = p.SupportTFO()
proxyInfo := p.ProxyInfo()
mapping["xudp"] = proxyInfo.XUDP
mapping["tfo"] = proxyInfo.TFO
mapping["mptcp"] = proxyInfo.MPTCP
mapping["smux"] = proxyInfo.SMUX
mapping["interface"] = proxyInfo.Interface
mapping["dialer-proxy"] = proxyInfo.DialerProxy
mapping["routing-mark"] = proxyInfo.RoutingMark
return json.Marshal(mapping) return json.Marshal(mapping)
} }
@ -271,18 +255,10 @@ func (p *Proxy) URLTest(ctx context.Context, url string, expectedStatus utils.In
if unifiedDelay { if unifiedDelay {
second := time.Now() second := time.Now()
var ignoredErr error resp, err = client.Do(req)
var secondResp *http.Response if err == nil {
secondResp, ignoredErr = client.Do(req)
if ignoredErr == nil {
resp = secondResp
_ = resp.Body.Close() _ = resp.Body.Close()
start = second start = second
} else {
if strings.HasPrefix(url, "http://") {
log.Errorln("%s failed to get the second response from %s: %v", p.Name(), url, ignoredErr)
log.Warnln("It is recommended to use HTTPS for provider.health-check.url and group.url to ensure better reliability. Due to some proxy providers hijacking test addresses and not being compatible with repeated HEAD requests, using HTTP may result in failed tests.")
}
} }
} }
@ -290,7 +266,6 @@ func (p *Proxy) URLTest(ctx context.Context, url string, expectedStatus utils.In
t = uint16(time.Since(start) / time.Millisecond) t = uint16(time.Since(start) / time.Millisecond)
return return
} }
func NewProxy(adapter C.ProxyAdapter) *Proxy { func NewProxy(adapter C.ProxyAdapter) *Proxy {
return &Proxy{ return &Proxy{
ProxyAdapter: adapter, ProxyAdapter: adapter,

View file

@ -34,5 +34,12 @@ func SkipAuthRemoteAddress(addr string) bool {
} }
func skipAuth(addr netip.Addr) bool { func skipAuth(addr netip.Addr) bool {
return prefixesContains(skipAuthPrefixes, addr) if addr.IsValid() {
for _, prefix := range skipAuthPrefixes {
if prefix.Contains(addr.Unmap()) {
return true
}
}
}
return false
} }

View file

@ -11,8 +11,6 @@ import (
func NewHTTPS(request *http.Request, conn net.Conn, additions ...Addition) (net.Conn, *C.Metadata) { func NewHTTPS(request *http.Request, conn net.Conn, additions ...Addition) (net.Conn, *C.Metadata) {
metadata := parseHTTPAddr(request) metadata := parseHTTPAddr(request)
metadata.Type = C.HTTPS metadata.Type = C.HTTPS
metadata.RawSrcAddr = conn.RemoteAddr()
metadata.RawDstAddr = conn.LocalAddr()
ApplyAdditions(metadata, WithSrcAddr(conn.RemoteAddr()), WithInAddr(conn.LocalAddr())) ApplyAdditions(metadata, WithSrcAddr(conn.RemoteAddr()), WithInAddr(conn.LocalAddr()))
ApplyAdditions(metadata, additions...) ApplyAdditions(metadata, additions...)
return conn, metadata return conn, metadata

View file

@ -31,17 +31,27 @@ func IsRemoteAddrDisAllowed(addr net.Addr) bool {
if err := m.SetRemoteAddr(addr); err != nil { if err := m.SetRemoteAddr(addr); err != nil {
return false return false
} }
ipAddr := m.AddrPort().Addr() return isAllowed(m.AddrPort().Addr().Unmap()) && !isDisAllowed(m.AddrPort().Addr().Unmap())
if ipAddr.IsValid() { }
return isAllowed(ipAddr) && !isDisAllowed(ipAddr)
func isAllowed(addr netip.Addr) bool {
if addr.IsValid() {
for _, prefix := range lanAllowedIPs {
if prefix.Contains(addr) {
return true
}
}
} }
return false return false
} }
func isAllowed(addr netip.Addr) bool {
return prefixesContains(lanAllowedIPs, addr)
}
func isDisAllowed(addr netip.Addr) bool { func isDisAllowed(addr netip.Addr) bool {
return prefixesContains(lanDisAllowedIPs, addr) if addr.IsValid() {
for _, prefix := range lanDisAllowedIPs {
if prefix.Contains(addr) {
return true
}
}
}
return false
} }

View file

@ -2,12 +2,7 @@ package inbound
import ( import (
"context" "context"
"fmt"
"net" "net"
"net/netip"
"sync"
"github.com/metacubex/mihomo/component/keepalive"
"github.com/metacubex/tfo-go" "github.com/metacubex/tfo-go"
) )
@ -16,91 +11,28 @@ var (
lc = tfo.ListenConfig{ lc = tfo.ListenConfig{
DisableTFO: true, DisableTFO: true,
} }
mutex sync.RWMutex
) )
func SetTfo(open bool) { func SetTfo(open bool) {
mutex.Lock()
defer mutex.Unlock()
lc.DisableTFO = !open lc.DisableTFO = !open
} }
func Tfo() bool { func Tfo() bool {
mutex.RLock()
defer mutex.RUnlock()
return !lc.DisableTFO return !lc.DisableTFO
} }
func SetMPTCP(open bool) { func SetMPTCP(open bool) {
mutex.Lock()
defer mutex.Unlock()
setMultiPathTCP(&lc.ListenConfig, open) setMultiPathTCP(&lc.ListenConfig, open)
} }
func MPTCP() bool { func MPTCP() bool {
mutex.RLock()
defer mutex.RUnlock()
return getMultiPathTCP(&lc.ListenConfig) return getMultiPathTCP(&lc.ListenConfig)
} }
func preResolve(network, address string) (string, error) {
switch network { // like net.Resolver.internetAddrList but filter domain to avoid call net.Resolver.lookupIPAddr
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6", "ip", "ip4", "ip6":
if host, port, err := net.SplitHostPort(address); err == nil {
switch host {
case "localhost":
switch network {
case "tcp6", "udp6", "ip6":
address = net.JoinHostPort("::1", port)
default:
address = net.JoinHostPort("127.0.0.1", port)
}
case "": // internetAddrList can handle this special case
break
default:
if _, err := netip.ParseAddr(host); err != nil { // not ip
return "", fmt.Errorf("invalid network address: %s", address)
}
}
}
}
return address, nil
}
func ListenContext(ctx context.Context, network, address string) (net.Listener, error) { func ListenContext(ctx context.Context, network, address string) (net.Listener, error) {
address, err := preResolve(network, address)
if err != nil {
return nil, err
}
mutex.RLock()
defer mutex.RUnlock()
return lc.Listen(ctx, network, address) return lc.Listen(ctx, network, address)
} }
func Listen(network, address string) (net.Listener, error) { func Listen(network, address string) (net.Listener, error) {
return ListenContext(context.Background(), network, address) return ListenContext(context.Background(), network, address)
} }
func ListenPacketContext(ctx context.Context, network, address string) (net.PacketConn, error) {
address, err := preResolve(network, address)
if err != nil {
return nil, err
}
mutex.RLock()
defer mutex.RUnlock()
return lc.ListenPacket(ctx, network, address)
}
func ListenPacket(network, address string) (net.PacketConn, error) {
return ListenPacketContext(context.Background(), network, address)
}
func init() {
keepalive.SetDisableKeepAliveCallback.Register(func(b bool) {
mutex.Lock()
defer mutex.Unlock()
keepalive.SetNetListenConfig(&lc.ListenConfig)
})
}

View file

@ -1,14 +0,0 @@
//go:build !windows
package inbound
import (
"net"
"os"
)
const SupportNamedPipe = false
func ListenNamedPipe(path string) (net.Listener, error) {
return nil, os.ErrInvalid
}

View file

@ -1,32 +0,0 @@
package inbound
import (
"net"
"os"
"github.com/metacubex/wireguard-go/ipc/namedpipe"
"golang.org/x/sys/windows"
)
const SupportNamedPipe = true
// windowsSDDL is the Security Descriptor set on the namedpipe.
// It provides read/write access to all users and the local system.
const windowsSDDL = "D:PAI(A;OICI;GWGR;;;BU)(A;OICI;GWGR;;;SY)"
func ListenNamedPipe(path string) (net.Listener, error) {
sddl := os.Getenv("LISTEN_NAMEDPIPE_SDDL")
if sddl == "" {
sddl = windowsSDDL
}
securityDescriptor, err := windows.SecurityDescriptorFromString(sddl)
if err != nil {
return nil, err
}
namedpipeLC := namedpipe.ListenConfig{
SecurityDescriptor: securityDescriptor,
InputBufferSize: 256 * 1024,
OutputBufferSize: 256 * 1024,
}
return namedpipeLC.Listen(path)
}

View file

@ -7,6 +7,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/metacubex/mihomo/common/nnip"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/socks5" "github.com/metacubex/mihomo/transport/socks5"
) )
@ -20,13 +21,13 @@ func parseSocksAddr(target socks5.Addr) *C.Metadata {
metadata.Host = strings.TrimRight(string(target[2:2+target[1]]), ".") metadata.Host = strings.TrimRight(string(target[2:2+target[1]]), ".")
metadata.DstPort = uint16((int(target[2+target[1]]) << 8) | int(target[2+target[1]+1])) metadata.DstPort = uint16((int(target[2+target[1]]) << 8) | int(target[2+target[1]+1]))
case socks5.AtypIPv4: case socks5.AtypIPv4:
metadata.DstIP, _ = netip.AddrFromSlice(target[1 : 1+net.IPv4len]) metadata.DstIP = nnip.IpToAddr(net.IP(target[1 : 1+net.IPv4len]))
metadata.DstPort = uint16((int(target[1+net.IPv4len]) << 8) | int(target[1+net.IPv4len+1])) metadata.DstPort = uint16((int(target[1+net.IPv4len]) << 8) | int(target[1+net.IPv4len+1]))
case socks5.AtypIPv6: case socks5.AtypIPv6:
metadata.DstIP, _ = netip.AddrFromSlice(target[1 : 1+net.IPv6len]) ip6, _ := netip.AddrFromSlice(target[1 : 1+net.IPv6len])
metadata.DstIP = ip6.Unmap()
metadata.DstPort = uint16((int(target[1+net.IPv6len]) << 8) | int(target[1+net.IPv6len+1])) metadata.DstPort = uint16((int(target[1+net.IPv6len]) << 8) | int(target[1+net.IPv6len+1]))
} }
metadata.DstIP = metadata.DstIP.Unmap()
return metadata return metadata
} }
@ -60,19 +61,3 @@ func parseHTTPAddr(request *http.Request) *C.Metadata {
return metadata return metadata
} }
func prefixesContains(prefixes []netip.Prefix, addr netip.Addr) bool {
if len(prefixes) == 0 {
return false
}
if !addr.IsValid() {
return false
}
addr = addr.Unmap().WithZone("") // netip.Prefix.Contains returns false if ip has an IPv6 zone
for _, prefix := range prefixes {
if prefix.Contains(addr) {
return true
}
}
return false
}

View file

@ -1,137 +0,0 @@
package outbound
import (
"context"
"errors"
"net"
"strconv"
"time"
CN "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
"github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/anytls"
"github.com/metacubex/mihomo/transport/vmess"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/uot"
)
type AnyTLS struct {
*Base
client *anytls.Client
dialer proxydialer.SingDialer
option *AnyTLSOption
}
type AnyTLSOption struct {
BasicOption
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port"`
Password string `proxy:"password"`
ALPN []string `proxy:"alpn,omitempty"`
SNI string `proxy:"sni,omitempty"`
ClientFingerprint string `proxy:"client-fingerprint,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"`
UDP bool `proxy:"udp,omitempty"`
IdleSessionCheckInterval int `proxy:"idle-session-check-interval,omitempty"`
IdleSessionTimeout int `proxy:"idle-session-timeout,omitempty"`
MinIdleSession int `proxy:"min-idle-session,omitempty"`
}
func (t *AnyTLS) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
options := t.Base.DialOptions(opts...)
t.dialer.SetDialer(dialer.NewDialer(options...))
c, err := t.client.CreateProxy(ctx, M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort))
if err != nil {
return nil, err
}
return NewConn(c, t), nil
}
func (t *AnyTLS) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
// create tcp
options := t.Base.DialOptions(opts...)
t.dialer.SetDialer(dialer.NewDialer(options...))
c, err := t.client.CreateProxy(ctx, uot.RequestDestination(2))
if err != nil {
return nil, err
}
// create uot on tcp
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(ctx, metadata.Host)
if err != nil {
return nil, errors.New("can't resolve ip")
}
metadata.DstIP = ip
}
destination := M.SocksaddrFromNet(metadata.UDPAddr())
return newPacketConn(CN.NewThreadSafePacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination})), t), nil
}
// SupportUOT implements C.ProxyAdapter
func (t *AnyTLS) SupportUOT() bool {
return true
}
// ProxyInfo implements C.ProxyAdapter
func (t *AnyTLS) ProxyInfo() C.ProxyInfo {
info := t.Base.ProxyInfo()
info.DialerProxy = t.option.DialerProxy
return info
}
// Close implements C.ProxyAdapter
func (t *AnyTLS) Close() error {
return t.client.Close()
}
func NewAnyTLS(option AnyTLSOption) (*AnyTLS, error) {
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
singDialer := proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer())
tOption := anytls.ClientConfig{
Password: option.Password,
Server: M.ParseSocksaddrHostPort(option.Server, uint16(option.Port)),
Dialer: singDialer,
IdleSessionCheckInterval: time.Duration(option.IdleSessionCheckInterval) * time.Second,
IdleSessionTimeout: time.Duration(option.IdleSessionTimeout) * time.Second,
MinIdleSession: option.MinIdleSession,
}
tlsConfig := &vmess.TLSConfig{
Host: option.SNI,
SkipCertVerify: option.SkipCertVerify,
NextProtos: option.ALPN,
FingerPrint: option.Fingerprint,
ClientFingerprint: option.ClientFingerprint,
}
if tlsConfig.Host == "" {
tlsConfig.Host = option.Server
}
tOption.TLSConfig = tlsConfig
outbound := &AnyTLS{
Base: &Base{
name: option.Name,
addr: addr,
tp: C.AnyTLS,
udp: option.UDP,
tfo: option.TFO,
mpTcp: option.MPTCP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
},
client: anytls.NewClient(context.TODO(), tOption),
option: &option,
dialer: singDialer,
}
return outbound, nil
}

View file

@ -4,23 +4,15 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"net" "net"
"runtime"
"strings" "strings"
"sync"
"syscall" "syscall"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/dialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
) )
type ProxyAdapter interface {
C.ProxyAdapter
DialOptions(opts ...dialer.Option) []dialer.Option
}
type Base struct { type Base struct {
name string name string
addr string addr string
@ -93,15 +85,14 @@ func (b *Base) SupportUDP() bool {
return b.udp return b.udp
} }
// ProxyInfo implements C.ProxyAdapter // SupportXUDP implements C.ProxyAdapter
func (b *Base) ProxyInfo() (info C.ProxyInfo) { func (b *Base) SupportXUDP() bool {
info.XUDP = b.xudp return b.xudp
info.TFO = b.tfo }
info.MPTCP = b.mpTcp
info.SMUX = false // SupportTFO implements C.ProxyAdapter
info.Interface = b.iface func (b *Base) SupportTFO() bool {
info.RoutingMark = b.rmark return b.tfo
return
} }
// IsL3Protocol implements C.ProxyAdapter // IsL3Protocol implements C.ProxyAdapter
@ -160,16 +151,12 @@ func (b *Base) DialOptions(opts ...dialer.Option) []dialer.Option {
return opts return opts
} }
func (b *Base) Close() error {
return nil
}
type BasicOption struct { type BasicOption struct {
TFO bool `proxy:"tfo,omitempty"` TFO bool `proxy:"tfo,omitempty" group:"tfo,omitempty"`
MPTCP bool `proxy:"mptcp,omitempty"` MPTCP bool `proxy:"mptcp,omitempty" group:"mptcp,omitempty"`
Interface string `proxy:"interface-name,omitempty" group:"interface-name,omitempty"` Interface string `proxy:"interface-name,omitempty" group:"interface-name,omitempty"`
RoutingMark int `proxy:"routing-mark,omitempty" group:"routing-mark,omitempty"` RoutingMark int `proxy:"routing-mark,omitempty" group:"routing-mark,omitempty"`
IPVersion string `proxy:"ip-version,omitempty"` IPVersion string `proxy:"ip-version,omitempty" group:"ip-version,omitempty"`
DialerProxy string `proxy:"dialer-proxy,omitempty"` // don't apply this option into groups, but can set a group name in a proxy DialerProxy string `proxy:"dialer-proxy,omitempty"` // don't apply this option into groups, but can set a group name in a proxy
} }
@ -233,10 +220,6 @@ func (c *conn) ReaderReplaceable() bool {
return true return true
} }
func (c *conn) AddRef(ref any) {
c.ExtendedConn = N.NewRefConn(c.ExtendedConn, ref) // add ref for autoCloseProxyAdapter
}
func NewConn(c net.Conn, a C.ProxyAdapter) C.Conn { func NewConn(c net.Conn, a C.ProxyAdapter) C.Conn {
if _, ok := c.(syscall.Conn); !ok { // exclusion system conn like *net.TCPConn if _, ok := c.(syscall.Conn); !ok { // exclusion system conn like *net.TCPConn
c = N.NewDeadlineConn(c) // most conn from outbound can't handle readDeadline correctly c = N.NewDeadlineConn(c) // most conn from outbound can't handle readDeadline correctly
@ -283,10 +266,6 @@ func (c *packetConn) ReaderReplaceable() bool {
return true return true
} }
func (c *packetConn) AddRef(ref any) {
c.EnhancePacketConn = N.NewRefPacketConn(c.EnhancePacketConn, ref) // add ref for autoCloseProxyAdapter
}
func newPacketConn(pc net.PacketConn, a C.ProxyAdapter) C.PacketConn { func newPacketConn(pc net.PacketConn, a C.ProxyAdapter) C.PacketConn {
epc := N.NewEnhancePacketConn(pc) epc := N.NewEnhancePacketConn(pc)
if _, ok := pc.(syscall.Conn); !ok { // exclusion system conn like *net.UDPConn if _, ok := pc.(syscall.Conn); !ok { // exclusion system conn like *net.UDPConn
@ -306,75 +285,3 @@ func parseRemoteDestination(addr string) string {
} }
} }
} }
type AddRef interface {
AddRef(ref any)
}
type autoCloseProxyAdapter struct {
ProxyAdapter
closeOnce sync.Once
closeErr error
}
func (p *autoCloseProxyAdapter) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
c, err := p.ProxyAdapter.DialContext(ctx, metadata, opts...)
if err != nil {
return nil, err
}
if c, ok := c.(AddRef); ok {
c.AddRef(p)
}
return c, nil
}
func (p *autoCloseProxyAdapter) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
c, err := p.ProxyAdapter.DialContextWithDialer(ctx, dialer, metadata)
if err != nil {
return nil, err
}
if c, ok := c.(AddRef); ok {
c.AddRef(p)
}
return c, nil
}
func (p *autoCloseProxyAdapter) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
pc, err := p.ProxyAdapter.ListenPacketContext(ctx, metadata, opts...)
if err != nil {
return nil, err
}
if pc, ok := pc.(AddRef); ok {
pc.AddRef(p)
}
return pc, nil
}
func (p *autoCloseProxyAdapter) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
pc, err := p.ProxyAdapter.ListenPacketWithDialer(ctx, dialer, metadata)
if err != nil {
return nil, err
}
if pc, ok := pc.(AddRef); ok {
pc.AddRef(p)
}
return pc, nil
}
func (p *autoCloseProxyAdapter) Close() error {
p.closeOnce.Do(func() {
log.Debugln("Closing outdated proxy [%s]", p.Name())
runtime.SetFinalizer(p, nil)
p.closeErr = p.ProxyAdapter.Close()
})
return p.closeErr
}
func NewAutoCloseProxyAdapter(adapter ProxyAdapter) ProxyAdapter {
proxy := &autoCloseProxyAdapter{
ProxyAdapter: adapter,
}
// auto close ProxyAdapter
runtime.SetFinalizer(proxy, (*autoCloseProxyAdapter).Close)
return proxy
}

View file

@ -3,12 +3,19 @@ package outbound
import ( import (
"context" "context"
"errors" "errors"
"os"
"strconv"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/loopback" "github.com/metacubex/mihomo/component/loopback"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/constant/features"
) )
var DisableLoopBackDetector, _ = strconv.ParseBool(os.Getenv("DISABLE_LOOPBACK_DETECTOR"))
type Direct struct { type Direct struct {
*Base *Base
loopBack *loopback.Detector loopBack *loopback.Detector
@ -21,25 +28,30 @@ type DirectOption struct {
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
if !features.CMFA && !DisableLoopBackDetector {
if err := d.loopBack.CheckConn(metadata); err != nil { if err := d.loopBack.CheckConn(metadata); err != nil {
return nil, err return nil, err
} }
opts = append(opts, dialer.WithResolver(resolver.DirectHostResolver)) }
opts = append(opts, dialer.WithResolver(resolver.DefaultResolver))
c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...) c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
N.TCPKeepAlive(c)
return d.loopBack.NewConn(NewConn(c, d)), nil return d.loopBack.NewConn(NewConn(c, d)), nil
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
if !features.CMFA && !DisableLoopBackDetector {
if err := d.loopBack.CheckPacketConn(metadata); err != nil { if err := d.loopBack.CheckPacketConn(metadata); err != nil {
return nil, err return nil, err
} }
}
// net.UDPConn.WriteTo only working with *net.UDPAddr, so we need a net.UDPAddr // net.UDPConn.WriteTo only working with *net.UDPAddr, so we need a net.UDPAddr
if !metadata.Resolved() { if !metadata.Resolved() {
ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, resolver.DirectHostResolver) ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, resolver.DefaultResolver)
if err != nil { if err != nil {
return nil, errors.New("can't resolve ip") return nil, errors.New("can't resolve ip")
} }

View file

@ -7,6 +7,8 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
@ -51,7 +53,7 @@ func (h *Http) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Me
} }
} }
if err := h.shakeHandContext(ctx, c, metadata); err != nil { if err := h.shakeHand(metadata, c); err != nil {
return nil, err return nil, err
} }
return c, nil return c, nil
@ -74,6 +76,7 @@ func (h *Http) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metad
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", h.addr, err) return nil, fmt.Errorf("%s connect error: %w", h.addr, err)
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
@ -92,19 +95,7 @@ func (h *Http) SupportWithDialer() C.NetWork {
return C.TCP return C.TCP
} }
// ProxyInfo implements C.ProxyAdapter func (h *Http) shakeHand(metadata *C.Metadata, rw io.ReadWriter) error {
func (h *Http) ProxyInfo() C.ProxyInfo {
info := h.Base.ProxyInfo()
info.DialerProxy = h.option.DialerProxy
return info
}
func (h *Http) shakeHandContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (err error) {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
addr := metadata.RemoteAddress() addr := metadata.RemoteAddress()
HeaderString := "CONNECT " + addr + " HTTP/1.1\r\n" HeaderString := "CONNECT " + addr + " HTTP/1.1\r\n"
tempHeaders := map[string]string{ tempHeaders := map[string]string{
@ -128,13 +119,13 @@ func (h *Http) shakeHandContext(ctx context.Context, c net.Conn, metadata *C.Met
HeaderString += "\r\n" HeaderString += "\r\n"
_, err = c.Write([]byte(HeaderString)) _, err := rw.Write([]byte(HeaderString))
if err != nil { if err != nil {
return err return err
} }
resp, err := http.ReadResponse(bufio.NewReader(c), nil) resp, err := http.ReadResponse(bufio.NewReader(rw), nil)
if err != nil { if err != nil {
return err return err

View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"runtime"
"strconv" "strconv"
"time" "time"
@ -14,6 +15,7 @@ import (
"github.com/metacubex/quic-go/congestion" "github.com/metacubex/quic-go/congestion"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
CN "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer" "github.com/metacubex/mihomo/component/proxydialer"
@ -43,6 +45,8 @@ type Hysteria struct {
option *HysteriaOption option *HysteriaOption
client *core.Client client *core.Client
closeCh chan struct{} // for test
} }
func (h *Hysteria) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { func (h *Hysteria) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
@ -51,7 +55,7 @@ func (h *Hysteria) DialContext(ctx context.Context, metadata *C.Metadata, opts .
return nil, err return nil, err
} }
return NewConn(tcpConn, h), nil return NewConn(CN.NewRefConn(tcpConn, h), h), nil
} }
func (h *Hysteria) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { func (h *Hysteria) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
@ -59,13 +63,13 @@ func (h *Hysteria) ListenPacketContext(ctx context.Context, metadata *C.Metadata
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newPacketConn(&hyPacketConn{udpConn}, h), nil return newPacketConn(CN.NewRefPacketConn(&hyPacketConn{udpConn}, h), h), nil
} }
func (h *Hysteria) genHdc(ctx context.Context, opts ...dialer.Option) utils.PacketDialer { func (h *Hysteria) genHdc(ctx context.Context, opts ...dialer.Option) utils.PacketDialer {
return &hyDialerWithContext{ return &hyDialerWithContext{
ctx: context.Background(), ctx: context.Background(),
hyDialer: func(network string, rAddr net.Addr) (net.PacketConn, error) { hyDialer: func(network string) (net.PacketConn, error) {
var err error var err error
var cDialer C.Dialer = dialer.NewDialer(h.Base.DialOptions(opts...)...) var cDialer C.Dialer = dialer.NewDialer(h.Base.DialOptions(opts...)...)
if len(h.option.DialerProxy) > 0 { if len(h.option.DialerProxy) > 0 {
@ -74,22 +78,15 @@ func (h *Hysteria) genHdc(ctx context.Context, opts ...dialer.Option) utils.Pack
return nil, err return nil, err
} }
} }
rAddrPort, _ := netip.ParseAddrPort(rAddr.String()) rAddrPort, _ := netip.ParseAddrPort(h.Addr())
return cDialer.ListenPacket(ctx, network, "", rAddrPort) return cDialer.ListenPacket(ctx, network, "", rAddrPort)
}, },
remoteAddr: func(addr string) (net.Addr, error) { remoteAddr: func(addr string) (net.Addr, error) {
return resolveUDPAddr(ctx, "udp", addr, h.prefer) return resolveUDPAddrWithPrefer(ctx, "udp", addr, h.prefer)
}, },
} }
} }
// ProxyInfo implements C.ProxyAdapter
func (h *Hysteria) ProxyInfo() C.ProxyInfo {
info := h.Base.ProxyInfo()
info.DialerProxy = h.option.DialerProxy
return info
}
type HysteriaOption struct { type HysteriaOption struct {
BasicOption BasicOption
Name string `proxy:"name"` Name string `proxy:"name"`
@ -134,7 +131,11 @@ func (c *HysteriaOption) Speed() (uint64, uint64, error) {
} }
func NewHysteria(option HysteriaOption) (*Hysteria, error) { func NewHysteria(option HysteriaOption) (*Hysteria, error) {
clientTransport := &transport.ClientTransport{} clientTransport := &transport.ClientTransport{
Dialer: &net.Dialer{
Timeout: 8 * time.Second,
},
}
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
ports := option.Ports ports := option.Ports
@ -235,16 +236,18 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) {
option: &option, option: &option,
client: client, client: client,
} }
runtime.SetFinalizer(outbound, closeHysteria)
return outbound, nil return outbound, nil
} }
// Close implements C.ProxyAdapter func closeHysteria(h *Hysteria) {
func (h *Hysteria) Close() error {
if h.client != nil { if h.client != nil {
return h.client.Close() _ = h.client.Close()
}
if h.closeCh != nil {
close(h.closeCh)
} }
return nil
} }
type hyPacketConn struct { type hyPacketConn struct {
@ -281,7 +284,7 @@ func (c *hyPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
} }
type hyDialerWithContext struct { type hyDialerWithContext struct {
hyDialer func(network string, rAddr net.Addr) (net.PacketConn, error) hyDialer func(network string) (net.PacketConn, error)
ctx context.Context ctx context.Context
remoteAddr func(host string) (net.Addr, error) remoteAddr func(host string) (net.Addr, error)
} }
@ -291,7 +294,7 @@ func (h *hyDialerWithContext) ListenPacket(rAddr net.Addr) (net.PacketConn, erro
if addrPort, err := netip.ParseAddrPort(rAddr.String()); err == nil { if addrPort, err := netip.ParseAddrPort(rAddr.String()); err == nil {
network = dialer.ParseNetwork(network, addrPort.Addr()) network = dialer.ParseNetwork(network, addrPort.Addr())
} }
return h.hyDialer(network, rAddr) return h.hyDialer(network)
} }
func (h *hyDialerWithContext) Context() context.Context { func (h *hyDialerWithContext) Context() context.Context {

View file

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"runtime"
"strconv" "strconv"
"time" "time"
@ -20,7 +21,6 @@ import (
"github.com/metacubex/sing-quic/hysteria2" "github.com/metacubex/sing-quic/hysteria2"
"github.com/metacubex/quic-go"
"github.com/metacubex/randv2" "github.com/metacubex/randv2"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
) )
@ -38,6 +38,8 @@ type Hysteria2 struct {
option *Hysteria2Option option *Hysteria2Option
client *hysteria2.Client client *hysteria2.Client
dialer proxydialer.SingDialer dialer proxydialer.SingDialer
closeCh chan struct{} // for test
} }
type Hysteria2Option struct { type Hysteria2Option struct {
@ -60,12 +62,6 @@ type Hysteria2Option struct {
CustomCAString string `proxy:"ca-str,omitempty"` CustomCAString string `proxy:"ca-str,omitempty"`
CWND int `proxy:"cwnd,omitempty"` CWND int `proxy:"cwnd,omitempty"`
UdpMTU int `proxy:"udp-mtu,omitempty"` UdpMTU int `proxy:"udp-mtu,omitempty"`
// quic-go special config
InitialStreamReceiveWindow uint64 `proxy:"initial-stream-receive-window,omitempty"`
MaxStreamReceiveWindow uint64 `proxy:"max-stream-receive-window,omitempty"`
InitialConnectionReceiveWindow uint64 `proxy:"initial-connection-receive-window,omitempty"`
MaxConnectionReceiveWindow uint64 `proxy:"max-connection-receive-window,omitempty"`
} }
func (h *Hysteria2) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { func (h *Hysteria2) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
@ -75,7 +71,7 @@ func (h *Hysteria2) DialContext(ctx context.Context, metadata *C.Metadata, opts
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewConn(c, h), nil return NewConn(CN.NewRefConn(c, h), h), nil
} }
func (h *Hysteria2) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { func (h *Hysteria2) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
@ -88,22 +84,16 @@ func (h *Hysteria2) ListenPacketContext(ctx context.Context, metadata *C.Metadat
if pc == nil { if pc == nil {
return nil, errors.New("packetConn is nil") return nil, errors.New("packetConn is nil")
} }
return newPacketConn(CN.NewThreadSafePacketConn(pc), h), nil return newPacketConn(CN.NewRefPacketConn(CN.NewThreadSafePacketConn(pc), h), h), nil
} }
// Close implements C.ProxyAdapter func closeHysteria2(h *Hysteria2) {
func (h *Hysteria2) Close() error {
if h.client != nil { if h.client != nil {
return h.client.CloseWithError(errors.New("proxy removed")) _ = h.client.CloseWithError(errors.New("proxy removed"))
} }
return nil if h.closeCh != nil {
close(h.closeCh)
} }
// ProxyInfo implements C.ProxyAdapter
func (h *Hysteria2) ProxyInfo() C.ProxyInfo {
info := h.Base.ProxyInfo()
info.DialerProxy = h.option.DialerProxy
return info
} }
func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) { func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
@ -148,13 +138,6 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
option.UdpMTU = 1200 - 3 option.UdpMTU = 1200 - 3
} }
quicConfig := &quic.Config{
InitialStreamReceiveWindow: option.InitialStreamReceiveWindow,
MaxStreamReceiveWindow: option.MaxStreamReceiveWindow,
InitialConnectionReceiveWindow: option.InitialConnectionReceiveWindow,
MaxConnectionReceiveWindow: option.MaxConnectionReceiveWindow,
}
singDialer := proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer()) singDialer := proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer())
clientOptions := hysteria2.ClientOptions{ clientOptions := hysteria2.ClientOptions{
@ -166,12 +149,11 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
SalamanderPassword: salamanderPassword, SalamanderPassword: salamanderPassword,
Password: option.Password, Password: option.Password,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
QUICConfig: quicConfig,
UDPDisabled: false, UDPDisabled: false,
CWND: option.CWND, CWND: option.CWND,
UdpMTU: option.UdpMTU, UdpMTU: option.UdpMTU,
ServerAddress: func(ctx context.Context) (*net.UDPAddr, error) { ServerAddress: func(ctx context.Context) (*net.UDPAddr, error) {
return resolveUDPAddr(ctx, "udp", addr, C.NewDNSPrefer(option.IPVersion)) return resolveUDPAddrWithPrefer(ctx, "udp", addr, C.NewDNSPrefer(option.IPVersion))
}, },
} }
@ -188,7 +170,7 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
}) })
if len(serverAddress) > 0 { if len(serverAddress) > 0 {
clientOptions.ServerAddress = func(ctx context.Context) (*net.UDPAddr, error) { clientOptions.ServerAddress = func(ctx context.Context) (*net.UDPAddr, error) {
return resolveUDPAddr(ctx, "udp", serverAddress[randv2.IntN(len(serverAddress))], C.NewDNSPrefer(option.IPVersion)) return resolveUDPAddrWithPrefer(ctx, "udp", serverAddress[randv2.IntN(len(serverAddress))], C.NewDNSPrefer(option.IPVersion))
} }
if option.HopInterval == 0 { if option.HopInterval == 0 {
@ -222,6 +204,7 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
client: client, client: client,
dialer: singDialer, dialer: singDialer,
} }
runtime.SetFinalizer(outbound, closeHysteria2)
return outbound, nil return outbound, nil
} }

View file

@ -0,0 +1,38 @@
package outbound
import (
"context"
"runtime"
"testing"
"time"
)
func TestHysteria2GC(t *testing.T) {
option := Hysteria2Option{}
option.Server = "127.0.0.1"
option.Ports = "200,204,401-429,501-503"
option.HopInterval = 30
option.Password = "password"
option.Obfs = "salamander"
option.ObfsPassword = "password"
option.SNI = "example.com"
option.ALPN = []string{"h3"}
hy, err := NewHysteria2(option)
if err != nil {
t.Error(err)
return
}
closeCh := make(chan struct{})
hy.closeCh = closeCh
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
hy = nil
runtime.GC()
select {
case <-closeCh:
return
case <-ctx.Done():
t.Error("timeout not GC")
}
}

View file

@ -0,0 +1,39 @@
package outbound
import (
"context"
"runtime"
"testing"
"time"
)
func TestHysteriaGC(t *testing.T) {
option := HysteriaOption{}
option.Server = "127.0.0.1"
option.Ports = "200,204,401-429,501-503"
option.Protocol = "udp"
option.Up = "1Mbps"
option.Down = "1Mbps"
option.HopInterval = 30
option.Obfs = "salamander"
option.SNI = "example.com"
option.ALPN = []string{"h3"}
hy, err := NewHysteria(option)
if err != nil {
t.Error(err)
return
}
closeCh := make(chan struct{})
hy.closeCh = closeCh
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
hy = nil
runtime.GC()
select {
case <-closeCh:
return
case <-ctx.Done():
t.Error("timeout not GC")
}
}

View file

@ -1,301 +0,0 @@
package outbound
import (
"context"
"fmt"
"net"
"strconv"
"sync"
CN "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant"
mieruclient "github.com/enfein/mieru/v3/apis/client"
mierucommon "github.com/enfein/mieru/v3/apis/common"
mierumodel "github.com/enfein/mieru/v3/apis/model"
mierupb "github.com/enfein/mieru/v3/pkg/appctl/appctlpb"
"google.golang.org/protobuf/proto"
)
type Mieru struct {
*Base
option *MieruOption
client mieruclient.Client
mu sync.Mutex
}
type MieruOption struct {
BasicOption
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port,omitempty"`
PortRange string `proxy:"port-range,omitempty"`
Transport string `proxy:"transport"`
UDP bool `proxy:"udp,omitempty"`
UserName string `proxy:"username"`
Password string `proxy:"password"`
Multiplexing string `proxy:"multiplexing,omitempty"`
}
// DialContext implements C.ProxyAdapter
func (m *Mieru) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
if err := m.ensureClientIsRunning(opts...); err != nil {
return nil, err
}
addr := metadataToMieruNetAddrSpec(metadata)
c, err := m.client.DialContext(ctx, addr)
if err != nil {
return nil, fmt.Errorf("dial to %s failed: %w", addr, err)
}
return NewConn(c, m), nil
}
// ListenPacketContext implements C.ProxyAdapter
func (m *Mieru) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
if err := m.ensureClientIsRunning(opts...); err != nil {
return nil, err
}
c, err := m.client.DialContext(ctx, metadata.UDPAddr())
if err != nil {
return nil, fmt.Errorf("dial to %s failed: %w", metadata.UDPAddr(), err)
}
return newPacketConn(CN.NewThreadSafePacketConn(mierucommon.NewUDPAssociateWrapper(mierucommon.NewPacketOverStreamTunnel(c))), m), nil
}
// SupportUOT implements C.ProxyAdapter
func (m *Mieru) SupportUOT() bool {
return true
}
// ProxyInfo implements C.ProxyAdapter
func (m *Mieru) ProxyInfo() C.ProxyInfo {
info := m.Base.ProxyInfo()
info.DialerProxy = m.option.DialerProxy
return info
}
func (m *Mieru) ensureClientIsRunning(opts ...dialer.Option) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.client.IsRunning() {
return nil
}
// Create a dialer and add it to the client config, before starting the client.
var dialer C.Dialer = dialer.NewDialer(m.Base.DialOptions(opts...)...)
var err error
if len(m.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(m.option.DialerProxy, dialer)
if err != nil {
return err
}
}
config, err := m.client.Load()
if err != nil {
return err
}
config.Dialer = dialer
if err := m.client.Store(config); err != nil {
return err
}
if err := m.client.Start(); err != nil {
return fmt.Errorf("failed to start mieru client: %w", err)
}
return nil
}
func NewMieru(option MieruOption) (*Mieru, error) {
config, err := buildMieruClientConfig(option)
if err != nil {
return nil, fmt.Errorf("failed to build mieru client config: %w", err)
}
c := mieruclient.NewClient()
if err := c.Store(config); err != nil {
return nil, fmt.Errorf("failed to store mieru client config: %w", err)
}
// Client is started lazily on the first use.
var addr string
if option.Port != 0 {
addr = net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
} else {
beginPort, _, _ := beginAndEndPortFromPortRange(option.PortRange)
addr = net.JoinHostPort(option.Server, strconv.Itoa(beginPort))
}
outbound := &Mieru{
Base: &Base{
name: option.Name,
addr: addr,
iface: option.Interface,
tp: C.Mieru,
udp: option.UDP,
xudp: false,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
},
option: &option,
client: c,
}
return outbound, nil
}
// Close implements C.ProxyAdapter
func (m *Mieru) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
if m.client != nil && m.client.IsRunning() {
return m.client.Stop()
}
return nil
}
func metadataToMieruNetAddrSpec(metadata *C.Metadata) mierumodel.NetAddrSpec {
if metadata.Host != "" {
return mierumodel.NetAddrSpec{
AddrSpec: mierumodel.AddrSpec{
FQDN: metadata.Host,
Port: int(metadata.DstPort),
},
Net: "tcp",
}
} else {
return mierumodel.NetAddrSpec{
AddrSpec: mierumodel.AddrSpec{
IP: metadata.DstIP.AsSlice(),
Port: int(metadata.DstPort),
},
Net: "tcp",
}
}
}
func buildMieruClientConfig(option MieruOption) (*mieruclient.ClientConfig, error) {
if err := validateMieruOption(option); err != nil {
return nil, fmt.Errorf("failed to validate mieru option: %w", err)
}
transportProtocol := mierupb.TransportProtocol_TCP.Enum()
var server *mierupb.ServerEndpoint
if net.ParseIP(option.Server) != nil {
// server is an IP address
if option.PortRange != "" {
server = &mierupb.ServerEndpoint{
IpAddress: proto.String(option.Server),
PortBindings: []*mierupb.PortBinding{
{
PortRange: proto.String(option.PortRange),
Protocol: transportProtocol,
},
},
}
} else {
server = &mierupb.ServerEndpoint{
IpAddress: proto.String(option.Server),
PortBindings: []*mierupb.PortBinding{
{
Port: proto.Int32(int32(option.Port)),
Protocol: transportProtocol,
},
},
}
}
} else {
// server is a domain name
if option.PortRange != "" {
server = &mierupb.ServerEndpoint{
DomainName: proto.String(option.Server),
PortBindings: []*mierupb.PortBinding{
{
PortRange: proto.String(option.PortRange),
Protocol: transportProtocol,
},
},
}
} else {
server = &mierupb.ServerEndpoint{
DomainName: proto.String(option.Server),
PortBindings: []*mierupb.PortBinding{
{
Port: proto.Int32(int32(option.Port)),
Protocol: transportProtocol,
},
},
}
}
}
config := &mieruclient.ClientConfig{
Profile: &mierupb.ClientProfile{
ProfileName: proto.String(option.Name),
User: &mierupb.User{
Name: proto.String(option.UserName),
Password: proto.String(option.Password),
},
Servers: []*mierupb.ServerEndpoint{server},
},
}
if multiplexing, ok := mierupb.MultiplexingLevel_value[option.Multiplexing]; ok {
config.Profile.Multiplexing = &mierupb.MultiplexingConfig{
Level: mierupb.MultiplexingLevel(multiplexing).Enum(),
}
}
return config, nil
}
func validateMieruOption(option MieruOption) error {
if option.Name == "" {
return fmt.Errorf("name is empty")
}
if option.Server == "" {
return fmt.Errorf("server is empty")
}
if option.Port == 0 && option.PortRange == "" {
return fmt.Errorf("either port or port-range must be set")
}
if option.Port != 0 && option.PortRange != "" {
return fmt.Errorf("port and port-range cannot be set at the same time")
}
if option.Port != 0 && (option.Port < 1 || option.Port > 65535) {
return fmt.Errorf("port must be between 1 and 65535")
}
if option.PortRange != "" {
begin, end, err := beginAndEndPortFromPortRange(option.PortRange)
if err != nil {
return fmt.Errorf("invalid port-range format")
}
if begin < 1 || begin > 65535 {
return fmt.Errorf("begin port must be between 1 and 65535")
}
if end < 1 || end > 65535 {
return fmt.Errorf("end port must be between 1 and 65535")
}
if begin > end {
return fmt.Errorf("begin port must be less than or equal to end port")
}
}
if option.Transport != "TCP" {
return fmt.Errorf("transport must be TCP")
}
if option.UserName == "" {
return fmt.Errorf("username is empty")
}
if option.Password == "" {
return fmt.Errorf("password is empty")
}
if option.Multiplexing != "" {
if _, ok := mierupb.MultiplexingLevel_value[option.Multiplexing]; !ok {
return fmt.Errorf("invalid multiplexing level: %s", option.Multiplexing)
}
}
return nil
}
func beginAndEndPortFromPortRange(portRange string) (int, int, error) {
var begin, end int
_, err := fmt.Sscanf(portRange, "%d-%d", &begin, &end)
return begin, end, err
}

View file

@ -1,92 +0,0 @@
package outbound
import "testing"
func TestNewMieru(t *testing.T) {
testCases := []struct {
option MieruOption
wantBaseAddr string
}{
{
option: MieruOption{
Name: "test",
Server: "1.2.3.4",
Port: 10000,
Transport: "TCP",
UserName: "test",
Password: "test",
},
wantBaseAddr: "1.2.3.4:10000",
},
{
option: MieruOption{
Name: "test",
Server: "2001:db8::1",
PortRange: "10001-10002",
Transport: "TCP",
UserName: "test",
Password: "test",
},
wantBaseAddr: "[2001:db8::1]:10001",
},
{
option: MieruOption{
Name: "test",
Server: "example.com",
Port: 10003,
Transport: "TCP",
UserName: "test",
Password: "test",
},
wantBaseAddr: "example.com:10003",
},
}
for _, testCase := range testCases {
mieru, err := NewMieru(testCase.option)
if err != nil {
t.Error(err)
}
if mieru.addr != testCase.wantBaseAddr {
t.Errorf("got addr %q, want %q", mieru.addr, testCase.wantBaseAddr)
}
}
}
func TestBeginAndEndPortFromPortRange(t *testing.T) {
testCases := []struct {
input string
begin int
end int
hasErr bool
}{
{"1-10", 1, 10, false},
{"1000-2000", 1000, 2000, false},
{"65535-65535", 65535, 65535, false},
{"1", 0, 0, true},
{"1-", 0, 0, true},
{"-10", 0, 0, true},
{"a-b", 0, 0, true},
{"1-b", 0, 0, true},
{"a-10", 0, 0, true},
}
for _, testCase := range testCases {
begin, end, err := beginAndEndPortFromPortRange(testCase.input)
if testCase.hasErr {
if err == nil {
t.Errorf("beginAndEndPortFromPortRange(%s) should return an error", testCase.input)
}
} else {
if err != nil {
t.Errorf("beginAndEndPortFromPortRange(%s) should not return an error, but got %v", testCase.input, err)
}
if begin != testCase.begin {
t.Errorf("beginAndEndPortFromPortRange(%s) begin port mismatch, got %d, want %d", testCase.input, begin, testCase.begin)
}
if end != testCase.end {
t.Errorf("beginAndEndPortFromPortRange(%s) end port mismatch, got %d, want %d", testCase.input, end, testCase.end)
}
}
}
}

View file

@ -20,19 +20,16 @@ func (o RealityOptions) Parse() (*tlsC.RealityConfig, error) {
config := new(tlsC.RealityConfig) config := new(tlsC.RealityConfig)
const x25519ScalarSize = 32 const x25519ScalarSize = 32
publicKey, err := base64.RawURLEncoding.DecodeString(o.PublicKey) var publicKey [x25519ScalarSize]byte
if err != nil || len(publicKey) != x25519ScalarSize { n, err := base64.RawURLEncoding.Decode(publicKey[:], []byte(o.PublicKey))
if err != nil || n != x25519ScalarSize {
return nil, errors.New("invalid REALITY public key") return nil, errors.New("invalid REALITY public key")
} }
config.PublicKey, err = ecdh.X25519().NewPublicKey(publicKey) config.PublicKey, err = ecdh.X25519().NewPublicKey(publicKey[:])
if err != nil { if err != nil {
return nil, fmt.Errorf("fail to create REALITY public key: %w", err) return nil, fmt.Errorf("fail to create REALITY public key: %w", err)
} }
n := hex.DecodedLen(len(o.ShortID))
if n > tlsC.RealityMaxShortIDLen {
return nil, errors.New("invalid REALITY short id")
}
n, err = hex.Decode(config.ShortID[:], []byte(o.ShortID)) n, err = hex.Decode(config.ShortID[:], []byte(o.ShortID))
if err != nil || n > tlsC.RealityMaxShortIDLen { if err != nil || n > tlsC.RealityMaxShortIDLen {
return nil, errors.New("invalid REALITY short ID") return nil, errors.New("invalid REALITY short ID")

View file

@ -37,7 +37,7 @@ func NewRejectWithOption(option RejectOption) *Reject {
return &Reject{ return &Reject{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
tp: C.Reject, tp: C.Direct,
udp: true, udp: true,
}, },
} }

View file

@ -13,7 +13,6 @@ import (
"github.com/metacubex/mihomo/component/proxydialer" "github.com/metacubex/mihomo/component/proxydialer"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
gost "github.com/metacubex/mihomo/transport/gost-plugin"
"github.com/metacubex/mihomo/transport/restls" "github.com/metacubex/mihomo/transport/restls"
obfs "github.com/metacubex/mihomo/transport/simple-obfs" obfs "github.com/metacubex/mihomo/transport/simple-obfs"
shadowtls "github.com/metacubex/mihomo/transport/sing-shadowtls" shadowtls "github.com/metacubex/mihomo/transport/sing-shadowtls"
@ -35,7 +34,6 @@ type ShadowSocks struct {
obfsMode string obfsMode string
obfsOption *simpleObfsOption obfsOption *simpleObfsOption
v2rayOption *v2rayObfs.Option v2rayOption *v2rayObfs.Option
gostOption *gost.Option
shadowTLSOption *shadowtls.ShadowTLSOption shadowTLSOption *shadowtls.ShadowTLSOption
restlsConfig *restlsC.Config restlsConfig *restlsC.Config
} }
@ -73,17 +71,6 @@ type v2rayObfsOption struct {
V2rayHttpUpgradeFastOpen bool `obfs:"v2ray-http-upgrade-fast-open,omitempty"` V2rayHttpUpgradeFastOpen bool `obfs:"v2ray-http-upgrade-fast-open,omitempty"`
} }
type gostObfsOption struct {
Mode string `obfs:"mode"`
Host string `obfs:"host,omitempty"`
Path string `obfs:"path,omitempty"`
TLS bool `obfs:"tls,omitempty"`
Fingerprint string `obfs:"fingerprint,omitempty"`
Headers map[string]string `obfs:"headers,omitempty"`
SkipCertVerify bool `obfs:"skip-cert-verify,omitempty"`
Mux bool `obfs:"mux,omitempty"`
}
type shadowTLSOption struct { type shadowTLSOption struct {
Password string `obfs:"password"` Password string `obfs:"password"`
Host string `obfs:"host"` Host string `obfs:"host"`
@ -100,7 +87,7 @@ type restlsOption struct {
} }
// StreamConnContext implements C.ProxyAdapter // StreamConnContext implements C.ProxyAdapter
func (ss *ShadowSocks) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ net.Conn, err error) { func (ss *ShadowSocks) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
useEarly := false useEarly := false
switch ss.obfsMode { switch ss.obfsMode {
case "tls": case "tls":
@ -109,23 +96,20 @@ func (ss *ShadowSocks) StreamConnContext(ctx context.Context, c net.Conn, metada
_, port, _ := net.SplitHostPort(ss.addr) _, port, _ := net.SplitHostPort(ss.addr)
c = obfs.NewHTTPObfs(c, ss.obfsOption.Host, port) c = obfs.NewHTTPObfs(c, ss.obfsOption.Host, port)
case "websocket": case "websocket":
if ss.v2rayOption != nil { var err error
c, err = v2rayObfs.NewV2rayObfs(ctx, c, ss.v2rayOption) c, err = v2rayObfs.NewV2rayObfs(ctx, c, ss.v2rayOption)
} else if ss.gostOption != nil {
c, err = gost.NewGostWebsocket(ctx, c, ss.gostOption)
} else {
return nil, fmt.Errorf("plugin options is required")
}
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ss.addr, err) return nil, fmt.Errorf("%s connect error: %w", ss.addr, err)
} }
case shadowtls.Mode: case shadowtls.Mode:
var err error
c, err = shadowtls.NewShadowTLS(ctx, c, ss.shadowTLSOption) c, err = shadowtls.NewShadowTLS(ctx, c, ss.shadowTLSOption)
if err != nil { if err != nil {
return nil, err return nil, err
} }
useEarly = true useEarly = true
case restls.Mode: case restls.Mode:
var err error
c, err = restls.NewRestls(ctx, c, ss.restlsConfig) c, err = restls.NewRestls(ctx, c, ss.restlsConfig)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s (restls) connect error: %w", ss.addr, err) return nil, fmt.Errorf("%s (restls) connect error: %w", ss.addr, err)
@ -133,12 +117,6 @@ func (ss *ShadowSocks) StreamConnContext(ctx context.Context, c net.Conn, metada
useEarly = true useEarly = true
} }
useEarly = useEarly || N.NeedHandshake(c) useEarly = useEarly || N.NeedHandshake(c)
if !useEarly {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
}
if metadata.NetWork == C.UDP && ss.option.UDPOverTCP { if metadata.NetWork == C.UDP && ss.option.UDPOverTCP {
uotDestination := uot.RequestDestination(uint8(ss.option.UDPOverTCPVersion)) uotDestination := uot.RequestDestination(uint8(ss.option.UDPOverTCPVersion))
if useEarly { if useEarly {
@ -171,6 +149,7 @@ func (ss *ShadowSocks) DialContextWithDialer(ctx context.Context, dialer C.Diale
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ss.addr, err) return nil, fmt.Errorf("%s connect error: %w", ss.addr, err)
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
@ -200,7 +179,7 @@ func (ss *ShadowSocks) ListenPacketWithDialer(ctx context.Context, dialer C.Dial
return nil, err return nil, err
} }
} }
addr, err := resolveUDPAddr(ctx, "udp", ss.addr, ss.prefer) addr, err := resolveUDPAddrWithPrefer(ctx, "udp", ss.addr, ss.prefer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -218,13 +197,6 @@ func (ss *ShadowSocks) SupportWithDialer() C.NetWork {
return C.ALLNet return C.ALLNet
} }
// ProxyInfo implements C.ProxyAdapter
func (ss *ShadowSocks) ProxyInfo() C.ProxyInfo {
info := ss.Base.ProxyInfo()
info.DialerProxy = ss.option.DialerProxy
return info
}
// ListenPacketOnStreamConn implements C.ProxyAdapter // ListenPacketOnStreamConn implements C.ProxyAdapter
func (ss *ShadowSocks) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) { func (ss *ShadowSocks) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) {
if ss.option.UDPOverTCP { if ss.option.UDPOverTCP {
@ -258,11 +230,10 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
Password: option.Password, Password: option.Password,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("ss %s cipher: %s initialize error: %w", addr, option.Cipher, err) return nil, fmt.Errorf("ss %s initialize error: %w", addr, err)
} }
var v2rayOption *v2rayObfs.Option var v2rayOption *v2rayObfs.Option
var gostOption *gost.Option
var obfsOption *simpleObfsOption var obfsOption *simpleObfsOption
var shadowTLSOpt *shadowtls.ShadowTLSOption var shadowTLSOpt *shadowtls.ShadowTLSOption
var restlsConfig *restlsC.Config var restlsConfig *restlsC.Config
@ -304,28 +275,6 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
v2rayOption.SkipCertVerify = opts.SkipCertVerify v2rayOption.SkipCertVerify = opts.SkipCertVerify
v2rayOption.Fingerprint = opts.Fingerprint v2rayOption.Fingerprint = opts.Fingerprint
} }
} else if option.Plugin == "gost-plugin" {
opts := gostObfsOption{Host: "bing.com", Mux: true}
if err := decoder.Decode(option.PluginOpts, &opts); err != nil {
return nil, fmt.Errorf("ss %s initialize gost-plugin error: %w", addr, err)
}
if opts.Mode != "websocket" {
return nil, fmt.Errorf("ss %s obfs mode error: %s", addr, opts.Mode)
}
obfsMode = opts.Mode
gostOption = &gost.Option{
Host: opts.Host,
Path: opts.Path,
Headers: opts.Headers,
Mux: opts.Mux,
}
if opts.TLS {
gostOption.TLS = true
gostOption.SkipCertVerify = opts.SkipCertVerify
gostOption.Fingerprint = opts.Fingerprint
}
} else if option.Plugin == shadowtls.Mode { } else if option.Plugin == shadowtls.Mode {
obfsMode = shadowtls.Mode obfsMode = shadowtls.Mode
opt := &shadowTLSOption{ opt := &shadowTLSOption{
@ -381,7 +330,6 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
option: &option, option: &option,
obfsMode: obfsMode, obfsMode: obfsMode,
v2rayOption: v2rayOption, v2rayOption: v2rayOption,
gostOption: gostOption,
obfsOption: obfsOption, obfsOption: obfsOption,
shadowTLSOption: shadowTLSOpt, shadowTLSOption: shadowTLSOpt,
restlsConfig: restlsConfig, restlsConfig: restlsConfig,

View file

@ -42,15 +42,12 @@ type ShadowSocksROption struct {
} }
// StreamConnContext implements C.ProxyAdapter // StreamConnContext implements C.ProxyAdapter
func (ssr *ShadowSocksR) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ net.Conn, err error) { func (ssr *ShadowSocksR) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
c = ssr.obfs.StreamConn(c) c = ssr.obfs.StreamConn(c)
c = ssr.cipher.StreamConn(c) c = ssr.cipher.StreamConn(c)
var ( var (
iv []byte iv []byte
err error
) )
switch conn := c.(type) { switch conn := c.(type) {
case *shadowstream.Conn: case *shadowstream.Conn:
@ -83,6 +80,7 @@ func (ssr *ShadowSocksR) DialContextWithDialer(ctx context.Context, dialer C.Dia
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ssr.addr, err) return nil, fmt.Errorf("%s connect error: %w", ssr.addr, err)
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
@ -105,7 +103,7 @@ func (ssr *ShadowSocksR) ListenPacketWithDialer(ctx context.Context, dialer C.Di
return nil, err return nil, err
} }
} }
addr, err := resolveUDPAddr(ctx, "udp", ssr.addr, ssr.prefer) addr, err := resolveUDPAddrWithPrefer(ctx, "udp", ssr.addr, ssr.prefer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -125,13 +123,6 @@ func (ssr *ShadowSocksR) SupportWithDialer() C.NetWork {
return C.ALLNet return C.ALLNet
} }
// ProxyInfo implements C.ProxyAdapter
func (ssr *ShadowSocksR) ProxyInfo() C.ProxyInfo {
info := ssr.Base.ProxyInfo()
info.DialerProxy = ssr.option.DialerProxy
return info
}
func NewShadowSocksR(option ShadowSocksROption) (*ShadowSocksR, error) { func NewShadowSocksR(option ShadowSocksROption) (*ShadowSocksR, error) {
// SSR protocol compatibility // SSR protocol compatibility
// https://github.com/metacubex/mihomo/pull/2056 // https://github.com/metacubex/mihomo/pull/2056
@ -144,7 +135,7 @@ func NewShadowSocksR(option ShadowSocksROption) (*ShadowSocksR, error) {
password := option.Password password := option.Password
coreCiph, err := core.PickCipher(cipher, nil, password) coreCiph, err := core.PickCipher(cipher, nil, password)
if err != nil { if err != nil {
return nil, fmt.Errorf("ssr %s cipher: %s initialize error: %w", addr, cipher, err) return nil, fmt.Errorf("ssr %s initialize error: %w", addr, err)
} }
var ( var (
ivSize int ivSize int

View file

@ -3,6 +3,7 @@ package outbound
import ( import (
"context" "context"
"errors" "errors"
"runtime"
CN "github.com/metacubex/mihomo/common/net" CN "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/dialer"
@ -17,7 +18,8 @@ import (
) )
type SingMux struct { type SingMux struct {
ProxyAdapter C.ProxyAdapter
base ProxyBase
client *mux.Client client *mux.Client
dialer proxydialer.SingDialer dialer proxydialer.SingDialer
onlyTcp bool onlyTcp bool
@ -41,21 +43,25 @@ type BrutalOption struct {
Down string `proxy:"down,omitempty"` Down string `proxy:"down,omitempty"`
} }
type ProxyBase interface {
DialOptions(opts ...dialer.Option) []dialer.Option
}
func (s *SingMux) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { func (s *SingMux) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
options := s.ProxyAdapter.DialOptions(opts...) options := s.base.DialOptions(opts...)
s.dialer.SetDialer(dialer.NewDialer(options...)) s.dialer.SetDialer(dialer.NewDialer(options...))
c, err := s.client.DialContext(ctx, "tcp", M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort)) c, err := s.client.DialContext(ctx, "tcp", M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewConn(c, s), err return NewConn(CN.NewRefConn(c, s), s.ProxyAdapter), err
} }
func (s *SingMux) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { func (s *SingMux) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
if s.onlyTcp { if s.onlyTcp {
return s.ProxyAdapter.ListenPacketContext(ctx, metadata, opts...) return s.ProxyAdapter.ListenPacketContext(ctx, metadata, opts...)
} }
options := s.ProxyAdapter.DialOptions(opts...) options := s.base.DialOptions(opts...)
s.dialer.SetDialer(dialer.NewDialer(options...)) s.dialer.SetDialer(dialer.NewDialer(options...))
// sing-mux use stream-oriented udp with a special address, so we need a net.UDPAddr // sing-mux use stream-oriented udp with a special address, so we need a net.UDPAddr
@ -74,7 +80,7 @@ func (s *SingMux) ListenPacketContext(ctx context.Context, metadata *C.Metadata,
if pc == nil { if pc == nil {
return nil, E.New("packetConn is nil") return nil, E.New("packetConn is nil")
} }
return newPacketConn(CN.NewThreadSafePacketConn(pc), s), nil return newPacketConn(CN.NewRefPacketConn(CN.NewThreadSafePacketConn(pc), s), s.ProxyAdapter), nil
} }
func (s *SingMux) SupportUDP() bool { func (s *SingMux) SupportUDP() bool {
@ -91,21 +97,11 @@ func (s *SingMux) SupportUOT() bool {
return true return true
} }
func (s *SingMux) ProxyInfo() C.ProxyInfo { func closeSingMux(s *SingMux) {
info := s.ProxyAdapter.ProxyInfo()
info.SMUX = true
return info
}
// Close implements C.ProxyAdapter
func (s *SingMux) Close() error {
if s.client != nil {
_ = s.client.Close() _ = s.client.Close()
} }
return s.ProxyAdapter.Close()
}
func NewSingMux(option SingMuxOption, proxy ProxyAdapter) (ProxyAdapter, error) { func NewSingMux(option SingMuxOption, proxy C.ProxyAdapter, base ProxyBase) (C.ProxyAdapter, error) {
// TODO // TODO
// "TCP Brutal is only supported on Linux-based systems" // "TCP Brutal is only supported on Linux-based systems"
@ -129,9 +125,11 @@ func NewSingMux(option SingMuxOption, proxy ProxyAdapter) (ProxyAdapter, error)
} }
outbound := &SingMux{ outbound := &SingMux{
ProxyAdapter: proxy, ProxyAdapter: proxy,
base: base,
client: client, client: client,
dialer: singDialer, dialer: singDialer,
onlyTcp: option.OnlyTcp, onlyTcp: option.OnlyTcp,
} }
runtime.SetFinalizer(outbound, closeSingMux)
return outbound, nil return outbound, nil
} }

View file

@ -42,7 +42,7 @@ type streamOption struct {
obfsOption *simpleObfsOption obfsOption *simpleObfsOption
} }
func snellStreamConn(c net.Conn, option streamOption) *snell.Snell { func streamConn(c net.Conn, option streamOption) *snell.Snell {
switch option.obfsOption.Mode { switch option.obfsOption.Mode {
case "tls": case "tls":
c = obfs.NewTLSObfs(c, option.obfsOption.Host) c = obfs.NewTLSObfs(c, option.obfsOption.Host)
@ -55,35 +55,25 @@ func snellStreamConn(c net.Conn, option streamOption) *snell.Snell {
// StreamConnContext implements C.ProxyAdapter // StreamConnContext implements C.ProxyAdapter
func (s *Snell) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { func (s *Snell) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
c = snellStreamConn(c, streamOption{s.psk, s.version, s.addr, s.obfsOption}) c = streamConn(c, streamOption{s.psk, s.version, s.addr, s.obfsOption})
err := s.writeHeaderContext(ctx, c, metadata) if metadata.NetWork == C.UDP {
err := snell.WriteUDPHeader(c, s.version)
return c, err return c, err
} }
err := snell.WriteHeader(c, metadata.String(), uint(metadata.DstPort), s.version)
func (s *Snell) writeHeaderContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (err error) { return c, err
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
if metadata.NetWork == C.UDP {
err = snell.WriteUDPHeader(c, s.version)
return
}
err = snell.WriteHeader(c, metadata.String(), uint(metadata.DstPort), s.version)
return
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
if s.version == snell.Version2 && dialer.IsZeroOptions(opts) { if s.version == snell.Version2 && len(opts) == 0 {
c, err := s.pool.Get() c, err := s.pool.Get()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err = s.writeHeaderContext(ctx, c, metadata); err != nil { if err = snell.WriteHeader(c, metadata.String(), uint(metadata.DstPort), s.version); err != nil {
_ = c.Close() c.Close()
return nil, err return nil, err
} }
return NewConn(c, s), err return NewConn(c, s), err
@ -104,6 +94,7 @@ func (s *Snell) DialContextWithDialer(ctx context.Context, dialer C.Dialer, meta
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", s.addr, err) return nil, fmt.Errorf("%s connect error: %w", s.addr, err)
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
@ -131,8 +122,13 @@ func (s *Snell) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met
if err != nil { if err != nil {
return nil, err return nil, err
} }
N.TCPKeepAlive(c)
c = streamConn(c, streamOption{s.psk, s.version, s.addr, s.obfsOption})
c, err = s.StreamConnContext(ctx, c, metadata) err = snell.WriteUDPHeader(c, s.version)
if err != nil {
return nil, err
}
pc := snell.PacketConn(c) pc := snell.PacketConn(c)
return newPacketConn(pc, s), nil return newPacketConn(pc, s), nil
@ -148,13 +144,6 @@ func (s *Snell) SupportUOT() bool {
return true return true
} }
// ProxyInfo implements C.ProxyAdapter
func (s *Snell) ProxyInfo() C.ProxyInfo {
info := s.Base.ProxyInfo()
info.DialerProxy = s.option.DialerProxy
return info
}
func NewSnell(option SnellOption) (*Snell, error) { func NewSnell(option SnellOption) (*Snell, error) {
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
psk := []byte(option.Psk) psk := []byte(option.Psk)
@ -219,7 +208,8 @@ func NewSnell(option SnellOption) (*Snell, error) {
return nil, err return nil, err
} }
return snellStreamConn(c, streamOption{psk, option.Version, addr, obfsOption}), nil N.TCPKeepAlive(c)
return streamConn(c, streamOption{psk, option.Version, addr, obfsOption}), nil
}) })
} }
return s, nil return s, nil

View file

@ -59,7 +59,7 @@ func (ss *Socks5) StreamConnContext(ctx context.Context, c net.Conn, metadata *C
Password: ss.pass, Password: ss.pass,
} }
} }
if _, err := ss.clientHandshakeContext(ctx, c, serializesSocksAddr(metadata), socks5.CmdConnect, user); err != nil { if _, err := socks5.ClientHandshake(c, serializesSocksAddr(metadata), socks5.CmdConnect, user); err != nil {
return nil, err return nil, err
} }
return c, nil return c, nil
@ -82,6 +82,7 @@ func (ss *Socks5) DialContextWithDialer(ctx context.Context, dialer C.Dialer, me
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ss.addr, err) return nil, fmt.Errorf("%s connect error: %w", ss.addr, err)
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
@ -127,6 +128,7 @@ func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata,
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
N.TCPKeepAlive(c)
var user *socks5.User var user *socks5.User
if ss.user != "" { if ss.user != "" {
user = &socks5.User{ user = &socks5.User{
@ -136,7 +138,7 @@ func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata,
} }
udpAssocateAddr := socks5.AddrFromStdAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)) udpAssocateAddr := socks5.AddrFromStdAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
bindAddr, err := ss.clientHandshakeContext(ctx, c, udpAssocateAddr, socks5.CmdUDPAssociate, user) bindAddr, err := socks5.ClientHandshake(c, udpAssocateAddr, socks5.CmdUDPAssociate, user)
if err != nil { if err != nil {
err = fmt.Errorf("client hanshake error: %w", err) err = fmt.Errorf("client hanshake error: %w", err)
return return
@ -148,7 +150,7 @@ func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata,
err = errors.New("invalid UDP bind address") err = errors.New("invalid UDP bind address")
return return
} else if bindUDPAddr.IP.IsUnspecified() { } else if bindUDPAddr.IP.IsUnspecified() {
serverAddr, err := resolveUDPAddr(ctx, "udp", ss.Addr(), C.IPv4Prefer) serverAddr, err := resolveUDPAddr(ctx, "udp", ss.Addr())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -172,21 +174,6 @@ func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata,
return newPacketConn(&socksPacketConn{PacketConn: pc, rAddr: bindUDPAddr, tcpConn: c}, ss), nil return newPacketConn(&socksPacketConn{PacketConn: pc, rAddr: bindUDPAddr, tcpConn: c}, ss), nil
} }
// ProxyInfo implements C.ProxyAdapter
func (ss *Socks5) ProxyInfo() C.ProxyInfo {
info := ss.Base.ProxyInfo()
info.DialerProxy = ss.option.DialerProxy
return info
}
func (ss *Socks5) clientHandshakeContext(ctx context.Context, c net.Conn, addr socks5.Addr, command socks5.Command, user *socks5.User) (_ socks5.Addr, err error) {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
return socks5.ClientHandshake(c, addr, command, user)
}
func NewSocks5(option Socks5Option) (*Socks5, error) { func NewSocks5(option Socks5Option) (*Socks5, error) {
var tlsConfig *tls.Config var tlsConfig *tls.Config
if option.TLS { if option.TLS {

View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -24,10 +25,7 @@ type Ssh struct {
*Base *Base
option *SshOption option *SshOption
client *sshClient // using a standalone struct to avoid its inner loop invalidate the Finalizer
config *ssh.ClientConfig
client *ssh.Client
cMutex sync.Mutex
} }
type SshOption struct { type SshOption struct {
@ -51,7 +49,7 @@ func (s *Ssh) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dia
return nil, err return nil, err
} }
} }
client, err := s.connect(ctx, cDialer, s.addr) client, err := s.client.connect(ctx, cDialer, s.addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -60,10 +58,16 @@ func (s *Ssh) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dia
return nil, err return nil, err
} }
return NewConn(c, s), nil return NewConn(N.NewRefConn(c, s), s), nil
} }
func (s *Ssh) connect(ctx context.Context, cDialer C.Dialer, addr string) (client *ssh.Client, err error) { type sshClient struct {
config *ssh.ClientConfig
client *ssh.Client
cMutex sync.Mutex
}
func (s *sshClient) connect(ctx context.Context, cDialer C.Dialer, addr string) (client *ssh.Client, err error) {
s.cMutex.Lock() s.cMutex.Lock()
defer s.cMutex.Unlock() defer s.cMutex.Unlock()
if s.client != nil { if s.client != nil {
@ -73,6 +77,7 @@ func (s *Ssh) connect(ctx context.Context, cDialer C.Dialer, addr string) (clien
if err != nil { if err != nil {
return nil, err return nil, err
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
@ -104,15 +109,7 @@ func (s *Ssh) connect(ctx context.Context, cDialer C.Dialer, addr string) (clien
return client, nil return client, nil
} }
// ProxyInfo implements C.ProxyAdapter func (s *sshClient) Close() error {
func (s *Ssh) ProxyInfo() C.ProxyInfo {
info := s.Base.ProxyInfo()
info.DialerProxy = s.option.DialerProxy
return info
}
// Close implements C.ProxyAdapter
func (s *Ssh) Close() error {
s.cMutex.Lock() s.cMutex.Lock()
defer s.cMutex.Unlock() defer s.cMutex.Unlock()
if s.client != nil { if s.client != nil {
@ -121,6 +118,10 @@ func (s *Ssh) Close() error {
return nil return nil
} }
func closeSsh(s *Ssh) {
_ = s.client.Close()
}
func NewSsh(option SshOption) (*Ssh, error) { func NewSsh(option SshOption) (*Ssh, error) {
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
@ -197,8 +198,11 @@ func NewSsh(option SshOption) (*Ssh, error) {
prefer: C.NewDNSPrefer(option.IPVersion), prefer: C.NewDNSPrefer(option.IPVersion),
}, },
option: &option, option: &option,
client: &sshClient{
config: &config, config: &config,
},
} }
runtime.SetFinalizer(outbound, closeSsh)
return outbound, nil return outbound, nil
} }

View file

@ -18,13 +18,12 @@ import (
"github.com/metacubex/mihomo/transport/gun" "github.com/metacubex/mihomo/transport/gun"
"github.com/metacubex/mihomo/transport/shadowsocks/core" "github.com/metacubex/mihomo/transport/shadowsocks/core"
"github.com/metacubex/mihomo/transport/trojan" "github.com/metacubex/mihomo/transport/trojan"
"github.com/metacubex/mihomo/transport/vmess"
) )
type Trojan struct { type Trojan struct {
*Base *Base
instance *trojan.Trojan
option *TrojanOption option *TrojanOption
hexPassword [trojan.KeyLength]byte
// for gun mux // for gun mux
gunTLSConfig *tls.Config gunTLSConfig *tls.Config
@ -62,21 +61,15 @@ type TrojanSSOption struct {
Password string `proxy:"password,omitempty"` Password string `proxy:"password,omitempty"`
} }
// StreamConnContext implements C.ProxyAdapter func (t *Trojan) plainStream(ctx context.Context, c net.Conn) (net.Conn, error) {
func (t *Trojan) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ net.Conn, err error) { if t.option.Network == "ws" {
switch t.option.Network {
case "ws":
host, port, _ := net.SplitHostPort(t.addr) host, port, _ := net.SplitHostPort(t.addr)
wsOpts := &trojan.WebsocketOption{
wsOpts := &vmess.WebsocketConfig{
Host: host, Host: host,
Port: port, Port: port,
Path: t.option.WSOpts.Path, Path: t.option.WSOpts.Path,
MaxEarlyData: t.option.WSOpts.MaxEarlyData,
EarlyDataHeaderName: t.option.WSOpts.EarlyDataHeaderName,
V2rayHttpUpgrade: t.option.WSOpts.V2rayHttpUpgrade, V2rayHttpUpgrade: t.option.WSOpts.V2rayHttpUpgrade,
V2rayHttpUpgradeFastOpen: t.option.WSOpts.V2rayHttpUpgradeFastOpen, V2rayHttpUpgradeFastOpen: t.option.WSOpts.V2rayHttpUpgradeFastOpen,
ClientFingerprint: t.option.ClientFingerprint,
Headers: http.Header{}, Headers: http.Header{},
} }
@ -90,95 +83,57 @@ func (t *Trojan) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.
} }
} }
alpn := trojan.DefaultWebsocketALPN return t.instance.StreamWebsocketConn(ctx, c, wsOpts)
if len(t.option.ALPN) != 0 {
alpn = t.option.ALPN
} }
wsOpts.TLS = true return t.instance.StreamConn(ctx, c)
tlsConfig := &tls.Config{
NextProtos: alpn,
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: t.option.SkipCertVerify,
ServerName: t.option.SNI,
} }
wsOpts.TLSConfig, err = ca.GetSpecifiedFingerprintTLSConfig(tlsConfig, t.option.Fingerprint) // StreamConnContext implements C.ProxyAdapter
if err != nil { func (t *Trojan) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
return nil, err var err error
if tlsC.HaveGlobalFingerprint() && len(t.option.ClientFingerprint) == 0 {
t.option.ClientFingerprint = tlsC.GetGlobalFingerprint()
} }
c, err = vmess.StreamWebsocketConn(ctx, c, wsOpts) if t.transport != nil {
case "grpc":
c, err = gun.StreamGunWithConn(c, t.gunTLSConfig, t.gunConfig, t.realityConfig) c, err = gun.StreamGunWithConn(c, t.gunTLSConfig, t.gunConfig, t.realityConfig)
default: } else {
// default tcp network c, err = t.plainStream(ctx, c)
// handle TLS
alpn := trojan.DefaultALPN
if len(t.option.ALPN) != 0 {
alpn = t.option.ALPN
}
c, err = vmess.StreamTLSConn(ctx, c, &vmess.TLSConfig{
Host: t.option.SNI,
SkipCertVerify: t.option.SkipCertVerify,
FingerPrint: t.option.Fingerprint,
ClientFingerprint: t.option.ClientFingerprint,
NextProtos: alpn,
Reality: t.realityConfig,
})
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err) return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
} }
return t.streamConnContext(ctx, c, metadata)
}
func (t *Trojan) streamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ net.Conn, err error) {
if t.ssCipher != nil { if t.ssCipher != nil {
c = t.ssCipher.StreamConn(c) c = t.ssCipher.StreamConn(c)
} }
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
command := trojan.CommandTCP
if metadata.NetWork == C.UDP { if metadata.NetWork == C.UDP {
command = trojan.CommandUDP err = t.instance.WriteHeader(c, trojan.CommandUDP, serializesSocksAddr(metadata))
}
err = trojan.WriteHeader(c, t.hexPassword, command, serializesSocksAddr(metadata))
return c, err return c, err
} }
err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata))
func (t *Trojan) writeHeaderContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (err error) { return c, err
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
command := trojan.CommandTCP
if metadata.NetWork == C.UDP {
command = trojan.CommandUDP
}
err = trojan.WriteHeader(c, t.hexPassword, command, serializesSocksAddr(metadata))
return err
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
var c net.Conn
// gun transport // gun transport
if t.transport != nil && dialer.IsZeroOptions(opts) { if t.transport != nil && len(opts) == 0 {
c, err = gun.StreamGunWithTransport(t.transport, t.gunConfig) c, err := gun.StreamGunWithTransport(t.transport, t.gunConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func(c net.Conn) {
safeConnClose(c, err)
}(c)
c, err = t.streamConnContext(ctx, c, metadata) if t.ssCipher != nil {
if err != nil { c = t.ssCipher.StreamConn(c)
}
if err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata)); err != nil {
c.Close()
return nil, err return nil, err
} }
@ -199,6 +154,7 @@ func (t *Trojan) DialContextWithDialer(ctx context.Context, dialer C.Dialer, met
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err) return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
@ -217,7 +173,7 @@ func (t *Trojan) ListenPacketContext(ctx context.Context, metadata *C.Metadata,
var c net.Conn var c net.Conn
// grpc transport // grpc transport
if t.transport != nil && dialer.IsZeroOptions(opts) { if t.transport != nil && len(opts) == 0 {
c, err = gun.StreamGunWithTransport(t.transport, t.gunConfig) c, err = gun.StreamGunWithTransport(t.transport, t.gunConfig)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err) return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
@ -226,12 +182,16 @@ func (t *Trojan) ListenPacketContext(ctx context.Context, metadata *C.Metadata,
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
c, err = t.streamConnContext(ctx, c, metadata) if t.ssCipher != nil {
c = t.ssCipher.StreamConn(c)
}
err = t.instance.WriteHeader(c, trojan.CommandUDP, serializesSocksAddr(metadata))
if err != nil { if err != nil {
return nil, err return nil, err
} }
pc := trojan.NewPacketConn(c) pc := t.instance.PacketConn(c)
return newPacketConn(pc, t), err return newPacketConn(pc, t), err
} }
return t.ListenPacketWithDialer(ctx, dialer.NewDialer(t.Base.DialOptions(opts...)...), metadata) return t.ListenPacketWithDialer(ctx, dialer.NewDialer(t.Base.DialOptions(opts...)...), metadata)
@ -252,12 +212,22 @@ func (t *Trojan) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, me
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
c, err = t.StreamConnContext(ctx, c, metadata) N.TCPKeepAlive(c)
c, err = t.plainStream(ctx, c)
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
}
if t.ssCipher != nil {
c = t.ssCipher.StreamConn(c)
}
err = t.instance.WriteHeader(c, trojan.CommandUDP, serializesSocksAddr(metadata))
if err != nil { if err != nil {
return nil, err return nil, err
} }
pc := trojan.NewPacketConn(c) pc := t.instance.PacketConn(c)
return newPacketConn(pc, t), err return newPacketConn(pc, t), err
} }
@ -268,7 +238,7 @@ func (t *Trojan) SupportWithDialer() C.NetWork {
// ListenPacketOnStreamConn implements C.ProxyAdapter // ListenPacketOnStreamConn implements C.ProxyAdapter
func (t *Trojan) ListenPacketOnStreamConn(c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) { func (t *Trojan) ListenPacketOnStreamConn(c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) {
pc := trojan.NewPacketConn(c) pc := t.instance.PacketConn(c)
return newPacketConn(pc, t), err return newPacketConn(pc, t), err
} }
@ -277,24 +247,22 @@ func (t *Trojan) SupportUOT() bool {
return true return true
} }
// ProxyInfo implements C.ProxyAdapter
func (t *Trojan) ProxyInfo() C.ProxyInfo {
info := t.Base.ProxyInfo()
info.DialerProxy = t.option.DialerProxy
return info
}
// Close implements C.ProxyAdapter
func (t *Trojan) Close() error {
if t.transport != nil {
return t.transport.Close()
}
return nil
}
func NewTrojan(option TrojanOption) (*Trojan, error) { func NewTrojan(option TrojanOption) (*Trojan, error) {
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
tOption := &trojan.Option{
Password: option.Password,
ALPN: option.ALPN,
ServerName: option.Server,
SkipCertVerify: option.SkipCertVerify,
Fingerprint: option.Fingerprint,
ClientFingerprint: option.ClientFingerprint,
}
if option.SNI != "" {
tOption.ServerName = option.SNI
}
t := &Trojan{ t := &Trojan{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
@ -307,8 +275,8 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: C.NewDNSPrefer(option.IPVersion),
}, },
instance: trojan.New(tOption),
option: &option, option: &option,
hexPassword: trojan.Key(option.Password),
} }
var err error var err error
@ -316,6 +284,7 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
tOption.Reality = t.realityConfig
if option.SSOpts.Enabled { if option.SSOpts.Enabled {
if option.SSOpts.Password == "" { if option.SSOpts.Password == "" {
@ -332,7 +301,7 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
} }
if option.Network == "grpc" { if option.Network == "grpc" {
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) { dialFn := func(network, addr string) (net.Conn, error) {
var err error var err error
var cDialer C.Dialer = dialer.NewDialer(t.Base.DialOptions()...) var cDialer C.Dialer = dialer.NewDialer(t.Base.DialOptions()...)
if len(t.option.DialerProxy) > 0 { if len(t.option.DialerProxy) > 0 {
@ -341,18 +310,19 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
return nil, err return nil, err
} }
} }
c, err := cDialer.DialContext(ctx, "tcp", t.addr) c, err := cDialer.DialContext(context.Background(), "tcp", t.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", t.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", t.addr, err.Error())
} }
N.TCPKeepAlive(c)
return c, nil return c, nil
} }
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
NextProtos: option.ALPN, NextProtos: option.ALPN,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
InsecureSkipVerify: option.SkipCertVerify, InsecureSkipVerify: tOption.SkipCertVerify,
ServerName: option.SNI, ServerName: tOption.ServerName,
} }
var err error var err error
@ -361,13 +331,12 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
return nil, err return nil, err
} }
t.transport = gun.NewHTTP2Client(dialFn, tlsConfig, option.ClientFingerprint, t.realityConfig) t.transport = gun.NewHTTP2Client(dialFn, tlsConfig, tOption.ClientFingerprint, t.realityConfig)
t.gunTLSConfig = tlsConfig t.gunTLSConfig = tlsConfig
t.gunConfig = &gun.Config{ t.gunConfig = &gun.Config{
ServiceName: option.GrpcOpts.GrpcServiceName, ServiceName: option.GrpcOpts.GrpcServiceName,
Host: option.SNI, Host: tOption.ServerName,
ClientFingerprint: option.ClientFingerprint,
} }
} }

View file

@ -130,7 +130,7 @@ func (t *Tuic) dialWithDialer(ctx context.Context, dialer C.Dialer) (transport *
return nil, nil, err return nil, nil, err
} }
} }
udpAddr, err := resolveUDPAddr(ctx, "udp", t.addr, t.prefer) udpAddr, err := resolveUDPAddrWithPrefer(ctx, "udp", t.addr, t.prefer)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -146,13 +146,6 @@ func (t *Tuic) dialWithDialer(ctx context.Context, dialer C.Dialer) (transport *
return return
} }
// ProxyInfo implements C.ProxyAdapter
func (t *Tuic) ProxyInfo() C.ProxyInfo {
info := t.Base.ProxyInfo()
info.DialerProxy = t.option.DialerProxy
return info
}
func NewTuic(option TuicOption) (*Tuic, error) { func NewTuic(option TuicOption) (*Tuic, error) {
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
serverName := option.Server serverName := option.Server

View file

@ -3,59 +3,118 @@ package outbound
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"regexp" "regexp"
"strconv" "strconv"
"sync"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/socks5" "github.com/metacubex/mihomo/transport/socks5"
) )
var (
globalClientSessionCache tls.ClientSessionCache
once sync.Once
)
func getClientSessionCache() tls.ClientSessionCache {
once.Do(func() {
globalClientSessionCache = tls.NewLRUClientSessionCache(128)
})
return globalClientSessionCache
}
func serializesSocksAddr(metadata *C.Metadata) []byte { func serializesSocksAddr(metadata *C.Metadata) []byte {
var buf [][]byte var buf [][]byte
addrType := metadata.AddrType() addrType := metadata.AddrType()
aType := uint8(addrType)
p := uint(metadata.DstPort) p := uint(metadata.DstPort)
port := []byte{uint8(p >> 8), uint8(p & 0xff)} port := []byte{uint8(p >> 8), uint8(p & 0xff)}
switch addrType { switch addrType {
case C.AtypDomainName: case socks5.AtypDomainName:
lenM := uint8(len(metadata.Host)) lenM := uint8(len(metadata.Host))
host := []byte(metadata.Host) host := []byte(metadata.Host)
buf = [][]byte{{socks5.AtypDomainName, lenM}, host, port} buf = [][]byte{{aType, lenM}, host, port}
case C.AtypIPv4: case socks5.AtypIPv4:
host := metadata.DstIP.AsSlice() host := metadata.DstIP.AsSlice()
buf = [][]byte{{socks5.AtypIPv4}, host, port} buf = [][]byte{{aType}, host, port}
case C.AtypIPv6: case socks5.AtypIPv6:
host := metadata.DstIP.AsSlice() host := metadata.DstIP.AsSlice()
buf = [][]byte{{socks5.AtypIPv6}, host, port} buf = [][]byte{{aType}, host, port}
} }
return bytes.Join(buf, nil) return bytes.Join(buf, nil)
} }
func resolveUDPAddr(ctx context.Context, network, address string, prefer C.DNSPrefer) (*net.UDPAddr, error) { func resolveUDPAddr(ctx context.Context, network, address string) (*net.UDPAddr, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
ip, err := resolver.ResolveProxyServerHost(ctx, host)
if err != nil {
return nil, err
}
return net.ResolveUDPAddr(network, net.JoinHostPort(ip.String(), port))
}
func resolveUDPAddrWithPrefer(ctx context.Context, network, address string, prefer C.DNSPrefer) (*net.UDPAddr, error) {
host, port, err := net.SplitHostPort(address) host, port, err := net.SplitHostPort(address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var ip netip.Addr var ip netip.Addr
var fallback netip.Addr
switch prefer { switch prefer {
case C.IPv4Only: case C.IPv4Only:
ip, err = resolver.ResolveIPv4WithResolver(ctx, host, resolver.ProxyServerHostResolver) ip, err = resolver.ResolveIPv4ProxyServerHost(ctx, host)
case C.IPv6Only: case C.IPv6Only:
ip, err = resolver.ResolveIPv6WithResolver(ctx, host, resolver.ProxyServerHostResolver) ip, err = resolver.ResolveIPv6ProxyServerHost(ctx, host)
case C.IPv6Prefer: case C.IPv6Prefer:
ip, err = resolver.ResolveIPPrefer6WithResolver(ctx, host, resolver.ProxyServerHostResolver) var ips []netip.Addr
ips, err = resolver.LookupIPProxyServerHost(ctx, host)
if err == nil {
for _, addr := range ips {
if addr.Is6() {
ip = addr
break
} else {
if !fallback.IsValid() {
fallback = addr
}
}
}
}
default: default:
ip, err = resolver.ResolveIPWithResolver(ctx, host, resolver.ProxyServerHostResolver) // C.IPv4Prefer, C.DualStack and other
var ips []netip.Addr
ips, err = resolver.LookupIPProxyServerHost(ctx, host)
if err == nil {
for _, addr := range ips {
if addr.Is4() {
ip = addr
break
} else {
if !fallback.IsValid() {
fallback = addr
}
}
}
}
}
if !ip.IsValid() && fallback.IsValid() {
ip = fallback
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
ip, port = resolver.LookupIP4P(ip, port)
return net.ResolveUDPAddr(network, net.JoinHostPort(ip.String(), port)) return net.ResolveUDPAddr(network, net.JoinHostPort(ip.String(), port))
} }

View file

@ -21,7 +21,9 @@ import (
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
tlsC "github.com/metacubex/mihomo/component/tls" tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/transport/gun" "github.com/metacubex/mihomo/transport/gun"
"github.com/metacubex/mihomo/transport/socks5"
"github.com/metacubex/mihomo/transport/vless" "github.com/metacubex/mihomo/transport/vless"
"github.com/metacubex/mihomo/transport/vmess" "github.com/metacubex/mihomo/transport/vmess"
@ -75,7 +77,13 @@ type VlessOption struct {
ClientFingerprint string `proxy:"client-fingerprint,omitempty"` ClientFingerprint string `proxy:"client-fingerprint,omitempty"`
} }
func (v *Vless) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ net.Conn, err error) { func (v *Vless) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
var err error
if tlsC.HaveGlobalFingerprint() && len(v.option.ClientFingerprint) == 0 {
v.option.ClientFingerprint = tlsC.GetGlobalFingerprint()
}
switch v.option.Network { switch v.option.Network {
case "ws": case "ws":
host, port, _ := net.SplitHostPort(v.addr) host, port, _ := net.SplitHostPort(v.addr)
@ -149,7 +157,7 @@ func (v *Vless) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
Path: v.option.HTTP2Opts.Path, Path: v.option.HTTP2Opts.Path,
} }
c, err = vmess.StreamH2Conn(ctx, c, h2Opts) c, err = vmess.StreamH2Conn(c, h2Opts)
case "grpc": case "grpc":
c, err = gun.StreamGunWithConn(c, v.gunTLSConfig, v.gunConfig, v.realityConfig) c, err = gun.StreamGunWithConn(c, v.gunTLSConfig, v.gunConfig, v.realityConfig)
default: default:
@ -162,14 +170,10 @@ func (v *Vless) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
return nil, err return nil, err
} }
return v.streamConnContext(ctx, c, metadata) return v.streamConn(c, metadata)
} }
func (v *Vless) streamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (conn net.Conn, err error) { func (v *Vless) streamConn(c net.Conn, metadata *C.Metadata) (conn net.Conn, err error) {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
if metadata.NetWork == C.UDP { if metadata.NetWork == C.UDP {
if v.option.PacketAddr { if v.option.PacketAddr {
metadata = &C.Metadata{ metadata = &C.Metadata{
@ -226,10 +230,9 @@ func (v *Vless) streamTLSConn(ctx context.Context, conn net.Conn, isH2 bool) (ne
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (v *Vless) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { func (v *Vless) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
var c net.Conn
// gun transport // gun transport
if v.transport != nil && dialer.IsZeroOptions(opts) { if v.transport != nil && len(opts) == 0 {
c, err = gun.StreamGunWithTransport(v.transport, v.gunConfig) c, err := gun.StreamGunWithTransport(v.transport, v.gunConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -237,7 +240,7 @@ func (v *Vless) DialContext(ctx context.Context, metadata *C.Metadata, opts ...d
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
c, err = v.streamConnContext(ctx, c, metadata) c, err = v.client.StreamConn(c, parseVlessAddr(metadata, v.option.XUDP))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -259,6 +262,7 @@ func (v *Vless) DialContextWithDialer(ctx context.Context, dialer C.Dialer, meta
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
@ -282,7 +286,7 @@ func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o
} }
var c net.Conn var c net.Conn
// gun transport // gun transport
if v.transport != nil && dialer.IsZeroOptions(opts) { if v.transport != nil && len(opts) == 0 {
c, err = gun.StreamGunWithTransport(v.transport, v.gunConfig) c, err = gun.StreamGunWithTransport(v.transport, v.gunConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@ -291,7 +295,7 @@ func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
c, err = v.streamConnContext(ctx, c, metadata) c, err = v.streamConn(c, metadata)
if err != nil { if err != nil {
return nil, fmt.Errorf("new vless client error: %v", err) return nil, fmt.Errorf("new vless client error: %v", err)
} }
@ -323,6 +327,7 @@ func (v *Vless) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
@ -376,34 +381,19 @@ func (v *Vless) SupportUOT() bool {
return true return true
} }
// ProxyInfo implements C.ProxyAdapter
func (v *Vless) ProxyInfo() C.ProxyInfo {
info := v.Base.ProxyInfo()
info.DialerProxy = v.option.DialerProxy
return info
}
// Close implements C.ProxyAdapter
func (v *Vless) Close() error {
if v.transport != nil {
return v.transport.Close()
}
return nil
}
func parseVlessAddr(metadata *C.Metadata, xudp bool) *vless.DstAddr { func parseVlessAddr(metadata *C.Metadata, xudp bool) *vless.DstAddr {
var addrType byte var addrType byte
var addr []byte var addr []byte
switch metadata.AddrType() { switch metadata.AddrType() {
case C.AtypIPv4: case socks5.AtypIPv4:
addrType = vless.AtypIPv4 addrType = vless.AtypIPv4
addr = make([]byte, net.IPv4len) addr = make([]byte, net.IPv4len)
copy(addr[:], metadata.DstIP.AsSlice()) copy(addr[:], metadata.DstIP.AsSlice())
case C.AtypIPv6: case socks5.AtypIPv6:
addrType = vless.AtypIPv6 addrType = vless.AtypIPv6
addr = make([]byte, net.IPv6len) addr = make([]byte, net.IPv6len)
copy(addr[:], metadata.DstIP.AsSlice()) copy(addr[:], metadata.DstIP.AsSlice())
case C.AtypDomainName: case socks5.AtypDomainName:
addrType = vless.AtypDomainName addrType = vless.AtypDomainName
addr = make([]byte, len(metadata.Host)+1) addr = make([]byte, len(metadata.Host)+1)
addr[0] = byte(len(metadata.Host)) addr[0] = byte(len(metadata.Host))
@ -515,12 +505,17 @@ func NewVless(option VlessOption) (*Vless, error) {
var addons *vless.Addons var addons *vless.Addons
if option.Network != "ws" && len(option.Flow) >= 16 { if option.Network != "ws" && len(option.Flow) >= 16 {
option.Flow = option.Flow[:16] option.Flow = option.Flow[:16]
if option.Flow != vless.XRV { switch option.Flow {
return nil, fmt.Errorf("unsupported xtls flow type: %s", option.Flow) case vless.XRV:
} log.Warnln("To use %s, ensure your server is upgrade to Xray-core v1.8.0+", vless.XRV)
addons = &vless.Addons{ addons = &vless.Addons{
Flow: option.Flow, Flow: option.Flow,
} }
case vless.XRO, vless.XRD, vless.XRS:
log.Fatalln("Legacy XTLS protocol %s is deprecated and no longer supported", option.Flow)
default:
return nil, fmt.Errorf("unsupported xtls flow type: %s", option.Flow)
}
} }
switch option.PacketEncoding { switch option.PacketEncoding {
@ -569,7 +564,7 @@ func NewVless(option VlessOption) (*Vless, error) {
option.HTTP2Opts.Host = append(option.HTTP2Opts.Host, "www.example.com") option.HTTP2Opts.Host = append(option.HTTP2Opts.Host, "www.example.com")
} }
case "grpc": case "grpc":
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) { dialFn := func(network, addr string) (net.Conn, error) {
var err error var err error
var cDialer C.Dialer = dialer.NewDialer(v.Base.DialOptions()...) var cDialer C.Dialer = dialer.NewDialer(v.Base.DialOptions()...)
if len(v.option.DialerProxy) > 0 { if len(v.option.DialerProxy) > 0 {
@ -578,10 +573,11 @@ func NewVless(option VlessOption) (*Vless, error) {
return nil, err return nil, err
} }
} }
c, err := cDialer.DialContext(ctx, "tcp", v.addr) c, err := cDialer.DialContext(context.Background(), "tcp", v.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
N.TCPKeepAlive(c)
return c, nil return c, nil
} }
@ -595,13 +591,10 @@ func NewVless(option VlessOption) (*Vless, error) {
} }
var tlsConfig *tls.Config var tlsConfig *tls.Config
if option.TLS { if option.TLS {
tlsConfig, err = ca.GetSpecifiedFingerprintTLSConfig(&tls.Config{ tlsConfig = ca.GetGlobalTLSConfig(&tls.Config{
InsecureSkipVerify: v.option.SkipCertVerify, InsecureSkipVerify: v.option.SkipCertVerify,
ServerName: v.option.ServerName, ServerName: v.option.ServerName,
}, v.option.Fingerprint) })
if err != nil {
return nil, err
}
if option.ServerName == "" { if option.ServerName == "" {
host, _, _ := net.SplitHostPort(v.addr) host, _, _ := net.SplitHostPort(v.addr)
tlsConfig.ServerName = host tlsConfig.ServerName = host

View file

@ -96,7 +96,13 @@ type WSOptions struct {
} }
// StreamConnContext implements C.ProxyAdapter // StreamConnContext implements C.ProxyAdapter
func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ net.Conn, err error) { func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
var err error
if tlsC.HaveGlobalFingerprint() && (len(v.option.ClientFingerprint) == 0) {
v.option.ClientFingerprint = tlsC.GetGlobalFingerprint()
}
switch v.option.Network { switch v.option.Network {
case "ws": case "ws":
host, port, _ := net.SplitHostPort(v.addr) host, port, _ := net.SplitHostPort(v.addr)
@ -193,7 +199,7 @@ func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
Path: v.option.HTTP2Opts.Path, Path: v.option.HTTP2Opts.Path,
} }
c, err = mihomoVMess.StreamH2Conn(ctx, c, h2Opts) c, err = mihomoVMess.StreamH2Conn(c, h2Opts)
case "grpc": case "grpc":
c, err = gun.StreamGunWithConn(c, v.gunTLSConfig, v.gunConfig, v.realityConfig) c, err = gun.StreamGunWithConn(c, v.gunTLSConfig, v.gunConfig, v.realityConfig)
default: default:
@ -220,24 +226,17 @@ func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
if err != nil { if err != nil {
return nil, err return nil, err
} }
return v.streamConnContext(ctx, c, metadata) return v.streamConn(c, metadata)
} }
func (v *Vmess) streamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (conn net.Conn, err error) { func (v *Vmess) streamConn(c net.Conn, metadata *C.Metadata) (conn net.Conn, err error) {
useEarly := N.NeedHandshake(c)
if !useEarly {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
}
if metadata.NetWork == C.UDP { if metadata.NetWork == C.UDP {
if v.option.XUDP { if v.option.XUDP {
var globalID [8]byte var globalID [8]byte
if metadata.SourceValid() { if metadata.SourceValid() {
globalID = utils.GlobalID(metadata.SourceAddress()) globalID = utils.GlobalID(metadata.SourceAddress())
} }
if useEarly { if N.NeedHandshake(c) {
conn = v.client.DialEarlyXUDPPacketConn(c, conn = v.client.DialEarlyXUDPPacketConn(c,
globalID, globalID,
M.SocksaddrFromNet(metadata.UDPAddr())) M.SocksaddrFromNet(metadata.UDPAddr()))
@ -247,7 +246,7 @@ func (v *Vmess) streamConnContext(ctx context.Context, c net.Conn, metadata *C.M
M.SocksaddrFromNet(metadata.UDPAddr())) M.SocksaddrFromNet(metadata.UDPAddr()))
} }
} else if v.option.PacketAddr { } else if v.option.PacketAddr {
if useEarly { if N.NeedHandshake(c) {
conn = v.client.DialEarlyPacketConn(c, conn = v.client.DialEarlyPacketConn(c,
M.ParseSocksaddrHostPort(packetaddr.SeqPacketMagicAddress, 443)) M.ParseSocksaddrHostPort(packetaddr.SeqPacketMagicAddress, 443))
} else { } else {
@ -256,7 +255,7 @@ func (v *Vmess) streamConnContext(ctx context.Context, c net.Conn, metadata *C.M
} }
conn = packetaddr.NewBindConn(conn) conn = packetaddr.NewBindConn(conn)
} else { } else {
if useEarly { if N.NeedHandshake(c) {
conn = v.client.DialEarlyPacketConn(c, conn = v.client.DialEarlyPacketConn(c,
M.SocksaddrFromNet(metadata.UDPAddr())) M.SocksaddrFromNet(metadata.UDPAddr()))
} else { } else {
@ -265,7 +264,7 @@ func (v *Vmess) streamConnContext(ctx context.Context, c net.Conn, metadata *C.M
} }
} }
} else { } else {
if useEarly { if N.NeedHandshake(c) {
conn = v.client.DialEarlyConn(c, conn = v.client.DialEarlyConn(c,
M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort)) M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort))
} else { } else {
@ -281,10 +280,9 @@ func (v *Vmess) streamConnContext(ctx context.Context, c net.Conn, metadata *C.M
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
var c net.Conn
// gun transport // gun transport
if v.transport != nil && dialer.IsZeroOptions(opts) { if v.transport != nil && len(opts) == 0 {
c, err = gun.StreamGunWithTransport(v.transport, v.gunConfig) c, err := gun.StreamGunWithTransport(v.transport, v.gunConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -292,7 +290,7 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata, opts ...d
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
c, err = v.streamConnContext(ctx, c, metadata) c, err = v.client.DialConn(c, M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -314,6 +312,7 @@ func (v *Vmess) DialContextWithDialer(ctx context.Context, dialer C.Dialer, meta
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
@ -334,7 +333,7 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o
} }
var c net.Conn var c net.Conn
// gun transport // gun transport
if v.transport != nil && dialer.IsZeroOptions(opts) { if v.transport != nil && len(opts) == 0 {
c, err = gun.StreamGunWithTransport(v.transport, v.gunConfig) c, err = gun.StreamGunWithTransport(v.transport, v.gunConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@ -343,7 +342,7 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
c, err = v.streamConnContext(ctx, c, metadata) c, err = v.streamConn(c, metadata)
if err != nil { if err != nil {
return nil, fmt.Errorf("new vmess client error: %v", err) return nil, fmt.Errorf("new vmess client error: %v", err)
} }
@ -374,6 +373,7 @@ func (v *Vmess) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
@ -390,21 +390,6 @@ func (v *Vmess) SupportWithDialer() C.NetWork {
return C.ALLNet return C.ALLNet
} }
// ProxyInfo implements C.ProxyAdapter
func (v *Vmess) ProxyInfo() C.ProxyInfo {
info := v.Base.ProxyInfo()
info.DialerProxy = v.option.DialerProxy
return info
}
// Close implements C.ProxyAdapter
func (v *Vmess) Close() error {
if v.transport != nil {
return v.transport.Close()
}
return nil
}
// ListenPacketOnStreamConn implements C.ProxyAdapter // ListenPacketOnStreamConn implements C.ProxyAdapter
func (v *Vmess) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) { func (v *Vmess) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) {
// vmess use stream-oriented udp with a special address, so we need a net.UDPAddr // vmess use stream-oriented udp with a special address, so we need a net.UDPAddr
@ -469,18 +454,13 @@ func NewVmess(option VmessOption) (*Vmess, error) {
option: &option, option: &option,
} }
v.realityConfig, err = v.option.RealityOpts.Parse()
if err != nil {
return nil, err
}
switch option.Network { switch option.Network {
case "h2": case "h2":
if len(option.HTTP2Opts.Host) == 0 { if len(option.HTTP2Opts.Host) == 0 {
option.HTTP2Opts.Host = append(option.HTTP2Opts.Host, "www.example.com") option.HTTP2Opts.Host = append(option.HTTP2Opts.Host, "www.example.com")
} }
case "grpc": case "grpc":
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) { dialFn := func(network, addr string) (net.Conn, error) {
var err error var err error
var cDialer C.Dialer = dialer.NewDialer(v.Base.DialOptions()...) var cDialer C.Dialer = dialer.NewDialer(v.Base.DialOptions()...)
if len(v.option.DialerProxy) > 0 { if len(v.option.DialerProxy) > 0 {
@ -489,10 +469,11 @@ func NewVmess(option VmessOption) (*Vmess, error) {
return nil, err return nil, err
} }
} }
c, err := cDialer.DialContext(ctx, "tcp", v.addr) c, err := cDialer.DialContext(context.Background(), "tcp", v.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
N.TCPKeepAlive(c)
return c, nil return c, nil
} }
@ -506,13 +487,10 @@ func NewVmess(option VmessOption) (*Vmess, error) {
} }
var tlsConfig *tls.Config var tlsConfig *tls.Config
if option.TLS { if option.TLS {
tlsConfig, err = ca.GetSpecifiedFingerprintTLSConfig(&tls.Config{ tlsConfig = ca.GetGlobalTLSConfig(&tls.Config{
InsecureSkipVerify: v.option.SkipCertVerify, InsecureSkipVerify: v.option.SkipCertVerify,
ServerName: v.option.ServerName, ServerName: v.option.ServerName,
}, v.option.Fingerprint) })
if err != nil {
return nil, err
}
if option.ServerName == "" { if option.ServerName == "" {
host, _, _ := net.SplitHostPort(v.addr) host, _, _ := net.SplitHostPort(v.addr)
tlsConfig.ServerName = host tlsConfig.ServerName = host
@ -525,6 +503,11 @@ func NewVmess(option VmessOption) (*Vmess, error) {
v.transport = gun.NewHTTP2Client(dialFn, tlsConfig, v.option.ClientFingerprint, v.realityConfig) v.transport = gun.NewHTTP2Client(dialFn, tlsConfig, v.option.ClientFingerprint, v.realityConfig)
} }
v.realityConfig, err = v.option.RealityOpts.Parse()
if err != nil {
return nil, err
}
return v, nil return v, nil
} }

View file

@ -8,12 +8,14 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/common/atomic"
CN "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer" "github.com/metacubex/mihomo/component/proxydialer"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
@ -22,27 +24,22 @@ import (
"github.com/metacubex/mihomo/dns" "github.com/metacubex/mihomo/dns"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
amnezia "github.com/metacubex/amneziawg-go/device"
wireguard "github.com/metacubex/sing-wireguard" wireguard "github.com/metacubex/sing-wireguard"
"github.com/metacubex/wireguard-go/device"
"github.com/sagernet/sing/common/debug" "github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/wireguard-go/device"
) )
type wireguardGoDevice interface {
Close()
IpcSet(uapiConf string) error
}
type WireGuard struct { type WireGuard struct {
*Base *Base
bind *wireguard.ClientBind bind *wireguard.ClientBind
device wireguardGoDevice device *device.Device
tunDevice wireguard.Device tunDevice wireguard.Device
dialer proxydialer.SingDialer dialer proxydialer.SingDialer
resolver resolver.Resolver resolver *dns.Resolver
refP *refProxyAdapter
initOk atomic.Bool initOk atomic.Bool
initMutex sync.Mutex initMutex sync.Mutex
@ -54,6 +51,8 @@ type WireGuard struct {
serverAddrMap map[M.Socksaddr]netip.AddrPort serverAddrMap map[M.Socksaddr]netip.AddrPort
serverAddrTime atomic.TypedValue[time.Time] serverAddrTime atomic.TypedValue[time.Time]
serverAddrMutex sync.Mutex serverAddrMutex sync.Mutex
closeCh chan struct{} // for test
} }
type WireGuardOption struct { type WireGuardOption struct {
@ -68,8 +67,6 @@ type WireGuardOption struct {
UDP bool `proxy:"udp,omitempty"` UDP bool `proxy:"udp,omitempty"`
PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"` PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"`
AmneziaWGOption *AmneziaWGOption `proxy:"amnezia-wg-option,omitempty"`
Peers []WireGuardPeerOption `proxy:"peers,omitempty"` Peers []WireGuardPeerOption `proxy:"peers,omitempty"`
RemoteDnsResolve bool `proxy:"remote-dns-resolve,omitempty"` RemoteDnsResolve bool `proxy:"remote-dns-resolve,omitempty"`
@ -87,18 +84,6 @@ type WireGuardPeerOption struct {
AllowedIPs []string `proxy:"allowed-ips,omitempty"` AllowedIPs []string `proxy:"allowed-ips,omitempty"`
} }
type AmneziaWGOption struct {
JC int `proxy:"jc"`
JMin int `proxy:"jmin"`
JMax int `proxy:"jmax"`
S1 int `proxy:"s1"`
S2 int `proxy:"s2"`
H1 uint32 `proxy:"h1"`
H2 uint32 `proxy:"h2"`
H3 uint32 `proxy:"h3"`
H4 uint32 `proxy:"h4"`
}
type wgSingErrorHandler struct { type wgSingErrorHandler struct {
name string name string
} }
@ -168,6 +153,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
}, },
dialer: proxydialer.NewSlowDownSingDialer(proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer()), slowdown.New()), dialer: proxydialer.NewSlowDownSingDialer(proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer()), slowdown.New()),
} }
runtime.SetFinalizer(outbound, closeWireGuard)
var reserved [3]uint8 var reserved [3]uint8
if len(option.Reserved) > 0 { if len(option.Reserved) > 0 {
@ -257,20 +243,14 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
if err != nil { if err != nil {
return nil, E.Cause(err, "create WireGuard device") return nil, E.Cause(err, "create WireGuard device")
} }
logger := &device.Logger{ outbound.device = device.NewDevice(context.Background(), outbound.tunDevice, outbound.bind, &device.Logger{
Verbosef: func(format string, args ...interface{}) { Verbosef: func(format string, args ...interface{}) {
log.SingLogger.Debug(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...))) log.SingLogger.Debug(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...)))
}, },
Errorf: func(format string, args ...interface{}) { Errorf: func(format string, args ...interface{}) {
log.SingLogger.Error(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...))) log.SingLogger.Error(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...)))
}, },
} }, option.Workers)
if option.AmneziaWGOption != nil {
outbound.bind.SetParseReserved(false) // AmneziaWG don't need parse reserved
outbound.device = amnezia.NewDevice(outbound.tunDevice, outbound.bind, logger, option.Workers)
} else {
outbound.device = device.NewDevice(outbound.tunDevice, outbound.bind, logger, option.Workers)
}
var has6 bool var has6 bool
for _, address := range outbound.localPrefixes { for _, address := range outbound.localPrefixes {
@ -280,13 +260,15 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
} }
} }
refP := &refProxyAdapter{}
outbound.refP = refP
if option.RemoteDnsResolve && len(option.Dns) > 0 { if option.RemoteDnsResolve && len(option.Dns) > 0 {
nss, err := dns.ParseNameServer(option.Dns) nss, err := dns.ParseNameServer(option.Dns)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for i := range nss { for i := range nss {
nss[i].ProxyAdapter = outbound nss[i].ProxyAdapter = refP
} }
outbound.resolver = dns.NewResolver(dns.Config{ outbound.resolver = dns.NewResolver(dns.Config{
Main: nss, Main: nss,
@ -301,7 +283,7 @@ func (w *WireGuard) resolve(ctx context.Context, address M.Socksaddr) (netip.Add
if address.Addr.IsValid() { if address.Addr.IsValid() {
return address.AddrPort(), nil return address.AddrPort(), nil
} }
udpAddr, err := resolveUDPAddr(ctx, "udp", address.String(), w.prefer) udpAddr, err := resolveUDPAddrWithPrefer(ctx, "udp", address.String(), w.prefer)
if err != nil { if err != nil {
return netip.AddrPort{}, err return netip.AddrPort{}, err
} }
@ -385,17 +367,6 @@ func (w *WireGuard) genIpcConf(ctx context.Context, updateOnly bool) (string, er
ipcConf := "" ipcConf := ""
if !updateOnly { if !updateOnly {
ipcConf += "private_key=" + w.option.PrivateKey + "\n" ipcConf += "private_key=" + w.option.PrivateKey + "\n"
if w.option.AmneziaWGOption != nil {
ipcConf += "jc=" + strconv.Itoa(w.option.AmneziaWGOption.JC) + "\n"
ipcConf += "jmin=" + strconv.Itoa(w.option.AmneziaWGOption.JMin) + "\n"
ipcConf += "jmax=" + strconv.Itoa(w.option.AmneziaWGOption.JMax) + "\n"
ipcConf += "s1=" + strconv.Itoa(w.option.AmneziaWGOption.S1) + "\n"
ipcConf += "s2=" + strconv.Itoa(w.option.AmneziaWGOption.S2) + "\n"
ipcConf += "h1=" + strconv.FormatUint(uint64(w.option.AmneziaWGOption.H1), 10) + "\n"
ipcConf += "h2=" + strconv.FormatUint(uint64(w.option.AmneziaWGOption.H2), 10) + "\n"
ipcConf += "h3=" + strconv.FormatUint(uint64(w.option.AmneziaWGOption.H3), 10) + "\n"
ipcConf += "h4=" + strconv.FormatUint(uint64(w.option.AmneziaWGOption.H4), 10) + "\n"
}
} }
if len(w.option.Peers) > 0 { if len(w.option.Peers) > 0 {
for i, peer := range w.option.Peers { for i, peer := range w.option.Peers {
@ -480,12 +451,13 @@ func (w *WireGuard) genIpcConf(ctx context.Context, updateOnly bool) (string, er
return ipcConf, nil return ipcConf, nil
} }
// Close implements C.ProxyAdapter func closeWireGuard(w *WireGuard) {
func (w *WireGuard) Close() error {
if w.device != nil { if w.device != nil {
w.device.Close() w.device.Close()
} }
return nil if w.closeCh != nil {
close(w.closeCh)
}
} }
func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
@ -498,6 +470,8 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts
if !metadata.Resolved() || w.resolver != nil { if !metadata.Resolved() || w.resolver != nil {
r := resolver.DefaultResolver r := resolver.DefaultResolver
if w.resolver != nil { if w.resolver != nil {
w.refP.SetProxyAdapter(w)
defer w.refP.ClearProxyAdapter()
r = w.resolver r = w.resolver
} }
options = append(options, dialer.WithResolver(r)) options = append(options, dialer.WithResolver(r))
@ -512,7 +486,7 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts
if conn == nil { if conn == nil {
return nil, E.New("conn is nil") return nil, E.New("conn is nil")
} }
return NewConn(conn, w), nil return NewConn(CN.NewRefConn(conn, w), w), nil
} }
func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
@ -525,6 +499,8 @@ func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadat
if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" { if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" {
r := resolver.DefaultResolver r := resolver.DefaultResolver
if w.resolver != nil { if w.resolver != nil {
w.refP.SetProxyAdapter(w)
defer w.refP.ClearProxyAdapter()
r = w.resolver r = w.resolver
} }
ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r) ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r)
@ -540,10 +516,146 @@ func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadat
if pc == nil { if pc == nil {
return nil, E.New("packetConn is nil") return nil, E.New("packetConn is nil")
} }
return newPacketConn(pc, w), nil return newPacketConn(CN.NewRefPacketConn(pc, w), w), nil
} }
// IsL3Protocol implements C.ProxyAdapter // IsL3Protocol implements C.ProxyAdapter
func (w *WireGuard) IsL3Protocol(metadata *C.Metadata) bool { func (w *WireGuard) IsL3Protocol(metadata *C.Metadata) bool {
return true return true
} }
type refProxyAdapter struct {
proxyAdapter C.ProxyAdapter
count int
mutex sync.Mutex
}
func (r *refProxyAdapter) SetProxyAdapter(proxyAdapter C.ProxyAdapter) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.proxyAdapter = proxyAdapter
r.count++
}
func (r *refProxyAdapter) ClearProxyAdapter() {
r.mutex.Lock()
defer r.mutex.Unlock()
r.count--
if r.count == 0 {
r.proxyAdapter = nil
}
}
func (r *refProxyAdapter) Name() string {
if r.proxyAdapter != nil {
return r.proxyAdapter.Name()
}
return ""
}
func (r *refProxyAdapter) Type() C.AdapterType {
if r.proxyAdapter != nil {
return r.proxyAdapter.Type()
}
return C.AdapterType(0)
}
func (r *refProxyAdapter) Addr() string {
if r.proxyAdapter != nil {
return r.proxyAdapter.Addr()
}
return ""
}
func (r *refProxyAdapter) SupportUDP() bool {
if r.proxyAdapter != nil {
return r.proxyAdapter.SupportUDP()
}
return false
}
func (r *refProxyAdapter) SupportXUDP() bool {
if r.proxyAdapter != nil {
return r.proxyAdapter.SupportXUDP()
}
return false
}
func (r *refProxyAdapter) SupportTFO() bool {
if r.proxyAdapter != nil {
return r.proxyAdapter.SupportTFO()
}
return false
}
func (r *refProxyAdapter) MarshalJSON() ([]byte, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.MarshalJSON()
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.StreamConnContext(ctx, c, metadata)
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.DialContext(ctx, metadata, opts...)
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.ListenPacketContext(ctx, metadata, opts...)
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) SupportUOT() bool {
if r.proxyAdapter != nil {
return r.proxyAdapter.SupportUOT()
}
return false
}
func (r *refProxyAdapter) SupportWithDialer() C.NetWork {
if r.proxyAdapter != nil {
return r.proxyAdapter.SupportWithDialer()
}
return C.InvalidNet
}
func (r *refProxyAdapter) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.Conn, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.DialContextWithDialer(ctx, dialer, metadata)
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.PacketConn, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.ListenPacketWithDialer(ctx, dialer, metadata)
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) IsL3Protocol(metadata *C.Metadata) bool {
if r.proxyAdapter != nil {
return r.proxyAdapter.IsL3Protocol(metadata)
}
return false
}
func (r *refProxyAdapter) Unwrap(metadata *C.Metadata, touch bool) C.Proxy {
if r.proxyAdapter != nil {
return r.proxyAdapter.Unwrap(metadata, touch)
}
return nil
}
var _ C.ProxyAdapter = (*refProxyAdapter)(nil)

View file

@ -0,0 +1,45 @@
//go:build with_gvisor
package outbound
import (
"context"
"runtime"
"testing"
"time"
)
func TestWireGuardGC(t *testing.T) {
option := WireGuardOption{}
option.Server = "162.159.192.1"
option.Port = 2408
option.PrivateKey = "iOx7749AdqH3IqluG7+0YbGKd0m1mcEXAfGRzpy9rG8="
option.PublicKey = "bmXOC+F1FxEMF9dyiK2H5/1SUtzH0JuVo51h2wPfgyo="
option.Ip = "172.16.0.2"
option.Ipv6 = "2606:4700:110:8d29:be92:3a6a:f4:c437"
option.Reserved = []uint8{51, 69, 125}
wg, err := NewWireGuard(option)
if err != nil {
t.Error(err)
}
closeCh := make(chan struct{})
wg.closeCh = closeCh
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err = wg.init(ctx)
if err != nil {
t.Error(err)
return
}
// must do a small sleep before test GC
// because it maybe deadlocks if w.device.Close call too fast after w.device.Start
time.Sleep(10 * time.Millisecond)
wg = nil
runtime.GC()
select {
case <-closeCh:
return
case <-ctx.Done():
t.Error("timeout not GC")
}
}

View file

@ -37,7 +37,7 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata, opts .
if err == nil { if err == nil {
c.AppendToChains(f) c.AppendToChains(f)
} else { } else {
f.onDialFailed(proxy.Type(), err, f.healthCheck) f.onDialFailed(proxy.Type(), err)
} }
if N.NeedHandshake(c) { if N.NeedHandshake(c) {
@ -45,7 +45,7 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata, opts .
if err == nil { if err == nil {
f.onDialSuccess() f.onDialSuccess()
} else { } else {
f.onDialFailed(proxy.Type(), err, f.healthCheck) f.onDialFailed(proxy.Type(), err)
} }
}) })
} }

View file

@ -2,7 +2,6 @@ package outboundgroup
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
@ -18,26 +17,22 @@ import (
"github.com/metacubex/mihomo/tunnel" "github.com/metacubex/mihomo/tunnel"
"github.com/dlclark/regexp2" "github.com/dlclark/regexp2"
"golang.org/x/exp/slices"
) )
type GroupBase struct { type GroupBase struct {
*outbound.Base *outbound.Base
filterRegs []*regexp2.Regexp filterRegs []*regexp2.Regexp
excludeFilterRegs []*regexp2.Regexp excludeFilterReg *regexp2.Regexp
excludeTypeArray []string excludeTypeArray []string
providers []provider.ProxyProvider providers []provider.ProxyProvider
failedTestMux sync.Mutex failedTestMux sync.Mutex
failedTimes int failedTimes int
failedTime time.Time failedTime time.Time
failedTesting atomic.Bool failedTesting atomic.Bool
proxies [][]C.Proxy
versions []atomic.Uint32
TestTimeout int TestTimeout int
maxFailedTimes int maxFailedTimes int
// for GetProxies
getProxiesMutex sync.Mutex
providerVersions []uint32
providerProxies []C.Proxy
} }
type GroupBaseOption struct { type GroupBaseOption struct {
@ -51,26 +46,15 @@ type GroupBaseOption struct {
} }
func NewGroupBase(opt GroupBaseOption) *GroupBase { func NewGroupBase(opt GroupBaseOption) *GroupBase {
if opt.RoutingMark != 0 { var excludeFilterReg *regexp2.Regexp
log.Warnln("The group [%s] with routing-mark configuration is deprecated, please set it directly on the proxy instead", opt.Name) if opt.excludeFilter != "" {
excludeFilterReg = regexp2.MustCompile(opt.excludeFilter, regexp2.None)
} }
if opt.Interface != "" {
log.Warnln("The group [%s] with interface-name configuration is deprecated, please set it directly on the proxy instead", opt.Name)
}
var excludeTypeArray []string var excludeTypeArray []string
if opt.excludeType != "" { if opt.excludeType != "" {
excludeTypeArray = strings.Split(opt.excludeType, "|") excludeTypeArray = strings.Split(opt.excludeType, "|")
} }
var excludeFilterRegs []*regexp2.Regexp
if opt.excludeFilter != "" {
for _, excludeFilter := range strings.Split(opt.excludeFilter, "`") {
excludeFilterReg := regexp2.MustCompile(excludeFilter, regexp2.None)
excludeFilterRegs = append(excludeFilterRegs, excludeFilterReg)
}
}
var filterRegs []*regexp2.Regexp var filterRegs []*regexp2.Regexp
if opt.filter != "" { if opt.filter != "" {
for _, filter := range strings.Split(opt.filter, "`") { for _, filter := range strings.Split(opt.filter, "`") {
@ -82,7 +66,7 @@ func NewGroupBase(opt GroupBaseOption) *GroupBase {
gb := &GroupBase{ gb := &GroupBase{
Base: outbound.NewBase(opt.BaseOption), Base: outbound.NewBase(opt.BaseOption),
filterRegs: filterRegs, filterRegs: filterRegs,
excludeFilterRegs: excludeFilterRegs, excludeFilterReg: excludeFilterReg,
excludeTypeArray: excludeTypeArray, excludeTypeArray: excludeTypeArray,
providers: opt.providers, providers: opt.providers,
failedTesting: atomic.NewBool(false), failedTesting: atomic.NewBool(false),
@ -97,6 +81,9 @@ func NewGroupBase(opt GroupBaseOption) *GroupBase {
gb.maxFailedTimes = 5 gb.maxFailedTimes = 5
} }
gb.proxies = make([][]C.Proxy, len(opt.providers))
gb.versions = make([]atomic.Uint32, len(opt.providers))
return gb return gb
} }
@ -107,39 +94,37 @@ func (gb *GroupBase) Touch() {
} }
func (gb *GroupBase) GetProxies(touch bool) []C.Proxy { func (gb *GroupBase) GetProxies(touch bool) []C.Proxy {
providerVersions := make([]uint32, len(gb.providers))
for i, pd := range gb.providers {
if touch { // touch first
pd.Touch()
}
providerVersions[i] = pd.Version()
}
// thread safe
gb.getProxiesMutex.Lock()
defer gb.getProxiesMutex.Unlock()
// return the cached proxies if version not changed
if slices.Equal(providerVersions, gb.providerVersions) {
return gb.providerProxies
}
var proxies []C.Proxy var proxies []C.Proxy
if len(gb.filterRegs) == 0 { if len(gb.filterRegs) == 0 {
for _, pd := range gb.providers { for _, pd := range gb.providers {
if touch {
pd.Touch()
}
proxies = append(proxies, pd.Proxies()...) proxies = append(proxies, pd.Proxies()...)
} }
} else { } else {
for _, pd := range gb.providers { for i, pd := range gb.providers {
if pd.VehicleType() == types.Compatible { // compatible provider unneeded filter if touch {
proxies = append(proxies, pd.Proxies()...) pd.Touch()
}
if pd.VehicleType() == types.Compatible {
gb.versions[i].Store(pd.Version())
gb.proxies[i] = pd.Proxies()
continue continue
} }
var newProxies []C.Proxy version := gb.versions[i].Load()
if version != pd.Version() && gb.versions[i].CompareAndSwap(version, pd.Version()) {
var (
proxies []C.Proxy
newProxies []C.Proxy
)
proxies = pd.Proxies()
proxiesSet := map[string]struct{}{} proxiesSet := map[string]struct{}{}
for _, filterReg := range gb.filterRegs { for _, filterReg := range gb.filterRegs {
for _, p := range pd.Proxies() { for _, p := range proxies {
name := p.Name() name := p.Name()
if mat, _ := filterReg.MatchString(name); mat { if mat, _ := filterReg.MatchString(name); mat {
if _, ok := proxiesSet[name]; !ok { if _, ok := proxiesSet[name]; !ok {
@ -149,13 +134,16 @@ func (gb *GroupBase) GetProxies(touch bool) []C.Proxy {
} }
} }
} }
proxies = append(proxies, newProxies...)
gb.proxies[i] = newProxies
}
}
for _, p := range gb.proxies {
proxies = append(proxies, p...)
} }
} }
// Multiple filers means that proxies are sorted in the order in which the filers appear.
// Although the filter has been performed once in the previous process,
// when there are multiple providers, the array needs to be reordered as a whole.
if len(gb.providers) > 1 && len(gb.filterRegs) > 1 { if len(gb.providers) > 1 && len(gb.filterRegs) > 1 {
var newProxies []C.Proxy var newProxies []C.Proxy
proxiesSet := map[string]struct{}{} proxiesSet := map[string]struct{}{}
@ -179,31 +167,32 @@ func (gb *GroupBase) GetProxies(touch bool) []C.Proxy {
} }
proxies = newProxies proxies = newProxies
} }
if gb.excludeTypeArray != nil {
if len(gb.excludeFilterRegs) > 0 {
var newProxies []C.Proxy var newProxies []C.Proxy
LOOP1:
for _, p := range proxies { for _, p := range proxies {
name := p.Name() mType := p.Type().String()
for _, excludeFilterReg := range gb.excludeFilterRegs { flag := false
if mat, _ := excludeFilterReg.MatchString(name); mat { for i := range gb.excludeTypeArray {
continue LOOP1 if strings.EqualFold(mType, gb.excludeTypeArray[i]) {
flag = true
break
} }
}
if flag {
continue
} }
newProxies = append(newProxies, p) newProxies = append(newProxies, p)
} }
proxies = newProxies proxies = newProxies
} }
if gb.excludeTypeArray != nil { if gb.excludeFilterReg != nil {
var newProxies []C.Proxy var newProxies []C.Proxy
LOOP2:
for _, p := range proxies { for _, p := range proxies {
mType := p.Type().String() name := p.Name()
for _, excludeType := range gb.excludeTypeArray { if mat, _ := gb.excludeFilterReg.MatchString(name); mat {
if strings.EqualFold(mType, excludeType) { continue
continue LOOP2
}
} }
newProxies = append(newProxies, p) newProxies = append(newProxies, p)
} }
@ -211,13 +200,9 @@ func (gb *GroupBase) GetProxies(touch bool) []C.Proxy {
} }
if len(proxies) == 0 { if len(proxies) == 0 {
return []C.Proxy{tunnel.Proxies()["COMPATIBLE"]} return append(proxies, tunnel.Proxies()["COMPATIBLE"])
} }
// only cache when proxies not empty
gb.providerVersions = providerVersions
gb.providerProxies = proxies
return proxies return proxies
} }
@ -249,21 +234,17 @@ func (gb *GroupBase) URLTest(ctx context.Context, url string, expectedStatus uti
} }
} }
func (gb *GroupBase) onDialFailed(adapterType C.AdapterType, err error, fn func()) { func (gb *GroupBase) onDialFailed(adapterType C.AdapterType, err error) {
if adapterType == C.Direct || adapterType == C.Compatible || adapterType == C.Reject || adapterType == C.Pass || adapterType == C.RejectDrop { if adapterType == C.Direct || adapterType == C.Compatible || adapterType == C.Reject || adapterType == C.Pass || adapterType == C.RejectDrop {
return return
} }
if errors.Is(err, C.ErrNotSupport) { if strings.Contains(err.Error(), "connection refused") {
go gb.healthCheck()
return return
} }
go func() { go func() {
if strings.Contains(err.Error(), "connection refused") {
fn()
return
}
gb.failedTestMux.Lock() gb.failedTestMux.Lock()
defer gb.failedTestMux.Unlock() defer gb.failedTestMux.Unlock()
@ -280,7 +261,7 @@ func (gb *GroupBase) onDialFailed(adapterType C.AdapterType, err error, fn func(
log.Debugln("ProxyGroup: %s failed count: %d", gb.Name(), gb.failedTimes) log.Debugln("ProxyGroup: %s failed count: %d", gb.Name(), gb.failedTimes)
if gb.failedTimes >= gb.maxFailedTimes { if gb.failedTimes >= gb.maxFailedTimes {
log.Warnln("because %s failed multiple times, active health check", gb.Name()) log.Warnln("because %s failed multiple times, active health check", gb.Name())
fn() gb.healthCheck()
} }
} }
}() }()

View file

@ -95,7 +95,7 @@ func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata, op
if err == nil { if err == nil {
c.AppendToChains(lb) c.AppendToChains(lb)
} else { } else {
lb.onDialFailed(proxy.Type(), err, lb.healthCheck) lb.onDialFailed(proxy.Type(), err)
} }
if N.NeedHandshake(c) { if N.NeedHandshake(c) {
@ -103,7 +103,7 @@ func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata, op
if err == nil { if err == nil {
lb.onDialSuccess() lb.onDialSuccess()
} else { } else {
lb.onDialFailed(proxy.Type(), err, lb.healthCheck) lb.onDialFailed(proxy.Type(), err)
} }
}) })
} }
@ -204,7 +204,7 @@ func strategyStickySessions(url string) strategyFn {
for i := 1; i < maxRetry; i++ { for i := 1; i < maxRetry; i++ {
proxy := proxies[nowIdx] proxy := proxies[nowIdx]
if proxy.AliveForTestUrl(url) { if proxy.AliveForTestUrl(url) {
if !has || nowIdx != idx { if nowIdx != idx {
lruCache.Set(key, nowIdx) lruCache.Set(key, nowIdx)
} }

View file

@ -4,6 +4,8 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"sync"
"time" "time"
"github.com/metacubex/mihomo/adapter/outbound" "github.com/metacubex/mihomo/adapter/outbound"
@ -52,13 +54,13 @@ func (u *URLTest) Set(name string) error {
if p == nil { if p == nil {
return errors.New("proxy not exist") return errors.New("proxy not exist")
} }
u.ForceSet(name) u.selected = name
u.fast(false)
return nil return nil
} }
func (u *URLTest) ForceSet(name string) { func (u *URLTest) ForceSet(name string) {
u.selected = name u.selected = name
u.fastSingle.Reset()
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
@ -68,7 +70,7 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata, opts ..
if err == nil { if err == nil {
c.AppendToChains(u) c.AppendToChains(u)
} else { } else {
u.onDialFailed(proxy.Type(), err, u.healthCheck) u.onDialFailed(proxy.Type(), err)
} }
if N.NeedHandshake(c) { if N.NeedHandshake(c) {
@ -76,7 +78,7 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata, opts ..
if err == nil { if err == nil {
u.onDialSuccess() u.onDialSuccess()
} else { } else {
u.onDialFailed(proxy.Type(), err, u.healthCheck) u.onDialFailed(proxy.Type(), err)
} }
}) })
} }
@ -86,12 +88,9 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata, opts ..
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (u *URLTest) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { func (u *URLTest) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
proxy := u.fast(true) pc, err := u.fast(true).ListenPacketContext(ctx, metadata, u.Base.DialOptions(opts...)...)
pc, err := proxy.ListenPacketContext(ctx, metadata, u.Base.DialOptions(opts...)...)
if err == nil { if err == nil {
pc.AppendToChains(u) pc.AppendToChains(u)
} else {
u.onDialFailed(proxy.Type(), err, u.healthCheck)
} }
return pc, err return pc, err
@ -102,14 +101,8 @@ func (u *URLTest) Unwrap(metadata *C.Metadata, touch bool) C.Proxy {
return u.fast(touch) return u.fast(touch)
} }
func (u *URLTest) healthCheck() {
u.fastSingle.Reset()
u.GroupBase.healthCheck()
u.fastSingle.Reset()
}
func (u *URLTest) fast(touch bool) C.Proxy { func (u *URLTest) fast(touch bool) C.Proxy {
elm, _, shared := u.fastSingle.Do(func() (C.Proxy, error) {
proxies := u.GetProxies(touch) proxies := u.GetProxies(touch)
if u.selected != "" { if u.selected != "" {
for _, proxy := range proxies { for _, proxy := range proxies {
@ -118,11 +111,12 @@ func (u *URLTest) fast(touch bool) C.Proxy {
} }
if proxy.Name() == u.selected { if proxy.Name() == u.selected {
u.fastNode = proxy u.fastNode = proxy
return proxy, nil return proxy
} }
} }
} }
elm, _, shared := u.fastSingle.Do(func() (C.Proxy, error) {
fast := proxies[0] fast := proxies[0]
minDelay := fast.LastDelayForTestUrl(u.testUrl) minDelay := fast.LastDelayForTestUrl(u.testUrl)
fastNotExist := true fastNotExist := true
@ -188,7 +182,31 @@ func (u *URLTest) MarshalJSON() ([]byte, error) {
} }
func (u *URLTest) URLTest(ctx context.Context, url string, expectedStatus utils.IntRanges[uint16]) (map[string]uint16, error) { func (u *URLTest) URLTest(ctx context.Context, url string, expectedStatus utils.IntRanges[uint16]) (map[string]uint16, error) {
return u.GroupBase.URLTest(ctx, u.testUrl, expectedStatus) var wg sync.WaitGroup
var lock sync.Mutex
mp := map[string]uint16{}
proxies := u.GetProxies(false)
for _, proxy := range proxies {
proxy := proxy
wg.Add(1)
go func() {
delay, err := proxy.URLTest(ctx, u.testUrl, expectedStatus)
if err == nil {
lock.Lock()
mp[proxy.Name()] = delay
lock.Unlock()
}
wg.Done()
}()
}
wg.Wait()
if len(mp) == 0 {
return mp, fmt.Errorf("get delay: all proxies timeout")
} else {
return mp, nil
}
} }
func parseURLTestOption(config map[string]any) []urlTestOption { func parseURLTestOption(config map[string]any) []urlTestOption {

View file

@ -4,7 +4,3 @@ type SelectAble interface {
Set(string) error Set(string) error
ForceSet(name string) ForceSet(name string)
} }
var _ SelectAble = (*Fallback)(nil)
var _ SelectAble = (*URLTest)(nil)
var _ SelectAble = (*Selector)(nil)

View file

@ -3,6 +3,8 @@ package adapter
import ( import (
"fmt" "fmt"
tlsC "github.com/metacubex/mihomo/component/tls"
"github.com/metacubex/mihomo/adapter/outbound" "github.com/metacubex/mihomo/adapter/outbound"
"github.com/metacubex/mihomo/common/structure" "github.com/metacubex/mihomo/common/structure"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
@ -16,12 +18,12 @@ func ParseProxy(mapping map[string]any) (C.Proxy, error) {
} }
var ( var (
proxy outbound.ProxyAdapter proxy C.ProxyAdapter
err error err error
) )
switch proxyType { switch proxyType {
case "ss": case "ss":
ssOption := &outbound.ShadowSocksOption{} ssOption := &outbound.ShadowSocksOption{ClientFingerprint: tlsC.GetGlobalFingerprint()}
err = decoder.Decode(mapping, ssOption) err = decoder.Decode(mapping, ssOption)
if err != nil { if err != nil {
break break
@ -54,6 +56,7 @@ func ParseProxy(mapping map[string]any) (C.Proxy, error) {
Method: "GET", Method: "GET",
Path: []string{"/"}, Path: []string{"/"},
}, },
ClientFingerprint: tlsC.GetGlobalFingerprint(),
} }
err = decoder.Decode(mapping, vmessOption) err = decoder.Decode(mapping, vmessOption)
@ -62,7 +65,7 @@ func ParseProxy(mapping map[string]any) (C.Proxy, error) {
} }
proxy, err = outbound.NewVmess(*vmessOption) proxy, err = outbound.NewVmess(*vmessOption)
case "vless": case "vless":
vlessOption := &outbound.VlessOption{} vlessOption := &outbound.VlessOption{ClientFingerprint: tlsC.GetGlobalFingerprint()}
err = decoder.Decode(mapping, vlessOption) err = decoder.Decode(mapping, vlessOption)
if err != nil { if err != nil {
break break
@ -76,7 +79,7 @@ func ParseProxy(mapping map[string]any) (C.Proxy, error) {
} }
proxy, err = outbound.NewSnell(*snellOption) proxy, err = outbound.NewSnell(*snellOption)
case "trojan": case "trojan":
trojanOption := &outbound.TrojanOption{} trojanOption := &outbound.TrojanOption{ClientFingerprint: tlsC.GetGlobalFingerprint()}
err = decoder.Decode(mapping, trojanOption) err = decoder.Decode(mapping, trojanOption)
if err != nil { if err != nil {
break break
@ -138,20 +141,6 @@ func ParseProxy(mapping map[string]any) (C.Proxy, error) {
break break
} }
proxy, err = outbound.NewSsh(*sshOption) proxy, err = outbound.NewSsh(*sshOption)
case "mieru":
mieruOption := &outbound.MieruOption{}
err = decoder.Decode(mapping, mieruOption)
if err != nil {
break
}
proxy, err = outbound.NewMieru(*mieruOption)
case "anytls":
anytlsOption := &outbound.AnyTLSOption{}
err = decoder.Decode(mapping, anytlsOption)
if err != nil {
break
}
proxy, err = outbound.NewAnyTLS(*anytlsOption)
default: default:
return nil, fmt.Errorf("unsupport proxy type: %s", proxyType) return nil, fmt.Errorf("unsupport proxy type: %s", proxyType)
} }
@ -167,13 +156,12 @@ func ParseProxy(mapping map[string]any) (C.Proxy, error) {
return nil, err return nil, err
} }
if muxOption.Enabled { if muxOption.Enabled {
proxy, err = outbound.NewSingMux(*muxOption, proxy) proxy, err = outbound.NewSingMux(*muxOption, proxy, proxy.(outbound.ProxyBase))
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
} }
proxy = outbound.NewAutoCloseProxyAdapter(proxy)
return NewProxy(proxy), nil return NewProxy(proxy), nil
} }

View file

@ -1,7 +1,6 @@
package provider package provider
import ( import (
"encoding"
"errors" "errors"
"fmt" "fmt"
"time" "time"
@ -10,9 +9,8 @@ import (
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/resource" "github.com/metacubex/mihomo/component/resource"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/constant/features"
types "github.com/metacubex/mihomo/constant/provider" types "github.com/metacubex/mihomo/constant/provider"
"github.com/dlclark/regexp2"
) )
var ( var (
@ -29,15 +27,6 @@ type healthCheckSchema struct {
ExpectedStatus string `provider:"expected-status,omitempty"` ExpectedStatus string `provider:"expected-status,omitempty"`
} }
type OverrideProxyNameSchema struct {
// matching expression for regex replacement
Pattern *regexp2.Regexp `provider:"pattern"`
// the new content after regex matching
Target string `provider:"target"`
}
var _ encoding.TextUnmarshaler = (*regexp2.Regexp)(nil) // ensure *regexp2.Regexp can decode direct by structure package
type OverrideSchema struct { type OverrideSchema struct {
TFO *bool `provider:"tfo,omitempty"` TFO *bool `provider:"tfo,omitempty"`
MPTcp *bool `provider:"mptcp,omitempty"` MPTcp *bool `provider:"mptcp,omitempty"`
@ -52,8 +41,6 @@ type OverrideSchema struct {
IPVersion *string `provider:"ip-version,omitempty"` IPVersion *string `provider:"ip-version,omitempty"`
AdditionalPrefix *string `provider:"additional-prefix,omitempty"` AdditionalPrefix *string `provider:"additional-prefix,omitempty"`
AdditionalSuffix *string `provider:"additional-suffix,omitempty"` AdditionalSuffix *string `provider:"additional-suffix,omitempty"`
ProxyName []OverrideProxyNameSchema `provider:"proxy-name,omitempty"`
} }
type proxyProviderSchema struct { type proxyProviderSchema struct {
@ -66,8 +53,6 @@ type proxyProviderSchema struct {
ExcludeFilter string `provider:"exclude-filter,omitempty"` ExcludeFilter string `provider:"exclude-filter,omitempty"`
ExcludeType string `provider:"exclude-type,omitempty"` ExcludeType string `provider:"exclude-type,omitempty"`
DialerProxy string `provider:"dialer-proxy,omitempty"` DialerProxy string `provider:"dialer-proxy,omitempty"`
SizeLimit int64 `provider:"size-limit,omitempty"`
Payload []map[string]any `provider:"payload,omitempty"`
HealthCheck healthCheckSchema `provider:"health-check,omitempty"` HealthCheck healthCheckSchema `provider:"health-check,omitempty"`
Override OverrideSchema `provider:"override,omitempty"` Override OverrideSchema `provider:"override,omitempty"`
@ -100,11 +85,6 @@ func ParseProxyProvider(name string, mapping map[string]any) (types.ProxyProvide
} }
hc := NewHealthCheck([]C.Proxy{}, schema.HealthCheck.URL, uint(schema.HealthCheck.TestTimeout), hcInterval, schema.HealthCheck.Lazy, expectedStatus) hc := NewHealthCheck([]C.Proxy{}, schema.HealthCheck.URL, uint(schema.HealthCheck.TestTimeout), hcInterval, schema.HealthCheck.Lazy, expectedStatus)
parser, err := NewProxiesParser(schema.Filter, schema.ExcludeFilter, schema.ExcludeType, schema.DialerProxy, schema.Override)
if err != nil {
return nil, err
}
var vehicle types.Vehicle var vehicle types.Vehicle
switch schema.Type { switch schema.Type {
case "file": case "file":
@ -114,18 +94,21 @@ func ParseProxyProvider(name string, mapping map[string]any) (types.ProxyProvide
path := C.Path.GetPathByHash("proxies", schema.URL) path := C.Path.GetPathByHash("proxies", schema.URL)
if schema.Path != "" { if schema.Path != "" {
path = C.Path.Resolve(schema.Path) path = C.Path.Resolve(schema.Path)
if !C.Path.IsSafePath(path) { if !features.CMFA && !C.Path.IsSafePath(path) {
return nil, fmt.Errorf("%w: %s", errSubPath, path) return nil, fmt.Errorf("%w: %s", errSubPath, path)
} }
} }
vehicle = resource.NewHTTPVehicle(schema.URL, path, schema.Proxy, schema.Header, resource.DefaultHttpTimeout, schema.SizeLimit) vehicle = resource.NewHTTPVehicle(schema.URL, path, schema.Proxy, schema.Header)
case "inline":
return NewInlineProvider(name, schema.Payload, parser, hc)
default: default:
return nil, fmt.Errorf("%w: %s", errVehicleType, schema.Type) return nil, fmt.Errorf("%w: %s", errVehicleType, schema.Type)
} }
interval := time.Duration(uint(schema.Interval)) * time.Second interval := time.Duration(uint(schema.Interval)) * time.Second
filter := schema.Filter
excludeFilter := schema.ExcludeFilter
excludeType := schema.ExcludeType
dialerProxy := schema.DialerProxy
override := schema.Override
return NewProxySetProvider(name, interval, parser, vehicle, hc) return NewProxySetProvider(name, interval, filter, excludeFilter, excludeType, dialerProxy, override, vehicle, hc)
} }

View file

@ -0,0 +1,19 @@
//go:build android && cmfa
package provider
import (
"time"
)
var (
suspended bool
)
type UpdatableProvider interface {
UpdatedAt() time.Time
}
func Suspend(s bool) {
suspended = s
}

View file

@ -1,6 +1,7 @@
package provider package provider
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -13,10 +14,11 @@ import (
"github.com/metacubex/mihomo/adapter" "github.com/metacubex/mihomo/adapter"
"github.com/metacubex/mihomo/common/convert" "github.com/metacubex/mihomo/common/convert"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/profile/cachefile" mihomoHttp "github.com/metacubex/mihomo/component/http"
"github.com/metacubex/mihomo/component/resource" "github.com/metacubex/mihomo/component/resource"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
types "github.com/metacubex/mihomo/constant/provider" types "github.com/metacubex/mihomo/constant/provider"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/tunnel/statistic" "github.com/metacubex/mihomo/tunnel/statistic"
"github.com/dlclark/regexp2" "github.com/dlclark/regexp2"
@ -31,119 +33,125 @@ type ProxySchema struct {
Proxies []map[string]any `yaml:"proxies"` Proxies []map[string]any `yaml:"proxies"`
} }
type providerForApi struct {
Name string `json:"name"`
Type string `json:"type"`
VehicleType string `json:"vehicleType"`
Proxies []C.Proxy `json:"proxies"`
TestUrl string `json:"testUrl"`
ExpectedStatus string `json:"expectedStatus"`
UpdatedAt time.Time `json:"updatedAt,omitempty"`
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
}
type baseProvider struct {
name string
proxies []C.Proxy
healthCheck *HealthCheck
version uint32
}
func (bp *baseProvider) Name() string {
return bp.name
}
func (bp *baseProvider) Version() uint32 {
return bp.version
}
func (bp *baseProvider) HealthCheck() {
bp.healthCheck.check()
}
func (bp *baseProvider) Type() types.ProviderType {
return types.Proxy
}
func (bp *baseProvider) Proxies() []C.Proxy {
return bp.proxies
}
func (bp *baseProvider) Count() int {
return len(bp.proxies)
}
func (bp *baseProvider) Touch() {
bp.healthCheck.touch()
}
func (bp *baseProvider) HealthCheckURL() string {
return bp.healthCheck.url
}
func (bp *baseProvider) RegisterHealthCheckTask(url string, expectedStatus utils.IntRanges[uint16], filter string, interval uint) {
bp.healthCheck.registerHealthCheckTask(url, expectedStatus, filter, interval)
}
func (bp *baseProvider) setProxies(proxies []C.Proxy) {
bp.proxies = proxies
bp.version += 1
bp.healthCheck.setProxy(proxies)
if bp.healthCheck.auto() {
go bp.healthCheck.check()
}
}
func (bp *baseProvider) Close() error {
bp.healthCheck.close()
return nil
}
// ProxySetProvider for auto gc // ProxySetProvider for auto gc
type ProxySetProvider struct { type ProxySetProvider struct {
*proxySetProvider *proxySetProvider
} }
type proxySetProvider struct { type proxySetProvider struct {
baseProvider
*resource.Fetcher[[]C.Proxy] *resource.Fetcher[[]C.Proxy]
proxies []C.Proxy
healthCheck *HealthCheck
version uint32
subscriptionInfo *SubscriptionInfo subscriptionInfo *SubscriptionInfo
} }
func (pp *proxySetProvider) MarshalJSON() ([]byte, error) { func (pp *proxySetProvider) MarshalJSON() ([]byte, error) {
return json.Marshal(providerForApi{ return json.Marshal(map[string]any{
Name: pp.Name(), "name": pp.Name(),
Type: pp.Type().String(), "type": pp.Type().String(),
VehicleType: pp.VehicleType().String(), "vehicleType": pp.VehicleType().String(),
Proxies: pp.Proxies(), "proxies": pp.Proxies(),
TestUrl: pp.healthCheck.url, "testUrl": pp.healthCheck.url,
ExpectedStatus: pp.healthCheck.expectedStatus.String(), "expectedStatus": pp.healthCheck.expectedStatus.String(),
UpdatedAt: pp.UpdatedAt(), "updatedAt": pp.UpdatedAt(),
SubscriptionInfo: pp.subscriptionInfo, "subscriptionInfo": pp.subscriptionInfo,
}) })
} }
func (pp *proxySetProvider) Version() uint32 {
return pp.version
}
func (pp *proxySetProvider) Name() string { func (pp *proxySetProvider) Name() string {
return pp.Fetcher.Name() return pp.Fetcher.Name()
} }
func (pp *proxySetProvider) HealthCheck() {
pp.healthCheck.check()
}
func (pp *proxySetProvider) Update() error { func (pp *proxySetProvider) Update() error {
_, _, err := pp.Fetcher.Update() elm, same, err := pp.Fetcher.Update()
if err == nil && !same {
pp.OnUpdate(elm)
}
return err return err
} }
func (pp *proxySetProvider) Initial() error { func (pp *proxySetProvider) Initial() error {
_, err := pp.Fetcher.Initial() elm, err := pp.Fetcher.Initial()
if err != nil { if err != nil {
return err return err
} }
if subscriptionInfo := cachefile.Cache().GetSubscriptionInfo(pp.Name()); subscriptionInfo != "" { pp.OnUpdate(elm)
pp.subscriptionInfo = NewSubscriptionInfo(subscriptionInfo) pp.getSubscriptionInfo()
}
pp.closeAllConnections() pp.closeAllConnections()
return nil return nil
} }
func (pp *proxySetProvider) Type() types.ProviderType {
return types.Proxy
}
func (pp *proxySetProvider) Proxies() []C.Proxy {
return pp.proxies
}
func (pp *proxySetProvider) Touch() {
pp.healthCheck.touch()
}
func (pp *proxySetProvider) HealthCheckURL() string {
return pp.healthCheck.url
}
func (pp *proxySetProvider) RegisterHealthCheckTask(url string, expectedStatus utils.IntRanges[uint16], filter string, interval uint) {
pp.healthCheck.registerHealthCheckTask(url, expectedStatus, filter, interval)
}
func (pp *proxySetProvider) setProxies(proxies []C.Proxy) {
pp.proxies = proxies
pp.healthCheck.setProxy(proxies)
if pp.healthCheck.auto() {
go pp.healthCheck.check()
}
}
func (pp *proxySetProvider) getSubscriptionInfo() {
if pp.VehicleType() != types.HTTP {
return
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*90)
defer cancel()
resp, err := mihomoHttp.HttpRequestWithProxy(ctx, pp.Vehicle().(*resource.HTTPVehicle).Url(),
http.MethodGet, http.Header{"User-Agent": {C.UA}}, nil, pp.Vehicle().Proxy())
if err != nil {
return
}
defer resp.Body.Close()
userInfoStr := strings.TrimSpace(resp.Header.Get("subscription-userinfo"))
if userInfoStr == "" {
resp2, err := mihomoHttp.HttpRequestWithProxy(ctx, pp.Vehicle().(*resource.HTTPVehicle).Url(),
http.MethodGet, http.Header{"User-Agent": {"Quantumultx"}}, nil, pp.Vehicle().Proxy())
if err != nil {
return
}
defer resp2.Body.Close()
userInfoStr = strings.TrimSpace(resp2.Header.Get("subscription-userinfo"))
if userInfoStr == "" {
return
}
}
pp.subscriptionInfo, err = NewSubscriptionInfo(userInfoStr)
if err != nil {
log.Warnln("[Provider] get subscription-userinfo: %e", err)
}
}()
}
func (pp *proxySetProvider) closeAllConnections() { func (pp *proxySetProvider) closeAllConnections() {
statistic.DefaultManager.Range(func(c statistic.Tracker) bool { statistic.DefaultManager.Range(func(c statistic.Tracker) bool {
for _, chain := range c.Chains() { for _, chain := range c.Chains() {
@ -157,176 +165,11 @@ func (pp *proxySetProvider) closeAllConnections() {
} }
func (pp *proxySetProvider) Close() error { func (pp *proxySetProvider) Close() error {
_ = pp.baseProvider.Close() pp.healthCheck.close()
return pp.Fetcher.Close() return pp.Fetcher.Close()
} }
func NewProxySetProvider(name string, interval time.Duration, parser resource.Parser[[]C.Proxy], vehicle types.Vehicle, hc *HealthCheck) (*ProxySetProvider, error) { func NewProxySetProvider(name string, interval time.Duration, filter string, excludeFilter string, excludeType string, dialerProxy string, override OverrideSchema, vehicle types.Vehicle, hc *HealthCheck) (*ProxySetProvider, error) {
if hc.auto() {
go hc.process()
}
pd := &proxySetProvider{
baseProvider: baseProvider{
name: name,
proxies: []C.Proxy{},
healthCheck: hc,
},
}
fetcher := resource.NewFetcher[[]C.Proxy](name, interval, vehicle, parser, pd.setProxies)
pd.Fetcher = fetcher
if httpVehicle, ok := vehicle.(*resource.HTTPVehicle); ok {
httpVehicle.SetInRead(func(resp *http.Response) {
if subscriptionInfo := resp.Header.Get("subscription-userinfo"); subscriptionInfo != "" {
cachefile.Cache().SetSubscriptionInfo(name, subscriptionInfo)
pd.subscriptionInfo = NewSubscriptionInfo(subscriptionInfo)
}
})
}
wrapper := &ProxySetProvider{pd}
runtime.SetFinalizer(wrapper, (*ProxySetProvider).Close)
return wrapper, nil
}
func (pp *ProxySetProvider) Close() error {
runtime.SetFinalizer(pp, nil)
return pp.proxySetProvider.Close()
}
// InlineProvider for auto gc
type InlineProvider struct {
*inlineProvider
}
type inlineProvider struct {
baseProvider
updateAt time.Time
}
func (ip *inlineProvider) MarshalJSON() ([]byte, error) {
return json.Marshal(providerForApi{
Name: ip.Name(),
Type: ip.Type().String(),
VehicleType: ip.VehicleType().String(),
Proxies: ip.Proxies(),
TestUrl: ip.healthCheck.url,
ExpectedStatus: ip.healthCheck.expectedStatus.String(),
UpdatedAt: ip.updateAt,
})
}
func (ip *inlineProvider) VehicleType() types.VehicleType {
return types.Inline
}
func (ip *inlineProvider) Initial() error {
return nil
}
func (ip *inlineProvider) Update() error {
// make api update happy
ip.updateAt = time.Now()
return nil
}
func NewInlineProvider(name string, payload []map[string]any, parser resource.Parser[[]C.Proxy], hc *HealthCheck) (*InlineProvider, error) {
if hc.auto() {
go hc.process()
}
ps := ProxySchema{Proxies: payload}
buf, err := yaml.Marshal(ps)
if err != nil {
return nil, err
}
proxies, err := parser(buf)
if err != nil {
return nil, err
}
ip := &inlineProvider{
baseProvider: baseProvider{
name: name,
proxies: proxies,
healthCheck: hc,
},
updateAt: time.Now(),
}
wrapper := &InlineProvider{ip}
runtime.SetFinalizer(wrapper, (*InlineProvider).Close)
return wrapper, nil
}
func (ip *InlineProvider) Close() error {
runtime.SetFinalizer(ip, nil)
return ip.baseProvider.Close()
}
// CompatibleProvider for auto gc
type CompatibleProvider struct {
*compatibleProvider
}
type compatibleProvider struct {
baseProvider
}
func (cp *compatibleProvider) MarshalJSON() ([]byte, error) {
return json.Marshal(providerForApi{
Name: cp.Name(),
Type: cp.Type().String(),
VehicleType: cp.VehicleType().String(),
Proxies: cp.Proxies(),
TestUrl: cp.healthCheck.url,
ExpectedStatus: cp.healthCheck.expectedStatus.String(),
})
}
func (cp *compatibleProvider) Update() error {
return nil
}
func (cp *compatibleProvider) Initial() error {
if cp.healthCheck.interval != 0 && cp.healthCheck.url != "" {
cp.HealthCheck()
}
return nil
}
func (cp *compatibleProvider) VehicleType() types.VehicleType {
return types.Compatible
}
func NewCompatibleProvider(name string, proxies []C.Proxy, hc *HealthCheck) (*CompatibleProvider, error) {
if len(proxies) == 0 {
return nil, errors.New("provider need one proxy at least")
}
if hc.auto() {
go hc.process()
}
pd := &compatibleProvider{
baseProvider: baseProvider{
name: name,
proxies: proxies,
healthCheck: hc,
},
}
wrapper := &CompatibleProvider{pd}
runtime.SetFinalizer(wrapper, (*CompatibleProvider).Close)
return wrapper, nil
}
func (cp *CompatibleProvider) Close() error {
runtime.SetFinalizer(cp, nil)
return cp.compatibleProvider.Close()
}
func NewProxiesParser(filter string, excludeFilter string, excludeType string, dialerProxy string, override OverrideSchema) (resource.Parser[[]C.Proxy], error) {
excludeFilterReg, err := regexp2.Compile(excludeFilter, regexp2.None) excludeFilterReg, err := regexp2.Compile(excludeFilter, regexp2.None)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid excludeFilter regex: %w", err) return nil, fmt.Errorf("invalid excludeFilter regex: %w", err)
@ -345,6 +188,136 @@ func NewProxiesParser(filter string, excludeFilter string, excludeType string, d
filterRegs = append(filterRegs, filterReg) filterRegs = append(filterRegs, filterReg)
} }
if hc.auto() {
go hc.process()
}
pd := &proxySetProvider{
proxies: []C.Proxy{},
healthCheck: hc,
}
fetcher := resource.NewFetcher[[]C.Proxy](name, interval, vehicle, proxiesParseAndFilter(filter, excludeFilter, excludeTypeArray, filterRegs, excludeFilterReg, dialerProxy, override), proxiesOnUpdate(pd))
pd.Fetcher = fetcher
wrapper := &ProxySetProvider{pd}
runtime.SetFinalizer(wrapper, (*ProxySetProvider).Close)
return wrapper, nil
}
func (pp *ProxySetProvider) Close() error {
runtime.SetFinalizer(pp, nil)
return pp.proxySetProvider.Close()
}
// CompatibleProvider for auto gc
type CompatibleProvider struct {
*compatibleProvider
}
type compatibleProvider struct {
name string
healthCheck *HealthCheck
proxies []C.Proxy
version uint32
}
func (cp *compatibleProvider) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"name": cp.Name(),
"type": cp.Type().String(),
"vehicleType": cp.VehicleType().String(),
"proxies": cp.Proxies(),
"testUrl": cp.healthCheck.url,
"expectedStatus": cp.healthCheck.expectedStatus.String(),
})
}
func (cp *compatibleProvider) Version() uint32 {
return cp.version
}
func (cp *compatibleProvider) Name() string {
return cp.name
}
func (cp *compatibleProvider) HealthCheck() {
cp.healthCheck.check()
}
func (cp *compatibleProvider) Update() error {
return nil
}
func (cp *compatibleProvider) Initial() error {
if cp.healthCheck.interval != 0 && cp.healthCheck.url != "" {
cp.HealthCheck()
}
return nil
}
func (cp *compatibleProvider) VehicleType() types.VehicleType {
return types.Compatible
}
func (cp *compatibleProvider) Type() types.ProviderType {
return types.Proxy
}
func (cp *compatibleProvider) Proxies() []C.Proxy {
return cp.proxies
}
func (cp *compatibleProvider) Touch() {
cp.healthCheck.touch()
}
func (cp *compatibleProvider) HealthCheckURL() string {
return cp.healthCheck.url
}
func (cp *compatibleProvider) RegisterHealthCheckTask(url string, expectedStatus utils.IntRanges[uint16], filter string, interval uint) {
cp.healthCheck.registerHealthCheckTask(url, expectedStatus, filter, interval)
}
func (cp *compatibleProvider) Close() error {
cp.healthCheck.close()
return nil
}
func NewCompatibleProvider(name string, proxies []C.Proxy, hc *HealthCheck) (*CompatibleProvider, error) {
if len(proxies) == 0 {
return nil, errors.New("provider need one proxy at least")
}
if hc.auto() {
go hc.process()
}
pd := &compatibleProvider{
name: name,
proxies: proxies,
healthCheck: hc,
}
wrapper := &CompatibleProvider{pd}
runtime.SetFinalizer(wrapper, (*CompatibleProvider).Close)
return wrapper, nil
}
func (cp *CompatibleProvider) Close() error {
runtime.SetFinalizer(cp, nil)
return cp.compatibleProvider.Close()
}
func proxiesOnUpdate(pd *proxySetProvider) func([]C.Proxy) {
return func(elm []C.Proxy) {
pd.setProxies(elm)
pd.version += 1
pd.getSubscriptionInfo()
}
}
func proxiesParseAndFilter(filter string, excludeFilter string, excludeTypeArray []string, filterRegs []*regexp2.Regexp, excludeFilterReg *regexp2.Regexp, dialerProxy string, override OverrideSchema) resource.Parser[[]C.Proxy] {
return func(buf []byte) ([]C.Proxy, error) { return func(buf []byte) ([]C.Proxy, error) {
schema := &ProxySchema{} schema := &ProxySchema{}
@ -426,16 +399,6 @@ func NewProxiesParser(filter string, excludeFilter string, excludeType string, d
case "additional-suffix": case "additional-suffix":
name := mapping["name"].(string) name := mapping["name"].(string)
mapping["name"] = name + *field.Interface().(*string) mapping["name"] = name + *field.Interface().(*string)
case "proxy-name":
// Iterate through all naming replacement rules and perform the replacements.
for _, expr := range override.ProxyName {
name := mapping["name"].(string)
newName, err := expr.Pattern.Replace(name, expr.Target, 0, -1)
if err != nil {
return nil, fmt.Errorf("proxy name replace error: %w", err)
}
mapping["name"] = newName
}
default: default:
mapping[fieldName] = field.Elem().Interface() mapping[fieldName] = field.Elem().Interface()
} }
@ -459,5 +422,5 @@ func NewProxiesParser(filter string, excludeFilter string, excludeType string, d
} }
return proxies, nil return proxies, nil
}, nil }
} }

View file

@ -1,11 +1,8 @@
package provider package provider
import ( import (
"fmt"
"strconv" "strconv"
"strings" "strings"
"github.com/metacubex/mihomo/log"
) )
type SubscriptionInfo struct { type SubscriptionInfo struct {
@ -15,44 +12,28 @@ type SubscriptionInfo struct {
Expire int64 Expire int64
} }
func NewSubscriptionInfo(userinfo string) (si *SubscriptionInfo) { func NewSubscriptionInfo(userinfo string) (si *SubscriptionInfo, err error) {
userinfo = strings.ReplaceAll(strings.ToLower(userinfo), " ", "") userinfo = strings.ToLower(userinfo)
userinfo = strings.ReplaceAll(userinfo, " ", "")
si = new(SubscriptionInfo) si = new(SubscriptionInfo)
for _, field := range strings.Split(userinfo, ";") { for _, field := range strings.Split(userinfo, ";") {
name, value, ok := strings.Cut(field, "=") switch name, value, _ := strings.Cut(field, "="); name {
if !ok {
continue
}
intValue, err := parseValue(value)
if err != nil {
log.Warnln("[Provider] get subscription-userinfo: %e", err)
continue
}
switch name {
case "upload": case "upload":
si.Upload = intValue si.Upload, err = strconv.ParseInt(value, 10, 64)
case "download": case "download":
si.Download = intValue si.Download, err = strconv.ParseInt(value, 10, 64)
case "total": case "total":
si.Total = intValue si.Total, err = strconv.ParseInt(value, 10, 64)
case "expire": case "expire":
si.Expire = intValue if value == "" {
si.Expire = 0
} else {
si.Expire, err = strconv.ParseInt(value, 10, 64)
} }
} }
return si if err != nil {
return
} }
func parseValue(value string) (int64, error) {
if intValue, err := strconv.ParseInt(value, 10, 64); err == nil {
return intValue, nil
} }
return
if floatValue, err := strconv.ParseFloat(value, 64); err == nil {
return int64(floatValue), nil
}
return 0, fmt.Errorf("failed to parse value '%s'", value)
} }

View file

@ -33,8 +33,15 @@ type ARC[K comparable, V any] struct {
// New returns a new Adaptive Replacement Cache (ARC). // New returns a new Adaptive Replacement Cache (ARC).
func New[K comparable, V any](options ...Option[K, V]) *ARC[K, V] { func New[K comparable, V any](options ...Option[K, V]) *ARC[K, V] {
arc := &ARC[K, V]{} arc := &ARC[K, V]{
arc.Clear() p: 0,
t1: list.New[*entry[K, V]](),
b1: list.New[*entry[K, V]](),
t2: list.New[*entry[K, V]](),
b2: list.New[*entry[K, V]](),
len: 0,
cache: make(map[K]*entry[K, V]),
}
for _, option := range options { for _, option := range options {
option(arc) option(arc)
@ -42,19 +49,6 @@ func New[K comparable, V any](options ...Option[K, V]) *ARC[K, V] {
return arc return arc
} }
func (a *ARC[K, V]) Clear() {
a.mutex.Lock()
defer a.mutex.Unlock()
a.p = 0
a.t1 = list.New[*entry[K, V]]()
a.b1 = list.New[*entry[K, V]]()
a.t2 = list.New[*entry[K, V]]()
a.b2 = list.New[*entry[K, V]]()
a.len = 0
a.cache = make(map[K]*entry[K, V])
}
// Set inserts a new key-value pair into the cache. // Set inserts a new key-value pair into the cache.
// This optimizes future access to this entry (side effect). // This optimizes future access to this entry (side effect).
func (a *ARC[K, V]) Set(key K, value V) { func (a *ARC[K, V]) Set(key K, value V) {

View file

@ -1,31 +0,0 @@
package contextutils
import (
"context"
"sync"
)
func afterFunc(ctx context.Context, f func()) (stop func() bool) {
stopc := make(chan struct{})
once := sync.Once{} // either starts running f or stops f from running
if ctx.Done() != nil {
go func() {
select {
case <-ctx.Done():
once.Do(func() {
go f()
})
case <-stopc:
}
}()
}
return func() bool {
stopped := false
once.Do(func() {
stopped = true
close(stopc)
})
return stopped
}
}

View file

@ -1,11 +0,0 @@
//go:build !go1.21
package contextutils
import (
"context"
)
func AfterFunc(ctx context.Context, f func()) (stop func() bool) {
return afterFunc(ctx, f)
}

View file

@ -1,9 +0,0 @@
//go:build go1.21
package contextutils
import "context"
func AfterFunc(ctx context.Context, f func()) (stop func() bool) {
return context.AfterFunc(ctx, f)
}

View file

@ -1,100 +0,0 @@
package contextutils
import (
"context"
"testing"
"time"
)
const (
shortDuration = 1 * time.Millisecond // a reasonable duration to block in a test
veryLongDuration = 1000 * time.Hour // an arbitrary upper bound on the test's running time
)
func TestAfterFuncCalledAfterCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
donec := make(chan struct{})
stop := afterFunc(ctx, func() {
close(donec)
})
select {
case <-donec:
t.Fatalf("AfterFunc called before context is done")
case <-time.After(shortDuration):
}
cancel()
select {
case <-donec:
case <-time.After(veryLongDuration):
t.Fatalf("AfterFunc not called after context is canceled")
}
if stop() {
t.Fatalf("stop() = true, want false")
}
}
func TestAfterFuncCalledAfterTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), shortDuration)
defer cancel()
donec := make(chan struct{})
afterFunc(ctx, func() {
close(donec)
})
select {
case <-donec:
case <-time.After(veryLongDuration):
t.Fatalf("AfterFunc not called after context is canceled")
}
}
func TestAfterFuncCalledImmediately(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
donec := make(chan struct{})
afterFunc(ctx, func() {
close(donec)
})
select {
case <-donec:
case <-time.After(veryLongDuration):
t.Fatalf("AfterFunc not called for already-canceled context")
}
}
func TestAfterFuncNotCalledAfterStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
donec := make(chan struct{})
stop := afterFunc(ctx, func() {
close(donec)
})
if !stop() {
t.Fatalf("stop() = false, want true")
}
cancel()
select {
case <-donec:
t.Fatalf("AfterFunc called for already-canceled context")
case <-time.After(shortDuration):
}
if stop() {
t.Fatalf("stop() = true, want false")
}
}
// This test verifies that canceling a context does not block waiting for AfterFuncs to finish.
func TestAfterFuncCalledAsynchronously(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
donec := make(chan struct{})
stop := afterFunc(ctx, func() {
// The channel send blocks until donec is read from.
donec <- struct{}{}
})
defer stop()
cancel()
// After cancel returns, read from donec and unblock the AfterFunc.
select {
case <-donec:
case <-time.After(veryLongDuration):
t.Fatalf("AfterFunc not called after context is canceled")
}
}

View file

@ -275,16 +275,15 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
vmess["skip-cert-verify"] = false vmess["skip-cert-verify"] = false
vmess["cipher"] = "auto" vmess["cipher"] = "auto"
if cipher, ok := values["scy"].(string); ok && cipher != "" { if cipher, ok := values["scy"]; ok && cipher != "" {
vmess["cipher"] = cipher vmess["cipher"] = cipher
} }
if sni, ok := values["sni"].(string); ok && sni != "" { if sni, ok := values["sni"]; ok && sni != "" {
vmess["servername"] = sni vmess["servername"] = sni
} }
network, ok := values["net"].(string) network, _ := values["net"].(string)
if ok {
network = strings.ToLower(network) network = strings.ToLower(network)
if values["type"] == "http" { if values["type"] == "http" {
network = "http" network = "http"
@ -292,7 +291,6 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
network = "h2" network = "h2"
} }
vmess["network"] = network vmess["network"] = network
}
tls, ok := values["tls"].(string) tls, ok := values["tls"].(string)
if ok { if ok {
@ -309,12 +307,12 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
case "http": case "http":
headers := make(map[string]any) headers := make(map[string]any)
httpOpts := make(map[string]any) httpOpts := make(map[string]any)
if host, ok := values["host"].(string); ok && host != "" { if host, ok := values["host"]; ok && host != "" {
headers["Host"] = []string{host} headers["Host"] = []string{host.(string)}
} }
httpOpts["path"] = []string{"/"} httpOpts["path"] = []string{"/"}
if path, ok := values["path"].(string); ok && path != "" { if path, ok := values["path"]; ok && path != "" {
httpOpts["path"] = []string{path} httpOpts["path"] = []string{path.(string)}
} }
httpOpts["headers"] = headers httpOpts["headers"] = headers
@ -323,8 +321,8 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
case "h2": case "h2":
headers := make(map[string]any) headers := make(map[string]any)
h2Opts := make(map[string]any) h2Opts := make(map[string]any)
if host, ok := values["host"].(string); ok && host != "" { if host, ok := values["host"]; ok && host != "" {
headers["Host"] = []string{host} headers["Host"] = []string{host.(string)}
} }
h2Opts["path"] = values["path"] h2Opts["path"] = values["path"]
@ -336,11 +334,11 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
headers := make(map[string]any) headers := make(map[string]any)
wsOpts := make(map[string]any) wsOpts := make(map[string]any)
wsOpts["path"] = "/" wsOpts["path"] = "/"
if host, ok := values["host"].(string); ok && host != "" { if host, ok := values["host"]; ok && host != "" {
headers["Host"] = host headers["Host"] = host.(string)
} }
if path, ok := values["path"].(string); ok && path != "" { if path, ok := values["path"]; ok && path != "" {
path := path path := path.(string)
pathURL, err := url.Parse(path) pathURL, err := url.Parse(path)
if err == nil { if err == nil {
query := pathURL.Query() query := pathURL.Query()

View file

@ -68,8 +68,10 @@ type LruCache[K comparable, V any] struct {
// New creates an LruCache // New creates an LruCache
func New[K comparable, V any](options ...Option[K, V]) *LruCache[K, V] { func New[K comparable, V any](options ...Option[K, V]) *LruCache[K, V] {
lc := &LruCache[K, V]{} lc := &LruCache[K, V]{
lc.Clear() lru: list.New[*entry[K, V]](),
cache: make(map[K]*list.Element[*entry[K, V]]),
}
for _, option := range options { for _, option := range options {
option(lc) option(lc)
@ -78,14 +80,6 @@ func New[K comparable, V any](options ...Option[K, V]) *LruCache[K, V] {
return lc return lc
} }
func (c *LruCache[K, V]) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.lru = list.New[*entry[K, V]]()
c.cache = make(map[K]*list.Element[*entry[K, V]])
}
// Get returns any representation of a cached response and a bool // Get returns any representation of a cached response and a bool
// set to true if the key was found. // set to true if the key was found.
func (c *LruCache[K, V]) Get(key K) (V, bool) { func (c *LruCache[K, V]) Get(key K) (V, bool) {
@ -256,6 +250,15 @@ func (c *LruCache[K, V]) deleteElement(le *list.Element[*entry[K, V]]) {
} }
} }
func (c *LruCache[K, V]) Clear() error {
c.mu.Lock()
defer c.mu.Unlock()
c.cache = make(map[K]*list.Element[*entry[K, V]])
return nil
}
// Compute either sets the computed new value for the key or deletes // Compute either sets the computed new value for the key or deletes
// the value for the key. When the delete result of the valueFn function // the value for the key. When the delete result of the valueFn function
// is set to true, the value will be deleted, if it exists. When delete // is set to true, the value will be deleted, if it exists. When delete

View file

@ -3,37 +3,29 @@ package net
import ( import (
"context" "context"
"net" "net"
"github.com/metacubex/mihomo/common/contextutils"
) )
// SetupContextForConn is a helper function that starts connection I/O interrupter. // SetupContextForConn is a helper function that starts connection I/O interrupter goroutine.
// if ctx be canceled before done called, it will close the connection.
// should use like this:
//
// func streamConn(ctx context.Context, conn net.Conn) (_ net.Conn, err error) {
// if ctx.Done() != nil {
// done := N.SetupContextForConn(ctx, conn)
// defer done(&err)
// }
// conn, err := xxx
// return conn, err
// }
func SetupContextForConn(ctx context.Context, conn net.Conn) (done func(*error)) { func SetupContextForConn(ctx context.Context, conn net.Conn) (done func(*error)) {
stopc := make(chan struct{}) var (
stop := contextutils.AfterFunc(ctx, func() { quit = make(chan struct{})
interrupt = make(chan error, 1)
)
go func() {
select {
case <-quit:
interrupt <- nil
case <-ctx.Done():
// Close the connection, discarding the error // Close the connection, discarding the error
_ = conn.Close() _ = conn.Close()
close(stopc) interrupt <- ctx.Err()
}) }
}()
return func(inputErr *error) { return func(inputErr *error) {
if !stop() { close(quit)
// The AfterFunc was started, wait for it to complete. if ctxErr := <-interrupt; ctxErr != nil && inputErr != nil {
<-stopc
if ctxErr := ctx.Err(); ctxErr != nil && inputErr != nil {
// Return context error to user. // Return context error to user.
*inputErr = ctxErr inputErr = &ctxErr
}
} }
} }
} }

View file

@ -1,103 +0,0 @@
package net_test
import (
"context"
"errors"
"net"
"testing"
"time"
N "github.com/metacubex/mihomo/common/net"
"github.com/stretchr/testify/assert"
)
func testRead(ctx context.Context, conn net.Conn) (err error) {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, conn)
defer done(&err)
}
_, err = conn.Read(make([]byte, 1))
return err
}
func TestSetupContextForConnWithCancel(t *testing.T) {
t.Parallel()
c1, c2 := N.Pipe()
defer c1.Close()
defer c2.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errc := make(chan error)
go func() {
errc <- testRead(ctx, c1)
}()
select {
case <-errc:
t.Fatal("conn closed before cancel")
case <-time.After(100 * time.Millisecond):
cancel()
}
select {
case err := <-errc:
assert.ErrorIs(t, err, context.Canceled)
case <-time.After(100 * time.Millisecond):
t.Fatal("conn not be canceled")
}
}
func TestSetupContextForConnWithTimeout1(t *testing.T) {
t.Parallel()
c1, c2 := N.Pipe()
defer c1.Close()
defer c2.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
errc := make(chan error)
go func() {
errc <- testRead(ctx, c1)
}()
select {
case err := <-errc:
if !errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Fatal("conn closed before timeout")
}
assert.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(200 * time.Millisecond):
t.Fatal("conn not be canceled")
}
}
func TestSetupContextForConnWithTimeout2(t *testing.T) {
t.Parallel()
c1, c2 := N.Pipe()
defer c1.Close()
defer c2.Close()
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
errc := make(chan error)
go func() {
errc <- testRead(ctx, c1)
}()
select {
case <-errc:
t.Fatal("conn closed before cancel")
case <-time.After(100 * time.Millisecond):
c2.Write(make([]byte, 1))
}
select {
case err := <-errc:
assert.Nil(t, ctx.Err())
assert.Nil(t, err)
case <-time.After(200 * time.Millisecond):
t.Fatal("conn not be canceled")
}
}

View file

@ -45,11 +45,7 @@ func (c *enhanceUDPConn) WaitReadFrom() (data []byte, put func(), addr net.Addr,
addr = &net.UDPAddr{IP: ip[:], Port: from.Port} addr = &net.UDPAddr{IP: ip[:], Port: from.Port}
case *syscall.SockaddrInet6: case *syscall.SockaddrInet6:
ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 16 bytes ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 16 bytes
zone := "" addr = &net.UDPAddr{IP: ip[:], Port: from.Port, Zone: strconv.FormatInt(int64(from.ZoneId), 10)}
if from.ZoneId != 0 {
zone = strconv.FormatInt(int64(from.ZoneId), 10)
}
addr = &net.UDPAddr{IP: ip[:], Port: from.Port, Zone: zone}
} }
} }
// udp should not convert readN == 0 to io.EOF // udp should not convert readN == 0 to io.EOF

View file

@ -54,11 +54,7 @@ func (c *enhanceUDPConn) WaitReadFrom() (data []byte, put func(), addr net.Addr,
addr = &net.UDPAddr{IP: ip[:], Port: from.Port} addr = &net.UDPAddr{IP: ip[:], Port: from.Port}
case *windows.SockaddrInet6: case *windows.SockaddrInet6:
ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 16 bytes ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 16 bytes
zone := "" addr = &net.UDPAddr{IP: ip[:], Port: from.Port, Zone: strconv.FormatInt(int64(from.ZoneId), 10)}
if from.ZoneId != 0 {
zone = strconv.FormatInt(int64(from.ZoneId), 10)
}
addr = &net.UDPAddr{IP: ip[:], Port: from.Port, Zone: zone}
} }
} }
// udp should not convert readN == 0 to io.EOF // udp should not convert readN == 0 to io.EOF

View file

@ -77,6 +77,6 @@ func (c *refConn) WriterReplaceable() bool { // Relay() will handle reference
var _ ExtendedConn = (*refConn)(nil) var _ ExtendedConn = (*refConn)(nil)
func NewRefConn(conn net.Conn, ref any) ExtendedConn { func NewRefConn(conn net.Conn, ref any) net.Conn {
return &refConn{conn: NewExtendedConn(conn), ref: ref} return &refConn{conn: NewExtendedConn(conn), ref: ref}
} }

View file

@ -0,0 +1,23 @@
package net
import (
"net"
"runtime"
"time"
)
var (
KeepAliveIdle = 0 * time.Second
KeepAliveInterval = 0 * time.Second
DisableKeepAlive = false
)
func TCPKeepAlive(c net.Conn) {
if tcp, ok := c.(*net.TCPConn); ok {
if runtime.GOOS == "android" || DisableKeepAlive {
_ = tcp.SetKeepAlive(false)
} else {
tcpKeepAlive(tcp)
}
}
}

View file

@ -0,0 +1,10 @@
//go:build !go1.23
package net
import "net"
func tcpKeepAlive(tcp *net.TCPConn) {
_ = tcp.SetKeepAlive(true)
_ = tcp.SetKeepAlivePeriod(KeepAliveInterval)
}

View file

@ -0,0 +1,19 @@
//go:build go1.23
package net
import "net"
func tcpKeepAlive(tcp *net.TCPConn) {
config := net.KeepAliveConfig{
Enable: true,
Idle: KeepAliveIdle,
Interval: KeepAliveInterval,
}
if !SupportTCPKeepAliveCount() {
// it's recommended to set both Idle and Interval to non-negative values in conjunction with a -1
// for Count on those old Windows if you intend to customize the TCP keep-alive settings.
config.Count = -1
}
_ = tcp.SetKeepAliveConfig(config)
}

View file

@ -1,6 +1,6 @@
//go:build go1.23 && unix //go:build go1.23 && unix
package keepalive package net
func SupportTCPKeepAliveIdle() bool { func SupportTCPKeepAliveIdle() bool {
return true return true

View file

@ -2,7 +2,7 @@
// copy and modify from golang1.23's internal/syscall/windows/version_windows.go // copy and modify from golang1.23's internal/syscall/windows/version_windows.go
package keepalive package net
import ( import (
"errors" "errors"

View file

@ -3,10 +3,8 @@ package net
import ( import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/sha256"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/hex"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"math/big" "math/big"
@ -18,11 +16,7 @@ type Path interface {
func ParseCert(certificate, privateKey string, path Path) (tls.Certificate, error) { func ParseCert(certificate, privateKey string, path Path) (tls.Certificate, error) {
if certificate == "" && privateKey == "" { if certificate == "" && privateKey == "" {
var err error return newRandomTLSKeyPair()
certificate, privateKey, _, err = NewRandomTLSKeyPair()
if err != nil {
return tls.Certificate{}, err
}
} }
cert, painTextErr := tls.X509KeyPair([]byte(certificate), []byte(privateKey)) cert, painTextErr := tls.X509KeyPair([]byte(certificate), []byte(privateKey))
if painTextErr == nil { if painTextErr == nil {
@ -38,10 +32,10 @@ func ParseCert(certificate, privateKey string, path Path) (tls.Certificate, erro
return cert, nil return cert, nil
} }
func NewRandomTLSKeyPair() (certificate string, privateKey string, fingerprint string, err error) { func newRandomTLSKeyPair() (tls.Certificate, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048) key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil { if err != nil {
return return tls.Certificate{}, err
} }
template := x509.Certificate{SerialNumber: big.NewInt(1)} template := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate( certDER, err := x509.CreateCertificate(
@ -51,15 +45,14 @@ func NewRandomTLSKeyPair() (certificate string, privateKey string, fingerprint s
&key.PublicKey, &key.PublicKey,
key) key)
if err != nil { if err != nil {
return return tls.Certificate{}, err
} }
cert, err := x509.ParseCertificate(certDER) keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil { if err != nil {
return return tls.Certificate{}, err
} }
hash := sha256.Sum256(cert.Raw) return tlsCert, nil
fingerprint = hex.EncodeToString(hash[:])
privateKey = string(pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}))
certificate = string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}))
return
} }

73
common/nnip/netip.go Normal file
View file

@ -0,0 +1,73 @@
package nnip
import (
"encoding/binary"
"net"
"net/netip"
)
// IpToAddr converts the net.IP to netip.Addr.
// If slice's length is not 4 or 16, IpToAddr returns netip.Addr{}
func IpToAddr(slice net.IP) netip.Addr {
ip := slice
if len(ip) != 4 {
if ip = slice.To4(); ip == nil {
ip = slice
}
}
if addr, ok := netip.AddrFromSlice(ip); ok {
return addr
}
return netip.Addr{}
}
// UnMasked returns p's last IP address.
// If p is invalid, UnMasked returns netip.Addr{}
func UnMasked(p netip.Prefix) netip.Addr {
if !p.IsValid() {
return netip.Addr{}
}
buf := p.Addr().As16()
hi := binary.BigEndian.Uint64(buf[:8])
lo := binary.BigEndian.Uint64(buf[8:])
bits := p.Bits()
if bits <= 32 {
bits += 96
}
hi = hi | ^uint64(0)>>bits
lo = lo | ^(^uint64(0) << (128 - bits))
binary.BigEndian.PutUint64(buf[:8], hi)
binary.BigEndian.PutUint64(buf[8:], lo)
addr := netip.AddrFrom16(buf)
if p.Addr().Is4() {
return addr.Unmap()
}
return addr
}
// PrefixCompare returns an integer comparing two prefixes.
// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
// modify from https://github.com/golang/go/issues/61642#issuecomment-1848587909
func PrefixCompare(p, p2 netip.Prefix) int {
// compare by validity, address family and prefix base address
if c := p.Masked().Addr().Compare(p2.Masked().Addr()); c != 0 {
return c
}
// compare by prefix length
f1, f2 := p.Bits(), p2.Bits()
if f1 < f2 {
return -1
}
if f1 > f2 {
return 1
}
// compare by prefix address
return p.Addr().Compare(p2.Addr())
}

View file

@ -10,7 +10,6 @@ type Observable[T any] struct {
listener map[Subscription[T]]*Subscriber[T] listener map[Subscription[T]]*Subscriber[T]
mux sync.Mutex mux sync.Mutex
done bool done bool
stopCh chan struct{}
} }
func (o *Observable[T]) process() { func (o *Observable[T]) process() {
@ -32,7 +31,6 @@ func (o *Observable[T]) close() {
for _, sub := range o.listener { for _, sub := range o.listener {
sub.Close() sub.Close()
} }
close(o.stopCh)
} }
func (o *Observable[T]) Subscribe() (Subscription[T], error) { func (o *Observable[T]) Subscribe() (Subscription[T], error) {
@ -61,7 +59,6 @@ func NewObservable[T any](iter Iterable[T]) *Observable[T] {
observable := &Observable[T]{ observable := &Observable[T]{
iterable: iter, iterable: iter,
listener: map[Subscription[T]]*Subscriber[T]{}, listener: map[Subscription[T]]*Subscriber[T]{},
stopCh: make(chan struct{}),
} }
go observable.process() go observable.process()
return observable return observable

View file

@ -70,11 +70,9 @@ func TestObservable_SubscribeClosedSource(t *testing.T) {
src := NewObservable[int](iter) src := NewObservable[int](iter)
data, _ := src.Subscribe() data, _ := src.Subscribe()
<-data <-data
select {
case <-src.stopCh: _, closed := src.Subscribe()
case <-time.After(time.Second): assert.NotNil(t, closed)
assert.Fail(t, "timeout not stop")
}
} }
func TestObservable_UnSubscribeWithNotExistSubscription(t *testing.T) { func TestObservable_UnSubscribeWithNotExistSubscription(t *testing.T) {

View file

@ -22,10 +22,9 @@ func sleepAndSend[T any](ctx context.Context, delay int, input T) func() (T, err
} }
func TestPicker_Basic(t *testing.T) { func TestPicker_Basic(t *testing.T) {
t.Parallel()
picker, ctx := WithContext[int](context.Background()) picker, ctx := WithContext[int](context.Background())
picker.Go(sleepAndSend(ctx, 200, 2)) picker.Go(sleepAndSend(ctx, 30, 2))
picker.Go(sleepAndSend(ctx, 100, 1)) picker.Go(sleepAndSend(ctx, 20, 1))
number := picker.Wait() number := picker.Wait()
assert.NotNil(t, number) assert.NotNil(t, number)
@ -33,9 +32,8 @@ func TestPicker_Basic(t *testing.T) {
} }
func TestPicker_Timeout(t *testing.T) { func TestPicker_Timeout(t *testing.T) {
t.Parallel()
picker, ctx := WithTimeout[int](context.Background(), time.Millisecond*5) picker, ctx := WithTimeout[int](context.Background(), time.Millisecond*5)
picker.Go(sleepAndSend(ctx, 100, 1)) picker.Go(sleepAndSend(ctx, 20, 1))
number := picker.Wait() number := picker.Wait()
assert.Equal(t, number, lo.Empty[int]()) assert.Equal(t, number, lo.Empty[int]())

View file

@ -13,6 +13,7 @@ type call[T any] struct {
type Single[T any] struct { type Single[T any] struct {
mux sync.Mutex mux sync.Mutex
last time.Time
wait time.Duration wait time.Duration
call *call[T] call *call[T]
result *Result[T] result *Result[T]
@ -21,18 +22,16 @@ type Single[T any] struct {
type Result[T any] struct { type Result[T any] struct {
Val T Val T
Err error Err error
Time time.Time
} }
// Do single.Do likes sync.singleFlight // Do single.Do likes sync.singleFlight
func (s *Single[T]) Do(fn func() (T, error)) (v T, err error, shared bool) { func (s *Single[T]) Do(fn func() (T, error)) (v T, err error, shared bool) {
s.mux.Lock() s.mux.Lock()
result := s.result now := time.Now()
if result != nil && time.Since(result.Time) < s.wait { if now.Before(s.last.Add(s.wait)) {
s.mux.Unlock() s.mux.Unlock()
return result.Val, result.Err, true return s.result.Val, s.result.Err, true
} }
s.result = nil // The result has expired, clear it
if callM := s.call; callM != nil { if callM := s.call; callM != nil {
s.mux.Unlock() s.mux.Unlock()
@ -48,19 +47,15 @@ func (s *Single[T]) Do(fn func() (T, error)) (v T, err error, shared bool) {
callM.wg.Done() callM.wg.Done()
s.mux.Lock() s.mux.Lock()
if s.call == callM { // maybe reset when fn is running
s.call = nil s.call = nil
s.result = &Result[T]{callM.val, callM.err, time.Now()} s.result = &Result[T]{callM.val, callM.err}
} s.last = now
s.mux.Unlock() s.mux.Unlock()
return callM.val, callM.err, false return callM.val, callM.err, false
} }
func (s *Single[T]) Reset() { func (s *Single[T]) Reset() {
s.mux.Lock() s.last = time.Time{}
s.call = nil
s.result = nil
s.mux.Unlock()
} }
func NewSingle[T any](wait time.Duration) *Single[T] { func NewSingle[T any](wait time.Duration) *Single[T] {

View file

@ -11,13 +11,12 @@ import (
) )
func TestBasic(t *testing.T) { func TestBasic(t *testing.T) {
t.Parallel() single := NewSingle[int](time.Millisecond * 30)
single := NewSingle[int](time.Millisecond * 200)
foo := 0 foo := 0
shardCount := atomic.NewInt32(0) shardCount := atomic.NewInt32(0)
call := func() (int, error) { call := func() (int, error) {
foo++ foo++
time.Sleep(time.Millisecond * 20) time.Sleep(time.Millisecond * 5)
return 0, nil return 0, nil
} }
@ -40,8 +39,7 @@ func TestBasic(t *testing.T) {
} }
func TestTimer(t *testing.T) { func TestTimer(t *testing.T) {
t.Parallel() single := NewSingle[int](time.Millisecond * 30)
single := NewSingle[int](time.Millisecond * 200)
foo := 0 foo := 0
callM := func() (int, error) { callM := func() (int, error) {
foo++ foo++
@ -49,7 +47,7 @@ func TestTimer(t *testing.T) {
} }
_, _, _ = single.Do(callM) _, _, _ = single.Do(callM)
time.Sleep(100 * time.Millisecond) time.Sleep(10 * time.Millisecond)
_, _, shard := single.Do(callM) _, _, shard := single.Do(callM)
assert.Equal(t, 1, foo) assert.Equal(t, 1, foo)
@ -57,8 +55,7 @@ func TestTimer(t *testing.T) {
} }
func TestReset(t *testing.T) { func TestReset(t *testing.T) {
t.Parallel() single := NewSingle[int](time.Millisecond * 30)
single := NewSingle[int](time.Millisecond * 200)
foo := 0 foo := 0
callM := func() (int, error) { callM := func() (int, error) {
foo++ foo++

View file

@ -1,30 +0,0 @@
package sockopt
import (
"net"
"syscall"
)
func RawConnReuseaddr(rc syscall.RawConn) (err error) {
var innerErr error
err = rc.Control(func(fd uintptr) {
innerErr = reuseControl(fd)
})
if innerErr != nil {
err = innerErr
}
return
}
func UDPReuseaddr(c net.PacketConn) error {
if c, ok := c.(syscall.Conn); ok {
rc, err := c.SyscallConn()
if err != nil {
return err
}
return RawConnReuseaddr(rc)
}
return nil
}

View file

@ -1,22 +0,0 @@
//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
package sockopt
import (
"golang.org/x/sys/unix"
)
func reuseControl(fd uintptr) error {
e1 := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
e2 := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
if e1 != nil {
return e1
}
if e2 != nil {
return e2
}
return nil
}

View file

@ -1,9 +0,0 @@
package sockopt
import (
"golang.org/x/sys/windows"
)
func reuseControl(fd uintptr) error {
return windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_REUSEADDR, 1)
}

View file

@ -0,0 +1,19 @@
package sockopt
import (
"net"
"syscall"
)
func UDPReuseaddr(c *net.UDPConn) (err error) {
rc, err := c.SyscallConn()
if err != nil {
return
}
rc.Control(func(fd uintptr) {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
})
return
}

View file

@ -0,0 +1,11 @@
//go:build !linux
package sockopt
import (
"net"
)
func UDPReuseaddr(c *net.UDPConn) (err error) {
return
}

View file

@ -3,7 +3,6 @@ package structure
// references: https://github.com/mitchellh/mapstructure // references: https://github.com/mitchellh/mapstructure
import ( import (
"encoding"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"reflect" "reflect"
@ -86,35 +85,8 @@ func (d *Decoder) Decode(src map[string]any, dst any) error {
return nil return nil
} }
// isNil returns true if the input is nil or a typed nil pointer.
func isNil(input any) bool {
if input == nil {
return true
}
val := reflect.ValueOf(input)
return val.Kind() == reflect.Pointer && val.IsNil()
}
func (d *Decoder) decode(name string, data any, val reflect.Value) error { func (d *Decoder) decode(name string, data any, val reflect.Value) error {
if isNil(data) {
// If the data is nil, then we don't set anything
// Maybe we should set to zero value?
return nil
}
if !reflect.ValueOf(data).IsValid() {
// If the input value is invalid, then we just set the value
// to be the zero value.
val.Set(reflect.Zero(val.Type()))
return nil
}
for {
kind := val.Kind() kind := val.Kind()
if kind == reflect.Pointer && val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
if ok, err := d.decodeTextUnmarshaller(name, data, val); ok {
return err
}
switch { switch {
case isInt(kind): case isInt(kind):
return d.decodeInt(name, data, val) return d.decodeInt(name, data, val)
@ -125,8 +97,10 @@ func (d *Decoder) decode(name string, data any, val reflect.Value) error {
} }
switch kind { switch kind {
case reflect.Pointer: case reflect.Pointer:
val = val.Elem() if val.IsNil() {
continue val.Set(reflect.New(val.Type().Elem()))
}
return d.decode(name, data, val.Elem())
case reflect.String: case reflect.String:
return d.decodeString(name, data, val) return d.decodeString(name, data, val)
case reflect.Bool: case reflect.Bool:
@ -143,7 +117,6 @@ func (d *Decoder) decode(name string, data any, val reflect.Value) error {
return fmt.Errorf("type %s not support", val.Kind().String()) return fmt.Errorf("type %s not support", val.Kind().String())
} }
} }
}
func isInt(kind reflect.Kind) bool { func isInt(kind reflect.Kind) bool {
switch kind { switch kind {
@ -580,25 +553,3 @@ func (d *Decoder) setInterface(name string, data any, val reflect.Value) (err er
val.Set(dataVal) val.Set(dataVal)
return nil return nil
} }
func (d *Decoder) decodeTextUnmarshaller(name string, data any, val reflect.Value) (bool, error) {
if !val.CanAddr() {
return false, nil
}
valAddr := val.Addr()
if !valAddr.CanInterface() {
return false, nil
}
unmarshaller, ok := valAddr.Interface().(encoding.TextUnmarshaler)
if !ok {
return false, nil
}
var str string
if err := d.decodeString(name, data, reflect.Indirect(reflect.ValueOf(&str))); err != nil {
return false, err
}
if err := unmarshaller.UnmarshalText([]byte(str)); err != nil {
return true, fmt.Errorf("cannot parse '%s' as %s: %s", name, val.Type(), err)
}
return true, nil
}

View file

@ -1,7 +1,6 @@
package structure package structure
import ( import (
"strconv"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -180,108 +179,3 @@ func TestStructure_SliceNilValueComplex(t *testing.T) {
err = decoder.Decode(rawMap, ss) err = decoder.Decode(rawMap, ss)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
func TestStructure_SliceCap(t *testing.T) {
rawMap := map[string]any{
"foo": []string{},
}
s := &struct {
Foo []string `test:"foo,omitempty"`
Bar []string `test:"bar,omitempty"`
}{}
err := decoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.NotNil(t, s.Foo) // structure's Decode will ensure value not nil when input has value even it was set an empty array
assert.Nil(t, s.Bar)
}
func TestStructure_Base64(t *testing.T) {
rawMap := map[string]any{
"foo": "AQID",
}
s := &struct {
Foo []byte `test:"foo"`
}{}
err := decoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.Equal(t, []byte{1, 2, 3}, s.Foo)
}
func TestStructure_Pointer(t *testing.T) {
rawMap := map[string]any{
"foo": "foo",
}
s := &struct {
Foo *string `test:"foo,omitempty"`
Bar *string `test:"bar,omitempty"`
}{}
err := decoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.NotNil(t, s.Foo)
assert.Equal(t, "foo", *s.Foo)
assert.Nil(t, s.Bar)
}
type num struct {
a int
}
func (n *num) UnmarshalText(text []byte) (err error) {
n.a, err = strconv.Atoi(string(text))
return
}
func TestStructure_TextUnmarshaller(t *testing.T) {
rawMap := map[string]any{
"num": "255",
"num_p": "127",
}
s := &struct {
Num num `test:"num"`
NumP *num `test:"num_p"`
}{}
err := decoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.Equal(t, 255, s.Num.a)
assert.NotNil(t, s.NumP)
assert.Equal(t, s.NumP.a, 127)
// test WeaklyTypedInput
rawMap["num"] = 256
err = decoder.Decode(rawMap, s)
assert.NotNilf(t, err, "should throw error: %#v", s)
err = weakTypeDecoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.Equal(t, 256, s.Num.a)
// test invalid input
rawMap["num_p"] = "abc"
err = decoder.Decode(rawMap, s)
assert.NotNilf(t, err, "should throw error: %#v", s)
}
func TestStructure_Null(t *testing.T) {
rawMap := map[string]any{
"opt": map[string]any{
"bar": nil,
},
}
s := struct {
Opt struct {
Bar string `test:"bar,optional"`
} `test:"opt,optional"`
}{}
err := decoder.Decode(rawMap, &s)
assert.Nil(t, err)
assert.Equal(t, s.Opt.Bar, "")
}

View file

@ -1,62 +0,0 @@
package utils
import (
"crypto/md5"
"encoding/hex"
"errors"
)
// HashType warps hash array inside struct
// someday can change to other hash algorithm simply
type HashType struct {
md5 [md5.Size]byte // MD5
}
func MakeHash(data []byte) HashType {
return HashType{md5.Sum(data)}
}
func (h HashType) Equal(hash HashType) bool {
return h.md5 == hash.md5
}
func (h HashType) Bytes() []byte {
return h.md5[:]
}
func (h HashType) String() string {
return hex.EncodeToString(h.Bytes())
}
func (h HashType) MarshalText() ([]byte, error) {
return []byte(h.String()), nil
}
func (h *HashType) UnmarshalText(data []byte) error {
if hex.DecodedLen(len(data)) != md5.Size {
return errors.New("invalid hash length")
}
_, err := hex.Decode(h.md5[:], data)
return err
}
func (h HashType) MarshalBinary() ([]byte, error) {
return h.md5[:], nil
}
func (h *HashType) UnmarshalBinary(data []byte) error {
if len(data) != md5.Size {
return errors.New("invalid hash length")
}
copy(h.md5[:], data)
return nil
}
func (h HashType) Len() int {
return len(h.md5)
}
func (h HashType) IsValid() bool {
var zero HashType
return h != zero
}

View file

@ -3,7 +3,6 @@ package utils
import ( import (
"errors" "errors"
"fmt" "fmt"
"sort"
"strconv" "strconv"
"strings" "strings"
@ -140,34 +139,10 @@ func (ranges IntRanges[T]) Range(f func(t T) bool) {
} }
for _, r := range ranges { for _, r := range ranges {
for i := r.Start(); i <= r.End() && i >= r.Start(); i++ { for i := r.Start(); i <= r.End(); i++ {
if !f(i) { if !f(i) {
return return
} }
if i+1 < i { // integer overflow
break
} }
} }
} }
}
func (ranges IntRanges[T]) Merge() (mergedRanges IntRanges[T]) {
if len(ranges) == 0 {
return
}
sort.Slice(ranges, func(i, j int) bool {
return ranges[i].Start() < ranges[j].Start()
})
mergedRanges = ranges[:1]
var rangeIndex int
for _, r := range ranges[1:] {
if mergedRanges[rangeIndex].End()+1 > mergedRanges[rangeIndex].End() && // integer overflow
r.Start() > mergedRanges[rangeIndex].End()+1 {
mergedRanges = append(mergedRanges, r)
rangeIndex++
} else if r.End() > mergedRanges[rangeIndex].End() {
mergedRanges[rangeIndex].end = r.End()
}
}
return
}

View file

@ -1,82 +0,0 @@
package utils
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestMergeRanges(t *testing.T) {
t.Parallel()
for _, testRange := range []struct {
ranges IntRanges[uint16]
expected IntRanges[uint16]
}{
{
ranges: IntRanges[uint16]{
NewRange[uint16](0, 1),
NewRange[uint16](1, 2),
},
expected: IntRanges[uint16]{
NewRange[uint16](0, 2),
},
},
{
ranges: IntRanges[uint16]{
NewRange[uint16](0, 3),
NewRange[uint16](5, 7),
NewRange[uint16](8, 9),
NewRange[uint16](10, 10),
},
expected: IntRanges[uint16]{
NewRange[uint16](0, 3),
NewRange[uint16](5, 10),
},
},
{
ranges: IntRanges[uint16]{
NewRange[uint16](1, 3),
NewRange[uint16](2, 6),
NewRange[uint16](8, 10),
NewRange[uint16](15, 18),
},
expected: IntRanges[uint16]{
NewRange[uint16](1, 6),
NewRange[uint16](8, 10),
NewRange[uint16](15, 18),
},
},
{
ranges: IntRanges[uint16]{
NewRange[uint16](1, 3),
NewRange[uint16](2, 7),
NewRange[uint16](2, 6),
},
expected: IntRanges[uint16]{
NewRange[uint16](1, 7),
},
},
{
ranges: IntRanges[uint16]{
NewRange[uint16](1, 3),
NewRange[uint16](2, 6),
NewRange[uint16](2, 7),
},
expected: IntRanges[uint16]{
NewRange[uint16](1, 7),
},
},
{
ranges: IntRanges[uint16]{
NewRange[uint16](1, 3),
NewRange[uint16](2, 65535),
NewRange[uint16](2, 7),
NewRange[uint16](3, 16),
},
expected: IntRanges[uint16]{
NewRange[uint16](1, 65535),
},
},
} {
assert.Equal(t, testRange.expected, testRange.ranges.Merge())
}
}

View file

@ -5,11 +5,6 @@ type Authenticator interface {
Users() []string Users() []string
} }
type AuthStore interface {
Authenticator() Authenticator
SetAuthenticator(Authenticator)
}
type AuthUser struct { type AuthUser struct {
User string User string
Pass string Pass string

View file

@ -17,6 +17,7 @@ import (
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
) )
var trustCerts []*x509.Certificate
var globalCertPool *x509.CertPool var globalCertPool *x509.CertPool
var mutex sync.RWMutex var mutex sync.RWMutex
var errNotMatch = errors.New("certificate fingerprints do not match") var errNotMatch = errors.New("certificate fingerprints do not match")
@ -29,19 +30,11 @@ var DisableSystemCa, _ = strconv.ParseBool(os.Getenv("DISABLE_SYSTEM_CA"))
func AddCertificate(certificate string) error { func AddCertificate(certificate string) error {
mutex.Lock() mutex.Lock()
defer mutex.Unlock() defer mutex.Unlock()
if certificate == "" { if certificate == "" {
return fmt.Errorf("certificate is empty") return fmt.Errorf("certificate is empty")
} }
if cert, err := x509.ParseCertificate([]byte(certificate)); err == nil {
if globalCertPool == nil { trustCerts = append(trustCerts, cert)
initializeCertPool()
}
if globalCertPool.AppendCertsFromPEM([]byte(certificate)) {
return nil
} else if cert, err := x509.ParseCertificate([]byte(certificate)); err == nil {
globalCertPool.AddCert(cert)
return nil return nil
} else { } else {
return fmt.Errorf("add certificate failed") return fmt.Errorf("add certificate failed")
@ -58,6 +51,9 @@ func initializeCertPool() {
globalCertPool = x509.NewCertPool() globalCertPool = x509.NewCertPool()
} }
} }
for _, cert := range trustCerts {
globalCertPool.AddCert(cert)
}
if !DisableEmbedCa { if !DisableEmbedCa {
globalCertPool.AppendCertsFromPEM(_CaCertificates) globalCertPool.AppendCertsFromPEM(_CaCertificates)
} }
@ -66,6 +62,7 @@ func initializeCertPool() {
func ResetCertificate() { func ResetCertificate() {
mutex.Lock() mutex.Lock()
defer mutex.Unlock() defer mutex.Unlock()
trustCerts = nil
initializeCertPool() initializeCertPool()
} }
@ -141,6 +138,7 @@ func GetTLSConfig(tlsConfig *tls.Config, fingerprint string, customCA string, cu
if err != nil { if err != nil {
return nil, err return nil, err
} }
tlsConfig = GetGlobalTLSConfig(tlsConfig)
tlsConfig.VerifyPeerCertificate = verifyFingerprint(fingerprintBytes) tlsConfig.VerifyPeerCertificate = verifyFingerprint(fingerprintBytes)
tlsConfig.InsecureSkipVerify = true tlsConfig.InsecureSkipVerify = true
} }

View file

@ -6,6 +6,7 @@ import (
"net" "net"
"net/netip" "net/netip"
"github.com/metacubex/mihomo/common/nnip"
"github.com/metacubex/mihomo/component/iface" "github.com/metacubex/mihomo/component/iface"
"github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4"
@ -85,14 +86,12 @@ func receiveOffer(conn net.PacketConn, id dhcpv4.TransactionID, result chan<- []
return return
} }
results := make([]netip.Addr, 0, len(dns)) dnsAddr := make([]netip.Addr, l)
for _, ip := range dns { for i := 0; i < l; i++ {
if addr, ok := netip.AddrFromSlice(ip); ok { dnsAddr[i] = nnip.IpToAddr(dns[i])
results = append(results, addr.Unmap())
}
} }
result <- results result <- dnsAddr
return return
} }

View file

@ -7,12 +7,13 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/metacubex/mihomo/component/keepalive"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
"github.com/metacubex/mihomo/log"
) )
const ( const (
@ -20,16 +21,34 @@ const (
DefaultUDPTimeout = DefaultTCPTimeout DefaultUDPTimeout = DefaultTCPTimeout
) )
type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error)
var ( var (
dialMux sync.Mutex dialMux sync.Mutex
IP4PEnable bool
actualSingleStackDialContext = serialSingleStackDialContext actualSingleStackDialContext = serialSingleStackDialContext
actualDualStackDialContext = serialDualStackDialContext actualDualStackDialContext = serialDualStackDialContext
tcpConcurrent = false tcpConcurrent = false
fallbackTimeout = 300 * time.Millisecond fallbackTimeout = 300 * time.Millisecond
) )
func applyOptions(options ...Option) *option {
opt := &option{
interfaceName: DefaultInterface.Load(),
routingMark: int(DefaultRoutingMark.Load()),
}
for _, o := range DefaultOptions {
o(opt)
}
for _, o := range options {
o(opt)
}
return opt
}
func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) { func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) {
opt := applyOptions(options...) opt := applyOptions(options...)
@ -59,43 +78,28 @@ func DialContext(ctx context.Context, network, address string, options ...Option
} }
func ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort, options ...Option) (net.PacketConn, error) { func ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort, options ...Option) (net.PacketConn, error) {
opt := applyOptions(options...) cfg := applyOptions(options...)
lc := &net.ListenConfig{} lc := &net.ListenConfig{}
if opt.addrReuse { if cfg.addrReuse {
addrReuseToListenConfig(lc) addrReuseToListenConfig(lc)
} }
if DefaultSocketHook != nil { // ignore interfaceName, routingMark when DefaultSocketHook not null (in CMFA) if DefaultSocketHook != nil { // ignore interfaceName, routingMark when DefaultSocketHook not null (in CFMA)
socketHookToListenConfig(lc) socketHookToListenConfig(lc)
} else { } else {
if opt.interfaceName == "" { if cfg.interfaceName != "" {
opt.interfaceName = DefaultInterface.Load()
}
if opt.interfaceName == "" {
if finder := DefaultInterfaceFinder.Load(); finder != nil {
opt.interfaceName = finder.FindInterfaceName(rAddrPort.Addr())
}
}
if rAddrPort.Addr().Unmap().IsLoopback() {
// avoid "The requested address is not valid in its context."
opt.interfaceName = ""
}
if opt.interfaceName != "" {
bind := bindIfaceToListenConfig bind := bindIfaceToListenConfig
if opt.fallbackBind { if cfg.fallbackBind {
bind = fallbackBindIfaceToListenConfig bind = fallbackBindIfaceToListenConfig
} }
addr, err := bind(opt.interfaceName, lc, network, address, rAddrPort) addr, err := bind(cfg.interfaceName, lc, network, address, rAddrPort)
if err != nil { if err != nil {
return nil, err return nil, err
} }
address = addr address = addr
} }
if opt.routingMark == 0 { if cfg.routingMark != 0 {
opt.routingMark = int(DefaultRoutingMark.Load()) bindMarkToListenConfig(cfg.routingMark, lc, network, address)
}
if opt.routingMark != 0 {
bindMarkToListenConfig(opt.routingMark, lc, network, address)
} }
} }
@ -121,9 +125,11 @@ func GetTcpConcurrent() bool {
return tcpConcurrent return tcpConcurrent
} }
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt option) (net.Conn, error) { func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) {
var address string var address string
destination, port = resolver.LookupIP4P(destination, port) if IP4PEnable {
destination, port = lookupIP4P(destination, port)
}
address = net.JoinHostPort(destination.String(), port) address = net.JoinHostPort(destination.String(), port)
netDialer := opt.netDialer netDialer := opt.netDialer
@ -138,22 +144,13 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
} }
dialer := netDialer.(*net.Dialer) dialer := netDialer.(*net.Dialer)
keepalive.SetNetDialer(dialer)
if opt.mpTcp { if opt.mpTcp {
setMultiPathTCP(dialer) setMultiPathTCP(dialer)
} }
if DefaultSocketHook != nil { // ignore interfaceName, routingMark and tfo when DefaultSocketHook not null (in CMFA) if DefaultSocketHook != nil { // ignore interfaceName, routingMark and tfo when DefaultSocketHook not null (in CFMA)
socketHookToToDialer(dialer) socketHookToToDialer(dialer)
} else { } else {
if opt.interfaceName == "" {
opt.interfaceName = DefaultInterface.Load()
}
if opt.interfaceName == "" {
if finder := DefaultInterfaceFinder.Load(); finder != nil {
opt.interfaceName = finder.FindInterfaceName(destination)
}
}
if opt.interfaceName != "" { if opt.interfaceName != "" {
bind := bindIfaceToDialer bind := bindIfaceToDialer
if opt.fallbackBind { if opt.fallbackBind {
@ -163,9 +160,6 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
return nil, err return nil, err
} }
} }
if opt.routingMark == 0 {
opt.routingMark = int(DefaultRoutingMark.Load())
}
if opt.routingMark != 0 { if opt.routingMark != 0 {
bindMarkToDialer(opt.routingMark, dialer, network, destination) bindMarkToDialer(opt.routingMark, dialer, network, destination)
} }
@ -177,26 +171,26 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
return dialer.DialContext(ctx, network, address) return dialer.DialContext(ctx, network, address)
} }
func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
return serialDialContext(ctx, network, ips, port, opt) return serialDialContext(ctx, network, ips, port, opt)
} }
func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
return dualStackDialContext(ctx, serialDialContext, network, ips, port, opt) return dualStackDialContext(ctx, serialDialContext, network, ips, port, opt)
} }
func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
return parallelDialContext(ctx, network, ips, port, opt) return parallelDialContext(ctx, network, ips, port, opt)
} }
func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
if opt.prefer != 4 && opt.prefer != 6 { if opt.prefer != 4 && opt.prefer != 6 {
return parallelDialContext(ctx, network, ips, port, opt) return parallelDialContext(ctx, network, ips, port, opt)
} }
return dualStackDialContext(ctx, parallelDialContext, network, ips, port, opt) return dualStackDialContext(ctx, parallelDialContext, network, ips, port, opt)
} }
func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
ipv4s, ipv6s := resolver.SortationAddr(ips) ipv4s, ipv6s := resolver.SortationAddr(ips)
if len(ipv4s) == 0 && len(ipv6s) == 0 { if len(ipv4s) == 0 && len(ipv6s) == 0 {
return nil, ErrorNoIpAddress return nil, ErrorNoIpAddress
@ -277,7 +271,7 @@ loop:
return nil, errors.Join(errs...) return nil, errors.Join(errs...)
} }
func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
if len(ips) == 0 { if len(ips) == 0 {
return nil, ErrorNoIpAddress return nil, ErrorNoIpAddress
} }
@ -316,7 +310,7 @@ func parallelDialContext(ctx context.Context, network string, ips []netip.Addr,
return nil, os.ErrDeadlineExceeded return nil, os.ErrDeadlineExceeded
} }
func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
if len(ips) == 0 { if len(ips) == 0 {
return nil, ErrorNoIpAddress return nil, ErrorNoIpAddress
} }
@ -344,19 +338,27 @@ func parseAddr(ctx context.Context, network, address string, preferResolver reso
return nil, "-1", err return nil, "-1", err
} }
if preferResolver == nil {
preferResolver = resolver.ProxyServerHostResolver
}
var ips []netip.Addr var ips []netip.Addr
switch network { switch network {
case "tcp4", "udp4": case "tcp4", "udp4":
if preferResolver == nil {
ips, err = resolver.LookupIPv4ProxyServerHost(ctx, host)
} else {
ips, err = resolver.LookupIPv4WithResolver(ctx, host, preferResolver) ips, err = resolver.LookupIPv4WithResolver(ctx, host, preferResolver)
}
case "tcp6", "udp6": case "tcp6", "udp6":
if preferResolver == nil {
ips, err = resolver.LookupIPv6ProxyServerHost(ctx, host)
} else {
ips, err = resolver.LookupIPv6WithResolver(ctx, host, preferResolver) ips, err = resolver.LookupIPv6WithResolver(ctx, host, preferResolver)
}
default: default:
if preferResolver == nil {
ips, err = resolver.LookupIPProxyServerHost(ctx, host)
} else {
ips, err = resolver.LookupIPWithResolver(ctx, host, preferResolver) ips, err = resolver.LookupIPWithResolver(ctx, host, preferResolver)
} }
}
if err != nil { if err != nil {
return nil, "-1", fmt.Errorf("dns resolve failed: %w", err) return nil, "-1", fmt.Errorf("dns resolve failed: %w", err)
} }
@ -377,10 +379,33 @@ func (d Dialer) DialContext(ctx context.Context, network, address string) (net.C
} }
func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) { func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) {
return ListenPacket(ctx, ParseNetwork(network, rAddrPort.Addr()), address, rAddrPort, WithOption(d.Opt)) opt := d.Opt // make a copy
if rAddrPort.Addr().Unmap().IsLoopback() {
// avoid "The requested address is not valid in its context."
WithInterface("")(&opt)
}
return ListenPacket(ctx, ParseNetwork(network, rAddrPort.Addr()), address, rAddrPort, WithOption(opt))
} }
func NewDialer(options ...Option) Dialer { func NewDialer(options ...Option) Dialer {
opt := applyOptions(options...) opt := applyOptions(options...)
return Dialer{Opt: opt} return Dialer{Opt: *opt}
}
func GetIP4PEnable(enableIP4PConvert bool) {
IP4PEnable = enableIP4PConvert
}
// kanged from https://github.com/heiher/frp/blob/ip4p/client/ip4p.go
func lookupIP4P(addr netip.Addr, port string) (netip.Addr, string) {
ip := addr.AsSlice()
if ip[0] == 0x20 && ip[1] == 0x01 &&
ip[2] == 0x00 && ip[3] == 0x00 {
addr = netip.AddrFrom4([4]byte{ip[12], ip[13], ip[14], ip[15]})
port = strconv.Itoa(int(ip[10])<<8 + int(ip[11]))
log.Debugln("Convert IP4P address %s to %s", ip, net.JoinHostPort(addr.String(), port))
return addr, port
}
return addr, port
} }

View file

@ -3,23 +3,17 @@ package dialer
import ( import (
"context" "context"
"net" "net"
"net/netip"
"github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/common/atomic"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
) )
var ( var (
DefaultOptions []Option
DefaultInterface = atomic.NewTypedValue[string]("") DefaultInterface = atomic.NewTypedValue[string]("")
DefaultRoutingMark = atomic.NewInt32(0) DefaultRoutingMark = atomic.NewInt32(0)
DefaultInterfaceFinder = atomic.NewTypedValue[InterfaceFinder](nil)
) )
type InterfaceFinder interface {
FindInterfaceName(destination netip.Addr) string
}
type NetDialer interface { type NetDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error) DialContext(ctx context.Context, network, address string) (net.Conn, error)
} }
@ -114,15 +108,3 @@ func WithOption(o option) Option {
*opt = o *opt = o
} }
} }
func IsZeroOptions(opts []Option) bool {
return applyOptions(opts...) == option{}
}
func applyOptions(options ...Option) option {
opt := option{}
for _, o := range options {
o(&opt)
}
return opt
}

View file

@ -0,0 +1,29 @@
package dialer
import (
"context"
"net"
"net/netip"
)
func init() {
// We must use this DialContext to query DNS
// when using net default resolver.
net.DefaultResolver.PreferGo = true
net.DefaultResolver.Dial = resolverDialContext
}
func resolverDialContext(ctx context.Context, network, address string) (net.Conn, error) {
d := &net.Dialer{}
interfaceName := DefaultInterface.Load()
if interfaceName != "" {
dstIP, err := netip.ParseAddr(address)
if err == nil {
_ = bindIfaceToDialer(interfaceName, d, network, dstIP)
}
}
return d.DialContext(ctx, network, address)
}

View file

@ -1,5 +1,9 @@
//go:build !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows //go:build !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows
package sockopt package dialer
func reuseControl(fd uintptr) error { return nil } import (
"net"
)
func addrReuseToListenConfig(*net.ListenConfig) {}

View file

@ -0,0 +1,20 @@
//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
package dialer
import (
"context"
"net"
"syscall"
"golang.org/x/sys/unix"
)
func addrReuseToListenConfig(lc *net.ListenConfig) {
addControlToListenConfig(lc, func(ctx context.Context, network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
})
})
}

View file

@ -5,11 +5,13 @@ import (
"net" "net"
"syscall" "syscall"
"github.com/metacubex/mihomo/common/sockopt" "golang.org/x/sys/windows"
) )
func addrReuseToListenConfig(lc *net.ListenConfig) { func addrReuseToListenConfig(lc *net.ListenConfig) {
addControlToListenConfig(lc, func(ctx context.Context, network, address string, c syscall.RawConn) error { addControlToListenConfig(lc, func(ctx context.Context, network, address string, c syscall.RawConn) error {
return sockopt.RawConnReuseaddr(c) return c.Control(func(fd uintptr) {
windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_REUSEADDR, 1)
})
}) })
} }

View file

@ -7,11 +7,11 @@ import (
) )
// SocketControl // SocketControl
// never change type traits because it's used in CMFA // never change type traits because it's used in CFMA
type SocketControl func(network, address string, conn syscall.RawConn) error type SocketControl func(network, address string, conn syscall.RawConn) error
// DefaultSocketHook // DefaultSocketHook
// never change type traits because it's used in CMFA // never change type traits because it's used in CFMA
var DefaultSocketHook SocketControl var DefaultSocketHook SocketControl
func socketHookToToDialer(dialer *net.Dialer) { func socketHookToToDialer(dialer *net.Dialer) {

View file

@ -7,32 +7,46 @@ import (
) )
type cachefileStore struct { type cachefileStore struct {
cache *cachefile.FakeIpStore cache *cachefile.CacheFile
} }
// GetByHost implements store.GetByHost // GetByHost implements store.GetByHost
func (c *cachefileStore) GetByHost(host string) (netip.Addr, bool) { func (c *cachefileStore) GetByHost(host string) (netip.Addr, bool) {
return c.cache.GetByHost(host) elm := c.cache.GetFakeip([]byte(host))
if elm == nil {
return netip.Addr{}, false
}
if len(elm) == 4 {
return netip.AddrFrom4(*(*[4]byte)(elm)), true
} else {
return netip.AddrFrom16(*(*[16]byte)(elm)), true
}
} }
// PutByHost implements store.PutByHost // PutByHost implements store.PutByHost
func (c *cachefileStore) PutByHost(host string, ip netip.Addr) { func (c *cachefileStore) PutByHost(host string, ip netip.Addr) {
c.cache.PutByHost(host, ip) c.cache.PutFakeip([]byte(host), ip.AsSlice())
} }
// GetByIP implements store.GetByIP // GetByIP implements store.GetByIP
func (c *cachefileStore) GetByIP(ip netip.Addr) (string, bool) { func (c *cachefileStore) GetByIP(ip netip.Addr) (string, bool) {
return c.cache.GetByIP(ip) elm := c.cache.GetFakeip(ip.AsSlice())
if elm == nil {
return "", false
}
return string(elm), true
} }
// PutByIP implements store.PutByIP // PutByIP implements store.PutByIP
func (c *cachefileStore) PutByIP(ip netip.Addr, host string) { func (c *cachefileStore) PutByIP(ip netip.Addr, host string) {
c.cache.PutByIP(ip, host) c.cache.PutFakeip(ip.AsSlice(), []byte(host))
} }
// DelByIP implements store.DelByIP // DelByIP implements store.DelByIP
func (c *cachefileStore) DelByIP(ip netip.Addr) { func (c *cachefileStore) DelByIP(ip netip.Addr) {
c.cache.DelByIP(ip) addr := ip.AsSlice()
c.cache.DelFakeipPair(addr, c.cache.GetFakeip(addr))
} }
// Exist implements store.Exist // Exist implements store.Exist
@ -49,7 +63,3 @@ func (c *cachefileStore) CloneTo(store store) {}
func (c *cachefileStore) FlushFakeIP() error { func (c *cachefileStore) FlushFakeIP() error {
return c.cache.FlushFakeIP() return c.cache.FlushFakeIP()
} }
func newCachefileStore(cache *cachefile.CacheFile) *cachefileStore {
return &cachefileStore{cache.FakeIpStore()}
}

View file

@ -67,9 +67,8 @@ func (m *memoryStore) CloneTo(store store) {
// FlushFakeIP implements store.FlushFakeIP // FlushFakeIP implements store.FlushFakeIP
func (m *memoryStore) FlushFakeIP() error { func (m *memoryStore) FlushFakeIP() error {
m.cacheIP.Clear() _ = m.cacheIP.Clear()
m.cacheHost.Clear() return m.cacheHost.Clear()
return nil
} }
func newMemoryStore(size int) *memoryStore { func newMemoryStore(size int) *memoryStore {

View file

@ -6,10 +6,9 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/metacubex/mihomo/common/nnip"
"github.com/metacubex/mihomo/component/profile/cachefile" "github.com/metacubex/mihomo/component/profile/cachefile"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"go4.org/netipx"
) )
const ( const (
@ -184,7 +183,7 @@ func New(options Options) (*Pool, error) {
hostAddr = options.IPNet.Masked().Addr() hostAddr = options.IPNet.Masked().Addr()
gateway = hostAddr.Next() gateway = hostAddr.Next()
first = gateway.Next().Next().Next() // default start with 198.18.0.4 first = gateway.Next().Next().Next() // default start with 198.18.0.4
last = netipx.PrefixLastIP(options.IPNet) last = nnip.UnMasked(options.IPNet)
) )
if !options.IPNet.IsValid() || !first.IsValid() || !first.Less(last) { if !options.IPNet.IsValid() || !first.IsValid() || !first.Less(last) {
@ -202,7 +201,9 @@ func New(options Options) (*Pool, error) {
ipnet: options.IPNet, ipnet: options.IPNet,
} }
if options.Persistence { if options.Persistence {
pool.store = newCachefileStore(cachefile.Cache()) pool.store = &cachefileStore{
cache: cachefile.Cache(),
}
} else { } else {
pool.store = newMemoryStore(options.Size) pool.store = newMemoryStore(options.Size)
} }

View file

@ -43,7 +43,9 @@ func createCachefileStore(options Options) (*Pool, string, error) {
return nil, "", err return nil, "", err
} }
pool.store = newCachefileStore(&cachefile.CacheFile{DB: db}) pool.store = &cachefileStore{
cache: &cachefile.CacheFile{DB: db},
}
return pool, f.Name(), nil return pool, f.Name(), nil
} }
@ -61,13 +63,13 @@ func TestPool_Basic(t *testing.T) {
last := pool.Lookup("bar.com") last := pool.Lookup("bar.com")
bar, exist := pool.LookBack(last) bar, exist := pool.LookBack(last)
assert.Equal(t, first, netip.AddrFrom4([4]byte{192, 168, 0, 4})) assert.True(t, first == netip.AddrFrom4([4]byte{192, 168, 0, 4}))
assert.Equal(t, pool.Lookup("foo.com"), netip.AddrFrom4([4]byte{192, 168, 0, 4})) assert.True(t, pool.Lookup("foo.com") == netip.AddrFrom4([4]byte{192, 168, 0, 4}))
assert.Equal(t, last, netip.AddrFrom4([4]byte{192, 168, 0, 5})) assert.True(t, last == netip.AddrFrom4([4]byte{192, 168, 0, 5}))
assert.True(t, exist) assert.True(t, exist)
assert.Equal(t, bar, "bar.com") assert.Equal(t, bar, "bar.com")
assert.Equal(t, pool.Gateway(), netip.AddrFrom4([4]byte{192, 168, 0, 1})) assert.True(t, pool.Gateway() == netip.AddrFrom4([4]byte{192, 168, 0, 1}))
assert.Equal(t, pool.Broadcast(), netip.AddrFrom4([4]byte{192, 168, 0, 15})) assert.True(t, pool.Broadcast() == netip.AddrFrom4([4]byte{192, 168, 0, 15}))
assert.Equal(t, pool.IPNet().String(), ipnet.String()) assert.Equal(t, pool.IPNet().String(), ipnet.String())
assert.True(t, pool.Exist(netip.AddrFrom4([4]byte{192, 168, 0, 5}))) assert.True(t, pool.Exist(netip.AddrFrom4([4]byte{192, 168, 0, 5})))
assert.False(t, pool.Exist(netip.AddrFrom4([4]byte{192, 168, 0, 6}))) assert.False(t, pool.Exist(netip.AddrFrom4([4]byte{192, 168, 0, 6})))
@ -89,13 +91,13 @@ func TestPool_BasicV6(t *testing.T) {
last := pool.Lookup("bar.com") last := pool.Lookup("bar.com")
bar, exist := pool.LookBack(last) bar, exist := pool.LookBack(last)
assert.Equal(t, first, netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804")) assert.True(t, first == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804"))
assert.Equal(t, pool.Lookup("foo.com"), netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804")) assert.True(t, pool.Lookup("foo.com") == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804"))
assert.Equal(t, last, netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8805")) assert.True(t, last == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8805"))
assert.True(t, exist) assert.True(t, exist)
assert.Equal(t, bar, "bar.com") assert.Equal(t, bar, "bar.com")
assert.Equal(t, pool.Gateway(), netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8801")) assert.True(t, pool.Gateway() == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8801"))
assert.Equal(t, pool.Broadcast(), netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8bff")) assert.True(t, pool.Broadcast() == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8bff"))
assert.Equal(t, pool.IPNet().String(), ipnet.String()) assert.Equal(t, pool.IPNet().String(), ipnet.String())
assert.True(t, pool.Exist(netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8805"))) assert.True(t, pool.Exist(netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8805")))
assert.False(t, pool.Exist(netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8806"))) assert.False(t, pool.Exist(netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8806")))
@ -141,8 +143,8 @@ func TestPool_CycleUsed(t *testing.T) {
} }
baz := pool.Lookup("baz.com") baz := pool.Lookup("baz.com")
next := pool.Lookup("foo.com") next := pool.Lookup("foo.com")
assert.Equal(t, foo, baz) assert.True(t, foo == baz)
assert.Equal(t, next, bar) assert.True(t, next == bar)
} }
} }
@ -199,7 +201,7 @@ func TestPool_MaxCacheSize(t *testing.T) {
pool.Lookup("baz.com") pool.Lookup("baz.com")
next := pool.Lookup("foo.com") next := pool.Lookup("foo.com")
assert.NotEqual(t, first, next) assert.False(t, first == next)
} }
func TestPool_DoubleMapping(t *testing.T) { func TestPool_DoubleMapping(t *testing.T) {
@ -229,7 +231,7 @@ func TestPool_DoubleMapping(t *testing.T) {
assert.False(t, bazExist) assert.False(t, bazExist)
assert.True(t, barExist) assert.True(t, barExist)
assert.NotEqual(t, bazIP, newBazIP) assert.False(t, bazIP == newBazIP)
} }
func TestPool_Clone(t *testing.T) { func TestPool_Clone(t *testing.T) {
@ -241,8 +243,8 @@ func TestPool_Clone(t *testing.T) {
first := pool.Lookup("foo.com") first := pool.Lookup("foo.com")
last := pool.Lookup("bar.com") last := pool.Lookup("bar.com")
assert.Equal(t, first, netip.AddrFrom4([4]byte{192, 168, 0, 4})) assert.True(t, first == netip.AddrFrom4([4]byte{192, 168, 0, 4}))
assert.Equal(t, last, netip.AddrFrom4([4]byte{192, 168, 0, 5})) assert.True(t, last == netip.AddrFrom4([4]byte{192, 168, 0, 5}))
newPool, _ := New(Options{ newPool, _ := New(Options{
IPNet: ipnet, IPNet: ipnet,
@ -287,13 +289,13 @@ func TestPool_FlushFileCache(t *testing.T) {
baz := pool.Lookup("foo.com") baz := pool.Lookup("foo.com")
nero := pool.Lookup("foo.com") nero := pool.Lookup("foo.com")
assert.Equal(t, foo, fox) assert.True(t, foo == fox)
assert.Equal(t, foo, next) assert.True(t, foo == next)
assert.NotEqual(t, foo, baz) assert.False(t, foo == baz)
assert.Equal(t, bar, bax) assert.True(t, bar == bax)
assert.Equal(t, bar, baz) assert.True(t, bar == baz)
assert.NotEqual(t, bar, next) assert.False(t, bar == next)
assert.Equal(t, baz, nero) assert.True(t, baz == nero)
} }
} }
@ -316,11 +318,11 @@ func TestPool_FlushMemoryCache(t *testing.T) {
baz := pool.Lookup("foo.com") baz := pool.Lookup("foo.com")
nero := pool.Lookup("foo.com") nero := pool.Lookup("foo.com")
assert.Equal(t, foo, fox) assert.True(t, foo == fox)
assert.Equal(t, foo, next) assert.True(t, foo == next)
assert.NotEqual(t, foo, baz) assert.False(t, foo == baz)
assert.Equal(t, bar, bax) assert.True(t, bar == bax)
assert.Equal(t, bar, baz) assert.True(t, bar == baz)
assert.NotEqual(t, bar, next) assert.False(t, bar == next)
assert.Equal(t, baz, nero) assert.True(t, baz == nero)
} }

View file

@ -1,37 +0,0 @@
package generater
import (
"encoding/base64"
"fmt"
"github.com/gofrs/uuid/v5"
)
func Main(args []string) {
if len(args) < 1 {
panic("Using: generate uuid/reality-keypair/wg-keypair")
}
switch args[0] {
case "uuid":
newUUID, err := uuid.NewV4()
if err != nil {
panic(err)
}
fmt.Println(newUUID.String())
case "reality-keypair":
privateKey, err := GeneratePrivateKey()
if err != nil {
panic(err)
}
publicKey := privateKey.PublicKey()
fmt.Println("PrivateKey: " + base64.RawURLEncoding.EncodeToString(privateKey[:]))
fmt.Println("PublicKey: " + base64.RawURLEncoding.EncodeToString(publicKey[:]))
case "wg-keypair":
privateKey, err := GeneratePrivateKey()
if err != nil {
panic(err)
}
fmt.Println("PrivateKey: " + privateKey.String())
fmt.Println("PublicKey: " + privateKey.PublicKey().String())
}
}

View file

@ -1,97 +0,0 @@
// Copy from https://github.com/WireGuard/wgctrl-go/blob/a9ab2273dd1075ea74b88c76f8757f8b4003fcbf/wgtypes/types.go#L71-L155
package generater
import (
"crypto/rand"
"encoding/base64"
"fmt"
"golang.org/x/crypto/curve25519"
)
// KeyLen is the expected key length for a WireGuard key.
const KeyLen = 32 // wgh.KeyLen
// A Key is a public, private, or pre-shared secret key. The Key constructor
// functions in this package can be used to create Keys suitable for each of
// these applications.
type Key [KeyLen]byte
// GenerateKey generates a Key suitable for use as a pre-shared secret key from
// a cryptographically safe source.
//
// The output Key should not be used as a private key; use GeneratePrivateKey
// instead.
func GenerateKey() (Key, error) {
b := make([]byte, KeyLen)
if _, err := rand.Read(b); err != nil {
return Key{}, fmt.Errorf("wgtypes: failed to read random bytes: %v", err)
}
return NewKey(b)
}
// GeneratePrivateKey generates a Key suitable for use as a private key from a
// cryptographically safe source.
func GeneratePrivateKey() (Key, error) {
key, err := GenerateKey()
if err != nil {
return Key{}, err
}
// Modify random bytes using algorithm described at:
// https://cr.yp.to/ecdh.html.
key[0] &= 248
key[31] &= 127
key[31] |= 64
return key, nil
}
// NewKey creates a Key from an existing byte slice. The byte slice must be
// exactly 32 bytes in length.
func NewKey(b []byte) (Key, error) {
if len(b) != KeyLen {
return Key{}, fmt.Errorf("wgtypes: incorrect key size: %d", len(b))
}
var k Key
copy(k[:], b)
return k, nil
}
// ParseKey parses a Key from a base64-encoded string, as produced by the
// Key.String method.
func ParseKey(s string) (Key, error) {
b, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return Key{}, fmt.Errorf("wgtypes: failed to parse base64-encoded key: %v", err)
}
return NewKey(b)
}
// PublicKey computes a public key from the private key k.
//
// PublicKey should only be called when k is a private key.
func (k Key) PublicKey() Key {
var (
pub [KeyLen]byte
priv = [KeyLen]byte(k)
)
// ScalarBaseMult uses the correct base value per https://cr.yp.to/ecdh.html,
// so no need to specify it.
curve25519.ScalarBaseMult(&pub, &priv)
return Key(pub)
}
// String returns the base64-encoded string representation of a Key.
//
// ParseKey can be used to produce a new Key from this string.
func (k Key) String() string {
return base64.StdEncoding.EncodeToString(k[:])
}

View file

@ -6,10 +6,8 @@ import (
"io" "io"
"net/http" "net/http"
"os" "os"
"sync"
"time" "time"
"github.com/metacubex/mihomo/common/atomic"
mihomoHttp "github.com/metacubex/mihomo/component/http" mihomoHttp "github.com/metacubex/mihomo/component/http"
"github.com/metacubex/mihomo/component/mmdb" "github.com/metacubex/mihomo/component/mmdb"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
@ -20,57 +18,36 @@ var (
initGeoSite bool initGeoSite bool
initGeoIP int initGeoIP int
initASN bool initASN bool
initGeoSiteMutex sync.Mutex
initGeoIPMutex sync.Mutex
initASNMutex sync.Mutex
geoIpEnable atomic.Bool
geoSiteEnable atomic.Bool
asnEnable atomic.Bool
geoIpUrl string
mmdbUrl string
geoSiteUrl string
asnUrl string
) )
func GeoIpUrl() string { func InitGeoSite() error {
return geoIpUrl if _, err := os.Stat(C.Path.GeoSite()); os.IsNotExist(err) {
log.Infoln("Can't find GeoSite.dat, start download")
if err := downloadGeoSite(C.Path.GeoSite()); err != nil {
return fmt.Errorf("can't download GeoSite.dat: %s", err.Error())
}
log.Infoln("Download GeoSite.dat finish")
initGeoSite = false
}
if !initGeoSite {
if err := Verify(C.GeositeName); err != nil {
log.Warnln("GeoSite.dat invalid, remove and download: %s", err)
if err := os.Remove(C.Path.GeoSite()); err != nil {
return fmt.Errorf("can't remove invalid GeoSite.dat: %s", err.Error())
}
if err := downloadGeoSite(C.Path.GeoSite()); err != nil {
return fmt.Errorf("can't download GeoSite.dat: %s", err.Error())
}
}
initGeoSite = true
}
return nil
} }
func SetGeoIpUrl(url string) { func downloadGeoSite(path string) (err error) {
geoIpUrl = url
}
func MmdbUrl() string {
return mmdbUrl
}
func SetMmdbUrl(url string) {
mmdbUrl = url
}
func GeoSiteUrl() string {
return geoSiteUrl
}
func SetGeoSiteUrl(url string) {
geoSiteUrl = url
}
func ASNUrl() string {
return asnUrl
}
func SetASNUrl(url string) {
asnUrl = url
}
func downloadToPath(url string, path string) (err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*90) ctx, cancel := context.WithTimeout(context.Background(), time.Second*90)
defer cancel() defer cancel()
resp, err := mihomoHttp.HttpRequest(ctx, url, http.MethodGet, nil, nil) resp, err := mihomoHttp.HttpRequest(ctx, C.GeoSiteUrl, http.MethodGet, http.Header{"User-Agent": {C.UA}}, nil)
if err != nil { if err != nil {
return return
} }
@ -86,41 +63,30 @@ func downloadToPath(url string, path string) (err error) {
return err return err
} }
func InitGeoSite() error { func downloadGeoIP(path string) (err error) {
geoSiteEnable.Store(true) ctx, cancel := context.WithTimeout(context.Background(), time.Second*90)
initGeoSiteMutex.Lock() defer cancel()
defer initGeoSiteMutex.Unlock() resp, err := mihomoHttp.HttpRequest(ctx, C.GeoIpUrl, http.MethodGet, http.Header{"User-Agent": {C.UA}}, nil)
if _, err := os.Stat(C.Path.GeoSite()); os.IsNotExist(err) { if err != nil {
log.Infoln("Can't find GeoSite.dat, start download") return
if err := downloadToPath(GeoSiteUrl(), C.Path.GeoSite()); err != nil {
return fmt.Errorf("can't download GeoSite.dat: %s", err.Error())
} }
log.Infoln("Download GeoSite.dat finish") defer resp.Body.Close()
initGeoSite = false
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return err
} }
if !initGeoSite { defer f.Close()
if err := Verify(C.GeositeName); err != nil { _, err = io.Copy(f, resp.Body)
log.Warnln("GeoSite.dat invalid, remove and download: %s", err)
if err := os.Remove(C.Path.GeoSite()); err != nil { return err
return fmt.Errorf("can't remove invalid GeoSite.dat: %s", err.Error())
}
if err := downloadToPath(GeoSiteUrl(), C.Path.GeoSite()); err != nil {
return fmt.Errorf("can't download GeoSite.dat: %s", err.Error())
}
}
initGeoSite = true
}
return nil
} }
func InitGeoIP() error { func InitGeoIP() error {
geoIpEnable.Store(true) if C.GeodataMode {
initGeoIPMutex.Lock()
defer initGeoIPMutex.Unlock()
if GeodataMode() {
if _, err := os.Stat(C.Path.GeoIP()); os.IsNotExist(err) { if _, err := os.Stat(C.Path.GeoIP()); os.IsNotExist(err) {
log.Infoln("Can't find GeoIP.dat, start download") log.Infoln("Can't find GeoIP.dat, start download")
if err := downloadToPath(GeoIpUrl(), C.Path.GeoIP()); err != nil { if err := downloadGeoIP(C.Path.GeoIP()); err != nil {
return fmt.Errorf("can't download GeoIP.dat: %s", err.Error()) return fmt.Errorf("can't download GeoIP.dat: %s", err.Error())
} }
log.Infoln("Download GeoIP.dat finish") log.Infoln("Download GeoIP.dat finish")
@ -133,7 +99,7 @@ func InitGeoIP() error {
if err := os.Remove(C.Path.GeoIP()); err != nil { if err := os.Remove(C.Path.GeoIP()); err != nil {
return fmt.Errorf("can't remove invalid GeoIP.dat: %s", err.Error()) return fmt.Errorf("can't remove invalid GeoIP.dat: %s", err.Error())
} }
if err := downloadToPath(GeoIpUrl(), C.Path.GeoIP()); err != nil { if err := downloadGeoIP(C.Path.GeoIP()); err != nil {
return fmt.Errorf("can't download GeoIP.dat: %s", err.Error()) return fmt.Errorf("can't download GeoIP.dat: %s", err.Error())
} }
} }
@ -144,7 +110,7 @@ func InitGeoIP() error {
if _, err := os.Stat(C.Path.MMDB()); os.IsNotExist(err) { if _, err := os.Stat(C.Path.MMDB()); os.IsNotExist(err) {
log.Infoln("Can't find MMDB, start download") log.Infoln("Can't find MMDB, start download")
if err := downloadToPath(MmdbUrl(), C.Path.MMDB()); err != nil { if err := mmdb.DownloadMMDB(C.Path.MMDB()); err != nil {
return fmt.Errorf("can't download MMDB: %s", err.Error()) return fmt.Errorf("can't download MMDB: %s", err.Error())
} }
} }
@ -155,7 +121,7 @@ func InitGeoIP() error {
if err := os.Remove(C.Path.MMDB()); err != nil { if err := os.Remove(C.Path.MMDB()); err != nil {
return fmt.Errorf("can't remove invalid MMDB: %s", err.Error()) return fmt.Errorf("can't remove invalid MMDB: %s", err.Error())
} }
if err := downloadToPath(MmdbUrl(), C.Path.MMDB()); err != nil { if err := mmdb.DownloadMMDB(C.Path.MMDB()); err != nil {
return fmt.Errorf("can't download MMDB: %s", err.Error()) return fmt.Errorf("can't download MMDB: %s", err.Error())
} }
} }
@ -165,12 +131,9 @@ func InitGeoIP() error {
} }
func InitASN() error { func InitASN() error {
asnEnable.Store(true)
initASNMutex.Lock()
defer initASNMutex.Unlock()
if _, err := os.Stat(C.Path.ASN()); os.IsNotExist(err) { if _, err := os.Stat(C.Path.ASN()); os.IsNotExist(err) {
log.Infoln("Can't find ASN.mmdb, start download") log.Infoln("Can't find ASN.mmdb, start download")
if err := downloadToPath(ASNUrl(), C.Path.ASN()); err != nil { if err := mmdb.DownloadASN(C.Path.ASN()); err != nil {
return fmt.Errorf("can't download ASN.mmdb: %s", err.Error()) return fmt.Errorf("can't download ASN.mmdb: %s", err.Error())
} }
log.Infoln("Download ASN.mmdb finish") log.Infoln("Download ASN.mmdb finish")
@ -182,7 +145,7 @@ func InitASN() error {
if err := os.Remove(C.Path.ASN()); err != nil { if err := os.Remove(C.Path.ASN()); err != nil {
return fmt.Errorf("can't remove invalid ASN: %s", err.Error()) return fmt.Errorf("can't remove invalid ASN: %s", err.Error())
} }
if err := downloadToPath(ASNUrl(), C.Path.ASN()); err != nil { if err := mmdb.DownloadASN(C.Path.ASN()); err != nil {
return fmt.Errorf("can't download ASN: %s", err.Error()) return fmt.Errorf("can't download ASN: %s", err.Error())
} }
} }
@ -190,15 +153,3 @@ func InitASN() error {
} }
return nil return nil
} }
func GeoIpEnable() bool {
return geoIpEnable.Load()
}
func GeoSiteEnable() bool {
return geoSiteEnable.Load()
}
func ASNEnable() bool {
return asnEnable.Load()
}

View file

@ -13,6 +13,8 @@ import (
var ( var (
geoMode bool geoMode bool
AutoUpdate bool
UpdateInterval int
geoLoaderName = "memconservative" geoLoaderName = "memconservative"
geoSiteMatcher = "succinct" geoSiteMatcher = "succinct"
) )
@ -23,6 +25,14 @@ func GeodataMode() bool {
return geoMode return geoMode
} }
func GeoAutoUpdate() bool {
return AutoUpdate
}
func GeoUpdateInterval() int {
return UpdateInterval
}
func LoaderName() string { func LoaderName() string {
return geoLoaderName return geoLoaderName
} }
@ -34,6 +44,12 @@ func SiteMatcherName() string {
func SetGeodataMode(newGeodataMode bool) { func SetGeodataMode(newGeodataMode bool) {
geoMode = newGeodataMode geoMode = newGeodataMode
} }
func SetGeoAutoUpdate(newAutoUpdate bool) {
AutoUpdate = newAutoUpdate
}
func SetGeoUpdateInterval(newGeoUpdateInterval int) {
UpdateInterval = newGeoUpdateInterval
}
func SetLoader(newLoader string) { func SetLoader(newLoader string) {
if newLoader == "memc" { if newLoader == "memc" {
@ -193,11 +209,8 @@ func LoadGeoIPMatcher(country string) (router.IPMatcher, error) {
return matcher, nil return matcher, nil
} }
func ClearGeoSiteCache() { func ClearCache() {
loadGeoSiteMatcherListSF.Reset() loadGeoSiteMatcherListSF.Reset()
loadGeoSiteMatcherSF.Reset() loadGeoSiteMatcherSF.Reset()
}
func ClearGeoIPCache() {
loadGeoIPMatcherSF.Reset() loadGeoIPMatcherSF.Reset()
} }

Some files were not shown because too many files have changed in this diff Show more