From ffa062dc954b4ee739163d435ef90bb5b1ebeb02 Mon Sep 17 00:00:00 2001 From: Tonis Tiigi Date: Wed, 26 Jan 2022 21:09:04 -0800 Subject: [PATCH] util: add waitmap for target synchronization Signed-off-by: Tonis Tiigi --- util/waitmap/waitmap.go | 74 ++++++++++++++++++++++++++++++++++++ util/waitmap/waitmap_test.go | 64 +++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+) create mode 100644 util/waitmap/waitmap.go create mode 100644 util/waitmap/waitmap_test.go diff --git a/util/waitmap/waitmap.go b/util/waitmap/waitmap.go new file mode 100644 index 00000000..c34b8f0c --- /dev/null +++ b/util/waitmap/waitmap.go @@ -0,0 +1,74 @@ +package waitmap + +import ( + "context" + "sync" +) + +type Map struct { + mu sync.RWMutex + m map[string]interface{} + ch map[string]chan struct{} +} + +func New() *Map { + return &Map{ + m: make(map[string]interface{}), + ch: make(map[string]chan struct{}), + } +} + +func (m *Map) Set(key string, value interface{}) { + m.mu.Lock() + defer m.mu.Unlock() + + m.m[key] = value + + if ch, ok := m.ch[key]; ok { + if ch != nil { + close(ch) + } + } + m.ch[key] = nil +} + +func (m *Map) Get(ctx context.Context, keys ...string) (map[string]interface{}, error) { + if len(keys) == 0 { + return map[string]interface{}{}, nil + } + + if len(keys) > 1 { + out := make(map[string]interface{}) + for _, key := range keys { + mm, err := m.Get(ctx, key) + if err != nil { + return nil, err + } + out[key] = mm[key] + } + return out, nil + } + + key := keys[0] + m.mu.Lock() + ch, ok := m.ch[key] + if !ok { + ch = make(chan struct{}) + m.ch[key] = ch + } + + if ch != nil { + m.mu.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ch: + m.mu.Lock() + } + } + + res := m.m[key] + m.mu.Unlock() + + return map[string]interface{}{key: res}, nil +} diff --git a/util/waitmap/waitmap_test.go b/util/waitmap/waitmap_test.go new file mode 100644 index 00000000..319be611 --- /dev/null +++ b/util/waitmap/waitmap_test.go @@ -0,0 +1,64 @@ +package waitmap + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestGetAfter(t *testing.T) { + m := New() + + m.Set("foo", "bar") + m.Set("bar", "baz") + + ctx := context.TODO() + v, err := m.Get(ctx, "foo", "bar") + require.NoError(t, err) + + require.Equal(t, 2, len(v)) + require.Equal(t, "bar", v["foo"]) + require.Equal(t, "baz", v["bar"]) + + v, err = m.Get(ctx, "foo") + require.NoError(t, err) + require.Equal(t, 1, len(v)) + require.Equal(t, "bar", v["foo"]) +} + +func TestTimeout(t *testing.T) { + m := New() + + m.Set("foo", "bar") + + ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond) + defer cancel() + + _, err := m.Get(ctx, "bar") + require.Error(t, err) + require.True(t, errors.Is(err, context.DeadlineExceeded)) +} + +func TestBlocking(t *testing.T) { + m := New() + + m.Set("foo", "bar") + + go func() { + time.Sleep(100 * time.Millisecond) + m.Set("bar", "baz") + time.Sleep(50 * time.Millisecond) + m.Set("baz", "abc") + }() + + ctx := context.TODO() + v, err := m.Get(ctx, "foo", "bar", "baz") + require.NoError(t, err) + require.Equal(t, 3, len(v)) + require.Equal(t, "bar", v["foo"]) + require.Equal(t, "baz", v["bar"]) + require.Equal(t, "abc", v["baz"]) +}