diff --git a/docs/src/_parts/commands/k8s_bootstrap.md b/docs/src/_parts/commands/k8s_bootstrap.md index 73bedab2f..1ad241588 100644 --- a/docs/src/_parts/commands/k8s_bootstrap.md +++ b/docs/src/_parts/commands/k8s_bootstrap.md @@ -13,7 +13,7 @@ k8s bootstrap [flags] ### Options ``` - --address string microcluster address, defaults to the node IP address + --address string microcluster address or CIDR, defaults to the node IP address --file string path to the YAML file containing your custom cluster bootstrap configuration. Use '-' to read from stdin. -h, --help help for bootstrap --interactive interactively configure the most important cluster options diff --git a/docs/src/_parts/commands/k8s_join-cluster.md b/docs/src/_parts/commands/k8s_join-cluster.md index a1693f57a..d485b902a 100644 --- a/docs/src/_parts/commands/k8s_join-cluster.md +++ b/docs/src/_parts/commands/k8s_join-cluster.md @@ -9,7 +9,7 @@ k8s join-cluster [flags] ### Options ``` - --address string microcluster address, defaults to the node IP address + --address string microcluster address or CIDR, defaults to the node IP address --file string path to the YAML file containing your custom cluster join configuration. Use '-' to read from stdin. -h, --help help for join-cluster --name string node name, defaults to hostname diff --git a/src/k8s/cmd/k8s/k8s_bootstrap.go b/src/k8s/cmd/k8s/k8s_bootstrap.go index 127d3b9cc..5a4dacc9b 100644 --- a/src/k8s/cmd/k8s/k8s_bootstrap.go +++ b/src/k8s/cmd/k8s/k8s_bootstrap.go @@ -16,7 +16,6 @@ import ( cmdutil "github.com/canonical/k8s/cmd/util" "github.com/canonical/k8s/pkg/config" "github.com/canonical/k8s/pkg/utils" - "github.com/canonical/lxd/lxd/util" "github.com/spf13/cobra" "gopkg.in/yaml.v2" ) @@ -73,10 +72,12 @@ func newBootstrapCmd(env cmdutil.ExecutionEnvironment) *cobra.Command { } } - if opts.address == "" { - opts.address = util.NetworkInterfaceAddress() + address, err := utils.ParseAddressString(opts.address, config.DefaultPort) + if err != nil { + cmd.PrintErrf("Error: Failed to parse the address %q.\n\nThe error was: %v\n", opts.address, err) + env.Exit(1) + return } - opts.address = util.CanonicalNetworkAddress(opts.address, config.DefaultPort) client, err := env.Client(cmd.Context()) if err != nil { @@ -126,7 +127,7 @@ func newBootstrapCmd(env cmdutil.ExecutionEnvironment) *cobra.Command { request := apiv1.PostClusterBootstrapRequest{ Name: opts.name, - Address: opts.address, + Address: address, Config: bootstrapConfig, } @@ -147,7 +148,7 @@ func newBootstrapCmd(env cmdutil.ExecutionEnvironment) *cobra.Command { cmd.Flags().BoolVar(&opts.interactive, "interactive", false, "interactively configure the most important cluster options") cmd.Flags().StringVar(&opts.configFile, "file", "", "path to the YAML file containing your custom cluster bootstrap configuration. Use '-' to read from stdin.") cmd.Flags().StringVar(&opts.name, "name", "", "node name, defaults to hostname") - cmd.Flags().StringVar(&opts.address, "address", "", "microcluster address, defaults to the node IP address") + cmd.Flags().StringVar(&opts.address, "address", "", "microcluster address or CIDR, defaults to the node IP address") cmd.Flags().StringVar(&opts.outputFormat, "output-format", "plain", "set the output format to one of plain, json or yaml") cmd.Flags().DurationVar(&opts.timeout, "timeout", 90*time.Second, "the max time to wait for the command to execute") diff --git a/src/k8s/cmd/k8s/k8s_join_cluster.go b/src/k8s/cmd/k8s/k8s_join_cluster.go index 148f32601..0bf5e4ac8 100644 --- a/src/k8s/cmd/k8s/k8s_join_cluster.go +++ b/src/k8s/cmd/k8s/k8s_join_cluster.go @@ -10,7 +10,7 @@ import ( apiv1 "github.com/canonical/k8s/api/v1" cmdutil "github.com/canonical/k8s/cmd/util" "github.com/canonical/k8s/pkg/config" - "github.com/canonical/lxd/lxd/util" + "github.com/canonical/k8s/pkg/utils" "github.com/spf13/cobra" ) @@ -55,10 +55,12 @@ func newJoinClusterCmd(env cmdutil.ExecutionEnvironment) *cobra.Command { opts.name = hostname } - if opts.address == "" { - opts.address = util.NetworkInterfaceAddress() + address, err := utils.ParseAddressString(opts.address, config.DefaultPort) + if err != nil { + cmd.PrintErrf("Error: Failed to parse the address %q.\n\nThe error was: %v\n", opts.address, err) + env.Exit(1) + return } - opts.address = util.CanonicalNetworkAddress(opts.address, config.DefaultPort) client, err := env.Client(cmd.Context()) if err != nil { @@ -100,7 +102,7 @@ func newJoinClusterCmd(env cmdutil.ExecutionEnvironment) *cobra.Command { cobra.OnFinalize(cancel) cmd.PrintErrln("Joining the cluster. This may take a few seconds, please wait.") - if err := client.JoinCluster(ctx, apiv1.JoinClusterRequest{Name: opts.name, Address: opts.address, Token: token, Config: joinClusterConfig}); err != nil { + if err := client.JoinCluster(ctx, apiv1.JoinClusterRequest{Name: opts.name, Address: address, Token: token, Config: joinClusterConfig}); err != nil { cmd.PrintErrf("Error: Failed to join the cluster using the provided token.\n\nThe error was: %v\n", err) env.Exit(1) return @@ -110,7 +112,7 @@ func newJoinClusterCmd(env cmdutil.ExecutionEnvironment) *cobra.Command { }, } cmd.Flags().StringVar(&opts.name, "name", "", "node name, defaults to hostname") - cmd.Flags().StringVar(&opts.address, "address", "", "microcluster address, defaults to the node IP address") + cmd.Flags().StringVar(&opts.address, "address", "", "microcluster address or CIDR, defaults to the node IP address") cmd.Flags().StringVar(&opts.configFile, "file", "", "path to the YAML file containing your custom cluster join configuration. Use '-' to read from stdin.") cmd.Flags().StringVar(&opts.outputFormat, "output-format", "plain", "set the output format to one of plain, json or yaml") cmd.Flags().DurationVar(&opts.timeout, "timeout", 90*time.Second, "the max time to wait for the command to execute") diff --git a/src/k8s/pkg/utils/cidr.go b/src/k8s/pkg/utils/cidr.go index 1c0c279ad..1a2ae0919 100644 --- a/src/k8s/pkg/utils/cidr.go +++ b/src/k8s/pkg/utils/cidr.go @@ -4,9 +4,44 @@ import ( "fmt" "math/big" "net" + "strconv" "strings" + + "github.com/canonical/lxd/lxd/util" ) +// findMatchingNodeAddress returns the IP address of a network interface that belongs to the given CIDR. +func findMatchingNodeAddress(cidr *net.IPNet) (net.IP, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil, fmt.Errorf("could not get interface addresses: %w", err) + } + + var selectedIP net.IP + selectedSubnetBits := -1 + + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if !ok { + continue + } + if cidr.Contains(ipNet.IP) { + _, subnetBits := cidr.Mask.Size() + if selectedSubnetBits == -1 || subnetBits < selectedSubnetBits { + // Prefer the address with the fewest subnet bits + selectedIP = ipNet.IP + selectedSubnetBits = subnetBits + } + } + } + + if selectedIP == nil { + return nil, fmt.Errorf("could not find a matching address for CIDR %q", cidr.String()) + } + + return selectedIP, nil +} + // GetFirstIP returns the first IP address of a subnet. Use big.Int so that it can handle both IPv4 and IPv6 addreses. func GetFirstIP(subnet string) (net.IP, error) { _, cidr, err := net.ParseCIDR(subnet) @@ -37,3 +72,32 @@ func GetKubernetesServiceIPsFromServiceCIDRs(serviceCIDR string) ([]net.IP, erro } return firstIPs, nil } + +// ParseAddressString parses an address string and returns a canonical network address. +func ParseAddressString(address string, port int64) (string, error) { + host, hostPort, err := net.SplitHostPort(address) + if err == nil { + address = host + port, err = strconv.ParseInt(hostPort, 10, 64) + if err != nil { + return "", fmt.Errorf("failed to parse the port from %q: %w", hostPort, err) + } + } + + if port < 0 || port > 65535 { + return "", fmt.Errorf("invalid port number %d", port) + } + + if address == "" { + address = util.NetworkInterfaceAddress() + } else if _, ipNet, err := net.ParseCIDR(address); err == nil { + matchingIP, err := findMatchingNodeAddress(ipNet) + if err != nil { + return "", fmt.Errorf("failed to find a matching node address for %q: %w", address, err) + } + address = matchingIP.String() + } + + return util.CanonicalNetworkAddress(address, port), nil + +} diff --git a/src/k8s/pkg/utils/cidr_test.go b/src/k8s/pkg/utils/cidr_test.go index 4502c10ec..cfee353e3 100644 --- a/src/k8s/pkg/utils/cidr_test.go +++ b/src/k8s/pkg/utils/cidr_test.go @@ -1,9 +1,12 @@ package utils_test import ( + "fmt" + "net" "testing" "github.com/canonical/k8s/pkg/utils" + "github.com/canonical/lxd/lxd/util" . "github.com/onsi/gomega" ) @@ -68,3 +71,44 @@ func TestGetKubernetesServiceIPsFromServiceCIDRs(t *testing.T) { } }) } + +func TestParseAddressString(t *testing.T) { + g := NewWithT(t) + + // Seed the default address + defaultAddress := util.NetworkInterfaceAddress() + ip := net.ParseIP(defaultAddress) + subnetMask := net.CIDRMask(24, 32) + networkAddress := ip.Mask(subnetMask) + // Infer the CIDR notation + networkAddressCIDR := fmt.Sprintf("%s/24", networkAddress.String()) + + for _, tc := range []struct { + name string + address string + port int64 + want string + wantErr bool + }{ + {name: "EmptyAddress", address: "", port: 8080, want: fmt.Sprintf("%s:8080", defaultAddress), wantErr: false}, + {name: "CIDR", address: networkAddressCIDR, port: 8080, want: fmt.Sprintf("%s:8080", defaultAddress), wantErr: false}, + {name: "CIDRAndPort", address: fmt.Sprintf("%s:9090", networkAddressCIDR), port: 8080, want: fmt.Sprintf("%s:9090", defaultAddress), wantErr: false}, + {name: "IPv4", address: "10.0.0.10", port: 8080, want: "10.0.0.10:8080", wantErr: false}, + {name: "IPv4AndPort", address: "10.0.0.10:9090", port: 8080, want: "10.0.0.10:9090", wantErr: false}, + {name: "NonMatchingCIDR", address: "10.10.5.0/24", port: 8080, want: "", wantErr: true}, + {name: "IPv6", address: "fe80::1:234", port: 8080, want: "[fe80::1:234]:8080", wantErr: false}, + {name: "IPv6AndPort", address: "[fe80::1:234]:9090", port: 8080, want: "[fe80::1:234]:9090", wantErr: false}, + {name: "InvalidPort", address: "127.0.0.1:invalid-port", port: 0, want: "", wantErr: true}, + {name: "PortOutOfBounds", address: "10.0.0.10:70799", port: 8080, want: "", wantErr: true}, + } { + t.Run(tc.name, func(t *testing.T) { + got, err := utils.ParseAddressString(tc.address, tc.port) + if tc.wantErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(got).To(Equal(tc.want)) + } + }) + } +}