diff --git a/vkubelet/pod_test.go b/vkubelet/pod_test.go index 9e02c788f..aa1773bb8 100644 --- a/vkubelet/pod_test.go +++ b/vkubelet/pod_test.go @@ -2,9 +2,11 @@ package vkubelet import ( "context" + "path" "testing" - "github.com/virtual-kubelet/virtual-kubelet/providers/mock" + "github.com/cpuguy83/strongerrors" + pkgerrors "github.com/pkg/errors" testutil "github.com/virtual-kubelet/virtual-kubelet/test/util" "gotest.tools/assert" is "gotest.tools/assert/cmp" @@ -12,60 +14,81 @@ import ( "k8s.io/client-go/kubernetes/fake" ) -type FakeProvider struct { - *mock.MockProvider - createFn func() - updateFn func() +type mockProvider struct { + pods map[string]*corev1.Pod + + creates int + updates int + deletes int } -func (f *FakeProvider) CreatePod(ctx context.Context, pod *corev1.Pod) error { - f.createFn() - return f.MockProvider.CreatePod(ctx, pod) +func (m *mockProvider) CreatePod(ctx context.Context, pod *corev1.Pod) error { + m.pods[path.Join(pod.GetNamespace(), pod.GetName())] = pod + m.creates++ + return nil } -func (f *FakeProvider) UpdatePod(ctx context.Context, pod *corev1.Pod) error { - f.updateFn() - return f.MockProvider.CreatePod(ctx, pod) +func (m *mockProvider) UpdatePod(ctx context.Context, pod *corev1.Pod) error { + m.pods[path.Join(pod.GetNamespace(), pod.GetName())] = pod + m.updates++ + return nil +} + +func (m *mockProvider) GetPod(ctx context.Context, namespace, name string) (*corev1.Pod, error) { + p := m.pods[path.Join(namespace, name)] + if p == nil { + return nil, strongerrors.NotFound(pkgerrors.New("not found")) + } + return p, nil +} + +func (m *mockProvider) GetPodStatus(ctx context.Context, namespace, name string) (*corev1.PodStatus, error) { + p := m.pods[path.Join(namespace, name)] + if p == nil { + return nil, strongerrors.NotFound(pkgerrors.New("not found")) + } + return &p.Status, nil +} + +func (m *mockProvider) DeletePod(ctx context.Context, p *corev1.Pod) error { + delete(m.pods, path.Join(p.GetNamespace(), p.GetName())) + m.deletes++ + return nil +} + +func (m *mockProvider) GetPods(_ context.Context) ([]*corev1.Pod, error) { + ls := make([]*corev1.Pod, 0, len(m.pods)) + for _, p := range ls { + ls = append(ls, p) + } + return ls, nil } type TestServer struct { *Server - mock *FakeProvider + mock *mockProvider client *fake.Clientset } -func newMockProvider(t *testing.T) (*mock.MockProvider, error) { - return mock.NewMockProviderMockConfig( - mock.MockConfig{}, - "vk123", - "linux", - "127.0.0.1", - 443, - ) +func newMockProvider() *mockProvider { + return &mockProvider{pods: make(map[string]*corev1.Pod)} } -func newTestServer(t *testing.T) *TestServer { - - mockProvider, err := newMockProvider(t) - assert.Check(t, is.Nil(err)) - +func newTestServer() *TestServer { fk8s := fake.NewSimpleClientset() - fakeProvider := &FakeProvider{ - MockProvider: mockProvider, - } - rm := testutil.FakeResourceManager() + p := newMockProvider() tsvr := &TestServer{ Server: &Server{ namespace: "default", nodeName: "vk123", - provider: fakeProvider, + provider: p, resourceManager: rm, k8sClient: fk8s, }, - mock: fakeProvider, + mock: p, client: fk8s, } return tsvr @@ -146,7 +169,7 @@ func TestPodHashingDifferent(t *testing.T) { } func TestPodCreateNewPod(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer() pod := &corev1.Pod{} pod.ObjectMeta.Namespace = "default" @@ -166,25 +189,16 @@ func TestPodCreateNewPod(t *testing.T) { }, } - created := false - updated := false - // The pod doesn't exist, we should invoke the CreatePod() method of the provider - svr.mock.createFn = func() { - created = true - } - svr.mock.updateFn = func() { - updated = true - } er := testutil.FakeEventRecorder(5) err := svr.createOrUpdatePod(context.Background(), pod, er) assert.Check(t, is.Nil(err)) // createOrUpdate called CreatePod but did not call UpdatePod because the pod did not exist - assert.Check(t, created) - assert.Check(t, !updated) + assert.Check(t, is.Equal(svr.mock.creates, 1)) + assert.Check(t, is.Equal(svr.mock.updates, 0)) } func TestPodUpdateExisting(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer() pod := &corev1.Pod{} pod.ObjectMeta.Namespace = "default" @@ -204,17 +218,10 @@ func TestPodUpdateExisting(t *testing.T) { }, } - err := svr.mock.MockProvider.CreatePod(context.Background(), pod) + err := svr.provider.CreatePod(context.Background(), pod) assert.Check(t, is.Nil(err)) - created := false - updated := false - // The pod doesn't exist, we should invoke the CreatePod() method of the provider - svr.mock.createFn = func() { - created = true - } - svr.mock.updateFn = func() { - updated = true - } + assert.Check(t, is.Equal(svr.mock.creates, 1)) + assert.Check(t, is.Equal(svr.mock.updates, 0)) pod2 := &corev1.Pod{} pod2.ObjectMeta.Namespace = "default" @@ -239,12 +246,12 @@ func TestPodUpdateExisting(t *testing.T) { assert.Check(t, is.Nil(err)) // createOrUpdate didn't call CreatePod but did call UpdatePod because the spec changed - assert.Check(t, !created) - assert.Check(t, updated) + assert.Check(t, is.Equal(svr.mock.creates, 1)) + assert.Check(t, is.Equal(svr.mock.updates, 1)) } func TestPodNoSpecChange(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer() pod := &corev1.Pod{} pod.ObjectMeta.Namespace = "default" @@ -264,23 +271,16 @@ func TestPodNoSpecChange(t *testing.T) { }, } - err := svr.mock.MockProvider.CreatePod(context.Background(), pod) + err := svr.mock.CreatePod(context.Background(), pod) assert.Check(t, is.Nil(err)) - created := false - updated := false - // The pod doesn't exist, we should invoke the CreatePod() method of the provider - svr.mock.createFn = func() { - created = true - } - svr.mock.updateFn = func() { - updated = true - } + assert.Check(t, is.Equal(svr.mock.creates, 1)) + assert.Check(t, is.Equal(svr.mock.updates, 0)) er := testutil.FakeEventRecorder(5) err = svr.createOrUpdatePod(context.Background(), pod, er) assert.Check(t, is.Nil(err)) // createOrUpdate didn't call CreatePod or UpdatePod, spec didn't change - assert.Check(t, !created) - assert.Check(t, !updated) + assert.Check(t, is.Equal(svr.mock.creates, 1)) + assert.Check(t, is.Equal(svr.mock.updates, 0)) }