Merge branch 'Alpha' into parental_control_1

update the last code
This commit is contained in:
Parental Clash 2024-09-22 18:43:15 -07:00
commit f34303bd16
187 changed files with 5232 additions and 9620 deletions

View file

@ -1,54 +0,0 @@
diff --git a/src/syscall/exec_windows.go b/src/syscall/exec_windows.go
index 06e684c7116b4..b311a5c74684b 100644
--- a/src/syscall/exec_windows.go
+++ b/src/syscall/exec_windows.go
@@ -319,17 +319,6 @@ func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid int, handle
}
}
- var maj, min, build uint32
- rtlGetNtVersionNumbers(&maj, &min, &build)
- isWin7 := maj < 6 || (maj == 6 && min <= 1)
- // NT kernel handles are divisible by 4, with the bottom 3 bits left as
- // a tag. The fully set tag correlates with the types of handles we're
- // concerned about here. Except, the kernel will interpret some
- // special handle values, like -1, -2, and so forth, so kernelbase.dll
- // checks to see that those bottom three bits are checked, but that top
- // bit is not checked.
- isLegacyWin7ConsoleHandle := func(handle Handle) bool { return isWin7 && handle&0x10000003 == 3 }
-
p, _ := GetCurrentProcess()
parentProcess := p
if sys.ParentProcess != 0 {
@@ -338,15 +327,7 @@ func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid int, handle
fd := make([]Handle, len(attr.Files))
for i := range attr.Files {
if attr.Files[i] > 0 {
- destinationProcessHandle := parentProcess
-
- // On Windows 7, console handles aren't real handles, and can only be duplicated
- // into the current process, not a parent one, which amounts to the same thing.
- if parentProcess != p && isLegacyWin7ConsoleHandle(Handle(attr.Files[i])) {
- destinationProcessHandle = p
- }
-
- err := DuplicateHandle(p, Handle(attr.Files[i]), destinationProcessHandle, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
+ err := DuplicateHandle(p, Handle(attr.Files[i]), parentProcess, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
if err != nil {
return 0, 0, err
}
@@ -377,14 +358,6 @@ func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid int, handle
fd = append(fd, sys.AdditionalInheritedHandles...)
- // On Windows 7, console handles aren't real handles, so don't pass them
- // through to PROC_THREAD_ATTRIBUTE_HANDLE_LIST.
- for i := range fd {
- if isLegacyWin7ConsoleHandle(fd[i]) {
- fd[i] = 0
- }
- }
-
// The presence of a NULL handle in the list is enough to cause PROC_THREAD_ATTRIBUTE_HANDLE_LIST
// to treat the entire list as empty, so remove NULL handles.
j := 0

View file

@ -1,158 +0,0 @@
diff --git a/src/crypto/rand/rand.go b/src/crypto/rand/rand.go
index 62738e2cb1a7d..d0dcc7cc71fc0 100644
--- a/src/crypto/rand/rand.go
+++ b/src/crypto/rand/rand.go
@@ -15,7 +15,7 @@ import "io"
// available, /dev/urandom otherwise.
// On OpenBSD and macOS, Reader uses getentropy(2).
// On other Unix-like systems, Reader reads from /dev/urandom.
-// On Windows systems, Reader uses the RtlGenRandom API.
+// On Windows systems, Reader uses the ProcessPrng API.
// On JS/Wasm, Reader uses the Web Crypto API.
// On WASIP1/Wasm, Reader uses random_get from wasi_snapshot_preview1.
var Reader io.Reader
diff --git a/src/crypto/rand/rand_windows.go b/src/crypto/rand/rand_windows.go
index 6c0655c72b692..7380f1f0f1e6e 100644
--- a/src/crypto/rand/rand_windows.go
+++ b/src/crypto/rand/rand_windows.go
@@ -15,11 +15,8 @@ func init() { Reader = &rngReader{} }
type rngReader struct{}
-func (r *rngReader) Read(b []byte) (n int, err error) {
- // RtlGenRandom only returns 1<<32-1 bytes at a time. We only read at
- // most 1<<31-1 bytes at a time so that this works the same on 32-bit
- // and 64-bit systems.
- if err := batched(windows.RtlGenRandom, 1<<31-1)(b); err != nil {
+func (r *rngReader) Read(b []byte) (int, error) {
+ if err := windows.ProcessPrng(b); err != nil {
return 0, err
}
return len(b), nil
diff --git a/src/internal/syscall/windows/syscall_windows.go b/src/internal/syscall/windows/syscall_windows.go
index ab4ad2ec64108..5854ca60b5cef 100644
--- a/src/internal/syscall/windows/syscall_windows.go
+++ b/src/internal/syscall/windows/syscall_windows.go
@@ -373,7 +373,7 @@ func ErrorLoadingGetTempPath2() error {
//sys DestroyEnvironmentBlock(block *uint16) (err error) = userenv.DestroyEnvironmentBlock
//sys CreateEvent(eventAttrs *SecurityAttributes, manualReset uint32, initialState uint32, name *uint16) (handle syscall.Handle, err error) = kernel32.CreateEventW
-//sys RtlGenRandom(buf []byte) (err error) = advapi32.SystemFunction036
+//sys ProcessPrng(buf []byte) (err error) = bcryptprimitives.ProcessPrng
type FILE_ID_BOTH_DIR_INFO struct {
NextEntryOffset uint32
diff --git a/src/internal/syscall/windows/zsyscall_windows.go b/src/internal/syscall/windows/zsyscall_windows.go
index e3f6d8d2a2208..5a587ad4f146c 100644
--- a/src/internal/syscall/windows/zsyscall_windows.go
+++ b/src/internal/syscall/windows/zsyscall_windows.go
@@ -37,13 +37,14 @@ func errnoErr(e syscall.Errno) error {
}
var (
- modadvapi32 = syscall.NewLazyDLL(sysdll.Add("advapi32.dll"))
- modiphlpapi = syscall.NewLazyDLL(sysdll.Add("iphlpapi.dll"))
- modkernel32 = syscall.NewLazyDLL(sysdll.Add("kernel32.dll"))
- modnetapi32 = syscall.NewLazyDLL(sysdll.Add("netapi32.dll"))
- modpsapi = syscall.NewLazyDLL(sysdll.Add("psapi.dll"))
- moduserenv = syscall.NewLazyDLL(sysdll.Add("userenv.dll"))
- modws2_32 = syscall.NewLazyDLL(sysdll.Add("ws2_32.dll"))
+ modadvapi32 = syscall.NewLazyDLL(sysdll.Add("advapi32.dll"))
+ modbcryptprimitives = syscall.NewLazyDLL(sysdll.Add("bcryptprimitives.dll"))
+ modiphlpapi = syscall.NewLazyDLL(sysdll.Add("iphlpapi.dll"))
+ modkernel32 = syscall.NewLazyDLL(sysdll.Add("kernel32.dll"))
+ modnetapi32 = syscall.NewLazyDLL(sysdll.Add("netapi32.dll"))
+ modpsapi = syscall.NewLazyDLL(sysdll.Add("psapi.dll"))
+ moduserenv = syscall.NewLazyDLL(sysdll.Add("userenv.dll"))
+ modws2_32 = syscall.NewLazyDLL(sysdll.Add("ws2_32.dll"))
procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges")
procDuplicateTokenEx = modadvapi32.NewProc("DuplicateTokenEx")
@@ -55,7 +56,7 @@ var (
procQueryServiceStatus = modadvapi32.NewProc("QueryServiceStatus")
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
procSetTokenInformation = modadvapi32.NewProc("SetTokenInformation")
- procSystemFunction036 = modadvapi32.NewProc("SystemFunction036")
+ procProcessPrng = modbcryptprimitives.NewProc("ProcessPrng")
procGetAdaptersAddresses = modiphlpapi.NewProc("GetAdaptersAddresses")
procCreateEventW = modkernel32.NewProc("CreateEventW")
procGetACP = modkernel32.NewProc("GetACP")
@@ -179,12 +180,12 @@ func SetTokenInformation(tokenHandle syscall.Token, tokenInformationClass uint32
return
}
-func RtlGenRandom(buf []byte) (err error) {
+func ProcessPrng(buf []byte) (err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
- r1, _, e1 := syscall.Syscall(procSystemFunction036.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), 0)
+ r1, _, e1 := syscall.Syscall(procProcessPrng.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
diff --git a/src/runtime/os_windows.go b/src/runtime/os_windows.go
index 8ca8d7790909e..3772a864b2ff4 100644
--- a/src/runtime/os_windows.go
+++ b/src/runtime/os_windows.go
@@ -127,15 +127,8 @@ var (
_WriteFile,
_ stdFunction
- // Use RtlGenRandom to generate cryptographically random data.
- // This approach has been recommended by Microsoft (see issue
- // 15589 for details).
- // The RtlGenRandom is not listed in advapi32.dll, instead
- // RtlGenRandom function can be found by searching for SystemFunction036.
- // Also some versions of Mingw cannot link to SystemFunction036
- // when building executable as Cgo. So load SystemFunction036
- // manually during runtime startup.
- _RtlGenRandom stdFunction
+ // Use ProcessPrng to generate cryptographically random data.
+ _ProcessPrng stdFunction
// Load ntdll.dll manually during startup, otherwise Mingw
// links wrong printf function to cgo executable (see issue
@@ -151,11 +144,11 @@ var (
)
var (
- advapi32dll = [...]uint16{'a', 'd', 'v', 'a', 'p', 'i', '3', '2', '.', 'd', 'l', 'l', 0}
- ntdlldll = [...]uint16{'n', 't', 'd', 'l', 'l', '.', 'd', 'l', 'l', 0}
- powrprofdll = [...]uint16{'p', 'o', 'w', 'r', 'p', 'r', 'o', 'f', '.', 'd', 'l', 'l', 0}
- winmmdll = [...]uint16{'w', 'i', 'n', 'm', 'm', '.', 'd', 'l', 'l', 0}
- ws2_32dll = [...]uint16{'w', 's', '2', '_', '3', '2', '.', 'd', 'l', 'l', 0}
+ bcryptprimitivesdll = [...]uint16{'b', 'c', 'r', 'y', 'p', 't', 'p', 'r', 'i', 'm', 'i', 't', 'i', 'v', 'e', 's', '.', 'd', 'l', 'l', 0}
+ ntdlldll = [...]uint16{'n', 't', 'd', 'l', 'l', '.', 'd', 'l', 'l', 0}
+ powrprofdll = [...]uint16{'p', 'o', 'w', 'r', 'p', 'r', 'o', 'f', '.', 'd', 'l', 'l', 0}
+ winmmdll = [...]uint16{'w', 'i', 'n', 'm', 'm', '.', 'd', 'l', 'l', 0}
+ ws2_32dll = [...]uint16{'w', 's', '2', '_', '3', '2', '.', 'd', 'l', 'l', 0}
)
// Function to be called by windows CreateThread
@@ -251,11 +244,11 @@ func windowsLoadSystemLib(name []uint16) uintptr {
}
func loadOptionalSyscalls() {
- a32 := windowsLoadSystemLib(advapi32dll[:])
- if a32 == 0 {
- throw("advapi32.dll not found")
+ bcryptPrimitives := windowsLoadSystemLib(bcryptprimitivesdll[:])
+ if bcryptPrimitives == 0 {
+ throw("bcryptprimitives.dll not found")
}
- _RtlGenRandom = windowsFindfunc(a32, []byte("SystemFunction036\000"))
+ _ProcessPrng = windowsFindfunc(bcryptPrimitives, []byte("ProcessPrng\000"))
n32 := windowsLoadSystemLib(ntdlldll[:])
if n32 == 0 {
@@ -531,7 +524,7 @@ func osinit() {
//go:nosplit
func readRandom(r []byte) int {
n := 0
- if stdcall2(_RtlGenRandom, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
+ if stdcall2(_ProcessPrng, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
n = len(r)
}
return n

View file

@ -1,162 +0,0 @@
diff --git a/src/net/hook_windows.go b/src/net/hook_windows.go
index ab8656cbbf343..28c49cc6de7e7 100644
--- a/src/net/hook_windows.go
+++ b/src/net/hook_windows.go
@@ -14,7 +14,6 @@ var (
testHookDialChannel = func() { time.Sleep(time.Millisecond) } // see golang.org/issue/5349
// Placeholders for socket system calls.
- socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket
wsaSocketFunc func(int32, int32, int32, *syscall.WSAProtocolInfo, uint32, uint32) (syscall.Handle, error) = windows.WSASocket
connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect
listenFunc func(syscall.Handle, int) error = syscall.Listen
diff --git a/src/net/internal/socktest/main_test.go b/src/net/internal/socktest/main_test.go
index 0197feb3f199a..967ce6795aedb 100644
--- a/src/net/internal/socktest/main_test.go
+++ b/src/net/internal/socktest/main_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !plan9 && !wasip1
+//go:build !js && !plan9 && !wasip1 && !windows
package socktest_test
diff --git a/src/net/internal/socktest/main_windows_test.go b/src/net/internal/socktest/main_windows_test.go
deleted file mode 100644
index df1cb97784b51..0000000000000
--- a/src/net/internal/socktest/main_windows_test.go
+++ /dev/null
@@ -1,22 +0,0 @@
-// Copyright 2015 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package socktest_test
-
-import "syscall"
-
-var (
- socketFunc func(int, int, int) (syscall.Handle, error)
- closeFunc func(syscall.Handle) error
-)
-
-func installTestHooks() {
- socketFunc = sw.Socket
- closeFunc = sw.Closesocket
-}
-
-func uninstallTestHooks() {
- socketFunc = syscall.Socket
- closeFunc = syscall.Closesocket
-}
diff --git a/src/net/internal/socktest/sys_windows.go b/src/net/internal/socktest/sys_windows.go
index 8c1c862f33c9b..1c42e5c7f34b7 100644
--- a/src/net/internal/socktest/sys_windows.go
+++ b/src/net/internal/socktest/sys_windows.go
@@ -9,38 +9,6 @@ import (
"syscall"
)
-// Socket wraps syscall.Socket.
-func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error) {
- sw.once.Do(sw.init)
-
- so := &Status{Cookie: cookie(family, sotype, proto)}
- sw.fmu.RLock()
- f, _ := sw.fltab[FilterSocket]
- sw.fmu.RUnlock()
-
- af, err := f.apply(so)
- if err != nil {
- return syscall.InvalidHandle, err
- }
- s, so.Err = syscall.Socket(family, sotype, proto)
- if err = af.apply(so); err != nil {
- if so.Err == nil {
- syscall.Closesocket(s)
- }
- return syscall.InvalidHandle, err
- }
-
- sw.smu.Lock()
- defer sw.smu.Unlock()
- if so.Err != nil {
- sw.stats.getLocked(so.Cookie).OpenFailed++
- return syscall.InvalidHandle, so.Err
- }
- nso := sw.addLocked(s, family, sotype, proto)
- sw.stats.getLocked(nso.Cookie).Opened++
- return s, nil
-}
-
// WSASocket wraps [syscall.WSASocket].
func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) {
sw.once.Do(sw.init)
diff --git a/src/net/main_windows_test.go b/src/net/main_windows_test.go
index 07f21b72eb1fc..bc024c0bbd82d 100644
--- a/src/net/main_windows_test.go
+++ b/src/net/main_windows_test.go
@@ -8,7 +8,6 @@ import "internal/poll"
var (
// Placeholders for saving original socket system calls.
- origSocket = socketFunc
origWSASocket = wsaSocketFunc
origClosesocket = poll.CloseFunc
origConnect = connectFunc
@@ -18,7 +17,6 @@ var (
)
func installTestHooks() {
- socketFunc = sw.Socket
wsaSocketFunc = sw.WSASocket
poll.CloseFunc = sw.Closesocket
connectFunc = sw.Connect
@@ -28,7 +26,6 @@ func installTestHooks() {
}
func uninstallTestHooks() {
- socketFunc = origSocket
wsaSocketFunc = origWSASocket
poll.CloseFunc = origClosesocket
connectFunc = origConnect
diff --git a/src/net/sock_windows.go b/src/net/sock_windows.go
index fa11c7af2e727..5540135a2c43e 100644
--- a/src/net/sock_windows.go
+++ b/src/net/sock_windows.go
@@ -19,21 +19,6 @@ func maxListenerBacklog() int {
func sysSocket(family, sotype, proto int) (syscall.Handle, error) {
s, err := wsaSocketFunc(int32(family), int32(sotype), int32(proto),
nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT)
- if err == nil {
- return s, nil
- }
- // WSA_FLAG_NO_HANDLE_INHERIT flag is not supported on some
- // old versions of Windows, see
- // https://msdn.microsoft.com/en-us/library/windows/desktop/ms742212(v=vs.85).aspx
- // for details. Just use syscall.Socket, if windows.WSASocket failed.
-
- // See ../syscall/exec_unix.go for description of ForkLock.
- syscall.ForkLock.RLock()
- s, err = socketFunc(family, sotype, proto)
- if err == nil {
- syscall.CloseOnExec(s)
- }
- syscall.ForkLock.RUnlock()
if err != nil {
return syscall.InvalidHandle, os.NewSyscallError("socket", err)
}
diff --git a/src/syscall/exec_windows.go b/src/syscall/exec_windows.go
index 0a93bc0a80d4e..06e684c7116b4 100644
--- a/src/syscall/exec_windows.go
+++ b/src/syscall/exec_windows.go
@@ -14,6 +14,7 @@ import (
"unsafe"
)
+// ForkLock is not used on Windows.
var ForkLock sync.RWMutex
// EscapeArg rewrites command line argument s as prescribed

View file

@ -40,10 +40,10 @@ jobs:
- { goos: linux, goarch: arm, goarm: '5', output: armv5 }
- { goos: linux, goarch: arm, goarm: '6', output: armv6 }
- { goos: linux, goarch: arm, goarm: '7', output: armv7 }
- { goos: linux, goarch: mips, mips: hardfloat, output: mips-hardfloat }
- { goos: linux, goarch: mips, mips: softfloat, output: mips-softfloat }
- { goos: linux, goarch: mipsle, mips: hardfloat, output: mipsle-hardfloat }
- { goos: linux, goarch: mipsle, mips: softfloat, output: mipsle-softfloat }
- { goos: linux, goarch: mips, gomips: hardfloat, output: mips-hardfloat }
- { goos: linux, goarch: mips, gomips: softfloat, output: mips-softfloat }
- { goos: linux, goarch: mipsle, gomips: hardfloat, output: mipsle-hardfloat }
- { goos: linux, goarch: mipsle, gomips: softfloat, output: mipsle-softfloat }
- { goos: linux, goarch: mips64, output: mips64 }
- { goos: linux, goarch: mips64le, output: mips64le }
- { goos: linux, goarch: loong64, output: loong64-abi1, abi: '1' }
@ -67,6 +67,12 @@ jobs:
- { goos: android, goarch: arm, ndk: armv7a-linux-androideabi34, output: armv7 }
- { goos: android, goarch: arm64, ndk: aarch64-linux-android34, output: arm64-v8 }
# Go 1.22 with special patch can work on Windows 7
# https://github.com/MetaCubeX/go/commits/release-branch.go1.22/
- { goos: windows, goarch: '386', output: '386-go122', goversion: '1.22' }
- { goos: windows, goarch: amd64, goamd64: v1, output: amd64-compatible-go122, goversion: '1.22' }
- { goos: windows, goarch: amd64, goamd64: v3, output: amd64-go122, goversion: '1.22' }
# Go 1.21 can revert commit `9e4385` to work on Windows 7
# https://github.com/golang/go/issues/64622#issuecomment-1847475161
# (OR we can just use golang1.21.4 which unneeded any patch)
@ -79,6 +85,11 @@ jobs:
- { goos: windows, goarch: amd64, goamd64: v1, output: amd64-compatible-go120, goversion: '1.20' }
- { goos: windows, goarch: amd64, goamd64: v3, output: amd64-go120, goversion: '1.20' }
# Go 1.22 is the last release that will run on macOS 10.15 Catalina. Go 1.23 will require macOS 11 Big Sur or later.
- { goos: darwin, goarch: arm64, output: arm64-go122, goversion: '1.22' }
- { goos: darwin, goarch: amd64, goamd64: v1, output: amd64-compatible-go122, goversion: '1.22' }
- { goos: darwin, goarch: amd64, goamd64: v3, output: amd64-go122, goversion: '1.22' }
# Go 1.20 is the last release that will run on macOS 10.13 High Sierra or 10.14 Mojave. Go 1.21 will require macOS 10.15 Catalina or later.
- { goos: darwin, goarch: arm64, output: arm64-go120, goversion: '1.20' }
- { goos: darwin, goarch: amd64, goamd64: v1, output: amd64-compatible-go120, goversion: '1.20' }
@ -96,7 +107,7 @@ jobs:
if: ${{ matrix.jobs.goversion == '' && matrix.jobs.goarch != 'loong64' }}
uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version: '1.23'
- name: Set up Go
if: ${{ matrix.jobs.goversion != '' && matrix.jobs.goarch != 'loong64' }}
@ -107,31 +118,52 @@ jobs:
- name: Set up Go1.22 loongarch abi1
if: ${{ matrix.jobs.goarch == 'loong64' && matrix.jobs.abi == '1' }}
run: |
wget -q https://github.com/xishang0128/loongarch64-golang/releases/download/1.22.0/go1.22.0.linux-amd64-abi1.tar.gz
sudo tar zxf go1.22.0.linux-amd64-abi1.tar.gz -C /usr/local
wget -q https://github.com/MetaCubeX/loongarch64-golang/releases/download/1.22.4/go1.22.4.linux-amd64-abi1.tar.gz
sudo tar zxf go1.22.4.linux-amd64-abi1.tar.gz -C /usr/local
echo "/usr/local/go/bin" >> $GITHUB_PATH
- name: Set up Go1.22 loongarch abi2
if: ${{ matrix.jobs.goarch == 'loong64' && matrix.jobs.abi == '2' }}
run: |
wget -q https://github.com/xishang0128/loongarch64-golang/releases/download/1.22.0/go1.22.0.linux-amd64-abi2.tar.gz
sudo tar zxf go1.22.0.linux-amd64-abi2.tar.gz -C /usr/local
wget -q https://github.com/MetaCubeX/loongarch64-golang/releases/download/1.22.4/go1.22.4.linux-amd64-abi2.tar.gz
sudo tar zxf go1.22.4.linux-amd64-abi2.tar.gz -C /usr/local
echo "/usr/local/go/bin" >> $GITHUB_PATH
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557
# this patch file only works on golang1.22.x
# that means after golang1.23 release it must be changed
# this patch file only works on golang1.23.x
# that means after golang1.24 release it must be changed
# see: https://github.com/MetaCubeX/go/commits/release-branch.go1.23/
# revert:
# 693def151adff1af707d82d28f55dba81ceb08e1: "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
# 7c1157f9544922e96945196b47b95664b1e39108: "net: remove sysSocket fallback for Windows 7"
# 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround"
- name: Revert Golang1.22 commit for Windows7/8
# a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries"
- name: Revert Golang1.23 commit for Windows7/8
if: ${{ matrix.jobs.goos == 'windows' && matrix.jobs.goversion == '' }}
run: |
cd $(go env GOROOT)
patch --verbose -R -p 1 < $GITHUB_WORKSPACE/.github/patch_go122/693def151adff1af707d82d28f55dba81ceb08e1.diff
patch --verbose -R -p 1 < $GITHUB_WORKSPACE/.github/patch_go122/7c1157f9544922e96945196b47b95664b1e39108.diff
patch --verbose -R -p 1 < $GITHUB_WORKSPACE/.github/patch_go122/48042aa09c2f878c4faa576948b07fe625c4707a.diff
curl https://github.com/MetaCubeX/go/commit/9ac42137ef6730e8b7daca016ece831297a1d75b.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/21290de8a4c91408de7c2b5b68757b1e90af49dd.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/6a31d3fa8e47ddabc10bd97bff10d9a85f4cfb76.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/69e2eed6dd0f6d815ebf15797761c13f31213dd6.diff | patch --verbose -p 1
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557
# this patch file only works on golang1.22.x
# that means after golang1.23 release it must be changed
# see: https://github.com/MetaCubeX/go/commits/release-branch.go1.22/
# revert:
# 693def151adff1af707d82d28f55dba81ceb08e1: "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
# 7c1157f9544922e96945196b47b95664b1e39108: "net: remove sysSocket fallback for Windows 7"
# 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround"
# a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries"
- name: Revert Golang1.22 commit for Windows7/8
if: ${{ matrix.jobs.goos == 'windows' && matrix.jobs.goversion == '1.22' }}
run: |
cd $(go env GOROOT)
curl https://github.com/MetaCubeX/go/commit/9779155f18b6556a034f7bb79fb7fb2aad1e26a9.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/ef0606261340e608017860b423ffae5c1ce78239.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/7f83badcb925a7e743188041cb6e561fc9b5b642.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/83ff9782e024cb328b690cbf0da4e7848a327f4f.diff | patch --verbose -p 1
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557
- name: Revert Golang1.21 commit for Windows7/8
@ -155,13 +187,14 @@ jobs:
echo "BUILDTIME=$(date)" >> $GITHUB_ENV
echo "CGO_ENABLED=0" >> $GITHUB_ENV
echo "BUILDTAG=-extldflags --static" >> $GITHUB_ENV
echo "GOTOOLCHAIN=local" >> $GITHUB_ENV
- name: Setup NDK
if: ${{ matrix.jobs.goos == 'android' }}
uses: nttld/setup-ndk@v1
id: setup-ndk
with:
ndk-version: r26c
ndk-version: r27
- name: Set NDK path
if: ${{ matrix.jobs.goos == 'android' }}
@ -174,6 +207,8 @@ jobs:
if: ${{ matrix.jobs.test == 'test' }}
run: |
go test ./...
echo "---test with_gvisor---"
go test ./... -tags "with_gvisor" -count=1
- name: Update CA
run: |
@ -186,10 +221,10 @@ jobs:
GOOS: ${{matrix.jobs.goos}}
GOARCH: ${{matrix.jobs.goarch}}
GOAMD64: ${{matrix.jobs.goamd64}}
GOARM: ${{matrix.jobs.arm}}
GOMIPS: ${{matrix.jobs.mips}}
GOARM: ${{matrix.jobs.goarm}}
GOMIPS: ${{matrix.jobs.gomips}}
run: |
echo $CGO_ENABLED
go env
go build -v -tags "with_gvisor" -trimpath -ldflags "${BUILDTAG} -X 'github.com/metacubex/mihomo/constant.Version=${VERSION}' -X 'github.com/metacubex/mihomo/constant.BuildTime=${BUILDTIME}' -w -s -buildid="
if [ "${{matrix.jobs.goos}}" = "windows" ]; then
cp mihomo.exe mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}.exe
@ -352,18 +387,18 @@ jobs:
git fetch --tags
echo "PREVERSION=$(git describe --tags --abbrev=0 HEAD)" >> $GITHUB_ENV
- name: Merge Alpha branch into Meta
- name: Force push Alpha branch to Meta
run: |
git config --global user.email "github-actions[bot]@users.noreply.github.com"
git config --global user.name "github-actions[bot]"
git fetch origin Alpha:Alpha
git merge Alpha
git push origin Meta
git push origin Alpha:Meta --force
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Tag the commit
- name: Tag the commit on Alpha
run: |
git checkout Alpha
git tag ${{ github.event.inputs.version }}
git push origin ${{ github.event.inputs.version }}
env:

View file

@ -13,7 +13,7 @@ WORKDIR /mihomo
COPY bin/ bin/
RUN FILE_NAME=`sh file-name.sh` && echo $FILE_NAME && \
FILE_NAME=`ls bin/ | egrep "$FILE_NAME.gz"|awk NR==1` && echo $FILE_NAME && \
mv bin/$FILE_NAME mihomo.gz && gzip -d mihomo.gz && echo "$FILE_NAME" > /mihomo-config/test
mv bin/$FILE_NAME mihomo.gz && gzip -d mihomo.gz && chmod +x mihomo && echo "$FILE_NAME" > /mihomo-config/test
FROM alpine:latest
LABEL org.opencontainers.image.source="https://github.com/MetaCubeX/mihomo"
@ -23,5 +23,4 @@ VOLUME ["/root/.config/mihomo/"]
COPY --from=builder /mihomo-config/ /root/.config/mihomo/
COPY --from=builder /mihomo/mihomo /mihomo
RUN chmod +x /mihomo
ENTRYPOINT [ "/mihomo" ]

View file

@ -163,7 +163,3 @@ clean:
CLANG ?= clang-14
CFLAGS := -O2 -g -Wall -Werror $(CFLAGS)
ebpf: export BPF_CLANG := $(CLANG)
ebpf: export BPF_CFLAGS := $(CFLAGS)
ebpf:
cd component/ebpf/ && go generate ./...

View file

@ -69,3 +69,5 @@ func WithDSCP(dscp uint8) Addition {
metadata.DSCP = dscp
}
}
func Placeholder(metadata *C.Metadata) {}

View file

@ -11,6 +11,8 @@ import (
func NewHTTPS(request *http.Request, conn net.Conn, additions ...Addition) (net.Conn, *C.Metadata) {
metadata := parseHTTPAddr(request)
metadata.Type = C.HTTPS
metadata.RawSrcAddr = conn.RemoteAddr()
metadata.RawDstAddr = conn.LocalAddr()
ApplyAdditions(metadata, WithSrcAddr(conn.RemoteAddr()), WithInAddr(conn.LocalAddr()))
ApplyAdditions(metadata, additions...)
return conn, metadata

View file

@ -3,10 +3,30 @@ package inbound
import (
"context"
"net"
"github.com/metacubex/tfo-go"
)
var (
lc = tfo.ListenConfig{
DisableTFO: true,
}
)
func SetTfo(open bool) {
lc.DisableTFO = !open
}
func Tfo() bool {
return !lc.DisableTFO
}
func SetMPTCP(open bool) {
setMultiPathTCP(getListenConfig(), open)
setMultiPathTCP(&lc.ListenConfig, open)
}
func MPTCP() bool {
return getMultiPathTCP(&lc.ListenConfig)
}
func ListenContext(ctx context.Context, network, address string) (net.Listener, error) {

View file

@ -1,23 +0,0 @@
//go:build unix
package inbound
import (
"net"
"github.com/metacubex/tfo-go"
)
var (
lc = tfo.ListenConfig{
DisableTFO: true,
}
)
func SetTfo(open bool) {
lc.DisableTFO = !open
}
func getListenConfig() *net.ListenConfig {
return &lc.ListenConfig
}

View file

@ -1,15 +0,0 @@
package inbound
import (
"net"
)
var (
lc = net.ListenConfig{}
)
func SetTfo(open bool) {}
func getListenConfig() *net.ListenConfig {
return &lc
}

View file

@ -8,3 +8,7 @@ const multipathTCPAvailable = false
func setMultiPathTCP(listenConfig *net.ListenConfig, open bool) {
}
func getMultiPathTCP(listenConfig *net.ListenConfig) bool {
return false
}

View file

@ -9,3 +9,7 @@ const multipathTCPAvailable = true
func setMultiPathTCP(listenConfig *net.ListenConfig, open bool) {
listenConfig.SetMultipathTCP(open)
}
func getMultiPathTCP(listenConfig *net.ListenConfig) bool {
return listenConfig.MultipathTCP()
}

View file

@ -7,6 +7,7 @@ import (
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"time"
@ -14,6 +15,7 @@ import (
"github.com/metacubex/quic-go/congestion"
M "github.com/sagernet/sing/common/metadata"
CN "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
@ -43,6 +45,8 @@ type Hysteria struct {
option *HysteriaOption
client *core.Client
closeCh chan struct{} // for test
}
func (h *Hysteria) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
@ -51,7 +55,7 @@ func (h *Hysteria) DialContext(ctx context.Context, metadata *C.Metadata, opts .
return nil, err
}
return NewConn(tcpConn, h), nil
return NewConn(CN.NewRefConn(tcpConn, h), h), nil
}
func (h *Hysteria) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
@ -59,7 +63,7 @@ func (h *Hysteria) ListenPacketContext(ctx context.Context, metadata *C.Metadata
if err != nil {
return nil, err
}
return newPacketConn(&hyPacketConn{udpConn}, h), nil
return newPacketConn(CN.NewRefPacketConn(&hyPacketConn{udpConn}, h), h), nil
}
func (h *Hysteria) genHdc(ctx context.Context, opts ...dialer.Option) utils.PacketDialer {
@ -218,7 +222,7 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) {
if err != nil {
return nil, fmt.Errorf("hysteria %s create error: %w", addr, err)
}
return &Hysteria{
outbound := &Hysteria{
Base: &Base{
name: option.Name,
addr: addr,
@ -231,7 +235,19 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) {
},
option: &option,
client: client,
}, nil
}
runtime.SetFinalizer(outbound, closeHysteria)
return outbound, nil
}
func closeHysteria(h *Hysteria) {
if h.client != nil {
_ = h.client.Close()
}
if h.closeCh != nil {
close(h.closeCh)
}
}
type hyPacketConn struct {

View file

@ -38,6 +38,8 @@ type Hysteria2 struct {
option *Hysteria2Option
client *hysteria2.Client
dialer proxydialer.SingDialer
closeCh chan struct{} // for test
}
type Hysteria2Option struct {
@ -89,6 +91,9 @@ func closeHysteria2(h *Hysteria2) {
if h.client != nil {
_ = h.client.CloseWithError(errors.New("proxy removed"))
}
if h.closeCh != nil {
close(h.closeCh)
}
}
func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {

View file

@ -0,0 +1,38 @@
package outbound
import (
"context"
"runtime"
"testing"
"time"
)
func TestHysteria2GC(t *testing.T) {
option := Hysteria2Option{}
option.Server = "127.0.0.1"
option.Ports = "200,204,401-429,501-503"
option.HopInterval = 30
option.Password = "password"
option.Obfs = "salamander"
option.ObfsPassword = "password"
option.SNI = "example.com"
option.ALPN = []string{"h3"}
hy, err := NewHysteria2(option)
if err != nil {
t.Error(err)
return
}
closeCh := make(chan struct{})
hy.closeCh = closeCh
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
hy = nil
runtime.GC()
select {
case <-closeCh:
return
case <-ctx.Done():
t.Error("timeout not GC")
}
}

View file

@ -0,0 +1,39 @@
package outbound
import (
"context"
"runtime"
"testing"
"time"
)
func TestHysteriaGC(t *testing.T) {
option := HysteriaOption{}
option.Server = "127.0.0.1"
option.Ports = "200,204,401-429,501-503"
option.Protocol = "udp"
option.Up = "1Mbps"
option.Down = "1Mbps"
option.HopInterval = 30
option.Obfs = "salamander"
option.SNI = "example.com"
option.ALPN = []string{"h3"}
hy, err := NewHysteria(option)
if err != nil {
t.Error(err)
return
}
closeCh := make(chan struct{})
hy.closeCh = closeCh
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
hy = nil
runtime.GC()
select {
case <-closeCh:
return
case <-ctx.Done():
t.Error("timeout not GC")
}
}

View file

@ -37,7 +37,7 @@ func NewRejectWithOption(option RejectOption) *Reject {
return &Reject{
Base: &Base{
name: option.Name,
tp: C.Direct,
tp: C.Reject,
udp: true,
},
}

View file

@ -505,17 +505,14 @@ func NewVless(option VlessOption) (*Vless, error) {
var addons *vless.Addons
if option.Network != "ws" && len(option.Flow) >= 16 {
option.Flow = option.Flow[:16]
switch option.Flow {
case vless.XRV:
log.Warnln("To use %s, ensure your server is upgrade to Xray-core v1.8.0+", vless.XRV)
addons = &vless.Addons{
Flow: option.Flow,
}
case vless.XRO, vless.XRD, vless.XRS:
log.Fatalln("Legacy XTLS protocol %s is deprecated and no longer supported", option.Flow)
default:
if option.Flow != vless.XRV {
return nil, fmt.Errorf("unsupported xtls flow type: %s", option.Flow)
}
log.Warnln("To use %s, ensure your server is upgrade to Xray-core v1.8.0+", vless.XRV)
addons = &vless.Addons{
Flow: option.Flow,
}
}
switch option.PacketEncoding {

View file

@ -24,19 +24,24 @@ import (
"github.com/metacubex/mihomo/dns"
"github.com/metacubex/mihomo/log"
amnezia "github.com/metacubex/amneziawg-go/device"
wireguard "github.com/metacubex/sing-wireguard"
"github.com/metacubex/wireguard-go/device"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/wireguard-go/device"
)
type wireguardGoDevice interface {
Close()
IpcSet(uapiConf string) error
}
type WireGuard struct {
*Base
bind *wireguard.ClientBind
device *device.Device
device wireguardGoDevice
tunDevice wireguard.Device
dialer proxydialer.SingDialer
resolver *dns.Resolver
@ -68,6 +73,8 @@ type WireGuardOption struct {
UDP bool `proxy:"udp,omitempty"`
PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"`
AmneziaWGOption *AmneziaWGOption `proxy:"amnezia-wg-option,omitempty"`
Peers []WireGuardPeerOption `proxy:"peers,omitempty"`
RemoteDnsResolve bool `proxy:"remote-dns-resolve,omitempty"`
@ -85,6 +92,18 @@ type WireGuardPeerOption struct {
AllowedIPs []string `proxy:"allowed-ips,omitempty"`
}
type AmneziaWGOption struct {
JC int `proxy:"jc"`
JMin int `proxy:"jmin"`
JMax int `proxy:"jmax"`
S1 int `proxy:"s1"`
S2 int `proxy:"s2"`
H1 uint32 `proxy:"h1"`
H2 uint32 `proxy:"h2"`
H3 uint32 `proxy:"h3"`
H4 uint32 `proxy:"h4"`
}
type wgSingErrorHandler struct {
name string
}
@ -244,14 +263,19 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
outbound.device = device.NewDevice(context.Background(), outbound.tunDevice, outbound.bind, &device.Logger{
logger := &device.Logger{
Verbosef: func(format string, args ...interface{}) {
log.SingLogger.Debug(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...)))
},
Errorf: func(format string, args ...interface{}) {
log.SingLogger.Error(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...)))
},
}, option.Workers)
}
if option.AmneziaWGOption != nil {
outbound.device = amnezia.NewDevice(outbound.tunDevice, outbound.bind, logger, option.Workers)
} else {
outbound.device = device.NewDevice(outbound.tunDevice, outbound.bind, logger, option.Workers)
}
var has6 bool
for _, address := range outbound.localPrefixes {
@ -368,6 +392,17 @@ func (w *WireGuard) genIpcConf(ctx context.Context, updateOnly bool) (string, er
ipcConf := ""
if !updateOnly {
ipcConf += "private_key=" + w.option.PrivateKey + "\n"
if w.option.AmneziaWGOption != nil {
ipcConf += "jc=" + strconv.Itoa(w.option.AmneziaWGOption.JC) + "\n"
ipcConf += "jmin=" + strconv.Itoa(w.option.AmneziaWGOption.JMin) + "\n"
ipcConf += "jmax=" + strconv.Itoa(w.option.AmneziaWGOption.JMax) + "\n"
ipcConf += "s1=" + strconv.Itoa(w.option.AmneziaWGOption.S1) + "\n"
ipcConf += "s2=" + strconv.Itoa(w.option.AmneziaWGOption.S2) + "\n"
ipcConf += "h1=" + strconv.FormatUint(uint64(w.option.AmneziaWGOption.H1), 10) + "\n"
ipcConf += "h2=" + strconv.FormatUint(uint64(w.option.AmneziaWGOption.H2), 10) + "\n"
ipcConf += "h3=" + strconv.FormatUint(uint64(w.option.AmneziaWGOption.H3), 10) + "\n"
ipcConf += "h4=" + strconv.FormatUint(uint64(w.option.AmneziaWGOption.H4), 10) + "\n"
}
}
if len(w.option.Peers) > 0 {
for i, peer := range w.option.Peers {
@ -456,7 +491,6 @@ func closeWireGuard(w *WireGuard) {
if w.device != nil {
w.device.Close()
}
_ = common.Close(w.tunDevice)
if w.closeCh != nil {
close(w.closeCh)
}

View file

@ -29,6 +29,7 @@ func TestWireGuardGC(t *testing.T) {
err = wg.init(ctx)
if err != nil {
t.Error(err)
return
}
// must do a small sleep before test GC
// because it maybe deadlocks if w.device.Close call too fast after w.device.Start

View file

@ -205,7 +205,6 @@ func strategyStickySessions(url string) strategyFn {
proxy := proxies[nowIdx]
if proxy.AliveForTestUrl(url) {
if nowIdx != idx {
lruCache.Delete(key)
lruCache.Set(key, nowIdx)
}
@ -215,7 +214,6 @@ func strategyStickySessions(url string) strategyFn {
}
}
lruCache.Delete(key)
lruCache.Set(key, 0)
return proxies[0]
}

View file

@ -69,7 +69,7 @@ func ParseProxyGroup(config map[string]any, proxyMap map[string]C.Proxy, provide
}
if groupOption.IncludeAllProviders {
groupOption.Use = append(groupOption.Use, AllProviders...)
groupOption.Use = AllProviders
}
if groupOption.IncludeAllProxies {
if groupOption.Filter != "" {
@ -88,6 +88,9 @@ func ParseProxyGroup(config map[string]any, proxyMap map[string]C.Proxy, provide
} else {
groupOption.Proxies = append(groupOption.Proxies, AllProxies...)
}
if len(groupOption.Proxies) == 0 && len(groupOption.Use) == 0 {
groupOption.Proxies = []string{"COMPATIBLE"}
}
}
if len(groupOption.Proxies) == 0 && len(groupOption.Use) == 0 {

View file

@ -27,6 +27,8 @@ type extraOption struct {
}
type HealthCheck struct {
ctx context.Context
ctxCancel context.CancelFunc
url string
extra map[string]*extraOption
mu sync.Mutex
@ -36,7 +38,6 @@ type HealthCheck struct {
lazy bool
expectedStatus utils.IntRanges[uint16]
lastTouch atomic.TypedValue[time.Time]
done chan struct{}
singleDo *singledo.Single[struct{}]
timeout time.Duration
}
@ -59,7 +60,7 @@ func (hc *HealthCheck) process() {
} else {
log.Debugln("Skip once health check because we are lazy")
}
case <-hc.done:
case <-hc.ctx.Done():
ticker.Stop()
hc.stop()
return
@ -146,7 +147,7 @@ func (hc *HealthCheck) check() {
_, _, _ = hc.singleDo.Do(func() (struct{}, error) {
id := utils.NewUUIDV4().String()
log.Debugln("Start New Health Checking {%s}", id)
b, _ := batch.New[bool](context.Background(), batch.WithConcurrencyNum[bool](10))
b, _ := batch.New[bool](hc.ctx, batch.WithConcurrencyNum[bool](10))
// execute default health check
option := &extraOption{filters: nil, expectedStatus: hc.expectedStatus}
@ -195,7 +196,7 @@ func (hc *HealthCheck) execute(b *batch.Batch[bool], url, uid string, option *ex
p := proxy
b.Go(p.Name(), func() (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), hc.timeout)
ctx, cancel := context.WithTimeout(hc.ctx, hc.timeout)
defer cancel()
log.Debugln("Health Checking, proxy: %s, url: %s, id: {%s}", p.Name(), url, uid)
_, _ = p.URLTest(ctx, url, expectedStatus)
@ -206,7 +207,7 @@ func (hc *HealthCheck) execute(b *batch.Batch[bool], url, uid string, option *ex
}
func (hc *HealthCheck) close() {
hc.done <- struct{}{}
hc.ctxCancel()
}
func NewHealthCheck(proxies []C.Proxy, url string, timeout uint, interval uint, lazy bool, expectedStatus utils.IntRanges[uint16]) *HealthCheck {
@ -217,8 +218,11 @@ func NewHealthCheck(proxies []C.Proxy, url string, timeout uint, interval uint,
if timeout == 0 {
timeout = 5000
}
ctx, cancel := context.WithCancel(context.Background())
return &HealthCheck{
ctx: ctx,
ctxCancel: cancel,
proxies: proxies,
url: url,
timeout: time.Duration(timeout) * time.Millisecond,
@ -226,7 +230,6 @@ func NewHealthCheck(proxies []C.Proxy, url string, timeout uint, interval uint,
interval: time.Duration(interval) * time.Second,
lazy: lazy,
expectedStatus: expectedStatus,
done: make(chan struct{}, 1),
singleDo: singledo.NewSingle[struct{}](time.Second),
}
}

View file

@ -1,6 +1,7 @@
package provider
import (
"encoding"
"errors"
"fmt"
"time"
@ -9,8 +10,9 @@ import (
"github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/resource"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/constant/features"
types "github.com/metacubex/mihomo/constant/provider"
"github.com/dlclark/regexp2"
)
var (
@ -27,6 +29,15 @@ type healthCheckSchema struct {
ExpectedStatus string `provider:"expected-status,omitempty"`
}
type OverrideProxyNameSchema struct {
// matching expression for regex replacement
Pattern *regexp2.Regexp `provider:"pattern"`
// the new content after regex matching
Target string `provider:"target"`
}
var _ encoding.TextUnmarshaler = (*regexp2.Regexp)(nil) // ensure *regexp2.Regexp can decode direct by structure package
type OverrideSchema struct {
TFO *bool `provider:"tfo,omitempty"`
MPTcp *bool `provider:"mptcp,omitempty"`
@ -41,6 +52,8 @@ type OverrideSchema struct {
IPVersion *string `provider:"ip-version,omitempty"`
AdditionalPrefix *string `provider:"additional-prefix,omitempty"`
AdditionalSuffix *string `provider:"additional-suffix,omitempty"`
ProxyName []OverrideProxyNameSchema `provider:"proxy-name,omitempty"`
}
type proxyProviderSchema struct {
@ -94,11 +107,11 @@ func ParseProxyProvider(name string, mapping map[string]any) (types.ProxyProvide
path := C.Path.GetPathByHash("proxies", schema.URL)
if schema.Path != "" {
path = C.Path.Resolve(schema.Path)
if !features.CMFA && !C.Path.IsSafePath(path) {
if !C.Path.IsSafePath(path) {
return nil, fmt.Errorf("%w: %s", errSubPath, path)
}
}
vehicle = resource.NewHTTPVehicle(schema.URL, path, schema.Proxy, schema.Header)
vehicle = resource.NewHTTPVehicle(schema.URL, path, schema.Proxy, schema.Header, resource.DefaultHttpTimeout)
default:
return nil, fmt.Errorf("%w: %s", errVehicleType, schema.Type)
}

View file

@ -1,36 +0,0 @@
//go:build android && cmfa
package provider
import (
"time"
)
var (
suspended bool
)
type UpdatableProvider interface {
UpdatedAt() time.Time
}
func (pp *proxySetProvider) UpdatedAt() time.Time {
return pp.Fetcher.UpdatedAt
}
func (pp *proxySetProvider) Close() error {
pp.healthCheck.close()
pp.Fetcher.Destroy()
return nil
}
func (cp *compatibleProvider) Close() error {
cp.healthCheck.close()
return nil
}
func Suspend(s bool) {
suspended = s
}

View file

@ -18,7 +18,6 @@ import (
"github.com/metacubex/mihomo/component/resource"
C "github.com/metacubex/mihomo/constant"
types "github.com/metacubex/mihomo/constant/provider"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/tunnel/statistic"
"github.com/dlclark/regexp2"
@ -54,7 +53,7 @@ func (pp *proxySetProvider) MarshalJSON() ([]byte, error) {
"proxies": pp.Proxies(),
"testUrl": pp.healthCheck.url,
"expectedStatus": pp.healthCheck.expectedStatus.String(),
"updatedAt": pp.UpdatedAt,
"updatedAt": pp.UpdatedAt(),
"subscriptionInfo": pp.subscriptionInfo,
})
}
@ -72,19 +71,15 @@ func (pp *proxySetProvider) HealthCheck() {
}
func (pp *proxySetProvider) Update() error {
elm, same, err := pp.Fetcher.Update()
if err == nil && !same {
pp.OnUpdate(elm)
}
_, _, err := pp.Fetcher.Update()
return err
}
func (pp *proxySetProvider) Initial() error {
elm, err := pp.Fetcher.Initial()
_, err := pp.Fetcher.Initial()
if err != nil {
return err
}
pp.OnUpdate(elm)
pp.getSubscriptionInfo()
pp.closeAllConnections()
return nil
@ -98,6 +93,10 @@ func (pp *proxySetProvider) Proxies() []C.Proxy {
return pp.proxies
}
func (pp *proxySetProvider) Count() int {
return len(pp.proxies)
}
func (pp *proxySetProvider) Touch() {
pp.healthCheck.touch()
}
@ -125,8 +124,8 @@ func (pp *proxySetProvider) getSubscriptionInfo() {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*90)
defer cancel()
resp, err := mihomoHttp.HttpRequestWithProxy(ctx, pp.Vehicle().(*resource.HTTPVehicle).Url(),
http.MethodGet, http.Header{"User-Agent": {C.UA}}, nil, pp.Vehicle().Proxy())
resp, err := mihomoHttp.HttpRequestWithProxy(ctx, pp.Vehicle().Url(),
http.MethodGet, nil, nil, pp.Vehicle().Proxy())
if err != nil {
return
}
@ -134,7 +133,7 @@ func (pp *proxySetProvider) getSubscriptionInfo() {
userInfoStr := strings.TrimSpace(resp.Header.Get("subscription-userinfo"))
if userInfoStr == "" {
resp2, err := mihomoHttp.HttpRequestWithProxy(ctx, pp.Vehicle().(*resource.HTTPVehicle).Url(),
resp2, err := mihomoHttp.HttpRequestWithProxy(ctx, pp.Vehicle().Url(),
http.MethodGet, http.Header{"User-Agent": {"Quantumultx"}}, nil, pp.Vehicle().Proxy())
if err != nil {
return
@ -145,10 +144,7 @@ func (pp *proxySetProvider) getSubscriptionInfo() {
return
}
}
pp.subscriptionInfo, err = NewSubscriptionInfo(userInfoStr)
if err != nil {
log.Warnln("[Provider] get subscription-userinfo: %e", err)
}
pp.subscriptionInfo = NewSubscriptionInfo(userInfoStr)
}()
}
@ -164,9 +160,9 @@ func (pp *proxySetProvider) closeAllConnections() {
})
}
func stopProxyProvider(pd *ProxySetProvider) {
pd.healthCheck.close()
_ = pd.Fetcher.Destroy()
func (pp *proxySetProvider) Close() error {
pp.healthCheck.close()
return pp.Fetcher.Close()
}
func NewProxySetProvider(name string, interval time.Duration, filter string, excludeFilter string, excludeType string, dialerProxy string, override OverrideSchema, vehicle types.Vehicle, hc *HealthCheck) (*ProxySetProvider, error) {
@ -200,10 +196,15 @@ func NewProxySetProvider(name string, interval time.Duration, filter string, exc
fetcher := resource.NewFetcher[[]C.Proxy](name, interval, vehicle, proxiesParseAndFilter(filter, excludeFilter, excludeTypeArray, filterRegs, excludeFilterReg, dialerProxy, override), proxiesOnUpdate(pd))
pd.Fetcher = fetcher
wrapper := &ProxySetProvider{pd}
runtime.SetFinalizer(wrapper, stopProxyProvider)
runtime.SetFinalizer(wrapper, (*ProxySetProvider).Close)
return wrapper, nil
}
func (pp *ProxySetProvider) Close() error {
runtime.SetFinalizer(pp, nil)
return pp.proxySetProvider.Close()
}
// CompatibleProvider for auto gc
type CompatibleProvider struct {
*compatibleProvider
@ -262,6 +263,10 @@ func (cp *compatibleProvider) Proxies() []C.Proxy {
return cp.proxies
}
func (cp *compatibleProvider) Count() int {
return len(cp.proxies)
}
func (cp *compatibleProvider) Touch() {
cp.healthCheck.touch()
}
@ -274,8 +279,9 @@ func (cp *compatibleProvider) RegisterHealthCheckTask(url string, expectedStatus
cp.healthCheck.registerHealthCheckTask(url, expectedStatus, filter, interval)
}
func stopCompatibleProvider(pd *CompatibleProvider) {
pd.healthCheck.close()
func (cp *compatibleProvider) Close() error {
cp.healthCheck.close()
return nil
}
func NewCompatibleProvider(name string, proxies []C.Proxy, hc *HealthCheck) (*CompatibleProvider, error) {
@ -294,10 +300,15 @@ func NewCompatibleProvider(name string, proxies []C.Proxy, hc *HealthCheck) (*Co
}
wrapper := &CompatibleProvider{pd}
runtime.SetFinalizer(wrapper, stopCompatibleProvider)
runtime.SetFinalizer(wrapper, (*CompatibleProvider).Close)
return wrapper, nil
}
func (cp *CompatibleProvider) Close() error {
runtime.SetFinalizer(cp, nil)
return cp.compatibleProvider.Close()
}
func proxiesOnUpdate(pd *proxySetProvider) func([]C.Proxy) {
return func(elm []C.Proxy) {
pd.setProxies(elm)
@ -388,6 +399,16 @@ func proxiesParseAndFilter(filter string, excludeFilter string, excludeTypeArray
case "additional-suffix":
name := mapping["name"].(string)
mapping["name"] = name + *field.Interface().(*string)
case "proxy-name":
// Iterate through all naming replacement rules and perform the replacements.
for _, expr := range override.ProxyName {
name := mapping["name"].(string)
newName, err := expr.Pattern.Replace(name, expr.Target, 0, -1)
if err != nil {
return nil, fmt.Errorf("proxy name replace error: %w", err)
}
mapping["name"] = newName
}
default:
mapping[fieldName] = field.Elem().Interface()
}

View file

@ -1,8 +1,11 @@
package provider
import (
"fmt"
"strconv"
"strings"
"github.com/metacubex/mihomo/log"
)
type SubscriptionInfo struct {
@ -12,28 +15,46 @@ type SubscriptionInfo struct {
Expire int64
}
func NewSubscriptionInfo(userinfo string) (si *SubscriptionInfo, err error) {
func NewSubscriptionInfo(userinfo string) (si *SubscriptionInfo) {
userinfo = strings.ToLower(userinfo)
userinfo = strings.ReplaceAll(userinfo, " ", "")
si = new(SubscriptionInfo)
for _, field := range strings.Split(userinfo, ";") {
switch name, value, _ := strings.Cut(field, "="); name {
case "upload":
si.Upload, err = strconv.ParseInt(value, 10, 64)
case "download":
si.Download, err = strconv.ParseInt(value, 10, 64)
case "total":
si.Total, err = strconv.ParseInt(value, 10, 64)
case "expire":
if value == "" {
si.Expire = 0
} else {
si.Expire, err = strconv.ParseInt(value, 10, 64)
}
name, value, ok := strings.Cut(field, "=")
if !ok {
continue
}
intValue, err := parseValue(value)
if err != nil {
return
log.Warnln("[Provider] get subscription-userinfo: %e", err)
continue
}
switch name {
case "upload":
si.Upload = intValue
case "download":
si.Download = intValue
case "total":
si.Total = intValue
case "expire":
si.Expire = intValue
}
}
return
return si
}
func parseValue(value string) (int64, error) {
if intValue, err := strconv.ParseInt(value, 10, 64); err == nil {
return intValue, nil
}
if floatValue, err := strconv.ParseFloat(value, 64); err == nil {
return int64(floatValue), nil
}
return 0, fmt.Errorf("failed to parse value '%s'", value)
}

View file

@ -33,15 +33,8 @@ type ARC[K comparable, V any] struct {
// New returns a new Adaptive Replacement Cache (ARC).
func New[K comparable, V any](options ...Option[K, V]) *ARC[K, V] {
arc := &ARC[K, V]{
p: 0,
t1: list.New[*entry[K, V]](),
b1: list.New[*entry[K, V]](),
t2: list.New[*entry[K, V]](),
b2: list.New[*entry[K, V]](),
len: 0,
cache: make(map[K]*entry[K, V]),
}
arc := &ARC[K, V]{}
arc.Clear()
for _, option := range options {
option(arc)
@ -49,6 +42,19 @@ func New[K comparable, V any](options ...Option[K, V]) *ARC[K, V] {
return arc
}
func (a *ARC[K, V]) Clear() {
a.mutex.Lock()
defer a.mutex.Unlock()
a.p = 0
a.t1 = list.New[*entry[K, V]]()
a.b1 = list.New[*entry[K, V]]()
a.t2 = list.New[*entry[K, V]]()
a.b2 = list.New[*entry[K, V]]()
a.len = 0
a.cache = make(map[K]*entry[K, V])
}
// Set inserts a new key-value pair into the cache.
// This optimizes future access to this entry (side effect).
func (a *ARC[K, V]) Set(key K, value V) {

View file

@ -68,10 +68,8 @@ type LruCache[K comparable, V any] struct {
// New creates an LruCache
func New[K comparable, V any](options ...Option[K, V]) *LruCache[K, V] {
lc := &LruCache[K, V]{
lru: list.New[*entry[K, V]](),
cache: make(map[K]*list.Element[*entry[K, V]]),
}
lc := &LruCache[K, V]{}
lc.Clear()
for _, option := range options {
option(lc)
@ -80,6 +78,14 @@ func New[K comparable, V any](options ...Option[K, V]) *LruCache[K, V] {
return lc
}
func (c *LruCache[K, V]) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.lru = list.New[*entry[K, V]]()
c.cache = make(map[K]*list.Element[*entry[K, V]])
}
// Get returns any representation of a cached response and a bool
// set to true if the key was found.
func (c *LruCache[K, V]) Get(key K) (V, bool) {
@ -223,6 +229,10 @@ func (c *LruCache[K, V]) Delete(key K) {
c.mu.Lock()
defer c.mu.Unlock()
c.delete(key)
}
func (c *LruCache[K, V]) delete(key K) {
if le, ok := c.cache[key]; ok {
c.deleteElement(le)
}
@ -246,13 +256,32 @@ func (c *LruCache[K, V]) deleteElement(le *list.Element[*entry[K, V]]) {
}
}
func (c *LruCache[K, V]) Clear() error {
// Compute either sets the computed new value for the key or deletes
// the value for the key. When the delete result of the valueFn function
// is set to true, the value will be deleted, if it exists. When delete
// is set to false, the value is updated to the newValue.
// The ok result indicates whether value was computed and stored, thus, is
// present in the map. The actual result contains the new value in cases where
// the value was computed and stored.
func (c *LruCache[K, V]) Compute(
key K,
valueFn func(oldValue V, loaded bool) (newValue V, delete bool),
) (actual V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache = make(map[K]*list.Element[*entry[K, V]])
return nil
if el := c.get(key); el != nil {
actual, ok = el.value, true
}
if newValue, del := valueFn(actual, ok); del {
if ok { // data not in cache, so needn't delete
c.delete(key)
}
return lo.Empty[V](), false
} else {
c.set(key, newValue)
return newValue, true
}
}
type entry[K comparable, V any] struct {

View file

@ -0,0 +1,23 @@
package net
import (
"net"
"runtime"
"time"
)
var (
KeepAliveIdle = 0 * time.Second
KeepAliveInterval = 0 * time.Second
DisableKeepAlive = false
)
func TCPKeepAlive(c net.Conn) {
if tcp, ok := c.(*net.TCPConn); ok {
if runtime.GOOS == "android" || DisableKeepAlive {
_ = tcp.SetKeepAlive(false)
} else {
tcpKeepAlive(tcp)
}
}
}

View file

@ -0,0 +1,10 @@
//go:build !go1.23
package net
import "net"
func tcpKeepAlive(tcp *net.TCPConn) {
_ = tcp.SetKeepAlive(true)
_ = tcp.SetKeepAlivePeriod(KeepAliveInterval)
}

View file

@ -0,0 +1,19 @@
//go:build go1.23
package net
import "net"
func tcpKeepAlive(tcp *net.TCPConn) {
config := net.KeepAliveConfig{
Enable: true,
Idle: KeepAliveIdle,
Interval: KeepAliveInterval,
}
if !SupportTCPKeepAliveCount() {
// it's recommended to set both Idle and Interval to non-negative values in conjunction with a -1
// for Count on those old Windows if you intend to customize the TCP keep-alive settings.
config.Count = -1
}
_ = tcp.SetKeepAliveConfig(config)
}

View file

@ -0,0 +1,15 @@
//go:build go1.23 && unix
package net
func SupportTCPKeepAliveIdle() bool {
return true
}
func SupportTCPKeepAliveInterval() bool {
return true
}
func SupportTCPKeepAliveCount() bool {
return true
}

View file

@ -0,0 +1,63 @@
//go:build go1.23 && windows
// copy and modify from golang1.23's internal/syscall/windows/version_windows.go
package net
import (
"errors"
"sync"
"syscall"
"github.com/metacubex/mihomo/constant/features"
"golang.org/x/sys/windows"
)
var (
supportTCPKeepAliveIdle bool
supportTCPKeepAliveInterval bool
supportTCPKeepAliveCount bool
)
var initTCPKeepAlive = sync.OnceFunc(func() {
s, err := windows.WSASocket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP, nil, 0, windows.WSA_FLAG_NO_HANDLE_INHERIT)
if err != nil {
// Fallback to checking the Windows version.
major, build := features.WindowsMajorVersion, features.WindowsBuildNumber
supportTCPKeepAliveIdle = major >= 10 && build >= 16299
supportTCPKeepAliveInterval = major >= 10 && build >= 16299
supportTCPKeepAliveCount = major >= 10 && build >= 15063
return
}
defer windows.Closesocket(s)
var optSupported = func(opt int) bool {
err := windows.SetsockoptInt(s, syscall.IPPROTO_TCP, opt, 1)
return !errors.Is(err, syscall.WSAENOPROTOOPT)
}
supportTCPKeepAliveIdle = optSupported(windows.TCP_KEEPIDLE)
supportTCPKeepAliveInterval = optSupported(windows.TCP_KEEPINTVL)
supportTCPKeepAliveCount = optSupported(windows.TCP_KEEPCNT)
})
// SupportTCPKeepAliveIdle indicates whether TCP_KEEPIDLE is supported.
// The minimal requirement is Windows 10.0.16299.
func SupportTCPKeepAliveIdle() bool {
initTCPKeepAlive()
return supportTCPKeepAliveIdle
}
// SupportTCPKeepAliveInterval indicates whether TCP_KEEPINTVL is supported.
// The minimal requirement is Windows 10.0.16299.
func SupportTCPKeepAliveInterval() bool {
initTCPKeepAlive()
return supportTCPKeepAliveInterval
}
// SupportTCPKeepAliveCount indicates whether TCP_KEEPCNT is supported.
// supports TCP_KEEPCNT.
// The minimal requirement is Windows 10.0.15063.
func SupportTCPKeepAliveCount() bool {
initTCPKeepAlive()
return supportTCPKeepAliveCount
}

View file

@ -4,11 +4,8 @@ import (
"fmt"
"net"
"strings"
"time"
)
var KeepAliveInterval = 15 * time.Second
func SplitNetworkType(s string) (string, string, error) {
var (
shecme string
@ -47,10 +44,3 @@ func SplitHostPort(s string) (host, port string, hasPort bool, err error) {
host, port, err = net.SplitHostPort(temp)
return
}
func TCPKeepAlive(c net.Conn) {
if tcp, ok := c.(*net.TCPConn); ok {
_ = tcp.SetKeepAlive(true)
_ = tcp.SetKeepAlivePeriod(KeepAliveInterval)
}
}

View file

@ -51,3 +51,23 @@ func UnMasked(p netip.Prefix) netip.Addr {
}
return addr
}
// PrefixCompare returns an integer comparing two prefixes.
// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
// modify from https://github.com/golang/go/issues/61642#issuecomment-1848587909
func PrefixCompare(p, p2 netip.Prefix) int {
// compare by validity, address family and prefix base address
if c := p.Masked().Addr().Compare(p2.Masked().Addr()); c != 0 {
return c
}
// compare by prefix length
f1, f2 := p.Bits(), p2.Bits()
if f1 < f2 {
return -1
}
if f1 > f2 {
return 1
}
// compare by prefix address
return p.Addr().Compare(p2.Addr())
}

View file

@ -59,8 +59,8 @@ func (q *Queue[T]) Copy() []T {
// Len returns the number of items in this queue.
func (q *Queue[T]) Len() int64 {
q.lock.Lock()
defer q.lock.Unlock()
q.lock.RLock()
defer q.lock.RUnlock()
return int64(len(q.items))
}

View file

@ -0,0 +1,224 @@
// copy and modify from "golang.org/x/sync/singleflight"
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package singleflight provides a duplicate function call suppression
// mechanism.
package singleflight
import (
"bytes"
"errors"
"fmt"
"runtime"
"runtime/debug"
"sync"
)
// errGoexit indicates the runtime.Goexit was called in
// the user given function.
var errGoexit = errors.New("runtime.Goexit was called")
// A panicError is an arbitrary value recovered from a panic
// with the stack trace during the execution of given function.
type panicError struct {
value interface{}
stack []byte
}
// Error implements error interface.
func (p *panicError) Error() string {
return fmt.Sprintf("%v\n\n%s", p.value, p.stack)
}
func (p *panicError) Unwrap() error {
err, ok := p.value.(error)
if !ok {
return nil
}
return err
}
func newPanicError(v interface{}) error {
stack := debug.Stack()
// The first line of the stack trace is of the form "goroutine N [status]:"
// but by the time the panic reaches Do the goroutine may no longer exist
// and its status will have changed. Trim out the misleading line.
if line := bytes.IndexByte(stack[:], '\n'); line >= 0 {
stack = stack[line+1:]
}
return &panicError{value: v, stack: stack}
}
// call is an in-flight or completed singleflight.Do call
type call[T any] struct {
wg sync.WaitGroup
// These fields are written once before the WaitGroup is done
// and are only read after the WaitGroup is done.
val T
err error
// These fields are read and written with the singleflight
// mutex held before the WaitGroup is done, and are read but
// not written after the WaitGroup is done.
dups int
chans []chan<- Result[T]
}
// Group represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
type Group[T any] struct {
mu sync.Mutex // protects m
m map[string]*call[T] // lazily initialized
StoreResult bool
}
// Result holds the results of Do, so they can be passed
// on a channel.
type Result[T any] struct {
Val T
Err error
Shared bool
}
// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
// The return value shared indicates whether v was given to multiple callers.
func (g *Group[T]) Do(key string, fn func() (T, error)) (v T, err error, shared bool) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call[T])
}
if c, ok := g.m[key]; ok {
c.dups++
g.mu.Unlock()
c.wg.Wait()
if e, ok := c.err.(*panicError); ok {
panic(e)
} else if c.err == errGoexit {
runtime.Goexit()
}
return c.val, c.err, true
}
c := new(call[T])
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
g.doCall(c, key, fn)
return c.val, c.err, c.dups > 0
}
// DoChan is like Do but returns a channel that will receive the
// results when they are ready.
//
// The returned channel will not be closed.
func (g *Group[T]) DoChan(key string, fn func() (T, error)) <-chan Result[T] {
ch := make(chan Result[T], 1)
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call[T])
}
if c, ok := g.m[key]; ok {
c.dups++
c.chans = append(c.chans, ch)
g.mu.Unlock()
return ch
}
c := &call[T]{chans: []chan<- Result[T]{ch}}
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
go g.doCall(c, key, fn)
return ch
}
// doCall handles the single call for a key.
func (g *Group[T]) doCall(c *call[T], key string, fn func() (T, error)) {
normalReturn := false
recovered := false
// use double-defer to distinguish panic from runtime.Goexit,
// more details see https://golang.org/cl/134395
defer func() {
// the given function invoked runtime.Goexit
if !normalReturn && !recovered {
c.err = errGoexit
}
g.mu.Lock()
defer g.mu.Unlock()
c.wg.Done()
if g.m[key] == c && !g.StoreResult {
delete(g.m, key)
}
if e, ok := c.err.(*panicError); ok {
// In order to prevent the waiting channels from being blocked forever,
// needs to ensure that this panic cannot be recovered.
if len(c.chans) > 0 {
go panic(e)
select {} // Keep this goroutine around so that it will appear in the crash dump.
} else {
panic(e)
}
} else if c.err == errGoexit {
// Already in the process of goexit, no need to call again
} else {
// Normal return
for _, ch := range c.chans {
ch <- Result[T]{c.val, c.err, c.dups > 0}
}
}
}()
func() {
defer func() {
if !normalReturn {
// Ideally, we would wait to take a stack trace until we've determined
// whether this is a panic or a runtime.Goexit.
//
// Unfortunately, the only way we can distinguish the two is to see
// whether the recover stopped the goroutine from terminating, and by
// the time we know that, the part of the stack trace relevant to the
// panic has been discarded.
if r := recover(); r != nil {
c.err = newPanicError(r)
}
}
}()
c.val, c.err = fn()
normalReturn = true
}()
if !normalReturn {
recovered = true
}
}
// Forget tells the singleflight to forget about a key. Future calls
// to Do for this key will call the function rather than waiting for
// an earlier call to complete.
func (g *Group[T]) Forget(key string) {
g.mu.Lock()
delete(g.m, key)
g.mu.Unlock()
}
func (g *Group[T]) Reset() {
g.mu.Lock()
g.m = nil
g.mu.Unlock()
}

View file

@ -3,6 +3,7 @@ package structure
// references: https://github.com/mitchellh/mapstructure
import (
"encoding"
"encoding/base64"
"fmt"
"reflect"
@ -86,35 +87,41 @@ func (d *Decoder) Decode(src map[string]any, dst any) error {
}
func (d *Decoder) decode(name string, data any, val reflect.Value) error {
kind := val.Kind()
switch {
case isInt(kind):
return d.decodeInt(name, data, val)
case isUint(kind):
return d.decodeUint(name, data, val)
case isFloat(kind):
return d.decodeFloat(name, data, val)
}
switch kind {
case reflect.Pointer:
if val.IsNil() {
for {
kind := val.Kind()
if kind == reflect.Pointer && val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
return d.decode(name, data, val.Elem())
case reflect.String:
return d.decodeString(name, data, val)
case reflect.Bool:
return d.decodeBool(name, data, val)
case reflect.Slice:
return d.decodeSlice(name, data, val)
case reflect.Map:
return d.decodeMap(name, data, val)
case reflect.Interface:
return d.setInterface(name, data, val)
case reflect.Struct:
return d.decodeStruct(name, data, val)
default:
return fmt.Errorf("type %s not support", val.Kind().String())
if ok, err := d.decodeTextUnmarshaller(name, data, val); ok {
return err
}
switch {
case isInt(kind):
return d.decodeInt(name, data, val)
case isUint(kind):
return d.decodeUint(name, data, val)
case isFloat(kind):
return d.decodeFloat(name, data, val)
}
switch kind {
case reflect.Pointer:
val = val.Elem()
continue
case reflect.String:
return d.decodeString(name, data, val)
case reflect.Bool:
return d.decodeBool(name, data, val)
case reflect.Slice:
return d.decodeSlice(name, data, val)
case reflect.Map:
return d.decodeMap(name, data, val)
case reflect.Interface:
return d.setInterface(name, data, val)
case reflect.Struct:
return d.decodeStruct(name, data, val)
default:
return fmt.Errorf("type %s not support", val.Kind().String())
}
}
}
@ -553,3 +560,25 @@ func (d *Decoder) setInterface(name string, data any, val reflect.Value) (err er
val.Set(dataVal)
return nil
}
func (d *Decoder) decodeTextUnmarshaller(name string, data any, val reflect.Value) (bool, error) {
if !val.CanAddr() {
return false, nil
}
valAddr := val.Addr()
if !valAddr.CanInterface() {
return false, nil
}
unmarshaller, ok := valAddr.Interface().(encoding.TextUnmarshaler)
if !ok {
return false, nil
}
var str string
if err := d.decodeString(name, data, reflect.Indirect(reflect.ValueOf(&str))); err != nil {
return false, err
}
if err := unmarshaller.UnmarshalText([]byte(str)); err != nil {
return true, fmt.Errorf("cannot parse '%s' as %s: %s", name, val.Type(), err)
}
return true, nil
}

View file

@ -1,6 +1,7 @@
package structure
import (
"strconv"
"testing"
"github.com/stretchr/testify/assert"
@ -179,3 +180,90 @@ func TestStructure_SliceNilValueComplex(t *testing.T) {
err = decoder.Decode(rawMap, ss)
assert.NotNil(t, err)
}
func TestStructure_SliceCap(t *testing.T) {
rawMap := map[string]any{
"foo": []string{},
}
s := &struct {
Foo []string `test:"foo,omitempty"`
Bar []string `test:"bar,omitempty"`
}{}
err := decoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.NotNil(t, s.Foo) // structure's Decode will ensure value not nil when input has value even it was set an empty array
assert.Nil(t, s.Bar)
}
func TestStructure_Base64(t *testing.T) {
rawMap := map[string]any{
"foo": "AQID",
}
s := &struct {
Foo []byte `test:"foo"`
}{}
err := decoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.Equal(t, []byte{1, 2, 3}, s.Foo)
}
func TestStructure_Pointer(t *testing.T) {
rawMap := map[string]any{
"foo": "foo",
}
s := &struct {
Foo *string `test:"foo,omitempty"`
Bar *string `test:"bar,omitempty"`
}{}
err := decoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.NotNil(t, s.Foo)
assert.Equal(t, "foo", *s.Foo)
assert.Nil(t, s.Bar)
}
type num struct {
a int
}
func (n *num) UnmarshalText(text []byte) (err error) {
n.a, err = strconv.Atoi(string(text))
return
}
func TestStructure_TextUnmarshaller(t *testing.T) {
rawMap := map[string]any{
"num": "255",
"num_p": "127",
}
s := &struct {
Num num `test:"num"`
NumP *num `test:"num_p"`
}{}
err := decoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.Equal(t, 255, s.Num.a)
assert.NotNil(t, s.NumP)
assert.Equal(t, s.NumP.a, 127)
// test WeaklyTypedInput
rawMap["num"] = 256
err = decoder.Decode(rawMap, s)
assert.NotNilf(t, err, "should throw error: %#v", s)
err = weakTypeDecoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.Equal(t, 256, s.Num.a)
// test invalid input
rawMap["num_p"] = "abc"
err = decoder.Decode(rawMap, s)
assert.NotNilf(t, err, "should throw error: %#v", s)
}

View file

@ -17,8 +17,8 @@ func NewCallback[T any]() *Callback[T] {
}
func (c *Callback[T]) Register(item func(T)) io.Closer {
c.mutex.RLock()
defer c.mutex.RUnlock()
c.mutex.Lock()
defer c.mutex.Unlock()
element := c.list.PushBack(item)
return &callbackCloser[T]{
element: element,

View file

@ -46,6 +46,14 @@ func (set *IpCidrSet) IsContain(ip netip.Addr) bool {
return set.ToIPSet().Contains(ip.WithZone(""))
}
// MatchIp implements C.IpMatcher
func (set *IpCidrSet) MatchIp(ip netip.Addr) bool {
if set.IsEmpty() {
return false
}
return set.IsContain(ip)
}
func (set *IpCidrSet) Merge() error {
var b netipx.IPSetBuilder
b.AddSet(set.ToIPSet())
@ -57,6 +65,20 @@ func (set *IpCidrSet) Merge() error {
return nil
}
func (set *IpCidrSet) IsEmpty() bool {
return set == nil || len(set.rr) == 0
}
func (set *IpCidrSet) Foreach(f func(prefix netip.Prefix) bool) {
for _, r := range set.rr {
for _, prefix := range r.Prefixes() {
if !f(prefix) {
return
}
}
}
}
// ToIPSet not safe convert to *netipx.IPSet
// be careful, must be used after Merge
func (set *IpCidrSet) ToIPSet() *netipx.IPSet {

View file

@ -0,0 +1,77 @@
package cidr
import (
"encoding/binary"
"errors"
"io"
"net/netip"
"go4.org/netipx"
)
func (ss *IpCidrSet) WriteBin(w io.Writer) (err error) {
// version
_, err = w.Write([]byte{1})
if err != nil {
return err
}
// rr
err = binary.Write(w, binary.BigEndian, int64(len(ss.rr)))
if err != nil {
return err
}
for _, r := range ss.rr {
err = binary.Write(w, binary.BigEndian, r.From().As16())
if err != nil {
return err
}
err = binary.Write(w, binary.BigEndian, r.To().As16())
if err != nil {
return err
}
}
return nil
}
func ReadIpCidrSet(r io.Reader) (ss *IpCidrSet, err error) {
// version
version := make([]byte, 1)
_, err = io.ReadFull(r, version)
if err != nil {
return nil, err
}
if version[0] != 1 {
return nil, errors.New("version is invalid")
}
ss = NewIpCidrSet()
var length int64
// rr
err = binary.Read(r, binary.BigEndian, &length)
if err != nil {
return nil, err
}
if length < 1 {
return nil, errors.New("length is invalid")
}
ss.rr = make([]netipx.IPRange, length)
for i := int64(0); i < length; i++ {
var a16 [16]byte
err = binary.Read(r, binary.BigEndian, &a16)
if err != nil {
return nil, err
}
from := netip.AddrFrom16(a16).Unmap()
err = binary.Read(r, binary.BigEndian, &a16)
if err != nil {
return nil, err
}
to := netip.AddrFrom16(a16).Unmap()
ss.rr[i] = netipx.IPRangeFrom(from, to)
}
return ss, nil
}

View file

@ -13,7 +13,6 @@ import (
"time"
"github.com/metacubex/mihomo/component/resolver"
"github.com/metacubex/mihomo/constant/features"
"github.com/metacubex/mihomo/log"
)
@ -79,29 +78,29 @@ func DialContext(ctx context.Context, network, address string, options ...Option
}
func ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort, options ...Option) (net.PacketConn, error) {
if features.CMFA && DefaultSocketHook != nil {
return listenPacketHooked(ctx, network, address)
}
cfg := applyOptions(options...)
lc := &net.ListenConfig{}
if cfg.interfaceName != "" {
bind := bindIfaceToListenConfig
if cfg.fallbackBind {
bind = fallbackBindIfaceToListenConfig
}
addr, err := bind(cfg.interfaceName, lc, network, address, rAddrPort)
if err != nil {
return nil, err
}
address = addr
}
if cfg.addrReuse {
addrReuseToListenConfig(lc)
}
if cfg.routingMark != 0 {
bindMarkToListenConfig(cfg.routingMark, lc, network, address)
if DefaultSocketHook != nil { // ignore interfaceName, routingMark when DefaultSocketHook not null (in CMFA)
socketHookToListenConfig(lc)
} else {
if cfg.interfaceName != "" {
bind := bindIfaceToListenConfig
if cfg.fallbackBind {
bind = fallbackBindIfaceToListenConfig
}
addr, err := bind(cfg.interfaceName, lc, network, address, rAddrPort)
if err != nil {
return nil, err
}
address = addr
}
if cfg.routingMark != 0 {
bindMarkToListenConfig(cfg.routingMark, lc, network, address)
}
}
return lc.ListenPacket(ctx, network, address)
@ -127,10 +126,6 @@ func GetTcpConcurrent() bool {
}
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) {
if features.CMFA && DefaultSocketHook != nil {
return dialContextHooked(ctx, network, destination, port)
}
var address string
if IP4PEnable {
destination, port = lookupIP4P(destination, port)
@ -149,24 +144,30 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
}
dialer := netDialer.(*net.Dialer)
if opt.interfaceName != "" {
bind := bindIfaceToDialer
if opt.fallbackBind {
bind = fallbackBindIfaceToDialer
}
if err := bind(opt.interfaceName, dialer, network, destination); err != nil {
return nil, err
}
}
if opt.routingMark != 0 {
bindMarkToDialer(opt.routingMark, dialer, network, destination)
}
if opt.mpTcp {
setMultiPathTCP(dialer)
}
if opt.tfo && !DisableTFO {
return dialTFO(ctx, *dialer, network, address)
if DefaultSocketHook != nil { // ignore interfaceName, routingMark and tfo when DefaultSocketHook not null (in CMFA)
socketHookToToDialer(dialer)
} else {
if opt.interfaceName != "" {
bind := bindIfaceToDialer
if opt.fallbackBind {
bind = fallbackBindIfaceToDialer
}
if err := bind(opt.interfaceName, dialer, network, destination); err != nil {
return nil, err
}
}
if opt.routingMark != 0 {
bindMarkToDialer(opt.routingMark, dialer, network, destination)
}
if opt.tfo && !DisableTFO {
return dialTFO(ctx, *dialer, network, address)
}
}
return dialer.DialContext(ctx, network, address)
}

View file

@ -1,39 +0,0 @@
//go:build android && cmfa
package dialer
import (
"context"
"net"
"net/netip"
"syscall"
)
type SocketControl func(network, address string, conn syscall.RawConn) error
var DefaultSocketHook SocketControl
func dialContextHooked(ctx context.Context, network string, destination netip.Addr, port string) (net.Conn, error) {
dialer := &net.Dialer{
Control: DefaultSocketHook,
}
conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(destination.String(), port))
if err != nil {
return nil, err
}
if t, ok := conn.(*net.TCPConn); ok {
t.SetKeepAlive(false)
}
return conn, nil
}
func listenPacketHooked(ctx context.Context, network, address string) (net.PacketConn, error) {
lc := &net.ListenConfig{
Control: DefaultSocketHook,
}
return lc.ListenPacket(ctx, network, address)
}

View file

@ -1,22 +0,0 @@
//go:build !(android && cmfa)
package dialer
import (
"context"
"net"
"net/netip"
"syscall"
)
type SocketControl func(network, address string, conn syscall.RawConn) error
var DefaultSocketHook SocketControl
func dialContextHooked(ctx context.Context, network string, destination netip.Addr, port string) (net.Conn, error) {
return nil, nil
}
func listenPacketHooked(ctx context.Context, network, address string) (net.PacketConn, error) {
return nil, nil
}

View file

@ -0,0 +1,27 @@
package dialer
import (
"context"
"net"
"syscall"
)
// SocketControl
// never change type traits because it's used in CMFA
type SocketControl func(network, address string, conn syscall.RawConn) error
// DefaultSocketHook
// never change type traits because it's used in CMFA
var DefaultSocketHook SocketControl
func socketHookToToDialer(dialer *net.Dialer) {
addControlToDialer(dialer, func(ctx context.Context, network, address string, c syscall.RawConn) error {
return DefaultSocketHook(network, address, c)
})
}
func socketHookToListenConfig(lc *net.ListenConfig) {
addControlToListenConfig(lc, func(ctx context.Context, network, address string, c syscall.RawConn) error {
return DefaultSocketHook(network, address, c)
})
}

View file

@ -5,8 +5,12 @@ import (
"io"
"net"
"time"
"github.com/metacubex/tfo-go"
)
var DisableTFO = false
type tfoConn struct {
net.Conn
closed bool
@ -120,3 +124,16 @@ func (c *tfoConn) ReaderReplaceable() bool {
func (c *tfoConn) WriterReplaceable() bool {
return c.Conn != nil
}
func dialTFO(ctx context.Context, netDialer net.Dialer, network, address string) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), DefaultTCPTimeout)
dialer := tfo.Dialer{Dialer: netDialer, DisableTFO: false}
return &tfoConn{
dialed: make(chan bool, 1),
cancel: cancel,
ctx: ctx,
dialFn: func(ctx context.Context, earlyData []byte) (net.Conn, error) {
return dialer.DialContext(ctx, network, address, earlyData)
},
}, nil
}

View file

@ -1,25 +0,0 @@
//go:build unix
package dialer
import (
"context"
"net"
"github.com/metacubex/tfo-go"
)
const DisableTFO = false
func dialTFO(ctx context.Context, netDialer net.Dialer, network, address string) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), DefaultTCPTimeout)
dialer := tfo.Dialer{Dialer: netDialer, DisableTFO: false}
return &tfoConn{
dialed: make(chan bool, 1),
cancel: cancel,
ctx: ctx,
dialFn: func(ctx context.Context, earlyData []byte) (net.Conn, error) {
return dialer.DialContext(ctx, network, address, earlyData)
},
}, nil
}

View file

@ -1,12 +1,11 @@
package dialer
import (
"context"
"net"
)
import "github.com/metacubex/mihomo/constant/features"
const DisableTFO = true
func dialTFO(ctx context.Context, netDialer net.Dialer, network, address string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, address)
func init() {
// According to MSDN, this option is available since Windows 10, 1607
// https://msdn.microsoft.com/en-us/library/windows/desktop/ms738596(v=vs.85).aspx
if features.WindowsMajorVersion < 10 || (features.WindowsMajorVersion == 10 && features.WindowsBuildNumber < 14393) {
DisableTFO = true
}
}

View file

@ -1,99 +0,0 @@
/* SPDX-License-Identifier: (LGPL-2.1 OR BSD-2-Clause) */
#ifndef __BPF_ENDIAN__
#define __BPF_ENDIAN__
/*
* Isolate byte #n and put it into byte #m, for __u##b type.
* E.g., moving byte #6 (nnnnnnnn) into byte #1 (mmmmmmmm) for __u64:
* 1) xxxxxxxx nnnnnnnn xxxxxxxx xxxxxxxx xxxxxxxx xxxxxxxx mmmmmmmm xxxxxxxx
* 2) nnnnnnnn xxxxxxxx xxxxxxxx xxxxxxxx xxxxxxxx mmmmmmmm xxxxxxxx 00000000
* 3) 00000000 00000000 00000000 00000000 00000000 00000000 00000000 nnnnnnnn
* 4) 00000000 00000000 00000000 00000000 00000000 00000000 nnnnnnnn 00000000
*/
#define ___bpf_mvb(x, b, n, m) ((__u##b)(x) << (b-(n+1)*8) >> (b-8) << (m*8))
#define ___bpf_swab16(x) ((__u16)( \
___bpf_mvb(x, 16, 0, 1) | \
___bpf_mvb(x, 16, 1, 0)))
#define ___bpf_swab32(x) ((__u32)( \
___bpf_mvb(x, 32, 0, 3) | \
___bpf_mvb(x, 32, 1, 2) | \
___bpf_mvb(x, 32, 2, 1) | \
___bpf_mvb(x, 32, 3, 0)))
#define ___bpf_swab64(x) ((__u64)( \
___bpf_mvb(x, 64, 0, 7) | \
___bpf_mvb(x, 64, 1, 6) | \
___bpf_mvb(x, 64, 2, 5) | \
___bpf_mvb(x, 64, 3, 4) | \
___bpf_mvb(x, 64, 4, 3) | \
___bpf_mvb(x, 64, 5, 2) | \
___bpf_mvb(x, 64, 6, 1) | \
___bpf_mvb(x, 64, 7, 0)))
/* LLVM's BPF target selects the endianness of the CPU
* it compiles on, or the user specifies (bpfel/bpfeb),
* respectively. The used __BYTE_ORDER__ is defined by
* the compiler, we cannot rely on __BYTE_ORDER from
* libc headers, since it doesn't reflect the actual
* requested byte order.
*
* Note, LLVM's BPF target has different __builtin_bswapX()
* semantics. It does map to BPF_ALU | BPF_END | BPF_TO_BE
* in bpfel and bpfeb case, which means below, that we map
* to cpu_to_be16(). We could use it unconditionally in BPF
* case, but better not rely on it, so that this header here
* can be used from application and BPF program side, which
* use different targets.
*/
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
# define __bpf_ntohs(x) __builtin_bswap16(x)
# define __bpf_htons(x) __builtin_bswap16(x)
# define __bpf_constant_ntohs(x) ___bpf_swab16(x)
# define __bpf_constant_htons(x) ___bpf_swab16(x)
# define __bpf_ntohl(x) __builtin_bswap32(x)
# define __bpf_htonl(x) __builtin_bswap32(x)
# define __bpf_constant_ntohl(x) ___bpf_swab32(x)
# define __bpf_constant_htonl(x) ___bpf_swab32(x)
# define __bpf_be64_to_cpu(x) __builtin_bswap64(x)
# define __bpf_cpu_to_be64(x) __builtin_bswap64(x)
# define __bpf_constant_be64_to_cpu(x) ___bpf_swab64(x)
# define __bpf_constant_cpu_to_be64(x) ___bpf_swab64(x)
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
# define __bpf_ntohs(x) (x)
# define __bpf_htons(x) (x)
# define __bpf_constant_ntohs(x) (x)
# define __bpf_constant_htons(x) (x)
# define __bpf_ntohl(x) (x)
# define __bpf_htonl(x) (x)
# define __bpf_constant_ntohl(x) (x)
# define __bpf_constant_htonl(x) (x)
# define __bpf_be64_to_cpu(x) (x)
# define __bpf_cpu_to_be64(x) (x)
# define __bpf_constant_be64_to_cpu(x) (x)
# define __bpf_constant_cpu_to_be64(x) (x)
#else
# error "Fix your compiler's __BYTE_ORDER__?!"
#endif
#define bpf_htons(x) \
(__builtin_constant_p(x) ? \
__bpf_constant_htons(x) : __bpf_htons(x))
#define bpf_ntohs(x) \
(__builtin_constant_p(x) ? \
__bpf_constant_ntohs(x) : __bpf_ntohs(x))
#define bpf_htonl(x) \
(__builtin_constant_p(x) ? \
__bpf_constant_htonl(x) : __bpf_htonl(x))
#define bpf_ntohl(x) \
(__builtin_constant_p(x) ? \
__bpf_constant_ntohl(x) : __bpf_ntohl(x))
#define bpf_cpu_to_be64(x) \
(__builtin_constant_p(x) ? \
__bpf_constant_cpu_to_be64(x) : __bpf_cpu_to_be64(x))
#define bpf_be64_to_cpu(x) \
(__builtin_constant_p(x) ? \
__bpf_constant_be64_to_cpu(x) : __bpf_be64_to_cpu(x))
#endif /* __BPF_ENDIAN__ */

File diff suppressed because it is too large Load diff

View file

@ -1,262 +0,0 @@
/* SPDX-License-Identifier: (LGPL-2.1 OR BSD-2-Clause) */
#ifndef __BPF_HELPERS__
#define __BPF_HELPERS__
/*
* Note that bpf programs need to include either
* vmlinux.h (auto-generated from BTF) or linux/types.h
* in advance since bpf_helper_defs.h uses such types
* as __u64.
*/
#include "bpf_helper_defs.h"
#define __uint(name, val) int (*name)[val]
#define __type(name, val) typeof(val) *name
#define __array(name, val) typeof(val) *name[]
/*
* Helper macro to place programs, maps, license in
* different sections in elf_bpf file. Section names
* are interpreted by libbpf depending on the context (BPF programs, BPF maps,
* extern variables, etc).
* To allow use of SEC() with externs (e.g., for extern .maps declarations),
* make sure __attribute__((unused)) doesn't trigger compilation warning.
*/
#define SEC(name) \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wignored-attributes\"") \
__attribute__((section(name), used)) \
_Pragma("GCC diagnostic pop") \
/* Avoid 'linux/stddef.h' definition of '__always_inline'. */
#undef __always_inline
#define __always_inline inline __attribute__((always_inline))
#ifndef __noinline
#define __noinline __attribute__((noinline))
#endif
#ifndef __weak
#define __weak __attribute__((weak))
#endif
/*
* Use __hidden attribute to mark a non-static BPF subprogram effectively
* static for BPF verifier's verification algorithm purposes, allowing more
* extensive and permissive BPF verification process, taking into account
* subprogram's caller context.
*/
#define __hidden __attribute__((visibility("hidden")))
/* When utilizing vmlinux.h with BPF CO-RE, user BPF programs can't include
* any system-level headers (such as stddef.h, linux/version.h, etc), and
* commonly-used macros like NULL and KERNEL_VERSION aren't available through
* vmlinux.h. This just adds unnecessary hurdles and forces users to re-define
* them on their own. So as a convenience, provide such definitions here.
*/
#ifndef NULL
#define NULL ((void *)0)
#endif
#ifndef KERNEL_VERSION
#define KERNEL_VERSION(a, b, c) (((a) << 16) + ((b) << 8) + ((c) > 255 ? 255 : (c)))
#endif
/*
* Helper macros to manipulate data structures
*/
#ifndef offsetof
#define offsetof(TYPE, MEMBER) ((unsigned long)&((TYPE *)0)->MEMBER)
#endif
#ifndef container_of
#define container_of(ptr, type, member) \
({ \
void *__mptr = (void *)(ptr); \
((type *)(__mptr - offsetof(type, member))); \
})
#endif
/*
* Helper macro to throw a compilation error if __bpf_unreachable() gets
* built into the resulting code. This works given BPF back end does not
* implement __builtin_trap(). This is useful to assert that certain paths
* of the program code are never used and hence eliminated by the compiler.
*
* For example, consider a switch statement that covers known cases used by
* the program. __bpf_unreachable() can then reside in the default case. If
* the program gets extended such that a case is not covered in the switch
* statement, then it will throw a build error due to the default case not
* being compiled out.
*/
#ifndef __bpf_unreachable
# define __bpf_unreachable() __builtin_trap()
#endif
/*
* Helper function to perform a tail call with a constant/immediate map slot.
*/
#if __clang_major__ >= 8 && defined(__bpf__)
static __always_inline void
bpf_tail_call_static(void *ctx, const void *map, const __u32 slot)
{
if (!__builtin_constant_p(slot))
__bpf_unreachable();
/*
* Provide a hard guarantee that LLVM won't optimize setting r2 (map
* pointer) and r3 (constant map index) from _different paths_ ending
* up at the _same_ call insn as otherwise we won't be able to use the
* jmpq/nopl retpoline-free patching by the x86-64 JIT in the kernel
* given they mismatch. See also d2e4c1e6c294 ("bpf: Constant map key
* tracking for prog array pokes") for details on verifier tracking.
*
* Note on clobber list: we need to stay in-line with BPF calling
* convention, so even if we don't end up using r0, r4, r5, we need
* to mark them as clobber so that LLVM doesn't end up using them
* before / after the call.
*/
asm volatile("r1 = %[ctx]\n\t"
"r2 = %[map]\n\t"
"r3 = %[slot]\n\t"
"call 12"
:: [ctx]"r"(ctx), [map]"r"(map), [slot]"i"(slot)
: "r0", "r1", "r2", "r3", "r4", "r5");
}
#endif
/*
* Helper structure used by eBPF C program
* to describe BPF map attributes to libbpf loader
*/
struct bpf_map_def {
unsigned int type;
unsigned int key_size;
unsigned int value_size;
unsigned int max_entries;
unsigned int map_flags;
};
enum libbpf_pin_type {
LIBBPF_PIN_NONE,
/* PIN_BY_NAME: pin maps by name (in /sys/fs/bpf by default) */
LIBBPF_PIN_BY_NAME,
};
enum libbpf_tristate {
TRI_NO = 0,
TRI_YES = 1,
TRI_MODULE = 2,
};
#define __kconfig __attribute__((section(".kconfig")))
#define __ksym __attribute__((section(".ksyms")))
#ifndef ___bpf_concat
#define ___bpf_concat(a, b) a ## b
#endif
#ifndef ___bpf_apply
#define ___bpf_apply(fn, n) ___bpf_concat(fn, n)
#endif
#ifndef ___bpf_nth
#define ___bpf_nth(_, _1, _2, _3, _4, _5, _6, _7, _8, _9, _a, _b, _c, N, ...) N
#endif
#ifndef ___bpf_narg
#define ___bpf_narg(...) \
___bpf_nth(_, ##__VA_ARGS__, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
#endif
#define ___bpf_fill0(arr, p, x) do {} while (0)
#define ___bpf_fill1(arr, p, x) arr[p] = x
#define ___bpf_fill2(arr, p, x, args...) arr[p] = x; ___bpf_fill1(arr, p + 1, args)
#define ___bpf_fill3(arr, p, x, args...) arr[p] = x; ___bpf_fill2(arr, p + 1, args)
#define ___bpf_fill4(arr, p, x, args...) arr[p] = x; ___bpf_fill3(arr, p + 1, args)
#define ___bpf_fill5(arr, p, x, args...) arr[p] = x; ___bpf_fill4(arr, p + 1, args)
#define ___bpf_fill6(arr, p, x, args...) arr[p] = x; ___bpf_fill5(arr, p + 1, args)
#define ___bpf_fill7(arr, p, x, args...) arr[p] = x; ___bpf_fill6(arr, p + 1, args)
#define ___bpf_fill8(arr, p, x, args...) arr[p] = x; ___bpf_fill7(arr, p + 1, args)
#define ___bpf_fill9(arr, p, x, args...) arr[p] = x; ___bpf_fill8(arr, p + 1, args)
#define ___bpf_fill10(arr, p, x, args...) arr[p] = x; ___bpf_fill9(arr, p + 1, args)
#define ___bpf_fill11(arr, p, x, args...) arr[p] = x; ___bpf_fill10(arr, p + 1, args)
#define ___bpf_fill12(arr, p, x, args...) arr[p] = x; ___bpf_fill11(arr, p + 1, args)
#define ___bpf_fill(arr, args...) \
___bpf_apply(___bpf_fill, ___bpf_narg(args))(arr, 0, args)
/*
* BPF_SEQ_PRINTF to wrap bpf_seq_printf to-be-printed values
* in a structure.
*/
#define BPF_SEQ_PRINTF(seq, fmt, args...) \
({ \
static const char ___fmt[] = fmt; \
unsigned long long ___param[___bpf_narg(args)]; \
\
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
___bpf_fill(___param, args); \
_Pragma("GCC diagnostic pop") \
\
bpf_seq_printf(seq, ___fmt, sizeof(___fmt), \
___param, sizeof(___param)); \
})
/*
* BPF_SNPRINTF wraps the bpf_snprintf helper with variadic arguments instead of
* an array of u64.
*/
#define BPF_SNPRINTF(out, out_size, fmt, args...) \
({ \
static const char ___fmt[] = fmt; \
unsigned long long ___param[___bpf_narg(args)]; \
\
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
___bpf_fill(___param, args); \
_Pragma("GCC diagnostic pop") \
\
bpf_snprintf(out, out_size, ___fmt, \
___param, sizeof(___param)); \
})
#ifdef BPF_NO_GLOBAL_DATA
#define BPF_PRINTK_FMT_MOD
#else
#define BPF_PRINTK_FMT_MOD static const
#endif
#define __bpf_printk(fmt, ...) \
({ \
BPF_PRINTK_FMT_MOD char ____fmt[] = fmt; \
bpf_trace_printk(____fmt, sizeof(____fmt), \
##__VA_ARGS__); \
})
/*
* __bpf_vprintk wraps the bpf_trace_vprintk helper with variadic arguments
* instead of an array of u64.
*/
#define __bpf_vprintk(fmt, args...) \
({ \
static const char ___fmt[] = fmt; \
unsigned long long ___param[___bpf_narg(args)]; \
\
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
___bpf_fill(___param, args); \
_Pragma("GCC diagnostic pop") \
\
bpf_trace_vprintk(___fmt, sizeof(___fmt), \
___param, sizeof(___param)); \
})
/* Use __bpf_printk when bpf_printk call has 3 or fewer fmt args
* Otherwise use __bpf_vprintk
*/
#define ___bpf_pick_printk(...) \
___bpf_nth(_, ##__VA_ARGS__, __bpf_vprintk, __bpf_vprintk, __bpf_vprintk, \
__bpf_vprintk, __bpf_vprintk, __bpf_vprintk, __bpf_vprintk, \
__bpf_vprintk, __bpf_vprintk, __bpf_printk /*3*/, __bpf_printk /*2*/,\
__bpf_printk /*1*/, __bpf_printk /*0*/)
/* Helper macro to print out debug messages */
#define bpf_printk(fmt, args...) ___bpf_pick_printk(args)(fmt, ##args)
#endif

View file

@ -1,342 +0,0 @@
#include <stdint.h>
#include <stdbool.h>
//#include <linux/types.h>
#include <linux/bpf.h>
#include <linux/if_ether.h>
//#include <linux/if_packet.h>
//#include <linux/if_vlan.h>
#include <linux/ip.h>
#include <linux/in.h>
#include <linux/tcp.h>
//#include <linux/udp.h>
#include <linux/pkt_cls.h>
#include "bpf_endian.h"
#include "bpf_helpers.h"
#define IP_CSUM_OFF (ETH_HLEN + offsetof(struct iphdr, check))
#define IP_DST_OFF (ETH_HLEN + offsetof(struct iphdr, daddr))
#define IP_SRC_OFF (ETH_HLEN + offsetof(struct iphdr, saddr))
#define IP_PROTO_OFF (ETH_HLEN + offsetof(struct iphdr, protocol))
#define TCP_CSUM_OFF (ETH_HLEN + sizeof(struct iphdr) + offsetof(struct tcphdr, check))
#define TCP_SRC_OFF (ETH_HLEN + sizeof(struct iphdr) + offsetof(struct tcphdr, source))
#define TCP_DST_OFF (ETH_HLEN + sizeof(struct iphdr) + offsetof(struct tcphdr, dest))
//#define UDP_CSUM_OFF (ETH_HLEN + sizeof(struct iphdr) + offsetof(struct udphdr, check))
//#define UDP_SRC_OFF (ETH_HLEN + sizeof(struct iphdr) + offsetof(struct udphdr, source))
//#define UDP_DST_OFF (ETH_HLEN + sizeof(struct iphdr) + offsetof(struct udphdr, dest))
#define IS_PSEUDO 0x10
struct origin_info {
__be32 ip;
__be16 port;
__u16 pad;
};
struct origin_info *origin_info_unused __attribute__((unused));
struct redir_info {
__be32 sip;
__be32 dip;
__be16 sport;
__be16 dport;
};
struct redir_info *redir_info_unused __attribute__((unused));
struct {
__uint(type, BPF_MAP_TYPE_LRU_HASH);
__type(key, struct redir_info);
__type(value, struct origin_info);
__uint(max_entries, 65535);
__uint(pinning, LIBBPF_PIN_BY_NAME);
} pair_original_dst_map SEC(".maps");
struct {
__uint(type, BPF_MAP_TYPE_ARRAY);
__type(key, __u32);
__type(value, __u32);
__uint(max_entries, 3);
__uint(pinning, LIBBPF_PIN_BY_NAME);
} redir_params_map SEC(".maps");
static __always_inline int rewrite_ip(struct __sk_buff *skb, __be32 new_ip, bool is_dest) {
int ret, off = 0, flags = IS_PSEUDO;
__be32 old_ip;
if (is_dest)
ret = bpf_skb_load_bytes(skb, IP_DST_OFF, &old_ip, 4);
else
ret = bpf_skb_load_bytes(skb, IP_SRC_OFF, &old_ip, 4);
if (ret < 0) {
return ret;
}
off = TCP_CSUM_OFF;
// __u8 proto;
//
// ret = bpf_skb_load_bytes(skb, IP_PROTO_OFF, &proto, 1);
// if (ret < 0) {
// return BPF_DROP;
// }
//
// switch (proto) {
// case IPPROTO_TCP:
// off = TCP_CSUM_OFF;
// break;
//
// case IPPROTO_UDP:
// off = UDP_CSUM_OFF;
// flags |= BPF_F_MARK_MANGLED_0;
// break;
//
// case IPPROTO_ICMPV6:
// off = offsetof(struct icmp6hdr, icmp6_cksum);
// break;
// }
//
// if (off) {
ret = bpf_l4_csum_replace(skb, off, old_ip, new_ip, flags | sizeof(new_ip));
if (ret < 0) {
return ret;
}
// }
ret = bpf_l3_csum_replace(skb, IP_CSUM_OFF, old_ip, new_ip, sizeof(new_ip));
if (ret < 0) {
return ret;
}
if (is_dest)
ret = bpf_skb_store_bytes(skb, IP_DST_OFF, &new_ip, sizeof(new_ip), 0);
else
ret = bpf_skb_store_bytes(skb, IP_SRC_OFF, &new_ip, sizeof(new_ip), 0);
if (ret < 0) {
return ret;
}
return 1;
}
static __always_inline int rewrite_port(struct __sk_buff *skb, __be16 new_port, bool is_dest) {
int ret, off = 0;
__be16 old_port;
if (is_dest)
ret = bpf_skb_load_bytes(skb, TCP_DST_OFF, &old_port, 2);
else
ret = bpf_skb_load_bytes(skb, TCP_SRC_OFF, &old_port, 2);
if (ret < 0) {
return ret;
}
off = TCP_CSUM_OFF;
ret = bpf_l4_csum_replace(skb, off, old_port, new_port, sizeof(new_port));
if (ret < 0) {
return ret;
}
if (is_dest)
ret = bpf_skb_store_bytes(skb, TCP_DST_OFF, &new_port, sizeof(new_port), 0);
else
ret = bpf_skb_store_bytes(skb, TCP_SRC_OFF, &new_port, sizeof(new_port), 0);
if (ret < 0) {
return ret;
}
return 1;
}
static __always_inline bool is_lan_ip(__be32 addr) {
if (addr == 0xffffffff)
return true;
__u8 fist = (__u8)(addr & 0xff);
if (fist == 127 || fist == 10)
return true;
__u8 second = (__u8)((addr >> 8) & 0xff);
if (fist == 172 && second >= 16 && second <= 31)
return true;
if (fist == 192 && second == 168)
return true;
return false;
}
SEC("tc_mihomo_auto_redir_ingress")
int tc_redir_ingress_func(struct __sk_buff *skb) {
void *data = (void *)(long)skb->data;
void *data_end = (void *)(long)skb->data_end;
struct ethhdr *eth = data;
if ((void *)(eth + 1) > data_end)
return TC_ACT_OK;
if (eth->h_proto != bpf_htons(ETH_P_IP))
return TC_ACT_OK;
struct iphdr *iph = (struct iphdr *)(eth + 1);
if ((void *)(iph + 1) > data_end)
return TC_ACT_OK;
__u32 key = 0, *route_index, *redir_ip, *redir_port;
route_index = bpf_map_lookup_elem(&redir_params_map, &key);
if (!route_index)
return TC_ACT_OK;
if (iph->protocol == IPPROTO_ICMP && *route_index != 0)
return bpf_redirect(*route_index, 0);
if (iph->protocol != IPPROTO_TCP)
return TC_ACT_OK;
struct tcphdr *tcph = (struct tcphdr *)(iph + 1);
if ((void *)(tcph + 1) > data_end)
return TC_ACT_SHOT;
key = 1;
redir_ip = bpf_map_lookup_elem(&redir_params_map, &key);
if (!redir_ip)
return TC_ACT_OK;
key = 2;
redir_port = bpf_map_lookup_elem(&redir_params_map, &key);
if (!redir_port)
return TC_ACT_OK;
__be32 new_ip = bpf_htonl(*redir_ip);
__be16 new_port = bpf_htonl(*redir_port) >> 16;
__be32 old_ip = iph->daddr;
__be16 old_port = tcph->dest;
if (old_ip == new_ip || is_lan_ip(old_ip) || bpf_ntohs(old_port) == 53) {
return TC_ACT_OK;
}
struct redir_info p_key = {
.sip = iph->saddr,
.sport = tcph->source,
.dip = new_ip,
.dport = new_port,
};
if (tcph->syn && !tcph->ack) {
struct origin_info origin = {
.ip = old_ip,
.port = old_port,
};
bpf_map_update_elem(&pair_original_dst_map, &p_key, &origin, BPF_NOEXIST);
if (rewrite_ip(skb, new_ip, true) < 0) {
return TC_ACT_SHOT;
}
if (rewrite_port(skb, new_port, true) < 0) {
return TC_ACT_SHOT;
}
} else {
struct origin_info *origin = bpf_map_lookup_elem(&pair_original_dst_map, &p_key);
if (!origin) {
return TC_ACT_OK;
}
if (rewrite_ip(skb, new_ip, true) < 0) {
return TC_ACT_SHOT;
}
if (rewrite_port(skb, new_port, true) < 0) {
return TC_ACT_SHOT;
}
}
return TC_ACT_OK;
}
SEC("tc_mihomo_auto_redir_egress")
int tc_redir_egress_func(struct __sk_buff *skb) {
void *data = (void *)(long)skb->data;
void *data_end = (void *)(long)skb->data_end;
struct ethhdr *eth = data;
if ((void *)(eth + 1) > data_end)
return TC_ACT_OK;
if (eth->h_proto != bpf_htons(ETH_P_IP))
return TC_ACT_OK;
__u32 key = 0, *redir_ip, *redir_port; // *mihomo_mark
// mihomo_mark = bpf_map_lookup_elem(&redir_params_map, &key);
// if (mihomo_mark && *mihomo_mark != 0 && *mihomo_mark == skb->mark)
// return TC_ACT_OK;
struct iphdr *iph = (struct iphdr *)(eth + 1);
if ((void *)(iph + 1) > data_end)
return TC_ACT_OK;
if (iph->protocol != IPPROTO_TCP)
return TC_ACT_OK;
struct tcphdr *tcph = (struct tcphdr *)(iph + 1);
if ((void *)(tcph + 1) > data_end)
return TC_ACT_SHOT;
key = 1;
redir_ip = bpf_map_lookup_elem(&redir_params_map, &key);
if (!redir_ip)
return TC_ACT_OK;
key = 2;
redir_port = bpf_map_lookup_elem(&redir_params_map, &key);
if (!redir_port)
return TC_ACT_OK;
__be32 new_ip = bpf_htonl(*redir_ip);
__be16 new_port = bpf_htonl(*redir_port) >> 16;
__be32 old_ip = iph->saddr;
__be16 old_port = tcph->source;
if (old_ip != new_ip || old_port != new_port) {
return TC_ACT_OK;
}
struct redir_info p_key = {
.sip = iph->daddr,
.sport = tcph->dest,
.dip = iph->saddr,
.dport = tcph->source,
};
struct origin_info *origin = bpf_map_lookup_elem(&pair_original_dst_map, &p_key);
if (!origin) {
return TC_ACT_OK;
}
if (tcph->fin && tcph->ack) {
bpf_map_delete_elem(&pair_original_dst_map, &p_key);
}
if (rewrite_ip(skb, origin->ip, false) < 0) {
return TC_ACT_SHOT;
}
if (rewrite_port(skb, origin->port, false) < 0) {
return TC_ACT_SHOT;
}
return TC_ACT_OK;
}
char _license[] SEC("license") = "GPL";

View file

@ -1,103 +0,0 @@
#include <stdbool.h>
#include <linux/bpf.h>
#include <linux/if_ether.h>
#include <linux/ip.h>
#include <linux/in.h>
//#include <linux/tcp.h>
//#include <linux/udp.h>
#include <linux/pkt_cls.h>
#include "bpf_endian.h"
#include "bpf_helpers.h"
struct {
__uint(type, BPF_MAP_TYPE_ARRAY);
__type(key, __u32);
__type(value, __u32);
__uint(max_entries, 2);
__uint(pinning, LIBBPF_PIN_BY_NAME);
} tc_params_map SEC(".maps");
static __always_inline bool is_lan_ip(__be32 addr) {
if (addr == 0xffffffff)
return true;
__u8 fist = (__u8)(addr & 0xff);
if (fist == 127 || fist == 10)
return true;
__u8 second = (__u8)((addr >> 8) & 0xff);
if (fist == 172 && second >= 16 && second <= 31)
return true;
if (fist == 192 && second == 168)
return true;
return false;
}
SEC("tc_mihomo_redirect_to_tun")
int tc_tun_func(struct __sk_buff *skb) {
void *data = (void *)(long)skb->data;
void *data_end = (void *)(long)skb->data_end;
struct ethhdr *eth = data;
if ((void *)(eth + 1) > data_end)
return TC_ACT_OK;
if (eth->h_proto == bpf_htons(ETH_P_ARP))
return TC_ACT_OK;
__u32 key = 0, *mihomo_mark, *tun_ifindex;
mihomo_mark = bpf_map_lookup_elem(&tc_params_map, &key);
if (!mihomo_mark)
return TC_ACT_OK;
if (skb->mark == *mihomo_mark)
return TC_ACT_OK;
if (eth->h_proto == bpf_htons(ETH_P_IP)) {
struct iphdr *iph = (struct iphdr *)(eth + 1);
if ((void *)(iph + 1) > data_end)
return TC_ACT_OK;
if (iph->protocol == IPPROTO_ICMP)
return TC_ACT_OK;
__be32 daddr = iph->daddr;
if (is_lan_ip(daddr))
return TC_ACT_OK;
// if (iph->protocol == IPPROTO_TCP) {
// struct tcphdr *tcph = (struct tcphdr *)(iph + 1);
// if ((void *)(tcph + 1) > data_end)
// return TC_ACT_OK;
//
// __u16 source = bpf_ntohs(tcph->source);
// if (source == 22 || source == 80 || source == 443 || source == 8080 || source == 8443 || source == 9090 || (source >= 7890 && source <= 7895))
// return TC_ACT_OK;
// } else if (iph->protocol == IPPROTO_UDP) {
// struct udphdr *udph = (struct udphdr *)(iph + 1);
// if ((void *)(udph + 1) > data_end)
// return TC_ACT_OK;
//
// __u16 source = bpf_ntohs(udph->source);
// if (source == 53 || (source >= 135 && source <= 139))
// return TC_ACT_OK;
// }
}
key = 1;
tun_ifindex = bpf_map_lookup_elem(&tc_params_map, &key);
if (!tun_ifindex)
return TC_ACT_OK;
//return bpf_redirect(*tun_ifindex, BPF_F_INGRESS); // __bpf_rx_skb
return bpf_redirect(*tun_ifindex, 0); // __bpf_tx_skb / __dev_xmit_skb
}
char _license[] SEC("license") = "GPL";

View file

@ -1,13 +0,0 @@
package byteorder
import (
"net"
)
// NetIPv4ToHost32 converts an net.IP to a uint32 in host byte order. ip
// must be a IPv4 address, otherwise the function will panic.
func NetIPv4ToHost32(ip net.IP) uint32 {
ipv4 := ip.To4()
_ = ipv4[3] // Assert length of ipv4.
return Native.Uint32(ipv4)
}

View file

@ -1,12 +0,0 @@
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
package byteorder
import "encoding/binary"
var Native binary.ByteOrder = binary.BigEndian
func HostToNetwork16(u uint16) uint16 { return u }
func HostToNetwork32(u uint32) uint32 { return u }
func NetworkToHost16(u uint16) uint16 { return u }
func NetworkToHost32(u uint32) uint32 { return u }

View file

@ -1,15 +0,0 @@
//go:build 386 || amd64 || amd64p32 || arm || arm64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64 || loong64
package byteorder
import (
"encoding/binary"
"math/bits"
)
var Native binary.ByteOrder = binary.LittleEndian
func HostToNetwork16(u uint16) uint16 { return bits.ReverseBytes16(u) }
func HostToNetwork32(u uint32) uint32 { return bits.ReverseBytes32(u) }
func NetworkToHost16(u uint16) uint16 { return bits.ReverseBytes16(u) }
func NetworkToHost32(u uint32) uint32 { return bits.ReverseBytes32(u) }

View file

@ -1,33 +0,0 @@
package ebpf
import (
"net/netip"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/socks5"
)
type TcEBpfProgram struct {
pros []C.EBpf
rawNICs []string
}
func (t *TcEBpfProgram) RawNICs() []string {
return t.rawNICs
}
func (t *TcEBpfProgram) Close() {
for _, p := range t.pros {
p.Close()
}
}
func (t *TcEBpfProgram) Lookup(srcAddrPort netip.AddrPort) (addr socks5.Addr, err error) {
for _, p := range t.pros {
addr, err = p.Lookup(srcAddrPort)
if err == nil {
return
}
}
return
}

View file

@ -1,137 +0,0 @@
//go:build !android
package ebpf
import (
"fmt"
"net/netip"
"github.com/metacubex/mihomo/common/cmd"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/ebpf/redir"
"github.com/metacubex/mihomo/component/ebpf/tc"
C "github.com/metacubex/mihomo/constant"
"github.com/sagernet/netlink"
)
func GetAutoDetectInterface() (string, error) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_V4)
if err != nil {
return "", err
}
for _, route := range routes {
if route.Dst == nil {
lk, err := netlink.LinkByIndex(route.LinkIndex)
if err != nil {
return "", err
}
if lk.Type() == "tuntap" {
continue
}
return lk.Attrs().Name, nil
}
}
return "", fmt.Errorf("interface not found")
}
// NewTcEBpfProgram new redirect to tun ebpf program
func NewTcEBpfProgram(ifaceNames []string, tunName string) (*TcEBpfProgram, error) {
tunIface, err := netlink.LinkByName(tunName)
if err != nil {
return nil, fmt.Errorf("lookup network iface %q: %w", tunName, err)
}
tunIndex := uint32(tunIface.Attrs().Index)
dialer.DefaultRoutingMark.Store(C.MihomoTrafficMark)
ifMark := uint32(dialer.DefaultRoutingMark.Load())
var pros []C.EBpf
for _, ifaceName := range ifaceNames {
iface, err := netlink.LinkByName(ifaceName)
if err != nil {
return nil, fmt.Errorf("lookup network iface %q: %w", ifaceName, err)
}
if iface.Attrs().OperState != netlink.OperUp {
return nil, fmt.Errorf("network iface %q is down", ifaceName)
}
attrs := iface.Attrs()
index := attrs.Index
tcPro := tc.NewEBpfTc(ifaceName, index, ifMark, tunIndex)
if err = tcPro.Start(); err != nil {
return nil, err
}
pros = append(pros, tcPro)
}
systemSetting(ifaceNames...)
return &TcEBpfProgram{pros: pros, rawNICs: ifaceNames}, nil
}
// NewRedirEBpfProgram new auto redirect ebpf program
func NewRedirEBpfProgram(ifaceNames []string, redirPort uint16, defaultRouteInterfaceName string) (*TcEBpfProgram, error) {
defaultRouteInterface, err := netlink.LinkByName(defaultRouteInterfaceName)
if err != nil {
return nil, fmt.Errorf("lookup network iface %q: %w", defaultRouteInterfaceName, err)
}
defaultRouteIndex := uint32(defaultRouteInterface.Attrs().Index)
var pros []C.EBpf
for _, ifaceName := range ifaceNames {
iface, err := netlink.LinkByName(ifaceName)
if err != nil {
return nil, fmt.Errorf("lookup network iface %q: %w", ifaceName, err)
}
attrs := iface.Attrs()
index := attrs.Index
addrs, err := netlink.AddrList(iface, netlink.FAMILY_V4)
if err != nil {
return nil, fmt.Errorf("lookup network iface %q address: %w", ifaceName, err)
}
if len(addrs) == 0 {
return nil, fmt.Errorf("network iface %q does not contain any ipv4 addresses", ifaceName)
}
address, _ := netip.AddrFromSlice(addrs[0].IP)
redirAddrPort := netip.AddrPortFrom(address, redirPort)
redirPro := redir.NewEBpfRedirect(ifaceName, index, 0, defaultRouteIndex, redirAddrPort)
if err = redirPro.Start(); err != nil {
return nil, err
}
pros = append(pros, redirPro)
}
systemSetting(ifaceNames...)
return &TcEBpfProgram{pros: pros, rawNICs: ifaceNames}, nil
}
func systemSetting(ifaceNames ...string) {
_, _ = cmd.ExecCmd("sysctl -w net.ipv4.ip_forward=1")
_, _ = cmd.ExecCmd("sysctl -w net.ipv4.conf.all.forwarding=1")
_, _ = cmd.ExecCmd("sysctl -w net.ipv4.conf.all.accept_local=1")
_, _ = cmd.ExecCmd("sysctl -w net.ipv4.conf.all.accept_redirects=1")
_, _ = cmd.ExecCmd("sysctl -w net.ipv4.conf.all.rp_filter=0")
for _, ifaceName := range ifaceNames {
_, _ = cmd.ExecCmd(fmt.Sprintf("sysctl -w net.ipv4.conf.%s.forwarding=1", ifaceName))
_, _ = cmd.ExecCmd(fmt.Sprintf("sysctl -w net.ipv4.conf.%s.accept_local=1", ifaceName))
_, _ = cmd.ExecCmd(fmt.Sprintf("sysctl -w net.ipv4.conf.%s.accept_redirects=1", ifaceName))
_, _ = cmd.ExecCmd(fmt.Sprintf("sysctl -w net.ipv4.conf.%s.rp_filter=0", ifaceName))
}
}

View file

@ -1,21 +0,0 @@
//go:build !linux || android
package ebpf
import (
"fmt"
)
// NewTcEBpfProgram new ebpf tc program
func NewTcEBpfProgram(_ []string, _ string) (*TcEBpfProgram, error) {
return nil, fmt.Errorf("system not supported")
}
// NewRedirEBpfProgram new ebpf redirect program
func NewRedirEBpfProgram(_ []string, _ uint16, _ string) (*TcEBpfProgram, error) {
return nil, fmt.Errorf("system not supported")
}
func GetAutoDetectInterface() (string, error) {
return "", fmt.Errorf("system not supported")
}

View file

@ -1,216 +0,0 @@
//go:build linux
package redir
import (
"encoding/binary"
"fmt"
"io"
"net"
"net/netip"
"os"
"path/filepath"
"github.com/cilium/ebpf"
"github.com/cilium/ebpf/rlimit"
"github.com/sagernet/netlink"
"golang.org/x/sys/unix"
"github.com/metacubex/mihomo/component/ebpf/byteorder"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/socks5"
)
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc $BPF_CLANG -cflags $BPF_CFLAGS bpf ../bpf/redir.c
const (
mapKey1 uint32 = 0
mapKey2 uint32 = 1
mapKey3 uint32 = 2
)
type EBpfRedirect struct {
objs io.Closer
originMap *ebpf.Map
qdisc netlink.Qdisc
filter netlink.Filter
filterEgress netlink.Filter
ifName string
ifIndex int
ifMark uint32
rtIndex uint32
redirIp uint32
redirPort uint16
bpfPath string
}
func NewEBpfRedirect(ifName string, ifIndex int, ifMark uint32, routeIndex uint32, redirAddrPort netip.AddrPort) *EBpfRedirect {
return &EBpfRedirect{
ifName: ifName,
ifIndex: ifIndex,
ifMark: ifMark,
rtIndex: routeIndex,
redirIp: binary.BigEndian.Uint32(redirAddrPort.Addr().AsSlice()),
redirPort: redirAddrPort.Port(),
}
}
func (e *EBpfRedirect) Start() error {
if err := rlimit.RemoveMemlock(); err != nil {
return fmt.Errorf("remove memory lock: %w", err)
}
e.bpfPath = filepath.Join(C.BpfFSPath, e.ifName)
if err := os.MkdirAll(e.bpfPath, os.ModePerm); err != nil {
return fmt.Errorf("failed to create bpf fs subpath: %w", err)
}
var objs bpfObjects
if err := loadBpfObjects(&objs, &ebpf.CollectionOptions{
Maps: ebpf.MapOptions{
PinPath: e.bpfPath,
},
}); err != nil {
e.Close()
return fmt.Errorf("loading objects: %w", err)
}
e.objs = &objs
e.originMap = objs.bpfMaps.PairOriginalDstMap
if err := objs.bpfMaps.RedirParamsMap.Update(mapKey1, e.rtIndex, ebpf.UpdateAny); err != nil {
e.Close()
return fmt.Errorf("storing objects: %w", err)
}
if err := objs.bpfMaps.RedirParamsMap.Update(mapKey2, e.redirIp, ebpf.UpdateAny); err != nil {
e.Close()
return fmt.Errorf("storing objects: %w", err)
}
if err := objs.bpfMaps.RedirParamsMap.Update(mapKey3, uint32(e.redirPort), ebpf.UpdateAny); err != nil {
e.Close()
return fmt.Errorf("storing objects: %w", err)
}
attrs := netlink.QdiscAttrs{
LinkIndex: e.ifIndex,
Handle: netlink.MakeHandle(0xffff, 0),
Parent: netlink.HANDLE_CLSACT,
}
qdisc := &netlink.GenericQdisc{
QdiscAttrs: attrs,
QdiscType: "clsact",
}
e.qdisc = qdisc
if err := netlink.QdiscAdd(qdisc); err != nil {
if os.IsExist(err) {
_ = netlink.QdiscDel(qdisc)
err = netlink.QdiscAdd(qdisc)
}
if err != nil {
e.Close()
return fmt.Errorf("cannot add clsact qdisc: %w", err)
}
}
filterAttrs := netlink.FilterAttrs{
LinkIndex: e.ifIndex,
Parent: netlink.HANDLE_MIN_INGRESS,
Handle: netlink.MakeHandle(0, 1),
Protocol: unix.ETH_P_IP,
Priority: 0,
}
filter := &netlink.BpfFilter{
FilterAttrs: filterAttrs,
Fd: objs.bpfPrograms.TcRedirIngressFunc.FD(),
Name: "mihomo-redir-ingress-" + e.ifName,
DirectAction: true,
}
if err := netlink.FilterAdd(filter); err != nil {
e.Close()
return fmt.Errorf("cannot attach ebpf object to filter ingress: %w", err)
}
e.filter = filter
filterAttrsEgress := netlink.FilterAttrs{
LinkIndex: e.ifIndex,
Parent: netlink.HANDLE_MIN_EGRESS,
Handle: netlink.MakeHandle(0, 1),
Protocol: unix.ETH_P_IP,
Priority: 0,
}
filterEgress := &netlink.BpfFilter{
FilterAttrs: filterAttrsEgress,
Fd: objs.bpfPrograms.TcRedirEgressFunc.FD(),
Name: "mihomo-redir-egress-" + e.ifName,
DirectAction: true,
}
if err := netlink.FilterAdd(filterEgress); err != nil {
e.Close()
return fmt.Errorf("cannot attach ebpf object to filter egress: %w", err)
}
e.filterEgress = filterEgress
return nil
}
func (e *EBpfRedirect) Close() {
if e.filter != nil {
_ = netlink.FilterDel(e.filter)
}
if e.filterEgress != nil {
_ = netlink.FilterDel(e.filterEgress)
}
if e.qdisc != nil {
_ = netlink.QdiscDel(e.qdisc)
}
if e.objs != nil {
_ = e.objs.Close()
}
_ = os.Remove(filepath.Join(e.bpfPath, "redir_params_map"))
_ = os.Remove(filepath.Join(e.bpfPath, "pair_original_dst_map"))
}
func (e *EBpfRedirect) Lookup(srcAddrPort netip.AddrPort) (socks5.Addr, error) {
rAddr := srcAddrPort.Addr().Unmap()
if rAddr.Is6() {
return nil, fmt.Errorf("remote address is ipv6")
}
srcIp := binary.BigEndian.Uint32(rAddr.AsSlice())
scrPort := srcAddrPort.Port()
key := bpfRedirInfo{
Sip: byteorder.HostToNetwork32(srcIp),
Sport: byteorder.HostToNetwork16(scrPort),
Dip: byteorder.HostToNetwork32(e.redirIp),
Dport: byteorder.HostToNetwork16(e.redirPort),
}
origin := bpfOriginInfo{}
err := e.originMap.Lookup(key, &origin)
if err != nil {
return nil, err
}
addr := make([]byte, net.IPv4len+3)
addr[0] = socks5.AtypIPv4
binary.BigEndian.PutUint32(addr[1:1+net.IPv4len], byteorder.NetworkToHost32(origin.Ip)) // big end
binary.BigEndian.PutUint16(addr[1+net.IPv4len:3+net.IPv4len], byteorder.NetworkToHost16(origin.Port)) // big end
return addr, nil
}

View file

@ -1,139 +0,0 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
// +build arm64be armbe mips mips64 mips64p32 ppc64 s390 s390x sparc sparc64
package redir
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
type bpfOriginInfo struct {
Ip uint32
Port uint16
Pad uint16
}
type bpfRedirInfo struct {
Sip uint32
Dip uint32
Sport uint16
Dport uint16
}
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
TcRedirEgressFunc *ebpf.ProgramSpec `ebpf:"tc_redir_egress_func"`
TcRedirIngressFunc *ebpf.ProgramSpec `ebpf:"tc_redir_ingress_func"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
PairOriginalDstMap *ebpf.MapSpec `ebpf:"pair_original_dst_map"`
RedirParamsMap *ebpf.MapSpec `ebpf:"redir_params_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
PairOriginalDstMap *ebpf.Map `ebpf:"pair_original_dst_map"`
RedirParamsMap *ebpf.Map `ebpf:"redir_params_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.PairOriginalDstMap,
m.RedirParamsMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
TcRedirEgressFunc *ebpf.Program `ebpf:"tc_redir_egress_func"`
TcRedirIngressFunc *ebpf.Program `ebpf:"tc_redir_ingress_func"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.TcRedirEgressFunc,
p.TcRedirIngressFunc,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfeb.o
var _BpfBytes []byte

Binary file not shown.

View file

@ -1,139 +0,0 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build 386 || amd64 || amd64p32 || arm || arm64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64 || loong64
// +build 386 amd64 amd64p32 arm arm64 mips64le mips64p32le mipsle ppc64le riscv64 loong64
package redir
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
type bpfOriginInfo struct {
Ip uint32
Port uint16
Pad uint16
}
type bpfRedirInfo struct {
Sip uint32
Dip uint32
Sport uint16
Dport uint16
}
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
TcRedirEgressFunc *ebpf.ProgramSpec `ebpf:"tc_redir_egress_func"`
TcRedirIngressFunc *ebpf.ProgramSpec `ebpf:"tc_redir_ingress_func"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
PairOriginalDstMap *ebpf.MapSpec `ebpf:"pair_original_dst_map"`
RedirParamsMap *ebpf.MapSpec `ebpf:"redir_params_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
PairOriginalDstMap *ebpf.Map `ebpf:"pair_original_dst_map"`
RedirParamsMap *ebpf.Map `ebpf:"redir_params_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.PairOriginalDstMap,
m.RedirParamsMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
TcRedirEgressFunc *ebpf.Program `ebpf:"tc_redir_egress_func"`
TcRedirIngressFunc *ebpf.Program `ebpf:"tc_redir_ingress_func"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.TcRedirEgressFunc,
p.TcRedirIngressFunc,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfel.o
var _BpfBytes []byte

Binary file not shown.

View file

@ -1,120 +0,0 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
// +build arm64be armbe mips mips64 mips64p32 ppc64 s390 s390x sparc sparc64
package tc
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
TcTunFunc *ebpf.ProgramSpec `ebpf:"tc_tun_func"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
TcParamsMap *ebpf.MapSpec `ebpf:"tc_params_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
TcParamsMap *ebpf.Map `ebpf:"tc_params_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.TcParamsMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
TcTunFunc *ebpf.Program `ebpf:"tc_tun_func"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.TcTunFunc,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfeb.o
var _BpfBytes []byte

Binary file not shown.

View file

@ -1,120 +0,0 @@
//
//go:build 386 || amd64 || amd64p32 || arm || arm64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64 || loong64
// +build 386 amd64 amd64p32 arm arm64 mips64le mips64p32le mipsle ppc64le riscv64 loong64
package tc
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
TcTunFunc *ebpf.ProgramSpec `ebpf:"tc_tun_func"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
TcParamsMap *ebpf.MapSpec `ebpf:"tc_params_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
TcParamsMap *ebpf.Map `ebpf:"tc_params_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.TcParamsMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
TcTunFunc *ebpf.Program `ebpf:"tc_tun_func"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.TcTunFunc,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfel.o
var _BpfBytes []byte

Binary file not shown.

View file

@ -1,147 +0,0 @@
//go:build linux
package tc
import (
"fmt"
"io"
"net/netip"
"os"
"path/filepath"
"github.com/cilium/ebpf"
"github.com/cilium/ebpf/rlimit"
"github.com/sagernet/netlink"
"golang.org/x/sys/unix"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/socks5"
)
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc $BPF_CLANG -cflags $BPF_CFLAGS bpf ../bpf/tc.c
const (
mapKey1 uint32 = 0
mapKey2 uint32 = 1
)
type EBpfTC struct {
objs io.Closer
qdisc netlink.Qdisc
filter netlink.Filter
ifName string
ifIndex int
ifMark uint32
tunIfIndex uint32
bpfPath string
}
func NewEBpfTc(ifName string, ifIndex int, ifMark uint32, tunIfIndex uint32) *EBpfTC {
return &EBpfTC{
ifName: ifName,
ifIndex: ifIndex,
ifMark: ifMark,
tunIfIndex: tunIfIndex,
}
}
func (e *EBpfTC) Start() error {
if err := rlimit.RemoveMemlock(); err != nil {
return fmt.Errorf("remove memory lock: %w", err)
}
e.bpfPath = filepath.Join(C.BpfFSPath, e.ifName)
if err := os.MkdirAll(e.bpfPath, os.ModePerm); err != nil {
return fmt.Errorf("failed to create bpf fs subpath: %w", err)
}
var objs bpfObjects
if err := loadBpfObjects(&objs, &ebpf.CollectionOptions{
Maps: ebpf.MapOptions{
PinPath: e.bpfPath,
},
}); err != nil {
e.Close()
return fmt.Errorf("loading objects: %w", err)
}
e.objs = &objs
if err := objs.bpfMaps.TcParamsMap.Update(mapKey1, e.ifMark, ebpf.UpdateAny); err != nil {
e.Close()
return fmt.Errorf("storing objects: %w", err)
}
if err := objs.bpfMaps.TcParamsMap.Update(mapKey2, e.tunIfIndex, ebpf.UpdateAny); err != nil {
e.Close()
return fmt.Errorf("storing objects: %w", err)
}
attrs := netlink.QdiscAttrs{
LinkIndex: e.ifIndex,
Handle: netlink.MakeHandle(0xffff, 0),
Parent: netlink.HANDLE_CLSACT,
}
qdisc := &netlink.GenericQdisc{
QdiscAttrs: attrs,
QdiscType: "clsact",
}
e.qdisc = qdisc
if err := netlink.QdiscAdd(qdisc); err != nil {
if os.IsExist(err) {
_ = netlink.QdiscDel(qdisc)
err = netlink.QdiscAdd(qdisc)
}
if err != nil {
e.Close()
return fmt.Errorf("cannot add clsact qdisc: %w", err)
}
}
filterAttrs := netlink.FilterAttrs{
LinkIndex: e.ifIndex,
Parent: netlink.HANDLE_MIN_EGRESS,
Handle: netlink.MakeHandle(0, 1),
Protocol: unix.ETH_P_ALL,
Priority: 1,
}
filter := &netlink.BpfFilter{
FilterAttrs: filterAttrs,
Fd: objs.bpfPrograms.TcTunFunc.FD(),
Name: "mihomo-tc-" + e.ifName,
DirectAction: true,
}
if err := netlink.FilterAdd(filter); err != nil {
e.Close()
return fmt.Errorf("cannot attach ebpf object to filter: %w", err)
}
e.filter = filter
return nil
}
func (e *EBpfTC) Close() {
if e.filter != nil {
_ = netlink.FilterDel(e.filter)
}
if e.qdisc != nil {
_ = netlink.QdiscDel(e.qdisc)
}
if e.objs != nil {
_ = e.objs.Close()
}
_ = os.Remove(filepath.Join(e.bpfPath, "tc_params_map"))
}
func (e *EBpfTC) Lookup(_ netip.AddrPort) (socks5.Addr, error) {
return nil, fmt.Errorf("not supported")
}

View file

@ -67,8 +67,9 @@ func (m *memoryStore) CloneTo(store store) {
// FlushFakeIP implements store.FlushFakeIP
func (m *memoryStore) FlushFakeIP() error {
_ = m.cacheIP.Clear()
return m.cacheHost.Clear()
m.cacheIP.Clear()
m.cacheHost.Clear()
return nil
}
func newMemoryStore(size int) *memoryStore {

View file

@ -8,7 +8,7 @@ import (
"github.com/metacubex/mihomo/common/nnip"
"github.com/metacubex/mihomo/component/profile/cachefile"
"github.com/metacubex/mihomo/component/trie"
C "github.com/metacubex/mihomo/constant"
)
const (
@ -35,7 +35,8 @@ type Pool struct {
offset netip.Addr
cycle bool
mux sync.Mutex
host *trie.DomainTrie[struct{}]
host []C.DomainMatcher
mode C.FilterMode
ipnet netip.Prefix
store store
}
@ -66,10 +67,20 @@ func (p *Pool) LookBack(ip netip.Addr) (string, bool) {
// ShouldSkipped return if domain should be skipped
func (p *Pool) ShouldSkipped(domain string) bool {
if p.host == nil {
return false
should := p.shouldSkipped(domain)
if p.mode == C.FilterWhiteList {
return !should
}
return p.host.Search(domain) != nil
return should
}
func (p *Pool) shouldSkipped(domain string) bool {
for _, matcher := range p.host {
if matcher.MatchDomain(domain) {
return true
}
}
return false
}
// Exist returns if given ip exists in fake-ip pool
@ -154,7 +165,8 @@ func (p *Pool) restoreState() {
type Options struct {
IPNet netip.Prefix
Host *trie.DomainTrie[struct{}]
Host []C.DomainMatcher
Mode C.FilterMode
// Size sets the maximum number of entries in memory
// and does not work if Persistence is true
@ -185,6 +197,7 @@ func New(options Options) (*Pool, error) {
offset: first.Prev(),
cycle: false,
host: options.Host,
mode: options.Mode,
ipnet: options.IPNet,
}
if options.Persistence {

View file

@ -9,8 +9,9 @@ import (
"github.com/metacubex/mihomo/component/profile/cachefile"
"github.com/metacubex/mihomo/component/trie"
C "github.com/metacubex/mihomo/constant"
"github.com/sagernet/bbolt"
"github.com/metacubex/bbolt"
"github.com/stretchr/testify/assert"
)
@ -62,13 +63,13 @@ func TestPool_Basic(t *testing.T) {
last := pool.Lookup("bar.com")
bar, exist := pool.LookBack(last)
assert.True(t, first == netip.AddrFrom4([4]byte{192, 168, 0, 4}))
assert.True(t, pool.Lookup("foo.com") == netip.AddrFrom4([4]byte{192, 168, 0, 4}))
assert.True(t, last == netip.AddrFrom4([4]byte{192, 168, 0, 5}))
assert.Equal(t, first, netip.AddrFrom4([4]byte{192, 168, 0, 4}))
assert.Equal(t, pool.Lookup("foo.com"), netip.AddrFrom4([4]byte{192, 168, 0, 4}))
assert.Equal(t, last, netip.AddrFrom4([4]byte{192, 168, 0, 5}))
assert.True(t, exist)
assert.Equal(t, bar, "bar.com")
assert.True(t, pool.Gateway() == netip.AddrFrom4([4]byte{192, 168, 0, 1}))
assert.True(t, pool.Broadcast() == netip.AddrFrom4([4]byte{192, 168, 0, 15}))
assert.Equal(t, pool.Gateway(), netip.AddrFrom4([4]byte{192, 168, 0, 1}))
assert.Equal(t, pool.Broadcast(), netip.AddrFrom4([4]byte{192, 168, 0, 15}))
assert.Equal(t, pool.IPNet().String(), ipnet.String())
assert.True(t, pool.Exist(netip.AddrFrom4([4]byte{192, 168, 0, 5})))
assert.False(t, pool.Exist(netip.AddrFrom4([4]byte{192, 168, 0, 6})))
@ -90,13 +91,13 @@ func TestPool_BasicV6(t *testing.T) {
last := pool.Lookup("bar.com")
bar, exist := pool.LookBack(last)
assert.True(t, first == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804"))
assert.True(t, pool.Lookup("foo.com") == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804"))
assert.True(t, last == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8805"))
assert.Equal(t, first, netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804"))
assert.Equal(t, pool.Lookup("foo.com"), netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804"))
assert.Equal(t, last, netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8805"))
assert.True(t, exist)
assert.Equal(t, bar, "bar.com")
assert.True(t, pool.Gateway() == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8801"))
assert.True(t, pool.Broadcast() == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8bff"))
assert.Equal(t, pool.Gateway(), netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8801"))
assert.Equal(t, pool.Broadcast(), netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8bff"))
assert.Equal(t, pool.IPNet().String(), ipnet.String())
assert.True(t, pool.Exist(netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8805")))
assert.False(t, pool.Exist(netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8806")))
@ -142,19 +143,20 @@ func TestPool_CycleUsed(t *testing.T) {
}
baz := pool.Lookup("baz.com")
next := pool.Lookup("foo.com")
assert.True(t, foo == baz)
assert.True(t, next == bar)
assert.Equal(t, foo, baz)
assert.Equal(t, next, bar)
}
}
func TestPool_Skip(t *testing.T) {
ipnet := netip.MustParsePrefix("192.168.0.1/29")
tree := trie.New[struct{}]()
tree.Insert("example.com", struct{}{})
assert.NoError(t, tree.Insert("example.com", struct{}{}))
assert.False(t, tree.IsEmpty())
pools, tempfile, err := createPools(Options{
IPNet: ipnet,
Size: 10,
Host: tree,
Host: []C.DomainMatcher{tree.NewDomainSet()},
})
assert.Nil(t, err)
defer os.Remove(tempfile)
@ -162,6 +164,28 @@ func TestPool_Skip(t *testing.T) {
for _, pool := range pools {
assert.True(t, pool.ShouldSkipped("example.com"))
assert.False(t, pool.ShouldSkipped("foo.com"))
assert.False(t, pool.shouldSkipped("baz.com"))
}
}
func TestPool_SkipWhiteList(t *testing.T) {
ipnet := netip.MustParsePrefix("192.168.0.1/29")
tree := trie.New[struct{}]()
assert.NoError(t, tree.Insert("example.com", struct{}{}))
assert.False(t, tree.IsEmpty())
pools, tempfile, err := createPools(Options{
IPNet: ipnet,
Size: 10,
Host: []C.DomainMatcher{tree.NewDomainSet()},
Mode: C.FilterWhiteList,
})
assert.Nil(t, err)
defer os.Remove(tempfile)
for _, pool := range pools {
assert.False(t, pool.ShouldSkipped("example.com"))
assert.True(t, pool.ShouldSkipped("foo.com"))
assert.True(t, pool.ShouldSkipped("baz.com"))
}
}
@ -177,7 +201,7 @@ func TestPool_MaxCacheSize(t *testing.T) {
pool.Lookup("baz.com")
next := pool.Lookup("foo.com")
assert.False(t, first == next)
assert.NotEqual(t, first, next)
}
func TestPool_DoubleMapping(t *testing.T) {
@ -207,7 +231,7 @@ func TestPool_DoubleMapping(t *testing.T) {
assert.False(t, bazExist)
assert.True(t, barExist)
assert.False(t, bazIP == newBazIP)
assert.NotEqual(t, bazIP, newBazIP)
}
func TestPool_Clone(t *testing.T) {
@ -219,8 +243,8 @@ func TestPool_Clone(t *testing.T) {
first := pool.Lookup("foo.com")
last := pool.Lookup("bar.com")
assert.True(t, first == netip.AddrFrom4([4]byte{192, 168, 0, 4}))
assert.True(t, last == netip.AddrFrom4([4]byte{192, 168, 0, 5}))
assert.Equal(t, first, netip.AddrFrom4([4]byte{192, 168, 0, 4}))
assert.Equal(t, last, netip.AddrFrom4([4]byte{192, 168, 0, 5}))
newPool, _ := New(Options{
IPNet: ipnet,
@ -265,13 +289,13 @@ func TestPool_FlushFileCache(t *testing.T) {
baz := pool.Lookup("foo.com")
nero := pool.Lookup("foo.com")
assert.True(t, foo == fox)
assert.True(t, foo == next)
assert.False(t, foo == baz)
assert.True(t, bar == bax)
assert.True(t, bar == baz)
assert.False(t, bar == next)
assert.True(t, baz == nero)
assert.Equal(t, foo, fox)
assert.Equal(t, foo, next)
assert.NotEqual(t, foo, baz)
assert.Equal(t, bar, bax)
assert.Equal(t, bar, baz)
assert.NotEqual(t, bar, next)
assert.Equal(t, baz, nero)
}
}
@ -294,11 +318,11 @@ func TestPool_FlushMemoryCache(t *testing.T) {
baz := pool.Lookup("foo.com")
nero := pool.Lookup("foo.com")
assert.True(t, foo == fox)
assert.True(t, foo == next)
assert.False(t, foo == baz)
assert.True(t, bar == bax)
assert.True(t, bar == baz)
assert.False(t, bar == next)
assert.True(t, baz == nero)
assert.Equal(t, foo, fox)
assert.Equal(t, foo, next)
assert.NotEqual(t, foo, baz)
assert.Equal(t, bar, bax)
assert.Equal(t, bar, baz)
assert.NotEqual(t, bar, next)
assert.Equal(t, baz, nero)
}

View file

@ -7,7 +7,7 @@ import (
)
type AttributeList struct {
matcher []AttributeMatcher
matcher []BooleanMatcher
}
func (al *AttributeList) Match(domain *router.Domain) bool {
@ -23,6 +23,14 @@ func (al *AttributeList) IsEmpty() bool {
return len(al.matcher) == 0
}
func (al *AttributeList) String() string {
matcher := make([]string, len(al.matcher))
for i, match := range al.matcher {
matcher[i] = string(match)
}
return strings.Join(matcher, ",")
}
func parseAttrs(attrs []string) *AttributeList {
al := new(AttributeList)
for _, attr := range attrs {

View file

@ -6,8 +6,10 @@ import (
"io"
"net/http"
"os"
"sync"
"time"
"github.com/metacubex/mihomo/common/atomic"
mihomoHttp "github.com/metacubex/mihomo/component/http"
"github.com/metacubex/mihomo/component/mmdb"
C "github.com/metacubex/mihomo/constant"
@ -18,12 +20,79 @@ var (
initGeoSite bool
initGeoIP int
initASN bool
initGeoSiteMutex sync.Mutex
initGeoIPMutex sync.Mutex
initASNMutex sync.Mutex
geoIpEnable atomic.Bool
geoSiteEnable atomic.Bool
asnEnable atomic.Bool
geoIpUrl string
mmdbUrl string
geoSiteUrl string
asnUrl string
)
func GeoIpUrl() string {
return geoIpUrl
}
func SetGeoIpUrl(url string) {
geoIpUrl = url
}
func MmdbUrl() string {
return mmdbUrl
}
func SetMmdbUrl(url string) {
mmdbUrl = url
}
func GeoSiteUrl() string {
return geoSiteUrl
}
func SetGeoSiteUrl(url string) {
geoSiteUrl = url
}
func ASNUrl() string {
return asnUrl
}
func SetASNUrl(url string) {
asnUrl = url
}
func downloadToPath(url string, path string) (err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*90)
defer cancel()
resp, err := mihomoHttp.HttpRequest(ctx, url, http.MethodGet, nil, nil)
if err != nil {
return
}
defer resp.Body.Close()
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(f, resp.Body)
return err
}
func InitGeoSite() error {
geoSiteEnable.Store(true)
initGeoSiteMutex.Lock()
defer initGeoSiteMutex.Unlock()
if _, err := os.Stat(C.Path.GeoSite()); os.IsNotExist(err) {
log.Infoln("Can't find GeoSite.dat, start download")
if err := downloadGeoSite(C.Path.GeoSite()); err != nil {
if err := downloadToPath(GeoSiteUrl(), C.Path.GeoSite()); err != nil {
return fmt.Errorf("can't download GeoSite.dat: %s", err.Error())
}
log.Infoln("Download GeoSite.dat finish")
@ -35,7 +104,7 @@ func InitGeoSite() error {
if err := os.Remove(C.Path.GeoSite()); err != nil {
return fmt.Errorf("can't remove invalid GeoSite.dat: %s", err.Error())
}
if err := downloadGeoSite(C.Path.GeoSite()); err != nil {
if err := downloadToPath(GeoSiteUrl(), C.Path.GeoSite()); err != nil {
return fmt.Errorf("can't download GeoSite.dat: %s", err.Error())
}
}
@ -44,49 +113,14 @@ func InitGeoSite() error {
return nil
}
func downloadGeoSite(path string) (err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*90)
defer cancel()
resp, err := mihomoHttp.HttpRequest(ctx, C.GeoSiteUrl, http.MethodGet, http.Header{"User-Agent": {C.UA}}, nil)
if err != nil {
return
}
defer resp.Body.Close()
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(f, resp.Body)
return err
}
func downloadGeoIP(path string) (err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*90)
defer cancel()
resp, err := mihomoHttp.HttpRequest(ctx, C.GeoIpUrl, http.MethodGet, http.Header{"User-Agent": {C.UA}}, nil)
if err != nil {
return
}
defer resp.Body.Close()
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(f, resp.Body)
return err
}
func InitGeoIP() error {
if C.GeodataMode {
geoIpEnable.Store(true)
initGeoIPMutex.Lock()
defer initGeoIPMutex.Unlock()
if GeodataMode() {
if _, err := os.Stat(C.Path.GeoIP()); os.IsNotExist(err) {
log.Infoln("Can't find GeoIP.dat, start download")
if err := downloadGeoIP(C.Path.GeoIP()); err != nil {
if err := downloadToPath(GeoIpUrl(), C.Path.GeoIP()); err != nil {
return fmt.Errorf("can't download GeoIP.dat: %s", err.Error())
}
log.Infoln("Download GeoIP.dat finish")
@ -99,7 +133,7 @@ func InitGeoIP() error {
if err := os.Remove(C.Path.GeoIP()); err != nil {
return fmt.Errorf("can't remove invalid GeoIP.dat: %s", err.Error())
}
if err := downloadGeoIP(C.Path.GeoIP()); err != nil {
if err := downloadToPath(GeoIpUrl(), C.Path.GeoIP()); err != nil {
return fmt.Errorf("can't download GeoIP.dat: %s", err.Error())
}
}
@ -110,7 +144,7 @@ func InitGeoIP() error {
if _, err := os.Stat(C.Path.MMDB()); os.IsNotExist(err) {
log.Infoln("Can't find MMDB, start download")
if err := mmdb.DownloadMMDB(C.Path.MMDB()); err != nil {
if err := downloadToPath(MmdbUrl(), C.Path.MMDB()); err != nil {
return fmt.Errorf("can't download MMDB: %s", err.Error())
}
}
@ -121,7 +155,7 @@ func InitGeoIP() error {
if err := os.Remove(C.Path.MMDB()); err != nil {
return fmt.Errorf("can't remove invalid MMDB: %s", err.Error())
}
if err := mmdb.DownloadMMDB(C.Path.MMDB()); err != nil {
if err := downloadToPath(MmdbUrl(), C.Path.MMDB()); err != nil {
return fmt.Errorf("can't download MMDB: %s", err.Error())
}
}
@ -131,9 +165,12 @@ func InitGeoIP() error {
}
func InitASN() error {
asnEnable.Store(true)
initASNMutex.Lock()
defer initASNMutex.Unlock()
if _, err := os.Stat(C.Path.ASN()); os.IsNotExist(err) {
log.Infoln("Can't find ASN.mmdb, start download")
if err := mmdb.DownloadASN(C.Path.ASN()); err != nil {
if err := downloadToPath(ASNUrl(), C.Path.ASN()); err != nil {
return fmt.Errorf("can't download ASN.mmdb: %s", err.Error())
}
log.Infoln("Download ASN.mmdb finish")
@ -145,7 +182,7 @@ func InitASN() error {
if err := os.Remove(C.Path.ASN()); err != nil {
return fmt.Errorf("can't remove invalid ASN: %s", err.Error())
}
if err := mmdb.DownloadASN(C.Path.ASN()); err != nil {
if err := downloadToPath(ASNUrl(), C.Path.ASN()); err != nil {
return fmt.Errorf("can't download ASN: %s", err.Error())
}
}
@ -153,3 +190,15 @@ func InitASN() error {
}
return nil
}
func GeoIpEnable() bool {
return geoIpEnable.Load()
}
func GeoSiteEnable() bool {
return geoSiteEnable.Load()
}
func ASNEnable() bool {
return asnEnable.Load()
}

View file

@ -33,12 +33,13 @@ func domainToMatcher(domain *Domain) (strmatcher.Matcher, error) {
type DomainMatcher interface {
ApplyDomain(string) bool
Count() int
}
type succinctDomainMatcher struct {
set *trie.DomainSet
otherMatchers []strmatcher.Matcher
not bool
count int
}
func (m *succinctDomainMatcher) ApplyDomain(domain string) bool {
@ -51,16 +52,17 @@ func (m *succinctDomainMatcher) ApplyDomain(domain string) bool {
}
}
}
if m.not {
isMatched = !isMatched
}
return isMatched
}
func NewSuccinctMatcherGroup(domains []*Domain, not bool) (DomainMatcher, error) {
func (m *succinctDomainMatcher) Count() int {
return m.count
}
func NewSuccinctMatcherGroup(domains []*Domain) (DomainMatcher, error) {
t := trie.New[struct{}]()
m := &succinctDomainMatcher{
not: not,
count: len(domains),
}
for _, d := range domains {
switch d.Type {
@ -90,10 +92,10 @@ func NewSuccinctMatcherGroup(domains []*Domain, not bool) (DomainMatcher, error)
type v2rayDomainMatcher struct {
matchers strmatcher.IndexMatcher
not bool
count int
}
func NewMphMatcherGroup(domains []*Domain, not bool) (DomainMatcher, error) {
func NewMphMatcherGroup(domains []*Domain) (DomainMatcher, error) {
g := strmatcher.NewMphMatcherGroup()
for _, d := range domains {
matcherType, f := matcherTypeMap[d.Type]
@ -108,119 +110,80 @@ func NewMphMatcherGroup(domains []*Domain, not bool) (DomainMatcher, error) {
g.Build()
return &v2rayDomainMatcher{
matchers: g,
not: not,
count: len(domains),
}, nil
}
func (m *v2rayDomainMatcher) ApplyDomain(domain string) bool {
isMatched := len(m.matchers.Match(strings.ToLower(domain))) > 0
if m.not {
isMatched = !isMatched
}
return isMatched
return len(m.matchers.Match(strings.ToLower(domain))) > 0
}
type GeoIPMatcher struct {
countryCode string
reverseMatch bool
cidrSet *cidr.IpCidrSet
func (m *v2rayDomainMatcher) Count() int {
return m.count
}
func (m *GeoIPMatcher) Init(cidrs []*CIDR) error {
for _, cidr := range cidrs {
addr, ok := netip.AddrFromSlice(cidr.Ip)
if !ok {
return fmt.Errorf("error when loading GeoIP: invalid IP: %s", cidr.Ip)
}
err := m.cidrSet.AddIpCidr(netip.PrefixFrom(addr, int(cidr.Prefix)))
if err != nil {
return fmt.Errorf("error when loading GeoIP: %w", err)
}
}
return m.cidrSet.Merge()
type notDomainMatcher struct {
DomainMatcher
}
func (m *GeoIPMatcher) SetReverseMatch(isReverseMatch bool) {
m.reverseMatch = isReverseMatch
func (m notDomainMatcher) ApplyDomain(domain string) bool {
return !m.DomainMatcher.ApplyDomain(domain)
}
func NewNotDomainMatcherGroup(matcher DomainMatcher) DomainMatcher {
return notDomainMatcher{matcher}
}
type IPMatcher interface {
Match(ip netip.Addr) bool
Count() int
}
type geoIPMatcher struct {
cidrSet *cidr.IpCidrSet
count int
}
// Match returns true if the given ip is included by the GeoIP.
func (m *GeoIPMatcher) Match(ip netip.Addr) bool {
match := m.cidrSet.IsContain(ip)
if m.reverseMatch {
return !match
func (m *geoIPMatcher) Match(ip netip.Addr) bool {
return m.cidrSet.IsContain(ip)
}
func (m *geoIPMatcher) Count() int {
return m.count
}
func NewGeoIPMatcher(cidrList []*CIDR) (IPMatcher, error) {
m := &geoIPMatcher{
cidrSet: cidr.NewIpCidrSet(),
count: len(cidrList),
}
return match
}
// GeoIPMatcherContainer is a container for GeoIPMatchers. It keeps unique copies of GeoIPMatcher by country code.
type GeoIPMatcherContainer struct {
matchers []*GeoIPMatcher
}
// Add adds a new GeoIP set into the container.
// If the country code of GeoIP is not empty, GeoIPMatcherContainer will try to find an existing one, instead of adding a new one.
func (c *GeoIPMatcherContainer) Add(geoip *GeoIP) (*GeoIPMatcher, error) {
if len(geoip.CountryCode) > 0 {
for _, m := range c.matchers {
if m.countryCode == geoip.CountryCode && m.reverseMatch == geoip.ReverseMatch {
return m, nil
}
for _, cidr := range cidrList {
addr, ok := netip.AddrFromSlice(cidr.Ip)
if !ok {
return nil, fmt.Errorf("error when loading GeoIP: invalid IP: %s", cidr.Ip)
}
err := m.cidrSet.AddIpCidr(netip.PrefixFrom(addr, int(cidr.Prefix)))
if err != nil {
return nil, fmt.Errorf("error when loading GeoIP: %w", err)
}
}
m := &GeoIPMatcher{
countryCode: geoip.CountryCode,
reverseMatch: geoip.ReverseMatch,
cidrSet: cidr.NewIpCidrSet(),
}
if err := m.Init(geoip.Cidr); err != nil {
return nil, err
}
if len(geoip.CountryCode) > 0 {
c.matchers = append(c.matchers, m)
}
return m, nil
}
var globalGeoIPContainer GeoIPMatcherContainer
type MultiGeoIPMatcher struct {
matchers []*GeoIPMatcher
}
func NewGeoIPMatcher(geoip *GeoIP) (*GeoIPMatcher, error) {
matcher, err := globalGeoIPContainer.Add(geoip)
err := m.cidrSet.Merge()
if err != nil {
return nil, err
}
return matcher, nil
return m, nil
}
func (m *MultiGeoIPMatcher) ApplyIp(ip netip.Addr) bool {
for _, matcher := range m.matchers {
if matcher.Match(ip) {
return true
}
}
return false
type notIPMatcher struct {
IPMatcher
}
func NewMultiGeoIPMatcher(geoips []*GeoIP) (*MultiGeoIPMatcher, error) {
var matchers []*GeoIPMatcher
for _, geoip := range geoips {
matcher, err := globalGeoIPContainer.Add(geoip)
if err != nil {
return nil, err
}
matchers = append(matchers, matcher)
}
matcher := &MultiGeoIPMatcher{
matchers: matchers,
}
return matcher, nil
func (m notIPMatcher) Match(ip netip.Addr) bool {
return !m.IPMatcher.Match(ip)
}
func NewNotIpMatcherGroup(matcher IPMatcher) IPMatcher {
return notIPMatcher{matcher}
}

View file

@ -5,8 +5,7 @@ import (
"fmt"
"strings"
"golang.org/x/sync/singleflight"
"github.com/metacubex/mihomo/common/singleflight"
"github.com/metacubex/mihomo/component/geodata/router"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
@ -14,8 +13,6 @@ import (
var (
geoMode bool
AutoUpdate bool
UpdateInterval int
geoLoaderName = "memconservative"
geoSiteMatcher = "succinct"
)
@ -26,14 +23,6 @@ func GeodataMode() bool {
return geoMode
}
func GeoAutoUpdate() bool {
return AutoUpdate
}
func GeoUpdateInterval() int {
return UpdateInterval
}
func LoaderName() string {
return geoLoaderName
}
@ -45,12 +34,6 @@ func SiteMatcherName() string {
func SetGeodataMode(newGeodataMode bool) {
geoMode = newGeodataMode
}
func SetGeoAutoUpdate(newAutoUpdate bool) {
AutoUpdate = newAutoUpdate
}
func SetGeoUpdateInterval(newGeoUpdateInterval int) {
UpdateInterval = newGeoUpdateInterval
}
func SetLoader(newLoader string) {
if newLoader == "memc" {
@ -71,21 +54,22 @@ func SetSiteMatcher(newMatcher string) {
func Verify(name string) error {
switch name {
case C.GeositeName:
_, _, err := LoadGeoSiteMatcher("CN")
_, err := LoadGeoSiteMatcher("CN")
return err
case C.GeoipName:
_, _, err := LoadGeoIPMatcher("CN")
_, err := LoadGeoIPMatcher("CN")
return err
default:
return fmt.Errorf("not support name")
}
}
var loadGeoSiteMatcherSF = singleflight.Group{}
var loadGeoSiteMatcherListSF = singleflight.Group[[]*router.Domain]{StoreResult: true}
var loadGeoSiteMatcherSF = singleflight.Group[router.DomainMatcher]{StoreResult: true}
func LoadGeoSiteMatcher(countryCode string) (router.DomainMatcher, int, error) {
func LoadGeoSiteMatcher(countryCode string) (router.DomainMatcher, error) {
if countryCode == "" {
return nil, 0, fmt.Errorf("country code could not be empty")
return nil, fmt.Errorf("country code could not be empty")
}
not := false
@ -97,73 +81,84 @@ func LoadGeoSiteMatcher(countryCode string) (router.DomainMatcher, int, error) {
parts := strings.Split(countryCode, "@")
if len(parts) == 0 {
return nil, 0, errors.New("empty rule")
return nil, errors.New("empty rule")
}
listName := strings.TrimSpace(parts[0])
attrVal := parts[1:]
attrs := parseAttrs(attrVal)
if listName == "" {
return nil, 0, fmt.Errorf("empty listname in rule: %s", countryCode)
return nil, fmt.Errorf("empty listname in rule: %s", countryCode)
}
v, err, shared := loadGeoSiteMatcherSF.Do(listName, func() (interface{}, error) {
geoLoader, err := GetGeoDataLoader(geoLoaderName)
matcherName := listName
if !attrs.IsEmpty() {
matcherName += "@" + attrs.String()
}
matcher, err, shared := loadGeoSiteMatcherSF.Do(matcherName, func() (router.DomainMatcher, error) {
log.Infoln("Load GeoSite rule: %s", matcherName)
domains, err, shared := loadGeoSiteMatcherListSF.Do(listName, func() ([]*router.Domain, error) {
geoLoader, err := GetGeoDataLoader(geoLoaderName)
if err != nil {
return nil, err
}
return geoLoader.LoadGeoSite(listName)
})
if err != nil {
if !shared {
loadGeoSiteMatcherListSF.Forget(listName) // don't store the error result
}
return nil, err
}
return geoLoader.LoadGeoSite(listName)
if attrs.IsEmpty() {
if strings.Contains(countryCode, "@") {
log.Warnln("empty attribute list: %s", countryCode)
}
} else {
filteredDomains := make([]*router.Domain, 0, len(domains))
hasAttrMatched := false
for _, domain := range domains {
if attrs.Match(domain) {
hasAttrMatched = true
filteredDomains = append(filteredDomains, domain)
}
}
if !hasAttrMatched {
log.Warnln("attribute match no rule: geosite: %s", countryCode)
}
domains = filteredDomains
}
/**
linear: linear algorithm
matcher, err := router.NewDomainMatcher(domains)
mphminimal perfect hash algorithm
*/
if geoSiteMatcher == "mph" {
return router.NewMphMatcherGroup(domains)
} else {
return router.NewSuccinctMatcherGroup(domains)
}
})
if err != nil {
if !shared {
loadGeoSiteMatcherSF.Forget(listName) // don't store the error result
loadGeoSiteMatcherSF.Forget(matcherName) // don't store the error result
}
return nil, 0, err
return nil, err
}
domains := v.([]*router.Domain)
attrs := parseAttrs(attrVal)
if attrs.IsEmpty() {
if strings.Contains(countryCode, "@") {
log.Warnln("empty attribute list: %s", countryCode)
}
} else {
filteredDomains := make([]*router.Domain, 0, len(domains))
hasAttrMatched := false
for _, domain := range domains {
if attrs.Match(domain) {
hasAttrMatched = true
filteredDomains = append(filteredDomains, domain)
}
}
if !hasAttrMatched {
log.Warnln("attribute match no rule: geosite: %s", countryCode)
}
domains = filteredDomains
if not {
matcher = router.NewNotDomainMatcherGroup(matcher)
}
/**
linear: linear algorithm
matcher, err := router.NewDomainMatcher(domains)
mphminimal perfect hash algorithm
*/
var matcher router.DomainMatcher
if geoSiteMatcher == "mph" {
matcher, err = router.NewMphMatcherGroup(domains, not)
} else {
matcher, err = router.NewSuccinctMatcherGroup(domains, not)
}
if err != nil {
return nil, 0, err
}
return matcher, len(domains), nil
return matcher, nil
}
var loadGeoIPMatcherSF = singleflight.Group{}
var loadGeoIPMatcherSF = singleflight.Group[router.IPMatcher]{StoreResult: true}
func LoadGeoIPMatcher(country string) (*router.GeoIPMatcher, int, error) {
func LoadGeoIPMatcher(country string) (router.IPMatcher, error) {
if len(country) == 0 {
return nil, 0, fmt.Errorf("country code could not be empty")
return nil, fmt.Errorf("country code could not be empty")
}
not := false
@ -173,35 +168,36 @@ func LoadGeoIPMatcher(country string) (*router.GeoIPMatcher, int, error) {
}
country = strings.ToLower(country)
v, err, shared := loadGeoIPMatcherSF.Do(country, func() (interface{}, error) {
matcher, err, shared := loadGeoIPMatcherSF.Do(country, func() (router.IPMatcher, error) {
log.Infoln("Load GeoIP rule: %s", country)
geoLoader, err := GetGeoDataLoader(geoLoaderName)
if err != nil {
return nil, err
}
return geoLoader.LoadGeoIP(country)
cidrList, err := geoLoader.LoadGeoIP(country)
if err != nil {
return nil, err
}
return router.NewGeoIPMatcher(cidrList)
})
if err != nil {
if !shared {
loadGeoIPMatcherSF.Forget(country) // don't store the error result
log.Warnln("Load GeoIP rule: %s", country)
}
return nil, 0, err
return nil, err
}
records := v.([]*router.CIDR)
geoIP := &router.GeoIP{
CountryCode: country,
Cidr: records,
ReverseMatch: not,
if not {
matcher = router.NewNotIpMatcherGroup(matcher)
}
matcher, err := router.NewGeoIPMatcher(geoIP)
if err != nil {
return nil, 0, err
}
return matcher, len(records), nil
return matcher, nil
}
func ClearCache() {
loadGeoSiteMatcherSF = singleflight.Group{}
loadGeoIPMatcherSF = singleflight.Group{}
func ClearGeoSiteCache() {
loadGeoSiteMatcherListSF.Reset()
loadGeoSiteMatcherSF.Reset()
}
func ClearGeoIPCache() {
loadGeoIPMatcherSF.Reset()
}

View file

@ -12,10 +12,21 @@ import (
"time"
"github.com/metacubex/mihomo/component/ca"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/listener/inner"
)
var (
ua string
)
func UA() string {
return ua
}
func SetUA(UA string) {
ua = UA
}
func HttpRequest(ctx context.Context, url, method string, header map[string][]string, body io.Reader) (*http.Response, error) {
return HttpRequestWithProxy(ctx, url, method, header, body, "")
}
@ -35,7 +46,7 @@ func HttpRequestWithProxy(ctx context.Context, url, method string, header map[st
}
if _, ok := header["User-Agent"]; !ok {
req.Header.Set("User-Agent", C.UA)
req.Header.Set("User-Agent", UA())
}
if err != nil {

View file

@ -1,15 +1,9 @@
package mmdb
import (
"context"
"io"
"net/http"
"os"
"sync"
"time"
mihomoOnce "github.com/metacubex/mihomo/common/once"
mihomoHttp "github.com/metacubex/mihomo/component/http"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
@ -25,26 +19,26 @@ const (
)
var (
IPreader IPReader
ASNreader ASNReader
IPonce sync.Once
ASNonce sync.Once
ipReader IPReader
asnReader ASNReader
ipOnce sync.Once
asnOnce sync.Once
)
func LoadFromBytes(buffer []byte) {
IPonce.Do(func() {
ipOnce.Do(func() {
mmdb, err := maxminddb.FromBytes(buffer)
if err != nil {
log.Fatalln("Can't load mmdb: %s", err.Error())
}
IPreader = IPReader{Reader: mmdb}
ipReader = IPReader{Reader: mmdb}
switch mmdb.Metadata.DatabaseType {
case "sing-geoip":
IPreader.databaseType = typeSing
ipReader.databaseType = typeSing
case "Meta-geoip0":
IPreader.databaseType = typeMetaV0
ipReader.databaseType = typeMetaV0
default:
IPreader.databaseType = typeMaxmind
ipReader.databaseType = typeMaxmind
}
})
}
@ -58,83 +52,45 @@ func Verify(path string) bool {
}
func IPInstance() IPReader {
IPonce.Do(func() {
ipOnce.Do(func() {
mmdbPath := C.Path.MMDB()
log.Infoln("Load MMDB file: %s", mmdbPath)
mmdb, err := maxminddb.Open(mmdbPath)
if err != nil {
log.Fatalln("Can't load MMDB: %s", err.Error())
}
IPreader = IPReader{Reader: mmdb}
ipReader = IPReader{Reader: mmdb}
switch mmdb.Metadata.DatabaseType {
case "sing-geoip":
IPreader.databaseType = typeSing
ipReader.databaseType = typeSing
case "Meta-geoip0":
IPreader.databaseType = typeMetaV0
ipReader.databaseType = typeMetaV0
default:
IPreader.databaseType = typeMaxmind
ipReader.databaseType = typeMaxmind
}
})
return IPreader
}
func DownloadMMDB(path string) (err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*90)
defer cancel()
resp, err := mihomoHttp.HttpRequest(ctx, C.MmdbUrl, http.MethodGet, http.Header{"User-Agent": {C.UA}}, nil)
if err != nil {
return
}
defer resp.Body.Close()
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(f, resp.Body)
return err
return ipReader
}
func ASNInstance() ASNReader {
ASNonce.Do(func() {
asnOnce.Do(func() {
ASNPath := C.Path.ASN()
log.Infoln("Load ASN file: %s", ASNPath)
asn, err := maxminddb.Open(ASNPath)
if err != nil {
log.Fatalln("Can't load ASN: %s", err.Error())
}
ASNreader = ASNReader{Reader: asn}
asnReader = ASNReader{Reader: asn}
})
return ASNreader
}
func DownloadASN(path string) (err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*90)
defer cancel()
resp, err := mihomoHttp.HttpRequest(ctx, C.ASNUrl, http.MethodGet, http.Header{"User-Agent": {C.UA}}, nil)
if err != nil {
return
}
defer resp.Body.Close()
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(f, resp.Body)
return err
return asnReader
}
func ReloadIP() {
mihomoOnce.Reset(&IPonce)
mihomoOnce.Reset(&ipOnce)
}
func ReloadASN() {
mihomoOnce.Reset(&ASNonce)
mihomoOnce.Reset(&asnOnce)
}

View file

@ -1,18 +0,0 @@
//go:build android && cmfa
package mmdb
import "github.com/oschwald/maxminddb-golang"
func InstallOverride(override *maxminddb.Reader) {
newReader := IPReader{Reader: override}
switch override.Metadata.DatabaseType {
case "sing-geoip":
IPreader.databaseType = typeSing
case "Meta-geoip0":
IPreader.databaseType = typeMetaV0
default:
IPreader.databaseType = typeMaxmind
}
IPreader = newReader
}

View file

@ -55,6 +55,11 @@ func NewEventListener(cb func(Type)) (func(), error) {
}
handle := uintptr(0)
// DWORD PowerRegisterSuspendResumeNotification(
// [in] DWORD Flags,
// [in] HANDLE Recipient,
// [out] PHPOWERNOTIFY RegistrationHandle
//);
_, _, err := powerRegisterSuspendResumeNotification.Call(
_DEVICE_NOTIFY_CALLBACK,
uintptr(unsafe.Pointer(&params)),
@ -65,8 +70,11 @@ func NewEventListener(cb func(Type)) (func(), error) {
}
return func() {
// DWORD PowerUnregisterSuspendResumeNotification(
// [in, out] HPOWERNOTIFY RegistrationHandle
//);
_, _, _ = powerUnregisterSuspendResumeNotification.Call(
uintptr(unsafe.Pointer(&handle)),
handle,
)
runtime.KeepAlive(params)
runtime.KeepAlive(handle)

View file

@ -3,6 +3,8 @@ package process
import (
"errors"
"net/netip"
C "github.com/metacubex/mihomo/constant"
)
var (
@ -19,3 +21,18 @@ const (
func FindProcessName(network string, srcIP netip.Addr, srcPort int) (uint32, string, error) {
return findProcessName(network, srcIP, srcPort)
}
// PackageNameResolver
// never change type traits because it's used in CMFA
type PackageNameResolver func(metadata *C.Metadata) (string, error)
// DefaultPackageNameResolver
// never change type traits because it's used in CMFA
var DefaultPackageNameResolver PackageNameResolver
func FindPackageName(metadata *C.Metadata) (string, error) {
if resolver := DefaultPackageNameResolver; resolver != nil {
return resolver(metadata)
}
return "", ErrPlatformNotSupport
}

View file

@ -1,16 +0,0 @@
//go:build android && cmfa
package process
import "github.com/metacubex/mihomo/constant"
type PackageNameResolver func(metadata *constant.Metadata) (string, error)
var DefaultPackageNameResolver PackageNameResolver
func FindPackageName(metadata *constant.Metadata) (string, error) {
if resolver := DefaultPackageNameResolver; resolver != nil {
return resolver(metadata)
}
return "", ErrPlatformNotSupport
}

View file

@ -1,9 +0,0 @@
//go:build !(android && cmfa)
package process
import "github.com/metacubex/mihomo/constant"
func FindPackageName(metadata *constant.Metadata) (string, error) {
return "", nil
}

View file

@ -46,12 +46,12 @@ func findProcessName(network string, ip netip.Addr, port int) (uint32, string, e
isIPv4 := ip.Is4()
value, err := syscall.Sysctl(spath)
value, err := unix.SysctlRaw(spath)
if err != nil {
return 0, "", err
}
buf := []byte(value)
buf := value
itemSize := structSize
if network == TCP {
// rup8(sizeof(xtcpcb_n))

View file

@ -2,23 +2,19 @@ package process
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"net/netip"
"os"
"path"
"path/filepath"
"runtime"
"strings"
"sync"
"syscall"
"unicode"
"unsafe"
"github.com/metacubex/mihomo/log"
"github.com/mdlayher/netlink"
tun "github.com/metacubex/sing-tun"
"golang.org/x/sys/unix"
)
@ -63,25 +59,11 @@ type inetDiagResponse struct {
INode uint32
}
type MyCallback struct{}
var (
packageManager tun.PackageManager
once sync.Once
)
func (cb *MyCallback) OnPackagesUpdated(packageCount int, sharedCount int) {}
func (cb *MyCallback) NewError(ctx context.Context, err error) {
log.Warnln("%s", err)
}
func findProcessName(network string, ip netip.Addr, srcPort int) (uint32, string, error) {
uid, inode, err := resolveSocketByNetlink(network, ip, srcPort)
if err != nil {
return 0, "", err
}
pp, err := resolveProcessNameByProcSearch(inode, uid)
return uid, pp, err
}
@ -177,44 +159,38 @@ func resolveProcessNameByProcSearch(inode, uid uint32) (string, error) {
if err != nil {
continue
}
if runtime.GOOS == "android" {
if bytes.Equal(buffer[:n], socket) {
return findPackageName(uid), nil
cmdline, err := os.ReadFile(path.Join(processPath, "cmdline"))
if err != nil {
return "", err
}
return splitCmdline(cmdline), nil
}
} else {
if bytes.Equal(buffer[:n], socket) {
return os.Readlink(filepath.Join(processPath, "exe"))
}
}
}
}
return "", fmt.Errorf("process of uid(%d),inode(%d) not found", uid, inode)
}
func findPackageName(uid uint32) string {
once.Do(func() {
callback := &MyCallback{}
var err error
packageManager, err = tun.NewPackageManager(callback)
if err != nil {
log.Warnln("%s", err)
}
err = packageManager.Start()
if err != nil {
log.Warnln("%s", err)
return
}
func splitCmdline(cmdline []byte) string {
cmdline = bytes.Trim(cmdline, " ")
idx := bytes.IndexFunc(cmdline, func(r rune) bool {
return unicode.IsControl(r) || unicode.IsSpace(r)
})
if sharedPackage, loaded := packageManager.SharedPackageByID(uid % 100000); loaded {
return sharedPackage
if idx == -1 {
return filepath.Base(string(cmdline))
}
if packageName, loaded := packageManager.PackageByID(uid % 100000); loaded {
return packageName
}
return ""
return filepath.Base(string(cmdline[:idx]))
}
func isPid(s string) bool {

View file

@ -1,6 +1,7 @@
package cachefile
import (
"math"
"os"
"sync"
"time"
@ -9,7 +10,7 @@ import (
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
"github.com/sagernet/bbolt"
"github.com/metacubex/bbolt"
)
var (
@ -19,6 +20,7 @@ var (
bucketSelected = []byte("selected")
bucketFakeip = []byte("fakeip")
bucketETag = []byte("etag")
)
// CacheFile store and update the cache file
@ -143,6 +145,59 @@ func (c *CacheFile) FlushFakeIP() error {
return err
}
func (c *CacheFile) SetETagWithHash(url string, hash []byte, etag string) {
if c.DB == nil {
return
}
lenHash := len(hash)
if lenHash > math.MaxUint8 {
return // maybe panic is better
}
data := make([]byte, 1, 1+lenHash+len(etag))
data[0] = uint8(lenHash)
data = append(data, hash...)
data = append(data, etag...)
err := c.DB.Batch(func(t *bbolt.Tx) error {
bucket, err := t.CreateBucketIfNotExists(bucketETag)
if err != nil {
return err
}
return bucket.Put([]byte(url), data)
})
if err != nil {
log.Warnln("[CacheFile] write cache to %s failed: %s", c.DB.Path(), err.Error())
return
}
}
func (c *CacheFile) GetETagWithHash(key string) (hash []byte, etag string) {
if c.DB == nil {
return
}
var value []byte
c.DB.View(func(t *bbolt.Tx) error {
if bucket := t.Bucket(bucketETag); bucket != nil {
if v := bucket.Get([]byte(key)); v != nil {
value = v
}
}
return nil
})
if len(value) == 0 {
return
}
lenHash := int(value[0])
if len(value) < 1+lenHash {
return
}
hash = value[1 : 1+lenHash]
etag = string(value[1+lenHash:])
return
}
func (c *CacheFile) Close() error {
return c.DB.Close()
}

View file

@ -46,6 +46,7 @@ type Resolver interface {
LookupIPv6(ctx context.Context, host string) (ips []netip.Addr, err error)
ExchangeContext(ctx context.Context, m *dns.Msg) (msg *dns.Msg, err error)
Invalid() bool
ClearCache()
}
// LookupIPv4WithResolver same as LookupIPv4, but with a resolver

View file

@ -1,35 +1,31 @@
package resource
import (
"bytes"
"crypto/md5"
"context"
"os"
"path/filepath"
"time"
types "github.com/metacubex/mihomo/constant/provider"
"github.com/metacubex/mihomo/log"
"github.com/sagernet/fswatch"
"github.com/samber/lo"
)
var (
fileMode os.FileMode = 0o666
dirMode os.FileMode = 0o755
)
type Parser[V any] func([]byte) (V, error)
type Fetcher[V any] struct {
ctx context.Context
ctxCancel context.CancelFunc
resourceType string
name string
vehicle types.Vehicle
UpdatedAt time.Time
done chan struct{}
hash [16]byte
updatedAt time.Time
hash types.HashType
parser Parser[V]
interval time.Duration
OnUpdate func(V)
onUpdate func(V)
watcher *fswatch.Watcher
}
func (f *Fetcher[V]) Name() string {
@ -44,93 +40,69 @@ func (f *Fetcher[V]) VehicleType() types.VehicleType {
return f.vehicle.Type()
}
func (f *Fetcher[V]) UpdatedAt() time.Time {
return f.updatedAt
}
func (f *Fetcher[V]) Initial() (V, error) {
var (
buf []byte
err error
isLocal bool
forceUpdate bool
buf []byte
contents V
err error
)
if stat, fErr := os.Stat(f.vehicle.Path()); fErr == nil {
// local file exists, use it first
buf, err = os.ReadFile(f.vehicle.Path())
modTime := stat.ModTime()
f.UpdatedAt = modTime
isLocal = true
if f.interval != 0 && modTime.Add(f.interval).Before(time.Now()) {
log.Warnln("[Provider] %s not updated for a long time, force refresh", f.Name())
forceUpdate = true
contents, _, err = f.loadBuf(buf, types.MakeHash(buf), false)
f.updatedAt = modTime // reset updatedAt to file's modTime
if err == nil {
err = f.startPullLoop(time.Since(modTime) > f.interval)
if err != nil {
return lo.Empty[V](), err
}
return contents, nil
}
} else {
buf, err = f.vehicle.Read()
f.UpdatedAt = time.Now()
}
// parse local file error, fallback to remote
contents, _, err = f.Update()
if err != nil {
return lo.Empty[V](), err
}
var contents V
if forceUpdate {
var forceBuf []byte
if forceBuf, err = f.vehicle.Read(); err == nil {
if contents, err = f.parser(forceBuf); err == nil {
isLocal = false
buf = forceBuf
}
}
}
if err != nil || !forceUpdate {
contents, err = f.parser(buf)
}
err = f.startPullLoop(false)
if err != nil {
if !isLocal {
return lo.Empty[V](), err
}
// parse local file error, fallback to remote
buf, err = f.vehicle.Read()
if err != nil {
return lo.Empty[V](), err
}
contents, err = f.parser(buf)
if err != nil {
return lo.Empty[V](), err
}
isLocal = false
return lo.Empty[V](), err
}
if f.vehicle.Type() != types.File && !isLocal {
if err := safeWrite(f.vehicle.Path(), buf); err != nil {
return lo.Empty[V](), err
}
}
f.hash = md5.Sum(buf)
// pull contents automatically
if f.interval > 0 {
go f.pullLoop()
}
return contents, nil
}
func (f *Fetcher[V]) Update() (V, bool, error) {
buf, err := f.vehicle.Read()
buf, hash, err := f.vehicle.Read(f.ctx, f.hash)
if err != nil {
return lo.Empty[V](), false, err
}
return f.loadBuf(buf, hash, f.vehicle.Type() != types.File)
}
func (f *Fetcher[V]) SideUpdate(buf []byte) (V, bool, error) {
return f.loadBuf(buf, types.MakeHash(buf), true)
}
func (f *Fetcher[V]) loadBuf(buf []byte, hash types.HashType, updateFile bool) (V, bool, error) {
now := time.Now()
hash := md5.Sum(buf)
if bytes.Equal(f.hash[:], hash[:]) {
f.UpdatedAt = now
_ = os.Chtimes(f.vehicle.Path(), now, now)
if f.hash.Equal(hash) {
if updateFile {
_ = os.Chtimes(f.vehicle.Path(), now, now)
}
f.updatedAt = now
return lo.Empty[V](), true, nil
}
if buf == nil { // f.hash has been changed between f.vehicle.Read but should not happen (cause by concurrent)
return lo.Empty[V](), true, nil
}
@ -139,78 +111,103 @@ func (f *Fetcher[V]) Update() (V, bool, error) {
return lo.Empty[V](), false, err
}
if f.vehicle.Type() != types.File {
if err := safeWrite(f.vehicle.Path(), buf); err != nil {
if updateFile {
if err = f.vehicle.Write(buf); err != nil {
return lo.Empty[V](), false, err
}
}
f.UpdatedAt = now
f.updatedAt = now
f.hash = hash
if f.onUpdate != nil {
f.onUpdate(contents)
}
return contents, false, nil
}
func (f *Fetcher[V]) Destroy() error {
if f.interval > 0 {
f.done <- struct{}{}
func (f *Fetcher[V]) Close() error {
f.ctxCancel()
if f.watcher != nil {
_ = f.watcher.Close()
}
return nil
}
func (f *Fetcher[V]) pullLoop() {
initialInterval := f.interval - time.Since(f.UpdatedAt)
func (f *Fetcher[V]) pullLoop(forceUpdate bool) {
initialInterval := f.interval - time.Since(f.updatedAt)
if initialInterval > f.interval {
initialInterval = f.interval
}
if forceUpdate {
log.Warnln("[Provider] %s not updated for a long time, force refresh", f.Name())
f.updateWithLog()
}
timer := time.NewTimer(initialInterval)
defer timer.Stop()
for {
select {
case <-timer.C:
timer.Reset(f.interval)
elm, same, err := f.Update()
if err != nil {
log.Errorln("[Provider] %s pull error: %s", f.Name(), err.Error())
continue
}
if same {
log.Debugln("[Provider] %s's content doesn't change", f.Name())
continue
}
log.Infoln("[Provider] %s's content update", f.Name())
if f.OnUpdate != nil {
f.OnUpdate(elm)
}
case <-f.done:
f.updateWithLog()
case <-f.ctx.Done():
return
}
}
}
func safeWrite(path string, buf []byte) error {
dir := filepath.Dir(path)
if _, err := os.Stat(dir); os.IsNotExist(err) {
if err := os.MkdirAll(dir, dirMode); err != nil {
func (f *Fetcher[V]) startPullLoop(forceUpdate bool) (err error) {
// pull contents automatically
if f.vehicle.Type() == types.File {
f.watcher, err = fswatch.NewWatcher(fswatch.Options{
Path: []string{f.vehicle.Path()},
Direct: true,
Callback: f.updateCallback,
})
if err != nil {
return err
}
err = f.watcher.Start()
if err != nil {
return err
}
} else if f.interval > 0 {
go f.pullLoop(forceUpdate)
}
return
}
func (f *Fetcher[V]) updateCallback(path string) {
f.updateWithLog()
}
func (f *Fetcher[V]) updateWithLog() {
_, same, err := f.Update()
if err != nil {
log.Errorln("[Provider] %s pull error: %s", f.Name(), err.Error())
return
}
return os.WriteFile(path, buf, fileMode)
if same {
log.Debugln("[Provider] %s's content doesn't change", f.Name())
return
}
log.Infoln("[Provider] %s's content update", f.Name())
return
}
func NewFetcher[V any](name string, interval time.Duration, vehicle types.Vehicle, parser Parser[V], onUpdate func(V)) *Fetcher[V] {
ctx, cancel := context.WithCancel(context.Background())
return &Fetcher[V]{
name: name,
vehicle: vehicle,
parser: parser,
done: make(chan struct{}, 8),
OnUpdate: onUpdate,
interval: interval,
ctx: ctx,
ctxCancel: cancel,
name: name,
vehicle: vehicle,
parser: parser,
onUpdate: onUpdate,
interval: interval,
}
}

View file

@ -6,12 +6,45 @@ import (
"io"
"net/http"
"os"
"path/filepath"
"time"
mihomoHttp "github.com/metacubex/mihomo/component/http"
"github.com/metacubex/mihomo/component/profile/cachefile"
types "github.com/metacubex/mihomo/constant/provider"
)
const (
DefaultHttpTimeout = time.Second * 20
fileMode os.FileMode = 0o666
dirMode os.FileMode = 0o755
)
var (
etag = false
)
func ETag() bool {
return etag
}
func SetETag(b bool) {
etag = b
}
func safeWrite(path string, buf []byte) error {
dir := filepath.Dir(path)
if _, err := os.Stat(dir); os.IsNotExist(err) {
if err := os.MkdirAll(dir, dirMode); err != nil {
return err
}
}
return os.WriteFile(path, buf, fileMode)
}
type FileVehicle struct {
path string
}
@ -24,23 +57,37 @@ func (f *FileVehicle) Path() string {
return f.path
}
func (f *FileVehicle) Read() ([]byte, error) {
return os.ReadFile(f.path)
func (f *FileVehicle) Url() string {
return "file://" + f.path
}
func (f *FileVehicle) Read(ctx context.Context, oldHash types.HashType) (buf []byte, hash types.HashType, err error) {
buf, err = os.ReadFile(f.path)
if err != nil {
return
}
hash = types.MakeHash(buf)
return
}
func (f *FileVehicle) Proxy() string {
return ""
}
func (f *FileVehicle) Write(buf []byte) error {
return safeWrite(f.path, buf)
}
func NewFileVehicle(path string) *FileVehicle {
return &FileVehicle{path: path}
}
type HTTPVehicle struct {
url string
path string
proxy string
header http.Header
url string
path string
proxy string
header http.Header
timeout time.Duration
}
func (h *HTTPVehicle) Url() string {
@ -59,24 +106,56 @@ func (h *HTTPVehicle) Proxy() string {
return h.proxy
}
func (h *HTTPVehicle) Read() ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
func (h *HTTPVehicle) Write(buf []byte) error {
return safeWrite(h.path, buf)
}
func (h *HTTPVehicle) Read(ctx context.Context, oldHash types.HashType) (buf []byte, hash types.HashType, err error) {
ctx, cancel := context.WithTimeout(ctx, h.timeout)
defer cancel()
resp, err := mihomoHttp.HttpRequestWithProxy(ctx, h.url, http.MethodGet, h.header, nil, h.proxy)
header := h.header
setIfNoneMatch := false
if etag && oldHash.IsValid() {
hashBytes, etag := cachefile.Cache().GetETagWithHash(h.url)
if oldHash.EqualBytes(hashBytes) && etag != "" {
if header == nil {
header = http.Header{}
} else {
header = header.Clone()
}
header.Set("If-None-Match", etag)
setIfNoneMatch = true
}
}
resp, err := mihomoHttp.HttpRequestWithProxy(ctx, h.url, http.MethodGet, header, nil, h.proxy)
if err != nil {
return nil, err
return
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode > 299 {
return nil, errors.New(resp.Status)
if setIfNoneMatch && resp.StatusCode == http.StatusNotModified {
return nil, oldHash, nil
}
err = errors.New(resp.Status)
return
}
buf, err := io.ReadAll(resp.Body)
buf, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
return
}
return buf, nil
hash = types.MakeHash(buf)
if etag {
cachefile.Cache().SetETagWithHash(h.url, hash.Bytes(), resp.Header.Get("ETag"))
}
return
}
func NewHTTPVehicle(url string, path string, proxy string, header http.Header) *HTTPVehicle {
return &HTTPVehicle{url, path, proxy, header}
func NewHTTPVehicle(url string, path string, proxy string, header http.Header, timeout time.Duration) *HTTPVehicle {
return &HTTPVehicle{
url: url,
path: path,
proxy: proxy,
header: header,
timeout: timeout,
}
}

View file

@ -2,15 +2,12 @@ package sniffer
import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
"github.com/metacubex/mihomo/common/lru"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/trie"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/constant/sniffer"
"github.com/metacubex/mihomo/log"
@ -22,29 +19,46 @@ var (
ErrNoClue = errors.New("not enough information for making a decision")
)
var Dispatcher *SnifferDispatcher
type SnifferDispatcher struct {
type Dispatcher struct {
enable bool
sniffers map[sniffer.Sniffer]SnifferConfig
forceDomain *trie.DomainSet
skipSNI *trie.DomainSet
skipList *lru.LruCache[string, uint8]
rwMux sync.RWMutex
forceDomain []C.DomainMatcher
skipSrcAddress []C.IpMatcher
skipDstAddress []C.IpMatcher
skipDomain []C.DomainMatcher
skipList *lru.LruCache[netip.AddrPort, uint8]
forceDnsMapping bool
parsePureIp bool
}
func (sd *SnifferDispatcher) shouldOverride(metadata *C.Metadata) bool {
return (metadata.Host == "" && sd.parsePureIp) ||
sd.forceDomain.Has(metadata.Host) ||
(metadata.DNSMode == C.DNSMapping && sd.forceDnsMapping)
func (sd *Dispatcher) shouldOverride(metadata *C.Metadata) bool {
for _, matcher := range sd.skipDstAddress {
if matcher.MatchIp(metadata.DstIP) {
return false
}
}
for _, matcher := range sd.skipSrcAddress {
if matcher.MatchIp(metadata.SrcIP) {
return false
}
}
if metadata.Host == "" && sd.parsePureIp {
return true
}
if metadata.DNSMode == C.DNSMapping && sd.forceDnsMapping {
return true
}
for _, matcher := range sd.forceDomain {
if matcher.MatchDomain(metadata.Host) {
return true
}
}
return false
}
func (sd *SnifferDispatcher) UDPSniff(packet C.PacketAdapter) bool {
func (sd *Dispatcher) UDPSniff(packet C.PacketAdapter) bool {
metadata := packet.Metadata()
if sd.shouldOverride(packet.Metadata()) {
if sd.shouldOverride(metadata) {
for sniffer, config := range sd.sniffers {
if sniffer.SupportNetwork() == C.UDP || sniffer.SupportNetwork() == C.ALLNet {
inWhitelist := sniffer.SupportPort(metadata.DstPort)
@ -67,7 +81,7 @@ func (sd *SnifferDispatcher) UDPSniff(packet C.PacketAdapter) bool {
}
// TCPSniff returns true if the connection is sniffed to have a domain
func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata) bool {
func (sd *Dispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata) bool {
if sd.shouldOverride(metadata) {
inWhitelist := false
overrideDest := false
@ -85,37 +99,35 @@ func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata
return false
}
sd.rwMux.RLock()
dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
dst := metadata.AddrPort()
if count, ok := sd.skipList.Get(dst); ok && count > 5 {
log.Debugln("[Sniffer] Skip sniffing[%s] due to multiple failures", dst)
defer sd.rwMux.RUnlock()
return false
}
sd.rwMux.RUnlock()
if host, err := sd.sniffDomain(conn, metadata); err != nil {
host, err := sd.sniffDomain(conn, metadata)
if err != nil {
sd.cacheSniffFailed(metadata)
log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%d] to [%s:%d]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort)
return false
} else {
if sd.skipSNI.Has(host) {
}
for _, matcher := range sd.skipDomain {
if matcher.MatchDomain(host) {
log.Debugln("[Sniffer] Skip sni[%s]", host)
return false
}
sd.rwMux.RLock()
sd.skipList.Delete(dst)
sd.rwMux.RUnlock()
sd.replaceDomain(metadata, host, overrideDest)
return true
}
sd.skipList.Delete(dst)
sd.replaceDomain(metadata, host, overrideDest)
return true
}
return false
}
func (sd *SnifferDispatcher) replaceDomain(metadata *C.Metadata, host string, overrideDest bool) {
func (sd *Dispatcher) replaceDomain(metadata *C.Metadata, host string, overrideDest bool) {
metadata.SniffHost = host
if overrideDest {
log.Debugln("[Sniffer] Sniff %s [%s]-->[%s] success, replace domain [%s]-->[%s]",
@ -128,11 +140,11 @@ func (sd *SnifferDispatcher) replaceDomain(metadata *C.Metadata, host string, ov
metadata.DNSMode = C.DNSNormal
}
func (sd *SnifferDispatcher) Enable() bool {
return sd.enable
func (sd *Dispatcher) Enable() bool {
return sd != nil && sd.enable
}
func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metadata) (string, error) {
func (sd *Dispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metadata) (string, error) {
for s := range sd.sniffers {
if s.SupportNetwork() == C.TCP {
_ = conn.SetReadDeadline(time.Now().Add(1 * time.Second))
@ -175,43 +187,45 @@ func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metad
return "", ErrorSniffFailed
}
func (sd *SnifferDispatcher) cacheSniffFailed(metadata *C.Metadata) {
sd.rwMux.Lock()
dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
count, _ := sd.skipList.Get(dst)
if count <= 5 {
count++
}
sd.skipList.Set(dst, count)
sd.rwMux.Unlock()
func (sd *Dispatcher) cacheSniffFailed(metadata *C.Metadata) {
dst := metadata.AddrPort()
sd.skipList.Compute(dst, func(oldValue uint8, loaded bool) (newValue uint8, delete bool) {
if oldValue <= 5 {
oldValue++
}
return oldValue, false
})
}
func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {
dispatcher := SnifferDispatcher{
enable: false,
}
return &dispatcher, nil
type Config struct {
Enable bool
Sniffers map[sniffer.Type]SnifferConfig
ForceDomain []C.DomainMatcher
SkipSrcAddress []C.IpMatcher
SkipDstAddress []C.IpMatcher
SkipDomain []C.DomainMatcher
ForceDnsMapping bool
ParsePureIp bool
}
func NewSnifferDispatcher(snifferConfig map[sniffer.Type]SnifferConfig,
forceDomain *trie.DomainSet, skipSNI *trie.DomainSet,
forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) {
dispatcher := SnifferDispatcher{
enable: true,
forceDomain: forceDomain,
skipSNI: skipSNI,
skipList: lru.New(lru.WithSize[string, uint8](128), lru.WithAge[string, uint8](600)),
forceDnsMapping: forceDnsMapping,
parsePureIp: parsePureIp,
sniffers: make(map[sniffer.Sniffer]SnifferConfig, 0),
func NewDispatcher(snifferConfig *Config) (*Dispatcher, error) {
dispatcher := Dispatcher{
enable: snifferConfig.Enable,
forceDomain: snifferConfig.ForceDomain,
skipSrcAddress: snifferConfig.SkipSrcAddress,
skipDstAddress: snifferConfig.SkipDstAddress,
skipDomain: snifferConfig.SkipDomain,
skipList: lru.New(lru.WithSize[netip.AddrPort, uint8](128), lru.WithAge[netip.AddrPort, uint8](600)),
forceDnsMapping: snifferConfig.ForceDnsMapping,
parsePureIp: snifferConfig.ParsePureIp,
sniffers: make(map[sniffer.Sniffer]SnifferConfig, len(snifferConfig.Sniffers)),
}
for snifferName, config := range snifferConfig {
for snifferName, config := range snifferConfig.Sniffers {
s, err := NewSniffer(snifferName, config)
if err != nil {
log.Errorln("Sniffer name[%s] is error", snifferName)
return &SnifferDispatcher{enable: false}, err
return &Dispatcher{enable: false}, err
}
dispatcher.sniffers[s] = config
}

View file

@ -3,6 +3,8 @@ package trie
import (
"errors"
"strings"
"unicode"
"unicode/utf8"
)
const (
@ -25,6 +27,14 @@ func ValidAndSplitDomain(domain string) ([]string, bool) {
if domain != "" && domain[len(domain)-1] == '.' {
return nil, false
}
if domain != "" {
if r, _ := utf8.DecodeRuneInString(domain); unicode.IsSpace(r) {
return nil, false
}
if r, _ := utf8.DecodeLastRuneInString(domain); unicode.IsSpace(r) {
return nil, false
}
}
domain = strings.ToLower(domain)
parts := strings.Split(domain, domainStep)
if len(parts) == 1 {
@ -123,27 +133,41 @@ func (t *DomainTrie[T]) Optimize() {
t.root.optimize()
}
func (t *DomainTrie[T]) Foreach(print func(domain string, data T)) {
func (t *DomainTrie[T]) Foreach(fn func(domain string, data T) bool) {
for key, data := range t.root.getChildren() {
recursion([]string{key}, data, print)
if data != nil && data.inited {
print(joinDomain([]string{key}), data.data)
recursion([]string{key}, data, fn)
if !data.isEmpty() {
if !fn(joinDomain([]string{key}), data.data) {
return
}
}
}
}
func recursion[T any](items []string, node *Node[T], fn func(domain string, data T)) {
func (t *DomainTrie[T]) IsEmpty() bool {
if t == nil || t.root == nil {
return true
}
return len(t.root.getChildren()) == 0
}
func recursion[T any](items []string, node *Node[T], fn func(domain string, data T) bool) bool {
for key, data := range node.getChildren() {
newItems := append([]string{key}, items...)
if data != nil && data.inited {
if !data.isEmpty() {
domain := joinDomain(newItems)
if domain[0] == domainStepByte {
domain = complexWildcard + domain
}
fn(domain, data.Data())
if !fn(domain, data.Data()) {
return false
}
}
if !recursion(newItems, data, fn) {
return false
}
recursion(newItems, data, fn)
}
return true
}
func joinDomain(items []string) string {

View file

@ -28,8 +28,9 @@ type qElt struct{ s, e, col int }
// NewDomainSet creates a new *DomainSet struct, from a DomainTrie.
func (t *DomainTrie[T]) NewDomainSet() *DomainSet {
reserveDomains := make([]string, 0)
t.Foreach(func(domain string, data T) {
t.Foreach(func(domain string, data T) bool {
reserveDomains = append(reserveDomains, utils.Reverse(domain))
return true
})
// ensure that the same prefix is continuous
// and according to the ascending sequence of length
@ -136,6 +137,46 @@ func (ss *DomainSet) Has(key string) bool {
}
func (ss *DomainSet) keys(f func(key string) bool) {
var currentKey []byte
var traverse func(int, int) bool
traverse = func(nodeId, bmIdx int) bool {
if getBit(ss.leaves, nodeId) != 0 {
if !f(string(currentKey)) {
return false
}
}
for ; ; bmIdx++ {
if getBit(ss.labelBitmap, bmIdx) != 0 {
return true
}
nextLabel := ss.labels[bmIdx-nodeId]
currentKey = append(currentKey, nextLabel)
nextNodeId := countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
nextBmIdx := selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nextNodeId-1) + 1
if !traverse(nextNodeId, nextBmIdx) {
return false
}
currentKey = currentKey[:len(currentKey)-1]
}
}
traverse(0, 0)
return
}
func (ss *DomainSet) Foreach(f func(key string) bool) {
ss.keys(func(key string) bool {
return f(utils.Reverse(key))
})
}
// MatchDomain implements C.DomainMatcher
func (ss *DomainSet) MatchDomain(domain string) bool {
return ss.Has(domain)
}
func setBit(bm *[]uint64, i int, v int) {
for i>>6 >= len(*bm) {
*bm = append(*bm, 0)

View file

@ -0,0 +1,115 @@
package trie
import (
"encoding/binary"
"errors"
"io"
)
func (ss *DomainSet) WriteBin(w io.Writer) (err error) {
// version
_, err = w.Write([]byte{1})
if err != nil {
return err
}
// leaves
err = binary.Write(w, binary.BigEndian, int64(len(ss.leaves)))
if err != nil {
return err
}
for _, d := range ss.leaves {
err = binary.Write(w, binary.BigEndian, d)
if err != nil {
return err
}
}
// labelBitmap
err = binary.Write(w, binary.BigEndian, int64(len(ss.labelBitmap)))
if err != nil {
return err
}
for _, d := range ss.labelBitmap {
err = binary.Write(w, binary.BigEndian, d)
if err != nil {
return err
}
}
// labels
err = binary.Write(w, binary.BigEndian, int64(len(ss.labels)))
if err != nil {
return err
}
_, err = w.Write(ss.labels)
if err != nil {
return err
}
return nil
}
func ReadDomainSetBin(r io.Reader) (ds *DomainSet, err error) {
// version
version := make([]byte, 1)
_, err = io.ReadFull(r, version)
if err != nil {
return nil, err
}
if version[0] != 1 {
return nil, errors.New("version is invalid")
}
ds = &DomainSet{}
var length int64
// leaves
err = binary.Read(r, binary.BigEndian, &length)
if err != nil {
return nil, err
}
if length < 1 {
return nil, errors.New("length is invalid")
}
ds.leaves = make([]uint64, length)
for i := int64(0); i < length; i++ {
err = binary.Read(r, binary.BigEndian, &ds.leaves[i])
if err != nil {
return nil, err
}
}
// labelBitmap
err = binary.Read(r, binary.BigEndian, &length)
if err != nil {
return nil, err
}
if length < 1 {
return nil, errors.New("length is invalid")
}
ds.labelBitmap = make([]uint64, length)
for i := int64(0); i < length; i++ {
err = binary.Read(r, binary.BigEndian, &ds.labelBitmap[i])
if err != nil {
return nil, err
}
}
// labels
err = binary.Read(r, binary.BigEndian, &length)
if err != nil {
return nil, err
}
if length < 1 {
return nil, errors.New("length is invalid")
}
ds.labels = make([]byte, length)
_, err = io.ReadFull(r, ds.labels)
if err != nil {
return nil, err
}
ds.init()
return ds, nil
}

View file

@ -1,12 +1,29 @@
package trie_test
import (
"golang.org/x/exp/slices"
"testing"
"github.com/metacubex/mihomo/component/trie"
"github.com/stretchr/testify/assert"
)
func testDump(t *testing.T, tree *trie.DomainTrie[struct{}], set *trie.DomainSet) {
var dataSrc []string
tree.Foreach(func(domain string, data struct{}) bool {
dataSrc = append(dataSrc, domain)
return true
})
slices.Sort(dataSrc)
var dataSet []string
set.Foreach(func(key string) bool {
dataSet = append(dataSet, key)
return true
})
slices.Sort(dataSet)
assert.Equal(t, dataSrc, dataSet)
}
func TestDomainSet(t *testing.T) {
tree := trie.New[struct{}]()
domainSet := []string{
@ -23,6 +40,7 @@ func TestDomainSet(t *testing.T) {
for _, domain := range domainSet {
assert.NoError(t, tree.Insert(domain, struct{}{}))
}
assert.False(t, tree.IsEmpty())
set := tree.NewDomainSet()
assert.NotNil(t, set)
assert.True(t, set.Has("test.cn"))
@ -33,6 +51,7 @@ func TestDomainSet(t *testing.T) {
assert.True(t, set.Has("google.com"))
assert.False(t, set.Has("qq.com"))
assert.False(t, set.Has("www.baidu.com"))
testDump(t, tree, set)
}
func TestDomainSetComplexWildcard(t *testing.T) {
@ -50,11 +69,13 @@ func TestDomainSetComplexWildcard(t *testing.T) {
for _, domain := range domainSet {
assert.NoError(t, tree.Insert(domain, struct{}{}))
}
assert.False(t, tree.IsEmpty())
set := tree.NewDomainSet()
assert.NotNil(t, set)
assert.False(t, set.Has("google.com"))
assert.True(t, set.Has("www.baidu.com"))
assert.True(t, set.Has("test.test.baidu.com"))
testDump(t, tree, set)
}
func TestDomainSetWildcard(t *testing.T) {
@ -71,6 +92,7 @@ func TestDomainSetWildcard(t *testing.T) {
for _, domain := range domainSet {
assert.NoError(t, tree.Insert(domain, struct{}{}))
}
assert.False(t, tree.IsEmpty())
set := tree.NewDomainSet()
assert.NotNil(t, set)
assert.True(t, set.Has("www.baidu.com"))
@ -82,4 +104,5 @@ func TestDomainSetWildcard(t *testing.T) {
assert.False(t, set.Has("a.www.google.com"))
assert.False(t, set.Has("test.qq.com"))
assert.False(t, set.Has("test.test.test.qq.com"))
testDump(t, tree, set)
}

View file

@ -121,8 +121,20 @@ func TestTrie_Foreach(t *testing.T) {
assert.NoError(t, tree.Insert(domain, localIP))
}
count := 0
tree.Foreach(func(domain string, data netip.Addr) {
tree.Foreach(func(domain string, data netip.Addr) bool {
count++
return true
})
assert.Equal(t, 7, count)
}
func TestTrie_Space(t *testing.T) {
validDomain := func(domain string) bool {
_, ok := trie.ValidAndSplitDomain(domain)
return ok
}
assert.True(t, validDomain("google.com"))
assert.False(t, validDomain(" google.com"))
assert.False(t, validDomain(" google.com "))
assert.True(t, validDomain("Mijia Cloud"))
}

View file

@ -230,14 +230,14 @@ func clean() {
// MaxPackageFileSize is a maximum package file length in bytes. The largest
// package whose size is limited by this constant currently has the size of
// approximately 9 MiB.
// approximately 32 MiB.
const MaxPackageFileSize = 32 * 1024 * 1024
// Download package file and save it to disk
func downloadPackageFile() (err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*90)
defer cancel()
resp, err := mihomoHttp.HttpRequest(ctx, packageURL, http.MethodGet, http.Header{"User-Agent": {C.UA}}, nil)
resp, err := mihomoHttp.HttpRequest(ctx, packageURL, http.MethodGet, nil, nil)
if err != nil {
return fmt.Errorf("http request failed: %w", err)
}
@ -418,7 +418,7 @@ func copyFile(src, dst string) error {
func getLatestVersion() (version string, err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
resp, err := mihomoHttp.HttpRequest(ctx, versionURL, http.MethodGet, http.Header{"User-Agent": {C.UA}}, nil)
resp, err := mihomoHttp.HttpRequest(ctx, versionURL, http.MethodGet, nil, nil)
if err != nil {
return "", fmt.Errorf("get Latest Version fail: %w", err)
}

View file

@ -1,6 +1,7 @@
package updater
import (
"context"
"errors"
"fmt"
"os"
@ -8,92 +9,199 @@ import (
"time"
"github.com/metacubex/mihomo/common/atomic"
"github.com/metacubex/mihomo/common/batch"
"github.com/metacubex/mihomo/component/geodata"
_ "github.com/metacubex/mihomo/component/geodata/standard"
"github.com/metacubex/mihomo/component/mmdb"
"github.com/metacubex/mihomo/component/resource"
C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider"
"github.com/metacubex/mihomo/log"
"github.com/oschwald/maxminddb-golang"
)
var (
UpdatingGeo atomic.Bool
autoUpdate bool
updateInterval int
updatingGeo atomic.Bool
)
func updateGeoDatabases() error {
defer runtime.GC()
geoLoader, err := geodata.GetGeoDataLoader("standard")
func GeoAutoUpdate() bool {
return autoUpdate
}
func GeoUpdateInterval() int {
return updateInterval
}
func SetGeoAutoUpdate(newAutoUpdate bool) {
autoUpdate = newAutoUpdate
}
func SetGeoUpdateInterval(newGeoUpdateInterval int) {
updateInterval = newGeoUpdateInterval
}
func UpdateMMDB() (err error) {
vehicle := resource.NewHTTPVehicle(geodata.MmdbUrl(), C.Path.MMDB(), "", nil, defaultHttpTimeout)
var oldHash P.HashType
if buf, err := os.ReadFile(vehicle.Path()); err == nil {
oldHash = P.MakeHash(buf)
}
data, hash, err := vehicle.Read(context.Background(), oldHash)
if err != nil {
return err
return fmt.Errorf("can't download MMDB database file: %w", err)
}
if oldHash.Equal(hash) { // same hash, ignored
return nil
}
if len(data) == 0 {
return fmt.Errorf("can't download MMDB database file: no data")
}
if C.GeodataMode {
data, err := downloadForBytes(C.GeoIpUrl)
if err != nil {
return fmt.Errorf("can't download GeoIP database file: %w", err)
}
instance, err := maxminddb.FromBytes(data)
if err != nil {
return fmt.Errorf("invalid MMDB database file: %s", err)
}
_ = instance.Close()
if _, err = geoLoader.LoadIPByBytes(data, "cn"); err != nil {
return fmt.Errorf("invalid GeoIP database file: %s", err)
}
defer mmdb.ReloadIP()
mmdb.IPInstance().Reader.Close() // mmdb is loaded with mmap, so it needs to be closed before overwriting the file
if err = vehicle.Write(data); err != nil {
return fmt.Errorf("can't save MMDB database file: %w", err)
}
return nil
}
if err = saveFile(data, C.Path.GeoIP()); err != nil {
return fmt.Errorf("can't save GeoIP database file: %w", err)
}
} else {
defer mmdb.ReloadIP()
data, err := downloadForBytes(C.MmdbUrl)
if err != nil {
return fmt.Errorf("can't download MMDB database file: %w", err)
}
instance, err := maxminddb.FromBytes(data)
if err != nil {
return fmt.Errorf("invalid MMDB database file: %s", err)
}
_ = instance.Close()
mmdb.IPInstance().Reader.Close() // mmdb is loaded with mmap, so it needs to be closed before overwriting the file
if err = saveFile(data, C.Path.MMDB()); err != nil {
return fmt.Errorf("can't save MMDB database file: %w", err)
}
func UpdateASN() (err error) {
vehicle := resource.NewHTTPVehicle(geodata.ASNUrl(), C.Path.ASN(), "", nil, defaultHttpTimeout)
var oldHash P.HashType
if buf, err := os.ReadFile(vehicle.Path()); err == nil {
oldHash = P.MakeHash(buf)
}
data, hash, err := vehicle.Read(context.Background(), oldHash)
if err != nil {
return fmt.Errorf("can't download ASN database file: %w", err)
}
if oldHash.Equal(hash) { // same hash, ignored
return nil
}
if len(data) == 0 {
return fmt.Errorf("can't download ASN database file: no data")
}
if C.ASNEnable {
defer mmdb.ReloadASN()
data, err := downloadForBytes(C.ASNUrl)
if err != nil {
return fmt.Errorf("can't download ASN database file: %w", err)
}
instance, err := maxminddb.FromBytes(data)
if err != nil {
return fmt.Errorf("invalid ASN database file: %s", err)
}
_ = instance.Close()
instance, err := maxminddb.FromBytes(data)
if err != nil {
return fmt.Errorf("invalid ASN database file: %s", err)
}
_ = instance.Close()
defer mmdb.ReloadASN()
mmdb.ASNInstance().Reader.Close() // mmdb is loaded with mmap, so it needs to be closed before overwriting the file
if err = vehicle.Write(data); err != nil {
return fmt.Errorf("can't save ASN database file: %w", err)
}
return nil
}
mmdb.ASNInstance().Reader.Close()
if err = saveFile(data, C.Path.ASN()); err != nil {
return fmt.Errorf("can't save ASN database file: %w", err)
}
func UpdateGeoIp() (err error) {
geoLoader, err := geodata.GetGeoDataLoader("standard")
vehicle := resource.NewHTTPVehicle(geodata.GeoIpUrl(), C.Path.GeoIP(), "", nil, defaultHttpTimeout)
var oldHash P.HashType
if buf, err := os.ReadFile(vehicle.Path()); err == nil {
oldHash = P.MakeHash(buf)
}
data, hash, err := vehicle.Read(context.Background(), oldHash)
if err != nil {
return fmt.Errorf("can't download GeoIP database file: %w", err)
}
if oldHash.Equal(hash) { // same hash, ignored
return nil
}
if len(data) == 0 {
return fmt.Errorf("can't download GeoIP database file: no data")
}
data, err := downloadForBytes(C.GeoSiteUrl)
if _, err = geoLoader.LoadIPByBytes(data, "cn"); err != nil {
return fmt.Errorf("invalid GeoIP database file: %s", err)
}
defer geodata.ClearGeoIPCache()
if err = vehicle.Write(data); err != nil {
return fmt.Errorf("can't save GeoIP database file: %w", err)
}
return nil
}
func UpdateGeoSite() (err error) {
geoLoader, err := geodata.GetGeoDataLoader("standard")
vehicle := resource.NewHTTPVehicle(geodata.GeoSiteUrl(), C.Path.GeoSite(), "", nil, defaultHttpTimeout)
var oldHash P.HashType
if buf, err := os.ReadFile(vehicle.Path()); err == nil {
oldHash = P.MakeHash(buf)
}
data, hash, err := vehicle.Read(context.Background(), oldHash)
if err != nil {
return fmt.Errorf("can't download GeoSite database file: %w", err)
}
if oldHash.Equal(hash) { // same hash, ignored
return nil
}
if len(data) == 0 {
return fmt.Errorf("can't download GeoSite database file: no data")
}
if _, err = geoLoader.LoadSiteByBytes(data, "cn"); err != nil {
return fmt.Errorf("invalid GeoSite database file: %s", err)
}
if err = saveFile(data, C.Path.GeoSite()); err != nil {
defer geodata.ClearGeoSiteCache()
if err = vehicle.Write(data); err != nil {
return fmt.Errorf("can't save GeoSite database file: %w", err)
}
return nil
}
geodata.ClearCache()
func updateGeoDatabases() error {
defer runtime.GC()
b, _ := batch.New[interface{}](context.Background())
if geodata.GeoIpEnable() {
if geodata.GeodataMode() {
b.Go("UpdateGeoIp", func() (_ interface{}, err error) {
err = UpdateGeoIp()
return
})
} else {
b.Go("UpdateMMDB", func() (_ interface{}, err error) {
err = UpdateMMDB()
return
})
}
}
if geodata.ASNEnable() {
b.Go("UpdateASN", func() (_ interface{}, err error) {
err = UpdateASN()
return
})
}
if geodata.GeoSiteEnable() {
b.Go("UpdateGeoSite", func() (_ interface{}, err error) {
err = UpdateGeoSite()
return
})
}
if e := b.Wait(); e != nil {
return e.Err
}
return nil
}
@ -103,12 +211,12 @@ var ErrGetDatabaseUpdateSkip = errors.New("GEO database is updating, skip")
func UpdateGeoDatabases() error {
log.Infoln("[GEO] Start updating GEO database")
if UpdatingGeo.Load() {
if updatingGeo.Load() {
return ErrGetDatabaseUpdateSkip
}
UpdatingGeo.Store(true)
defer UpdatingGeo.Store(false)
updatingGeo.Store(true)
defer updatingGeo.Store(false)
log.Infoln("[GEO] Updating GEO database")
@ -122,7 +230,7 @@ func UpdateGeoDatabases() error {
func getUpdateTime() (err error, time time.Time) {
var fileInfo os.FileInfo
if C.GeodataMode {
if geodata.GeodataMode() {
fileInfo, err = os.Stat(C.Path.GeoIP())
if err != nil {
return err, time
@ -137,14 +245,14 @@ func getUpdateTime() (err error, time time.Time) {
return nil, fileInfo.ModTime()
}
func RegisterGeoUpdater(onSuccess func()) {
if C.GeoUpdateInterval <= 0 {
log.Errorln("[GEO] Invalid update interval: %d", C.GeoUpdateInterval)
func RegisterGeoUpdater() {
if updateInterval <= 0 {
log.Errorln("[GEO] Invalid update interval: %d", updateInterval)
return
}
go func() {
ticker := time.NewTicker(time.Duration(C.GeoUpdateInterval) * time.Hour)
ticker := time.NewTicker(time.Duration(updateInterval) * time.Hour)
defer ticker.Stop()
err, lastUpdate := getUpdateTime()
@ -154,22 +262,18 @@ func RegisterGeoUpdater(onSuccess func()) {
}
log.Infoln("[GEO] last update time %s", lastUpdate)
if lastUpdate.Add(time.Duration(C.GeoUpdateInterval) * time.Hour).Before(time.Now()) {
log.Infoln("[GEO] Database has not been updated for %v, update now", time.Duration(C.GeoUpdateInterval)*time.Hour)
if lastUpdate.Add(time.Duration(updateInterval) * time.Hour).Before(time.Now()) {
log.Infoln("[GEO] Database has not been updated for %v, update now", time.Duration(updateInterval)*time.Hour)
if err := UpdateGeoDatabases(); err != nil {
log.Errorln("[GEO] Failed to update GEO database: %s", err.Error())
return
} else {
onSuccess()
}
}
for range ticker.C {
log.Infoln("[GEO] updating database every %d hours", C.GeoUpdateInterval)
log.Infoln("[GEO] updating database every %d hours", updateInterval)
if err := UpdateGeoDatabases(); err != nil {
log.Errorln("[GEO] Failed to update GEO database: %s", err.Error())
} else {
onSuccess()
}
}
}()

Some files were not shown because too many files have changed in this diff Show more