Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions charts/gpu-provisioner/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ controller:
value: "false"
- name: E2E_TEST_MODE
value: "false"
- name: CLOUD_PROVIDER # possible values can be aks or arc
value: "aks"
envFrom: []
# -- Resources for the controller pod.
resources:
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ require (
contrib.go.opencensus.io/exporter/ocagent v0.7.1-0.20200907061046-05415f1de66d // indirect
contrib.go.opencensus.io/exporter/prometheus v0.4.2 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.9.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/hybridcontainerservice/armhybridcontainerservice v1.0.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0 // indirect
github.com/Azure/go-autorest v14.2.0+incompatible // indirect
github.com/Azure/go-autorest/autorest/adal v0.9.24 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.9.0 h1:H+U3Gk9zY56G3u872L82bk4
github.com/Azure/azure-sdk-for-go/sdk/internal v1.9.0/go.mod h1:mgrmMSgaLp9hmax62XQTd0N4aAqSE5E0DulSpVYK7vc=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.8.0 h1:0nGmzwBv5ougvzfGPCO2ljFRHvun57KpNrVCMrlk0ns=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.8.0/go.mod h1:gYq8wyDgv6JLhGbAU6gg8amCPgQWRE+aCvrV2gyzdfs=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/hybridcontainerservice/armhybridcontainerservice v1.0.0 h1:crtqxU3LRy2UEPkQJGhJM1KUPf9q0jfIrr1m2o6UAd4=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/hybridcontainerservice/armhybridcontainerservice v1.0.0/go.mod h1:amJDuQ3h8RzdvCxHSwEwn1t9cRQ7KXaoL2OeGSFVj24=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0 h1:PTFGRSlMKCQelWwxUyYVEUqseBJVemLyqWJjvMyt0do=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0/go.mod h1:LRr2FzBTQlONPPa5HREE5+RjSCTXl7BwOvYOaWTqCaI=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0 h1:Dd+RhdJn0OTtVGaeDLZpcumkIVCtA/3/Fo42+eoYvVM=
Expand Down
12 changes: 6 additions & 6 deletions pkg/cloudprovider/cloudprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"time"

"github.com/awslabs/operatorpkg/status"
"github.com/azure/gpu-provisioner/pkg/providers/instance"
"github.com/azure/gpu-provisioner/pkg/providers"
"github.com/samber/lo"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand All @@ -35,11 +35,11 @@ import (
var _ cloudprovider.CloudProvider = &CloudProvider{}

type CloudProvider struct {
instanceProvider *instance.Provider
instanceProvider providers.InstanceProvider
kubeClient client.Client
}

func New(instanceProvider *instance.Provider, kubeClient client.Client) *CloudProvider {
func New(instanceProvider providers.InstanceProvider, kubeClient client.Client) *CloudProvider {
return &CloudProvider{
instanceProvider: instanceProvider,
kubeClient: kubeClient,
Expand Down Expand Up @@ -108,7 +108,7 @@ func (c *CloudProvider) GetSupportedNodeClasses() []status.Object {
return []status.Object{}
}

func (c *CloudProvider) instanceToNodeClaim(ctx context.Context, instanceObj *instance.Instance) *karpenterv1.NodeClaim {
func (c *CloudProvider) instanceToNodeClaim(ctx context.Context, instanceObj *providers.Instance) *karpenterv1.NodeClaim {
nodeClaim := &karpenterv1.NodeClaim{}
if instanceObj == nil {
return nodeClaim
Expand All @@ -133,8 +133,8 @@ func (c *CloudProvider) instanceToNodeClaim(ctx context.Context, instanceObj *in

nodeClaim.Labels = labels
nodeClaim.Annotations = annotations
if timestamp, ok := labels[instance.NodeClaimCreationLabel]; ok {
if creationTime, err := time.Parse(instance.CreationTimestampLayout, timestamp); err == nil {
if timestamp, ok := labels[providers.NodeClaimCreationLabel]; ok {
if creationTime, err := time.Parse(providers.CreationTimestampLayout, timestamp); err == nil {
nodeClaim.CreationTimestamp = metav1.Time{Time: creationTime}
}
}
Expand Down
61 changes: 47 additions & 14 deletions pkg/operator/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ import (
"os"

"github.com/azure/gpu-provisioner/pkg/auth"
"github.com/azure/gpu-provisioner/pkg/providers"
"github.com/azure/gpu-provisioner/pkg/providers/arcinstance"
"github.com/azure/gpu-provisioner/pkg/providers/instance"
"knative.dev/pkg/logging"
"sigs.k8s.io/karpenter/pkg/operator"
)

// Operator is injected into the AWS CloudProvider's factories
// Operator is injected into the CloudProvider's factories
type Operator struct {
*operator.Operator
InstanceProvider *instance.Provider
InstanceProvider providers.InstanceProvider
}

func NewOperator(ctx context.Context, operator *operator.Operator) (context.Context, *Operator) {
Expand All @@ -38,20 +40,51 @@ func NewOperator(ctx context.Context, operator *operator.Operator) (context.Cont
logging.FromContext(ctx).Errorf("creating Azure config, %s", err)
}

azClient, err := instance.CreateAzClient(azConfig)
if err != nil {
logging.FromContext(ctx).Errorf("creating Azure client, %s", err)
// Let us panic here, instead of crashing in the following code.
// TODO: move this to an init container
panic(fmt.Sprintf("Configure azure client fails. Please ensure federatedcredential has been created for identity %s.", os.Getenv("AZURE_CLIENT_ID")))
// Get cloud provider type from environment variable
cloudProvider := os.Getenv("CLOUD_PROVIDER")
if cloudProvider == "" {
cloudProvider = "aks" // default to AKS
}

var instanceProvider providers.InstanceProvider

switch cloudProvider {
case "aks":
azClient, err := instance.CreateAzClient(azConfig)
if err != nil {
logging.FromContext(ctx).Errorf("creating Azure client, %s", err)
// Let us panic here, instead of crashing in the following code.
// TODO: move this to an init container
panic(fmt.Sprintf("Configure azure client fails. Please ensure federatedcredential has been created for identity %s.", os.Getenv("AZURE_CLIENT_ID")))
}

instanceProvider = instance.NewProvider(
azClient,
operator.GetClient(),
azConfig.ResourceGroup,
azConfig.ClusterName,
)

case "arc":
arcClient, err := arcinstance.NewArcClient(azConfig.SubscriptionID)
if err != nil {
logging.FromContext(ctx).Errorf("creating Arc client, %s", err)
panic(fmt.Sprintf("Configure Arc client fails: %v", err))
}

instanceProvider = arcinstance.NewProvider(
arcClient,
operator.GetClient(),
azConfig.SubscriptionID,
azConfig.ResourceGroup,
azConfig.ClusterName,
)

default:
panic(fmt.Sprintf("Unsupported CLOUD_PROVIDER: %s. Supported values are 'aks' and 'arc'", cloudProvider))
}

instanceProvider := instance.NewProvider(
azClient,
operator.GetClient(),
azConfig.ResourceGroup,
azConfig.ClusterName,
)
logging.FromContext(ctx).Infof("Using cloud provider: %s", cloudProvider)

return ctx, &Operator{
Operator: operator,
Expand Down
104 changes: 104 additions & 0 deletions pkg/providers/arcinstance/armutils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
Copyright (c) Microsoft Corporation.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package arcinstance

import (
"context"

"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/hybridcontainerservice/armhybridcontainerservice"
"github.com/azure/gpu-provisioner/pkg/utils"
"k8s.io/klog/v2"
)

func createAgentPool(ctx context.Context, client AgentPoolsAPI, connectedClusterResourceURI, apName string, ap armhybridcontainerservice.AgentPool) (*armhybridcontainerservice.AgentPool, error) {

klog.InfoS("createAgentPool", "agentpool", apName)

poller, err := client.BeginCreateOrUpdate(ctx, connectedClusterResourceURI, apName, ap, nil)

if err != nil {

return nil, err

}

res, err := poller.PollUntilDone(ctx, nil)

if err != nil {

return nil, err

}

return &res.AgentPool, nil

}

func deleteAgentPool(ctx context.Context, client AgentPoolsAPI, connectedClusterResourceURI, apName string) error {

klog.InfoS("deleteAgentPool", "agentpool", apName)

poller, err := client.BeginDelete(ctx, connectedClusterResourceURI, apName, nil)

if err != nil {

return utils.ShouldIgnoreNotFoundError(err)

}

_, err = poller.PollUntilDone(ctx, nil)

return utils.ShouldIgnoreNotFoundError(err)

}

func getAgentPool(ctx context.Context, client AgentPoolsAPI, connectedClusterResourceURI, apName string) (*armhybridcontainerservice.AgentPool, error) {

resp, err := client.Get(ctx, connectedClusterResourceURI, apName, nil)

if err != nil {

return nil, err

}

return &resp.AgentPool, nil

}

func listAgentPools(ctx context.Context, client AgentPoolsAPI, connectedClusterResourceURI string) ([]*armhybridcontainerservice.AgentPool, error) {

var apList []*armhybridcontainerservice.AgentPool

pager := client.NewListByProvisionedClusterPager(connectedClusterResourceURI, nil)

for pager.More() {

page, err := pager.NextPage(ctx)

if err != nil {

return nil, err

}

apList = append(apList, page.Value...)

}

return apList, nil

}
Loading
Loading