diff --git a/README.md b/README.md index df0c5e6..96a3cdd 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ contexts: | Service | Feature | Status | |---------|---------|--------| | EC2 | SSM Session Manager (connect to EC2 instances) | ✅ Implemented | +| EC2 | Security Group Browser (list/filter SGs, view inbound/outbound rules) | ✅ Implemented | | VPC | VPC Browser (VPCs → subnets → available IPs) | ✅ Implemented | | RDS | RDS Browser (list, start/stop, failover, Aurora cluster support) | ✅ Implemented | | Route53 | ListHostedZones | 🚧 Coming Soon | diff --git a/internal/app/app.go b/internal/app/app.go index c173ea3..838f6a2 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -31,6 +31,8 @@ const ( screenRoute53RecordDetail screenSecretList screenSecretDetail + screenSecurityGroupList + screenSecurityGroupDetail screenContextPicker screenContextAdd screenLoading @@ -109,6 +111,14 @@ type Model struct { secretFilterActive bool selectedSecret *awsservice.SecretDetail + // Security Group browser state + securityGroups []awsservice.SecurityGroup + filteredSecurityGroups []awsservice.SecurityGroup + sgIdx int + sgFilter string + sgFilterActive bool + selectedSecurityGroup *awsservice.SecurityGroup + // Context picker configPath string ctxList []config.ContextInfo @@ -240,6 +250,13 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.screen = screenSecretDetail return m, nil + case securityGroupsLoadedMsg: + m.securityGroups = msg.securityGroups + m.filteredSecurityGroups = msg.securityGroups + m.sgIdx = 0 + m.screen = screenSecurityGroupList + return m, nil + case rdsActionDoneMsg: if msg.err != nil { m.errMsg = msg.err.Error() @@ -365,6 +382,10 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m.updateSecretList(msg) case screenSecretDetail: return m.updateSecretDetail(msg) + case screenSecurityGroupList: + return m.updateSecurityGroupList(msg) + case screenSecurityGroupDetail: + return m.updateSecurityGroupDetail(msg) case screenContextPicker: return m.updateContextPicker(msg) case screenContextAdd: @@ -434,6 +455,9 @@ func (m Model) updateFeatureList(msg tea.KeyMsg) (tea.Model, tea.Cmd) { case domain.FeatureSecretsBrowser: m.screen = screenLoading return m, m.loadSecrets() + case domain.FeatureSecurityGroupBrowser: + m.screen = screenLoading + return m, m.loadSecurityGroups() } } } @@ -487,6 +511,10 @@ func (m Model) View() string { v = m.viewSecretList() case screenSecretDetail: v = m.viewSecretDetail() + case screenSecurityGroupList: + v = m.viewSecurityGroupList() + case screenSecurityGroupDetail: + v = m.viewSecurityGroupDetail() case screenContextPicker: v = m.viewContextPicker() case screenContextAdd: diff --git a/internal/app/app_test.go b/internal/app/app_test.go index a3e1c30..1982a79 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -588,3 +588,131 @@ func TestViewFitsTerminalHeight(t *testing.T) { t.Errorf("view output has %d lines, exceeds terminal height %d", len(lines), m.height) } } + +// --- Security Group tests --- + +func TestSecurityGroupListNavigation(t *testing.T) { + m := New(testConfig(), "") + m.screen = screenSecurityGroupList + m.securityGroups = []awsservice.SecurityGroup{ + {GroupID: "sg-1", Name: "web", VPCID: "vpc-1"}, + {GroupID: "sg-2", Name: "db", VPCID: "vpc-1"}, + } + m.filteredSecurityGroups = m.securityGroups + + updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'j'}}) + model := updated.(Model) + if model.sgIdx != 1 { + t.Errorf("expected sgIdx 1, got %d", model.sgIdx) + } + + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'k'}}) + model = updated.(Model) + if model.sgIdx != 0 { + t.Errorf("expected sgIdx 0, got %d", model.sgIdx) + } +} + +func TestSecurityGroupListEnterGoesToDetail(t *testing.T) { + m := New(testConfig(), "") + m.screen = screenSecurityGroupList + m.securityGroups = []awsservice.SecurityGroup{ + {GroupID: "sg-1", Name: "web", VPCID: "vpc-1"}, + } + m.filteredSecurityGroups = m.securityGroups + + updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter}) + model := updated.(Model) + if model.screen != screenSecurityGroupDetail { + t.Errorf("expected detail screen, got %d", model.screen) + } + if model.selectedSecurityGroup == nil { + t.Error("selectedSecurityGroup should not be nil") + } +} + +func TestSecurityGroupDetailEscGoesBack(t *testing.T) { + m := New(testConfig(), "") + m.screen = screenSecurityGroupDetail + m.selectedSecurityGroup = &awsservice.SecurityGroup{GroupID: "sg-1"} + + updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyEsc}) + model := updated.(Model) + if model.screen != screenSecurityGroupList { + t.Errorf("expected list screen, got %d", model.screen) + } +} + +func TestSecurityGroupFilter(t *testing.T) { + m := New(testConfig(), "") + m.screen = screenSecurityGroupList + m.securityGroups = []awsservice.SecurityGroup{ + {GroupID: "sg-1", Name: "web-sg", VPCID: "vpc-1"}, + {GroupID: "sg-2", Name: "db-sg", VPCID: "vpc-1"}, + } + m.filteredSecurityGroups = m.securityGroups + + // Activate filter + updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}}) + model := updated.(Model) + if !model.sgFilterActive { + t.Error("filter should be active") + } + + // Type "web" + for _, ch := range "web" { + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{ch}}) + model = updated.(Model) + } + if len(model.filteredSecurityGroups) != 1 { + t.Errorf("expected 1 filtered SG, got %d", len(model.filteredSecurityGroups)) + } +} + +func TestSecurityGroupDetailView(t *testing.T) { + m := New(testConfig(), "") + m.screen = screenSecurityGroupDetail + m.height = 30 + m.selectedSecurityGroup = &awsservice.SecurityGroup{ + GroupID: "sg-aaa", + Name: "web-sg", + Description: "Web servers", + VPCID: "vpc-111", + IngressRules: []awsservice.SecurityGroupRule{ + {Protocol: "tcp", FromPort: 443, ToPort: 443, CIDRV4: "0.0.0.0/0", Description: "HTTPS"}, + }, + EgressRules: []awsservice.SecurityGroupRule{ + {Protocol: "-1", CIDRV4: "0.0.0.0/0"}, + }, + } + + v := m.View() + if !strings.Contains(v, "sg-aaa") { + t.Error("detail view should contain group ID") + } + if !strings.Contains(v, "Inbound Rules") { + t.Error("detail view should show inbound rules section") + } + if !strings.Contains(v, "Outbound Rules") { + t.Error("detail view should show outbound rules section") + } + if !strings.Contains(v, "443") { + t.Error("detail view should show port 443") + } +} + +func TestSecurityGroupBrowserInCatalog(t *testing.T) { + catalog := domain.Catalog() + for _, svc := range catalog { + if svc.Name == domain.ServiceEC2 { + for _, feat := range svc.Features { + if feat.Kind == domain.FeatureSecurityGroupBrowser { + return + } + } + t.Error("EC2 should have Security Group Browser feature") + return + } + } + t.Error("EC2 service not found") +} diff --git a/internal/app/messages.go b/internal/app/messages.go index 0d75265..7954af6 100644 --- a/internal/app/messages.go +++ b/internal/app/messages.go @@ -82,3 +82,7 @@ type secretsLoadedMsg struct { type secretDetailLoadedMsg struct { detail *awsservice.SecretDetail } + +type securityGroupsLoadedMsg struct { + securityGroups []awsservice.SecurityGroup +} diff --git a/internal/app/screen_securitygroup.go b/internal/app/screen_securitygroup.go new file mode 100644 index 0000000..1823659 --- /dev/null +++ b/internal/app/screen_securitygroup.go @@ -0,0 +1,268 @@ +package app + +import ( + "context" + "fmt" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + + awsservice "unic/internal/services/aws" +) + +func (m Model) loadSecurityGroups() tea.Cmd { + return func() tea.Msg { + ctx := context.Background() + repo, err := awsservice.NewAwsRepository(ctx, m.cfg) + if err != nil { + return errMsg{err: err} + } + m.awsRepo = repo + + sgs, err := repo.ListSecurityGroups(ctx) + if err != nil { + return errMsg{err: err} + } + if len(sgs) == 0 { + return errMsg{err: fmt.Errorf("no security groups found")} + } + return securityGroupsLoadedMsg{securityGroups: sgs} + } +} + +func (m Model) updateSecurityGroupList(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + key := msg.String() + + if m.sgFilterActive { + switch key { + case "esc": + m.sgFilterActive = false + case "enter": + m.sgFilterActive = false + case "backspace": + if len(m.sgFilter) > 0 { + m.sgFilter = m.sgFilter[:len(m.sgFilter)-1] + m.applySecurityGroupFilter() + } + default: + if len(key) == 1 { + m.sgFilter += key + m.applySecurityGroupFilter() + } + } + return m, nil + } + + switch key { + case "q", "esc": + m.screen = screenFeatureList + m.sgFilter = "" + m.filteredSecurityGroups = m.securityGroups + m.sgIdx = 0 + case "up", "k": + if m.sgIdx > 0 { + m.sgIdx-- + } + case "down", "j": + if m.sgIdx < len(m.filteredSecurityGroups)-1 { + m.sgIdx++ + } + case "/": + m.sgFilterActive = true + case "r": + m.screen = screenLoading + m.sgFilter = "" + m.sgIdx = 0 + return m, m.loadSecurityGroups() + case "enter": + if len(m.filteredSecurityGroups) > 0 && m.sgIdx < len(m.filteredSecurityGroups) { + selected := m.filteredSecurityGroups[m.sgIdx] + m.selectedSecurityGroup = &selected + m.screen = screenSecurityGroupDetail + } + } + return m, nil +} + +func (m *Model) applySecurityGroupFilter() { + if m.sgFilter == "" { + m.filteredSecurityGroups = m.securityGroups + } else { + query := strings.ToLower(m.sgFilter) + var result []awsservice.SecurityGroup + for _, sg := range m.securityGroups { + if strings.Contains(sg.FilterText(), query) { + result = append(result, sg) + } + } + m.filteredSecurityGroups = result + } + m.sgIdx = 0 +} + +func (m Model) updateSecurityGroupDetail(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + switch msg.String() { + case "q", "esc": + m.screen = screenSecurityGroupList + } + return m, nil +} + +func (m Model) viewSecurityGroupList() string { + var b strings.Builder + b.WriteString(m.renderStatusBar()) + b.WriteString(titleStyle.Render("Security Groups")) + b.WriteString("\n") + + if m.sgFilterActive { + b.WriteString(filterStyle.Render(fmt.Sprintf("Filter: %s▏", m.sgFilter))) + } else if m.sgFilter != "" { + b.WriteString(dimStyle.Render(fmt.Sprintf("Filter: %s", m.sgFilter))) + } + b.WriteString("\n\n") + + if len(m.filteredSecurityGroups) == 0 { + b.WriteString(dimStyle.Render(" No matching security groups")) + b.WriteString("\n") + } else { + visibleLines := max(m.height-8, 5) + start := 0 + if m.sgIdx >= visibleLines { + start = m.sgIdx - visibleLines + 1 + } + end := min(start+visibleLines, len(m.filteredSecurityGroups)) + + for i := start; i < end; i++ { + sg := m.filteredSecurityGroups[i] + cursor := " " + style := normalStyle + if i == m.sgIdx { + cursor = "> " + style = selectedStyle + } + b.WriteString(style.Render(fmt.Sprintf("%s%s", cursor, sg.DisplayTitle()))) + b.WriteString("\n") + } + + b.WriteString("\n") + b.WriteString(dimStyle.Render(fmt.Sprintf(" %d/%d security groups", len(m.filteredSecurityGroups), len(m.securityGroups)))) + } + + b.WriteString("\n") + b.WriteString(dimStyle.Render("↑/↓: navigate • /: filter • r: refresh • enter: detail • esc: back • H: home")) + return b.String() +} + +func (m Model) viewSecurityGroupDetail() string { + if m.selectedSecurityGroup == nil { + return "" + } + sg := m.selectedSecurityGroup + var b strings.Builder + b.WriteString(m.renderStatusBar()) + b.WriteString(titleStyle.Render("Security Group Detail")) + b.WriteString("\n\n") + + labelStyle := lipgloss.NewStyle().Width(16) + b.WriteString(normalStyle.Render(fmt.Sprintf(" %s%s", labelStyle.Render("Group ID"), sg.GroupID))) + b.WriteString("\n") + b.WriteString(normalStyle.Render(fmt.Sprintf(" %s%s", labelStyle.Render("Name"), sg.Name))) + b.WriteString("\n") + b.WriteString(normalStyle.Render(fmt.Sprintf(" %s%s", labelStyle.Render("Description"), sg.Description))) + b.WriteString("\n") + b.WriteString(normalStyle.Render(fmt.Sprintf(" %s%s", labelStyle.Render("VPC ID"), sg.VPCID))) + b.WriteString("\n") + + // Inbound rules + b.WriteString("\n") + b.WriteString(titleStyle.Render("Inbound Rules")) + b.WriteString("\n") + if len(sg.IngressRules) == 0 { + b.WriteString(dimStyle.Render(" No inbound rules")) + b.WriteString("\n") + } else { + protoCol := lipgloss.NewStyle().Width(8) + portCol := lipgloss.NewStyle().Width(14) + b.WriteString(dimStyle.Render(" " + protoCol.Render("PROTO") + portCol.Render("PORT") + "SOURCE")) + b.WriteString("\n") + for _, rule := range sg.IngressRules { + proto := rule.Protocol + if proto == "-1" { + proto = "All" + } + portRange := "All" + if rule.Protocol != "-1" { + if rule.FromPort == rule.ToPort { + portRange = fmt.Sprintf("%d", rule.FromPort) + } else { + portRange = fmt.Sprintf("%d-%d", rule.FromPort, rule.ToPort) + } + } + source := rule.CIDRV4 + if source == "" { + source = rule.CIDRV6 + } + if source == "" && rule.ReferencedSGID != "" { + source = rule.ReferencedSGID + } + if source == "" { + source = "-" + } + row := " " + protoCol.Render(proto) + portCol.Render(portRange) + source + if rule.Description != "" { + row += dimStyle.Render(" " + rule.Description) + } + b.WriteString(normalStyle.Render(row)) + b.WriteString("\n") + } + } + + // Outbound rules + b.WriteString("\n") + b.WriteString(titleStyle.Render("Outbound Rules")) + b.WriteString("\n") + if len(sg.EgressRules) == 0 { + b.WriteString(dimStyle.Render(" No outbound rules")) + b.WriteString("\n") + } else { + protoCol := lipgloss.NewStyle().Width(8) + portCol := lipgloss.NewStyle().Width(14) + b.WriteString(dimStyle.Render(" " + protoCol.Render("PROTO") + portCol.Render("PORT") + "DESTINATION")) + b.WriteString("\n") + for _, rule := range sg.EgressRules { + proto := rule.Protocol + if proto == "-1" { + proto = "All" + } + portRange := "All" + if rule.Protocol != "-1" { + if rule.FromPort == rule.ToPort { + portRange = fmt.Sprintf("%d", rule.FromPort) + } else { + portRange = fmt.Sprintf("%d-%d", rule.FromPort, rule.ToPort) + } + } + dest := rule.CIDRV4 + if dest == "" { + dest = rule.CIDRV6 + } + if dest == "" && rule.ReferencedSGID != "" { + dest = rule.ReferencedSGID + } + if dest == "" { + dest = "-" + } + row := " " + protoCol.Render(proto) + portCol.Render(portRange) + dest + if rule.Description != "" { + row += dimStyle.Render(" " + rule.Description) + } + b.WriteString(normalStyle.Render(row)) + b.WriteString("\n") + } + } + + b.WriteString("\n") + b.WriteString(dimStyle.Render("esc: back • H: home")) + return b.String() +} diff --git a/internal/domain/catalog.go b/internal/domain/catalog.go index 9b05f5d..c126343 100644 --- a/internal/domain/catalog.go +++ b/internal/domain/catalog.go @@ -10,6 +10,10 @@ func Catalog() []Service { Kind: FeatureSSMSession, Description: "Start an SSM session to an EC2 instance", }, + { + Kind: FeatureSecurityGroupBrowser, + Description: "Browse security groups and view inbound/outbound rules", + }, }, }, { diff --git a/internal/domain/catalog_test.go b/internal/domain/catalog_test.go index 21595c6..37357b9 100644 --- a/internal/domain/catalog_test.go +++ b/internal/domain/catalog_test.go @@ -84,3 +84,21 @@ func TestRDSHasBrowserFeature(t *testing.T) { t.Error("RDS service not found in catalog") } + +func TestEC2HasSecurityGroupBrowserFeature(t *testing.T) { + services := Catalog() + + for _, svc := range services { + if svc.Name == ServiceEC2 { + for _, feat := range svc.Features { + if feat.Kind == FeatureSecurityGroupBrowser { + return + } + } + t.Error("EC2 service should have Security Group Browser feature") + return + } + } + + t.Error("EC2 service not found in catalog") +} diff --git a/internal/domain/model.go b/internal/domain/model.go index 0d7eeb4..b967415 100644 --- a/internal/domain/model.go +++ b/internal/domain/model.go @@ -19,7 +19,8 @@ const ( FeatureVPCBrowser FeatureKind = "VPC Browser" FeatureRDSBrowser FeatureKind = "RDS Browser" FeatureRoute53Browser FeatureKind = "Route53 Browser" - FeatureSecretsBrowser FeatureKind = "Secrets Manager Browser" + FeatureSecretsBrowser FeatureKind = "Secrets Manager Browser" + FeatureSecurityGroupBrowser FeatureKind = "Security Group Browser" ) // Feature describes a selectable feature under an AWS service. diff --git a/internal/services/aws/repository.go b/internal/services/aws/repository.go index 3055398..8b9727d 100644 --- a/internal/services/aws/repository.go +++ b/internal/services/aws/repository.go @@ -69,6 +69,7 @@ type EC2ClientAPI interface { DescribeVpcs(ctx context.Context, params *ec2.DescribeVpcsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcsOutput, error) DescribeSubnets(ctx context.Context, params *ec2.DescribeSubnetsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) DescribeNetworkInterfaces(ctx context.Context, params *ec2.DescribeNetworkInterfacesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNetworkInterfacesOutput, error) + DescribeSecurityGroups(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) } // CallerIdentity holds the result of sts:GetCallerIdentity. diff --git a/internal/services/aws/securitygroup.go b/internal/services/aws/securitygroup.go new file mode 100644 index 0000000..3e32cd0 --- /dev/null +++ b/internal/services/aws/securitygroup.go @@ -0,0 +1,93 @@ +package aws + +import ( + "context" + + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" +) + +// ListSecurityGroups returns all security groups in the current account/region. +func (r *AwsRepository) ListSecurityGroups(ctx context.Context) ([]SecurityGroup, error) { + output, err := r.EC2Client.DescribeSecurityGroups(ctx, &ec2.DescribeSecurityGroupsInput{}) + if err != nil { + return nil, err + } + + sgs := make([]SecurityGroup, 0, len(output.SecurityGroups)) + for _, sg := range output.SecurityGroups { + group := SecurityGroup{ + GroupID: awssdk.ToString(sg.GroupId), + Name: awssdk.ToString(sg.GroupName), + Description: awssdk.ToString(sg.Description), + VPCID: awssdk.ToString(sg.VpcId), + } + + // Check if this is the default security group + if group.Name == "default" { + group.IsDefault = true + } + + for _, perm := range sg.IpPermissions { + base := SecurityGroupRule{ + Protocol: awssdk.ToString(perm.IpProtocol), + FromPort: awssdk.ToInt32(perm.FromPort), + ToPort: awssdk.ToInt32(perm.ToPort), + } + for _, ipRange := range perm.IpRanges { + rule := base + rule.CIDRV4 = awssdk.ToString(ipRange.CidrIp) + rule.Description = awssdk.ToString(ipRange.Description) + group.IngressRules = append(group.IngressRules, rule) + } + for _, ipv6Range := range perm.Ipv6Ranges { + rule := base + rule.CIDRV6 = awssdk.ToString(ipv6Range.CidrIpv6) + rule.Description = awssdk.ToString(ipv6Range.Description) + group.IngressRules = append(group.IngressRules, rule) + } + for _, sgRef := range perm.UserIdGroupPairs { + rule := base + rule.ReferencedSGID = awssdk.ToString(sgRef.GroupId) + rule.Description = awssdk.ToString(sgRef.Description) + group.IngressRules = append(group.IngressRules, rule) + } + // If no specific source, add the base rule + if len(perm.IpRanges) == 0 && len(perm.Ipv6Ranges) == 0 && len(perm.UserIdGroupPairs) == 0 { + group.IngressRules = append(group.IngressRules, base) + } + } + + for _, perm := range sg.IpPermissionsEgress { + base := SecurityGroupRule{ + Protocol: awssdk.ToString(perm.IpProtocol), + FromPort: awssdk.ToInt32(perm.FromPort), + ToPort: awssdk.ToInt32(perm.ToPort), + } + for _, ipRange := range perm.IpRanges { + rule := base + rule.CIDRV4 = awssdk.ToString(ipRange.CidrIp) + rule.Description = awssdk.ToString(ipRange.Description) + group.EgressRules = append(group.EgressRules, rule) + } + for _, ipv6Range := range perm.Ipv6Ranges { + rule := base + rule.CIDRV6 = awssdk.ToString(ipv6Range.CidrIpv6) + rule.Description = awssdk.ToString(ipv6Range.Description) + group.EgressRules = append(group.EgressRules, rule) + } + for _, sgRef := range perm.UserIdGroupPairs { + rule := base + rule.ReferencedSGID = awssdk.ToString(sgRef.GroupId) + rule.Description = awssdk.ToString(sgRef.Description) + group.EgressRules = append(group.EgressRules, rule) + } + if len(perm.IpRanges) == 0 && len(perm.Ipv6Ranges) == 0 && len(perm.UserIdGroupPairs) == 0 { + group.EgressRules = append(group.EgressRules, base) + } + } + + sgs = append(sgs, group) + } + return sgs, nil +} diff --git a/internal/services/aws/securitygroup_model.go b/internal/services/aws/securitygroup_model.go new file mode 100644 index 0000000..4a26e91 --- /dev/null +++ b/internal/services/aws/securitygroup_model.go @@ -0,0 +1,72 @@ +package aws + +import ( + "fmt" + "strings" +) + +// SecurityGroup holds essential information about an EC2 security group. +type SecurityGroup struct { + GroupID string + Name string + Description string + VPCID string + IsDefault bool + IngressRules []SecurityGroupRule + EgressRules []SecurityGroupRule +} + +// DisplayTitle returns a formatted string for list display. +func (sg SecurityGroup) DisplayTitle() string { + defaultMark := "" + if sg.IsDefault { + defaultMark = " [default]" + } + return fmt.Sprintf("%s (%s) - %s%s", sg.Name, sg.GroupID, sg.VPCID, defaultMark) +} + +// FilterText returns a lowercase string for keyword matching. +func (sg SecurityGroup) FilterText() string { + return strings.ToLower(fmt.Sprintf("%s %s %s %s", sg.Name, sg.GroupID, sg.VPCID, sg.Description)) +} + +// SecurityGroupRule represents an inbound or outbound rule. +type SecurityGroupRule struct { + Protocol string + FromPort int32 + ToPort int32 + CIDRV4 string + CIDRV6 string + ReferencedSGID string + Description string +} + +// DisplayTitle returns a formatted string for rule display. +func (r SecurityGroupRule) DisplayTitle() string { + proto := r.Protocol + if proto == "-1" { + proto = "All" + } + + portRange := "All" + if r.Protocol != "-1" { + if r.FromPort == r.ToPort { + portRange = fmt.Sprintf("%d", r.FromPort) + } else { + portRange = fmt.Sprintf("%d-%d", r.FromPort, r.ToPort) + } + } + + source := r.CIDRV4 + if source == "" { + source = r.CIDRV6 + } + if source == "" && r.ReferencedSGID != "" { + source = r.ReferencedSGID + } + if source == "" { + source = "-" + } + + return fmt.Sprintf("%s %s %s", proto, portRange, source) +} diff --git a/internal/services/aws/securitygroup_test.go b/internal/services/aws/securitygroup_test.go new file mode 100644 index 0000000..75850e5 --- /dev/null +++ b/internal/services/aws/securitygroup_test.go @@ -0,0 +1,163 @@ +package aws + +import ( + "context" + "testing" + + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" +) + +func TestListSecurityGroups_Success(t *testing.T) { + mock := &mockEC2Client{ + describeSecurityGroupsFunc: func(_ context.Context, _ *ec2.DescribeSecurityGroupsInput, _ ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) { + return &ec2.DescribeSecurityGroupsOutput{ + SecurityGroups: []types.SecurityGroup{ + { + GroupId: awssdk.String("sg-aaa"), + GroupName: awssdk.String("web-sg"), + Description: awssdk.String("Web servers"), + VpcId: awssdk.String("vpc-111"), + IpPermissions: []types.IpPermission{ + { + IpProtocol: awssdk.String("tcp"), + FromPort: awssdk.Int32(443), + ToPort: awssdk.Int32(443), + IpRanges: []types.IpRange{ + {CidrIp: awssdk.String("0.0.0.0/0"), Description: awssdk.String("HTTPS")}, + }, + }, + { + IpProtocol: awssdk.String("tcp"), + FromPort: awssdk.Int32(22), + ToPort: awssdk.Int32(22), + UserIdGroupPairs: []types.UserIdGroupPair{ + {GroupId: awssdk.String("sg-bastion"), Description: awssdk.String("SSH from bastion")}, + }, + }, + }, + IpPermissionsEgress: []types.IpPermission{ + { + IpProtocol: awssdk.String("-1"), + IpRanges: []types.IpRange{ + {CidrIp: awssdk.String("0.0.0.0/0")}, + }, + }, + }, + }, + { + GroupId: awssdk.String("sg-bbb"), + GroupName: awssdk.String("default"), + VpcId: awssdk.String("vpc-111"), + }, + }, + }, nil + }, + } + + repo := &AwsRepository{EC2Client: mock} + sgs, err := repo.ListSecurityGroups(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(sgs) != 2 { + t.Fatalf("expected 2 security groups, got %d", len(sgs)) + } + + // First SG + sg := sgs[0] + if sg.GroupID != "sg-aaa" { + t.Errorf("expected GroupID sg-aaa, got %s", sg.GroupID) + } + if sg.Name != "web-sg" { + t.Errorf("expected Name web-sg, got %s", sg.Name) + } + if sg.IsDefault { + t.Error("expected IsDefault false") + } + if len(sg.IngressRules) != 2 { + t.Fatalf("expected 2 ingress rules, got %d", len(sg.IngressRules)) + } + if sg.IngressRules[0].CIDRV4 != "0.0.0.0/0" { + t.Errorf("expected first ingress CIDR 0.0.0.0/0, got %s", sg.IngressRules[0].CIDRV4) + } + if sg.IngressRules[1].ReferencedSGID != "sg-bastion" { + t.Errorf("expected second ingress ref SG sg-bastion, got %s", sg.IngressRules[1].ReferencedSGID) + } + if len(sg.EgressRules) != 1 { + t.Fatalf("expected 1 egress rule, got %d", len(sg.EgressRules)) + } + + // Second SG (default) + if !sgs[1].IsDefault { + t.Error("expected default SG to have IsDefault true") + } +} + +func TestListSecurityGroups_Error(t *testing.T) { + mock := &mockEC2Client{ + describeSecurityGroupsFunc: func(_ context.Context, _ *ec2.DescribeSecurityGroupsInput, _ ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) { + return nil, context.DeadlineExceeded + }, + } + + repo := &AwsRepository{EC2Client: mock} + _, err := repo.ListSecurityGroups(context.Background()) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestSecurityGroupDisplayTitle(t *testing.T) { + sg := SecurityGroup{GroupID: "sg-aaa", Name: "web-sg", VPCID: "vpc-111"} + title := sg.DisplayTitle() + if title != "web-sg (sg-aaa) - vpc-111" { + t.Errorf("unexpected DisplayTitle: %s", title) + } + + sgDefault := SecurityGroup{GroupID: "sg-bbb", Name: "default", VPCID: "vpc-111", IsDefault: true} + titleDefault := sgDefault.DisplayTitle() + if titleDefault != "default (sg-bbb) - vpc-111 [default]" { + t.Errorf("unexpected DisplayTitle for default: %s", titleDefault) + } +} + +func TestSecurityGroupRuleDisplayTitle(t *testing.T) { + tests := []struct { + name string + rule SecurityGroupRule + expected string + }{ + { + name: "TCP single port with CIDR", + rule: SecurityGroupRule{Protocol: "tcp", FromPort: 443, ToPort: 443, CIDRV4: "0.0.0.0/0"}, + expected: "tcp 443 0.0.0.0/0", + }, + { + name: "TCP port range", + rule: SecurityGroupRule{Protocol: "tcp", FromPort: 1024, ToPort: 65535, CIDRV4: "10.0.0.0/8"}, + expected: "tcp 1024-65535 10.0.0.0/8", + }, + { + name: "All traffic", + rule: SecurityGroupRule{Protocol: "-1", CIDRV4: "0.0.0.0/0"}, + expected: "All All 0.0.0.0/0", + }, + { + name: "SG reference", + rule: SecurityGroupRule{Protocol: "tcp", FromPort: 22, ToPort: 22, ReferencedSGID: "sg-bastion"}, + expected: "tcp 22 sg-bastion", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.rule.DisplayTitle() + if got != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, got) + } + }) + } +} diff --git a/internal/services/aws/vpc_test.go b/internal/services/aws/vpc_test.go index d6162dc..4f7ef93 100644 --- a/internal/services/aws/vpc_test.go +++ b/internal/services/aws/vpc_test.go @@ -16,6 +16,7 @@ type mockEC2Client struct { describeSubnetsFunc func(ctx context.Context, params *ec2.DescribeSubnetsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) describeInstancesFunc func(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) describeNetworkInterfacesFunc func(ctx context.Context, params *ec2.DescribeNetworkInterfacesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNetworkInterfacesOutput, error) + describeSecurityGroupsFunc func(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) } func (m *mockEC2Client) DescribeVpcs(ctx context.Context, params *ec2.DescribeVpcsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcsOutput, error) { @@ -40,6 +41,13 @@ func (m *mockEC2Client) DescribeNetworkInterfaces(ctx context.Context, params *e return &ec2.DescribeNetworkInterfacesOutput{}, nil } +func (m *mockEC2Client) DescribeSecurityGroups(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error) { + if m.describeSecurityGroupsFunc != nil { + return m.describeSecurityGroupsFunc(ctx, params, optFns...) + } + return &ec2.DescribeSecurityGroupsOutput{}, nil +} + // --- VPC tests --- func TestListVPCs_Success(t *testing.T) {