From 485d7e23b3d88facd4ee0251dd4d527ed8483ffd Mon Sep 17 00:00:00 2001 From: Thomas Stromberg Date: Mon, 28 Jul 2025 12:48:09 -0400 Subject: [PATCH] Merge https://github.com/avast/retry-go/pull/103 with test --- retry.go | 5 +++-- retry_test.go | 24 +++++++++++++++++++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/retry.go b/retry.go index 5338985..e87451a 100644 --- a/retry.go +++ b/retry.go @@ -188,8 +188,6 @@ func DoWithData[T any](retryableFunc RetryableFuncWithData[T], opts ...Option) ( break } - config.onRetry(n, err) - for errToCheck, attempts := range attemptsForError { if errors.Is(err, errToCheck) { attempts-- @@ -202,6 +200,9 @@ func DoWithData[T any](retryableFunc RetryableFuncWithData[T], opts ...Option) ( if n == config.attempts-1 { break } + + config.onRetry(n, err) + n++ select { case <-config.timer.After(delay(config, n, err)): diff --git a/retry_test.go b/retry_test.go index 1ee3739..f9dfaec 100644 --- a/retry_test.go +++ b/retry_test.go @@ -35,7 +35,7 @@ func TestDoWithDataAllFailed(t *testing.T) { assert.Len(t, err, 10) fmt.Println(err.Error()) assert.Equal(t, expectedErrorFormat, err.Error(), "retry error format") - assert.Equal(t, uint(45), retrySum, "right count of retry") + assert.Equal(t, uint(36), retrySum, "right count of retry") } func TestDoFirstOk(t *testing.T) { @@ -632,6 +632,28 @@ func BenchmarkDoWithDataNoErrors(b *testing.B) { } } +func TestOnRetryNotCalledOnLastAttempt(t *testing.T) { + callCount := 0 + onRetryCalls := make([]uint, 0) + + err := Do( + func() error { + callCount++ + return errors.New("test error") + }, + Attempts(3), + OnRetry(func(n uint, err error) { + onRetryCalls = append(onRetryCalls, n) + }), + Delay(time.Nanosecond), + ) + + assert.Error(t, err) + assert.Equal(t, 3, callCount, "function should be called 3 times") + assert.Equal(t, []uint{0, 1}, onRetryCalls, "onRetry should only be called for first 2 attempts, not the final one") + assert.Len(t, onRetryCalls, 2, "onRetry should be called exactly 2 times (not on last attempt)") +} + func TestIsRecoverable(t *testing.T) { err := errors.New("err") assert.True(t, IsRecoverable(err))