diff --git a/adapters/outbound/base.go b/adapters/outbound/base.go index 26fa5d56..f4da21ce 100644 --- a/adapters/outbound/base.go +++ b/adapters/outbound/base.go @@ -6,11 +6,12 @@ import ( "errors" "net" "net/http" - "sync/atomic" "time" "github.com/Dreamacro/clash/common/queue" C "github.com/Dreamacro/clash/constant" + + "go.uber.org/atomic" ) type Base struct { @@ -95,11 +96,11 @@ func newPacketConn(pc net.PacketConn, a C.ProxyAdapter) C.PacketConn { type Proxy struct { C.ProxyAdapter history *queue.Queue - alive uint32 + alive *atomic.Bool } func (p *Proxy) Alive() bool { - return atomic.LoadUint32(&p.alive) > 0 + return p.alive.Load() } func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) { @@ -111,7 +112,7 @@ func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) { func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { conn, err := p.ProxyAdapter.DialContext(ctx, metadata) if err != nil { - atomic.StoreUint32(&p.alive, 0) + p.alive.Store(false) } return conn, err } @@ -128,7 +129,7 @@ func (p *Proxy) DelayHistory() []C.DelayHistory { // LastDelay return last history record. if proxy is not alive, return the max value of uint16. func (p *Proxy) LastDelay() (delay uint16) { var max uint16 = 0xffff - if atomic.LoadUint32(&p.alive) == 0 { + if !p.alive.Load() { return max } @@ -159,11 +160,7 @@ func (p *Proxy) MarshalJSON() ([]byte, error) { // URLTest get the delay for the specified URL func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) { defer func() { - if err == nil { - atomic.StoreUint32(&p.alive, 1) - } else { - atomic.StoreUint32(&p.alive, 0) - } + p.alive.Store(err == nil) record := C.DelayHistory{Time: time.Now()} if err == nil { record.Delay = t @@ -219,5 +216,5 @@ func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) { } func NewProxy(adapter C.ProxyAdapter) *Proxy { - return &Proxy{adapter, queue.New(10), 1} + return &Proxy{adapter, queue.New(10), atomic.NewBool(true)} } diff --git a/common/observable/observable_test.go b/common/observable/observable_test.go index 6dd6ee42..cb16ad39 100644 --- a/common/observable/observable_test.go +++ b/common/observable/observable_test.go @@ -2,11 +2,11 @@ package observable import ( "sync" - "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" + "go.uber.org/atomic" ) func iterator(item []interface{}) chan interface{} { @@ -33,25 +33,25 @@ func TestObservable(t *testing.T) { assert.Equal(t, count, 5) } -func TestObservable_MutilSubscribe(t *testing.T) { +func TestObservable_MultiSubscribe(t *testing.T) { iter := iterator([]interface{}{1, 2, 3, 4, 5}) src := NewObservable(iter) ch1, _ := src.Subscribe() ch2, _ := src.Subscribe() - var count int32 + var count = atomic.NewInt32(0) var wg sync.WaitGroup wg.Add(2) waitCh := func(ch <-chan interface{}) { for range ch { - atomic.AddInt32(&count, 1) + count.Inc() } wg.Done() } go waitCh(ch1) go waitCh(ch2) wg.Wait() - assert.Equal(t, int32(10), count) + assert.Equal(t, int32(10), count.Load()) } func TestObservable_UnSubscribe(t *testing.T) { diff --git a/common/singledo/singledo_test.go b/common/singledo/singledo_test.go index c9c58e58..2b0d5988 100644 --- a/common/singledo/singledo_test.go +++ b/common/singledo/singledo_test.go @@ -2,17 +2,17 @@ package singledo import ( "sync" - "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" + "go.uber.org/atomic" ) func TestBasic(t *testing.T) { single := NewSingle(time.Millisecond * 30) foo := 0 - var shardCount int32 = 0 + var shardCount = atomic.NewInt32(0) call := func() (interface{}, error) { foo++ time.Sleep(time.Millisecond * 5) @@ -26,7 +26,7 @@ func TestBasic(t *testing.T) { go func() { _, _, shard := single.Do(call) if shard { - atomic.AddInt32(&shardCount, 1) + shardCount.Inc() } wg.Done() }() @@ -34,7 +34,7 @@ func TestBasic(t *testing.T) { wg.Wait() assert.Equal(t, 1, foo) - assert.Equal(t, int32(4), shardCount) + assert.Equal(t, int32(4), shardCount.Load()) } func TestTimer(t *testing.T) { diff --git a/go.mod b/go.mod index 0b62577d..455ca0ec 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/oschwald/geoip2-golang v1.4.0 github.com/sirupsen/logrus v1.7.0 github.com/stretchr/testify v1.6.1 + go.uber.org/atomic v1.7.0 golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 golang.org/x/net v0.0.0-20201020065357-d65d470038a5 golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520 diff --git a/go.sum b/go.sum index 1495863c..f99949d5 100644 --- a/go.sum +++ b/go.sum @@ -25,9 +25,12 @@ github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/tunnel/manager.go b/tunnel/manager.go index 1493bc00..784d57d9 100644 --- a/tunnel/manager.go +++ b/tunnel/manager.go @@ -2,26 +2,34 @@ package tunnel import ( "sync" - "sync/atomic" "time" + + "go.uber.org/atomic" ) var DefaultManager *Manager func init() { - DefaultManager = &Manager{} + DefaultManager = &Manager{ + uploadTemp: atomic.NewInt64(0), + downloadTemp: atomic.NewInt64(0), + uploadBlip: atomic.NewInt64(0), + downloadBlip: atomic.NewInt64(0), + uploadTotal: atomic.NewInt64(0), + downloadTotal: atomic.NewInt64(0), + } go DefaultManager.handle() } type Manager struct { connections sync.Map - uploadTemp int64 - downloadTemp int64 - uploadBlip int64 - downloadBlip int64 - uploadTotal int64 - downloadTotal int64 + uploadTemp *atomic.Int64 + downloadTemp *atomic.Int64 + uploadBlip *atomic.Int64 + downloadBlip *atomic.Int64 + uploadTotal *atomic.Int64 + downloadTotal *atomic.Int64 } func (m *Manager) Join(c tracker) { @@ -33,17 +41,17 @@ func (m *Manager) Leave(c tracker) { } func (m *Manager) PushUploaded(size int64) { - atomic.AddInt64(&m.uploadTemp, size) - atomic.AddInt64(&m.uploadTotal, size) + m.uploadTemp.Add(size) + m.uploadTotal.Add(size) } func (m *Manager) PushDownloaded(size int64) { - atomic.AddInt64(&m.downloadTemp, size) - atomic.AddInt64(&m.downloadTotal, size) + m.downloadTemp.Add(size) + m.downloadTotal.Add(size) } func (m *Manager) Now() (up int64, down int64) { - return atomic.LoadInt64(&m.uploadBlip), atomic.LoadInt64(&m.downloadBlip) + return m.uploadBlip.Load(), m.downloadBlip.Load() } func (m *Manager) Snapshot() *Snapshot { @@ -54,29 +62,29 @@ func (m *Manager) Snapshot() *Snapshot { }) return &Snapshot{ - UploadTotal: atomic.LoadInt64(&m.uploadTotal), - DownloadTotal: atomic.LoadInt64(&m.downloadTotal), + UploadTotal: m.uploadTotal.Load(), + DownloadTotal: m.downloadTotal.Load(), Connections: connections, } } func (m *Manager) ResetStatistic() { - m.uploadTemp = 0 - m.uploadBlip = 0 - m.uploadTotal = 0 - m.downloadTemp = 0 - m.downloadBlip = 0 - m.downloadTotal = 0 + m.uploadTemp.Store(0) + m.uploadBlip.Store(0) + m.uploadTotal.Store(0) + m.downloadTemp.Store(0) + m.downloadBlip.Store(0) + m.downloadTotal.Store(0) } func (m *Manager) handle() { ticker := time.NewTicker(time.Second) for range ticker.C { - atomic.StoreInt64(&m.uploadBlip, atomic.LoadInt64(&m.uploadTemp)) - atomic.StoreInt64(&m.uploadTemp, 0) - atomic.StoreInt64(&m.downloadBlip, atomic.LoadInt64(&m.downloadTemp)) - atomic.StoreInt64(&m.downloadTemp, 0) + m.uploadBlip.Store(m.uploadTemp.Load()) + m.uploadTemp.Store(0) + m.downloadBlip.Store(m.downloadTemp.Load()) + m.downloadTemp.Store(0) } } diff --git a/tunnel/tracker.go b/tunnel/tracker.go index 19e79110..dcb81e7f 100644 --- a/tunnel/tracker.go +++ b/tunnel/tracker.go @@ -2,11 +2,12 @@ package tunnel import ( "net" - "sync/atomic" "time" C "github.com/Dreamacro/clash/constant" + "github.com/gofrs/uuid" + "go.uber.org/atomic" ) type tracker interface { @@ -15,14 +16,14 @@ type tracker interface { } type trackerInfo struct { - UUID uuid.UUID `json:"id"` - Metadata *C.Metadata `json:"metadata"` - UploadTotal int64 `json:"upload"` - DownloadTotal int64 `json:"download"` - Start time.Time `json:"start"` - Chain C.Chain `json:"chains"` - Rule string `json:"rule"` - RulePayload string `json:"rulePayload"` + UUID uuid.UUID `json:"id"` + Metadata *C.Metadata `json:"metadata"` + UploadTotal *atomic.Int64 `json:"upload"` + DownloadTotal *atomic.Int64 `json:"download"` + Start time.Time `json:"start"` + Chain C.Chain `json:"chains"` + Rule string `json:"rule"` + RulePayload string `json:"rulePayload"` } type tcpTracker struct { @@ -39,7 +40,7 @@ func (tt *tcpTracker) Read(b []byte) (int, error) { n, err := tt.Conn.Read(b) download := int64(n) tt.manager.PushDownloaded(download) - atomic.AddInt64(&tt.DownloadTotal, download) + tt.DownloadTotal.Add(download) return n, err } @@ -47,7 +48,7 @@ func (tt *tcpTracker) Write(b []byte) (int, error) { n, err := tt.Conn.Write(b) upload := int64(n) tt.manager.PushUploaded(upload) - atomic.AddInt64(&tt.UploadTotal, upload) + tt.UploadTotal.Add(upload) return n, err } @@ -63,11 +64,13 @@ func newTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R Conn: conn, manager: manager, trackerInfo: &trackerInfo{ - UUID: uuid, - Start: time.Now(), - Metadata: metadata, - Chain: conn.Chains(), - Rule: "", + UUID: uuid, + Start: time.Now(), + Metadata: metadata, + Chain: conn.Chains(), + Rule: "", + UploadTotal: atomic.NewInt64(0), + DownloadTotal: atomic.NewInt64(0), }, } @@ -94,7 +97,7 @@ func (ut *udpTracker) ReadFrom(b []byte) (int, net.Addr, error) { n, addr, err := ut.PacketConn.ReadFrom(b) download := int64(n) ut.manager.PushDownloaded(download) - atomic.AddInt64(&ut.DownloadTotal, download) + ut.DownloadTotal.Add(download) return n, addr, err } @@ -102,7 +105,7 @@ func (ut *udpTracker) WriteTo(b []byte, addr net.Addr) (int, error) { n, err := ut.PacketConn.WriteTo(b, addr) upload := int64(n) ut.manager.PushUploaded(upload) - atomic.AddInt64(&ut.UploadTotal, upload) + ut.UploadTotal.Add(upload) return n, err } @@ -118,11 +121,13 @@ func newUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, ru PacketConn: conn, manager: manager, trackerInfo: &trackerInfo{ - UUID: uuid, - Start: time.Now(), - Metadata: metadata, - Chain: conn.Chains(), - Rule: "", + UUID: uuid, + Start: time.Now(), + Metadata: metadata, + Chain: conn.Chains(), + Rule: "", + UploadTotal: atomic.NewInt64(0), + DownloadTotal: atomic.NewInt64(0), }, }