From eaaccffcef0de091fcc1f6a073f3f56d096987ae Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Fri, 4 Apr 2025 10:55:16 +0800 Subject: [PATCH] fix: race in Single.Do --- common/singledo/singledo.go | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/common/singledo/singledo.go b/common/singledo/singledo.go index 4e843ca2..6f0c3bd4 100644 --- a/common/singledo/singledo.go +++ b/common/singledo/singledo.go @@ -13,25 +13,26 @@ type call[T any] struct { type Single[T any] struct { mux sync.Mutex - last time.Time wait time.Duration call *call[T] result *Result[T] } type Result[T any] struct { - Val T - Err error + Val T + Err error + Time time.Time } // Do single.Do likes sync.singleFlight func (s *Single[T]) Do(fn func() (T, error)) (v T, err error, shared bool) { s.mux.Lock() - now := time.Now() - if now.Before(s.last.Add(s.wait)) { + result := s.result + if result != nil && time.Since(result.Time) < s.wait { s.mux.Unlock() - return s.result.Val, s.result.Err, true + return result.Val, result.Err, true } + s.result = nil // The result has expired, clear it if callM := s.call; callM != nil { s.mux.Unlock() @@ -47,15 +48,19 @@ func (s *Single[T]) Do(fn func() (T, error)) (v T, err error, shared bool) { callM.wg.Done() s.mux.Lock() - s.call = nil - s.result = &Result[T]{callM.val, callM.err} - s.last = now + if s.call == callM { // maybe reset when fn is running + s.call = nil + s.result = &Result[T]{callM.val, callM.err, time.Now()} + } s.mux.Unlock() return callM.val, callM.err, false } func (s *Single[T]) Reset() { - s.last = time.Time{} + s.mux.Lock() + s.call = nil + s.result = nil + s.mux.Unlock() } func NewSingle[T any](wait time.Duration) *Single[T] {