diff --git a/modules/globallock/globallock.go b/modules/globallock/globallock.go index 707d169f05..aa53557729 100644 --- a/modules/globallock/globallock.go +++ b/modules/globallock/globallock.go @@ -27,20 +27,20 @@ func DefaultLocker() Locker { // Lock tries to acquire a lock for the given key, it uses the default locker. // Read the documentation of Locker.Lock for more information about the behavior. -func Lock(ctx context.Context, key string) (context.Context, ReleaseFunc, error) { +func Lock(ctx context.Context, key string) (ReleaseFunc, error) { return DefaultLocker().Lock(ctx, key) } // TryLock tries to acquire a lock for the given key, it uses the default locker. // Read the documentation of Locker.TryLock for more information about the behavior. -func TryLock(ctx context.Context, key string) (bool, context.Context, ReleaseFunc, error) { +func TryLock(ctx context.Context, key string) (bool, ReleaseFunc, error) { return DefaultLocker().TryLock(ctx, key) } // LockAndDo tries to acquire a lock for the given key and then calls the given function. // It uses the default locker, and it will return an error if failed to acquire the lock. func LockAndDo(ctx context.Context, key string, f func(context.Context) error) error { - ctx, release, err := Lock(ctx, key) + release, err := Lock(ctx, key) if err != nil { return err } @@ -52,7 +52,7 @@ func LockAndDo(ctx context.Context, key string, f func(context.Context) error) e // TryLockAndDo tries to acquire a lock for the given key and then calls the given function. // It uses the default locker, and it will return false if failed to acquire the lock. func TryLockAndDo(ctx context.Context, key string, f func(context.Context) error) (bool, error) { - ok, ctx, release, err := TryLock(ctx, key) + ok, release, err := TryLock(ctx, key) if err != nil { return false, err } diff --git a/modules/globallock/locker.go b/modules/globallock/locker.go index b0764cd71c..682e24d052 100644 --- a/modules/globallock/locker.go +++ b/modules/globallock/locker.go @@ -5,56 +5,34 @@ package globallock import ( "context" - "fmt" ) type Locker interface { // Lock tries to acquire a lock for the given key, it blocks until the lock is acquired or the context is canceled. // - // Lock returns a new context which should be used in the following code. - // The new context will be canceled when the lock is released or lost - yes, it's possible to lose a lock. - // For example, it lost the connection to the redis server while holding the lock. - // If it fails to acquire the lock, the returned context will be the same as the input context. - // // Lock returns a ReleaseFunc to release the lock, it cannot be nil. // It's always safe to call this function even if it fails to acquire the lock, and it will do nothing in that case. // And it's also safe to call it multiple times, but it will only release the lock once. // That's why it's called ReleaseFunc, not UnlockFunc. // But be aware that it's not safe to not call it at all; it could lead to a memory leak. // So a recommended pattern is to use defer to call it: - // ctx, release, err := locker.Lock(ctx, "key") + // release, err := locker.Lock(ctx, "key") // if err != nil { // return err // } // defer release() - // The ReleaseFunc will return the original context which was used to acquire the lock. - // It's useful when you want to continue to do something after releasing the lock. - // At that time, the ctx will be canceled, and you can use the returned context by the ReleaseFunc to continue: - // ctx, release, err := locker.Lock(ctx, "key") - // if err != nil { - // return err - // } - // defer release() - // doSomething(ctx) - // ctx = release() - // doSomethingElse(ctx) - // Please ignore it and use `defer release()` instead if you don't need this, to avoid forgetting to release the lock. // // Lock returns an error if failed to acquire the lock. // Be aware that even the context is not canceled, it's still possible to fail to acquire the lock. // For example, redis is down, or it reached the maximum number of tries. - Lock(ctx context.Context, key string) (context.Context, ReleaseFunc, error) + Lock(ctx context.Context, key string) (ReleaseFunc, error) // TryLock tries to acquire a lock for the given key, it returns immediately. // It follows the same pattern as Lock, but it doesn't block. // And if it fails to acquire the lock because it's already locked, not other reasons like redis is down, // it will return false without any error. - TryLock(ctx context.Context, key string) (bool, context.Context, ReleaseFunc, error) + TryLock(ctx context.Context, key string) (bool, ReleaseFunc, error) } // ReleaseFunc is a function that releases a lock. -// It returns the original context which was used to acquire the lock. -type ReleaseFunc func() context.Context - -// ErrLockReleased is used as context cause when a lock is released -var ErrLockReleased = fmt.Errorf("lock released") +type ReleaseFunc func() diff --git a/modules/globallock/locker_test.go b/modules/globallock/locker_test.go index 15a3c65bb0..bee4d34b34 100644 --- a/modules/globallock/locker_test.go +++ b/modules/globallock/locker_test.go @@ -47,27 +47,24 @@ func TestLocker(t *testing.T) { func testLocker(t *testing.T, locker Locker) { t.Run("lock", func(t *testing.T) { parentCtx := context.Background() - ctx, release, err := locker.Lock(parentCtx, "test") + release, err := locker.Lock(parentCtx, "test") defer release() - assert.NotEqual(t, parentCtx, ctx) // new context should be returned assert.NoError(t, err) func() { - parentCtx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - ctx, release, err := locker.Lock(parentCtx, "test") + release, err := locker.Lock(ctx, "test") defer release() assert.Error(t, err) - assert.Equal(t, parentCtx, ctx) // should return the same context }() release() - assert.Error(t, ctx.Err()) func() { - _, release, err := locker.Lock(context.Background(), "test") + release, err := locker.Lock(context.Background(), "test") defer release() assert.NoError(t, err) @@ -76,29 +73,26 @@ func testLocker(t *testing.T, locker Locker) { t.Run("try lock", func(t *testing.T) { parentCtx := context.Background() - ok, ctx, release, err := locker.TryLock(parentCtx, "test") + ok, release, err := locker.TryLock(parentCtx, "test") defer release() assert.True(t, ok) - assert.NotEqual(t, parentCtx, ctx) // new context should be returned assert.NoError(t, err) func() { - parentCtx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - ok, ctx, release, err := locker.TryLock(parentCtx, "test") + ok, release, err := locker.TryLock(ctx, "test") defer release() assert.False(t, ok) assert.NoError(t, err) - assert.Equal(t, parentCtx, ctx) // should return the same context }() release() - assert.Error(t, ctx.Err()) func() { - ok, _, release, _ := locker.TryLock(context.Background(), "test") + ok, release, _ := locker.TryLock(context.Background(), "test") defer release() assert.True(t, ok) @@ -107,7 +101,7 @@ func testLocker(t *testing.T, locker Locker) { t.Run("wait and acquired", func(t *testing.T) { ctx := context.Background() - _, release, err := locker.Lock(ctx, "test") + release, err := locker.Lock(ctx, "test") require.NoError(t, err) wg := &sync.WaitGroup{} @@ -115,7 +109,7 @@ func testLocker(t *testing.T, locker Locker) { go func() { defer wg.Done() started := time.Now() - _, release, err := locker.Lock(context.Background(), "test") // should be blocked for seconds + release, err := locker.Lock(context.Background(), "test") // should be blocked for seconds defer release() assert.Greater(t, time.Since(started), time.Second) assert.NoError(t, err) @@ -127,34 +121,15 @@ func testLocker(t *testing.T, locker Locker) { wg.Wait() }) - t.Run("continue after release", func(t *testing.T) { - ctx := context.Background() - - ctxBeforeLock := ctx - ctx, release, err := locker.Lock(ctx, "test") - - require.NoError(t, err) - assert.NoError(t, ctx.Err()) - assert.NotEqual(t, ctxBeforeLock, ctx) - - ctxBeforeRelease := ctx - ctx = release() - - assert.NoError(t, ctx.Err()) - assert.Error(t, ctxBeforeRelease.Err()) - - // so it can continue with ctx to do more work - }) - t.Run("multiple release", func(t *testing.T) { ctx := context.Background() - _, release1, err := locker.Lock(ctx, "test") + release1, err := locker.Lock(ctx, "test") require.NoError(t, err) release1() - _, release2, err := locker.Lock(ctx, "test") + release2, err := locker.Lock(ctx, "test") defer release2() require.NoError(t, err) @@ -163,7 +138,7 @@ func testLocker(t *testing.T, locker Locker) { // and it shouldn't affect the other lock release1() - ok, _, release3, err := locker.TryLock(ctx, "test") + ok, release3, err := locker.TryLock(ctx, "test") defer release3() require.NoError(t, err) // It should be able to acquire the lock; @@ -184,28 +159,23 @@ func testRedisLocker(t *testing.T, locker *redisLocker) { // Otherwise, it will affect other tests. t.Run("close", func(t *testing.T) { assert.NoError(t, locker.Close()) - _, _, err := locker.Lock(context.Background(), "test") + _, err := locker.Lock(context.Background(), "test") assert.Error(t, err) }) }() t.Run("failed extend", func(t *testing.T) { - ctx, release, err := locker.Lock(context.Background(), "test") + release, err := locker.Lock(context.Background(), "test") defer release() require.NoError(t, err) // It simulates that there are some problems with extending like network issues or redis server down. v, ok := locker.mutexM.Load("test") require.True(t, ok) - m := v.(*redisMutex) - _, _ = m.mutex.Unlock() // release it to make it impossible to extend + m := v.(*redsync.Mutex) + _, _ = m.Unlock() // release it to make it impossible to extend - select { - case <-time.After(redisLockExpiry + time.Second): - t.Errorf("lock should be expired") - case <-ctx.Done(): - var errTaken *redsync.ErrTaken - assert.ErrorAs(t, context.Cause(ctx), &errTaken) - } + // In current design, callers can't know the lock can't be extended. + // Just keep this case to improve the test coverage. }) } diff --git a/modules/globallock/memory_locker.go b/modules/globallock/memory_locker.go index fb1fc79bd0..3f818d8d43 100644 --- a/modules/globallock/memory_locker.go +++ b/modules/globallock/memory_locker.go @@ -19,18 +19,13 @@ func NewMemoryLocker() Locker { return &memoryLocker{} } -func (l *memoryLocker) Lock(ctx context.Context, key string) (context.Context, ReleaseFunc, error) { - originalCtx := ctx - +func (l *memoryLocker) Lock(ctx context.Context, key string) (ReleaseFunc, error) { if l.tryLock(key) { - ctx, cancel := context.WithCancelCause(ctx) releaseOnce := sync.Once{} - return ctx, func() context.Context { + return func() { releaseOnce.Do(func() { l.locks.Delete(key) - cancel(ErrLockReleased) }) - return originalCtx }, nil } @@ -39,39 +34,31 @@ func (l *memoryLocker) Lock(ctx context.Context, key string) (context.Context, R for { select { case <-ctx.Done(): - return ctx, func() context.Context { return originalCtx }, ctx.Err() + return func() {}, ctx.Err() case <-ticker.C: if l.tryLock(key) { - ctx, cancel := context.WithCancelCause(ctx) releaseOnce := sync.Once{} - return ctx, func() context.Context { + return func() { releaseOnce.Do(func() { l.locks.Delete(key) - cancel(ErrLockReleased) }) - return originalCtx }, nil } } } } -func (l *memoryLocker) TryLock(ctx context.Context, key string) (bool, context.Context, ReleaseFunc, error) { - originalCtx := ctx - +func (l *memoryLocker) TryLock(_ context.Context, key string) (bool, ReleaseFunc, error) { if l.tryLock(key) { - ctx, cancel := context.WithCancelCause(ctx) releaseOnce := sync.Once{} - return true, ctx, func() context.Context { + return true, func() { releaseOnce.Do(func() { - cancel(ErrLockReleased) l.locks.Delete(key) }) - return originalCtx }, nil } - return false, ctx, func() context.Context { return originalCtx }, nil + return false, func() {}, nil } func (l *memoryLocker) tryLock(key string) bool { diff --git a/modules/globallock/redis_locker.go b/modules/globallock/redis_locker.go index 34b2fabfb3..34ed9e389b 100644 --- a/modules/globallock/redis_locker.go +++ b/modules/globallock/redis_locker.go @@ -48,21 +48,21 @@ func NewRedisLocker(connection string) Locker { return l } -func (l *redisLocker) Lock(ctx context.Context, key string) (context.Context, ReleaseFunc, error) { +func (l *redisLocker) Lock(ctx context.Context, key string) (ReleaseFunc, error) { return l.lock(ctx, key, 0) } -func (l *redisLocker) TryLock(ctx context.Context, key string) (bool, context.Context, ReleaseFunc, error) { - ctx, f, err := l.lock(ctx, key, 1) +func (l *redisLocker) TryLock(ctx context.Context, key string) (bool, ReleaseFunc, error) { + f, err := l.lock(ctx, key, 1) var ( errTaken *redsync.ErrTaken errNodeTaken *redsync.ErrNodeTaken ) if errors.As(err, &errTaken) || errors.As(err, &errNodeTaken) { - return false, ctx, f, nil + return false, f, nil } - return err == nil, ctx, f, err + return err == nil, f, err } // Close closes the locker. @@ -76,18 +76,11 @@ func (l *redisLocker) Close() error { return nil } -type redisMutex struct { - mutex *redsync.Mutex - cancel context.CancelCauseFunc -} - -func (l *redisLocker) lock(ctx context.Context, key string, tries int) (context.Context, ReleaseFunc, error) { +func (l *redisLocker) lock(ctx context.Context, key string, tries int) (ReleaseFunc, error) { if l.closed.Load() { - return ctx, func() context.Context { return ctx }, fmt.Errorf("locker is closed") + return func() {}, fmt.Errorf("locker is closed") } - originalCtx := ctx - options := []redsync.Option{ redsync.WithExpiry(redisLockExpiry), } @@ -96,18 +89,13 @@ func (l *redisLocker) lock(ctx context.Context, key string, tries int) (context. } mutex := l.rs.NewMutex(redisLockKeyPrefix+key, options...) if err := mutex.LockContext(ctx); err != nil { - return ctx, func() context.Context { return originalCtx }, err + return func() {}, err } - ctx, cancel := context.WithCancelCause(ctx) - - l.mutexM.Store(key, &redisMutex{ - mutex: mutex, - cancel: cancel, - }) + l.mutexM.Store(key, mutex) releaseOnce := sync.Once{} - return ctx, func() context.Context { + return func() { releaseOnce.Do(func() { l.mutexM.Delete(key) @@ -115,10 +103,7 @@ func (l *redisLocker) lock(ctx context.Context, key string, tries int) (context. // if it failed to unlock, it will be released automatically after the lock expires. // Do not call mutex.UnlockContext(ctx) here, or it will fail to release when ctx has timed out. _, _ = mutex.Unlock() - - cancel(ErrLockReleased) }) - return originalCtx }, nil } @@ -128,16 +113,15 @@ func (l *redisLocker) startExtend() { return } - toExtend := make([]*redisMutex, 0) + toExtend := make([]*redsync.Mutex, 0) l.mutexM.Range(func(_, value any) bool { - m := value.(*redisMutex) + m := value.(*redsync.Mutex) // Extend the lock if it is not expired. // Although the mutex will be removed from the map before it is released, // it still can be expired because of a failed extension. - // If it happens, the cancel function should have been called, - // so it does not need to be extended anymore. - if time.Now().After(m.mutex.Until()) { + // If it happens, it does not need to be extended anymore. + if time.Now().After(m.Until()) { return true } @@ -145,9 +129,8 @@ func (l *redisLocker) startExtend() { return true }) for _, v := range toExtend { - if ok, err := v.mutex.Extend(); !ok { - v.cancel(err) - } + // If it failed to extend, it will be released automatically after the lock expires. + _, _ = v.Extend() } time.AfterFunc(redisLockExpiry/2, l.startExtend)