diff --git a/internal/lock/monitor.go b/internal/lock/monitor.go new file mode 100644 index 000000000..331ddc610 --- /dev/null +++ b/internal/lock/monitor.go @@ -0,0 +1,98 @@ +package lock + +import ( + "sync" +) + +// NewMonitorVariable instantiates an empty monitor variable +func NewMonitorVariable() MonitorVariable { + mv := &monitorVariable{ + versionInvalidationChannel: make(chan struct{}), + } + return mv +} + +// MonitorVariable is a specific monitor variable which allows for channel-subscription to changes to +// the internal value of the MonitorVariable. +type MonitorVariable interface { + Set(value interface{}) + Subscribe() Subscription +} + +// Subscription is not concurrency safe. It must not be shared between multiple goroutines. +type Subscription interface { + // On instantiation, if the value has been set, this will return a closed channel. Otherwise, it will follow the + // standard semantic, which is when the Monitor Variable is updated, this channel will close. The channel is updated + // based on reading Value(). Once a value is read, the channel returned will only be closed if a the Monitor Variable + // is set to a new value. + NewValueReady() <-chan struct{} + // Value returns a value object in a non-blocking fashion. This also means it may return an uninitialized value. + // If the monitor variable has not yet been set, the "Version" of the value will be 0. + Value() Value +} + +type Value struct { + Value interface{} + Version int64 +} + +type monitorVariable struct { + lock sync.Mutex + currentValue interface{} + // 0 indicates uninitialized + currentVersion int64 + versionInvalidationChannel chan struct{} +} + +func (m *monitorVariable) Set(newValue interface{}) { + m.lock.Lock() + defer m.lock.Unlock() + m.currentValue = newValue + m.currentVersion++ + close(m.versionInvalidationChannel) + m.versionInvalidationChannel = make(chan struct{}) +} + +func (m *monitorVariable) Subscribe() Subscription { + m.lock.Lock() + defer m.lock.Unlock() + sub := &subscription{ + mv: m, + } + if m.currentVersion > 0 { + // A value has been set. Set the first versionInvalidationChannel to a closed one. + closedCh := make(chan struct{}) + close(closedCh) + sub.lastVersionReadInvalidationChannel = closedCh + } else { + // The value hasn't yet been initialized. + sub.lastVersionReadInvalidationChannel = m.versionInvalidationChannel + } + + return sub +} + +type subscription struct { + mv *monitorVariable + lastVersionRead int64 + lastVersionReadInvalidationChannel chan struct{} +} + +func (s *subscription) NewValueReady() <-chan struct{} { + /* This lock could be finer grained (on just the subscription) */ + s.mv.lock.Lock() + defer s.mv.lock.Unlock() + return s.lastVersionReadInvalidationChannel +} + +func (s *subscription) Value() Value { + s.mv.lock.Lock() + defer s.mv.lock.Unlock() + val := Value{ + Value: s.mv.currentValue, + Version: s.mv.currentVersion, + } + s.lastVersionRead = s.mv.currentVersion + s.lastVersionReadInvalidationChannel = s.mv.versionInvalidationChannel + return val +} diff --git a/internal/lock/monitor_test.go b/internal/lock/monitor_test.go new file mode 100644 index 000000000..e95e41c3a --- /dev/null +++ b/internal/lock/monitor_test.go @@ -0,0 +1,113 @@ +package lock + +import ( + "sync" + "testing" + "time" + + "golang.org/x/sync/errgroup" + "k8s.io/apimachinery/pkg/util/sets" + + "gotest.tools/assert" + is "gotest.tools/assert/cmp" +) + +func TestMonitorUninitialized(t *testing.T) { + t.Parallel() + mv := NewMonitorVariable() + subscription := mv.Subscribe() + select { + case <-subscription.NewValueReady(): + t.Fatalf("Received value update message: %v", subscription.Value()) + case <-time.After(time.Second): + } +} + +func TestGetUninitialized(t *testing.T) { + mv := NewMonitorVariable() + subscription := mv.Subscribe() + val := subscription.Value() + assert.Assert(t, is.Equal(val.Version, int64(0))) +} + +func TestMonitorSetInitialVersionAfterListen(t *testing.T) { + mv := NewMonitorVariable() + subscription := mv.Subscribe() + go mv.Set("test") + <-subscription.NewValueReady() + assert.Assert(t, is.Equal(subscription.Value().Value, "test")) +} + +func TestMonitorSetInitialVersionBeforeListen(t *testing.T) { + mv := NewMonitorVariable() + subscription := mv.Subscribe() + mv.Set("test") + <-subscription.NewValueReady() + assert.Assert(t, is.Equal(subscription.Value().Value, "test")) +} + +func TestMonitorMultipleVersionsBlock(t *testing.T) { + t.Parallel() + mv := NewMonitorVariable() + subscription := mv.Subscribe() + mv.Set("test") + <-subscription.NewValueReady() + /* This should mark the "current" version as seen */ + val := subscription.Value() + assert.Assert(t, is.Equal(val.Version, int64(1))) + select { + case <-subscription.NewValueReady(): + t.Fatalf("Received value update message: %v", subscription.Value()) + case <-time.After(time.Second): + } +} +func TestMonitorMultipleVersions(t *testing.T) { + t.Parallel() + lock := sync.Mutex{} + lock.Lock() + mv := NewMonitorVariable() + triggers := []int{} + ch := make(chan struct{}, 10) + go func() { + defer lock.Unlock() + subscription := mv.Subscribe() + for { + select { + case <-subscription.NewValueReady(): + val := subscription.Value() + triggers = append(triggers, val.Value.(int)) + ch <- struct{}{} + if val.Value == 9 { + return + } + } + + } + }() + + for i := 0; i < 10; i++ { + mv.Set(i) + // Wait for the trigger to occur + <-ch + } + + // Wait for the goroutine to finish + lock.Lock() + t.Logf("Saw %v triggers", triggers) + assert.Assert(t, is.Len(triggers, 10)) + // Make sure we saw all 10 unique values + assert.Assert(t, is.Equal(sets.NewInt(triggers...).Len(), 10)) +} +func TestMonitorMultipleSubscribers(t *testing.T) { + group := &errgroup.Group{} + mv := NewMonitorVariable() + for i := 0; i < 10; i++ { + sub := mv.Subscribe() + group.Go(func() error { + <-sub.NewValueReady() + return nil + }) + } + mv.Set(1) + _ = group.Wait() +} diff --git a/node/node_ping_controller.go b/node/node_ping_controller.go index d1afb3a43..0d603fd66 100644 --- a/node/node_ping_controller.go +++ b/node/node_ping_controller.go @@ -2,9 +2,9 @@ package node import ( "context" - "sync" "time" + "github.com/virtual-kubelet/virtual-kubelet/internal/lock" "github.com/virtual-kubelet/virtual-kubelet/log" "github.com/virtual-kubelet/virtual-kubelet/trace" "golang.org/x/sync/singleflight" @@ -12,14 +12,10 @@ import ( ) type nodePingController struct { - nodeProvider NodeProvider - pingInterval time.Duration - firstPingCompleted chan struct{} - pingTimeout *time.Duration - - // "Results" - sync.Mutex - result *pingResult + nodeProvider NodeProvider + pingInterval time.Duration + pingTimeout *time.Duration + cond lock.MonitorVariable } type pingResult struct { @@ -37,10 +33,10 @@ func newNodePingController(node NodeProvider, pingInterval time.Duration, timeou } return &nodePingController{ - nodeProvider: node, - pingInterval: pingInterval, - firstPingCompleted: make(chan struct{}), - pingTimeout: timeout, + nodeProvider: node, + pingInterval: pingInterval, + pingTimeout: timeout, + cond: lock.NewMonitorVariable(), } } @@ -87,28 +83,26 @@ func (npc *nodePingController) run(ctx context.Context) { pingResult.pingTime = result.Val.(time.Time) } - npc.Lock() - defer npc.Unlock() - npc.result = &pingResult + npc.cond.Set(&pingResult) span.SetStatus(pingResult.error) } // Run the first check manually checkFunc(ctx) - close(npc.firstPingCompleted) - wait.UntilWithContext(ctx, checkFunc, npc.pingInterval) } +// getResult returns the current ping result in a non-blocking fashion except for the first ping. It waits for the +// first ping to be successful before returning. If the context is cancelled while waiting for that value, it will +// return immediately. func (npc *nodePingController) getResult(ctx context.Context) (*pingResult, error) { + sub := npc.cond.Subscribe() select { case <-ctx.Done(): return nil, ctx.Err() - case <-npc.firstPingCompleted: + case <-sub.NewValueReady(): } - npc.Lock() - defer npc.Unlock() - return npc.result, nil + return sub.Value().Value.(*pingResult), nil }