diff --git a/.circleci/config.yml b/.circleci/config.yml index 2c57dbe..11f25c3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -25,12 +25,12 @@ jobs: working_directory: ~/repo steps: - checkout - - run: curl -sS https://releases.hashicorp.com/terraform/0.12.7/terraform_0.12.7_linux_amd64.zip -o ./terraform.zip && unzip terraform.zip && sudo mv terraform /usr/local/bin/ && which terraform - - run: go test -cover ./reach/... -test.v + - run: curl -sS https://releases.hashicorp.com/terraform/0.12.15/terraform_0.12.15_linux_amd64.zip -o ./terraform.zip && unzip terraform.zip && sudo mv terraform /usr/local/bin/ && which terraform + - run: go test -cover ./reach/analyzer -test.v -acceptance -log-tf -timeout 60m workflows: version: 2 - commit: + push: jobs: - build - unit_tests diff --git a/README.md b/README.md index 6630ab9..7122da6 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![Go Report Card](https://goreportcard.com/badge/github.com/luhring/reach)](https://goreportcard.com/report/github.com/luhring/reach) [![GitHub license](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/luhring/reach/blob/master/LICENSE) -Reach is a tool for discovering the impact your AWS configuration has on the flow of network traffic. +Reach is a tool for analyzing the network traffic allowed to flow in AWS. Reach doesn't need any access to your network — it simply queries the AWS API for your network configuration. ## Getting Started @@ -50,7 +50,7 @@ $ reach web-instance database-instance $ reach web data ``` -**Note:** Right now, Reach can only analyze the path between two EC2 instances when the instances are **in the same subnet**. Adding support for multiple subnets is the top priority and is currently in development. +**Note:** Right now, Reach can analyze the path between two EC2 instances only when the instances are **_in the same VPC_**. ## Initial Setup @@ -102,7 +102,7 @@ In this case, Reach will provide significantly more detail about the analysis. S ## Feature Ideas - ~~**Same-subnet analysis:** Between two EC2 instances within the same subnet~~ (done!) -- **Same-VPC analysis:** Between two EC2 instances within the same VPC, including for EC2 instances in separate subnets +- ~~**Same-VPC analysis:** Between two EC2 instances within the same VPC, including for EC2 instances in separate subnets~~ (done!) - **IP address analysis:** Between an EC2 instance and a specified IP address that may be outside of AWS entirely (enhancement idea: provide shortcuts for things like the user's own IP address, a specified hostname's resolved IP address, etc.) - **Filtered analysis:** Specify a particular kind of network traffic to analyze (e.g. a single TCP port) and return results only for that filter - **Other AWS resources:** Analyze other kinds of AWS resources than just EC2 instances (e.g. ELB, Lambda, VPC endpoints, etc.) diff --git a/build.sh b/build.sh index 5333e91..5d789dc 100755 --- a/build.sh +++ b/build.sh @@ -1,37 +1,73 @@ #!/bin/bash -set -ex +# This script takes an argument for which OS to build for: darwin, linux, or windows. +# If no argument is provided, the script builds for all three. + +# To build for a specific version, set the `REACH_VERSION` variable to something like "2.0.1" before running the script. + +set -e export REACH_VERSION=${REACH_VERSION:-"0.0.0"} +export SPECIFIED_OS="" + +if [[ -z "$1" ]] +then + export SPECIFIED_OS="$1" +fi set -u export CGO_ENABLED=0 export GOARCH=amd64 -export REACH_DIR_DARWIN=$(printf "reach_%s_darwin_amd64" $REACH_VERSION) -export REACH_DIR_LINUX=$(printf "reach_%s_linux_amd64" $REACH_VERSION) -export REACH_DIR_WINDOWS=$(printf "reach_%s_windows_amd64" $REACH_VERSION) -mkdir -p ./build +set -x + +function build_for_os { + local GOOS="$1" + local REACH_EXECUTABLE -GOOS=darwin go build -a -tags netgo -o "./build/$REACH_DIR_DARWIN/reach" -GOOS=linux go build -a -tags netgo -o "./build/$REACH_DIR_LINUX/reach" -GOOS=windows go build -a -tags netgo -o "./build/$REACH_DIR_WINDOWS/reach.exe" + if [[ "$GOOS" == "windows" ]] + then + REACH_EXECUTABLE="reach.exe" + else + REACH_EXECUTABLE="reach" + fi -cp -nv ./LICENSE ./README.md "./build/$REACH_DIR_DARWIN" -cp -nv ./LICENSE ./README.md "./build/$REACH_DIR_LINUX" -cp -nv ./LICENSE ./README.md "./build/$REACH_DIR_WINDOWS" + local REACH_DIR_FOR_OS + REACH_DIR_FOR_OS=$(printf "reach_%s_%s_amd64" "$REACH_VERSION" "$GOOS") + + mkdir -p "./$REACH_DIR_FOR_OS" + + GOOS=$GOOS go build -a -v -tags netgo -o "./$REACH_DIR_FOR_OS/$REACH_EXECUTABLE" .. + cp -nv ../LICENSE ../README.md "./$REACH_DIR_FOR_OS/" + + if [[ "$GOOS" == "windows" ]] + then + zip "$REACH_DIR_FOR_OS.zip" "./$REACH_DIR_FOR_OS"/* + openssl dgst -sha256 "./$REACH_DIR_FOR_OS.zip" >> ./checksums.txt + else + tar -cvzf "$REACH_DIR_FOR_OS.tar.gz" "./$REACH_DIR_FOR_OS"/* + openssl dgst -sha256 "./$REACH_DIR_FOR_OS.tar.gz" >> ./checksums.txt + fi +} + +rm -rf ./build +mkdir -p ./build pushd ./build - tar -cvzf $REACH_DIR_DARWIN.tar.gz ./$REACH_DIR_DARWIN/* - tar -cvzf $REACH_DIR_LINUX.tar.gz ./$REACH_DIR_LINUX/* - tar -cvzf $REACH_DIR_WINDOWS.tar.gz ./$REACH_DIR_WINDOWS/* + if [[ ! -z "$SPECIFIED_OS" ]] + then + build_for_os "$SPECIFIED_OS" + else + for CURRENT_OS in "darwin" "linux" "windows" + do + build_for_os "$CURRENT_OS" + done + fi - openssl dgst -sha256 ./$REACH_DIR_DARWIN.tar.gz >> ./checksums.txt - openssl dgst -sha256 ./$REACH_DIR_LINUX.tar.gz >> ./checksums.txt - openssl dgst -sha256 ./$REACH_DIR_WINDOWS.tar.gz >> ./checksums.txt + set +x cat ./checksums.txt popd -set +eux +set +eu diff --git a/cmd/assert.go b/cmd/assert.go new file mode 100644 index 0000000..974ff2b --- /dev/null +++ b/cmd/assert.go @@ -0,0 +1,42 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/mgutz/ansi" + + "github.com/luhring/reach/reach" +) + +func doAssertReachable(analysis reach.Analysis) { + if analysis.PassesAssertReachable() { + exitSuccessfulAssertion("source is able to reach destination") + } else { + exitFailedAssertion("one or more forward or return paths of network traffic is obstructed") + } +} + +func doAssertNotReachable(analysis reach.Analysis) { + if analysis.PassesAssertNotReachable() { + exitSuccessfulAssertion("source is unable to reach destination") + } else { + exitFailedAssertion("source is able to send network traffic to destination") + } +} + +func exitFailedAssertion(text string) { + failedMessage := ansi.Color("assertion failed:", "red+b") + secondaryMessage := ansi.Color(text, "red") + _, _ = fmt.Fprintf(os.Stderr, "\n%v %v\n", failedMessage, secondaryMessage) + + os.Exit(2) +} + +func exitSuccessfulAssertion(text string) { + succeededMessage := ansi.Color("assertion succeeded:", "green+b") + secondaryMessage := ansi.Color(text, "green") + _, _ = fmt.Fprintf(os.Stderr, "\n%v %v\n", succeededMessage, secondaryMessage) + + os.Exit(0) +} diff --git a/cmd/exit.go b/cmd/exit.go index 25f6f14..d11505a 100644 --- a/cmd/exit.go +++ b/cmd/exit.go @@ -2,7 +2,6 @@ package cmd import ( "fmt" - "github.com/mgutz/ansi" "os" ) @@ -11,19 +10,3 @@ func exitWithError(err error) { os.Exit(1) } - -func exitWithFailedAssertion(text string) { - failedMessage := ansi.Color("assertion failed:", "red+b") - secondaryMessage := ansi.Color(text, "red") - _, _ = fmt.Fprintf(os.Stderr, "\n%v %v\n", failedMessage, secondaryMessage) - - os.Exit(2) -} - -func exitWithSuccessfulAssertion(text string) { - succeededMessage := ansi.Color("assertion succeeded:", "green+b") - secondaryMessage := ansi.Color(text, "green") - _, _ = fmt.Fprintf(os.Stderr, "\n%v %v\n", succeededMessage, secondaryMessage) - - os.Exit(0) -} diff --git a/cmd/root.go b/cmd/root.go index 4092b42..fe2cbdc 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -3,7 +3,6 @@ package cmd import ( "errors" "fmt" - "os" "strings" "github.com/spf13/cobra" @@ -16,11 +15,13 @@ import ( const explainFlag = "explain" const vectorsFlag = "vectors" +const jsonFlag = "json" const assertReachableFlag = "assert-reachable" const assertNotReachableFlag = "assert-not-reachable" var explain bool var showVectors bool +var outputJSON bool var assertReachable bool var assertNotReachable bool @@ -58,7 +59,7 @@ See https://github.com/luhring/reach for documentation.`, } destination.SetRoleToDestination() - if !explain && !showVectors { + if !outputJSON && !explain && !showVectors { fmt.Printf("source: %s\ndestination: %s\n\n", source.ID, destination.ID) } @@ -73,17 +74,16 @@ See https://github.com/luhring/reach for documentation.`, exitWithError(err) } - if explain { + if outputJSON { + fmt.Println(analysis.ToJSON()) + } else if explain { ex := explainer.New(*analysis) fmt.Print(ex.Explain()) } else if showVectors { var vectorOutputs []string for _, v := range analysis.NetworkVectors { - output := "" - output += v.String() - - vectorOutputs = append(vectorOutputs, output) + vectorOutputs = append(vectorOutputs, v.String()) } fmt.Print(strings.Join(vectorOutputs, "\n")) @@ -91,30 +91,32 @@ See https://github.com/luhring/reach for documentation.`, fmt.Print("network traffic allowed from source to destination:" + "\n") fmt.Print(mergedTraffic.ColorStringWithSymbols()) - if len(analysis.NetworkVectors) > 1 { + if len(analysis.NetworkVectors) > 1 { // handling this case with care; this view isn't optimized for multi-vector output! printMergedResultsWarning() + warnIfAnyVectorHasRestrictedReturnTraffic(analysis.NetworkVectors) + } else { + // calculate merged return traffic + mergedReturnTraffic, err := analysis.MergedReturnTraffic() + if err != nil { + exitWithError(err) + } + + restrictedProtocols := mergedTraffic.ProtocolsWithRestrictedReturnPath(mergedReturnTraffic) + if len(restrictedProtocols) > 0 { + found, warnings := explainer.WarningsFromRestrictedReturnPath(restrictedProtocols) + if found { + fmt.Print("\n" + warnings + "\n") + } + } } } - // fmt.Println(analysis.ToJSON()) // for debugging - - const canReach = "source is able to reach destination" - const cannotReach = "source is unable to reach destination" - if assertReachable { - if mergedTraffic.None() { - exitWithFailedAssertion(cannotReach) - } else { - exitWithSuccessfulAssertion(canReach) - } + doAssertReachable(*analysis) } if assertNotReachable { - if mergedTraffic.None() { - exitWithSuccessfulAssertion(cannotReach) - } else { - exitWithFailedAssertion(canReach) - } + doAssertNotReachable(*analysis) } }, } @@ -129,11 +131,7 @@ func Execute() { func init() { rootCmd.Flags().BoolVar(&explain, explainFlag, false, "explain how the configuration was analyzed") rootCmd.Flags().BoolVar(&showVectors, vectorsFlag, false, "show allowed traffic in terms of network vectors") + rootCmd.Flags().BoolVar(&outputJSON, jsonFlag, false, "output full analysis as JSON (overrides other display flags)") rootCmd.Flags().BoolVar(&assertReachable, assertReachableFlag, false, "exit non-zero if no traffic is allowed from source to destination") rootCmd.Flags().BoolVar(&assertNotReachable, assertNotReachableFlag, false, "exit non-zero if any traffic can reach destination from source") } - -func printMergedResultsWarning() { - const mergedResultsWarning = "IMPORTANT: Reach detected more than one network path between the source and destination. Reach calls these paths \"network vectors\". The analysis result shown above is the merging of all network vectors' analysis results. The impact that infrastructure configuration has on actual network reachability might vary based on the way hosts are configured to use their network interfaces, and Reach is unable to access any configuration internal to a host. To see the network reachability across individual network vectors, run the command again with '--vectors'.\n\n" - _, _ = fmt.Fprint(os.Stderr, "\n"+mergedResultsWarning) -} diff --git a/cmd/warnings.go b/cmd/warnings.go new file mode 100644 index 0000000..fbd178c --- /dev/null +++ b/cmd/warnings.go @@ -0,0 +1,24 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/luhring/reach/reach" +) + +func printMergedResultsWarning() { + const mergedResultsWarning = "WARNING: Reach detected more than one network path between the source and destination. Reach calls these paths \"network vectors\". The analysis result shown above is the merging of all network vectors' analysis results. The impact that infrastructure configuration has on actual network reachability might vary based on the way hosts are configured to use their network interfaces, and Reach is unable to access any configuration internal to a host. To see the network reachability across individual network vectors, run the command again with '--" + vectorsFlag + "'.\n" + _, _ = fmt.Fprint(os.Stderr, "\n"+mergedResultsWarning) +} + +func warnIfAnyVectorHasRestrictedReturnTraffic(vectors []reach.NetworkVector) { + for _, v := range vectors { + if !v.ReturnTraffic.All() { + const restrictedVectorReturnTraffic = "WARNING: One or more of the analyzed network vectors has restrictions on network traffic allowed to return from the destination to the source. For details, run the command again with '--" + vectorsFlag + "'.\n" + _, _ = fmt.Fprintf(os.Stderr, "\n"+restrictedVectorReturnTraffic) + + return + } + } +} diff --git a/reach/acceptance/data/tf/ec2_instance_source_and_destination.tf b/reach/acceptance/data/tf/ec2_instance_source_and_destination.tf deleted file mode 100644 index b8813e8..0000000 --- a/reach/acceptance/data/tf/ec2_instance_source_and_destination.tf +++ /dev/null @@ -1,25 +0,0 @@ -resource "aws_instance" "source" { - ami = "${data.aws_ami.ubuntu.id}" - instance_type = "t2.micro" - - tags = { - Name = "source" - } -} - -resource "aws_instance" "destination" { - ami = "${data.aws_ami.ubuntu.id}" - instance_type = "t2.micro" - - tags = { - Name = "destination" - } -} - -output "source_id" { - value = aws_instance.source.id -} - -output "destination_id" { - value = aws_instance.destination.id -} diff --git a/reach/acceptance/data/tf/ec2_instances_same_subnet_all_traffic.tf b/reach/acceptance/data/tf/ec2_instances_same_subnet_all_traffic.tf new file mode 100644 index 0000000..a02bb55 --- /dev/null +++ b/reach/acceptance/data/tf/ec2_instances_same_subnet_all_traffic.tf @@ -0,0 +1,28 @@ +resource "aws_instance" "source" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_1_of_1.id + + + vpc_security_group_ids = [ + aws_security_group.outbound_allow_all.id + ] + + tags = { + Name = "aat_source" + } +} + +resource "aws_instance" "destination" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_1_of_1.id + + vpc_security_group_ids = [ + aws_security_group.inbound_allow_all.id + ] + + tags = { + Name = "aat_destination" + } +} diff --git a/reach/acceptance/data/tf/ec2_instances_same_subnet_https_via_two-way_sg_ip_match.tf b/reach/acceptance/data/tf/ec2_instances_same_subnet_https_via_two-way_sg_ip_match.tf new file mode 100644 index 0000000..13ff81a --- /dev/null +++ b/reach/acceptance/data/tf/ec2_instances_same_subnet_https_via_two-way_sg_ip_match.tf @@ -0,0 +1,28 @@ +resource "aws_instance" "source" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_1_of_1.id + + + vpc_security_group_ids = [ + aws_security_group.outbound_allow_https_to_ip.id + ] + + tags = { + Name = "aat_source" + } +} + +resource "aws_instance" "destination" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_1_of_1.id + + vpc_security_group_ids = [ + aws_security_group.inbound_allow_https_from_ip.id + ] + + tags = { + Name = "aat_destination" + } +} diff --git a/reach/acceptance/data/tf/ec2_instances_same_subnet_multiple_protocols.tf b/reach/acceptance/data/tf/ec2_instances_same_subnet_multiple_protocols.tf new file mode 100644 index 0000000..12b7e3a --- /dev/null +++ b/reach/acceptance/data/tf/ec2_instances_same_subnet_multiple_protocols.tf @@ -0,0 +1,34 @@ +resource "aws_instance" "source" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_1_of_1.id + + + vpc_security_group_ids = [ + aws_security_group.no_rules.id, + aws_security_group.outbound_allow_all_tcp.id, + aws_security_group.outbound_allow_all_udp_to_sg_no_rules.id, + aws_security_group.outbound_allow_esp.id, + ] + + tags = { + Name = "aat_source" + } +} + +resource "aws_instance" "destination" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_1_of_1.id + + vpc_security_group_ids = [ + aws_security_group.no_rules.id, + aws_security_group.inbound_allow_udp_dns_from_sg_no_rules.id, + aws_security_group.inbound_allow_ssh_from_all_ip_addresses.id, + aws_security_group.inbound_allow_esp.id, + ] + + tags = { + Name = "aat_destination" + } +} diff --git a/reach/acceptance/data/tf/ec2_instances_same_subnet_no_security_group_rules.tf b/reach/acceptance/data/tf/ec2_instances_same_subnet_no_security_group_rules.tf new file mode 100644 index 0000000..2b680d7 --- /dev/null +++ b/reach/acceptance/data/tf/ec2_instances_same_subnet_no_security_group_rules.tf @@ -0,0 +1,27 @@ +resource "aws_instance" "source" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_1_of_1.id + + vpc_security_group_ids = [ + aws_security_group.no_rules.id + ] + + tags = { + Name = "aat_source" + } +} + +resource "aws_instance" "destination" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_1_of_1.id + + vpc_security_group_ids = [ + aws_security_group.no_rules.id + ] + + tags = { + Name = "aat_destination" + } +} diff --git a/reach/acceptance/data/tf/ec2_instances_same_subnet_ssh.tf b/reach/acceptance/data/tf/ec2_instances_same_subnet_ssh.tf new file mode 100644 index 0000000..c5961e7 --- /dev/null +++ b/reach/acceptance/data/tf/ec2_instances_same_subnet_ssh.tf @@ -0,0 +1,28 @@ +resource "aws_instance" "source" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_1_of_1.id + + + vpc_security_group_ids = [ + aws_security_group.outbound_allow_all.id + ] + + tags = { + Name = "aat_source" + } +} + +resource "aws_instance" "destination" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_1_of_1.id + + vpc_security_group_ids = [ + aws_security_group.inbound_allow_ssh_from_all_ip_addresses.id + ] + + tags = { + Name = "aat_destination" + } +} diff --git a/reach/acceptance/data/tf/ec2_instances_same_subnet_udp_dns_via_sg_reference.tf b/reach/acceptance/data/tf/ec2_instances_same_subnet_udp_dns_via_sg_reference.tf new file mode 100644 index 0000000..2a6a7dc --- /dev/null +++ b/reach/acceptance/data/tf/ec2_instances_same_subnet_udp_dns_via_sg_reference.tf @@ -0,0 +1,30 @@ +resource "aws_instance" "source" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_1_of_1.id + + + vpc_security_group_ids = [ + aws_security_group.no_rules.id, + aws_security_group.outbound_allow_all_udp_to_sg_no_rules.id + ] + + tags = { + Name = "aat_source" + } +} + +resource "aws_instance" "destination" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_1_of_1.id + + vpc_security_group_ids = [ + aws_security_group.no_rules.id, + aws_security_group.inbound_allow_udp_dns_from_sg_no_rules.id + ] + + tags = { + Name = "aat_destination" + } +} diff --git a/reach/acceptance/data/tf/ec2_instances_same_vpc_all_esp.tf b/reach/acceptance/data/tf/ec2_instances_same_vpc_all_esp.tf new file mode 100644 index 0000000..8372ac2 --- /dev/null +++ b/reach/acceptance/data/tf/ec2_instances_same_vpc_all_esp.tf @@ -0,0 +1,28 @@ +resource "aws_instance" "source" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_1_of_2.id + + + vpc_security_group_ids = [ + aws_security_group.outbound_allow_esp.id + ] + + tags = { + Name = "aat_source" + } +} + +resource "aws_instance" "destination" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_2_of_2.id + + vpc_security_group_ids = [ + aws_security_group.inbound_allow_all.id + ] + + tags = { + Name = "aat_destination" + } +} diff --git a/reach/acceptance/data/tf/ec2_instances_same_vpc_all_traffic.tf b/reach/acceptance/data/tf/ec2_instances_same_vpc_all_traffic.tf new file mode 100644 index 0000000..aaf4320 --- /dev/null +++ b/reach/acceptance/data/tf/ec2_instances_same_vpc_all_traffic.tf @@ -0,0 +1,28 @@ +resource "aws_instance" "source" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_1_of_2.id + + + vpc_security_group_ids = [ + aws_security_group.outbound_allow_all.id + ] + + tags = { + Name = "aat_source" + } +} + +resource "aws_instance" "destination" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_2_of_2.id + + vpc_security_group_ids = [ + aws_security_group.inbound_allow_all.id + ] + + tags = { + Name = "aat_destination" + } +} diff --git a/reach/acceptance/data/tf/ec2_instances_same_vpc_postgres.tf b/reach/acceptance/data/tf/ec2_instances_same_vpc_postgres.tf new file mode 100644 index 0000000..9fcaca3 --- /dev/null +++ b/reach/acceptance/data/tf/ec2_instances_same_vpc_postgres.tf @@ -0,0 +1,30 @@ +resource "aws_instance" "source" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_1_of_2.id + + + vpc_security_group_ids = [ + aws_security_group.no_rules.id, + aws_security_group.outbound_allow_postgres_to_sg_no_rules.id + ] + + tags = { + Name = "aat_source" + } +} + +resource "aws_instance" "destination" { + ami = data.aws_ami.ubuntu.id + instance_type = "t2.micro" + subnet_id = aws_subnet.subnet_2_of_2.id + + vpc_security_group_ids = [ + aws_security_group.no_rules.id, + aws_security_group.inbound_allow_postgres_from_sg_no_rules.id + ] + + tags = { + Name = "aat_destination" + } +} diff --git a/reach/acceptance/data/tf/network_acl_both_subnets_all_tcp.tf b/reach/acceptance/data/tf/network_acl_both_subnets_all_tcp.tf new file mode 100644 index 0000000..4294734 --- /dev/null +++ b/reach/acceptance/data/tf/network_acl_both_subnets_all_tcp.tf @@ -0,0 +1,30 @@ +resource "aws_network_acl" "both_subnets_all_tcp" { + vpc_id = "${aws_vpc.aat_vpc.id}" + + subnet_ids = [ + aws_subnet.subnet_1_of_2.id, + aws_subnet.subnet_2_of_2.id, + ] + + egress { + protocol = "tcp" + rule_no = 100 + action = "allow" + cidr_block = "0.0.0.0/0" + from_port = 0 + to_port = 65535 + } + + ingress { + protocol = "tcp" + rule_no = 100 + action = "allow" + cidr_block = "0.0.0.0/0" + from_port = 0 + to_port = 65535 + } + + tags = { + Name = "aat_both_subnets_all_tcp" + } +} diff --git a/reach/acceptance/data/tf/network_acl_both_subnets_all_traffic.tf b/reach/acceptance/data/tf/network_acl_both_subnets_all_traffic.tf new file mode 100644 index 0000000..05d830a --- /dev/null +++ b/reach/acceptance/data/tf/network_acl_both_subnets_all_traffic.tf @@ -0,0 +1,30 @@ +resource "aws_network_acl" "both_subnets_all_traffic" { + vpc_id = "${aws_vpc.aat_vpc.id}" + + subnet_ids = [ + aws_subnet.subnet_1_of_2.id, + aws_subnet.subnet_2_of_2.id, + ] + + egress { + protocol = "-1" + rule_no = 100 + action = "allow" + cidr_block = "0.0.0.0/0" + from_port = 0 + to_port = 0 + } + + ingress { + protocol = "-1" + rule_no = 100 + action = "allow" + cidr_block = "0.0.0.0/0" + from_port = 0 + to_port = 0 + } + + tags = { + Name = "aat_both_subnets_all_traffic" + } +} diff --git a/reach/acceptance/data/tf/network_acl_both_subnets_no_traffic.tf b/reach/acceptance/data/tf/network_acl_both_subnets_no_traffic.tf new file mode 100644 index 0000000..b65061c --- /dev/null +++ b/reach/acceptance/data/tf/network_acl_both_subnets_no_traffic.tf @@ -0,0 +1,12 @@ +resource "aws_network_acl" "both_subnets_no_traffic" { + vpc_id = "${aws_vpc.aat_vpc.id}" + + subnet_ids = [ + aws_subnet.subnet_1_of_2.id, + aws_subnet.subnet_2_of_2.id, + ] + + tags = { + Name = "aat_both_subnets_no_traffic" + } +} diff --git a/reach/acceptance/data/tf/network_acl_destination_subnet_tightened_postgres.tf b/reach/acceptance/data/tf/network_acl_destination_subnet_tightened_postgres.tf new file mode 100644 index 0000000..a2c0dd9 --- /dev/null +++ b/reach/acceptance/data/tf/network_acl_destination_subnet_tightened_postgres.tf @@ -0,0 +1,29 @@ +resource "aws_network_acl" "destination_subnet_tightened_postgres" { + vpc_id = "${aws_vpc.aat_vpc.id}" + + subnet_ids = [ + aws_subnet.subnet_2_of_2.id, + ] + + egress { + protocol = "tcp" + rule_no = 100 + action = "allow" + cidr_block = aws_subnet.subnet_1_of_2.cidr_block + from_port = 0 + to_port = 65535 + } + + ingress { + protocol = "tcp" + rule_no = 100 + action = "allow" + cidr_block = aws_subnet.subnet_1_of_2.cidr_block + from_port = 5432 + to_port = 5432 + } + + tags = { + Name = "aat_destination_subnet_tightened_postgres" + } +} diff --git a/reach/acceptance/data/tf/network_acl_source_subnet_tightened_postgres.tf b/reach/acceptance/data/tf/network_acl_source_subnet_tightened_postgres.tf new file mode 100644 index 0000000..c9ec91d --- /dev/null +++ b/reach/acceptance/data/tf/network_acl_source_subnet_tightened_postgres.tf @@ -0,0 +1,29 @@ +resource "aws_network_acl" "source_subnet_tightened_postgres" { + vpc_id = "${aws_vpc.aat_vpc.id}" + + subnet_ids = [ + aws_subnet.subnet_1_of_2.id, + ] + + egress { + protocol = "tcp" + rule_no = 100 + action = "allow" + cidr_block = aws_subnet.subnet_2_of_2.cidr_block + from_port = 5432 + to_port = 5432 + } + + ingress { + protocol = "tcp" + rule_no = 100 + action = "allow" + cidr_block = aws_subnet.subnet_2_of_2.cidr_block + from_port = 0 + to_port = 65535 + } + + tags = { + Name = "aat_source_subnet_tightened_postgres" + } +} diff --git a/reach/acceptance/data/tf/outputs.tf b/reach/acceptance/data/tf/outputs.tf new file mode 100644 index 0000000..e948350 --- /dev/null +++ b/reach/acceptance/data/tf/outputs.tf @@ -0,0 +1,7 @@ +output "source_id" { + value = aws_instance.source.id +} + +output "destination_id" { + value = aws_instance.destination.id +} diff --git a/reach/acceptance/data/tf/security_group_inbound_allow_all.tf b/reach/acceptance/data/tf/security_group_inbound_allow_all.tf new file mode 100644 index 0000000..b78a34d --- /dev/null +++ b/reach/acceptance/data/tf/security_group_inbound_allow_all.tf @@ -0,0 +1,12 @@ +resource "aws_security_group" "inbound_allow_all" { + name = "aat_inbound_allow_all" + description = "Allow all inbound traffic" + vpc_id = aws_vpc.aat_vpc.id + + ingress { + from_port = 0 + to_port = 0 + protocol = "-1" + cidr_blocks = ["0.0.0.0/0"] + } +} diff --git a/reach/acceptance/data/tf/security_group_inbound_allow_esp.tf b/reach/acceptance/data/tf/security_group_inbound_allow_esp.tf new file mode 100644 index 0000000..0b0018d --- /dev/null +++ b/reach/acceptance/data/tf/security_group_inbound_allow_esp.tf @@ -0,0 +1,12 @@ +resource "aws_security_group" "inbound_allow_esp" { + name = "aat_inbound_allow_esp" + description = "Allow all inbound ESP traffic" + vpc_id = aws_vpc.aat_vpc.id + + ingress { + from_port = 0 + to_port = 0 + protocol = "50" + cidr_blocks = ["0.0.0.0/0"] + } +} diff --git a/reach/acceptance/data/tf/security_group_inbound_allow_https_from_ip.tf b/reach/acceptance/data/tf/security_group_inbound_allow_https_from_ip.tf new file mode 100644 index 0000000..5a686c2 --- /dev/null +++ b/reach/acceptance/data/tf/security_group_inbound_allow_https_from_ip.tf @@ -0,0 +1,12 @@ +resource "aws_security_group" "inbound_allow_https_from_ip" { + name = "aat_inbound_allow_https_from_ip" + description = "Allow inbound HTTPS traffic from IP CIDR" + vpc_id = aws_vpc.aat_vpc.id + + ingress { + from_port = 443 + to_port = 443 + protocol = "tcp" + cidr_blocks = [aws_subnet.subnet_1_of_1.cidr_block] + } +} diff --git a/reach/acceptance/data/tf/security_group_inbound_allow_postgres_from_sg_no_rules.tf b/reach/acceptance/data/tf/security_group_inbound_allow_postgres_from_sg_no_rules.tf new file mode 100644 index 0000000..f26c8b6 --- /dev/null +++ b/reach/acceptance/data/tf/security_group_inbound_allow_postgres_from_sg_no_rules.tf @@ -0,0 +1,12 @@ +resource "aws_security_group" "inbound_allow_postgres_from_sg_no_rules" { + name = "aat_inbound_allow_postgres_from_sg_no_rules" + description = "Allow all inbound Postgres traffic from SG no rules" + vpc_id = aws_vpc.aat_vpc.id + + ingress { + from_port = 5432 + to_port = 5432 + protocol = "tcp" + security_groups = [aws_security_group.no_rules.id] + } +} diff --git a/reach/acceptance/data/tf/security_group_inbound_allow_ssh.tf b/reach/acceptance/data/tf/security_group_inbound_allow_ssh.tf new file mode 100644 index 0000000..5307dd3 --- /dev/null +++ b/reach/acceptance/data/tf/security_group_inbound_allow_ssh.tf @@ -0,0 +1,12 @@ +resource "aws_security_group" "inbound_allow_ssh_from_all_ip_addresses" { + name = "aat_inbound_allow_ssh_from_all_ip_addresses" + description = "Allow all SSH traffic" + vpc_id = aws_vpc.aat_vpc.id + + ingress { + from_port = 22 + to_port = 22 + protocol = "tcp" + cidr_blocks = ["0.0.0.0/0"] + } +} diff --git a/reach/acceptance/data/tf/security_group_inbound_allow_udp_dns_from_sg_no_rules.tf b/reach/acceptance/data/tf/security_group_inbound_allow_udp_dns_from_sg_no_rules.tf new file mode 100644 index 0000000..01ff1ee --- /dev/null +++ b/reach/acceptance/data/tf/security_group_inbound_allow_udp_dns_from_sg_no_rules.tf @@ -0,0 +1,12 @@ +resource "aws_security_group" "inbound_allow_udp_dns_from_sg_no_rules" { + name = "aat_inbound_allow_udp_dns_from_sg_no_rules" + description = "Allow DNS (UDP) from SG no rules" + vpc_id = aws_vpc.aat_vpc.id + + ingress { + from_port = 53 + to_port = 53 + protocol = "udp" + security_groups = [aws_security_group.no_rules.id] + } +} diff --git a/reach/acceptance/data/tf/security_group_no_rules.tf b/reach/acceptance/data/tf/security_group_no_rules.tf new file mode 100644 index 0000000..2678ec2 --- /dev/null +++ b/reach/acceptance/data/tf/security_group_no_rules.tf @@ -0,0 +1,5 @@ +resource "aws_security_group" "no_rules" { + name = "aat_no_rules" + description = "No rules associated with this group" + vpc_id = aws_vpc.aat_vpc.id +} diff --git a/reach/acceptance/data/tf/security_group_outbound_allow_all.tf b/reach/acceptance/data/tf/security_group_outbound_allow_all.tf new file mode 100644 index 0000000..f09fe67 --- /dev/null +++ b/reach/acceptance/data/tf/security_group_outbound_allow_all.tf @@ -0,0 +1,12 @@ +resource "aws_security_group" "outbound_allow_all" { + name = "aat_outbound_allow_all" + description = "Allow all outbound traffic" + vpc_id = aws_vpc.aat_vpc.id + + egress { + from_port = 0 + to_port = 0 + protocol = "-1" + cidr_blocks = ["0.0.0.0/0"] + } +} diff --git a/reach/acceptance/data/tf/security_group_outbound_allow_all_tcp.tf b/reach/acceptance/data/tf/security_group_outbound_allow_all_tcp.tf new file mode 100644 index 0000000..a785b0b --- /dev/null +++ b/reach/acceptance/data/tf/security_group_outbound_allow_all_tcp.tf @@ -0,0 +1,12 @@ +resource "aws_security_group" "outbound_allow_all_tcp" { + name = "aat_outbound_allow_all_tcp" + description = "Allow all outbound TCP traffic" + vpc_id = aws_vpc.aat_vpc.id + + egress { + from_port = 0 + to_port = 65535 + protocol = "tcp" + cidr_blocks = ["0.0.0.0/0"] + } +} diff --git a/reach/acceptance/data/tf/security_group_outbound_allow_all_udp_to_sg_no_rules.tf b/reach/acceptance/data/tf/security_group_outbound_allow_all_udp_to_sg_no_rules.tf new file mode 100644 index 0000000..33361d6 --- /dev/null +++ b/reach/acceptance/data/tf/security_group_outbound_allow_all_udp_to_sg_no_rules.tf @@ -0,0 +1,12 @@ +resource "aws_security_group" "outbound_allow_all_udp_to_sg_no_rules" { + name = "aat_outbound_allow_all_udp_to_sg_no_rules" + description = "Allow all outbound UDP traffic to SG no rules" + vpc_id = aws_vpc.aat_vpc.id + + egress { + from_port = 0 + to_port = 65535 + protocol = "udp" + security_groups = [aws_security_group.no_rules.id] + } +} diff --git a/reach/acceptance/data/tf/security_group_outbound_allow_esp.tf b/reach/acceptance/data/tf/security_group_outbound_allow_esp.tf new file mode 100644 index 0000000..40b5661 --- /dev/null +++ b/reach/acceptance/data/tf/security_group_outbound_allow_esp.tf @@ -0,0 +1,12 @@ +resource "aws_security_group" "outbound_allow_esp" { + name = "aat_outbound_allow_esp" + description = "Allow all outbound ESP traffic" + vpc_id = aws_vpc.aat_vpc.id + + egress { + from_port = 0 + to_port = 0 + protocol = "50" + cidr_blocks = ["0.0.0.0/0"] + } +} diff --git a/reach/acceptance/data/tf/security_group_outbound_allow_https_to_ip.tf b/reach/acceptance/data/tf/security_group_outbound_allow_https_to_ip.tf new file mode 100644 index 0000000..6b0ff7b --- /dev/null +++ b/reach/acceptance/data/tf/security_group_outbound_allow_https_to_ip.tf @@ -0,0 +1,12 @@ +resource "aws_security_group" "outbound_allow_https_to_ip" { + name = "aat_outbound_allow_https_to_ip" + description = "Allow outbound HTTPS traffic to IP CIDR" + vpc_id = aws_vpc.aat_vpc.id + + egress { + from_port = 443 + to_port = 443 + protocol = "tcp" + cidr_blocks = [aws_subnet.subnet_1_of_1.cidr_block] + } +} diff --git a/reach/acceptance/data/tf/security_group_outbound_allow_postgres_to_sg_no_rules.tf b/reach/acceptance/data/tf/security_group_outbound_allow_postgres_to_sg_no_rules.tf new file mode 100644 index 0000000..09e0e09 --- /dev/null +++ b/reach/acceptance/data/tf/security_group_outbound_allow_postgres_to_sg_no_rules.tf @@ -0,0 +1,12 @@ +resource "aws_security_group" "outbound_allow_postgres_to_sg_no_rules" { + name = "aat_outbound_allow_postgres_to_sg_no_rules" + description = "Allow all outbound Postgres traffic to SG no rules" + vpc_id = aws_vpc.aat_vpc.id + + egress { + from_port = 5432 + to_port = 5432 + protocol = "tcp" + security_groups = [aws_security_group.no_rules.id] + } +} diff --git a/reach/acceptance/data/tf/subnet_pair.tf b/reach/acceptance/data/tf/subnet_pair.tf new file mode 100644 index 0000000..7dfc73d --- /dev/null +++ b/reach/acceptance/data/tf/subnet_pair.tf @@ -0,0 +1,19 @@ +resource "aws_subnet" "subnet_1_of_2" { + vpc_id = aws_vpc.aat_vpc.id + cidr_block = "10.0.1.0/24" + map_public_ip_on_launch = false + + tags = { + Name = "aat_subnet_1_of_2" + } +} + +resource "aws_subnet" "subnet_2_of_2" { + vpc_id = aws_vpc.aat_vpc.id + cidr_block = "10.0.2.0/24" + map_public_ip_on_launch = false + + tags = { + Name = "aat_subnet_2_of_2" + } +} diff --git a/reach/acceptance/data/tf/subnet_single.tf b/reach/acceptance/data/tf/subnet_single.tf new file mode 100644 index 0000000..f26dce1 --- /dev/null +++ b/reach/acceptance/data/tf/subnet_single.tf @@ -0,0 +1,9 @@ +resource "aws_subnet" "subnet_1_of_1" { + vpc_id = aws_vpc.aat_vpc.id + cidr_block = "10.0.1.0/24" + map_public_ip_on_launch = false + + tags = { + Name = "aat_subnet_1_of_1" + } +} diff --git a/reach/acceptance/data/tf/vpc.tf b/reach/acceptance/data/tf/vpc.tf new file mode 100644 index 0000000..00702b0 --- /dev/null +++ b/reach/acceptance/data/tf/vpc.tf @@ -0,0 +1,7 @@ +resource "aws_vpc" "aat_vpc" { + cidr_block = "10.0.0.0/16" + + tags = { + Name = "aat_vpc" + } +} diff --git a/reach/acceptance/terraform/terraform.go b/reach/acceptance/terraform/terraform.go index 85c2845..d5b90bc 100644 --- a/reach/acceptance/terraform/terraform.go +++ b/reach/acceptance/terraform/terraform.go @@ -73,7 +73,7 @@ func (tf *Terraform) CleanUp() error { } // LoadFilesFromDir calls LoadFile for all specified files within the specified directory. -func (tf *Terraform) LoadFilesFromDir(dir string, files []string) error { +func (tf *Terraform) LoadFilesFromDir(dir string, files ...string) error { for _, file := range files { fullPath := path.Join(dir, file) err := tf.LoadFile(fullPath) diff --git a/reach/analysis.go b/reach/analysis.go index 0fbe416..bd065cb 100644 --- a/reach/analysis.go +++ b/reach/analysis.go @@ -36,7 +36,7 @@ func (a *Analysis) MergedTraffic() (TrafficContent, error) { for _, v := range a.NetworkVectors { if t := v.Traffic; t != nil { - mergedTrafficContent, err := result.Merge(*v.Traffic) + mergedTrafficContent, err := result.Merge(*t) if err != nil { return TrafficContent{}, err } @@ -47,3 +47,62 @@ func (a *Analysis) MergedTraffic() (TrafficContent, error) { return result, nil } + +// MergedReturnTraffic gets the return TrafficContent results of each of the analysis's network vectors and returns them as a merged TrafficContent. +func (a *Analysis) MergedReturnTraffic() (TrafficContent, error) { + result := newTrafficContent() + + for _, v := range a.NetworkVectors { + if t := v.ReturnTraffic; t != nil { + mergedTrafficContent, err := result.Merge(*t) + if err != nil { + return TrafficContent{}, err + } + + result = mergedTrafficContent + } + } + + return result, nil +} + +// PassesAssertReachable determines if the analysis implies the source can reach the destination over at least one protocol whose return path is unobstructed. +func (a Analysis) PassesAssertReachable() bool { + forwardTrafficCanReach := false + + // For each vector, see if there is an obstructed path + for _, vector := range a.NetworkVectors { + if !vector.Traffic.None() { + forwardTrafficCanReach = true + + for _, p := range vector.Traffic.Protocols() { + // is return path obstructed (at all) for this protocol? + if protocolReturnTraffic := vector.ReturnTraffic.protocol(p); !protocolReturnTraffic.complete() { + return false + } + } + } + } + + if !forwardTrafficCanReach { + return false + } + + return true +} + +// PassesAssertNotReachable determines if the analysis implies the source has no way to send network traffic to the destination. +func (a Analysis) PassesAssertNotReachable() bool { + // Here, we want to be more careful / conservative. If any traffic can get out to destination, fail, regardless of return traffic. + + forwardTraffic, err := a.MergedTraffic() + if err != nil { + return false + } + + if !forwardTraffic.None() { + return false + } + + return true +} diff --git a/reach/analyzer/analyzer.go b/reach/analyzer/analyzer.go index 1abe67b..0cdf9af 100644 --- a/reach/analyzer/analyzer.go +++ b/reach/analyzer/analyzer.go @@ -31,7 +31,7 @@ func (a *Analyzer) buildResourceCollection(subjects []*reach.Subject, provider a case aws.SubjectKindEC2Instance: id := subject.ID - ec2Instance, err := provider.GetEC2Instance(id) + ec2Instance, err := provider.EC2Instance(id) if err != nil { log.Fatalf("couldn't get resource: %v", err) } @@ -88,13 +88,20 @@ func (a *Analyzer) Analyze(subjects ...*reach.Subject) (*reach.Analysis, error) } trafficContents := reach.TrafficContentsFromFactors(factors) - trafficContent, err := reach.NewTrafficContentFromIntersectingMultiple(trafficContents) if err != nil { return nil, err } + returnTrafficContents := reach.ReturnTrafficContentsFromFactors(factors) + returnTrafficContent, err := reach.NewTrafficContentFromIntersectingMultiple(returnTrafficContents) + if err != nil { + return nil, err + } + processedVector.Traffic = &trafficContent + processedVector.ReturnTraffic = &returnTrafficContent + processedNetworkVectors[i] = processedVector } diff --git a/reach/analyzer/analyzer_test.go b/reach/analyzer/analyzer_test.go new file mode 100644 index 0000000..3c437de --- /dev/null +++ b/reach/analyzer/analyzer_test.go @@ -0,0 +1,314 @@ +package analyzer + +import ( + "log" + "testing" + + "github.com/luhring/reach/reach" + "github.com/luhring/reach/reach/acceptance" + "github.com/luhring/reach/reach/acceptance/terraform" + "github.com/luhring/reach/reach/aws" + "github.com/luhring/reach/reach/set" +) + +func TestAnalyze(t *testing.T) { + acceptance.Check(t) + + type testCase struct { + name string + files []string + expectedForwardTraffic reach.TrafficContent + expectedReturnTraffic reach.TrafficContent + } + + groupings := []struct { + name string + files []string + cases []testCase + }{ + { + "same subnet", + []string{ + "main.tf", + "outputs.tf", + "ami_ubuntu.tf", + "vpc.tf", + "subnet_single.tf", + }, + []testCase{ + { + "no security group rules", + []string{ + "ec2_instances_same_subnet_no_security_group_rules.tf", + "security_group_no_rules.tf", + }, + reach.NewTrafficContentForNoTraffic(), + reach.NewTrafficContentForAllTraffic(), + }, + { + "multiple protocols", + []string{ + "ec2_instances_same_subnet_multiple_protocols.tf", + "security_group_no_rules.tf", + "security_group_outbound_allow_all_udp_to_sg_no_rules.tf", + "security_group_outbound_allow_esp.tf", + "security_group_outbound_allow_all_tcp.tf", + "security_group_inbound_allow_udp_dns_from_sg_no_rules.tf", + "security_group_inbound_allow_esp.tf", + "security_group_inbound_allow_ssh.tf", + }, + trafficAssorted(), + reach.NewTrafficContentForAllTraffic(), + }, + { + "UDP DNS via SG reference", + []string{ + "ec2_instances_same_subnet_udp_dns_via_sg_reference.tf", + "security_group_no_rules.tf", + "security_group_outbound_allow_all_udp_to_sg_no_rules.tf", + "security_group_inbound_allow_udp_dns_from_sg_no_rules.tf", + }, + trafficDNS(), + reach.NewTrafficContentForAllTraffic(), + }, + { + "HTTPS via two-way IP match", + []string{ + "ec2_instances_same_subnet_https_via_two-way_sg_ip_match.tf", + "security_group_outbound_allow_https_to_ip.tf", + "security_group_inbound_allow_https_from_ip.tf", + }, + trafficHTTPS(), + reach.NewTrafficContentForAllTraffic(), + }, + { + "SSH", + []string{ + "ec2_instances_same_subnet_ssh.tf", + "security_group_outbound_allow_all.tf", + "security_group_inbound_allow_ssh.tf", + }, + trafficSSH(), + reach.NewTrafficContentForAllTraffic(), + }, + { + "all traffic", + []string{ + "ec2_instances_same_subnet_all_traffic.tf", + "security_group_outbound_allow_all.tf", + "security_group_inbound_allow_all.tf", + }, + reach.NewTrafficContentForAllTraffic(), + reach.NewTrafficContentForAllTraffic(), + }, + }, + }, + { + "same VPC", + []string{ + "main.tf", + "outputs.tf", + "ami_ubuntu.tf", + "vpc.tf", + "subnet_pair.tf", + }, + []testCase{ + { + "all traffic", + []string{ + "network_acl_both_subnets_all_traffic.tf", + "ec2_instances_same_vpc_all_traffic.tf", + "security_group_outbound_allow_all.tf", + "security_group_inbound_allow_all.tf", + }, + reach.NewTrafficContentForAllTraffic(), + reach.NewTrafficContentForAllTraffic(), + }, + { + "no NACL allow rules", + []string{ + "network_acl_both_subnets_no_traffic.tf", + "ec2_instances_same_vpc_all_traffic.tf", + "security_group_outbound_allow_all.tf", + "security_group_inbound_allow_all.tf", + }, + reach.NewTrafficContentForNoTraffic(), + reach.NewTrafficContentForNoTraffic(), + }, + { + "NACL rules don't match SG rules", + []string{ + "network_acl_both_subnets_all_tcp.tf", + "ec2_instances_same_vpc_all_esp.tf", + "security_group_outbound_allow_esp.tf", + "security_group_inbound_allow_all.tf", + }, + reach.NewTrafficContentForNoTraffic(), + trafficTCP(), // TODO: Revisit return traffic calculation for this scenario + }, + { + "Postgres with tightened rules", + []string{ + "network_acl_source_subnet_tightened_postgres.tf", + "network_acl_destination_subnet_tightened_postgres.tf", + "ec2_instances_same_vpc_postgres.tf", + "security_group_no_rules.tf", + "security_group_outbound_allow_postgres_to_sg_no_rules.tf", + "security_group_inbound_allow_postgres_from_sg_no_rules.tf", + }, + trafficPostgres(), + trafficTCP(), + }, + }, + }, + } + + for _, g := range groupings { + t.Run(g.name, func(t *testing.T) { + for _, tc := range g.cases { + t.Run(tc.name, func(t *testing.T) { + // if tc.name != "Postgres with tightened rules" || g.name != "same VPC" { // TODO: remove this to run full test suite + // t.SkipNow() + // } + + // Setup (and deferred teardown) + tf, err := terraform.New(t) + if err != nil { + t.Fatal(err) + } + defer func() { + err = tf.CleanUp() + if err != nil { + t.Fatal(err) + } + }() + + err = tf.LoadFilesFromDir( + "../acceptance/data/tf", + append(g.files, tc.files...)..., + ) + if err != nil { + t.Fatal(err) + } + + err = tf.Init() + if err != nil { + t.Fatal(err) + } + + defer func() { + err = tf.Destroy() // Putting this before apply so that we're not left with some resources not destroyed after failure from apply step. + if err != nil { + t.Fatal(err) + } + }() + err = tf.PlanAndApply() + if err != nil { + t.Fatal(err) + } + + sourceID, err := tf.Output("source_id") + if err != nil { + t.Fatal(err) + } + destinationID, err := tf.Output("destination_id") + if err != nil { + t.Fatal(err) + } + + source, err := aws.NewEC2InstanceSubject(sourceID, reach.SubjectRoleSource) + if err != nil { + t.Fatal(err) + } + destination, err := aws.NewEC2InstanceSubject(destinationID, reach.SubjectRoleDestination) + if err != nil { + t.Fatal(err) + } + + // Analyze + + analyzer := New() + + log.Print("analyzing...") + analysis, err := analyzer.Analyze(source, destination) + if err != nil { + t.Fatal(err) + } + + // Tests + + log.Print("verifying analysis results...") + + if forwardTraffic := analysis.NetworkVectors[0].Traffic; forwardTraffic.String() != tc.expectedForwardTraffic.String() { // TODO: consider a better comparison method besides strings + t.Errorf("forward traffic -- expected:\n%v\nbut was:\n%v\n", tc.expectedForwardTraffic, forwardTraffic) + } else { + log.Print("✓ forward traffic content is correct") + } + + if returnTraffic := analysis.NetworkVectors[0].ReturnTraffic; returnTraffic.String() != tc.expectedReturnTraffic.String() { + t.Errorf("return traffic -- expected:\n%v\nbut was:\n%v\n", tc.expectedReturnTraffic, returnTraffic) + } else { + log.Print("✓ return traffic content is correct") + } + }) + } + }) + } +} + +func trafficSSH() reach.TrafficContent { + ports, err := set.NewPortSetFromRange(22, 22) + if err != nil { + panic(err) + } + + return reach.NewTrafficContentForPorts(reach.ProtocolTCP, ports) +} + +func trafficHTTPS() reach.TrafficContent { + ports, err := set.NewPortSetFromRange(443, 443) + if err != nil { + panic(err) + } + + return reach.NewTrafficContentForPorts(reach.ProtocolTCP, ports) +} + +func trafficDNS() reach.TrafficContent { + ports, err := set.NewPortSetFromRange(53, 53) + if err != nil { + panic(err) + } + + return reach.NewTrafficContentForPorts(reach.ProtocolUDP, ports) +} + +func trafficESP() reach.TrafficContent { + return reach.NewTrafficContentForCustomProtocol(50, true) +} + +func trafficAssorted() reach.TrafficContent { + tc, err := reach.NewTrafficContentFromMergingMultiple([]reach.TrafficContent{ + trafficDNS(), + trafficSSH(), + trafficESP(), + }) + if err != nil { + panic(err) + } + + return tc +} + +func trafficTCP() reach.TrafficContent { + return reach.NewTrafficContentForPorts(reach.ProtocolTCP, set.NewFullPortSet()) +} + +func trafficPostgres() reach.TrafficContent { + ports, err := set.NewPortSetFromRange(5432, 5432) + if err != nil { + panic(err) + } + + return reach.NewTrafficContentForPorts(reach.ProtocolTCP, ports) +} diff --git a/reach/aws/api/ec2_instance.go b/reach/aws/api/ec2_instance.go index 8c94799..c59f122 100644 --- a/reach/aws/api/ec2_instance.go +++ b/reach/aws/api/ec2_instance.go @@ -9,8 +9,8 @@ import ( reachAWS "github.com/luhring/reach/reach/aws" ) -// GetEC2Instance queries the AWS API for an EC2 instance matching the given ID. -func (provider *ResourceProvider) GetEC2Instance(id string) (*reachAWS.EC2Instance, error) { +// EC2Instance queries the AWS API for an EC2 instance matching the given ID. +func (provider *ResourceProvider) EC2Instance(id string) (*reachAWS.EC2Instance, error) { input := &ec2.DescribeInstancesInput{ InstanceIds: []*string{ aws.String(id), @@ -38,8 +38,8 @@ func (provider *ResourceProvider) GetEC2Instance(id string) (*reachAWS.EC2Instan return &instance, nil } -// GetAllEC2Instances queries the AWS API for all EC2 instances. -func (provider *ResourceProvider) GetAllEC2Instances() ([]reachAWS.EC2Instance, error) { +// AllEC2Instances queries the AWS API for all EC2 instances. +func (provider *ResourceProvider) AllEC2Instances() ([]reachAWS.EC2Instance, error) { const errFormat = "unable to get all EC2 instances: %v" describeInstancesOutput, err := provider.ec2.DescribeInstances(nil) @@ -60,9 +60,9 @@ func (provider *ResourceProvider) GetAllEC2Instances() ([]reachAWS.EC2Instance, func newEC2InstanceFromAPI(instance *ec2.Instance) reachAWS.EC2Instance { return reachAWS.EC2Instance{ ID: aws.StringValue(instance.InstanceId), - NameTag: getNameTag(instance.Tags), + NameTag: nameTag(instance.Tags), State: aws.StringValue(instance.State.Name), - NetworkInterfaceAttachments: getNetworkInterfaceAttachments(instance), + NetworkInterfaceAttachments: networkInterfaceAttachments(instance), } } @@ -79,7 +79,7 @@ func extractEC2Instances(reservations []*ec2.Reservation) ([]reachAWS.EC2Instanc return instances, nil } -func getNetworkInterfaceAttachments(instance *ec2.Instance) []reachAWS.NetworkInterfaceAttachment { +func networkInterfaceAttachments(instance *ec2.Instance) []reachAWS.NetworkInterfaceAttachment { var attachments []reachAWS.NetworkInterfaceAttachment if instance.NetworkInterfaces != nil && len(instance.NetworkInterfaces) > 0 { diff --git a/reach/aws/api/elastic_network_interface.go b/reach/aws/api/elastic_network_interface.go index 8743d4c..0cbd4e0 100644 --- a/reach/aws/api/elastic_network_interface.go +++ b/reach/aws/api/elastic_network_interface.go @@ -9,8 +9,8 @@ import ( reachAWS "github.com/luhring/reach/reach/aws" ) -// GetElasticNetworkInterface queries the AWS API for an elastic network interface matching the given ID. -func (provider *ResourceProvider) GetElasticNetworkInterface(id string) (*reachAWS.ElasticNetworkInterface, error) { +// ElasticNetworkInterface queries the AWS API for an elastic network interface matching the given ID. +func (provider *ResourceProvider) ElasticNetworkInterface(id string) (*reachAWS.ElasticNetworkInterface, error) { input := &ec2.DescribeNetworkInterfacesInput{ NetworkInterfaceIds: []*string{ aws.String(id), @@ -30,23 +30,23 @@ func (provider *ResourceProvider) GetElasticNetworkInterface(id string) (*reachA } func newElasticNetworkInterfaceFromAPI(eni *ec2.NetworkInterface) reachAWS.ElasticNetworkInterface { - publicIPv4Address := getPublicIPAddress(eni.Association) - privateIPv4Addresses := getPrivateIPAddresses(eni.PrivateIpAddresses) - ipv6Addresses := getIPv6Addresses(eni.Ipv6Addresses) + publicIPv4Address := publicIPAddress(eni.Association) + privateIPv4Addresses := privateIPAddresses(eni.PrivateIpAddresses) + ipv6Addresses := ipv6Addresses(eni.Ipv6Addresses) return reachAWS.ElasticNetworkInterface{ ID: aws.StringValue(eni.NetworkInterfaceId), - NameTag: getNameTag(eni.TagSet), + NameTag: nameTag(eni.TagSet), SubnetID: aws.StringValue(eni.SubnetId), VPCID: aws.StringValue(eni.VpcId), - SecurityGroupIDs: getSecurityGroupIDs(eni.Groups), + SecurityGroupIDs: securityGroupIDs(eni.Groups), PublicIPv4Address: publicIPv4Address, PrivateIPv4Addresses: privateIPv4Addresses, IPv6Addresses: ipv6Addresses, } } -func getSecurityGroupID(identifier *ec2.GroupIdentifier) string { +func securityGroupID(identifier *ec2.GroupIdentifier) string { if identifier == nil { return "" } @@ -54,17 +54,17 @@ func getSecurityGroupID(identifier *ec2.GroupIdentifier) string { return aws.StringValue(identifier.GroupId) } -func getSecurityGroupIDs(identifiers []*ec2.GroupIdentifier) []string { +func securityGroupIDs(identifiers []*ec2.GroupIdentifier) []string { ids := make([]string, len(identifiers)) for i, identifier := range identifiers { - ids[i] = getSecurityGroupID(identifier) + ids[i] = securityGroupID(identifier) } return ids } -func getPrivateIPAddress(address *ec2.NetworkInterfacePrivateIpAddress) net.IP { +func privateIPAddress(address *ec2.NetworkInterfacePrivateIpAddress) net.IP { if address == nil { return net.IP{} } @@ -72,17 +72,17 @@ func getPrivateIPAddress(address *ec2.NetworkInterfacePrivateIpAddress) net.IP { return net.ParseIP(aws.StringValue(address.PrivateIpAddress)) } -func getPrivateIPAddresses(addresses []*ec2.NetworkInterfacePrivateIpAddress) []net.IP { +func privateIPAddresses(addresses []*ec2.NetworkInterfacePrivateIpAddress) []net.IP { ips := make([]net.IP, len(addresses)) for i, address := range addresses { - ips[i] = getPrivateIPAddress(address) + ips[i] = privateIPAddress(address) } return ips } -func getIPv6Address(address *ec2.NetworkInterfaceIpv6Address) net.IP { +func ipv6Address(address *ec2.NetworkInterfaceIpv6Address) net.IP { if address == nil { return net.IP{} } @@ -90,17 +90,17 @@ func getIPv6Address(address *ec2.NetworkInterfaceIpv6Address) net.IP { return net.ParseIP(aws.StringValue(address.Ipv6Address)) } -func getIPv6Addresses(addresses []*ec2.NetworkInterfaceIpv6Address) []net.IP { +func ipv6Addresses(addresses []*ec2.NetworkInterfaceIpv6Address) []net.IP { ips := make([]net.IP, len(addresses)) for i, address := range addresses { - ips[i] = getIPv6Address(address) + ips[i] = ipv6Address(address) } return ips } -func getPublicIPAddress(association *ec2.NetworkInterfaceAssociation) net.IP { +func publicIPAddress(association *ec2.NetworkInterfaceAssociation) net.IP { if association == nil { return net.IP{} } diff --git a/reach/aws/api/network_acl.go b/reach/aws/api/network_acl.go index abd6821..5f8b681 100644 --- a/reach/aws/api/network_acl.go +++ b/reach/aws/api/network_acl.go @@ -11,8 +11,8 @@ import ( reachAWS "github.com/luhring/reach/reach/aws" ) -// GetNetworkACL queries the AWS API for a network ACL matching the given ID. -func (provider *ResourceProvider) GetNetworkACL(id string) (*reachAWS.NetworkACL, error) { +// NetworkACL queries the AWS API for a network ACL matching the given ID. +func (provider *ResourceProvider) NetworkACL(id string) (*reachAWS.NetworkACL, error) { input := &ec2.DescribeNetworkAclsInput{ NetworkAclIds: []*string{ aws.String(id), @@ -32,8 +32,8 @@ func (provider *ResourceProvider) GetNetworkACL(id string) (*reachAWS.NetworkACL } func newNetworkACLFromAPI(networkACL *ec2.NetworkAcl) reachAWS.NetworkACL { - inboundRules := getInboundNetworkACLRules(networkACL.Entries) - outboundRules := getOutboundNetworkACLRules(networkACL.Entries) + inboundRules := inboundNetworkACLRules(networkACL.Entries) + outboundRules := outboundNetworkACLRules(networkACL.Entries) return reachAWS.NetworkACL{ ID: aws.StringValue(networkACL.NetworkAclId), @@ -42,17 +42,17 @@ func newNetworkACLFromAPI(networkACL *ec2.NetworkAcl) reachAWS.NetworkACL { } } -func getNetworkACLRulesForSingleDirection(entries []*ec2.NetworkAclEntry, inbound bool) []reachAWS.NetworkACLRule { +func networkACLRulesForSingleDirection(entries []*ec2.NetworkAclEntry, inbound bool) []reachAWS.NetworkACLRule { if entries == nil { return nil } - rules := make([]reachAWS.NetworkACLRule, len(entries)) + rules := []reachAWS.NetworkACLRule{} - for i, entry := range entries { + for _, entry := range entries { if entry != nil { if inbound != aws.BoolValue(entry.Egress) { - rules[i] = getNetworkACLRule(entry) + rules = append(rules, networkACLRule(entry)) } } } @@ -60,15 +60,15 @@ func getNetworkACLRulesForSingleDirection(entries []*ec2.NetworkAclEntry, inboun return rules } -func getInboundNetworkACLRules(entries []*ec2.NetworkAclEntry) []reachAWS.NetworkACLRule { - return getNetworkACLRulesForSingleDirection(entries, true) +func inboundNetworkACLRules(entries []*ec2.NetworkAclEntry) []reachAWS.NetworkACLRule { + return networkACLRulesForSingleDirection(entries, true) } -func getOutboundNetworkACLRules(entries []*ec2.NetworkAclEntry) []reachAWS.NetworkACLRule { - return getNetworkACLRulesForSingleDirection(entries, false) +func outboundNetworkACLRules(entries []*ec2.NetworkAclEntry) []reachAWS.NetworkACLRule { + return networkACLRulesForSingleDirection(entries, false) } -func getNetworkACLRule(entry *ec2.NetworkAclEntry) reachAWS.NetworkACLRule { // note: this function ignores rule direction (inbound vs. outbound) +func networkACLRule(entry *ec2.NetworkAclEntry) reachAWS.NetworkACLRule { // note: this function ignores rule direction (inbound vs. outbound) if entry == nil { return reachAWS.NetworkACLRule{} } diff --git a/reach/aws/api/resource_provider.go b/reach/aws/api/resource_provider.go index 4a160e0..71b9928 100644 --- a/reach/aws/api/resource_provider.go +++ b/reach/aws/api/resource_provider.go @@ -33,7 +33,7 @@ func NewResourceProvider() *ResourceProvider { } } -func getNameTag(tags []*ec2.Tag) string { +func nameTag(tags []*ec2.Tag) string { if tags != nil && len(tags) > 0 { for _, tag := range tags { if aws.StringValue(tag.Key) == "Name" { diff --git a/reach/aws/api/route_table.go b/reach/aws/api/route_table.go index 9679875..722cff4 100644 --- a/reach/aws/api/route_table.go +++ b/reach/aws/api/route_table.go @@ -7,8 +7,8 @@ import ( reachAWS "github.com/luhring/reach/reach/aws" ) -// GetRouteTable queries the AWS API for a route table matching the given ID. -func (provider *ResourceProvider) GetRouteTable(id string) (*reachAWS.RouteTable, error) { +// RouteTable queries the AWS API for a route table matching the given ID. +func (provider *ResourceProvider) RouteTable(id string) (*reachAWS.RouteTable, error) { input := &ec2.DescribeRouteTablesInput{ RouteTableIds: []*string{ aws.String(id), @@ -37,11 +37,11 @@ func newRouteTableFromAPI(routeTable *ec2.RouteTable) reachAWS.RouteTable { } } -func getRouteTableRoutes(routes []*ec2.Route) []reachAWS.RouteTableRoute { +func routeTableRoutes(routes []*ec2.Route) []reachAWS.RouteTableRoute { return nil // TODO: implement } -func getRouteTableRoute(route *ec2.Route) reachAWS.RouteTableRoute { +func routeTableRoute(route *ec2.Route) reachAWS.RouteTableRoute { if route == nil { return reachAWS.RouteTableRoute{} } diff --git a/reach/aws/api/security_group.go b/reach/aws/api/security_group.go index cb78f25..157d388 100644 --- a/reach/aws/api/security_group.go +++ b/reach/aws/api/security_group.go @@ -12,8 +12,8 @@ import ( "github.com/luhring/reach/reach/set" ) -// GetSecurityGroup queries the AWS API for a security group matching the given ID. -func (provider *ResourceProvider) GetSecurityGroup(id string) (*reachAWS.SecurityGroup, error) { +// SecurityGroup queries the AWS API for a security group matching the given ID. +func (provider *ResourceProvider) SecurityGroup(id string) (*reachAWS.SecurityGroup, error) { input := &ec2.DescribeSecurityGroupsInput{ GroupIds: []*string{ aws.String(id), @@ -33,12 +33,12 @@ func (provider *ResourceProvider) GetSecurityGroup(id string) (*reachAWS.Securit } func newSecurityGroupFromAPI(securityGroup *ec2.SecurityGroup) reachAWS.SecurityGroup { - inboundRules := getSecurityGroupRules(securityGroup.IpPermissions) - outboundRules := getSecurityGroupRules(securityGroup.IpPermissionsEgress) + inboundRules := securityGroupRules(securityGroup.IpPermissions) + outboundRules := securityGroupRules(securityGroup.IpPermissionsEgress) return reachAWS.SecurityGroup{ ID: aws.StringValue(securityGroup.GroupId), - NameTag: getNameTag(securityGroup.Tags), + NameTag: nameTag(securityGroup.Tags), GroupName: aws.StringValue(securityGroup.GroupName), VPCID: aws.StringValue(securityGroup.VpcId), InboundRules: inboundRules, @@ -46,7 +46,7 @@ func newSecurityGroupFromAPI(securityGroup *ec2.SecurityGroup) reachAWS.Security } } -func getSecurityGroupRules(inputRules []*ec2.IpPermission) []reachAWS.SecurityGroupRule { +func securityGroupRules(inputRules []*ec2.IpPermission) []reachAWS.SecurityGroupRule { if inputRules == nil { return nil } @@ -55,19 +55,19 @@ func getSecurityGroupRules(inputRules []*ec2.IpPermission) []reachAWS.SecurityGr for i, inputRule := range inputRules { if inputRule != nil { - rules[i] = getSecurityGroupRule(inputRule) + rules[i] = securityGroupRule(inputRule) } } return rules } -func getSecurityGroupRule(rule *ec2.IpPermission) reachAWS.SecurityGroupRule { // note: this function ignores rule direction (inbound vs. outbound) +func securityGroupRule(rule *ec2.IpPermission) reachAWS.SecurityGroupRule { // note: this function ignores rule direction (inbound vs. outbound) if rule == nil { return reachAWS.SecurityGroupRule{} } - tc, err := newTrafficContentFromAWSIPPermission(rule) + tc, err := trafficContentFromAWSIPPermission(rule) if err != nil { panic(err) // TODO: Better error handling } @@ -78,14 +78,14 @@ func getSecurityGroupRule(rule *ec2.IpPermission) reachAWS.SecurityGroupRule { / if rule.UserIdGroupPairs != nil { firstPair := rule.UserIdGroupPairs[0] // if panicking, see above to-do... - targetSecurityGroupReferenceID = getSecurityGroupReferenceID(firstPair) - targetSecurityGroupReferenceAccountID = getSecurityGroupReferenceAccountID(firstPair) + targetSecurityGroupReferenceID = securityGroupReferenceID(firstPair) + targetSecurityGroupReferenceAccountID = securityGroupReferenceAccountID(firstPair) } // TODO: Handle prefix lists (and thus VPC endpoints) // for context: https://docs.aws.amazon.com/vpc/latest/userguide/vpce-gateway.html - targetIPNetworks := getIPNetworksFromSecurityGroupRule(rule.IpRanges, rule.Ipv6Ranges) + targetIPNetworks := ipNetworksFromSecurityGroupRule(rule.IpRanges, rule.Ipv6Ranges) return reachAWS.SecurityGroupRule{ TrafficContent: tc, @@ -117,7 +117,7 @@ func newPortSetFromAWSIPPermission(permission *ec2.IpPermission) (set.PortSet, e return set.NewPortSetFromRange(uint16(from), uint16(to)) } -func getSecurityGroupReferenceID(pair *ec2.UserIdGroupPair) string { +func securityGroupReferenceID(pair *ec2.UserIdGroupPair) string { if pair == nil { return "" } @@ -125,7 +125,7 @@ func getSecurityGroupReferenceID(pair *ec2.UserIdGroupPair) string { return aws.StringValue(pair.GroupId) } -func getSecurityGroupReferenceAccountID(pair *ec2.UserIdGroupPair) string { +func securityGroupReferenceAccountID(pair *ec2.UserIdGroupPair) string { if pair == nil { return "" } @@ -133,7 +133,7 @@ func getSecurityGroupReferenceAccountID(pair *ec2.UserIdGroupPair) string { return aws.StringValue(pair.UserId) } -func getIPNetworksFromSecurityGroupRule(ipv4Ranges []*ec2.IpRange, ipv6Ranges []*ec2.Ipv6Range) []*net.IPNet { +func ipNetworksFromSecurityGroupRule(ipv4Ranges []*ec2.IpRange, ipv6Ranges []*ec2.Ipv6Range) []*net.IPNet { networks := make([]*net.IPNet, len(ipv4Ranges)+len(ipv6Ranges)) for i, block := range ipv4Ranges { @@ -157,7 +157,7 @@ func getIPNetworksFromSecurityGroupRule(ipv4Ranges []*ec2.IpRange, ipv6Ranges [] return networks } -func newTrafficContentFromAWSIPPermission(permission *ec2.IpPermission) (reach.TrafficContent, error) { +func trafficContentFromAWSIPPermission(permission *ec2.IpPermission) (reach.TrafficContent, error) { const errCreation = "unable to create content: %v" protocol, err := convertAWSIPProtocolStringToProtocol(permission.IpProtocol) diff --git a/reach/aws/api/security_group_reference.go b/reach/aws/api/security_group_reference.go index e23a168..9fdb157 100644 --- a/reach/aws/api/security_group_reference.go +++ b/reach/aws/api/security_group_reference.go @@ -4,12 +4,12 @@ import ( reachAWS "github.com/luhring/reach/reach/aws" ) -// GetSecurityGroupReference queries the AWS API for a security group matching the given ID, but returns a security group reference representation instead of the full security group representation. -func (provider *ResourceProvider) GetSecurityGroupReference(id, accountID string) (*reachAWS.SecurityGroupReference, error) { +// SecurityGroupReference queries the AWS API for a security group matching the given ID, but returns a security group reference representation instead of the full security group representation. +func (provider *ResourceProvider) SecurityGroupReference(id, accountID string) (*reachAWS.SecurityGroupReference, error) { // TODO: Incorporate account ID in search. // In the meantime, this will be a known bug, where other accounts are not considered. - sg, err := provider.GetSecurityGroup(id) + sg, err := provider.SecurityGroup(id) if err != nil { return nil, err } diff --git a/reach/aws/api/subnet.go b/reach/aws/api/subnet.go index 31c2f02..247bbf8 100644 --- a/reach/aws/api/subnet.go +++ b/reach/aws/api/subnet.go @@ -7,8 +7,8 @@ import ( reachAWS "github.com/luhring/reach/reach/aws" ) -// GetSubnet queries the AWS API for a subnet matching the given ID. -func (provider *ResourceProvider) GetSubnet(id string) (*reachAWS.Subnet, error) { +// Subnet queries the AWS API for a subnet matching the given ID. +func (provider *ResourceProvider) Subnet(id string) (*reachAWS.Subnet, error) { input := &ec2.DescribeSubnetsInput{ SubnetIds: []*string{ aws.String(id), @@ -23,13 +23,43 @@ func (provider *ResourceProvider) GetSubnet(id string) (*reachAWS.Subnet, error) return nil, err } - subnet := newSubnetFromAPI(result.Subnets[0]) + awsSubnet := result.Subnets[0] + networkACLID, err := provider.networkACLIDFromSubnetID(aws.StringValue(awsSubnet.SubnetId)) + if err != nil { + return nil, err + } + + subnet := newSubnetFromAPI(result.Subnets[0], networkACLID) return &subnet, nil } -func newSubnetFromAPI(subnet *ec2.Subnet) reachAWS.Subnet { +func newSubnetFromAPI(subnet *ec2.Subnet, networkACLID string) reachAWS.Subnet { return reachAWS.Subnet{ - ID: aws.StringValue(subnet.SubnetId), - VPCID: aws.StringValue(subnet.VpcId), + ID: aws.StringValue(subnet.SubnetId), + NetworkACLID: networkACLID, + VPCID: aws.StringValue(subnet.VpcId), } } + +func (provider *ResourceProvider) networkACLIDFromSubnetID(id string) (string, error) { + input := &ec2.DescribeNetworkAclsInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("association.subnet-id"), + Values: []*string{ + aws.String(id), + }, + }, + }, + } + result, err := provider.ec2.DescribeNetworkAcls(input) + if err != nil { + return "", err + } + + if err = ensureSingleResult(len(result.NetworkAcls), "network ACL (via subnet)", id); err != nil { + return "", err + } + + return aws.StringValue(result.NetworkAcls[0].NetworkAclId), nil +} diff --git a/reach/aws/api/vpc.go b/reach/aws/api/vpc.go index 26837c8..5feece7 100644 --- a/reach/aws/api/vpc.go +++ b/reach/aws/api/vpc.go @@ -9,8 +9,8 @@ import ( reachAWS "github.com/luhring/reach/reach/aws" ) -// GetVPC queries the AWS API for a VPC matching the given ID. -func (provider *ResourceProvider) GetVPC(id string) (*reachAWS.VPC, error) { +// VPC queries the AWS API for a VPC matching the given ID. +func (provider *ResourceProvider) VPC(id string) (*reachAWS.VPC, error) { input := &ec2.DescribeVpcsInput{ VpcIds: []*string{ aws.String(id), @@ -30,8 +30,8 @@ func (provider *ResourceProvider) GetVPC(id string) (*reachAWS.VPC, error) { } func newVPCFromAPI(vpc *ec2.Vpc) reachAWS.VPC { - ipv4CIDRs := getCIDRs(vpc.CidrBlockAssociationSet) - ipv6CIDRs := getIPv6CIDRs(vpc.Ipv6CidrBlockAssociationSet) + ipv4CIDRs := cidrs(vpc.CidrBlockAssociationSet) + ipv6CIDRs := ipv6CIDRs(vpc.Ipv6CidrBlockAssociationSet) return reachAWS.VPC{ ID: aws.StringValue(vpc.VpcId), @@ -40,17 +40,17 @@ func newVPCFromAPI(vpc *ec2.Vpc) reachAWS.VPC { } } -func getCIDRs(associationSet []*ec2.VpcCidrBlockAssociation) []net.IPNet { +func cidrs(associationSet []*ec2.VpcCidrBlockAssociation) []net.IPNet { cidrs := make([]net.IPNet, len(associationSet)) for i, association := range associationSet { - cidrs[i] = getCIDR(association) + cidrs[i] = cidr(association) } return cidrs } -func getCIDR(association *ec2.VpcCidrBlockAssociation) net.IPNet { +func cidr(association *ec2.VpcCidrBlockAssociation) net.IPNet { if association == nil { return net.IPNet{} } @@ -63,17 +63,17 @@ func getCIDR(association *ec2.VpcCidrBlockAssociation) net.IPNet { return *cidr } -func getIPv6CIDRs(associationSet []*ec2.VpcIpv6CidrBlockAssociation) []net.IPNet { +func ipv6CIDRs(associationSet []*ec2.VpcIpv6CidrBlockAssociation) []net.IPNet { cidrs := make([]net.IPNet, len(associationSet)) for i, association := range associationSet { - cidrs[i] = getIPv6CIDR(association) + cidrs[i] = ipv6CIDR(association) } return cidrs } -func getIPv6CIDR(association *ec2.VpcIpv6CidrBlockAssociation) net.IPNet { +func ipv6CIDR(association *ec2.VpcIpv6CidrBlockAssociation) net.IPNet { if association == nil { return net.IPNet{} } diff --git a/reach/aws/ec2_instance.go b/reach/aws/ec2_instance.go index e5635c6..3091728 100644 --- a/reach/aws/ec2_instance.go +++ b/reach/aws/ec2_instance.go @@ -54,7 +54,7 @@ func (i EC2Instance) Dependencies(provider ResourceProvider) (*reach.ResourceCol rc := reach.NewResourceCollection() for _, attachment := range i.NetworkInterfaceAttachments { - attachmentDependencies, err := dependenciesForNetworkInterfaceAttachment(attachment, provider) + attachmentDependencies, err := attachment.Dependencies(provider) if err != nil { return nil, err } @@ -64,28 +64,6 @@ func (i EC2Instance) Dependencies(provider ResourceProvider) (*reach.ResourceCol return rc, nil } -func dependenciesForNetworkInterfaceAttachment(attachment NetworkInterfaceAttachment, provider ResourceProvider) (*reach.ResourceCollection, error) { - rc := reach.NewResourceCollection() - - eni, err := provider.GetElasticNetworkInterface(attachment.ElasticNetworkInterfaceID) - if err != nil { - return nil, err - } - rc.Put(reach.ResourceReference{ - Domain: ResourceDomainAWS, - Kind: ResourceKindElasticNetworkInterface, - ID: eni.ID, - }, eni.ToResource()) - - eniDependencies, err := eni.Dependencies(provider) - if err != nil { - return nil, err - } - rc.Merge(eniDependencies) - - return rc, nil -} - func (i EC2Instance) networkPoints(rc *reach.ResourceCollection) []reach.NetworkPoint { var points []reach.NetworkPoint diff --git a/reach/aws/elastic_network_interface.go b/reach/aws/elastic_network_interface.go index 3307c3a..bc594ce 100644 --- a/reach/aws/elastic_network_interface.go +++ b/reach/aws/elastic_network_interface.go @@ -56,7 +56,7 @@ func (eni ElasticNetworkInterface) ToResourceReference() reach.ResourceReference func (eni ElasticNetworkInterface) Dependencies(provider ResourceProvider) (*reach.ResourceCollection, error) { rc := reach.NewResourceCollection() - subnet, err := provider.GetSubnet(eni.SubnetID) + subnet, err := provider.Subnet(eni.SubnetID) if err != nil { return nil, err } @@ -66,7 +66,13 @@ func (eni ElasticNetworkInterface) Dependencies(provider ResourceProvider) (*rea ID: subnet.ID, }, subnet.ToResource()) - vpc, err := provider.GetVPC(eni.VPCID) + subnetDependencies, err := subnet.Dependencies(provider) + if err != nil { + return nil, err + } + rc.Merge(subnetDependencies) + + vpc, err := provider.VPC(eni.VPCID) if err != nil { return nil, err } @@ -77,7 +83,7 @@ func (eni ElasticNetworkInterface) Dependencies(provider ResourceProvider) (*rea }, vpc.ToResource()) for _, sgID := range eni.SecurityGroupIDs { - sg, err := provider.GetSecurityGroup(sgID) + sg, err := provider.SecurityGroup(sgID) if err != nil { return nil, err } diff --git a/reach/aws/explainer.go b/reach/aws/explainer.go index 4d9dec1..48a618c 100644 --- a/reach/aws/explainer.go +++ b/reach/aws/explainer.go @@ -28,12 +28,16 @@ func NewExplainer(analysis reach.Analysis) *Explainer { func (ex *Explainer) NetworkPoint(point reach.NetworkPoint, p reach.Perspective) string { var outputItems []string - if instanceStateFactor, _ := getInstanceStateFactor(point.Factors); instanceStateFactor != nil { - outputItems = append(outputItems, ex.InstanceState(*instanceStateFactor)) + if f, _ := getInstanceStateFactor(point.Factors); f != nil { + outputItems = append(outputItems, ex.InstanceState(*f)) } - if securityGroupRulesFactor, _ := getSecurityGroupRulesFactor(point.Factors); securityGroupRulesFactor != nil { - outputItems = append(outputItems, ex.SecurityGroupRules(*securityGroupRulesFactor, p)) + if f, _ := getSecurityGroupRulesFactor(point.Factors); f != nil { + outputItems = append(outputItems, ex.SecurityGroupRules(*f, p)) + } + + if f, _ := getNetworkACLRulesFactor(point.Factors); f != nil { + outputItems = append(outputItems, ex.NetworkACLRules(*f, p)) } return strings.Join(outputItems, "\n") @@ -73,10 +77,10 @@ func (ex *Explainer) SecurityGroupRules(factor reach.Factor, p reach.Perspective var bodyItems []string - if rules := props.ComponentRules; len(rules) == 0 { + if rules := props.RuleComponents; len(rules) == 0 { bodyItems = append(bodyItems, "no rules that apply to analysis\n") } else { - var ruleViewModels []ruleExplanationViewModel + var ruleViewModels []securityGroupRuleExplanationViewModel for _, rule := range rules { sgRef := ex.analysis.Resources.Get(rule.SecurityGroup) @@ -96,13 +100,13 @@ func (ex *Explainer) SecurityGroupRules(factor reach.Factor, p reach.Perspective case securityGroupRuleMatchBasisSGRef: inclusionReason = fmt.Sprintf( "This rule specifies a security group \"%s\" that is attached to the %s's network interface.", - rule.Match.Value, + rule.Match.Requirement, p.OtherRole, ) case securityGroupRuleMatchBasisIP: inclusionReason = fmt.Sprintf( - "This rule specifies an IP CIDR block \"%s\" that contains the %s's IP address \"%s\".", - originalRule.TargetIPNetworks[0], // TODO: This could show a different network than the matched network, which would be wrong. Include this IPNet in the Match struct to ensure we use the right network here. + "This rule specifies an IP CIDR block \"%s\" that contains the %s's IP address (%s).", + rule.Match.Requirement, p.OtherRole, p.Other.IPAddress, ) @@ -110,13 +114,13 @@ func (ex *Explainer) SecurityGroupRules(factor reach.Factor, p reach.Perspective inclusionReason = fmt.Sprintf("Unknown reason for inclusion. Match basis is '%s'. Please report this.", rule.Match.Basis) } - ruleViewModel := ruleExplanationViewModel{ + model := securityGroupRuleExplanationViewModel{ securityGroupName: sg.Name(), inclusionReason: inclusionReason, allowedTraffic: originalRule.TrafficContent.String(), } - ruleViewModels = append(ruleViewModels, ruleViewModel) + ruleViewModels = append(ruleViewModels, model) } sort.Slice(ruleViewModels, func(i, j int) bool { @@ -167,6 +171,58 @@ func (ex *Explainer) SecurityGroupRules(factor reach.Factor, p reach.Perspective return strings.Join(outputItems, "\n") } +// NetworkACLRules explains the analysis component for the specified network ACL rules factor. +func (ex *Explainer) NetworkACLRules(factor reach.Factor, p reach.Perspective) string { + var outputItems []string + header := fmt.Sprintf( + "%s (including only rules from %s that match %s):", + helper.Bold("network ACL rules"), + p.SelfRole, + p.OtherRole, + ) + outputItems = append(outputItems, header) + + props := factor.Properties.(networkACLRulesFactor) + + var bodyItems []string + + if rules := props.RuleComponentsForwardDirection; len(rules) == 0 { + bodyItems = append(bodyItems, "no rules that apply to analysis\n") + } else { + // forward direction + forwardViewModels := networkACLRuleComponentsToViewModels(props.RuleComponentsForwardDirection, p) + + var forwardExplanation string + for _, model := range forwardViewModels { + forwardExplanation += model.String() + } + bodyItems = append(bodyItems, forwardExplanation) + + // return direction + returnHeader := "rules that affect network traffic returning from destination to source:\n" + bodyItems = append(bodyItems, returnHeader) + + returnViewModels := networkACLRuleComponentsToViewModels(props.RuleComponentsReturnDirection, p) + + var returnExplanation string + for _, model := range returnViewModels { + returnExplanation += model.String() + } + bodyItems = append(bodyItems, returnExplanation) + } + + bodyItems = append(bodyItems, "network traffic allowed based on network ACL rules:") + bodyItems = append(bodyItems, helper.Indent(factor.Traffic.ColorString(), 2)) + + bodyItems = append(bodyItems, "return network traffic allowed based on network ACL rules:") + bodyItems = append(bodyItems, helper.Indent(factor.ReturnTraffic.String(), 2)) + + body := strings.Join(bodyItems, "\n") + outputItems = append(outputItems, helper.Indent(body, 2)) + + return strings.Join(outputItems, "\n") +} + // CheckBothInAWS returns a boolean indicating whether both network points in a network vector are AWS resources. func (ex Explainer) CheckBothInAWS(v reach.NetworkVector) bool { return IsUsedByNetworkPoint(v.Source) && IsUsedByNetworkPoint(v.Destination) @@ -191,23 +247,3 @@ func (ex Explainer) CheckBothInSameSubnet(v reach.NetworkVector) bool { return sourceENI.SubnetID == destinationENI.SubnetID } - -type ruleExplanationViewModel struct { - securityGroupName string - allowedTraffic string - inclusionReason string -} - -func (vm ruleExplanationViewModel) String() string { - output := "- rule\n" - - allowedTrafficHeader := "network traffic allowed:" - allowedTrafficSection := fmt.Sprintf("%s\n%s", allowedTrafficHeader, helper.Indent(vm.allowedTraffic, 2)) - output += helper.Indent(allowedTrafficSection, 4) - - inclusionReasonHeader := "reason for inclusion:" - inclusionReasonSection := fmt.Sprintf("%s\n%s\n", inclusionReasonHeader, helper.Indent(vm.inclusionReason, 2)) - output += helper.Indent(inclusionReasonSection, 4) - - return output -} diff --git a/reach/aws/factors.go b/reach/aws/factors.go index 5784583..847deb6 100644 --- a/reach/aws/factors.go +++ b/reach/aws/factors.go @@ -25,3 +25,13 @@ func getSecurityGroupRulesFactor(factors []reach.Factor) (*reach.Factor, error) return nil, errors.New("no security group rules factor found") } + +func getNetworkACLRulesFactor(factors []reach.Factor) (*reach.Factor, error) { + for _, factor := range factors { + if factor.Kind == FactorKindNetworkACLRules { + return &factor, nil + } + } + + return nil, errors.New("no network ACL rules factor found") +} diff --git a/reach/aws/find_ec2_instance_id.go b/reach/aws/find_ec2_instance_id.go index 3083560..b0a07b8 100644 --- a/reach/aws/find_ec2_instance_id.go +++ b/reach/aws/find_ec2_instance_id.go @@ -7,7 +7,7 @@ import ( // FindEC2InstanceID looks up the instance ID for an EC2 instance using a given resource provider (e.g. an AWS API client) based on the specified search text. The search text can match the entire value or beginning substring for an instance's ID or name tag value, as long as the text matches exactly one EC2 instance. func FindEC2InstanceID(searchText string, provider ResourceProvider) (string, error) { - instances, err := provider.GetAllEC2Instances() + instances, err := provider.AllEC2Instances() if err != nil { return "", err } diff --git a/reach/aws/instance_state_factor.go b/reach/aws/instance_state_factor.go index 81bcdb9..3068b8b 100644 --- a/reach/aws/instance_state_factor.go +++ b/reach/aws/instance_state_factor.go @@ -8,17 +8,21 @@ import ( const FactorKindInstanceState = "InstanceState" func (i EC2Instance) newInstanceStateFactor() reach.Factor { - var tc reach.TrafficContent + var traffic reach.TrafficContent + var returnTraffic reach.TrafficContent if i.isRunning() { - tc = reach.NewTrafficContentForAllTraffic() + traffic = reach.NewTrafficContentForAllTraffic() + returnTraffic = reach.NewTrafficContentForAllTraffic() } else { - tc = reach.NewTrafficContentForNoTraffic() + traffic = reach.NewTrafficContentForNoTraffic() + returnTraffic = reach.NewTrafficContentForNoTraffic() } return reach.Factor{ - Kind: FactorKindInstanceState, - Resource: i.ToResourceReference(), - Traffic: tc, + Kind: FactorKindInstanceState, + Resource: i.ToResourceReference(), + Traffic: traffic, + ReturnTraffic: returnTraffic, } } diff --git a/reach/aws/network_acl.go b/reach/aws/network_acl.go index 4f06222..2eca8e7 100644 --- a/reach/aws/network_acl.go +++ b/reach/aws/network_acl.go @@ -19,3 +19,12 @@ func (nacl NetworkACL) ToResource() reach.Resource { Properties: nacl, } } + +// ToResourceReference returns a resource reference to uniquely identify the network ACL. +func (nacl NetworkACL) ToResourceReference() reach.ResourceReference { + return reach.ResourceReference{ + Domain: ResourceDomainAWS, + Kind: ResourceKindNetworkACL, + ID: nacl.ID, + } +} diff --git a/reach/aws/network_acl_rule.go b/reach/aws/network_acl_rule.go index b871c03..22d1c11 100644 --- a/reach/aws/network_acl_rule.go +++ b/reach/aws/network_acl_rule.go @@ -1,6 +1,7 @@ package aws import ( + "encoding/json" "net" "github.com/luhring/reach/reach" @@ -15,6 +16,23 @@ const ( NetworkACLRuleActionAllow ) +// String returns the string representation of the NetworkACLRuleAction. +func (action NetworkACLRuleAction) String() string { + switch action { + case NetworkACLRuleActionDeny: + return "deny" + case NetworkACLRuleActionAllow: + return "allow" + default: + return "[unknown action]" + } +} + +// MarshalJSON returns the JSON representation of the NetworkACLRuleAction. +func (action NetworkACLRuleAction) MarshalJSON() ([]byte, error) { + return json.Marshal(action.String()) +} + // An NetworkACLRule resource representation. type NetworkACLRule struct { Number int64 @@ -22,3 +40,24 @@ type NetworkACLRule struct { TargetIPNetwork *net.IPNet Action NetworkACLRuleAction } + +// Allows returns a boolean indicating if the rule is allowing traffic. +func (r NetworkACLRule) Allows() bool { + return r.Action == NetworkACLRuleActionAllow +} + +// Denies returns a boolean indicating if the rule is denying traffic. +func (r NetworkACLRule) Denies() bool { + return r.Action == NetworkACLRuleActionDeny +} + +func (r NetworkACLRule) matchByIP(ip net.IP) *networkACLRuleMatch { + if r.TargetIPNetwork.Contains(ip) { + return &networkACLRuleMatch{ + Requirement: *r.TargetIPNetwork, + Value: ip, + } + } + + return nil +} diff --git a/reach/aws/network_acl_rule_direction.go b/reach/aws/network_acl_rule_direction.go new file mode 100644 index 0000000..7f9cd9d --- /dev/null +++ b/reach/aws/network_acl_rule_direction.go @@ -0,0 +1,6 @@ +package aws + +type networkACLRuleDirection string + +const networkACLRuleDirectionInbound networkACLRuleDirection = "inbound" +const networkACLRuleDirectionOutbound networkACLRuleDirection = "outbound" diff --git a/reach/aws/network_acl_rule_explanation_view_model.go b/reach/aws/network_acl_rule_explanation_view_model.go new file mode 100644 index 0000000..d0d1c60 --- /dev/null +++ b/reach/aws/network_acl_rule_explanation_view_model.go @@ -0,0 +1,54 @@ +package aws + +import ( + "fmt" + + "github.com/luhring/reach/reach" + "github.com/luhring/reach/reach/helper" +) + +type networkACLRuleExplanationViewModel struct { + ruleNumber int64 + allowedTraffic string + inclusionReason string +} + +func newNetworkACLRuleExplanationViewModel(rule networkACLRulesFactorComponent, p reach.Perspective) networkACLRuleExplanationViewModel { + inclusionReason := fmt.Sprintf( + "This rule specifies an IP CIDR block \"%s\" that contains the %s's IP address (%s).", + rule.Match.Requirement.String(), + p.OtherRole, + p.Other.IPAddress, + ) + + return networkACLRuleExplanationViewModel{ + ruleNumber: rule.RuleNumber, + allowedTraffic: rule.Traffic.String(), + inclusionReason: inclusionReason, + } +} + +func (model networkACLRuleExplanationViewModel) String() string { + output := fmt.Sprintf("- rule # %d\n", model.ruleNumber) + + allowedTrafficHeader := "network traffic allowed:" + allowedTrafficSection := fmt.Sprintf("%s\n%s", allowedTrafficHeader, helper.Indent(model.allowedTraffic, 2)) + output += helper.Indent(allowedTrafficSection, 4) + + inclusionReasonHeader := "reason for inclusion:" + inclusionReasonSection := fmt.Sprintf("%s\n%s\n", inclusionReasonHeader, helper.Indent(model.inclusionReason, 2)) + output += helper.Indent(inclusionReasonSection, 4) + + return output +} + +func networkACLRuleComponentsToViewModels(rules []networkACLRulesFactorComponent, p reach.Perspective) []networkACLRuleExplanationViewModel { + var models []networkACLRuleExplanationViewModel + + for _, rule := range rules { + model := newNetworkACLRuleExplanationViewModel(rule, p) + models = append(models, model) + } + + return models +} diff --git a/reach/aws/network_acl_rule_match.go b/reach/aws/network_acl_rule_match.go new file mode 100644 index 0000000..308db1a --- /dev/null +++ b/reach/aws/network_acl_rule_match.go @@ -0,0 +1,8 @@ +package aws + +import "net" + +type networkACLRuleMatch struct { + Requirement net.IPNet + Value net.IP +} diff --git a/reach/aws/network_acl_rules_factor.go b/reach/aws/network_acl_rules_factor.go new file mode 100644 index 0000000..03e825b --- /dev/null +++ b/reach/aws/network_acl_rules_factor.go @@ -0,0 +1,144 @@ +package aws + +import ( + "fmt" + "sort" + + "github.com/luhring/reach/reach" +) + +// FactorKindNetworkACLRules specifies the unique name for the network ACL rules kind of factor. +const FactorKindNetworkACLRules = "NetworkACLRules" + +const newNetworkACLRulesFactorErrFmt = "unable to compute network ACL rules factor: %v" + +type networkACLRulesFactor struct { + RuleComponentsForwardDirection []networkACLRulesFactorComponent + RuleComponentsReturnDirection []networkACLRulesFactorComponent +} + +func (eni ElasticNetworkInterface) newNetworkACLRulesFactor( + rc *reach.ResourceCollection, + p reach.Perspective, + awsP perspective, + targetENI *ElasticNetworkInterface, +) (*reach.Factor, error) { + subnetResource := rc.Get(reach.ResourceReference{ + Domain: ResourceDomainAWS, + Kind: ResourceKindSubnet, + ID: eni.SubnetID, + }) + if subnetResource == nil { + return nil, fmt.Errorf("couldn't find subnet: %s", eni.SubnetID) + } + subnet := subnetResource.Properties.(Subnet) + + ref := reach.ResourceReference{ + Domain: ResourceDomainAWS, + Kind: ResourceKindNetworkACL, + ID: subnet.NetworkACLID, + } + + networkACLResource := rc.Get(ref) + if networkACLResource == nil { + return nil, fmt.Errorf("couldn't find network ACL: %s", subnet.NetworkACLID) + } + networkACL := networkACLResource.Properties.(NetworkACL) + + forwardTraffic, forwardComponents, err := networkACL.effectOnForwardTraffic(p, awsP) + if err != nil { + return nil, fmt.Errorf(newNetworkACLRulesFactorErrFmt, err) + } + + returnTraffic, returnComponents, err := networkACL.effectOnReturnTraffic(p, awsP) + if err != nil { + return nil, fmt.Errorf(newNetworkACLRulesFactorErrFmt, err) + } + + props := networkACLRulesFactor{ + RuleComponentsForwardDirection: forwardComponents, + RuleComponentsReturnDirection: returnComponents, + } + + return &reach.Factor{ + Kind: FactorKindNetworkACLRules, + Resource: eni.ToResourceReference(), + Traffic: forwardTraffic, + ReturnTraffic: returnTraffic, + Properties: props, + }, nil +} + +func (nacl NetworkACL) effectOnForwardTraffic(p reach.Perspective, awsP perspective) (reach.TrafficContent, []networkACLRulesFactorComponent, error) { + return nacl.factorComponents(awsP.networkACLRuleDirectionForForwardTraffic, p, awsP) +} + +func (nacl NetworkACL) effectOnReturnTraffic(p reach.Perspective, awsP perspective) (reach.TrafficContent, []networkACLRulesFactorComponent, error) { + return nacl.factorComponents(awsP.networkACLRuleDirectionForReturnTraffic, p, awsP) +} + +func (nacl NetworkACL) rulesForDirection(direction networkACLRuleDirection) []NetworkACLRule { + if direction == networkACLRuleDirectionOutbound { + return nacl.OutboundRules + } + + return nacl.InboundRules +} + +func (nacl NetworkACL) factorComponents(direction networkACLRuleDirection, p reach.Perspective, awsP perspective) (reach.TrafficContent, []networkACLRulesFactorComponent, error) { + rules := nacl.rulesForDirection(direction) + + sort.Slice(rules, func(i, j int) bool { + return rules[i].Number < rules[j].Number + }) + + var trafficContentSegments []reach.TrafficContent + var ruleComponents []networkACLRulesFactorComponent + decidedTraffic := reach.NewTrafficContentForNoTraffic() + + for _, rule := range rules { + // Make sure rule matches + match := rule.matchByIP(p.Other.IPAddress) + if match == nil { + continue // this rule doesn't match + } + + if rule.Allows() { + // Determine what subset of rule traffic affects outcome + effectiveTraffic, err := rule.TrafficContent.Subtract(decidedTraffic) + if err != nil { + return reach.TrafficContent{}, nil, fmt.Errorf(newNetworkACLRulesFactorErrFmt, err) + } + + // add the allowed traffic to the trafficContentSegments + trafficContentSegments = append(trafficContentSegments, effectiveTraffic) + + // add to ruleComponents for the explanation + ruleComponents = append(ruleComponents, networkACLRulesFactorComponent{ + NetworkACL: nacl.ToResourceReference(), + RuleDirection: direction, + RuleNumber: rule.Number, + Match: *match, + Traffic: effectiveTraffic, + }) + } + + var err error + decidedTraffic, err = reach.NewTrafficContentFromMergingMultiple( + []reach.TrafficContent{ + decidedTraffic, + rule.TrafficContent, + }, + ) + if err != nil { + return reach.TrafficContent{}, nil, fmt.Errorf(newNetworkACLRulesFactorErrFmt, err) + } + } + + traffic, err := reach.NewTrafficContentFromMergingMultiple(trafficContentSegments) + if err != nil { + return reach.TrafficContent{}, nil, fmt.Errorf(newNetworkACLRulesFactorErrFmt, err) + } + + return traffic, ruleComponents, nil +} diff --git a/reach/aws/network_acl_rules_factor_component.go b/reach/aws/network_acl_rules_factor_component.go new file mode 100644 index 0000000..38ea534 --- /dev/null +++ b/reach/aws/network_acl_rules_factor_component.go @@ -0,0 +1,11 @@ +package aws + +import "github.com/luhring/reach/reach" + +type networkACLRulesFactorComponent struct { + NetworkACL reach.ResourceReference + RuleDirection networkACLRuleDirection + RuleNumber int64 + Match networkACLRuleMatch + Traffic reach.TrafficContent +} diff --git a/reach/aws/network_interface_attachment.go b/reach/aws/network_interface_attachment.go index ab601ba..e3674b9 100644 --- a/reach/aws/network_interface_attachment.go +++ b/reach/aws/network_interface_attachment.go @@ -1,8 +1,33 @@ package aws +import "github.com/luhring/reach/reach" + // A NetworkInterfaceAttachment resource representation. type NetworkInterfaceAttachment struct { ID string ElasticNetworkInterfaceID string DeviceIndex int64 // e.g. 0 for "eth0" } + +// Dependencies returns a collection of the network interface attachment's resource dependencies. +func (attachment NetworkInterfaceAttachment) Dependencies(provider ResourceProvider) (*reach.ResourceCollection, error) { + rc := reach.NewResourceCollection() + + eni, err := provider.ElasticNetworkInterface(attachment.ElasticNetworkInterfaceID) + if err != nil { + return nil, err + } + rc.Put(reach.ResourceReference{ + Domain: ResourceDomainAWS, + Kind: ResourceKindElasticNetworkInterface, + ID: eni.ID, + }, eni.ToResource()) + + eniDependencies, err := eni.Dependencies(provider) + if err != nil { + return nil, err + } + rc.Merge(eniDependencies) + + return rc, nil +} diff --git a/reach/aws/perspective.go b/reach/aws/perspective.go index bcc7a34..9db83ab 100644 --- a/reach/aws/perspective.go +++ b/reach/aws/perspective.go @@ -1,24 +1,44 @@ package aws type perspective struct { - getSecurityGroupRules func(sg SecurityGroup) []SecurityGroupRule - ruleDirection securityGroupRuleDirection + securityGroupRules func(sg SecurityGroup) []SecurityGroupRule + securityGroupRuleDirection securityGroupRuleDirection + networkACLRulesForForwardTraffic func(nacl NetworkACL) []NetworkACLRule + networkACLRuleDirectionForForwardTraffic networkACLRuleDirection + networkACLRulesForReturnTraffic func(nacl NetworkACL) []NetworkACLRule + networkACLRuleDirectionForReturnTraffic networkACLRuleDirection } func newPerspectiveSourceOriented() perspective { return perspective{ - getSecurityGroupRules: func(sg SecurityGroup) []SecurityGroupRule { + securityGroupRules: func(sg SecurityGroup) []SecurityGroupRule { return sg.OutboundRules }, - ruleDirection: securityGroupRuleDirectionOutbound, + securityGroupRuleDirection: securityGroupRuleDirectionOutbound, + networkACLRulesForForwardTraffic: func(nacl NetworkACL) []NetworkACLRule { + return nacl.OutboundRules + }, + networkACLRuleDirectionForForwardTraffic: networkACLRuleDirectionOutbound, + networkACLRulesForReturnTraffic: func(nacl NetworkACL) []NetworkACLRule { + return nacl.InboundRules + }, + networkACLRuleDirectionForReturnTraffic: networkACLRuleDirectionInbound, } } func newPerspectiveDestinationOriented() perspective { return perspective{ - getSecurityGroupRules: func(sg SecurityGroup) []SecurityGroupRule { + securityGroupRules: func(sg SecurityGroup) []SecurityGroupRule { return sg.InboundRules }, - ruleDirection: securityGroupRuleDirectionInbound, + securityGroupRuleDirection: securityGroupRuleDirectionInbound, + networkACLRulesForForwardTraffic: func(nacl NetworkACL) []NetworkACLRule { + return nacl.InboundRules + }, + networkACLRuleDirectionForForwardTraffic: networkACLRuleDirectionInbound, + networkACLRulesForReturnTraffic: func(nacl NetworkACL) []NetworkACLRule { + return nacl.OutboundRules + }, + networkACLRuleDirectionForReturnTraffic: networkACLRuleDirectionOutbound, } } diff --git a/reach/aws/resource_provider.go b/reach/aws/resource_provider.go index 58300fa..2e00df1 100644 --- a/reach/aws/resource_provider.go +++ b/reach/aws/resource_provider.go @@ -2,13 +2,13 @@ package aws // The ResourceProvider interface wraps all of the necessary methods for accessing AWS-specific resources. type ResourceProvider interface { - GetAllEC2Instances() ([]EC2Instance, error) - GetEC2Instance(id string) (*EC2Instance, error) - GetElasticNetworkInterface(id string) (*ElasticNetworkInterface, error) - GetNetworkACL(id string) (*NetworkACL, error) - GetRouteTable(id string) (*RouteTable, error) - GetSecurityGroup(id string) (*SecurityGroup, error) - GetSecurityGroupReference(id, accountID string) (*SecurityGroupReference, error) - GetSubnet(id string) (*Subnet, error) - GetVPC(id string) (*VPC, error) + AllEC2Instances() ([]EC2Instance, error) + EC2Instance(id string) (*EC2Instance, error) + ElasticNetworkInterface(id string) (*ElasticNetworkInterface, error) + NetworkACL(id string) (*NetworkACL, error) + RouteTable(id string) (*RouteTable, error) + SecurityGroup(id string) (*SecurityGroup, error) + SecurityGroupReference(id, accountID string) (*SecurityGroupReference, error) + Subnet(id string) (*Subnet, error) + VPC(id string) (*VPC, error) } diff --git a/reach/aws/route_table.go b/reach/aws/route_table.go index ba2af65..85464d0 100644 --- a/reach/aws/route_table.go +++ b/reach/aws/route_table.go @@ -24,7 +24,7 @@ func (rt RouteTable) ToResource() reach.Resource { func (rt RouteTable) Dependencies(provider ResourceProvider) (*reach.ResourceCollection, error) { rc := reach.NewResourceCollection() - vpc, err := provider.GetVPC(rt.VPCID) + vpc, err := provider.VPC(rt.VPCID) if err != nil { return nil, err } diff --git a/reach/aws/security_group.go b/reach/aws/security_group.go index 9e5cf6a..61d0df3 100644 --- a/reach/aws/security_group.go +++ b/reach/aws/security_group.go @@ -31,7 +31,7 @@ func (sg SecurityGroup) ToResource() reach.Resource { func (sg SecurityGroup) Dependencies(provider ResourceProvider) (*reach.ResourceCollection, error) { rc := reach.NewResourceCollection() - vpc, err := provider.GetVPC(sg.VPCID) + vpc, err := provider.VPC(sg.VPCID) if err != nil { return nil, err } @@ -47,7 +47,7 @@ func (sg SecurityGroup) Dependencies(provider ResourceProvider) (*reach.Resource // TODO: sg ref IDs shouldn't be strings, they should be pointers, and this check should be for nil not "" if sgRefID := rule.TargetSecurityGroupReferenceID; sgRefID != "" { - sgRef, err := provider.GetSecurityGroupReference(sgRefID, rule.TargetSecurityGroupReferenceAccountID) + sgRef, err := provider.SecurityGroupReference(sgRefID, rule.TargetSecurityGroupReferenceAccountID) if err != nil { return nil, err } diff --git a/reach/aws/security_group_rule.go b/reach/aws/security_group_rule.go index d02dd60..0fbc666 100644 --- a/reach/aws/security_group_rule.go +++ b/reach/aws/security_group_rule.go @@ -18,8 +18,9 @@ func (rule SecurityGroupRule) matchByIP(ip net.IP) *securityGroupRuleMatch { for _, network := range rule.TargetIPNetworks { if network.Contains(ip) { return &securityGroupRuleMatch{ - Basis: securityGroupRuleMatchBasisIP, - Value: ip, + Basis: securityGroupRuleMatchBasisIP, + Requirement: network, + Value: ip, } } } @@ -32,8 +33,9 @@ func (rule SecurityGroupRule) matchBySecurityGroup(eni *ElasticNetworkInterface) for _, targetENISecurityGroupID := range eni.SecurityGroupIDs { if rule.TargetSecurityGroupReferenceID == targetENISecurityGroupID { // TODO: Handle SG Account ID return &securityGroupRuleMatch{ - Basis: securityGroupRuleMatchBasisSGRef, - Value: targetENISecurityGroupID, + Basis: securityGroupRuleMatchBasisSGRef, + Requirement: rule.TargetSecurityGroupReferenceID, + Value: targetENISecurityGroupID, } } } diff --git a/reach/aws/security_group_rule_explanation_view_model.go b/reach/aws/security_group_rule_explanation_view_model.go new file mode 100644 index 0000000..d14df3e --- /dev/null +++ b/reach/aws/security_group_rule_explanation_view_model.go @@ -0,0 +1,27 @@ +package aws + +import ( + "fmt" + + "github.com/luhring/reach/reach/helper" +) + +type securityGroupRuleExplanationViewModel struct { + securityGroupName string + allowedTraffic string + inclusionReason string +} + +func (model securityGroupRuleExplanationViewModel) String() string { + output := "- rule\n" + + allowedTrafficHeader := "network traffic allowed:" + allowedTrafficSection := fmt.Sprintf("%s\n%s", allowedTrafficHeader, helper.Indent(model.allowedTraffic, 2)) + output += helper.Indent(allowedTrafficSection, 4) + + inclusionReasonHeader := "reason for inclusion:" + inclusionReasonSection := fmt.Sprintf("%s\n%s\n", inclusionReasonHeader, helper.Indent(model.inclusionReason, 2)) + output += helper.Indent(inclusionReasonSection, 4) + + return output +} diff --git a/reach/aws/security_group_rule_match.go b/reach/aws/security_group_rule_match.go new file mode 100644 index 0000000..85d3fdb --- /dev/null +++ b/reach/aws/security_group_rule_match.go @@ -0,0 +1,7 @@ +package aws + +type securityGroupRuleMatch struct { + Basis securityGroupRuleMatchBasis + Requirement interface{} + Value interface{} +} diff --git a/reach/aws/security_group_rule_match_basis.go b/reach/aws/security_group_rule_match_basis.go new file mode 100644 index 0000000..f3be8d9 --- /dev/null +++ b/reach/aws/security_group_rule_match_basis.go @@ -0,0 +1,18 @@ +package aws + +type securityGroupRuleMatchBasis string + +const securityGroupRuleMatchBasisIP securityGroupRuleMatchBasis = "IP" +const securityGroupRuleMatchBasisSGRef securityGroupRuleMatchBasis = "SecurityGroupReference" + +// String returns the string representation of a security group rule match. +func (basis securityGroupRuleMatchBasis) String() string { + switch basis { + case securityGroupRuleMatchBasisIP: + return "IP address" + case securityGroupRuleMatchBasisSGRef: + return "attached security group" + default: + return "[unknown match basis]" + } +} diff --git a/reach/aws/security_group_rules_factor.go b/reach/aws/security_group_rules_factor.go index b61511a..c2a085b 100644 --- a/reach/aws/security_group_rules_factor.go +++ b/reach/aws/security_group_rules_factor.go @@ -7,37 +7,8 @@ import ( // FactorKindSecurityGroupRules specifies the unique name for the security group rules kind of factor. const FactorKindSecurityGroupRules = "SecurityGroupRules" -type securityGroupRuleMatchBasis string - -const securityGroupRuleMatchBasisIP securityGroupRuleMatchBasis = "IP" -const securityGroupRuleMatchBasisSGRef securityGroupRuleMatchBasis = "SecurityGroupReference" - type securityGroupRulesFactor struct { - ComponentRules []securityGroupRulesFactorComponent -} - -type securityGroupRulesFactorComponent struct { - SecurityGroup reach.ResourceReference - RuleDirection securityGroupRuleDirection - RuleIndex int - Match securityGroupRuleMatch -} - -type securityGroupRuleMatch struct { - Basis securityGroupRuleMatchBasis - Value interface{} -} - -// String returns the string representation of a security group rule match. -func (basis securityGroupRuleMatchBasis) String() string { - switch basis { - case securityGroupRuleMatchBasisIP: - return "IP address" - case securityGroupRuleMatchBasisSGRef: - return "attached security group" - default: - return "[unknown match basis]" - } + RuleComponents []securityGroupRulesFactorComponent } func (eni ElasticNetworkInterface) newSecurityGroupRulesFactor( @@ -46,7 +17,7 @@ func (eni ElasticNetworkInterface) newSecurityGroupRulesFactor( awsP perspective, targetENI *ElasticNetworkInterface, ) (*reach.Factor, error) { - var componentRules []securityGroupRulesFactorComponent + var ruleComponents []securityGroupRulesFactorComponent var trafficContentSegments []reach.TrafficContent for _, id := range eni.SecurityGroupIDs { @@ -58,7 +29,7 @@ func (eni ElasticNetworkInterface) newSecurityGroupRulesFactor( sg := rc.Get(ref).Properties.(SecurityGroup) - for ruleIndex, rule := range awsP.getSecurityGroupRules(sg) { + for ruleIndex, rule := range awsP.securityGroupRules(sg) { var match *securityGroupRuleMatch // check ip match @@ -72,13 +43,14 @@ func (eni ElasticNetworkInterface) newSecurityGroupRulesFactor( if match != nil { component := securityGroupRulesFactorComponent{ SecurityGroup: ref, - RuleDirection: awsP.ruleDirection, + RuleDirection: awsP.securityGroupRuleDirection, RuleIndex: ruleIndex, Match: *match, + Traffic: rule.TrafficContent, } trafficContentSegments = append(trafficContentSegments, rule.TrafficContent) - componentRules = append(componentRules, component) + ruleComponents = append(ruleComponents, component) } } } @@ -89,13 +61,14 @@ func (eni ElasticNetworkInterface) newSecurityGroupRulesFactor( } props := securityGroupRulesFactor{ - ComponentRules: componentRules, + RuleComponents: ruleComponents, } return &reach.Factor{ - Kind: FactorKindSecurityGroupRules, - Resource: eni.ToResourceReference(), - Traffic: tc, - Properties: props, + Kind: FactorKindSecurityGroupRules, + Resource: eni.ToResourceReference(), + Traffic: tc, + ReturnTraffic: reach.NewTrafficContentForAllTraffic(), + Properties: props, }, nil } diff --git a/reach/aws/security_group_rules_factor_component.go b/reach/aws/security_group_rules_factor_component.go new file mode 100644 index 0000000..11baf30 --- /dev/null +++ b/reach/aws/security_group_rules_factor_component.go @@ -0,0 +1,11 @@ +package aws + +import "github.com/luhring/reach/reach" + +type securityGroupRulesFactorComponent struct { + SecurityGroup reach.ResourceReference + RuleDirection securityGroupRuleDirection + RuleIndex int + Match securityGroupRuleMatch + Traffic reach.TrafficContent +} diff --git a/reach/aws/subnet.go b/reach/aws/subnet.go index ae37e3e..952d624 100644 --- a/reach/aws/subnet.go +++ b/reach/aws/subnet.go @@ -7,8 +7,9 @@ const ResourceKindSubnet = "Subnet" // A Subnet resource representation. type Subnet struct { - ID string - VPCID string + ID string + NetworkACLID string + VPCID string } // ToResource returns the subnet converted to a generalized Reach resource. @@ -23,7 +24,17 @@ func (s Subnet) ToResource() reach.Resource { func (s Subnet) Dependencies(provider ResourceProvider) (*reach.ResourceCollection, error) { rc := reach.NewResourceCollection() - vpc, err := provider.GetVPC(s.VPCID) + networkACL, err := provider.NetworkACL(s.NetworkACLID) + if err != nil { + return nil, err + } + rc.Put(reach.ResourceReference{ + Domain: ResourceDomainAWS, + Kind: ResourceKindNetworkACL, + ID: s.NetworkACLID, + }, networkACL.ToResource()) + + vpc, err := provider.VPC(s.VPCID) if err != nil { return nil, err } diff --git a/reach/aws/vector_analyzer.go b/reach/aws/vector_analyzer.go index fd98f5f..ebf78ad 100644 --- a/reach/aws/vector_analyzer.go +++ b/reach/aws/vector_analyzer.go @@ -30,6 +30,7 @@ func (analyzer VectorAnalyzer) factorsForPerspective(p reach.Perspective) ([]rea } if resourceRef.Kind == ResourceKindElasticNetworkInterface { + // Get ready to evaluate factors eni := analyzer.resourceCollection.Get(resourceRef).Properties.(ElasticNetworkInterface) targetENI := ElasticNetworkInterfaceFromNetworkPoint(p.Other, analyzer.resourceCollection) @@ -40,6 +41,12 @@ func (analyzer VectorAnalyzer) factorsForPerspective(p reach.Perspective) ([]rea awsP = newPerspectiveDestinationOriented() } + // Ensure this is scenario that Reach can analyze + if !sameVPC(&eni, targetENI) { + return nil, fmt.Errorf("error: reach is not yet able to analyze EC2 instances in different VPCs, but that's coming soon! (VPCs: %s, %s)", eni.VPCID, targetENI.VPCID) + } + + // Evaluate factors securityGroupRulesFactor, err := eni.newSecurityGroupRulesFactor( analyzer.resourceCollection, p, @@ -52,13 +59,24 @@ func (analyzer VectorAnalyzer) factorsForPerspective(p reach.Perspective) ([]rea factors = append(factors, *securityGroupRulesFactor) - if !sameVPC(&eni, targetENI) { - return nil, fmt.Errorf("error: reach is not yet able to analyze EC2 instances in different VPCs, but that's coming soon! (VPCs: %s, %s)", eni.VPCID, targetENI.VPCID) + if sameSubnet(&eni, targetENI) { + // There's nothing further to evaluate for this ENI + continue } - if !sameSubnet(&eni, targetENI) { - return nil, fmt.Errorf("error: reach is not yet able to analyze EC2 instances in different subnets, but that's coming soon! (subnets: %s, %s)", eni.SubnetID, targetENI.SubnetID) + // Different subnets, same VPC + + networkACLRulesFactor, err := eni.newNetworkACLRulesFactor( + analyzer.resourceCollection, + p, + awsP, + targetENI, + ) + if err != nil { + return nil, err } + + factors = append(factors, *networkACLRulesFactor) } } } diff --git a/reach/explainer/explainer.go b/reach/explainer/explainer.go index 4a5c71a..d7b10eb 100644 --- a/reach/explainer/explainer.go +++ b/reach/explainer/explainer.go @@ -5,6 +5,8 @@ import ( "log" "strings" + "github.com/mgutz/ansi" + "github.com/luhring/reach/reach" "github.com/luhring/reach/reach/aws" "github.com/luhring/reach/reach/helper" @@ -60,8 +62,11 @@ func (ex *Explainer) ExplainNetworkVector(v reach.NetworkVector) string { outputSections = append(outputSections, helper.Indent(destinationContent, 2)) // final results - results := fmt.Sprintf("%s\n%s", helper.Bold("network traffic allowed from source to destination:"), v.Traffic.ColorStringWithSymbols()) - outputSections = append(outputSections, results) + forwardResults := fmt.Sprintf("%s\n%s", helper.Bold("network traffic allowed from source to destination:"), v.Traffic.ColorStringWithSymbols()) + outputSections = append(outputSections, forwardResults) + + returnResults := fmt.Sprintf("%s\n%s", helper.Bold("network traffic allowed to return from destination to source:"), v.ReturnTraffic.StringWithSymbols()) + outputSections = append(outputSections, returnResults) return strings.Join(outputSections, "\n") } @@ -127,3 +132,40 @@ func (ex *Explainer) NetworkPointName(point reach.NetworkPoint) string { return output } + +// WarningsFromRestrictedReturnPath returns a slice of warning strings based on the input slice of restricted protocols. +func WarningsFromRestrictedReturnPath(restrictedProtocols []reach.RestrictedProtocol) (bool, string) { + if len(restrictedProtocols) == 0 { + return false, "" + } + + var warnings []string + + for _, rp := range restrictedProtocols { + var warning string + + if rp.Protocol == reach.ProtocolTCP { // We have a specific message based on the knowledge that the protocol is TCP. + if rp.NoReturnTraffic { + warning = ansi.Color("All TCP connection attempts will be unsuccessful. No TCP traffic is allowed to return to the source.", "red+b") + } else { + warning = ansi.Color("TCP connection attempts might be unsuccessful. TCP traffic is allowed to return to the source only at particular source ports.", "yellow+b") + } + } else { + firstSentence := fmt.Sprintf("%s-based communication might be unsuccessful.", rp.Protocol) + + var secondSentence string + + if rp.NoReturnTraffic { + secondSentence = fmt.Sprintf("No %s traffic is able to return to the source.", rp.Protocol) + } else { + secondSentence = fmt.Sprintf("Some %s traffic is unable to return to the source.", rp.Protocol) + } + + warning = ansi.Color(fmt.Sprintf("%s %s", firstSentence, secondSentence), "yellow+b") + } + + warnings = append(warnings, warning) + } + + return true, "warnings from return traffic obstructions:\n" + strings.Join(warnings, "\n") +} diff --git a/reach/factor.go b/reach/factor.go index bd4d808..a59b486 100644 --- a/reach/factor.go +++ b/reach/factor.go @@ -2,8 +2,9 @@ package reach // A Factor describes how a particular component of the ingested resources has an impact on the network traffic allowed to flow from a source to a destination. type Factor struct { - Kind string - Resource ResourceReference - Traffic TrafficContent - Properties interface{} `json:"Properties,omitempty"` + Kind string + Resource ResourceReference + Traffic TrafficContent + ReturnTraffic TrafficContent + Properties interface{} `json:"Properties,omitempty"` } diff --git a/reach/network_vector.go b/reach/network_vector.go index fc5f4f1..2d68d35 100644 --- a/reach/network_vector.go +++ b/reach/network_vector.go @@ -8,10 +8,11 @@ import ( // A NetworkVector represents the path between two network points that's able to be analyzed in terms of what kind of network traffic is allowed to flow from point to point. type NetworkVector struct { - ID string - Source NetworkPoint - Destination NetworkPoint - Traffic *TrafficContent + ID string + Source NetworkPoint + Destination NetworkPoint + Traffic *TrafficContent + ReturnTraffic *TrafficContent } // NewNetworkVector creates a new network vector given a source and a destination network point. @@ -38,6 +39,9 @@ func (v NetworkVector) String() string { output += "\n" output += v.Traffic.String() output += "\n" + output += "network traffic allowed to return from destination to source:\n" + output += "\n" + output += v.ReturnTraffic.String() } return output diff --git a/reach/protocol.go b/reach/protocol.go index f151cd8..fd40e8a 100644 --- a/reach/protocol.go +++ b/reach/protocol.go @@ -40,6 +40,11 @@ func (p Protocol) IsCustomProtocol() bool { return p != ProtocolICMPv4 && p != ProtocolTCP && p != ProtocolUDP && p != ProtocolICMPv6 } +// String returns the common name of the IP protocol. +func (p Protocol) String() string { + return ProtocolName(p) +} + // ProtocolName returns the name of an IP protocol given the protocol's assigned number. func ProtocolName(protocol Protocol) string { switch protocol { diff --git a/reach/protocol_content.go b/reach/protocol_content.go index 38e2251..29e23e0 100644 --- a/reach/protocol_content.go +++ b/reach/protocol_content.go @@ -81,6 +81,16 @@ func (pc ProtocolContent) empty() bool { } } +func (pc ProtocolContent) complete() bool { + if pc.isTCPOrUDP() { + return pc.Ports.Complete() + } else if pc.isICMPv4OrICMPv6() { + return pc.ICMP.Complete() + } else { + return *pc.CustomProtocolHasContent + } +} + // String returns the string representation of the protocol content. func (pc ProtocolContent) String() string { protocolName := ProtocolName(pc.Protocol) diff --git a/reach/set/icmp_set.go b/reach/set/icmp_set.go index a2fb0df..94dc236 100644 --- a/reach/set/icmp_set.go +++ b/reach/set/icmp_set.go @@ -245,9 +245,9 @@ func encodeICMPTypeCode(icmpType, icmpCode uint) uint16 { func decodeICMPTypeCode(value uint16) ICMPTypeCode { const bitSize = 8 - var icmpType uint8 = uint8((value & 0b1111111100000000) >> bitSize) + var icmpType = uint8((value & 0b1111111100000000) >> bitSize) - var icmpCode uint8 = uint8((value & 0b0000000011111111)) + var icmpCode = uint8(value & 0b0000000011111111) return ICMPTypeCode{icmpType, icmpCode} } diff --git a/reach/traffic_content.go b/reach/traffic_content.go index 29439c8..b2de244 100644 --- a/reach/traffic_content.go +++ b/reach/traffic_content.go @@ -2,6 +2,7 @@ package reach import ( "encoding/json" + "fmt" "sort" "strings" @@ -139,6 +140,17 @@ func TrafficContentsFromFactors(factors []Factor) []TrafficContent { return result } +// ReturnTrafficContentsFromFactors returns distinct TrafficContent representations from the input factors's return traffic. +func ReturnTrafficContentsFromFactors(factors []Factor) []TrafficContent { + var result []TrafficContent + + for _, factor := range factors { + result = append(result, factor.ReturnTraffic) + } + + return result +} + // Merge performs a set merge operation on two TrafficContents. func (tc *TrafficContent) Merge(other TrafficContent) (TrafficContent, error) { if tc.All() || other.All() { @@ -216,6 +228,30 @@ func (tc *TrafficContent) Intersect(other TrafficContent) (TrafficContent, error return result, nil } +// Subtract performs a set subtraction (self - other) on two TrafficContents. +func (tc *TrafficContent) Subtract(other TrafficContent) (TrafficContent, error) { + if tc.None() || other.All() { + return NewTrafficContentForNoTraffic(), nil + } + + if other.None() { + return *tc, nil + } + + result := newTrafficContent() + + for p, pc := range tc.protocols { + pcDifference, err := pc.subtract(other.protocol(p)) + if err != nil { + return TrafficContent{}, fmt.Errorf("unable to subtract traffic content: %v", err) + } + + result.setProtocolContent(p, pcDifference) + } + + return result, nil +} + // MarshalJSON returns the JSON representation of the TrafficContent. func (tc TrafficContent) MarshalJSON() ([]byte, error) { if tc.None() { @@ -392,6 +428,67 @@ func (tc TrafficContent) None() bool { return tc.indicator == trafficContentIndicatorNone || (tc.indicator == trafficContentIndicatorUnset && len(tc.protocols) == 0) } +// RestrictedProtocol describes an IP protocol whose return traffic has been restricted +type RestrictedProtocol struct { + Protocol Protocol + NoReturnTraffic bool +} + +// ProtocolsWithRestrictedReturnPath returns a list of IP protocols whose communication would be disrupted if return traffic was restricted. +func (tc TrafficContent) ProtocolsWithRestrictedReturnPath(returnTraffic TrafficContent) []RestrictedProtocol { + var restrictedProtocols []RestrictedProtocol + var protocolsToAssess []Protocol + + // if tc specifies all traffic, warn about protocols not listed in returnTraffic + if tc.All() { + protocolsToAssess = []Protocol{ + ProtocolTCP, + ProtocolUDP, + ProtocolICMPv4, + ProtocolICMPv6, + } + } else { + for p := range tc.protocols { + protocolsToAssess = append(protocolsToAssess, p) + } + } + + for _, protocol := range protocolsToAssess { + returnTrafficProtocolContent := returnTraffic.protocol(protocol) + + if !returnTrafficProtocolContent.complete() { + + noReturnTraffic := false + + if returnTrafficProtocolContent.empty() { // return traffic is completely blocked + noReturnTraffic = true + } + + restrictedProtocols = append(restrictedProtocols, RestrictedProtocol{ + Protocol: protocol, + NoReturnTraffic: noReturnTraffic, + }) + } + } + + return restrictedProtocols +} + +// Protocols returns a slice of the IP protocols described by the traffic content. +func (tc TrafficContent) Protocols() []Protocol { + if tc.protocols == nil { + return nil + } + + var result []Protocol + + for protocol := range tc.protocols { + result = append(result, protocol) + } + + return result +} + func (tc *TrafficContent) setProtocolContent(p Protocol, content ProtocolContent) { tc.indicator = trafficContentIndicatorUnset