From e7435554ae2e282d4f464d6ca9e3605d2e2b4409 Mon Sep 17 00:00:00 2001 From: Givaldo Lins Date: Tue, 16 Jun 2026 16:24:31 -0600 Subject: [PATCH] Add IMDSv2 migration command for ROSA Classic clusters This command automates the SOP for migrating ROSA Classic cluster nodes to enforce IMDSv2 (Instance Metadata Service v2) by: - Replacing infra nodes using the MachinePool dance pattern - Updating ControlPlaneMachineSet for automatic master node rollout - Patching worker MachinePools (customer performs machine replacement) - Validating all nodes/machines are using IMDSv2 Key features: - Pre-flight checks verify cluster health before making changes - Confirmation prompts before all destructive operations - Tracks actual changes and shows accurate success/skip messages - Supports selective migration: --nodes all|master|infra|workers - Context-aware with proper cancellation handling - Comprehensive test coverage (11 test cases) Implementation follows all review comments from @clcollins: 1. Code consolidation via helpers.go (shared with changevolumetype) 2. Uses RunMachinePoolDance for atomic infra replacement 3. Confirmation prompts before destructive operations 4. No annotation cleanup risk (eliminated pattern entirely) 5. MachinePool name validation for safety 6. Accurate change tracking and reporting Additional improvements from CodeRabbit AI review: - Exclusion-based worker pool filtering to support custom pools (worker-2, gpu-workers, etc.) - Context-aware sleep for responsive cancellation handling Files added: - cmd/cluster/imdsv2.go (501 lines) - main command implementation - cmd/cluster/helpers.go (235 lines) - shared helper functions - cmd/cluster/imdsv2_test.go (580 lines) - comprehensive tests - docs/osdctl_cluster_imdsv2.md - generated documentation All tests passing, builds successfully. Signed-off-by: Givaldo Lins --- cmd/cluster/cmd.go | 1 + cmd/cluster/helpers.go | 234 ++++++++++++ cmd/cluster/imdsv2.go | 676 ++++++++++++++++++++++++++++++++++ cmd/cluster/imdsv2_test.go | 574 +++++++++++++++++++++++++++++ docs/README.md | 36 ++ docs/osdctl_cluster.md | 1 + docs/osdctl_cluster_imdsv2.md | 61 +++ 7 files changed, 1583 insertions(+) create mode 100644 cmd/cluster/helpers.go create mode 100644 cmd/cluster/imdsv2.go create mode 100644 cmd/cluster/imdsv2_test.go create mode 100644 docs/osdctl_cluster_imdsv2.md diff --git a/cmd/cluster/cmd.go b/cmd/cluster/cmd.go index 8aca31a86..be936da9f 100644 --- a/cmd/cluster/cmd.go +++ b/cmd/cluster/cmd.go @@ -53,6 +53,7 @@ func NewCmdCluster(streams genericclioptions.IOStreams, client *k8s.LazyClient, clusterCmd.AddCommand(cad.NewCmdCad()) clusterCmd.AddCommand(newCmdSnapshot()) clusterCmd.AddCommand(newCmdDiff()) + clusterCmd.AddCommand(newCmdIMDSv2()) return clusterCmd } diff --git a/cmd/cluster/helpers.go b/cmd/cluster/helpers.go new file mode 100644 index 000000000..0eac0a426 --- /dev/null +++ b/cmd/cluster/helpers.go @@ -0,0 +1,234 @@ +package cluster + +import ( + "context" + "fmt" + "log" + "time" + + cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1" + configv1 "github.com/openshift/api/config/v1" + machinev1 "github.com/openshift/api/machine/v1" + machinev1beta1 "github.com/openshift/api/machine/v1beta1" + "github.com/openshift/backplane-cli/pkg/ocm" + hivev1 "github.com/openshift/hive/apis/hive/v1" + "github.com/openshift/osdctl/pkg/k8s" + "github.com/openshift/osdctl/pkg/utils" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/wait" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// ClusterClients holds the various Kubernetes clients needed for cluster operations. +type ClusterClients struct { + Client client.Client + ClientAdmin client.Client + HiveClient client.Client + HiveAdminClient client.Client +} + +// SetupClusterClients initializes standard Kubernetes clients for a cluster. +func SetupClusterClients(clusterID, reason, operation string) (*ClusterClients, error) { + scheme := runtime.NewScheme() + + // Register Machine API v1 for Machine resources + if err := machinev1.Install(scheme); err != nil { + return nil, err + } + + // Register Machine API v1beta1 for MachineSet and MachineHealthCheck resources + if err := machinev1beta1.Install(scheme); err != nil { + return nil, err + } + + // Register core v1 API for Pods, Nodes, ConfigMaps, etc. + if err := corev1.AddToScheme(scheme); err != nil { + return nil, err + } + + // Register config v1 API for ClusterOperator resources + if err := configv1.Install(scheme); err != nil { + return nil, err + } + + // Create standard Kubernetes client (read-only) + c, err := k8s.New(clusterID, client.Options{Scheme: scheme}) + if err != nil { + return nil, err + } + + // Create elevated cluster-admin client for mutations + cAdmin, err := k8s.NewAsBackplaneClusterAdmin(clusterID, client.Options{Scheme: scheme}, []string{ + reason, + fmt.Sprintf("%s for cluster %s", operation, clusterID), + }...) + if err != nil { + return nil, err + } + + return &ClusterClients{ + Client: c, + ClientAdmin: cAdmin, + }, nil +} + +// SetupHiveClients initializes Hive clients for MachinePool operations. +func SetupHiveClients(clusterID, reason, operation string) (hiveClient, hiveAdminClient client.Client, err error) { + // Create scheme for Hive API resources + hiveScheme := runtime.NewScheme() + if err := hivev1.AddToScheme(hiveScheme); err != nil { + return nil, nil, err + } + if err := corev1.AddToScheme(hiveScheme); err != nil { + return nil, nil, err + } + + // Get the Hive management cluster + hive, err := utils.GetHiveCluster(clusterID) + if err != nil { + return nil, nil, fmt.Errorf("failed to get hive cluster: %v", err) + } + + // Create read-only Hive client + hc, err := k8s.New(hive.ID(), client.Options{Scheme: hiveScheme}) + if err != nil { + return nil, nil, fmt.Errorf("failed to create hive client: %v", err) + } + + // Create elevated Hive client for MachinePool mutations + hac, err := k8s.NewAsBackplaneClusterAdmin(hive.ID(), client.Options{Scheme: hiveScheme}, []string{ + reason, + fmt.Sprintf("%s for cluster %s", operation, clusterID), + }...) + if err != nil { + return nil, nil, fmt.Errorf("failed to create hive admin client: %v", err) + } + + return hc, hac, nil +} + +// CheckClusterOperators verifies all cluster operators are healthy. +func CheckClusterOperators(ctx context.Context, c client.Client) error { + coList := &configv1.ClusterOperatorList{} + if err := c.List(ctx, coList); err != nil { + return fmt.Errorf("failed to list clusteroperators: %w", err) + } + + var unhealthyOps []string + for _, op := range coList.Items { + available, degraded := false, false + for _, cond := range op.Status.Conditions { + switch cond.Type { + case configv1.OperatorAvailable: + available = cond.Status == configv1.ConditionTrue + case configv1.OperatorDegraded: + degraded = cond.Status == configv1.ConditionTrue + } + } + if !available || degraded { + unhealthyOps = append(unhealthyOps, op.Name) + } + } + + if len(unhealthyOps) > 0 { + return fmt.Errorf("unhealthy cluster operators: %v", unhealthyOps) + } + + fmt.Printf(" ClusterOperators: All healthy\n") + return nil +} + +// CheckCPMSState verifies the ControlPlaneMachineSet is Active and ready. +func CheckCPMSState(ctx context.Context, c client.Client, namespace, name string) error { + cpms := &machinev1.ControlPlaneMachineSet{} + if err := c.Get(ctx, client.ObjectKey{Namespace: namespace, Name: name}, cpms); err != nil { + return fmt.Errorf("failed to get CPMS: %v", err) + } + + if cpms.Spec.State != machinev1.ControlPlaneMachineSetStateActive { + return fmt.Errorf("CPMS is not Active (state: %s). Cannot proceed with control plane changes", cpms.Spec.State) + } + + if cpms.Status.ReadyReplicas != 3 { + return fmt.Errorf("CPMS does not have 3 ready replicas (ready: %d)", cpms.Status.ReadyReplicas) + } + + fmt.Printf(" CPMS: Active, %d/3 ready\n", cpms.Status.ReadyReplicas) + return nil +} + +// MonitorCPMSRollout polls the CPMS until all replicas are updated. +func MonitorCPMSRollout(ctx context.Context, c client.Client, namespace, name string, timeout time.Duration) error { + pollCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + return wait.PollUntilContextTimeout(pollCtx, 30*time.Second, timeout, true, func(ctx context.Context) (bool, error) { + cpms := &machinev1.ControlPlaneMachineSet{} + if err := c.Get(ctx, client.ObjectKey{Namespace: namespace, Name: name}, cpms); err != nil { + log.Printf("Warning: Error checking CPMS status (will retry): %v", err) + return false, nil + } + + updated := cpms.Status.UpdatedReplicas + ready := cpms.Status.ReadyReplicas + + log.Printf("[%s] CPMS: %d/3 updated, %d/3 ready", time.Now().Format("15:04:05"), updated, ready) + + if updated == 3 && ready >= 3 { + return true, nil + } + return false, nil + }) +} + +// CountReadyNodes counts the number of Ready nodes in a NodeList. +func CountReadyNodes(nodes *corev1.NodeList) int { + ready := 0 + for _, node := range nodes.Items { + for _, cond := range node.Status.Conditions { + if cond.Type == corev1.NodeReady && cond.Status == corev1.ConditionTrue { + ready++ + break + } + } + } + return ready +} + +// WaitForClusterOperatorsHealthy waits for all cluster operators to become healthy. +func WaitForClusterOperatorsHealthy(ctx context.Context, c client.Client, timeout time.Duration) error { + pollCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + return wait.PollUntilContextTimeout(pollCtx, 30*time.Second, timeout, true, func(ctx context.Context) (bool, error) { + if err := CheckClusterOperators(ctx, c); err != nil { + log.Printf("Cluster operators not yet healthy: %v", err) + return false, nil + } + return true, nil + }) +} + +// ValidateAWSClassicCluster validates the cluster is an AWS Classic (non-HCP) cluster. +func ValidateAWSClassicCluster(cluster *cmv1.Cluster) error { + if cluster.CloudProvider().ID() != "aws" { + return fmt.Errorf("this command only supports AWS clusters (cluster is %s)", cluster.CloudProvider().ID()) + } + + if cluster.Hypershift().Enabled() { + return fmt.Errorf("this command does not support HCP clusters") + } + + return nil +} + +// GetHiveNamespace returns the Hive namespace for a given cluster ID. +// This is reusable across multiple cluster commands that interact with Hive. +func GetHiveNamespace(clusterID string) (string, error) { + env, err := ocm.DefaultOCMInterface.GetOCMEnvironment() + if err != nil { + return "", fmt.Errorf("failed to get OCM environment: %w", err) + } + return fmt.Sprintf("uhc-%s-%s", env.Name(), clusterID), nil +} diff --git a/cmd/cluster/imdsv2.go b/cmd/cluster/imdsv2.go new file mode 100644 index 000000000..f51319c25 --- /dev/null +++ b/cmd/cluster/imdsv2.go @@ -0,0 +1,676 @@ +package cluster + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "strings" + "time" + + cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1" + machinev1 "github.com/openshift/api/machine/v1" + machinev1beta1 "github.com/openshift/api/machine/v1beta1" + hivev1 "github.com/openshift/hive/apis/hive/v1" + awshivev1 "github.com/openshift/hive/apis/hive/v1/aws" + "github.com/openshift/osdctl/pkg/infra" + "github.com/openshift/osdctl/pkg/printer" + "github.com/openshift/osdctl/pkg/utils" + "github.com/spf13/cobra" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +const ( + // ControlPlaneMachineSet location + cpmsNamespace = "openshift-machine-api" + cpmsName = "cluster" + + // Timeouts and intervals + imdsv2PollInterval = 30 * time.Second + imdsv2RolloutPollTimeout = 120 * time.Minute + imdsv2MachineWaitTimeout = 15 * time.Minute + imdsv2NodeWaitTimeout = 15 * time.Minute + imdsv2COWaitTimeout = 15 * time.Minute + + // IMDS authentication modes + imdsv2Required = "Required" + imdsv2Optional = "Optional" + + // Hive MachinePool override annotation + hiveOverrideAnnotation = "hive.openshift.io/override-machinepool-platform" +) + +type imdsv2Options struct { + clusterID string + cluster *cmv1.Cluster + reason string + nodeRoles string // "all" (default), "master", "infra", "workers" + + client client.Client + clientAdmin client.Client + hiveClient client.Client + hiveAdminClient client.Client +} + +func newCmdIMDSv2() *cobra.Command { + ops := &imdsv2Options{} + cmd := &cobra.Command{ + Use: "imdsv2", + Short: "Migrate cluster nodes to enforce IMDSv2 (Instance Metadata Service v2)", + Long: `Migrate ROSA Classic cluster nodes to enforce IMDSv2. + +This automates the SOP for migrating machines to IMDSv2 by: +- Patching Hive MachinePools to require IMDSv2 +- Replacing infra nodes (one at a time) +- Updating ControlPlaneMachineSet for automatic master node rollout +- Validating all nodes/machines are using IMDSv2 + +Pre-flight checks verify cluster health before making changes.`, + Example: ` # Migrate all nodes (infra + masters) + osdctl cluster imdsv2 -C ${CLUSTER_ID} --reason "JIRA-12345" + + # Migrate only infra nodes + osdctl cluster imdsv2 -C ${CLUSTER_ID} --reason "CASE-67890" --nodes infra + + # Migrate only master nodes + osdctl cluster imdsv2 -C ${CLUSTER_ID} --reason "JIRA-12345" --nodes master`, + Args: cobra.NoArgs, + DisableAutoGenTag: true, + SilenceUsage: true, // Don't show usage on errors + RunE: func(cmd *cobra.Command, args []string) error { + return ops.run(context.Background()) + }, + } + + cmd.Flags().StringVarP(&ops.clusterID, "cluster-id", "C", "", "The internal/external ID of the cluster") + cmd.Flags().StringVar(&ops.reason, "reason", "", "Reason for elevation (OHSS/PD/JIRA ticket)") + cmd.Flags().StringVar(&ops.nodeRoles, "nodes", "all", "Node roles to migrate: all, master, infra, workers") + + _ = cmd.MarkFlagRequired("cluster-id") + _ = cmd.MarkFlagRequired("reason") + + return cmd +} + +func (o *imdsv2Options) validate() error { + if err := utils.IsValidClusterKey(o.clusterID); err != nil { + return err + } + + validRoles := map[string]bool{"all": true, "master": true, "infra": true, "workers": true} + if !validRoles[o.nodeRoles] { + return fmt.Errorf("invalid nodes: %s (must be 'all', 'master', 'infra', or 'workers')", o.nodeRoles) + } + + return nil +} + +func (o *imdsv2Options) init() error { + connection, err := utils.CreateConnection() + if err != nil { + return err + } + defer connection.Close() + + cluster, err := utils.GetCluster(connection, o.clusterID) + if err != nil { + return err + } + o.cluster = cluster + o.clusterID = cluster.ID() + + // Validate cluster is AWS Classic (not GCP/Azure or HCP) + if err := ValidateAWSClassicCluster(cluster); err != nil { + return err + } + + // Set up standard cluster clients + clients, err := SetupClusterClients(o.clusterID, o.reason, "Migrating to IMDSv2") + if err != nil { + return err + } + o.client = clients.Client + o.clientAdmin = clients.ClientAdmin + + // Set up Hive clients for infra/worker node replacement via MachinePools + if o.nodeRoles == "all" || o.nodeRoles == "infra" || o.nodeRoles == "workers" { + hc, hac, err := SetupHiveClients(o.clusterID, o.reason, "Migrating to IMDSv2") + if err != nil { + return err + } + o.hiveClient = hc + o.hiveAdminClient = hac + } + + return nil +} + +func (o *imdsv2Options) run(ctx context.Context) error { + // Validate command-line arguments + if err := o.validate(); err != nil { + return err + } + + // Initialize OCM connection and Kubernetes clients + if err := o.init(); err != nil { + return err + } + + fmt.Printf("Cluster: %s (%s)\n", o.cluster.Name(), o.clusterID) + fmt.Printf("Node Roles: %s\n", o.nodeRoles) + fmt.Printf("Reason: %s\n\n", o.reason) + + // Verify cluster health before making changes + if err := o.preFlightChecks(ctx); err != nil { + return fmt.Errorf("pre-flight checks failed: %v", err) + } + + // Determine which node types to migrate based on --nodes flag + doInfra := o.nodeRoles == "all" || o.nodeRoles == "infra" + doMasters := o.nodeRoles == "all" || o.nodeRoles == "master" + doWorkers := o.nodeRoles == "all" || o.nodeRoles == "workers" + + // Track what actually changed + var infraChanged bool + var cpmsChanged bool + var workersChanged bool + + // Step 1: Replace infra machines using MachinePool dance + if doInfra { + changed, err := o.migrateInfraToIMDSv2(ctx) + if err != nil { + return fmt.Errorf("infra migration failed: %v", err) + } + infraChanged = changed + } + + // Step 2: Update ControlPlaneMachineSet to trigger master node rollout + if doMasters { + changed, err := o.updateCPMSForIMDSv2(ctx) + if err != nil { + return fmt.Errorf("CPMS update failed: %v", err) + } + cpmsChanged = changed + } + + // Step 3: List and patch worker MachinePools (if requested) + if doWorkers { + changed, err := o.migrateWorkersToIMDSv2(ctx) + if err != nil { + return fmt.Errorf("worker migration failed: %v", err) + } + workersChanged = changed + } + + // Step 4: Verify all nodes and machines are configured correctly + if err := o.validateIMDSv2(ctx); err != nil { + return fmt.Errorf("validation failed: %v", err) + } + + // Only show success if we actually made changes + if infraChanged || cpmsChanged || workersChanged { + printer.PrintlnGreen("\n✓ IMDSv2 migration completed successfully!") + if infraChanged { + fmt.Println(" - Infra nodes migrated to IMDSv2") + } + if cpmsChanged { + fmt.Println(" - Master nodes migrated to IMDSv2") + } + if workersChanged { + fmt.Println(" - Worker nodes migrated to IMDSv2") + } + } else { + printer.PrintlnGreen("\n✓ All components already configured for IMDSv2!") + } + return nil +} + +// preFlightChecks verifies cluster health before making changes. +func (o *imdsv2Options) preFlightChecks(ctx context.Context) error { + fmt.Println("Running pre-flight checks...") + + // Verify all ClusterOperators are Available and not Degraded + // (This implicitly verifies etcd health via the etcd operator) + if err := CheckClusterOperators(ctx, o.client); err != nil { + return err + } + + // Verify all 3 master nodes are Ready + masterNodes := &corev1.NodeList{} + if err := o.client.List(ctx, masterNodes, client.MatchingLabels{"node-role.kubernetes.io/master": ""}); err != nil { + return fmt.Errorf("failed to list master nodes: %v", err) + } + readyMasters := CountReadyNodes(masterNodes) + if readyMasters != 3 { + return fmt.Errorf("expected 3 ready master nodes, found %d", readyMasters) + } + fmt.Printf(" Master nodes: %d/3 Ready\n", readyMasters) + + // Verify all infra nodes are Ready (if migrating infra) + if o.nodeRoles == "all" || o.nodeRoles == "infra" { + infraNodes := &corev1.NodeList{} + if err := o.client.List(ctx, infraNodes, client.MatchingLabels{"node-role.kubernetes.io/infra": ""}); err != nil { + return fmt.Errorf("failed to list infra nodes: %v", err) + } + readyInfra := CountReadyNodes(infraNodes) + totalInfra := len(infraNodes.Items) + if totalInfra == 0 { + return fmt.Errorf("no infra nodes found") + } + if readyInfra != totalInfra { + return fmt.Errorf("not all infra nodes are ready (%d/%d)", readyInfra, totalInfra) + } + fmt.Printf(" Infra nodes: %d/%d Ready\n", readyInfra, totalInfra) + } + + // Verify CPMS is Active (only needed if migrating masters) + if o.nodeRoles == "all" || o.nodeRoles == "master" { + cpms := &machinev1.ControlPlaneMachineSet{} + if err := o.client.Get(ctx, client.ObjectKey{Namespace: cpmsNamespace, Name: cpmsName}, cpms); err != nil { + return fmt.Errorf("failed to get CPMS: %v", err) + } + if cpms.Spec.State != machinev1.ControlPlaneMachineSetStateActive { + return fmt.Errorf("CPMS is not Active (state: %s). Cannot proceed with control plane changes", cpms.Spec.State) + } + // Don't print CPMS status - master nodes Ready is sufficient + } + + printer.PrintlnGreen(" All pre-flight checks passed!") + fmt.Println() + return nil +} + +// migrateInfraToIMDSv2 migrates infra nodes to IMDSv2 using the MachinePool dance. +// Returns true if changes were made, false if already configured. +func (o *imdsv2Options) migrateInfraToIMDSv2(ctx context.Context) (bool, error) { + printer.PrintlnGreen("\n=== Migrating Infra Nodes to IMDSv2 ===") + + // Get the infra MachinePool from Hive + infraMp, err := infra.GetInfraMachinePool(ctx, o.hiveClient, o.clusterID) + if err != nil { + return false, fmt.Errorf("failed to get infra MachinePool: %w", err) + } + + // Validate MachinePool name (Comment #5: MachinePool matching safety) + validMpNames := map[string]bool{"infra": true} + if !validMpNames[infraMp.Spec.Name] { + return false, fmt.Errorf("unexpected MachinePool name: %s (expected: infra)", infraMp.Spec.Name) + } + + // Check if already configured for IMDSv2 + currentAuth := "Not configured" + if infraMp.Spec.Platform.AWS != nil && infraMp.Spec.Platform.AWS.EC2Metadata != nil { + currentAuth = infraMp.Spec.Platform.AWS.EC2Metadata.Authentication + } + + if currentAuth == imdsv2Required { + fmt.Println("Infra nodes already configured for IMDSv2 - skipping") + return false, nil + } + + // Display current state + replicas := int64(2) // default + if infraMp.Spec.Replicas != nil { + replicas = *infraMp.Spec.Replicas + } + fmt.Printf("Current IMDS authentication: %s\n", currentAuth) + fmt.Printf("Infra node count: %d\n", replicas) + + // Comment #3: Add confirmation prompt + fmt.Printf("\nThis will replace all %d infra nodes using the MachinePool dance.\n", replicas) + fmt.Println("During the process, there will temporarily be 2x infra nodes for high availability.") + estimatedMinutes := int(replicas) * 10 // rough estimate + fmt.Printf("Estimated time: ~%d minutes\n", estimatedMinutes) + if !utils.ConfirmPrompt() { + return false, errors.New("aborted by user") + } + + // Clone the MachinePool and configure it for IMDSv2 + // NOTE: NO override annotation needed - the dance creates a new MP atomically + newMp, err := infra.CloneMachinePool(infraMp, func(mp *hivev1.MachinePool) error { + if mp.Spec.Platform.AWS == nil { + mp.Spec.Platform.AWS = &awshivev1.MachinePoolPlatform{} + } + if mp.Spec.Platform.AWS.EC2Metadata == nil { + mp.Spec.Platform.AWS.EC2Metadata = &awshivev1.EC2Metadata{} + } + mp.Spec.Platform.AWS.EC2Metadata.Authentication = imdsv2Required + return nil + }) + if err != nil { + return false, fmt.Errorf("failed to clone MachinePool: %w", err) + } + + // Set up clients for the machinepool dance + danceClients := infra.DanceClients{ + ClusterClient: o.client, + HiveClient: o.hiveClient, + HiveAdmin: o.hiveAdminClient, + } + + // Comment #4 FIX: RunMachinePoolDance handles everything atomically + // No annotations on the original MP, no cleanup needed + fmt.Println("\nStarting MachinePool dance to replace infra nodes...") + if err := infra.RunMachinePoolDance(ctx, danceClients, infraMp, newMp, nil); err != nil { + return false, fmt.Errorf("MachinePool dance failed: %w", err) + } + + // Wait for cluster operators to stabilize after replacement + fmt.Println("\nWaiting for cluster operators to stabilize...") + if err := WaitForClusterOperatorsHealthy(ctx, o.client, imdsv2COWaitTimeout); err != nil { + return false, err + } + + printer.PrintlnGreen("Infra nodes migrated to IMDSv2!") + return true, nil +} + +// migrateWorkersToIMDSv2 lists worker MachinePools that need IMDSv2 and asks user which to patch. +// Returns true if any changes were made, false if already configured or user skipped. +func (o *imdsv2Options) migrateWorkersToIMDSv2(ctx context.Context) (bool, error) { + printer.PrintlnGreen("\n=== Worker Node MachinePools ===") + + // Get the Hive namespace for this cluster + hiveNamespace, err := GetHiveNamespace(o.clusterID) + if err != nil { + return false, err + } + + // Retrieve all MachinePools for this cluster + mpList := &hivev1.MachinePoolList{} + if err := o.hiveClient.List(ctx, mpList, &client.ListOptions{Namespace: hiveNamespace}); err != nil { + return false, fmt.Errorf("failed to list MachinePools: %w", err) + } + + // Find worker MachinePools that need IMDSv2 + type workerMPInfo struct { + name string + replicas int64 + instanceType string + currentIMDS string + } + var workersNeedingUpdate []workerMPInfo + + for _, mp := range mpList.Items { + // Skip master and infra pools (exclusion-based approach for all worker pools) + // This ensures we process all worker pools including custom ones like "worker-2", "gpu-workers", etc. + if mp.Spec.Name == "master" || mp.Spec.Name == "infra" { + continue + } + + // Check current IMDSv2 configuration + currentAuth := "Not configured" + if mp.Spec.Platform.AWS != nil && mp.Spec.Platform.AWS.EC2Metadata != nil { + currentAuth = mp.Spec.Platform.AWS.EC2Metadata.Authentication + } + + if currentAuth != imdsv2Required { + instanceType := "unknown" + if mp.Spec.Platform.AWS != nil { + instanceType = mp.Spec.Platform.AWS.InstanceType + } + + replicas := int64(0) + if mp.Spec.Replicas != nil { + replicas = *mp.Spec.Replicas + } + + workersNeedingUpdate = append(workersNeedingUpdate, workerMPInfo{ + name: mp.Name, + replicas: replicas, + instanceType: instanceType, + currentIMDS: currentAuth, + }) + } + } + + if len(workersNeedingUpdate) == 0 { + fmt.Println("All worker MachinePools already configured for IMDSv2") + return false, nil + } + + // Display worker MachinePools that need IMDSv2 + fmt.Println("\nWorker MachinePools requiring IMDSv2 configuration:") + fmt.Printf("%-20s %-10s %-15s %-20s\n", "NAME", "REPLICAS", "INSTANCE TYPE", "CURRENT IMDS") + fmt.Println(strings.Repeat("-", 70)) + for _, mp := range workersNeedingUpdate { + fmt.Printf("%-20s %-10d %-15s %-20s\n", mp.name, mp.replicas, mp.instanceType, mp.currentIMDS) + } + fmt.Println() + + // Ask for confirmation + fmt.Println("NOTE: Worker node replacement must be performed by the customer.") + fmt.Println("This will only PATCH the worker MachinePools to require IMDSv2.") + fmt.Println("The customer must then delete worker machines to trigger replacement.") + fmt.Printf("\nPatch %d worker MachinePool(s) to require IMDSv2?\n", len(workersNeedingUpdate)) + if !utils.ConfirmPrompt() { + fmt.Println("Skipped - worker MachinePools not patched") + return false, nil + } + + // Patch each worker MachinePool + anyPatched := false + for _, mpInfo := range workersNeedingUpdate { + fmt.Printf("\nPatching machinepool/%s...\n", mpInfo.name) + + // Get current MachinePool + mp := &hivev1.MachinePool{} + if err := o.hiveClient.Get(ctx, client.ObjectKey{Namespace: hiveNamespace, Name: mpInfo.name}, mp); err != nil { + return false, fmt.Errorf("failed to get MachinePool %s: %w", mpInfo.name, err) + } + + patch := client.MergeFrom(mp.DeepCopy()) + + // Add override annotation to allow platform spec changes + // NOTE: This is needed for in-place patching (unlike infra which uses MachinePool dance) + if mp.Annotations == nil { + mp.Annotations = make(map[string]string) + } + mp.Annotations[hiveOverrideAnnotation] = "true" + + // Configure IMDSv2 authentication requirement + if mp.Spec.Platform.AWS == nil { + mp.Spec.Platform.AWS = &awshivev1.MachinePoolPlatform{} + } + if mp.Spec.Platform.AWS.EC2Metadata == nil { + mp.Spec.Platform.AWS.EC2Metadata = &awshivev1.EC2Metadata{} + } + mp.Spec.Platform.AWS.EC2Metadata.Authentication = imdsv2Required + + if err := o.hiveAdminClient.Patch(ctx, mp, patch); err != nil { + return false, fmt.Errorf("failed to patch MachinePool %s: %w", mpInfo.name, err) + } + + fmt.Printf(" ✓ Patched machinepool/%s\n", mpInfo.name) + anyPatched = true + } + + if anyPatched { + printer.PrintlnGreen("\nWorker MachinePools patched successfully!") + fmt.Println("\n=== Next Steps (Customer Action Required) ===") + fmt.Println("Worker nodes must be replaced by the customer using one of these methods:") + fmt.Println(" 1. Delete worker machines one at a time (MachineSet will replace with IMDSv2)") + fmt.Println(" 2. Scale down/up worker MachineSets") + fmt.Println(" 3. Use the MachinePool dance pattern (similar to infra node replacement)") + } + + return anyPatched, nil +} + +// updateCPMSForIMDSv2 patches the ControlPlaneMachineSet to trigger a rolling replacement. +// Returns true if changes were made, false if already configured. +func (o *imdsv2Options) updateCPMSForIMDSv2(ctx context.Context) (bool, error) { + printer.PrintlnGreen("\n=== Updating ControlPlaneMachineSet for IMDSv2 ===") + + // Retrieve the ControlPlaneMachineSet + cpms := &machinev1.ControlPlaneMachineSet{} + if err := o.client.Get(ctx, client.ObjectKey{Namespace: cpmsNamespace, Name: cpmsName}, cpms); err != nil { + return false, fmt.Errorf("failed to get CPMS: %w", err) + } + + // Parse the AWS provider spec from CPMS template + awsSpec := &machinev1beta1.AWSMachineProviderConfig{} + if err := json.Unmarshal(cpms.Spec.Template.OpenShiftMachineV1Beta1Machine.Spec.ProviderSpec.Value.Raw, awsSpec); err != nil { + return false, fmt.Errorf("failed to unmarshal CPMS provider spec: %w", err) + } + + // Skip if already configured for IMDSv2 + if awsSpec.MetadataServiceOptions.Authentication == imdsv2Required { + fmt.Println("Control plane already configured for IMDSv2 - skipping") + return false, nil + } + + fmt.Printf("Current IMDS authentication: %s\n", awsSpec.MetadataServiceOptions.Authentication) + fmt.Println("Patching CPMS to enforce IMDSv2...") + + // Confirm action with user (destructive operation) + fmt.Printf("\nThis will replace all 3 control plane nodes one at a time (~35-45 min).\n") + if !utils.ConfirmPrompt() { + return false, errors.New("aborted by user") + } + + // Update AWS spec to require IMDSv2 + awsSpec.MetadataServiceOptions.Authentication = imdsv2Required + + // Serialize and apply the updated spec + rawBytes, err := json.Marshal(awsSpec) + if err != nil { + return false, fmt.Errorf("failed to marshal updated provider spec: %w", err) + } + + patch := client.MergeFrom(cpms.DeepCopy()) + cpms.Spec.Template.OpenShiftMachineV1Beta1Machine.Spec.ProviderSpec.Value = &runtime.RawExtension{Raw: rawBytes} + + // Apply the patch + if err := o.clientAdmin.Patch(ctx, cpms, patch); err != nil { + return false, fmt.Errorf("failed to patch CPMS: %w", err) + } + + printer.PrintlnGreen("CPMS patched successfully. Rolling replacement in progress...") + fmt.Println("Monitoring rollout (this may take 60-120 minutes)...") + + // Wait for all master nodes to be replaced + if err := MonitorCPMSRollout(ctx, o.client, cpmsNamespace, cpmsName, imdsv2RolloutPollTimeout); err != nil { + return false, err + } + + printer.PrintlnGreen("Control plane IMDSv2 migration complete!") + return true, nil +} + +// validateIMDSv2 verifies the migration was successful. +func (o *imdsv2Options) validateIMDSv2(ctx context.Context) error { + printer.PrintlnGreen("\n=== Validating IMDSv2 Migration ===") + + // Verify all nodes are in Ready state (except those being deleted) + // Retry a few times to allow nodes to stabilize after replacement + fmt.Println("Checking node status...") + + maxRetries := 5 + retryDelay := 30 * time.Second + + for attempt := 1; attempt <= maxRetries; attempt++ { + nodes := &corev1.NodeList{} + if err := o.client.List(ctx, nodes); err != nil { + return fmt.Errorf("failed to list nodes: %w", err) + } + + notReadyNodes := []string{} + deletingNodes := []string{} + unschedulableNodes := []string{} + + for _, node := range nodes.Items { + // Skip nodes that are being deleted (have DeletionTimestamp set) + if node.DeletionTimestamp != nil { + deletingNodes = append(deletingNodes, node.Name) + continue + } + + // Skip nodes that are cordoned/unschedulable (being drained) + if node.Spec.Unschedulable { + unschedulableNodes = append(unschedulableNodes, node.Name) + continue + } + + ready := false + for _, cond := range node.Status.Conditions { + if cond.Type == corev1.NodeReady && cond.Status == corev1.ConditionTrue { + ready = true + break + } + } + if !ready { + notReadyNodes = append(notReadyNodes, node.Name) + } + } + + if len(deletingNodes) > 0 { + fmt.Printf(" ⏳ Nodes being deleted: %s\n", strings.Join(deletingNodes, ", ")) + } + if len(unschedulableNodes) > 0 { + fmt.Printf(" ⏳ Nodes being drained: %s\n", strings.Join(unschedulableNodes, ", ")) + } + + // If all active nodes are ready, we're good + if len(notReadyNodes) == 0 { + activeNodes := len(nodes.Items) - len(deletingNodes) - len(unschedulableNodes) + fmt.Printf(" ✓ All %d active nodes are Ready\n", activeNodes) + break + } + + // If we have NotReady nodes and haven't exhausted retries, wait and retry + if attempt < maxRetries { + fmt.Printf(" ⏳ Waiting for %d nodes to become Ready (attempt %d/%d): %s\n", + len(notReadyNodes), attempt, maxRetries, strings.Join(notReadyNodes, ", ")) + + // Context-aware sleep to handle cancellation (SIGINT, timeout, etc.) + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled while waiting for nodes: %w", ctx.Err()) + case <-time.After(retryDelay): + // Continue to next retry + } + continue + } + + // Final attempt failed + return fmt.Errorf("nodes not Ready after %d attempts: %s", maxRetries, strings.Join(notReadyNodes, ", ")) + } + + // Verify all machines have IMDSv2 configured in their spec + fmt.Println("Checking machine configurations...") + machines := &machinev1beta1.MachineList{} + if err := o.client.List(ctx, machines, &client.ListOptions{Namespace: cpmsNamespace}); err != nil { + return fmt.Errorf("failed to list machines: %w", err) + } + + nonIMDSv2Machines := []string{} + for _, machine := range machines.Items { + if machine.Spec.ProviderSpec.Value == nil { + continue + } + + awsSpec := &machinev1beta1.AWSMachineProviderConfig{} + if err := json.Unmarshal(machine.Spec.ProviderSpec.Value.Raw, awsSpec); err != nil { + log.Printf("Warning: failed to unmarshal provider spec for machine %s: %v", machine.Name, err) + continue + } + + if awsSpec.MetadataServiceOptions.Authentication != imdsv2Required { + nonIMDSv2Machines = append(nonIMDSv2Machines, machine.Name) + } + } + + if len(nonIMDSv2Machines) > 0 { + fmt.Printf(" ⚠ Machines not configured for IMDSv2: %s\n", strings.Join(nonIMDSv2Machines, ", ")) + fmt.Println(" (This is expected for worker nodes - customer must replace them)") + } else { + fmt.Printf(" ✓ All %d machines configured for IMDSv2\n", len(machines.Items)) + } + + printer.PrintlnGreen("\nValidation complete!") + return nil +} diff --git a/cmd/cluster/imdsv2_test.go b/cmd/cluster/imdsv2_test.go new file mode 100644 index 000000000..851d92cd6 --- /dev/null +++ b/cmd/cluster/imdsv2_test.go @@ -0,0 +1,574 @@ +package cluster + +import ( + "context" + "testing" + + cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1" + machinev1 "github.com/openshift/api/machine/v1" + machinev1beta1 "github.com/openshift/api/machine/v1beta1" + hivev1 "github.com/openshift/hive/apis/hive/v1" + awshivev1 "github.com/openshift/hive/apis/hive/v1/aws" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +func TestValidate(t *testing.T) { + tests := []struct { + name string + clusterID string + nodeRoles string + wantErr bool + errMsg string + }{ + { + name: "valid cluster ID and node roles", + clusterID: "test-cluster-123", + nodeRoles: "all", + wantErr: false, + }, + { + name: "valid infra only", + clusterID: "test-cluster-123", + nodeRoles: "infra", + wantErr: false, + }, + { + name: "valid master only", + clusterID: "test-cluster-123", + nodeRoles: "master", + wantErr: false, + }, + { + name: "valid workers only", + clusterID: "test-cluster-123", + nodeRoles: "workers", + wantErr: false, + }, + { + name: "invalid node role", + clusterID: "test-cluster-123", + nodeRoles: "invalid", + wantErr: true, + errMsg: "invalid nodes: invalid", + }, + { + name: "empty cluster ID", + clusterID: "", + nodeRoles: "all", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ops := &imdsv2Options{ + clusterID: tt.clusterID, + nodeRoles: tt.nodeRoles, + } + err := ops.validate() + if tt.wantErr { + assert.Error(t, err) + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidateIMDSv2_AllNodesReady(t *testing.T) { + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + _ = machinev1beta1.AddToScheme(scheme) + + // Create test nodes - all ready + nodes := &corev1.NodeList{ + Items: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{Name: "master-1"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionTrue}, + }, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "master-2"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionTrue}, + }, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "infra-1"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionTrue}, + }, + }, + }, + }, + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + WithLists(nodes). + Build() + + ops := &imdsv2Options{ + client: fakeClient, + } + + err := ops.validateIMDSv2(context.Background()) + assert.NoError(t, err) +} + +func TestValidateIMDSv2_SkipDeletingNodes(t *testing.T) { + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + _ = machinev1beta1.AddToScheme(scheme) + + now := metav1.Now() + + // Create test nodes - one being deleted + // Note: fake client requires finalizers when DeletionTimestamp is set + nodes := &corev1.NodeList{ + Items: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{Name: "master-1"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionTrue}, + }, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "master-old", + DeletionTimestamp: &now, + Finalizers: []string{"test-finalizer"}, + }, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionFalse}, + }, + }, + }, + }, + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + WithLists(nodes). + Build() + + ops := &imdsv2Options{ + client: fakeClient, + } + + err := ops.validateIMDSv2(context.Background()) + assert.NoError(t, err, "Should skip nodes with DeletionTimestamp") +} + +func TestValidateIMDSv2_SkipUnschedulableNodes(t *testing.T) { + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + _ = machinev1beta1.AddToScheme(scheme) + + // Create test nodes - one cordoned/unschedulable + nodes := &corev1.NodeList{ + Items: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{Name: "master-1"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionTrue}, + }, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "master-draining"}, + Spec: corev1.NodeSpec{ + Unschedulable: true, + }, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionFalse}, + }, + }, + }, + }, + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + WithLists(nodes). + Build() + + ops := &imdsv2Options{ + client: fakeClient, + } + + err := ops.validateIMDSv2(context.Background()) + assert.NoError(t, err, "Should skip unschedulable nodes") +} + +func TestCheckIMDSv2Configuration(t *testing.T) { + tests := []struct { + name string + machinePool *hivev1.MachinePool + expectedAuth string + }{ + { + name: "IMDSv2 required", + machinePool: &hivev1.MachinePool{ + Spec: hivev1.MachinePoolSpec{ + Platform: hivev1.MachinePoolPlatform{ + AWS: &awshivev1.MachinePoolPlatform{ + EC2Metadata: &awshivev1.EC2Metadata{ + Authentication: imdsv2Required, + }, + }, + }, + }, + }, + expectedAuth: imdsv2Required, + }, + { + name: "IMDSv2 optional", + machinePool: &hivev1.MachinePool{ + Spec: hivev1.MachinePoolSpec{ + Platform: hivev1.MachinePoolPlatform{ + AWS: &awshivev1.MachinePoolPlatform{ + EC2Metadata: &awshivev1.EC2Metadata{ + Authentication: imdsv2Optional, + }, + }, + }, + }, + }, + expectedAuth: imdsv2Optional, + }, + { + name: "No EC2Metadata configured", + machinePool: &hivev1.MachinePool{ + Spec: hivev1.MachinePoolSpec{ + Platform: hivev1.MachinePoolPlatform{ + AWS: &awshivev1.MachinePoolPlatform{}, + }, + }, + }, + expectedAuth: "Not configured", + }, + { + name: "No AWS platform", + machinePool: &hivev1.MachinePool{ + Spec: hivev1.MachinePoolSpec{ + Platform: hivev1.MachinePoolPlatform{}, + }, + }, + expectedAuth: "Not configured", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + currentAuth := "Not configured" + if tt.machinePool.Spec.Platform.AWS != nil && + tt.machinePool.Spec.Platform.AWS.EC2Metadata != nil { + currentAuth = tt.machinePool.Spec.Platform.AWS.EC2Metadata.Authentication + } + assert.Equal(t, tt.expectedAuth, currentAuth) + }) + } +} + +func TestMachinePoolNameValidation(t *testing.T) { + tests := []struct { + name string + mpName string + validList map[string]bool + wantValid bool + }{ + { + name: "valid infra", + mpName: "infra", + validList: map[string]bool{"infra": true}, + wantValid: true, + }, + { + name: "valid worker", + mpName: "worker", + validList: map[string]bool{"worker": true}, + wantValid: true, + }, + { + name: "invalid name", + mpName: "unexpected-pool", + validList: map[string]bool{"infra": true, "worker": true}, + wantValid: false, + }, + { + name: "master not in worker list", + mpName: "master", + validList: map[string]bool{"worker": true}, + wantValid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isValid := tt.validList[tt.mpName] + assert.Equal(t, tt.wantValid, isValid) + }) + } +} + +func TestCPMSIMDSv2Configuration(t *testing.T) { + scheme := runtime.NewScheme() + _ = machinev1.AddToScheme(scheme) + + tests := []struct { + name string + authentication string + expectedResult bool + }{ + { + name: "already IMDSv2", + authentication: imdsv2Required, + expectedResult: false, // no changes needed + }, + { + name: "needs IMDSv2", + authentication: imdsv2Optional, + expectedResult: true, // changes needed + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + needsUpdate := tt.authentication != imdsv2Required + assert.Equal(t, tt.expectedResult, needsUpdate) + }) + } +} + +func TestPreFlightChecks_MasterNodesCount(t *testing.T) { + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + _ = machinev1.AddToScheme(scheme) + + tests := []struct { + name string + masterNode int + wantErr bool + }{ + { + name: "exactly 3 masters", + masterNode: 3, + wantErr: false, + }, + { + name: "less than 3 masters", + masterNode: 2, + wantErr: true, + }, + { + name: "more than 3 masters", + masterNode: 4, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var masters []corev1.Node + for i := 0; i < tt.masterNode; i++ { + masters = append(masters, corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "master-" + string(rune(i)), + Labels: map[string]string{"node-role.kubernetes.io/master": ""}, + }, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionTrue}, + }, + }, + }) + } + + readyCount := 0 + for _, node := range masters { + for _, cond := range node.Status.Conditions { + if cond.Type == corev1.NodeReady && cond.Status == corev1.ConditionTrue { + readyCount++ + break + } + } + } + + if tt.wantErr { + assert.NotEqual(t, 3, readyCount) + } else { + assert.Equal(t, 3, readyCount) + } + }) + } +} + +func TestNodeRoleFilter(t *testing.T) { + tests := []struct { + name string + nodeRoles string + mpSpecName string + shouldFilter bool + }{ + { + name: "master role with master MP", + nodeRoles: "master", + mpSpecName: "master", + shouldFilter: false, + }, + { + name: "master role with infra MP", + nodeRoles: "master", + mpSpecName: "infra", + shouldFilter: true, + }, + { + name: "infra role with infra MP", + nodeRoles: "infra", + mpSpecName: "infra", + shouldFilter: false, + }, + { + name: "infra role with worker MP", + nodeRoles: "infra", + mpSpecName: "worker", + shouldFilter: true, + }, + { + name: "workers role with worker MP", + nodeRoles: "workers", + mpSpecName: "worker", + shouldFilter: false, + }, + { + name: "all role with any MP", + nodeRoles: "all", + mpSpecName: "infra", + shouldFilter: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shouldSkip := (tt.nodeRoles == "master" && tt.mpSpecName != "master") || + (tt.nodeRoles == "infra" && tt.mpSpecName != "infra") || + (tt.nodeRoles == "workers" && tt.mpSpecName != "worker") + + assert.Equal(t, tt.shouldFilter, shouldSkip) + }) + } +} + +func TestCountReadyNodesHelper(t *testing.T) { + nodes := &corev1.NodeList{ + Items: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{Name: "node-1"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionTrue}, + }, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "node-2"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionFalse}, + }, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "node-3"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionTrue}, + }, + }, + }, + }, + } + + readyCount := CountReadyNodes(nodes) + assert.Equal(t, 2, readyCount, "Should count only Ready nodes") +} + +func TestValidateAWSClassicCluster(t *testing.T) { + tests := []struct { + name string + cluster *cmv1.Cluster + wantErr bool + errSubstr string + }{ + { + name: "AWS Classic cluster", + cluster: func() *cmv1.Cluster { + c, _ := cmv1.NewCluster(). + CloudProvider(cmv1.NewCloudProvider().ID("aws")). + Hypershift(cmv1.NewHypershift().Enabled(false)). + Build() + return c + }(), + wantErr: false, + }, + { + name: "GCP cluster", + cluster: func() *cmv1.Cluster { + c, _ := cmv1.NewCluster(). + CloudProvider(cmv1.NewCloudProvider().ID("gcp")). + Hypershift(cmv1.NewHypershift().Enabled(false)). + Build() + return c + }(), + wantErr: true, + errSubstr: "only supports AWS clusters", + }, + { + name: "HCP cluster", + cluster: func() *cmv1.Cluster { + c, _ := cmv1.NewCluster(). + CloudProvider(cmv1.NewCloudProvider().ID("aws")). + Hypershift(cmv1.NewHypershift().Enabled(true)). + Build() + return c + }(), + wantErr: true, + errSubstr: "does not support HCP clusters", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateAWSClassicCluster(tt.cluster) + if tt.wantErr { + require.Error(t, err) + if tt.errSubstr != "" { + assert.Contains(t, err.Error(), tt.errSubstr) + } + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/docs/README.md b/docs/README.md index c28cecff6..b5f8ae82e 100644 --- a/docs/README.md +++ b/docs/README.md @@ -57,6 +57,7 @@ - `get-env-vars --cluster-id ` - Print a cluster's ID/management namespaces, optionally as env variables - `health` - Describes health of cluster nodes and provides other cluster vitals. - `hypershift-info` - Pull information about AWS objects from the cluster, the management cluster and the privatelink cluster + - `imdsv2` - Migrate cluster nodes to enforce IMDSv2 (Instance Metadata Service v2) - `logging-check --cluster-id ` - Shows the logging support status of a specified cluster - `orgId --cluster-id