diff --git a/SECURITY.md b/SECURITY.md index b81d717..2a3101f 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -88,8 +88,10 @@ The AWS Multi-ENI Controller requires specific IAM permissions to function prope "ec2:DescribeNetworkInterfaces", "ec2:AttachNetworkInterface", "ec2:DetachNetworkInterface", + "ec2:DescribeInstances", "ec2:DescribeSubnets", - "ec2:DescribeSecurityGroups" + "ec2:DescribeSecurityGroups", + "ec2:ModifyInstanceMetadataOptions" ], "Resource": "*" } diff --git a/pkg/aws/ec2.go b/pkg/aws/ec2.go index b290cb2..17b47b6 100644 --- a/pkg/aws/ec2.go +++ b/pkg/aws/ec2.go @@ -932,7 +932,11 @@ func (c *EC2Client) configureIMDSWithFallback(ctx context.Context, hopLimit int3 } // Strategy 4: Configure all instances in the current VPC (last resort) - if err := c.tryVPCWideConfiguration(ctx, hopLimit); err == nil { + // Resolve VPC ID first to scope the configuration to the correct VPC + vpcID, err := c.resolveCurrentVPCID(ctx) + if err != nil { + c.Logger.Info("Failed to resolve VPC ID, skipping VPC-wide configuration", "error", err.Error()) + } else if err := c.tryVPCWideConfiguration(ctx, hopLimit, vpcID); err == nil { c.Logger.Info("Successfully configured IMDS hop limit using VPC-wide approach") return nil } @@ -1125,7 +1129,11 @@ func (c *EC2Client) tryPrivateIPBasedConfiguration(ctx context.Context, hopLimit } // tryVPCWideConfiguration attempts to configure IMDS for all instances in the VPC (last resort) -func (c *EC2Client) tryVPCWideConfiguration(ctx context.Context, hopLimit int32) error { +func (c *EC2Client) tryVPCWideConfiguration(ctx context.Context, hopLimit int32, vpcID string) error { + if vpcID == "" { + return fmt.Errorf("cannot perform VPC-wide configuration without a VPC ID") + } + // Check if aggressive configuration is enabled aggressiveConfig := os.Getenv("IMDS_AGGRESSIVE_CONFIGURATION") if aggressiveConfig != "true" { @@ -1136,15 +1144,19 @@ func (c *EC2Client) tryVPCWideConfiguration(ctx context.Context, hopLimit int32) // This is a last resort strategy - configure IMDS for all instances that might need it // We'll look for instances that have hop limit 1 and are in running state - c.Logger.Info("Attempting VPC-wide IMDS configuration as last resort") + c.Logger.Info("Attempting VPC-wide IMDS configuration as last resort", "vpcID", vpcID) - // Get all running instances in the region + // Get all running instances in the VPC input := &ec2.DescribeInstancesInput{ Filters: []types.Filter{ { Name: aws.String("instance-state-name"), Values: []string{"running"}, }, + { + Name: aws.String("vpc-id"), + Values: []string{vpcID}, + }, }, } @@ -1201,6 +1213,68 @@ func (c *EC2Client) tryVPCWideConfiguration(ctx context.Context, hopLimit int32) return nil } +// resolveCurrentVPCID determines the VPC ID of the current instance by looking up its private IP +func (c *EC2Client) resolveCurrentVPCID(ctx context.Context) (string, error) { + if c.EC2 == nil { + return "", fmt.Errorf("EC2 client is not initialized") + } + + privateIP, err := c.getPrivateIPFromNetworkInterface() + if err != nil { + return "", fmt.Errorf("failed to get private IP for VPC resolution: %v", err) + } + + if privateIP == "" { + return "", fmt.Errorf("no private IP found for VPC resolution") + } + + c.Logger.V(1).Info("Resolving VPC ID using private IP", "privateIP", privateIP) + + input := &ec2.DescribeInstancesInput{ + Filters: []types.Filter{ + { + Name: aws.String("private-ip-address"), + Values: []string{privateIP}, + }, + { + Name: aws.String("instance-state-name"), + Values: []string{"running", "pending"}, + }, + }, + } + + result, err := c.EC2.DescribeInstances(ctx, input) + if err != nil { + return "", fmt.Errorf("failed to describe instances for VPC resolution: %v", err) + } + + // Collect unique VPC IDs to detect ambiguity from overlapping CIDRs + vpcIDs := make(map[string]struct{}) + for _, reservation := range result.Reservations { + for _, instance := range reservation.Instances { + if instance.VpcId != nil { + vpcIDs[*instance.VpcId] = struct{}{} + } + } + } + + if len(vpcIDs) == 0 { + return "", fmt.Errorf("no VPC ID found for instance with private IP %s", privateIP) + } + + if len(vpcIDs) > 1 { + return "", fmt.Errorf("ambiguous VPC resolution: private IP %s matched instances in %d different VPCs", privateIP, len(vpcIDs)) + } + + // Exactly one VPC matched + for vpcID := range vpcIDs { + c.Logger.V(1).Info("Resolved VPC ID", "vpcID", vpcID, "privateIP", privateIP) + return vpcID, nil + } + + return "", fmt.Errorf("no VPC ID found for instance with private IP %s", privateIP) +} + // configureInstanceIMDS configures IMDS hop limit for a specific instance func (c *EC2Client) configureInstanceIMDS(ctx context.Context, instanceID string, hopLimit int32) error { c.Logger.Info("Configuring IMDS for instance", "instanceID", instanceID, "hopLimit", hopLimit) diff --git a/pkg/aws/imds_vpc_test.go b/pkg/aws/imds_vpc_test.go new file mode 100644 index 0000000..bfd2595 --- /dev/null +++ b/pkg/aws/imds_vpc_test.go @@ -0,0 +1,65 @@ +package aws + +import ( + "context" + "testing" + + "github.com/go-logr/logr/testr" +) + +// TestTryVPCWideConfiguration_EmptyVPCID verifies that tryVPCWideConfiguration +// rejects an empty VPC ID immediately without making any API calls. +func TestTryVPCWideConfiguration_EmptyVPCID(t *testing.T) { + client := &EC2Client{ + Logger: testr.New(t), + } + + err := client.tryVPCWideConfiguration(context.Background(), 2, "") + if err == nil { + t.Fatal("expected error when VPC ID is empty, got nil") + } + + expected := "cannot perform VPC-wide configuration without a VPC ID" + if err.Error() != expected { + t.Errorf("expected error %q, got %q", expected, err.Error()) + } +} + +// TestTryVPCWideConfiguration_AggressiveDisabled verifies that tryVPCWideConfiguration +// returns an error when aggressive configuration is disabled. +func TestTryVPCWideConfiguration_AggressiveDisabled(t *testing.T) { + // t.Setenv automatically restores the original value after the test + t.Setenv("IMDS_AGGRESSIVE_CONFIGURATION", "false") + + client := &EC2Client{ + Logger: testr.New(t), + } + + err := client.tryVPCWideConfiguration(context.Background(), 2, "vpc-12345") + if err == nil { + t.Fatal("expected error when aggressive configuration is disabled, got nil") + } + + expected := "aggressive configuration disabled" + if err.Error() != expected { + t.Errorf("expected error %q, got %q", expected, err.Error()) + } +} + +// TestResolveCurrentVPCID_NilEC2Client verifies that resolveCurrentVPCID +// returns a clear error when the EC2 client is not initialized. +func TestResolveCurrentVPCID_NilEC2Client(t *testing.T) { + client := &EC2Client{ + Logger: testr.New(t), + } + + _, err := client.resolveCurrentVPCID(context.Background()) + if err == nil { + t.Fatal("expected error when EC2 client is nil, got nil") + } + + expected := "EC2 client is not initialized" + if err.Error() != expected { + t.Errorf("expected error %q, got %q", expected, err.Error()) + } +}