diff --git a/pkg/ip/cidr.go b/pkg/ip/cidr.go index 156351a..0c66232 100644 --- a/pkg/ip/cidr.go +++ b/pkg/ip/cidr.go @@ -146,3 +146,45 @@ func IsBroadcast(ip net.IP, network *net.IPNet) bool { binary.BigEndian.Uint32(network.IP.To4())|^binary.BigEndian.Uint32(net.IP(network.Mask).To4())) return ip.Equal(masked) } + +// GetSubnetGen returns generator function that can be called multiple times +// to generate subnet for the network with the prefix size. +// The function always returns non-nil function. +// The generator function will return nil If subnet can't be generate +// (invalid input args provided, or no more subnets available for the network). +// Example: +// _, network, _ := net.ParseCIDR("192.168.0.0/23") +// gen := GetSubnetGen(network, 25) +// println(gen().String()) // 192.168.0.0/25 +// println(gen().String()) // 192.168.0.128/25 +// println(gen().String()) // 192.168.1.0/25 +// println(gen().String()) // 192.168.1.128/25 +// println(gen().String()) // - no more ranges available +func GetSubnetGen(network *net.IPNet, prefixSize uint) func() *net.IPNet { + networkOnes, netBitsTotal := network.Mask.Size() + if prefixSize < uint(networkOnes) || prefixSize > uint(netBitsTotal) { + return func() *net.IPNet { return nil } + } + isIPv6 := false + if network.IP.To4() == nil { + isIPv6 = true + } + networkIPAsInt := ipToInt(network.IP) + subnetIPCount := big.NewInt(0).Exp(big.NewInt(2), big.NewInt(int64(netBitsTotal)-int64(prefixSize)), nil) + subnetCount := big.NewInt(0).Exp(big.NewInt(2), big.NewInt(int64(prefixSize)-int64(networkOnes)), nil) + + curSubnetIndex := big.NewInt(0) + + return func() *net.IPNet { + if curSubnetIndex.Cmp(subnetCount) >= 0 { + return nil + } + subnetIPAsInt := big.NewInt(0).Add(networkIPAsInt, big.NewInt(0).Mul(subnetIPCount, curSubnetIndex)) + curSubnetIndex.Add(curSubnetIndex, big.NewInt(1)) + subnetIP := intToIP(subnetIPAsInt, isIPv6) + if subnetIP == nil { + return nil + } + return &net.IPNet{IP: subnetIP, Mask: net.CIDRMask(int(prefixSize), netBitsTotal)} + } +} diff --git a/pkg/ip/cidr_test.go b/pkg/ip/cidr_test.go index a6fc49b..1f98499 100644 --- a/pkg/ip/cidr_test.go +++ b/pkg/ip/cidr_test.go @@ -319,4 +319,59 @@ var _ = Describe("CIDR functions", func() { Expect(result).To(Equal(test.result)) } }) + + Context("GetSubnetGen", func() { + It("Invalid args - prefix is larger then network", func() { + _, net, _ := net.ParseCIDR("192.168.0.0/16") + gen := GetSubnetGen(net, 8) + Expect(gen).NotTo(BeNil()) + Expect(gen()).To(BeNil()) + }) + It("Invalid args - prefix is too small for IPv4", func() { + _, net, _ := net.ParseCIDR("192.168.0.0/16") + gen := GetSubnetGen(net, 120) + Expect(gen).NotTo(BeNil()) + Expect(gen()).To(BeNil()) + }) + It("Valid - single subnet IPv4", func() { + _, net, _ := net.ParseCIDR("192.168.0.0/24") + gen := GetSubnetGen(net, 24) + Expect(gen).NotTo(BeNil()) + Expect(gen().String()).To(Equal("192.168.0.0/24")) + Expect(gen()).To(BeNil()) + }) + It("Valid - single subnet IPv6", func() { + _, net, _ := net.ParseCIDR("2002:0:0:1234::/64") + gen := GetSubnetGen(net, 64) + Expect(gen).NotTo(BeNil()) + Expect(gen().String()).To(Equal("2002:0:0:1234::/64")) + Expect(gen()).To(BeNil()) + }) + It("valid - IPv4", func() { + _, net, _ := net.ParseCIDR("192.168.4.0/23") + gen := GetSubnetGen(net, 25) + Expect(gen).NotTo(BeNil()) + Expect(gen().String()).To(Equal("192.168.4.0/25")) + Expect(gen().String()).To(Equal("192.168.4.128/25")) + Expect(gen().String()).To(Equal("192.168.5.0/25")) + Expect(gen().String()).To(Equal("192.168.5.128/25")) + Expect(gen()).To(BeNil()) + }) + It("valid - IPv6", func() { + _, net, _ := net.ParseCIDR("2002:0:0:1234::/64") + gen := GetSubnetGen(net, 124) + Expect(gen).NotTo(BeNil()) + Expect(gen().String()).To(Equal("2002:0:0:1234::/124")) + Expect(gen().String()).To(Equal("2002:0:0:1234::10/124")) + Expect(gen().String()).To(Equal("2002:0:0:1234::20/124")) + }) + It("valid - large IPv6 subnet (overflow test)", func() { + _, net, _ := net.ParseCIDR("::/0") + gen := GetSubnetGen(net, 127) + Expect(gen).NotTo(BeNil()) + Expect(gen().String()).To(Equal("::/127")) + Expect(gen().String()).To(Equal("::2/127")) + Expect(gen().String()).To(Equal("::4/127")) + }) + }) })