From fc8b67a902a1722aab32b33492c4f63819eb4c2b Mon Sep 17 00:00:00 2001 From: Amir Raminfar Date: Fri, 2 Feb 2024 13:56:48 -0800 Subject: [PATCH] fix: fixes race issues (#2747) --- .reflex | 2 +- internal/docker/client.go | 7 ++-- internal/docker/container_store.go | 51 +++++++++++++++++------------- internal/docker/stats_collector.go | 29 +++++++++-------- internal/utils/ring_buffer.go | 18 +++++------ internal/utils/ring_buffer_test.go | 12 ------- 6 files changed, 56 insertions(+), 63 deletions(-) diff --git a/.reflex b/.reflex index 47b44891..19ce32bd 100644 --- a/.reflex +++ b/.reflex @@ -1 +1 @@ --r '\.(go)$' -R 'node_modules' -G '\*\_test.go' -s -- go run main.go --level debug +-r '\.(go)$' -R 'node_modules' -G '\*\_test.go' -s -- go run -race main.go --level debug diff --git a/internal/docker/client.go b/internal/docker/client.go index 7ab46c5f..670ea789 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -306,11 +306,8 @@ func (d *_client) Events(ctx context.Context, messages chan<- ContainerEvent) <- select { case <-ctx.Done(): return - case err, ok := <-errors: - if !ok { - log.Errorf("docker events channel closed") - } - log.Warnf("error while listening to docker events: %v", err) + case err := <-errors: + log.Fatalf("error while listening to docker events: %v. Exiting...", err) case message, ok := <-dockerMessages: if !ok { log.Errorf("docker events channel closed") diff --git a/internal/docker/container_store.go b/internal/docker/container_store.go index 9d3917df..86532645 100644 --- a/internal/docker/container_store.go +++ b/internal/docker/container_store.go @@ -2,22 +2,23 @@ package docker import ( "context" + "sync" log "github.com/sirupsen/logrus" ) type ContainerStore struct { - containers map[string]*Container + containers sync.Map client Client statsCollector *StatsCollector - subscribers map[context.Context]chan ContainerEvent + subscribers sync.Map } func NewContainerStore(client Client) *ContainerStore { s := &ContainerStore{ - containers: make(map[string]*Container), + containers: sync.Map{}, client: client, - subscribers: make(map[context.Context]chan ContainerEvent), + subscribers: sync.Map{}, statsCollector: NewStatsCollector(client), } @@ -28,10 +29,11 @@ func NewContainerStore(client Client) *ContainerStore { } func (s *ContainerStore) List() []Container { - containers := make([]Container, 0, len(s.containers)) - for _, c := range s.containers { - containers = append(containers, *c) - } + containers := make([]Container, 0) + s.containers.Range(func(_, value any) bool { + containers = append(containers, value.(Container)) + return true + }) return containers } @@ -41,7 +43,7 @@ func (s *ContainerStore) Client() Client { } func (s *ContainerStore) Subscribe(ctx context.Context, events chan ContainerEvent) { - s.subscribers[ctx] = events + s.subscribers.Store(ctx, events) } func (s *ContainerStore) SubscribeStats(ctx context.Context, stats chan ContainerStat) { @@ -56,7 +58,7 @@ func (s *ContainerStore) init(ctx context.Context) { for _, c := range containers { c := c // create a new variable to avoid capturing the loop variable - s.containers[c.ID] = &c + s.containers.Store(c.ID, c) } events := make(chan ContainerEvent) @@ -72,14 +74,16 @@ func (s *ContainerStore) init(ctx context.Context) { switch event.Name { case "start": if container, err := s.client.FindContainer(event.ActorID); err == nil { - s.containers[container.ID] = &container + log.Debugf("container %s started", container.ID) + s.containers.Store(container.ID, container) } case "destroy": log.Debugf("container %s destroyed", event.ActorID) - delete(s.containers, event.ActorID) + s.containers.Delete(event.ActorID) case "die": - if container, ok := s.containers[event.ActorID]; ok { + if value, ok := s.containers.Load(event.ActorID); ok { + container := value.(Container) log.Debugf("container %s died", container.ID) container.State = "exited" } @@ -88,22 +92,25 @@ func (s *ContainerStore) init(ctx context.Context) { if event.Name == "health_status: healthy" { healthy = "healthy" } - if container, ok := s.containers[event.ActorID]; ok { + if value, ok := s.containers.Load(event.ActorID); ok { + container := value.(Container) log.Debugf("container %s is %s", container.ID, healthy) container.Health = healthy } } - - for ctx, sub := range s.subscribers { + s.subscribers.Range(func(key, value any) bool { select { - case sub <- event: - case <-ctx.Done(): - delete(s.subscribers, ctx) + case value.(chan ContainerEvent) <- event: + case <-key.(context.Context).Done(): + s.subscribers.Delete(key) } - } + return true + }) + case stat := <-stats: - if container, ok := s.containers[stat.ID]; ok { - container.Stats.Push(stat) + if container, ok := s.containers.Load(stat.ID); ok { + stat.ID = "" + container.(Container).Stats.Push(stat) } case <-ctx.Done(): return diff --git a/internal/docker/stats_collector.go b/internal/docker/stats_collector.go index dcbd561e..f2c1fad5 100644 --- a/internal/docker/stats_collector.go +++ b/internal/docker/stats_collector.go @@ -4,28 +4,29 @@ import ( "context" "errors" "io" + "sync" log "github.com/sirupsen/logrus" ) type StatsCollector struct { stream chan ContainerStat - subscribers map[context.Context]chan ContainerStat + subscribers sync.Map client Client - cancelers map[string]context.CancelFunc + cancelers sync.Map } func NewStatsCollector(client Client) *StatsCollector { return &StatsCollector{ stream: make(chan ContainerStat), - subscribers: make(map[context.Context]chan ContainerStat), + subscribers: sync.Map{}, client: client, - cancelers: make(map[string]context.CancelFunc), + cancelers: sync.Map{}, } } func (c *StatsCollector) Subscribe(ctx context.Context, stats chan ContainerStat) { - c.subscribers[ctx] = stats + c.subscribers.Store(ctx, stats) } func (sc *StatsCollector) StartCollecting(ctx context.Context) { @@ -34,7 +35,7 @@ func (sc *StatsCollector) StartCollecting(ctx context.Context) { if c.State == "running" { go func(client Client, id string) { ctx, cancel := context.WithCancel(ctx) - sc.cancelers[id] = cancel + sc.cancelers.Store(id, cancel) if err := client.ContainerStats(ctx, id, sc.stream); err != nil { if !errors.Is(err, context.Canceled) && !errors.Is(err, io.EOF) { log.Errorf("unexpected error when streaming container stats: %v", err) @@ -62,9 +63,8 @@ func (sc *StatsCollector) StartCollecting(ctx context.Context) { }(sc.client, event.ActorID) case "die": - if cancel, ok := sc.cancelers[event.ActorID]; ok { - cancel() - delete(sc.cancelers, event.ActorID) + if cancel, ok := sc.cancelers.LoadAndDelete(event.ActorID); ok { + cancel.(context.CancelFunc)() } } } @@ -75,13 +75,14 @@ func (sc *StatsCollector) StartCollecting(ctx context.Context) { case <-ctx.Done(): return case stat := <-sc.stream: - for c, sub := range sc.subscribers { + sc.subscribers.Range(func(key, value interface{}) bool { select { - case sub <- stat: - case <-c.Done(): - delete(sc.subscribers, c) + case value.(chan ContainerStat) <- stat: + case <-key.(context.Context).Done(): + sc.subscribers.Delete(key) } - } + return true + }) } } } diff --git a/internal/utils/ring_buffer.go b/internal/utils/ring_buffer.go index 725a93b1..42475899 100644 --- a/internal/utils/ring_buffer.go +++ b/internal/utils/ring_buffer.go @@ -1,11 +1,15 @@ package utils -import "encoding/json" +import ( + "encoding/json" + "sync" +) type RingBuffer[T any] struct { Size int data []T start int + mutex sync.RWMutex } func NewRingBuffer[T any](size int) *RingBuffer[T] { @@ -16,6 +20,8 @@ func NewRingBuffer[T any](size int) *RingBuffer[T] { } func (r *RingBuffer[T]) Push(data T) { + r.mutex.Lock() + defer r.mutex.Unlock() if len(r.data) == r.Size { r.data[r.start] = data r.start = (r.start + 1) % r.Size @@ -25,6 +31,8 @@ func (r *RingBuffer[T]) Push(data T) { } func (r *RingBuffer[T]) Data() []T { + r.mutex.RLock() + defer r.mutex.RUnlock() if len(r.data) == r.Size { return append(r.data[r.start:], r.data[:r.start]...) } else { @@ -32,14 +40,6 @@ func (r *RingBuffer[T]) Data() []T { } } -func (r *RingBuffer[T]) Len() int { - return len(r.data) -} - -func (r *RingBuffer[T]) Full() bool { - return len(r.data) == r.Size -} - func (r *RingBuffer[T]) MarshalJSON() ([]byte, error) { return json.Marshal(r.Data()) } diff --git a/internal/utils/ring_buffer_test.go b/internal/utils/ring_buffer_test.go index e064d42a..6141892c 100644 --- a/internal/utils/ring_buffer_test.go +++ b/internal/utils/ring_buffer_test.go @@ -8,22 +8,10 @@ import ( func TestRingBuffer(t *testing.T) { rb := NewRingBuffer[int](3) - if rb.Len() != 0 { - t.Errorf("Expected length to be 0, got %d", rb.Len()) - } - rb.Push(1) rb.Push(2) rb.Push(3) - if rb.Len() != 3 { - t.Errorf("Expected length to be 3, got %d", rb.Len()) - } - - if !rb.Full() { - t.Errorf("Expected buffer to be full") - } - data := rb.Data() expectedData := []int{1, 2, 3} if !reflect.DeepEqual(data, expectedData) {