GPU support in ACI provider (#563)

* GPU support in ACI provider
This commit is contained in:
Vipin Duleb
2019-04-02 18:11:35 -07:00
committed by Brian Goff
parent 1dadd46e20
commit bab9c59ac8
11 changed files with 557 additions and 63 deletions

View File

@@ -35,6 +35,7 @@ const (
fakeClientSecret = "VGhpcyBpcyBhIHNlY3JldAo="
fakeTenantID = "8cb81aca-83fe-4c6f-b667-4ec09c45a8bf"
fakeNodeName = "vk"
fakeRegion = "eastus"
)
// Test make registry credential
@@ -199,6 +200,166 @@ func TestCreatePodWithResourceRequestOnly(t *testing.T) {
}
}
// Tests create pod with default GPU SKU.
func TestCreatePodWithGPU(t *testing.T) {
aadServerMocker := NewAADMock()
aciServerMocker := NewACIMock()
podName := "pod-" + uuid.New().String()
podNamespace := "ns-" + uuid.New().String()
gpuSKU := aci.GPUSKU("sku-" + uuid.New().String())
aciServerMocker.OnGetRPManifest = func() (int, interface{}) {
manifest := &aci.ResourceProviderManifest{
Metadata: &aci.ResourceProviderMetadata{
GPURegionalSKUs: []*aci.GPURegionalSKU{
&aci.GPURegionalSKU{
Location: fakeRegion,
SKUs: []aci.GPUSKU{gpuSKU, aci.K80, aci.P100},
},
},
},
}
return http.StatusOK, manifest
}
provider, err := createTestProvider(aadServerMocker, aciServerMocker)
if err != nil {
t.Fatalf("failed to create the test provider. %s", err.Error())
return
}
aciServerMocker.OnCreate = func(subscription, resourceGroup, containerGroup string, cg *aci.ContainerGroup) (int, interface{}) {
assert.Check(t, is.Equal(fakeSubscription, subscription), "Subscription doesn't match")
assert.Check(t, is.Equal(fakeResourceGroup, resourceGroup), "Resource group doesn't match")
assert.Check(t, is.Equal(podNamespace+"-"+podName, containerGroup), "Container group name is not expected")
assert.Check(t, cg.ContainerGroupProperties.Containers != nil, "Containers should not be nil")
assert.Check(t, is.Equal(1, len(cg.ContainerGroupProperties.Containers)), "1 Container is expected")
assert.Check(t, is.Equal("nginx", cg.ContainerGroupProperties.Containers[0].Name), "Container nginx is expected")
assert.Check(t, cg.ContainerGroupProperties.Containers[0].Resources.Requests != nil, "Container resource requests should not be nil")
assert.Check(t, is.Equal(1.98, cg.ContainerGroupProperties.Containers[0].Resources.Requests.CPU), "Request CPU is not expected")
assert.Check(t, is.Equal(3.4, cg.ContainerGroupProperties.Containers[0].Resources.Requests.MemoryInGB), "Request Memory is not expected")
assert.Check(t, cg.ContainerGroupProperties.Containers[0].Resources.Requests.GPU != nil, "Requests GPU is not expected")
assert.Check(t, is.Equal(int32(10), cg.ContainerGroupProperties.Containers[0].Resources.Requests.GPU.Count), "Requests GPU Count is not expected")
assert.Check(t, is.Equal(gpuSKU, cg.ContainerGroupProperties.Containers[0].Resources.Requests.GPU.SKU), "Requests GPU SKU is not expected")
assert.Check(t, cg.ContainerGroupProperties.Containers[0].Resources.Limits.GPU != nil, "Limits GPU is not expected")
assert.Check(t, is.Equal(int32(10), cg.ContainerGroupProperties.Containers[0].Resources.Limits.GPU.Count), "Requests GPU Count is not expected")
assert.Check(t, is.Equal(gpuSKU, cg.ContainerGroupProperties.Containers[0].Resources.Limits.GPU.SKU), "Requests GPU SKU is not expected")
return http.StatusOK, cg
}
pod := &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: podName,
Namespace: podNamespace,
},
Spec: v1.PodSpec{
Containers: []v1.Container{
v1.Container{
Name: "nginx",
Resources: v1.ResourceRequirements{
Requests: v1.ResourceList{
"cpu": resource.MustParse("1.981"),
"memory": resource.MustParse("3.49G"),
},
Limits: v1.ResourceList{
gpuResourceName: resource.MustParse("10"),
},
},
},
},
},
}
if err := provider.CreatePod(context.Background(), pod); err != nil {
t.Fatal("Failed to create pod", err)
}
}
// Tests create pod with GPU SKU in annotation.
func TestCreatePodWithGPUSKU(t *testing.T) {
aadServerMocker := NewAADMock()
aciServerMocker := NewACIMock()
podName := "pod-" + uuid.New().String()
podNamespace := "ns-" + uuid.New().String()
gpuSKU := aci.GPUSKU("sku-" + uuid.New().String())
aciServerMocker.OnGetRPManifest = func() (int, interface{}) {
manifest := &aci.ResourceProviderManifest{
Metadata: &aci.ResourceProviderMetadata{
GPURegionalSKUs: []*aci.GPURegionalSKU{
&aci.GPURegionalSKU{
Location: fakeRegion,
SKUs: []aci.GPUSKU{aci.K80, aci.P100, gpuSKU},
},
},
},
}
return http.StatusOK, manifest
}
provider, err := createTestProvider(aadServerMocker, aciServerMocker)
if err != nil {
t.Fatalf("failed to create the test provider. %s", err.Error())
return
}
aciServerMocker.OnCreate = func(subscription, resourceGroup, containerGroup string, cg *aci.ContainerGroup) (int, interface{}) {
assert.Check(t, is.Equal(fakeSubscription, subscription), "Subscription doesn't match")
assert.Check(t, is.Equal(fakeResourceGroup, resourceGroup), "Resource group doesn't match")
assert.Check(t, cg != nil, "Container group is nil")
assert.Check(t, is.Equal(podNamespace+"-"+podName, containerGroup), "Container group name is not expected")
assert.Check(t, cg.ContainerGroupProperties.Containers != nil, "Containers should not be nil")
assert.Check(t, is.Equal(1, len(cg.ContainerGroupProperties.Containers)), "1 Container is expected")
assert.Check(t, is.Equal("nginx", cg.ContainerGroupProperties.Containers[0].Name), "Container nginx is expected")
assert.Check(t, cg.ContainerGroupProperties.Containers[0].Resources.Requests != nil, "Container resource requests should not be nil")
assert.Check(t, is.Equal(1.98, cg.ContainerGroupProperties.Containers[0].Resources.Requests.CPU), "Request CPU is not expected")
assert.Check(t, is.Equal(3.4, cg.ContainerGroupProperties.Containers[0].Resources.Requests.MemoryInGB), "Request Memory is not expected")
assert.Check(t, cg.ContainerGroupProperties.Containers[0].Resources.Requests.GPU != nil, "Requests GPU is not expected")
assert.Check(t, is.Equal(int32(1), cg.ContainerGroupProperties.Containers[0].Resources.Requests.GPU.Count), "Requests GPU Count is not expected")
assert.Check(t, is.Equal(gpuSKU, cg.ContainerGroupProperties.Containers[0].Resources.Requests.GPU.SKU), "Requests GPU SKU is not expected")
assert.Check(t, cg.ContainerGroupProperties.Containers[0].Resources.Limits.GPU != nil, "Limits GPU is not expected")
assert.Check(t, is.Equal(int32(1), cg.ContainerGroupProperties.Containers[0].Resources.Limits.GPU.Count), "Requests GPU Count is not expected")
assert.Check(t, is.Equal(gpuSKU, cg.ContainerGroupProperties.Containers[0].Resources.Limits.GPU.SKU), "Requests GPU SKU is not expected")
return http.StatusOK, cg
}
pod := &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: podName,
Namespace: podNamespace,
Annotations: map[string]string{
gpuTypeAnnotation: string(gpuSKU),
},
},
Spec: v1.PodSpec{
Containers: []v1.Container{
v1.Container{
Name: "nginx",
Resources: v1.ResourceRequirements{
Requests: v1.ResourceList{
"cpu": resource.MustParse("1.981"),
"memory": resource.MustParse("3.49G"),
},
Limits: v1.ResourceList{
gpuResourceName: resource.MustParse("1"),
},
},
},
},
},
}
if err := provider.CreatePod(context.Background(), pod); err != nil {
t.Fatal("Failed to create pod", err)
}
}
// Tests create pod with both resource request and limit.
func TestCreatePodWithResourceRequestAndLimit(t *testing.T) {
_, aciServerMocker, provider, err := prepareMocks()
@@ -314,7 +475,7 @@ func TestGetPodsWithoutResourceRequestsLimits(t *testing.T) {
},
},
Resources: aci.ResourceRequirements{
Requests: &aci.ResourceRequests{
Requests: &aci.ComputeResources{
CPU: 0.99,
MemoryInGB: 1.5,
},
@@ -385,7 +546,7 @@ func TestGetPodWithoutResourceRequestsLimits(t *testing.T) {
},
},
Resources: aci.ResourceRequirements{
Requests: &aci.ResourceRequests{
Requests: &aci.ComputeResources{
CPU: 0.99,
MemoryInGB: 1.5,
},
@@ -412,6 +573,94 @@ func TestGetPodWithoutResourceRequestsLimits(t *testing.T) {
pod.Spec.Containers[0].Resources.Requests.Memory().Value()), "Containers[0].Resources.Requests.Memory doesn't match")
}
// Tests get pod with GPU.
func TestGetPodWithGPU(t *testing.T) {
_, aciServerMocker, provider, err := prepareMocks()
if err != nil {
t.Fatal("Unable to prepare the mocks", err)
}
podName := "pod-" + uuid.New().String()
podNamespace := "ns-" + uuid.New().String()
aciServerMocker.OnGetContainerGroup = func(subscription, resourceGroup, containerGroup string) (int, interface{}) {
assert.Equal(t, fakeSubscription, subscription, "Subscription doesn't match")
assert.Equal(t, fakeResourceGroup, resourceGroup, "Resource group doesn't match")
assert.Equal(t, podNamespace+"-"+podName, containerGroup, "Container group name is not expected")
return http.StatusOK, aci.ContainerGroup{
Tags: map[string]string{
"NodeName": fakeNodeName,
},
ContainerGroupProperties: aci.ContainerGroupProperties{
ProvisioningState: "Creating",
Containers: []aci.Container{
aci.Container{
Name: "nginx",
ContainerProperties: aci.ContainerProperties{
Image: "nginx",
Command: []string{"nginx", "-g", "daemon off;"},
Ports: []aci.ContainerPort{
{
Protocol: aci.ContainerNetworkProtocolTCP,
Port: 80,
},
},
Resources: aci.ResourceRequirements{
Requests: &aci.ComputeResources{
CPU: 0.99,
MemoryInGB: 1.5,
GPU: &aci.GPUResource{
Count: 5,
SKU: aci.P100,
},
},
Limits: &aci.ComputeResources{
GPU: &aci.GPUResource{
Count: 5,
SKU: aci.P100,
},
},
},
},
},
},
},
}
}
pod, err := provider.GetPod(context.Background(), podNamespace, podName)
if err != nil {
t.Fatal("Failed to get pod", err)
}
assert.Check(t, pod != nil, "Response pod should not be nil")
assert.Check(t, pod.Spec.Containers != nil, "Containers should not be nil")
assert.Check(t, pod.Spec.Containers[0].Resources.Requests != nil, "Containers[0].Resources.Requests should not be nil")
assert.Check(
t,
is.Equal(ptrQuantity(resource.MustParse("0.99")).Value(), pod.Spec.Containers[0].Resources.Requests.Cpu().Value()),
"Containers[0].Resources.Requests.CPU doesn't match")
assert.Check(
t,
is.Equal(ptrQuantity(resource.MustParse("1.5G")).Value(), pod.Spec.Containers[0].Resources.Requests.Memory().Value()),
"Containers[0].Resources.Requests.Memory doesn't match")
gpuQuantity, ok := pod.Spec.Containers[0].Resources.Requests[gpuResourceName]
assert.Check(t, is.Equal(ok, true), "Containers[0].Resources.Requests.GPU should not be nil")
assert.Check(
t,
is.Equal(ptrQuantity(resource.MustParse("5")).Value(), ptrQuantity(gpuQuantity).Value()),
"Containers[0].Resources.Requests.GPU.Count doesn't match")
assert.Check(t, pod.Spec.Containers[0].Resources.Limits != nil, "Containers[0].Resources.Limits should not be nil")
gpuQuantity, ok = pod.Spec.Containers[0].Resources.Limits[gpuResourceName]
assert.Check(t, is.Equal(ok, true), "Containers[0].Resources.Requests.GPU should not be nil")
assert.Check(
t,
is.Equal(ptrQuantity(resource.MustParse("5")).Value(), ptrQuantity(gpuQuantity).Value()),
"Containers[0].Resources.Limits.GPU.Count doesn't match")
}
func TestGetPodWithContainerID(t *testing.T) {
_, aciServerMocker, provider, err := prepareMocks()
@@ -451,7 +700,7 @@ func TestGetPodWithContainerID(t *testing.T) {
},
},
Resources: aci.ResourceRequirements{
Requests: &aci.ResourceRequests{
Requests: &aci.ComputeResources{
CPU: 0.99,
MemoryInGB: 1.5,
},
@@ -531,6 +780,30 @@ func prepareMocks() (*AADMock, *ACIMock, *ACIProvider, error) {
aadServerMocker := NewAADMock()
aciServerMocker := NewACIMock()
aciServerMocker.OnGetRPManifest = func() (int, interface{}) {
manifest := &aci.ResourceProviderManifest{
Metadata: &aci.ResourceProviderMetadata{
GPURegionalSKUs: []*aci.GPURegionalSKU{
&aci.GPURegionalSKU{
Location: fakeRegion,
SKUs: []aci.GPUSKU{aci.K80, aci.P100, aci.V100},
},
},
},
}
return http.StatusOK, manifest
}
provider, err := createTestProvider(aadServerMocker, aciServerMocker)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to create the test provider %s", err.Error())
}
return aadServerMocker, aciServerMocker, provider, nil
}
func createTestProvider(aadServerMocker *AADMock, aciServerMocker*ACIMock) (*ACIProvider, error) {
auth := azure.NewAuthentication(
azure.PublicCloud.Name,
fakeClientID,
@@ -543,7 +816,7 @@ func prepareMocks() (*AADMock, *ACIMock, *ACIProvider, error) {
file, err := ioutil.TempFile("", "auth.json")
if err != nil {
return nil, nil, nil, err
return nil, err
}
defer os.Remove(file.Name())
@@ -552,23 +825,25 @@ func prepareMocks() (*AADMock, *ACIMock, *ACIProvider, error) {
json.NewEncoder(b).Encode(auth)
if _, err := file.Write(b.Bytes()); err != nil {
return nil, nil, nil, err
return nil, err
}
os.Setenv("AZURE_AUTH_LOCATION", file.Name())
os.Setenv("ACI_RESOURCE_GROUP", fakeResourceGroup)
os.Setenv("ACI_REGION", fakeRegion)
rm, err := manager.NewResourceManager(nil, nil, nil)
if err != nil {
return nil, nil, nil, err
return nil, err
}
provider, err := NewACIProvider("example.toml", rm, fakeNodeName, "Linux", "0.0.0.0", 10250)
if err != nil {
return nil, nil, nil, err
return nil, err
}
return aadServerMocker, aciServerMocker, provider, nil
return provider, nil
}
func ptrQuantity(q resource.Quantity) *resource.Quantity {