From 03d7670dcd05ec7eb606b9b18b0e5236e222e467 Mon Sep 17 00:00:00 2001 From: Alexander Wels Date: Tue, 15 Jul 2025 14:19:30 -0500 Subject: [PATCH] Update DataVolumeTemplate storage class during storage class migration. Signed-off-by: Alexander Wels --- pkg/controller/directvolumemigration/vm.go | 62 +++++++++- .../directvolumemigration/vm_test.go | 107 +++++++++++++++++- 2 files changed, 158 insertions(+), 11 deletions(-) diff --git a/pkg/controller/directvolumemigration/vm.go b/pkg/controller/directvolumemigration/vm.go index 9dd253683..88d39ea98 100644 --- a/pkg/controller/directvolumemigration/vm.go +++ b/pkg/controller/directvolumemigration/vm.go @@ -19,6 +19,7 @@ import ( prometheusapi "github.com/prometheus/client_golang/api" prometheusv1 "github.com/prometheus/client_golang/api/prometheus/v1" corev1 "k8s.io/api/core/v1" + storagev1 "k8s.io/api/storage/v1" k8serrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -28,6 +29,7 @@ import ( "k8s.io/client-go/rest" "k8s.io/utils/ptr" virtv1 "kubevirt.io/api/core/v1" + cdiv1 "kubevirt.io/containerized-data-importer-api/pkg/apis/core/v1beta1" k8sclient "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -40,6 +42,8 @@ const ( PodKind = "Pod" virtLauncherPodLabelSelectorKey = "kubevirt.io" virtLauncherPodLabelSelectorValue = "virt-launcher" + defaultVirtStorageClass = "storageclass.kubevirt.io/is-default-virt-class" + defaultK8sStorageClass = "storageclass.kubernetes.io/is-default-class" ) var ( @@ -483,18 +487,15 @@ func updateVM(client k8sclient.Client, vm *virtv1.VirtualMachine, sourceVolumes, log.V(5).Info("Setting volume migration strategy to migration", "vm", vmCopy.Name) vmCopy.Spec.UpdateVolumesStrategy = ptr.To[virtv1.UpdateVolumesStrategy](virtv1.UpdateVolumesStrategyMigration) + volumeNameToPVCMap := make(map[string]string) + for i := 0; i < len(sourceVolumes); i++ { // Check if we need to update DataVolumeTemplates. - for j, dvTemplate := range vmCopy.Spec.DataVolumeTemplates { - if dvTemplate.Name == sourceVolumes[i] { - log.V(5).Info("Updating DataVolumeTemplate", "source", sourceVolumes[i], "target", targetVolumes[i]) - vmCopy.Spec.DataVolumeTemplates[j].Name = targetVolumes[i] - } - } for j, volume := range vm.Spec.Template.Spec.Volumes { if volume.PersistentVolumeClaim != nil && volume.PersistentVolumeClaim.ClaimName == sourceVolumes[i] { log.V(5).Info("Updating PersistentVolumeClaim", "source", sourceVolumes[i], "target", targetVolumes[i]) vmCopy.Spec.Template.Spec.Volumes[j].PersistentVolumeClaim.ClaimName = targetVolumes[i] + volumeNameToPVCMap[sourceVolumes[i]] = targetVolumes[i] } if volume.DataVolume != nil && volume.DataVolume.Name == sourceVolumes[i] { log.V(5).Info("Updating DataVolume", "source", sourceVolumes[i], "target", targetVolumes[i]) @@ -502,6 +503,23 @@ func updateVM(client k8sclient.Client, vm *virtv1.VirtualMachine, sourceVolumes, return err } vmCopy.Spec.Template.Spec.Volumes[j].DataVolume.Name = targetVolumes[i] + volumeNameToPVCMap[sourceVolumes[i]] = targetVolumes[i] + } + } + for j, dvTemplate := range vmCopy.Spec.DataVolumeTemplates { + if dvTemplate.Name == sourceVolumes[i] { + log.V(5).Info("Updating DataVolumeTemplate", "source", sourceVolumes[i], "target", targetVolumes[i]) + vmCopy.Spec.DataVolumeTemplates[j].Name = targetVolumes[i] + pvcName := volumeNameToPVCMap[sourceVolumes[i]] + sc, err := getStorageClassFromName(client, pvcName, vmCopy.Namespace) + if err != nil { + return err + } + if vmCopy.Spec.DataVolumeTemplates[j].Spec.Storage != nil { + vmCopy.Spec.DataVolumeTemplates[j].Spec.Storage.StorageClassName = ptr.To(sc) + } else if vmCopy.Spec.DataVolumeTemplates[j].Spec.PVC != nil { + vmCopy.Spec.DataVolumeTemplates[j].Spec.PVC.StorageClassName = ptr.To(sc) + } } } } @@ -516,6 +534,38 @@ func updateVM(client k8sclient.Client, vm *virtv1.VirtualMachine, sourceVolumes, return nil } +func getStorageClassFromName(client k8sclient.Client, name, namespace string) (string, error) { + volume := &corev1.PersistentVolumeClaim{} + if err := client.Get(context.TODO(), k8sclient.ObjectKey{Namespace: namespace, Name: name}, volume); err != nil { + if k8serrors.IsNotFound(err) { + return "", nil + } + return "", err + } + if volume.Spec.StorageClassName == nil { + // Find the default storage class + scList := &storagev1.StorageClassList{} + if err := client.List(context.TODO(), scList); err != nil { + return "", err + } + defaultStorageClass := "" + for _, sc := range scList.Items { + if sc.Annotations != nil && sc.Annotations[defaultK8sStorageClass] == "true" { + defaultStorageClass = sc.Name + } + if sc.Annotations != nil && sc.Annotations[defaultVirtStorageClass] == "true" { + return sc.Name, nil + } + } + if defaultStorageClass != "" { + return defaultStorageClass, nil + } + // No default storage class found, return blank + return "", nil + } + return *volume.Spec.StorageClassName, nil +} + func createBlankDataVolumeFromPVC(client k8sclient.Client, targetPvc *corev1.PersistentVolumeClaim) error { dv := &cdiv1.DataVolume{ ObjectMeta: metav1.ObjectMeta{ diff --git a/pkg/controller/directvolumemigration/vm_test.go b/pkg/controller/directvolumemigration/vm_test.go index 945277407..4f12be308 100644 --- a/pkg/controller/directvolumemigration/vm_test.go +++ b/pkg/controller/directvolumemigration/vm_test.go @@ -21,6 +21,7 @@ import ( prometheusv1 "github.com/prometheus/client_golang/api/prometheus/v1" "github.com/prometheus/common/model" corev1 "k8s.io/api/core/v1" + storagev1 "k8s.io/api/storage/v1" k8serrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/rest" @@ -31,11 +32,15 @@ import ( ) const ( - sourcePVC = "source-pvc" - sourceNs = "source-ns" - targetPVC = "target-pvc" - targetNs = "target-ns" - targetDv = "target-dv" + sourcePVC = "source-pvc" + sourceNs = "source-ns" + targetPVC = "target-pvc" + targetNs = "target-ns" + targetDv = "target-dv" + testPVCName = "test-pvc" + testStorageClass = "test-sc" + testDefaultStorageClass = "test-default-sc" + testVirtDefaultStorageClass = "test-virt-default-sc" ) func TestTask_startLiveMigrations(t *testing.T) { @@ -1428,6 +1433,60 @@ func TestTaskBuildSourcePrometheusEndPointURL(t *testing.T) { } } +func TestGetStorageClassFromName(t *testing.T) { + tests := []struct { + name string + client compat.Client + expectedError bool + expectedSc string + }{ + { + name: "no pvcs, no storage class, should return blank", + client: getFakeClientWithObjs(), + expectedError: false, + expectedSc: "", + }, + { + name: "pvcs, with storage class, should return name", + client: getFakeClientWithObjs(createPVC(testPVCName, testNamespace)), + expectedError: false, + expectedSc: testStorageClass, + }, + { + name: "pvcs, no storage class, no default storage class return blank", + client: getFakeClientWithObjs(createNoStorageClassPVC(testPVCName, testNamespace)), + expectedError: false, + expectedSc: "", + }, + { + name: "pvcs, no storage class, default storage class, should return name", + client: getFakeClientWithObjs(createNoStorageClassPVC(testPVCName, testNamespace), createDefaultStorageClass(testDefaultStorageClass)), + expectedError: false, + expectedSc: testDefaultStorageClass, + }, + { + name: "pvcs, no storage class, virt default storage class, should return name", + client: getFakeClientWithObjs(createNoStorageClassPVC(testPVCName, testNamespace), createDefaultStorageClass(testDefaultStorageClass), createVirtDefaultStorageClass(testVirtDefaultStorageClass)), + expectedError: false, + expectedSc: testVirtDefaultStorageClass, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sc, err := getStorageClassFromName(tt.client, testPVCName, testNamespace) + if tt.expectedError { + if err == nil { + t.Errorf("expected error but got nil") + t.FailNow() + } + } + if sc != tt.expectedSc { + t.Errorf("expected %s, got %s", tt.expectedSc, sc) + } + }) + } +} + func getFakeClientWithObjs(obj ...k8sclient.Object) compat.Client { client, _ := fakecompat.NewFakeClient(obj...) return client @@ -1579,6 +1638,44 @@ func createRoute(name, namespace, url string) *routev1.Route { } } +func createPVC(name, namespace string) *corev1.PersistentVolumeClaim { + pvc := createNoStorageClassPVC(name, namespace) + pvc.Spec.StorageClassName = ptr.To(testStorageClass) + return pvc +} + +func createNoStorageClassPVC(name, namespace string) *corev1.PersistentVolumeClaim { + return &corev1.PersistentVolumeClaim{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + } +} +func createStorageClass(name string) *storagev1.StorageClass { + return &storagev1.StorageClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + }, + } +} + +func createDefaultStorageClass(name string) *storagev1.StorageClass { + sc := createStorageClass(name) + sc.Annotations = map[string]string{ + defaultK8sStorageClass: "true", + } + return sc +} + +func createVirtDefaultStorageClass(name string) *storagev1.StorageClass { + sc := createStorageClass(name) + sc.Annotations = map[string]string{ + defaultVirtStorageClass: "true", + } + return sc +} + type mockPrometheusClient struct { fakeUrl string responseBody string