@@ -53,6 +53,12 @@ const (
|
||||
maxDNSSearchListChars = 256
|
||||
)
|
||||
|
||||
const (
|
||||
gpuResourceName v1.ResourceName = "nvidia.com/gpu"
|
||||
gpuTypeAnnotation = "virtual-kubelet.io/gpu-type"
|
||||
)
|
||||
|
||||
|
||||
// ACIProvider implements the virtual-kubelet provider interface and communicates with Azure's ACI APIs.
|
||||
type ACIProvider struct {
|
||||
aciClient *aci.Client
|
||||
@@ -64,6 +70,8 @@ type ACIProvider struct {
|
||||
cpu string
|
||||
memory string
|
||||
pods string
|
||||
gpu string
|
||||
gpuSKUs []aci.GPUSKU
|
||||
internalIP string
|
||||
daemonEndpointPort int32
|
||||
diagnostics *aci.ContainerGroupDiagnostics
|
||||
@@ -260,21 +268,8 @@ func NewACIProvider(config string, rm *manager.ResourceManager, nodeName, operat
|
||||
return nil, errors.New(unsupportedRegionMessage)
|
||||
}
|
||||
|
||||
// Set sane defaults for Capacity in case config is not supplied
|
||||
p.cpu = "800"
|
||||
p.memory = "4Ti"
|
||||
p.pods = "800"
|
||||
|
||||
if cpuQuota := os.Getenv("ACI_QUOTA_CPU"); cpuQuota != "" {
|
||||
p.cpu = cpuQuota
|
||||
}
|
||||
|
||||
if memoryQuota := os.Getenv("ACI_QUOTA_MEMORY"); memoryQuota != "" {
|
||||
p.memory = memoryQuota
|
||||
}
|
||||
|
||||
if podsQuota := os.Getenv("ACI_QUOTA_POD"); podsQuota != "" {
|
||||
p.pods = podsQuota
|
||||
if err := p.setupCapacity(context.TODO()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.operatingSystem = operatingSystem
|
||||
@@ -324,6 +319,54 @@ func NewACIProvider(config string, rm *manager.ResourceManager, nodeName, operat
|
||||
return &p, err
|
||||
}
|
||||
|
||||
func (p *ACIProvider) setupCapacity(ctx context.Context) error {
|
||||
ctx, span := trace.StartSpan(ctx, "setupCapacity")
|
||||
defer span.End()
|
||||
logger := log.G(ctx).WithField("method", "setupCapacity")
|
||||
|
||||
// Set sane defaults for Capacity in case config is not supplied
|
||||
p.cpu = "800"
|
||||
p.memory = "4Ti"
|
||||
p.pods = "800"
|
||||
|
||||
if cpuQuota := os.Getenv("ACI_QUOTA_CPU"); cpuQuota != "" {
|
||||
p.cpu = cpuQuota
|
||||
}
|
||||
|
||||
if memoryQuota := os.Getenv("ACI_QUOTA_MEMORY"); memoryQuota != "" {
|
||||
p.memory = memoryQuota
|
||||
}
|
||||
|
||||
if podsQuota := os.Getenv("ACI_QUOTA_POD"); podsQuota != "" {
|
||||
p.pods = podsQuota
|
||||
}
|
||||
|
||||
metadata, err := p.aciClient.GetResourceProviderMetadata(ctx)
|
||||
|
||||
if err != nil {
|
||||
msg := "Unable to fetch the ACI metadata"
|
||||
logger.WithError(err).Error(msg)
|
||||
return err
|
||||
}
|
||||
|
||||
if metadata == nil || metadata.GPURegionalSKUs == nil {
|
||||
logger.Warn("ACI GPU capacity is not enabled. GPU capacity will be disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, regionalSKU := range metadata.GPURegionalSKUs {
|
||||
if strings.EqualFold(regionalSKU.Location, p.region) && len(regionalSKU.SKUs) != 0 {
|
||||
p.gpu = "100"
|
||||
if gpu := os.Getenv("ACI_QUOTA_GPU"); gpu != "" {
|
||||
p.gpu = gpu
|
||||
}
|
||||
p.gpuSKUs = regionalSKU.SKUs
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ACIProvider) setupNetworkProfile(auth *client.Authentication) error {
|
||||
c, err := network.NewClient(auth, p.extraUserAgent)
|
||||
if err != nil {
|
||||
@@ -706,7 +749,7 @@ func (p *ACIProvider) GetPod(ctx context.Context, namespace, name string) (*v1.P
|
||||
defer span.End()
|
||||
ctx = addAzureAttributes(ctx, span, p)
|
||||
|
||||
cg, err, status := p.aciClient.GetContainerGroup(ctx, p.resourceGroup, fmt.Sprintf("%s-%s", namespace, name))
|
||||
cg, status, err := p.aciClient.GetContainerGroup(ctx, p.resourceGroup, fmt.Sprintf("%s-%s", namespace, name))
|
||||
if err != nil {
|
||||
if status != nil && *status == http.StatusNotFound {
|
||||
return nil, nil
|
||||
@@ -728,7 +771,7 @@ func (p *ACIProvider) GetContainerLogs(ctx context.Context, namespace, podName,
|
||||
ctx = addAzureAttributes(ctx, span, p)
|
||||
|
||||
logContent := ""
|
||||
cg, err, _ := p.aciClient.GetContainerGroup(ctx, p.resourceGroup, fmt.Sprintf("%s-%s", namespace, podName))
|
||||
cg, _, err := p.aciClient.GetContainerGroup(ctx, p.resourceGroup, fmt.Sprintf("%s-%s", namespace, podName))
|
||||
if err != nil {
|
||||
return logContent, err
|
||||
}
|
||||
@@ -768,7 +811,7 @@ func (p *ACIProvider) ExecInContainer(name string, uid types.UID, container stri
|
||||
defer errstream.Close()
|
||||
}
|
||||
|
||||
cg, err, _ := p.aciClient.GetContainerGroup(context.TODO(), p.resourceGroup, name)
|
||||
cg, _, err := p.aciClient.GetContainerGroup(context.TODO(), p.resourceGroup, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -789,10 +832,10 @@ func (p *ACIProvider) ExecInContainer(name string, uid types.UID, container stri
|
||||
return err
|
||||
}
|
||||
|
||||
wsUri := xcrsp.WebSocketUri
|
||||
wsURI := xcrsp.WebSocketURI
|
||||
password := xcrsp.Password
|
||||
|
||||
c, _, _ := websocket.DefaultDialer.Dial(wsUri, nil)
|
||||
c, _, _ := websocket.DefaultDialer.Dial(wsURI, nil)
|
||||
c.WriteMessage(websocket.TextMessage, []byte(password)) // Websocket password needs to be sent before WS terminal is active
|
||||
|
||||
// Cleanup on exit
|
||||
@@ -889,11 +932,17 @@ func (p *ACIProvider) GetPods(ctx context.Context) ([]*v1.Pod, error) {
|
||||
|
||||
// Capacity returns a resource list containing the capacity limits set for ACI.
|
||||
func (p *ACIProvider) Capacity(ctx context.Context) v1.ResourceList {
|
||||
return v1.ResourceList{
|
||||
"cpu": resource.MustParse(p.cpu),
|
||||
"memory": resource.MustParse(p.memory),
|
||||
"pods": resource.MustParse(p.pods),
|
||||
resourceList := v1.ResourceList{
|
||||
v1.ResourceCPU: resource.MustParse(p.cpu),
|
||||
v1.ResourceMemory: resource.MustParse(p.memory),
|
||||
v1.ResourcePods: resource.MustParse(p.pods),
|
||||
}
|
||||
|
||||
if p.gpu != "" {
|
||||
resourceList[gpuResourceName] = resource.MustParse(p.gpu)
|
||||
}
|
||||
|
||||
return resourceList
|
||||
}
|
||||
|
||||
// NodeConditions returns a list of conditions (Ready, OutOfDisk, etc), for updates to the node status
|
||||
@@ -1146,7 +1195,7 @@ func (p *ACIProvider) getContainers(pod *v1.Pod) ([]aci.Container, error) {
|
||||
}
|
||||
|
||||
c.Resources = aci.ResourceRequirements{
|
||||
Requests: &aci.ResourceRequests{
|
||||
Requests: &aci.ComputeResources{
|
||||
CPU: cpuRequest,
|
||||
MemoryInGB: memoryRequest,
|
||||
},
|
||||
@@ -1163,10 +1212,29 @@ func (p *ACIProvider) getContainers(pod *v1.Pod) ([]aci.Container, error) {
|
||||
memoryLimit = float64(container.Resources.Limits.Memory().Value()) / 1000000000.00
|
||||
}
|
||||
|
||||
c.Resources.Limits = &aci.ResourceLimits{
|
||||
c.Resources.Limits = &aci.ComputeResources{
|
||||
CPU: cpuLimit,
|
||||
MemoryInGB: memoryLimit,
|
||||
}
|
||||
|
||||
if gpu, ok := container.Resources.Limits[gpuResourceName]; ok {
|
||||
sku, err := p.getGPUSKU(pod)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gpu.Value() == 0 {
|
||||
return nil, errors.New("GPU must be a integer number")
|
||||
}
|
||||
|
||||
gpuResource := &aci.GPUResource{
|
||||
Count: int32(gpu.Value()),
|
||||
SKU: sku,
|
||||
}
|
||||
|
||||
c.Resources.Requests.GPU = gpuResource
|
||||
c.Resources.Limits.GPU = gpuResource
|
||||
}
|
||||
}
|
||||
|
||||
if container.LivenessProbe != nil {
|
||||
@@ -1190,6 +1258,24 @@ func (p *ACIProvider) getContainers(pod *v1.Pod) ([]aci.Container, error) {
|
||||
return containers, nil
|
||||
}
|
||||
|
||||
func (p *ACIProvider) getGPUSKU(pod *v1.Pod) (aci.GPUSKU, error) {
|
||||
if len(p.gpuSKUs) == 0 {
|
||||
return "", fmt.Errorf("The pod requires GPU resource, but ACI doesn't provide GPU enabled container group in region %s", p.region)
|
||||
}
|
||||
|
||||
if desiredSKU, ok := pod.Annotations[gpuTypeAnnotation]; ok {
|
||||
for _, supportedSKU := range p.gpuSKUs {
|
||||
if strings.EqualFold(string(desiredSKU), string(supportedSKU)) {
|
||||
return supportedSKU, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("The pod requires GPU SKU %s, but ACI only supports SKUs %v in region %s", desiredSKU, p.region, p.gpuSKUs)
|
||||
}
|
||||
|
||||
return p.gpuSKUs[0], nil
|
||||
}
|
||||
|
||||
func getProbe(probe *v1.Probe) (*aci.ContainerProbe, error) {
|
||||
|
||||
if probe.Handler.Exec != nil && probe.Handler.HTTPGet != nil {
|
||||
@@ -1376,11 +1462,19 @@ func containerGroupToPod(cg *aci.ContainerGroup) (*v1.Pod, error) {
|
||||
},
|
||||
}
|
||||
|
||||
if c.Resources.Requests.GPU != nil {
|
||||
container.Resources.Requests[gpuResourceName] = resource.MustParse(fmt.Sprintf("%d", c.Resources.Requests.GPU.Count))
|
||||
}
|
||||
|
||||
if c.Resources.Limits != nil {
|
||||
container.Resources.Limits = v1.ResourceList{
|
||||
v1.ResourceCPU: resource.MustParse(fmt.Sprintf("%g", c.Resources.Limits.CPU)),
|
||||
v1.ResourceMemory: resource.MustParse(fmt.Sprintf("%gG", c.Resources.Limits.MemoryInGB)),
|
||||
}
|
||||
|
||||
if c.Resources.Limits.GPU != nil {
|
||||
container.Resources.Limits[gpuResourceName] = resource.MustParse(fmt.Sprintf("%d", c.Resources.Requests.GPU.Count))
|
||||
}
|
||||
}
|
||||
|
||||
containers = append(containers, container)
|
||||
|
||||
@@ -16,12 +16,14 @@ type ACIMock struct {
|
||||
OnCreate func(string, string, string, *aci.ContainerGroup) (int, interface{})
|
||||
OnGetContainerGroups func(string, string) (int, interface{})
|
||||
OnGetContainerGroup func(string, string, string) (int, interface{})
|
||||
OnGetRPManifest func() (int, interface{})
|
||||
}
|
||||
|
||||
const (
|
||||
containerGroupsRoute = "/subscriptions/{subscriptionId}/resourceGroups/{resourceGroup}/providers/Microsoft.ContainerInstance/containerGroups"
|
||||
containerGroupRoute = containerGroupsRoute + "/{containerGroup}"
|
||||
containerGroupLogRoute = containerGroupRoute + "/containers/{containerName}/logs"
|
||||
resourceProviderRoute = "/providers/Microsoft.ContainerInstance"
|
||||
)
|
||||
|
||||
// NewACIMock creates a new Azure Container Instance mock server.
|
||||
@@ -103,6 +105,22 @@ func (mock *ACIMock) start() {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}).Methods("GET")
|
||||
|
||||
router.HandleFunc(
|
||||
resourceProviderRoute,
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
if mock.OnGetRPManifest != nil {
|
||||
statusCode, response := mock.OnGetRPManifest()
|
||||
w.WriteHeader(statusCode)
|
||||
b := new(bytes.Buffer)
|
||||
json.NewEncoder(b).Encode(response)
|
||||
w.Write(b.Bytes())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}).Methods("GET")
|
||||
|
||||
mock.server = httptest.NewServer(router)
|
||||
}
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ const (
|
||||
fakeClientSecret = "VGhpcyBpcyBhIHNlY3JldAo="
|
||||
fakeTenantID = "8cb81aca-83fe-4c6f-b667-4ec09c45a8bf"
|
||||
fakeNodeName = "vk"
|
||||
fakeRegion = "eastus"
|
||||
)
|
||||
|
||||
// Test make registry credential
|
||||
@@ -199,6 +200,166 @@ func TestCreatePodWithResourceRequestOnly(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Tests create pod with default GPU SKU.
|
||||
func TestCreatePodWithGPU(t *testing.T) {
|
||||
aadServerMocker := NewAADMock()
|
||||
aciServerMocker := NewACIMock()
|
||||
|
||||
podName := "pod-" + uuid.New().String()
|
||||
podNamespace := "ns-" + uuid.New().String()
|
||||
gpuSKU := aci.GPUSKU("sku-" + uuid.New().String())
|
||||
|
||||
aciServerMocker.OnGetRPManifest = func() (int, interface{}) {
|
||||
manifest := &aci.ResourceProviderManifest{
|
||||
Metadata: &aci.ResourceProviderMetadata{
|
||||
GPURegionalSKUs: []*aci.GPURegionalSKU{
|
||||
&aci.GPURegionalSKU{
|
||||
Location: fakeRegion,
|
||||
SKUs: []aci.GPUSKU{gpuSKU, aci.K80, aci.P100},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return http.StatusOK, manifest
|
||||
}
|
||||
|
||||
provider, err := createTestProvider(aadServerMocker, aciServerMocker)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create the test provider. %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
aciServerMocker.OnCreate = func(subscription, resourceGroup, containerGroup string, cg *aci.ContainerGroup) (int, interface{}) {
|
||||
assert.Check(t, is.Equal(fakeSubscription, subscription), "Subscription doesn't match")
|
||||
assert.Check(t, is.Equal(fakeResourceGroup, resourceGroup), "Resource group doesn't match")
|
||||
assert.Check(t, is.Equal(podNamespace+"-"+podName, containerGroup), "Container group name is not expected")
|
||||
assert.Check(t, cg.ContainerGroupProperties.Containers != nil, "Containers should not be nil")
|
||||
assert.Check(t, is.Equal(1, len(cg.ContainerGroupProperties.Containers)), "1 Container is expected")
|
||||
assert.Check(t, is.Equal("nginx", cg.ContainerGroupProperties.Containers[0].Name), "Container nginx is expected")
|
||||
assert.Check(t, cg.ContainerGroupProperties.Containers[0].Resources.Requests != nil, "Container resource requests should not be nil")
|
||||
assert.Check(t, is.Equal(1.98, cg.ContainerGroupProperties.Containers[0].Resources.Requests.CPU), "Request CPU is not expected")
|
||||
assert.Check(t, is.Equal(3.4, cg.ContainerGroupProperties.Containers[0].Resources.Requests.MemoryInGB), "Request Memory is not expected")
|
||||
assert.Check(t, cg.ContainerGroupProperties.Containers[0].Resources.Requests.GPU != nil, "Requests GPU is not expected")
|
||||
assert.Check(t, is.Equal(int32(10), cg.ContainerGroupProperties.Containers[0].Resources.Requests.GPU.Count), "Requests GPU Count is not expected")
|
||||
assert.Check(t, is.Equal(gpuSKU, cg.ContainerGroupProperties.Containers[0].Resources.Requests.GPU.SKU), "Requests GPU SKU is not expected")
|
||||
assert.Check(t, cg.ContainerGroupProperties.Containers[0].Resources.Limits.GPU != nil, "Limits GPU is not expected")
|
||||
assert.Check(t, is.Equal(int32(10), cg.ContainerGroupProperties.Containers[0].Resources.Limits.GPU.Count), "Requests GPU Count is not expected")
|
||||
assert.Check(t, is.Equal(gpuSKU, cg.ContainerGroupProperties.Containers[0].Resources.Limits.GPU.SKU), "Requests GPU SKU is not expected")
|
||||
|
||||
return http.StatusOK, cg
|
||||
}
|
||||
|
||||
pod := &v1.Pod{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: podName,
|
||||
Namespace: podNamespace,
|
||||
},
|
||||
Spec: v1.PodSpec{
|
||||
Containers: []v1.Container{
|
||||
v1.Container{
|
||||
Name: "nginx",
|
||||
Resources: v1.ResourceRequirements{
|
||||
Requests: v1.ResourceList{
|
||||
"cpu": resource.MustParse("1.981"),
|
||||
"memory": resource.MustParse("3.49G"),
|
||||
},
|
||||
Limits: v1.ResourceList{
|
||||
gpuResourceName: resource.MustParse("10"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := provider.CreatePod(context.Background(), pod); err != nil {
|
||||
t.Fatal("Failed to create pod", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Tests create pod with GPU SKU in annotation.
|
||||
func TestCreatePodWithGPUSKU(t *testing.T) {
|
||||
aadServerMocker := NewAADMock()
|
||||
aciServerMocker := NewACIMock()
|
||||
|
||||
podName := "pod-" + uuid.New().String()
|
||||
podNamespace := "ns-" + uuid.New().String()
|
||||
gpuSKU := aci.GPUSKU("sku-" + uuid.New().String())
|
||||
|
||||
aciServerMocker.OnGetRPManifest = func() (int, interface{}) {
|
||||
manifest := &aci.ResourceProviderManifest{
|
||||
Metadata: &aci.ResourceProviderMetadata{
|
||||
GPURegionalSKUs: []*aci.GPURegionalSKU{
|
||||
&aci.GPURegionalSKU{
|
||||
Location: fakeRegion,
|
||||
SKUs: []aci.GPUSKU{aci.K80, aci.P100, gpuSKU},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return http.StatusOK, manifest
|
||||
}
|
||||
|
||||
provider, err := createTestProvider(aadServerMocker, aciServerMocker)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create the test provider. %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
aciServerMocker.OnCreate = func(subscription, resourceGroup, containerGroup string, cg *aci.ContainerGroup) (int, interface{}) {
|
||||
assert.Check(t, is.Equal(fakeSubscription, subscription), "Subscription doesn't match")
|
||||
assert.Check(t, is.Equal(fakeResourceGroup, resourceGroup), "Resource group doesn't match")
|
||||
assert.Check(t, cg != nil, "Container group is nil")
|
||||
assert.Check(t, is.Equal(podNamespace+"-"+podName, containerGroup), "Container group name is not expected")
|
||||
assert.Check(t, cg.ContainerGroupProperties.Containers != nil, "Containers should not be nil")
|
||||
assert.Check(t, is.Equal(1, len(cg.ContainerGroupProperties.Containers)), "1 Container is expected")
|
||||
assert.Check(t, is.Equal("nginx", cg.ContainerGroupProperties.Containers[0].Name), "Container nginx is expected")
|
||||
assert.Check(t, cg.ContainerGroupProperties.Containers[0].Resources.Requests != nil, "Container resource requests should not be nil")
|
||||
assert.Check(t, is.Equal(1.98, cg.ContainerGroupProperties.Containers[0].Resources.Requests.CPU), "Request CPU is not expected")
|
||||
assert.Check(t, is.Equal(3.4, cg.ContainerGroupProperties.Containers[0].Resources.Requests.MemoryInGB), "Request Memory is not expected")
|
||||
assert.Check(t, cg.ContainerGroupProperties.Containers[0].Resources.Requests.GPU != nil, "Requests GPU is not expected")
|
||||
assert.Check(t, is.Equal(int32(1), cg.ContainerGroupProperties.Containers[0].Resources.Requests.GPU.Count), "Requests GPU Count is not expected")
|
||||
assert.Check(t, is.Equal(gpuSKU, cg.ContainerGroupProperties.Containers[0].Resources.Requests.GPU.SKU), "Requests GPU SKU is not expected")
|
||||
assert.Check(t, cg.ContainerGroupProperties.Containers[0].Resources.Limits.GPU != nil, "Limits GPU is not expected")
|
||||
assert.Check(t, is.Equal(int32(1), cg.ContainerGroupProperties.Containers[0].Resources.Limits.GPU.Count), "Requests GPU Count is not expected")
|
||||
assert.Check(t, is.Equal(gpuSKU, cg.ContainerGroupProperties.Containers[0].Resources.Limits.GPU.SKU), "Requests GPU SKU is not expected")
|
||||
|
||||
return http.StatusOK, cg
|
||||
}
|
||||
|
||||
pod := &v1.Pod{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: podName,
|
||||
Namespace: podNamespace,
|
||||
Annotations: map[string]string{
|
||||
gpuTypeAnnotation: string(gpuSKU),
|
||||
},
|
||||
},
|
||||
Spec: v1.PodSpec{
|
||||
Containers: []v1.Container{
|
||||
v1.Container{
|
||||
Name: "nginx",
|
||||
Resources: v1.ResourceRequirements{
|
||||
Requests: v1.ResourceList{
|
||||
"cpu": resource.MustParse("1.981"),
|
||||
"memory": resource.MustParse("3.49G"),
|
||||
},
|
||||
Limits: v1.ResourceList{
|
||||
gpuResourceName: resource.MustParse("1"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := provider.CreatePod(context.Background(), pod); err != nil {
|
||||
t.Fatal("Failed to create pod", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Tests create pod with both resource request and limit.
|
||||
func TestCreatePodWithResourceRequestAndLimit(t *testing.T) {
|
||||
_, aciServerMocker, provider, err := prepareMocks()
|
||||
@@ -314,7 +475,7 @@ func TestGetPodsWithoutResourceRequestsLimits(t *testing.T) {
|
||||
},
|
||||
},
|
||||
Resources: aci.ResourceRequirements{
|
||||
Requests: &aci.ResourceRequests{
|
||||
Requests: &aci.ComputeResources{
|
||||
CPU: 0.99,
|
||||
MemoryInGB: 1.5,
|
||||
},
|
||||
@@ -385,7 +546,7 @@ func TestGetPodWithoutResourceRequestsLimits(t *testing.T) {
|
||||
},
|
||||
},
|
||||
Resources: aci.ResourceRequirements{
|
||||
Requests: &aci.ResourceRequests{
|
||||
Requests: &aci.ComputeResources{
|
||||
CPU: 0.99,
|
||||
MemoryInGB: 1.5,
|
||||
},
|
||||
@@ -412,6 +573,94 @@ func TestGetPodWithoutResourceRequestsLimits(t *testing.T) {
|
||||
pod.Spec.Containers[0].Resources.Requests.Memory().Value()), "Containers[0].Resources.Requests.Memory doesn't match")
|
||||
}
|
||||
|
||||
// Tests get pod with GPU.
|
||||
func TestGetPodWithGPU(t *testing.T) {
|
||||
_, aciServerMocker, provider, err := prepareMocks()
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("Unable to prepare the mocks", err)
|
||||
}
|
||||
|
||||
podName := "pod-" + uuid.New().String()
|
||||
podNamespace := "ns-" + uuid.New().String()
|
||||
|
||||
aciServerMocker.OnGetContainerGroup = func(subscription, resourceGroup, containerGroup string) (int, interface{}) {
|
||||
assert.Equal(t, fakeSubscription, subscription, "Subscription doesn't match")
|
||||
assert.Equal(t, fakeResourceGroup, resourceGroup, "Resource group doesn't match")
|
||||
assert.Equal(t, podNamespace+"-"+podName, containerGroup, "Container group name is not expected")
|
||||
|
||||
return http.StatusOK, aci.ContainerGroup{
|
||||
Tags: map[string]string{
|
||||
"NodeName": fakeNodeName,
|
||||
},
|
||||
ContainerGroupProperties: aci.ContainerGroupProperties{
|
||||
ProvisioningState: "Creating",
|
||||
Containers: []aci.Container{
|
||||
aci.Container{
|
||||
Name: "nginx",
|
||||
ContainerProperties: aci.ContainerProperties{
|
||||
Image: "nginx",
|
||||
Command: []string{"nginx", "-g", "daemon off;"},
|
||||
Ports: []aci.ContainerPort{
|
||||
{
|
||||
Protocol: aci.ContainerNetworkProtocolTCP,
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
Resources: aci.ResourceRequirements{
|
||||
Requests: &aci.ComputeResources{
|
||||
CPU: 0.99,
|
||||
MemoryInGB: 1.5,
|
||||
GPU: &aci.GPUResource{
|
||||
Count: 5,
|
||||
SKU: aci.P100,
|
||||
},
|
||||
},
|
||||
Limits: &aci.ComputeResources{
|
||||
GPU: &aci.GPUResource{
|
||||
Count: 5,
|
||||
SKU: aci.P100,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pod, err := provider.GetPod(context.Background(), podNamespace, podName)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to get pod", err)
|
||||
}
|
||||
|
||||
assert.Check(t, pod != nil, "Response pod should not be nil")
|
||||
assert.Check(t, pod.Spec.Containers != nil, "Containers should not be nil")
|
||||
assert.Check(t, pod.Spec.Containers[0].Resources.Requests != nil, "Containers[0].Resources.Requests should not be nil")
|
||||
assert.Check(
|
||||
t,
|
||||
is.Equal(ptrQuantity(resource.MustParse("0.99")).Value(), pod.Spec.Containers[0].Resources.Requests.Cpu().Value()),
|
||||
"Containers[0].Resources.Requests.CPU doesn't match")
|
||||
assert.Check(
|
||||
t,
|
||||
is.Equal(ptrQuantity(resource.MustParse("1.5G")).Value(), pod.Spec.Containers[0].Resources.Requests.Memory().Value()),
|
||||
"Containers[0].Resources.Requests.Memory doesn't match")
|
||||
gpuQuantity, ok := pod.Spec.Containers[0].Resources.Requests[gpuResourceName]
|
||||
assert.Check(t, is.Equal(ok, true), "Containers[0].Resources.Requests.GPU should not be nil")
|
||||
assert.Check(
|
||||
t,
|
||||
is.Equal(ptrQuantity(resource.MustParse("5")).Value(), ptrQuantity(gpuQuantity).Value()),
|
||||
"Containers[0].Resources.Requests.GPU.Count doesn't match")
|
||||
assert.Check(t, pod.Spec.Containers[0].Resources.Limits != nil, "Containers[0].Resources.Limits should not be nil")
|
||||
gpuQuantity, ok = pod.Spec.Containers[0].Resources.Limits[gpuResourceName]
|
||||
assert.Check(t, is.Equal(ok, true), "Containers[0].Resources.Requests.GPU should not be nil")
|
||||
assert.Check(
|
||||
t,
|
||||
is.Equal(ptrQuantity(resource.MustParse("5")).Value(), ptrQuantity(gpuQuantity).Value()),
|
||||
"Containers[0].Resources.Limits.GPU.Count doesn't match")
|
||||
}
|
||||
|
||||
func TestGetPodWithContainerID(t *testing.T) {
|
||||
_, aciServerMocker, provider, err := prepareMocks()
|
||||
|
||||
@@ -451,7 +700,7 @@ func TestGetPodWithContainerID(t *testing.T) {
|
||||
},
|
||||
},
|
||||
Resources: aci.ResourceRequirements{
|
||||
Requests: &aci.ResourceRequests{
|
||||
Requests: &aci.ComputeResources{
|
||||
CPU: 0.99,
|
||||
MemoryInGB: 1.5,
|
||||
},
|
||||
@@ -531,6 +780,30 @@ func prepareMocks() (*AADMock, *ACIMock, *ACIProvider, error) {
|
||||
aadServerMocker := NewAADMock()
|
||||
aciServerMocker := NewACIMock()
|
||||
|
||||
aciServerMocker.OnGetRPManifest = func() (int, interface{}) {
|
||||
manifest := &aci.ResourceProviderManifest{
|
||||
Metadata: &aci.ResourceProviderMetadata{
|
||||
GPURegionalSKUs: []*aci.GPURegionalSKU{
|
||||
&aci.GPURegionalSKU{
|
||||
Location: fakeRegion,
|
||||
SKUs: []aci.GPUSKU{aci.K80, aci.P100, aci.V100},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return http.StatusOK, manifest
|
||||
}
|
||||
|
||||
provider, err := createTestProvider(aadServerMocker, aciServerMocker)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to create the test provider %s", err.Error())
|
||||
}
|
||||
|
||||
return aadServerMocker, aciServerMocker, provider, nil
|
||||
}
|
||||
|
||||
func createTestProvider(aadServerMocker *AADMock, aciServerMocker*ACIMock) (*ACIProvider, error) {
|
||||
auth := azure.NewAuthentication(
|
||||
azure.PublicCloud.Name,
|
||||
fakeClientID,
|
||||
@@ -543,7 +816,7 @@ func prepareMocks() (*AADMock, *ACIMock, *ACIProvider, error) {
|
||||
|
||||
file, err := ioutil.TempFile("", "auth.json")
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer os.Remove(file.Name())
|
||||
@@ -552,23 +825,25 @@ func prepareMocks() (*AADMock, *ACIMock, *ACIProvider, error) {
|
||||
json.NewEncoder(b).Encode(auth)
|
||||
|
||||
if _, err := file.Write(b.Bytes()); err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
os.Setenv("AZURE_AUTH_LOCATION", file.Name())
|
||||
os.Setenv("ACI_RESOURCE_GROUP", fakeResourceGroup)
|
||||
os.Setenv("ACI_REGION", fakeRegion)
|
||||
|
||||
rm, err := manager.NewResourceManager(nil, nil, nil)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
provider, err := NewACIProvider("example.toml", rm, fakeNodeName, "Linux", "0.0.0.0", 10250)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return aadServerMocker, aciServerMocker, provider, nil
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func ptrQuantity(q resource.Quantity) *resource.Quantity {
|
||||
|
||||
Reference in New Issue
Block a user