-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
439 additions
and
126 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
// Copyright 2025 The Outline Authors | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
// Package callback provides a thread-safe mechanism for managing and invoking callbacks. | ||
package callback | ||
|
||
import ( | ||
"fmt" | ||
"log/slog" | ||
"sync" | ||
) | ||
|
||
// Token can be used to uniquely identify a registered callback. | ||
type Token string | ||
|
||
// Callback is an interface that can be implemented to receive callbacks. | ||
type Callback interface { | ||
OnCall(data string) | ||
} | ||
|
||
var ( | ||
mu sync.RWMutex | ||
callbacks = make(map[uint32]Callback) | ||
nextCbID uint32 = 1 | ||
) | ||
|
||
// New registers a new callback and returns a unique callback token. | ||
func New(c Callback) Token { | ||
mu.Lock() | ||
defer mu.Unlock() | ||
|
||
id := nextCbID | ||
nextCbID++ | ||
callbacks[id] = c | ||
slog.Debug("callback created", "id", id) | ||
return getTokenByID(id) | ||
} | ||
|
||
// Delete removes a callback identified by the token. | ||
// | ||
// Calling this function is safe even if the callback has not been registered. | ||
func Delete(token Token) { | ||
mu.Lock() | ||
defer mu.Unlock() | ||
|
||
if id, err := getIDByToken(token); err == nil { | ||
delete(callbacks, id) | ||
slog.Debug("callback deleted", "id", id) | ||
} else { | ||
slog.Warn("invalid callback token", "err", err, "token", token) | ||
} | ||
} | ||
|
||
// Call executes a callback identified by the token. | ||
// | ||
// Calling this function is safe even if the callback has not been registered. | ||
func Call(token Token, data string) { | ||
id, err := getIDByToken(token) | ||
if err != nil { | ||
slog.Warn("invalid callback token", "err", err, "token", token) | ||
return | ||
} | ||
|
||
mu.RLock() | ||
cb, ok := callbacks[id] | ||
mu.RUnlock() | ||
|
||
if !ok { | ||
slog.Warn("callback not yet created", "id", id, "token", token) | ||
return | ||
} | ||
slog.Debug("invoking callback", "id", id, "data", data) | ||
cb.OnCall(data) | ||
} | ||
|
||
// getTokenByID creates a string-based callback token from a number-based internal ID. | ||
func getTokenByID(id uint32) Token { | ||
return Token(fmt.Sprintf("cbid-%d", id)) | ||
} | ||
|
||
// getIDByToken parses a number-based internal ID from a string-based callback token. | ||
func getIDByToken(token Token) (uint32, error) { | ||
var id uint32 | ||
_, err := fmt.Sscanf(string(token), "cbid-%d", &id) | ||
return id, err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
// Copyright 2025 The Outline Authors | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package callback | ||
|
||
import ( | ||
"fmt" | ||
"sync" | ||
"sync/atomic" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func Test_New(t *testing.T) { | ||
curID := nextCbID | ||
token := New(&testCallback{}) | ||
require.Equal(t, curID+1, nextCbID) | ||
require.Contains(t, callbacks, curID) | ||
|
||
require.NotEmpty(t, token) | ||
require.Equal(t, fmt.Sprintf("cbid-%d", curID), string(token)) | ||
|
||
id, err := getIDByToken(token) | ||
require.NoError(t, err) | ||
require.Contains(t, callbacks, id) | ||
require.Equal(t, id, curID) | ||
} | ||
|
||
func Test_Delete(t *testing.T) { | ||
curID := nextCbID | ||
token := New(&testCallback{}) | ||
require.Contains(t, callbacks, curID) | ||
|
||
Delete(token) | ||
require.NotContains(t, callbacks, curID) | ||
require.Equal(t, curID+1, nextCbID) | ||
|
||
Delete("invalid-token") | ||
require.NotContains(t, callbacks, curID) | ||
require.Equal(t, curID+1, nextCbID) | ||
|
||
Delete("cbid-99999999") | ||
require.NotContains(t, callbacks, curID) | ||
require.Equal(t, curID+1, nextCbID) | ||
} | ||
|
||
func Test_Call(t *testing.T) { | ||
c := &testCallback{} | ||
token := New(c) | ||
c.requireEqual(t, 0, "") | ||
|
||
Call(token, "arg1") | ||
c.requireEqual(t, 1, "arg1") | ||
|
||
Call("invalid-token", "arg1") | ||
c.requireEqual(t, 1, "arg1") // No change | ||
|
||
Call(token, "arg2") | ||
c.requireEqual(t, 2, "arg2") | ||
|
||
Call("cbid-99999999", "arg3") | ||
c.requireEqual(t, 2, "arg2") // No change | ||
} | ||
|
||
func Test_ConcurrentCreate(t *testing.T) { | ||
const numTokens = 1000 | ||
|
||
curID := nextCbID | ||
originalLen := len(callbacks) | ||
var wg sync.WaitGroup | ||
|
||
tokens := make([]Token, numTokens) | ||
wg.Add(numTokens) | ||
for i := 0; i < numTokens; i++ { | ||
go func(i int) { | ||
defer wg.Done() | ||
tokens[i] = New(&testCallback{}) | ||
require.NotEmpty(t, tokens[i]) | ||
require.Regexp(t, `^cbid-\d+$`, tokens[i]) | ||
}(i) | ||
} | ||
wg.Wait() | ||
|
||
require.Len(t, callbacks, originalLen+numTokens) | ||
require.Equal(t, curID+numTokens, nextCbID) | ||
tokenSet := make(map[Token]bool) | ||
for _, token := range tokens { | ||
require.False(t, tokenSet[token], "Duplicate token found: %s", token) | ||
tokenSet[token] = true | ||
|
||
id, err := getIDByToken(token) | ||
require.NoError(t, err) | ||
require.Contains(t, callbacks, id) | ||
} | ||
} | ||
|
||
func Test_ConcurrentCall(t *testing.T) { | ||
const numInvocations = 1000 | ||
|
||
curID := nextCbID | ||
originalLen := len(callbacks) | ||
|
||
c := &testCallback{} | ||
token := New(c) | ||
|
||
var wg sync.WaitGroup | ||
wg.Add(numInvocations) | ||
for i := 0; i < numInvocations; i++ { | ||
go func(i int) { | ||
defer wg.Done() | ||
Call(token, fmt.Sprintf("data-%d", i)) | ||
}(i) | ||
} | ||
wg.Wait() | ||
|
||
require.Equal(t, int32(numInvocations), c.cnt.Load()) | ||
require.Regexp(t, `^data-\d+$`, c.lastData.Load()) | ||
|
||
require.Len(t, callbacks, originalLen+1) | ||
require.Equal(t, curID+1, nextCbID) | ||
} | ||
|
||
func Test_ConcurrentDelete(t *testing.T) { | ||
const ( | ||
numTokens = 50 | ||
numDeletes = 1000 | ||
) | ||
|
||
curID := nextCbID | ||
originalLen := len(callbacks) | ||
|
||
tokens := make([]Token, numTokens) | ||
for i := 0; i < numTokens; i++ { | ||
tokens[i] = New(&testCallback{}) | ||
} | ||
require.Len(t, callbacks, originalLen+numTokens) | ||
require.Equal(t, curID+numTokens, nextCbID) | ||
|
||
var wg sync.WaitGroup | ||
wg.Add(numDeletes) | ||
for i := 0; i < numDeletes; i++ { | ||
go func(i int) { | ||
defer wg.Done() | ||
Delete(tokens[i%numTokens]) | ||
}(i) | ||
} | ||
wg.Wait() | ||
|
||
require.Len(t, callbacks, originalLen) | ||
require.Equal(t, curID+numTokens, nextCbID) | ||
} | ||
|
||
// testCallback is a mock implementation of callback.Callback for testing. | ||
type testCallback struct { | ||
cnt atomic.Int32 | ||
lastData atomic.Value | ||
} | ||
|
||
func (tc *testCallback) OnCall(data string) { | ||
tc.cnt.Add(1) | ||
tc.lastData.Store(data) | ||
} | ||
|
||
func (tc *testCallback) requireEqual(t *testing.T, cnt int32, data string) { | ||
require.Equal(t, cnt, tc.cnt.Load()) | ||
if cnt == 0 { | ||
require.Nil(t, tc.lastData.Load()) | ||
} else { | ||
require.Equal(t, data, tc.lastData.Load()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.