diff --git a/net/ipfamily_test.go b/net/ipfamily_test.go index 78f2cdc1..13dd954d 100644 --- a/net/ipfamily_test.go +++ b/net/ipfamily_test.go @@ -24,167 +24,169 @@ import ( func TestDualStackIPs(t *testing.T) { testCases := []struct { + desc string ips []string - errMessage string expectedResult bool expectError bool }{ { + desc: "should fail because length is not at least 2", ips: []string{"1.1.1.1"}, - errMessage: "should fail because length is not at least 2", expectedResult: false, expectError: false, }, { + desc: "should fail because length is not at least 2", ips: []string{}, - errMessage: "should fail because length is not at least 2", expectedResult: false, expectError: false, }, { + desc: "should fail because all are v4", ips: []string{"1.1.1.1", "2.2.2.2", "3.3.3.3"}, - errMessage: "should fail because all are v4", expectedResult: false, expectError: false, }, { + desc: "should fail because all are v6", ips: []string{"fd92:20ba:ca:34f7:ffff:ffff:ffff:ffff", "fd92:20ba:ca:34f7:ffff:ffff:ffff:fff0", "fd92:20ba:ca:34f7:ffff:ffff:ffff:fff1"}, - errMessage: "should fail because all are v6", expectedResult: false, expectError: false, }, { + desc: "should fail because 2nd ip is invalid", ips: []string{"1.1.1.1", "not-a-valid-ip"}, - errMessage: "should fail because 2nd ip is invalid", expectedResult: false, expectError: true, }, { + desc: "should fail because 1st ip is invalid", ips: []string{"not-a-valid-ip", "fd92:20ba:ca:34f7:ffff:ffff:ffff:ffff"}, - errMessage: "should fail because 1st ip is invalid", expectedResult: false, expectError: true, }, { + desc: "should fail despite dual-stack because 3rd ip is invalid", + ips: []string{"1.1.1.1", "fd92:20ba:ca:34f7:ffff:ffff:ffff:ffff", "not-a-valid-ip"}, + expectedResult: false, + expectError: true, + }, + { + desc: "dual-stack ipv4-primary", ips: []string{"1.1.1.1", "fd92:20ba:ca:34f7:ffff:ffff:ffff:ffff"}, - errMessage: "expected success, but found failure", expectedResult: true, expectError: false, }, { + desc: "dual-stack, multiple ipv6", ips: []string{"fd92:20ba:ca:34f7:ffff:ffff:ffff:ffff", "1.1.1.1", "fd92:20ba:ca:34f7:ffff:ffff:ffff:fff0"}, - errMessage: "expected success, but found failure", expectedResult: true, expectError: false, }, { + desc: "dual-stack, multiple ipv4", ips: []string{"1.1.1.1", "fd92:20ba:ca:34f7:ffff:ffff:ffff:ffff", "10.0.0.0"}, - errMessage: "expected success, but found failure", expectedResult: true, expectError: false, }, { + desc: "dual-stack, ipv6-primary", ips: []string{"fd92:20ba:ca:34f7:ffff:ffff:ffff:ffff", "1.1.1.1"}, - errMessage: "expected success, but found failure", expectedResult: true, expectError: false, }, } // for each test case, test the regular func and the string func for _, tc := range testCases { - dualStack, err := IsDualStackIPStrings(tc.ips) - if err == nil && tc.expectError { - t.Errorf("%s", tc.errMessage) - continue - } - if err != nil && !tc.expectError { - t.Errorf("failed to run test case for %v, error: %v", tc.ips, err) - continue - } - if dualStack != tc.expectedResult { - t.Errorf("%v for %v", tc.errMessage, tc.ips) - } - } + t.Run(tc.desc, func(t *testing.T) { + dualStack, err := IsDualStackIPStrings(tc.ips) + if err == nil && tc.expectError { + t.Fatalf("expected an error from IsDualStackIPStrings") + } + if err != nil && !tc.expectError { + t.Fatalf("unexpected error from IsDualStackIPStrings: %v", err) + } + if dualStack != tc.expectedResult { + t.Errorf("expected IsDualStackIPStrings=%v, got %v", tc.expectedResult, dualStack) + } - for _, tc := range testCases { - ips := make([]net.IP, 0, len(tc.ips)) - for _, ip := range tc.ips { - parsedIP := ParseIPSloppy(ip) - ips = append(ips, parsedIP) - } - dualStack, err := IsDualStackIPs(ips) - if err == nil && tc.expectError { - t.Errorf("%s", tc.errMessage) - continue - } - if err != nil && !tc.expectError { - t.Errorf("failed to run test case for %v, error: %v", tc.ips, err) - continue - } - if dualStack != tc.expectedResult { - t.Errorf("%v for %v", tc.errMessage, tc.ips) - } + ips := make([]net.IP, 0, len(tc.ips)) + for _, ip := range tc.ips { + parsedIP := ParseIPSloppy(ip) + ips = append(ips, parsedIP) + } + dualStack, err = IsDualStackIPs(ips) + if err == nil && tc.expectError { + t.Fatalf("expected an error from IsDualStackIPs") + } + if err != nil && !tc.expectError { + t.Fatalf("unexpected error from IsDualStackIPs: %v", err) + } + if dualStack != tc.expectedResult { + t.Errorf("expected IsDualStackIPs=%v, got %v", tc.expectedResult, dualStack) + } + }) } } func TestDualStackCIDRs(t *testing.T) { testCases := []struct { + desc string cidrs []string - errMessage string expectedResult bool expectError bool }{ { + desc: "should fail because length is not at least 2", cidrs: []string{"10.10.10.10/8"}, - errMessage: "should fail because length is not at least 2", expectedResult: false, expectError: false, }, { + desc: "should fail because length is not at least 2", cidrs: []string{}, - errMessage: "should fail because length is not at least 2", expectedResult: false, expectError: false, }, { + desc: "should fail because all cidrs are v4", cidrs: []string{"10.10.10.10/8", "20.20.20.20/8", "30.30.30.30/8"}, - errMessage: "should fail because all cidrs are v4", expectedResult: false, expectError: false, }, { + desc: "should fail because all cidrs are v6", cidrs: []string{"2000::/10", "3000::/10"}, - errMessage: "should fail because all cidrs are v6", expectedResult: false, expectError: false, }, { + desc: "should fail because 2nd cidr is invalid", cidrs: []string{"10.10.10.10/8", "not-a-valid-cidr"}, - errMessage: "should fail because 2nd cidr is invalid", expectedResult: false, expectError: true, }, { + desc: "should fail because 1st cidr is invalid", cidrs: []string{"not-a-valid-ip", "2000::/10"}, - errMessage: "should fail because 1st cidr is invalid", expectedResult: false, expectError: true, }, { + desc: "dual-stack, ipv4-primary", cidrs: []string{"10.10.10.10/8", "2000::/10"}, - errMessage: "expected success, but found failure", expectedResult: true, expectError: false, }, { + desc: "dual-stack, ipv6-primary", cidrs: []string{"2000::/10", "10.10.10.10/8"}, - errMessage: "expected success, but found failure", expectedResult: true, expectError: false, }, { + desc: "dual-stack, multiple IPv6", cidrs: []string{"2000::/10", "10.10.10.10/8", "3000::/10"}, - errMessage: "expected success, but found failure", expectedResult: true, expectError: false, }, @@ -192,39 +194,35 @@ func TestDualStackCIDRs(t *testing.T) { // for each test case, test the regular func and the string func for _, tc := range testCases { - dualStack, err := IsDualStackCIDRStrings(tc.cidrs) - if err == nil && tc.expectError { - t.Errorf("%s", tc.errMessage) - continue - } - if err != nil && !tc.expectError { - t.Errorf("failed to run test case for %v, error: %v", tc.cidrs, err) - continue - } - if dualStack != tc.expectedResult { - t.Errorf("%v for %v", tc.errMessage, tc.cidrs) - } - } + t.Run(tc.desc, func(t *testing.T) { + dualStack, err := IsDualStackCIDRStrings(tc.cidrs) + if err == nil && tc.expectError { + t.Fatalf("expected an error from IsDualStackCIDRStrings") + } + if err != nil && !tc.expectError { + t.Fatalf("unexpected error from IsDualStackCIDRStrings: %v", err) + } + if dualStack != tc.expectedResult { + t.Errorf("expected IsDualStackCIDRStrings=%v, got %v", tc.expectedResult, dualStack) + } - for _, tc := range testCases { - cidrs := make([]*net.IPNet, 0, len(tc.cidrs)) - for _, cidr := range tc.cidrs { - _, parsedCIDR, _ := ParseCIDRSloppy(cidr) - cidrs = append(cidrs, parsedCIDR) - } + cidrs := make([]*net.IPNet, 0, len(tc.cidrs)) + for _, cidr := range tc.cidrs { + _, parsedCIDR, _ := ParseCIDRSloppy(cidr) + cidrs = append(cidrs, parsedCIDR) + } - dualStack, err := IsDualStackCIDRs(cidrs) - if err == nil && tc.expectError { - t.Errorf("%s", tc.errMessage) - continue - } - if err != nil && !tc.expectError { - t.Errorf("failed to run test case for %v, error: %v", tc.cidrs, err) - continue - } - if dualStack != tc.expectedResult { - t.Errorf("%v for %v", tc.errMessage, tc.cidrs) - } + dualStack, err = IsDualStackCIDRs(cidrs) + if err == nil && tc.expectError { + t.Fatalf("expected an error from IsDualStackCIDRs") + } + if err != nil && !tc.expectError { + t.Fatalf("unexpected error from IsDualStackCIDRs: %v", err) + } + if dualStack != tc.expectedResult { + t.Errorf("expected IsDualStackCIDRs=%v, got %v", tc.expectedResult, dualStack) + } + }) } } diff --git a/net/v2/convert.go b/net/v2/convert.go new file mode 100644 index 00000000..f37e0c87 --- /dev/null +++ b/net/v2/convert.go @@ -0,0 +1,202 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "net" + "net/netip" + "strings" +) + +// AddrFromIP converts a net.IP to a netip.Addr. Given valid input this will always +// succeed; it will return the invalid netip.Addr on nil or garbage input. +// +// Use this rather than netip.AddrFromSlice(), which (despite the claims of its +// documentation) does not always do what you would expect if you pass it a net.IP. +func AddrFromIP(ip net.IP) netip.Addr { + // Naively using netip.AddrFromSlice() gives unexpected results: + // + // ip := net.ParseIP("1.2.3.4") + // addr, _ := netip.AddrFromSlice(ip) + // addr.String() => "::ffff:1.2.3.4" + // addr.Is4() => false + // addr.Is6() => true + // + // This is because net.IP and netip.Addr have different ideas about how to handle + // "IPv4-mapped IPv6" addresses, but netip.AddrFromSlice ignores that fact. + // + // In net.IP, parsing either "1.2.3.4" or "::ffff:1.2.3.4", will give you the + // same result: + // + // ip1 := net.ParseIP("1.2.3.4") + // []byte(ip1) => []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 1, 2, 3, 4} + // ip1.String() => "1.2.3.4" + // ip2 := net.ParseIP("::ffff:1.2.3.4") + // []byte(ip2) => []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 1, 2, 3, 4} + // ip2.String() => "1.2.3.4" + // + // net.IP normally stores IPv4 addresses as 16-byte IPv4-mapped IPv6 addresses, + // but it hides that from the user, and it never stringifies an IPv4 IP to an + // IPv4-mapped IPv6 form, even if that was the format you started with. + // + // net.IP *can* represent IPv4 addresses in a 4-byte format, but this is treated + // as completly equivalent to the 16-byte representation: + // + // ip4 := ip1.To4() + // []byte(ip4) => []byte{1, 2, 3, 4} + // ip4.String() => "1.2.3.4" + // ip1.Equal(ip4) => true + // + // netip.Addr, on the other hand, treats "plain" IPv4 and IPv4-mapped IPv6 as two + // completely separate things: + // + // a1 := netip.MustParseAddr("1.2.3.4") + // a2 := netip.MustParseAddr("::ffff:1.2.3.4") + // a1.String() => "1.2.3.4" + // a2.String() => "::ffff:1.2.3.4" + // a1 == a2 => false + // + // which would be fine, except that netip.AddrFromSlice breaks net.IP's normal + // semantics by converting the 4-byte and 16-byte net.IP forms to different + // netip.Addr values, giving the confusing results above. + // + // In order to correctly convert an IPv4 address from net.IP to netip.Addr, you + // need to either call .To4() on it before converting, or call .Unmap() on it + // after converting. (The latter option is slightly simpler for us here because we + // can just do it unconditionally, since it's a no-op in the IPv6 and invalid + // cases). + + addr, _ := netip.AddrFromSlice(ip) + return addr.Unmap() +} + +// IPFromAddr converts a netip.Addr to a net.IP. Given valid input this will always +// succeed; it will return nil if addr is the invalid netip.Addr. +func IPFromAddr(addr netip.Addr) net.IP { + // addr.AsSlice() returns: + // - a []byte of length 4 if addr is a normal IPv4 address + // - a []byte of length 16 if addr is an IPv6 address (including IPv4-mapped IPv6) + // - nil if addr is the zero Addr (which is the only other possibility) + // + // Any of those values can be correctly cast directly to a net.IP. + // + // Note that we don't bother to do any "cleanup" here like in the AddrFromIP case, + // so converting a plain IPv4 netip.Addr to net.IP gives a different result than + // converting an IPv4-mapped IPv6 netip.Addr: + // + // ip1 := netutils.IPFromAddr(netip.MustParseAddr("1.2.3.4")) + // []byte(ip1) => []byte{1, 2, 3, 4} + // + // ip2 := netutils.IPFromAddr(netip.MustParseAddr("::ffff:1.2.3.4")) + // []byte(ip2) => []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 1, 2, 3, 4} + // + // However, the net.IP API treats the two values as the same anyway, so it doesn't + // matter. + // + // ip1.String() => "1.2.3.4" + // ip2.String() => "1.2.3.4" + // ip2.Equal(ip1) => true + + return net.IP(addr.AsSlice()) +} + +// IPFromInterfaceAddr can be used to extract the underlying IP address value (as a +// net.IP) from the return values of net.InterfaceAddrs(), net.Interface.Addrs(), or +// net.Interface.MulticastAddrs(). (net.Addr is also used in some other APIs, but this +// function should not be used on net.Addrs that are not "interface addresses".) +func IPFromInterfaceAddr(ifaddr net.Addr) net.IP { + // On both Linux and Windows, the values returned from the "interface address" + // methods are currently *net.IPNet for unicast addresses or *net.IPAddr for + // multicast addresses. + if ipnet, ok := ifaddr.(*net.IPNet); ok { + return ipnet.IP + } else if ipaddr, ok := ifaddr.(*net.IPAddr); ok { + return ipaddr.IP + } + + // Try to deal with other similar types... in particular, this is needed for + // some existing unit tests... + addrStr := ifaddr.String() + // If it has a subnet length (like net.IPNet) or optional zone identifier (like + // net.IPAddr), trim that away. + if end := strings.IndexAny(addrStr, "/%"); end != -1 { + addrStr = addrStr[:end] + } + // What's left is either an IP address, or something we can't parse. + ip, _ := ParseIP(addrStr) + return ip +} + +// AddrFromInterfaceAddr can be used to extract the underlying IP address value (as a +// netip.Addr) from the return values of net.InterfaceAddrs(), net.Interface.Addrs(), or +// net.Interface.MulticastAddrs(). (net.Addr is also used in some other APIs, but this +// function should not be used on net.Addrs that are not "interface addresses".) +func AddrFromInterfaceAddr(ifaddr net.Addr) netip.Addr { + return AddrFromIP(IPFromInterfaceAddr(ifaddr)) +} + +// PrefixFromIPNet converts a *net.IPNet to a netip.Prefix. Given valid input this will +// always succeed; it will return the invalid netip.Prefix on nil or garbage input. +func PrefixFromIPNet(ipnet *net.IPNet) netip.Prefix { + if ipnet == nil { + return netip.Prefix{} + } + + addr := AddrFromIP(ipnet.IP) + if !addr.IsValid() { + return netip.Prefix{} + } + + prefixLen, bits := ipnet.Mask.Size() + if prefixLen == 0 && bits == 0 { + // non-CIDR Mask representation; not representible as a netip.Prefix + return netip.Prefix{} + } + if bits == 128 && addr.Is4() && (bits-prefixLen <= 32) { + // In the same way that net.IP allows an IPv4 IP to be either 4 or 16 + // bytes (32 or 128 bits), *net.IPNet allows an IPv4 CIDR to have either a + // 32-bit or a 128-bit mask. If the mask is 128 bits, we discard the + // leftmost 96 bits. + prefixLen -= 128 - 32 + } else if bits != addr.BitLen() { + // invalid IPv4/IPv6 mix + return netip.Prefix{} + } + + return netip.PrefixFrom(addr, prefixLen) +} + +// IPNetFromPrefix converts a netip.Prefix to a *net.IPNet. Given valid input this will +// always succeed; it will return nil if prefix is the invalid netip.Prefix or is +// otherwise invalid. +func IPNetFromPrefix(prefix netip.Prefix) *net.IPNet { + addr := prefix.Addr() + bits := prefix.Bits() + if bits == -1 || !addr.IsValid() { + return nil + } + addrLen := addr.BitLen() + + // (As with IPFromAddr, a plain IPv4 netip.Prefix and an equivalent IPv4-mapped + // IPv6 netip.Prefix will get converted to distinct *net.IPNet values, but + // *net.IPNet will treat them equivalently.) + + return &net.IPNet{ + IP: IPFromAddr(addr), + Mask: net.CIDRMask(bits, addrLen), + } +} diff --git a/net/v2/convert_test.go b/net/v2/convert_test.go new file mode 100644 index 00000000..96798ff9 --- /dev/null +++ b/net/v2/convert_test.go @@ -0,0 +1,319 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "fmt" + "net" + "net/netip" + "testing" +) + +func TestAddrFromIP(t *testing.T) { + // See test cases in ips_test.go + for _, tc := range goodTestIPs { + if tc.skipConvert { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, ip := range tc.ips { + addr := AddrFromIP(ip) + if tc.addrs[0] != addr { + t.Errorf("IP %d %#v %s converted to addr %q, but expected %q", i+1, ip, ip, addr, tc.addrs[0]) + } + + // No net.IP should convert to an IPv4-mapped IPv6 netip.Addr + if addr.Is4In6() { + t.Errorf("AddrFromIP() converted IP %d %#v %s to IPv4-mapped IPv6 Addr %#v %s", i+1, ip, ip, addr, addr) + } + // And thus every value should round-trip. + rtIP := IPFromAddr(addr) + if !ip.Equal(rtIP) { + t.Errorf("IP %d %#v %s round-tripped to %#v %s", i+1, ip, ip, rtIP, rtIP) + } + } + }) + } + + // See test cases in ips_test.go + for _, tc := range badTestIPs { + if tc.skipConvert { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, ip := range tc.ips { + addr := AddrFromIP(ip) + if addr.IsValid() { + t.Errorf("Expected IP %d %#v to convert to invalid netip.Addr but got %#v %s", i+1, ip, addr, addr) + } + } + }) + } +} + +func TestIPFromAddr(t *testing.T) { + // See test cases in ips_test.go + for _, tc := range goodTestIPs { + if tc.skipConvert { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, addr := range tc.addrs { + ip := IPFromAddr(addr) + if !ip.Equal(tc.ips[0]) { + t.Errorf("addr %d %#v %s converted to ip %q, but expected %q", i, addr, addr, ip, tc.ips[0]) + } + + // As long as addr is not IPv4-mapped IPv6, it should round-trip. + if !addr.Is4In6() { + rtAddr := AddrFromIP(ip) + if addr != rtAddr { + t.Errorf("Addr %d %#v %s round-tripped to %#v %s", i+1, addr, addr, rtAddr, rtAddr) + } + } + } + }) + } + + // Conversion of IPv4-mapped IPv6 is asymmetric because netip.Addr distinguishes + // plain IPv4 from IPv4-mapped IPv6, while net.IP does not. The "IPv4-mapped IPv6" + // test case in goodTestIPs covers most of the cases, but goodTestIPs has no way + // to describe the asymmetric part. + t.Run("IPv4-mapped IPv6 conversion from netip.Addr", func(t *testing.T) { + addr := netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 1, 2, 3, 4}) + if !addr.Is4In6() { + panic("failed to create IPv4-mapped IPv6 netip.Addr?") + } + + ip := IPFromAddr(addr) + expectedIP := net.IP{1, 2, 3, 4} + if !ip.Equal(expectedIP) { + t.Errorf("netip.Addr %q converted to %q, expected %q", addr, ip, expectedIP) + } + rtAddr := AddrFromIP(ip) + if rtAddr == addr { + t.Errorf("IPv4-mapped IPv6 netip.Addr unexpectedly round-tripped through net.IP!") + } + }) + + // See test cases in ips_test.go + for _, tc := range badTestIPs { + if tc.skipConvert { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, addr := range tc.addrs { + ip := IPFromAddr(addr) + if ip != nil { + t.Errorf("Expected Addr %d %#v to convert to invalid net.IP but got %#v %s", i+1, addr, ip, ip) + } + } + }) + } +} + +type dummyNetAddr string + +func (d dummyNetAddr) Network() string { + return "dummy" +} +func (d dummyNetAddr) String() string { + return string(d) +} + +func TestIPFromInterfaceAddr_AddrFromInterfaceAddr(t *testing.T) { + testCases := []struct { + desc string + ifaddr net.Addr + out string + }{ + { + desc: "net.IPNet", + ifaddr: &net.IPNet{IP: net.IP{192, 168, 1, 1}, Mask: net.CIDRMask(24, 32)}, + out: "192.168.1.1", + }, + { + desc: "net.IPAddr", + ifaddr: &net.IPAddr{IP: net.IP{192, 168, 1, 2}}, + out: "192.168.1.2", + }, + { + desc: "net.IPAddr with zone", + ifaddr: &net.IPAddr{IP: net.IP{192, 168, 1, 3}, Zone: "eth0"}, + out: "192.168.1.3", + }, + { + desc: "net.TCPAddr", + ifaddr: &net.TCPAddr{IP: net.IP{192, 168, 1, 4}, Port: 80}, + out: "", + }, + { + desc: "unknown plain IP", + ifaddr: dummyNetAddr("192.168.1.5"), + out: "192.168.1.5", + }, + { + desc: "unknown CIDR", + ifaddr: dummyNetAddr("192.168.1.6/24"), + out: "192.168.1.6", + }, + { + desc: "unknown IP with zone", + ifaddr: dummyNetAddr("192.168.1.7%eth0"), + out: "192.168.1.7", + }, + { + desc: "unknown sockaddr", + ifaddr: dummyNetAddr("192.168.1.8:80"), + out: "", + }, + { + desc: "unknown junk", + ifaddr: dummyNetAddr("junk"), + out: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ip := IPFromInterfaceAddr(tc.ifaddr) + addr := AddrFromInterfaceAddr(tc.ifaddr) + if tc.out == "" { + if ip != nil { + t.Errorf("expected IPFromInterfaceAddr to return nil but got %q", ip.String()) + } + if addr.IsValid() { + t.Errorf("expected AddrFromInterfaceAddr to return zero but got %q", addr.String()) + } + } else { + if ip.String() != tc.out { + t.Errorf("expected IPFromInterfaceAddr to return %q but got %q", tc.out, ip.String()) + } + if addr.String() != tc.out { + t.Errorf("expected AddrFromInterfaceAddr to return %q but got %q", tc.out, addr.String()) + } + } + }) + } +} + +func TestPrefixFromIPNet(t *testing.T) { + // See test cases in ips_test.go + for _, tc := range goodTestCIDRs { + if tc.skipConvert { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, ipnet := range tc.ipnets { + prefix := PrefixFromIPNet(ipnet) + if tc.prefixes[0] != prefix { + t.Errorf("IPNet %d %#v %s converted to prefix %q, but expected %q", i+1, *ipnet, ipnet, prefix, tc.prefixes[0]) + } + + // No net.IPNet should convert to an IPv4-mapped IPv6 netip.Prefix + if prefix.Addr().Is4In6() { + t.Errorf("PrefixFromIPNet() converted IPNet %d %#v %s to IPv4-mapped IPv6 prefix %#v %s", i+1, *ipnet, ipnet, prefix, prefix) + } + // And thus every value should round-trip. + rtIPNet := IPNetFromPrefix(prefix) + if rtIPNet.String() != ipnet.String() { + t.Errorf("IPNet %d %#v %s round-tripped to %#v %s", i+1, *ipnet, ipnet, *rtIPNet, rtIPNet) + } + } + }) + } + + // See test cases in ips_test.go + for _, tc := range badTestCIDRs { + if tc.skipConvert { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, ipnet := range tc.ipnets { + prefix := PrefixFromIPNet(ipnet) + if prefix.IsValid() { + str := "" + if ipnet != nil { + str = fmt.Sprintf("%#v", *ipnet) + } + t.Errorf("Expected IPNet %d %s to convert to invalid netip.Prefix but got %#v %s", i+1, str, prefix, prefix) + } + } + }) + } +} + +func TestIPNetFromPrefix(t *testing.T) { + // See test cases in ips_test.go + for _, tc := range goodTestCIDRs { + if tc.skipConvert { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, prefix := range tc.prefixes { + ipnet := IPNetFromPrefix(prefix) + if ipnet.String() != tc.ipnets[0].String() { + t.Errorf("prefix %d %#v %s converted to ipnet %q, but expected %q", i, prefix, prefix, ipnet, tc.ipnets[0]) + } + + // As long as addr is not IPv4-mapped IPv6, it should round-trip. + if !prefix.Addr().Is4In6() { + rtPrefix := PrefixFromIPNet(ipnet) + if prefix != rtPrefix { + t.Errorf("prefix %d %#v %s round-tripped to %#v %s", i+1, prefix, prefix, rtPrefix, rtPrefix) + } + } + } + }) + } + + // Conversion of IPv4-mapped IPv6 is asymmetric because netip.Addr distinguishes + // plain IPv4 from IPv4-mapped IPv6, while net.IP does not. The "IPv4-mapped IPv6" + // test case in goodTestCIDRs covers most of the cases, but goodTestCIDRs has no way + // to describe the asymmetric part. + t.Run("IPv4-mapped IPv6 conversion from netip.Prefix", func(t *testing.T) { + prefix := netip.MustParsePrefix("::ffff:1.2.3.0/120") + if !prefix.Addr().Is4In6() { + panic("failed to create IPv4-mapped IPv6 netip.Addr?") + } + + ipnet := IPNetFromPrefix(prefix) + expected := "1.2.3.0/24" + if ipnet.String() != expected { + t.Errorf("netip.Prefix %q converted to %q, expected %q", prefix, ipnet.String(), expected) + } + rtPrefix := PrefixFromIPNet(ipnet) + if rtPrefix == prefix { + t.Errorf("IPv4-mapped IPv6 netip.Prefix unexpectedly round-tripped through net.IPNet!") + } + }) + + // See test cases in ips_test.go + for _, tc := range badTestCIDRs { + if tc.skipConvert { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, prefix := range tc.prefixes { + ipnet := IPNetFromPrefix(prefix) + if ipnet != nil { + t.Errorf("Expected Prefix %d %#v to convert to invalid net.IPNet but got %#v %s", i+1, prefix, *ipnet, ipnet) + } + } + }) + } +} diff --git a/net/v2/ipfamily.go b/net/v2/ipfamily.go new file mode 100644 index 00000000..93b71b2e --- /dev/null +++ b/net/v2/ipfamily.go @@ -0,0 +1,196 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "net" + "net/netip" +) + +// IPFamily refers to the IP family of an address or CIDR value. Its values are +// intentionally identical to those of "k8s.io/api/core/v1".IPFamily and +// "k8s.io/discovery/v1".AddressType, so you can cast values between these types. +type IPFamily string + +const ( + // IPv4 indicates an IPv4 IP or CIDR. + IPv4 IPFamily = "IPv4" + // IPv6 indicates an IPv4 IP or CIDR. + IPv6 IPFamily = "IPv6" + + // IPFamilyUnknown indicates an unspecified or invalid IP family. + IPFamilyUnknown IPFamily = "" +) + +type ipOrString interface { + net.IP | netip.Addr | string +} + +type cidrOrString interface { + *net.IPNet | netip.Prefix | string +} + +// IPFamilyOf returns the IP family of val (or IPFamilyUnknown if val is nil or invalid). +// IPv6-encoded IPv4 addresses (e.g., "::ffff:1.2.3.4") are considered IPv4. val can be a +// net.IP, a netip.Addr, or a string containing a single IP address. +// +// Note that "k8s.io/utils/net/v2".IPFamily intentionally has identical values to +// "k8s.io/api/core/v1".IPFamily and "k8s.io/discovery/v1".AddressType, so you can cast +// the return value of this function to those types. +func IPFamilyOf[T ipOrString](val T) IPFamily { + switch typedVal := interface{}(val).(type) { + case net.IP: + switch { + case typedVal.To4() != nil: + return IPv4 + case typedVal.To16() != nil: + return IPv6 + } + case netip.Addr: + switch { + case typedVal.Is4(), typedVal.Is4In6(): + return IPv4 + case typedVal.Is6(): + return IPv6 + } + case string: + if ip, _ := ParseIP(typedVal); ip != nil { + return IPFamilyOf(ip) + } + } + + return IPFamilyUnknown +} + +// IsIPv4 returns true if IPFamilyOf(val) is IPv4 (and false if it is IPv6 or invalid). +func IsIPv4[T ipOrString](val T) bool { + return IPFamilyOf(val) == IPv4 +} + +// IsIPv6 returns true if IPFamilyOf(val) is IPv6 (and false if it is IPv4 or invalid). +func IsIPv6[T ipOrString](val T) bool { + return IPFamilyOf(val) == IPv6 +} + +// IsDualStack returns true if vals contains at least one IPv4 address and at least one +// IPv6 address (and no invalid values). +func IsDualStack[T ipOrString](vals []T) bool { + v4Found := false + v6Found := false + for _, val := range vals { + switch IPFamilyOf(val) { + case IPv4: + v4Found = true + case IPv6: + v6Found = true + default: + return false + } + } + + return (v4Found && v6Found) +} + +// IsDualStackPair returns true if vals contains exactly 1 IPv4 address and 1 IPv6 address +// (in either order). +func IsDualStackPair[T ipOrString](vals []T) bool { + return len(vals) == 2 && IsDualStack(vals) +} + +// IPFamilyOfCIDR returns the IP family of val (or IPFamilyUnknown if val is nil or +// invalid). IPv6-encoded IPv4 addresses (e.g., "::ffff:1.2.3.0/120") are considered IPv4. +// val can be a *net.IPNet, a netip.Prefix, or a string containing a single CIDR value. +// +// Note that "k8s.io/utils/net/v2".IPFamily intentionally has identical values to +// "k8s.io/api/core/v1".IPFamily and "k8s.io/discovery/v1".AddressType, so you can cast +// the return value of this function to those types. +func IPFamilyOfCIDR[T cidrOrString](val T) IPFamily { + switch typedVal := interface{}(val).(type) { + case *net.IPNet: + if typedVal != nil { + family := IPFamilyOf(typedVal.IP) + // An IPv6 CIDR must have a 128-bit mask. An IPv4 CIDR must have a + // 32- or 128-bit mask. (Any other mask length is invalid.) + _, masklen := typedVal.Mask.Size() + if masklen == 128 || (family == IPv4 && masklen == 32) { + return family + } + } + case netip.Prefix: + if !typedVal.IsValid() { + return IPFamilyUnknown + } + return IPFamilyOf(typedVal.Addr()) + case string: + if ipnet, _ := ParseIPNet(typedVal); ipnet != nil { + return IPFamilyOf(ipnet.IP) + } + } + + return IPFamilyUnknown +} + +// IsIPv4CIDR returns true if IPFamilyOfCIDR(val) is IPv4 (and false if it is IPv6 or invalid). +func IsIPv4CIDR[T cidrOrString](val T) bool { + return IPFamilyOfCIDR(val) == IPv4 +} + +// IsIPv6CIDR returns true if IPFamilyOfCIDR(val) is IPv6 (and false if it is IPv4 or invalid). +func IsIPv6CIDR[T cidrOrString](val T) bool { + return IPFamilyOfCIDR(val) == IPv6 +} + +// IsDualStackCIDRs returns true if vals contains at least one IPv4 CIDR value and at +// least one IPv6 CIDR value (and no invalid values). +func IsDualStackCIDRs[T cidrOrString](vals []T) bool { + v4Found := false + v6Found := false + for _, val := range vals { + switch IPFamilyOfCIDR(val) { + case IPv4: + v4Found = true + case IPv6: + v6Found = true + default: + return false + } + } + + return (v4Found && v6Found) +} + +// IsDualStackCIDRPair returns true if vals contains exactly 1 IPv4 CIDR value and 1 IPv6 +// CIDR value (in either order). +func IsDualStackCIDRPair[T cidrOrString](vals []T) bool { + return len(vals) == 2 && IsDualStackCIDRs(vals) +} + +// OtherIPFamily returns the other IP family from ipFamily. +// +// Note that "k8s.io/utils/net/v2".IPFamily intentionally has identical values to +// "k8s.io/api/core/v1".IPFamily and "k8s.io/discovery/v1".AddressType, so you can cast +// the input/output values of this function between these types. +func OtherIPFamily(ipFamily IPFamily) IPFamily { + switch ipFamily { + case IPv4: + return IPv6 + case IPv6: + return IPv4 + default: + return IPFamilyUnknown + } +} diff --git a/net/v2/ipfamily_test.go b/net/v2/ipfamily_test.go new file mode 100644 index 00000000..886addff --- /dev/null +++ b/net/v2/ipfamily_test.go @@ -0,0 +1,340 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "fmt" + "net" + "net/netip" + "testing" +) + +func TestIsDualStack(t *testing.T) { + testCases := []struct { + desc string + ips []string + expectedResult bool + }{ + { + desc: "should fail because length is not at least 2", + ips: []string{"1.1.1.1"}, + expectedResult: false, + }, + { + desc: "should fail because length is not at least 2", + ips: []string{}, + expectedResult: false, + }, + { + desc: "should fail because all are v4", + ips: []string{"1.1.1.1", "2.2.2.2", "3.3.3.3"}, + expectedResult: false, + }, + { + desc: "should fail because all are v6", + ips: []string{"fd92:20ba:ca:34f7:ffff:ffff:ffff:ffff", "fd92:20ba:ca:34f7:ffff:ffff:ffff:fff0", "fd92:20ba:ca:34f7:ffff:ffff:ffff:fff1"}, + expectedResult: false, + }, + { + desc: "should fail because 2nd ip is invalid", + ips: []string{"1.1.1.1", "not-a-valid-ip"}, + expectedResult: false, + }, + { + desc: "should fail because 1st ip is invalid", + ips: []string{"not-a-valid-ip", "fd92:20ba:ca:34f7:ffff:ffff:ffff:ffff"}, + expectedResult: false, + }, + { + desc: "should fail despite dual-stack because 3rd ip is invalid", + ips: []string{"1.1.1.1", "fd92:20ba:ca:34f7:ffff:ffff:ffff:ffff", "not-a-valid-ip"}, + expectedResult: false, + }, + { + desc: "dual-stack ipv4-primary", + ips: []string{"1.1.1.1", "fd92:20ba:ca:34f7:ffff:ffff:ffff:ffff"}, + expectedResult: true, + }, + { + desc: "dual-stack, multiple ipv6", + ips: []string{"fd92:20ba:ca:34f7:ffff:ffff:ffff:ffff", "1.1.1.1", "fd92:20ba:ca:34f7:ffff:ffff:ffff:fff0"}, + expectedResult: true, + }, + { + desc: "dual-stack, multiple ipv4", + ips: []string{"1.1.1.1", "fd92:20ba:ca:34f7:ffff:ffff:ffff:ffff", "10.0.0.0"}, + expectedResult: true, + }, + { + desc: "dual-stack, ipv6-primary", + ips: []string{"fd92:20ba:ca:34f7:ffff:ffff:ffff:ffff", "1.1.1.1"}, + expectedResult: true, + }, + } + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + netips := make([]net.IP, len(tc.ips)) + addrs := make([]netip.Addr, len(tc.ips)) + for i := range tc.ips { + netips[i], _ = ParseIP(tc.ips[i]) + addrs[i], _ = ParseAddr(tc.ips[i]) + } + + dualStack := IsDualStack(tc.ips) + if dualStack != tc.expectedResult { + t.Errorf("expected %v, []string got %v", tc.expectedResult, dualStack) + } + if IsDualStackPair(tc.ips) != (dualStack && len(tc.ips) == 2) { + t.Errorf("IsDualStackIPPair gave wrong result for []string") + } + + dualStack = IsDualStack(netips) + if dualStack != tc.expectedResult { + t.Errorf("expected %v []net.IP got %v", tc.expectedResult, dualStack) + } + if IsDualStackPair(netips) != (dualStack && len(tc.ips) == 2) { + t.Errorf("IsDualStackIPPair gave wrong result for []net.IP") + } + + dualStack = IsDualStack(addrs) + if dualStack != tc.expectedResult { + t.Errorf("expected %v []netip.Addr got %v", tc.expectedResult, dualStack) + } + if IsDualStackPair(addrs) != (dualStack && len(tc.ips) == 2) { + t.Errorf("IsDualStackIPPair gave wrong result for []netip.Addr") + } + }) + } +} + +func TestIsDualStackCIDRs(t *testing.T) { + testCases := []struct { + desc string + cidrs []string + expectedResult bool + }{ + { + desc: "should fail because length is not at least 2", + cidrs: []string{"10.10.10.10/8"}, + expectedResult: false, + }, + { + desc: "should fail because length is not at least 2", + cidrs: []string{}, + expectedResult: false, + }, + { + desc: "should fail because all cidrs are v4", + cidrs: []string{"10.10.10.10/8", "20.20.20.20/8", "30.30.30.30/8"}, + expectedResult: false, + }, + { + desc: "should fail because all cidrs are v6", + cidrs: []string{"2000::/10", "3000::/10"}, + expectedResult: false, + }, + { + desc: "should fail because 2nd cidr is invalid", + cidrs: []string{"10.10.10.10/8", "not-a-valid-cidr"}, + expectedResult: false, + }, + { + desc: "should fail because 1st cidr is invalid", + cidrs: []string{"not-a-valid-ip", "2000::/10"}, + expectedResult: false, + }, + { + desc: "dual-stack, ipv4-primary", + cidrs: []string{"10.10.10.10/8", "2000::/10"}, + expectedResult: true, + }, + { + desc: "dual-stack, ipv6-primary", + cidrs: []string{"2000::/10", "10.10.10.10/8"}, + expectedResult: true, + }, + { + desc: "dual-stack, multiple IPv6", + cidrs: []string{"2000::/10", "10.10.10.10/8", "3000::/10"}, + expectedResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ipnets := make([]*net.IPNet, len(tc.cidrs)) + prefixes := make([]netip.Prefix, len(tc.cidrs)) + for i := range tc.cidrs { + ipnets[i], _ = ParseIPNet(tc.cidrs[i]) + prefixes[i], _ = ParsePrefix(tc.cidrs[i]) + } + + dualStack := IsDualStackCIDRs(tc.cidrs) + if dualStack != tc.expectedResult { + t.Errorf("expected %v []string got %v", tc.expectedResult, dualStack) + } + if IsDualStackCIDRPair(tc.cidrs) != (dualStack && len(tc.cidrs) == 2) { + t.Errorf("IsDualStackCIDRPair gave wrong result for []string") + } + + dualStack = IsDualStackCIDRs(ipnets) + if dualStack != tc.expectedResult { + t.Errorf("expected %v []*net.IPNet got %v", tc.expectedResult, dualStack) + } + if IsDualStackCIDRPair(ipnets) != (dualStack && len(tc.cidrs) == 2) { + t.Errorf("IsDualStackCIDRPair gave wrong result for []*net.IPNet") + } + + dualStack = IsDualStackCIDRs(prefixes) + if dualStack != tc.expectedResult { + t.Errorf("expected %v []netip.Prefix got %v", tc.expectedResult, dualStack) + } + if IsDualStackCIDRPair(prefixes) != (dualStack && len(tc.cidrs) == 2) { + t.Errorf("IsDualStackCIDRPair gave wrong result for []netip.Prefix") + } + }) + } +} + +func checkOneIPFamily(t *testing.T, ip string, expectedFamily, family IPFamily, isIPv4, isIPv6 bool) { + t.Helper() + if family != expectedFamily { + t.Errorf("Expect %q family %q, got %q", ip, expectedFamily, family) + } + if isIPv4 != (expectedFamily == IPv4) { + t.Errorf("Expect %q ipv4 %v, got %v", ip, expectedFamily == IPv4, isIPv6) + } + if isIPv6 != (expectedFamily == IPv6) { + t.Errorf("Expect %q ipv6 %v, got %v", ip, expectedFamily == IPv6, isIPv6) + } +} + +func TestIPFamilyOf(t *testing.T) { + // See test cases in ips_test.go + for _, tc := range goodTestIPs { + if tc.skipFamily { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for _, str := range tc.strings { + family := IPFamilyOf(str) + isIPv4 := IsIPv4(str) + isIPv6 := IsIPv6(str) + checkOneIPFamily(t, str, tc.family, family, isIPv4, isIPv6) + } + for _, ip := range tc.ips { + family := IPFamilyOf(ip) + isIPv4 := IsIPv4(ip) + isIPv6 := IsIPv6(ip) + checkOneIPFamily(t, ip.String(), tc.family, family, isIPv4, isIPv6) + } + for _, addr := range tc.addrs { + family := IPFamilyOf(addr) + isIPv4 := IsIPv4(addr) + isIPv6 := IsIPv6(addr) + checkOneIPFamily(t, addr.String(), tc.family, family, isIPv4, isIPv6) + } + }) + } + + // See test cases in ips_test.go + for _, tc := range badTestIPs { + if tc.skipFamily { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for _, ip := range tc.ips { + family := IPFamilyOf(ip) + isIPv4 := IsIPv4(ip) + isIPv6 := IsIPv6(ip) + checkOneIPFamily(t, fmt.Sprintf("%#v", ip), IPFamilyUnknown, family, isIPv4, isIPv6) + } + for _, addr := range tc.addrs { + family := IPFamilyOf(addr) + isIPv4 := IsIPv4(addr) + isIPv6 := IsIPv6(addr) + checkOneIPFamily(t, fmt.Sprintf("%#v", addr), IPFamilyUnknown, family, isIPv4, isIPv6) + } + for _, str := range tc.strings { + family := IPFamilyOf(str) + isIPv4 := IsIPv4(str) + isIPv6 := IsIPv6(str) + checkOneIPFamily(t, str, IPFamilyUnknown, family, isIPv4, isIPv6) + } + }) + } +} + +func TestIPFamilyOfCIDR(t *testing.T) { + // See test cases in ips_test.go + for _, tc := range goodTestCIDRs { + if tc.skipFamily { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for _, str := range tc.strings { + family := IPFamilyOfCIDR(str) + isIPv4 := IsIPv4CIDR(str) + isIPv6 := IsIPv6CIDR(str) + checkOneIPFamily(t, str, tc.family, family, isIPv4, isIPv6) + } + for _, ipnet := range tc.ipnets { + family := IPFamilyOfCIDR(ipnet) + isIPv4 := IsIPv4CIDR(ipnet) + isIPv6 := IsIPv6CIDR(ipnet) + checkOneIPFamily(t, ipnet.String(), tc.family, family, isIPv4, isIPv6) + } + for _, prefix := range tc.prefixes { + family := IPFamilyOfCIDR(prefix) + isIPv4 := IsIPv4CIDR(prefix) + isIPv6 := IsIPv6CIDR(prefix) + checkOneIPFamily(t, prefix.String(), tc.family, family, isIPv4, isIPv6) + } + }) + } + + // See test cases in ips_test.go + for _, tc := range badTestCIDRs { + if tc.skipFamily { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for _, ipnet := range tc.ipnets { + family := IPFamilyOfCIDR(ipnet) + isIPv4 := IsIPv4CIDR(ipnet) + isIPv6 := IsIPv6CIDR(ipnet) + str := "" + if ipnet != nil { + str = fmt.Sprintf("%#v", *ipnet) + } + checkOneIPFamily(t, str, IPFamilyUnknown, family, isIPv4, isIPv6) + } + for _, prefix := range tc.prefixes { + family := IPFamilyOfCIDR(prefix) + isIPv4 := IsIPv4CIDR(prefix) + isIPv6 := IsIPv6CIDR(prefix) + checkOneIPFamily(t, fmt.Sprintf("%#v", prefix), IPFamilyUnknown, family, isIPv4, isIPv6) + } + for _, str := range tc.strings { + family := IPFamilyOfCIDR(str) + isIPv4 := IsIPv4CIDR(str) + isIPv6 := IsIPv6CIDR(str) + checkOneIPFamily(t, str, IPFamilyUnknown, family, isIPv4, isIPv6) + } + }) + } +} diff --git a/net/v2/ips_test.go b/net/v2/ips_test.go new file mode 100644 index 00000000..8013b815 --- /dev/null +++ b/net/v2/ips_test.go @@ -0,0 +1,797 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "net" + "net/netip" + "testing" +) + +// testIP represents a set of equivalent IP address representations. +type testIP struct { + desc string + family IPFamily + strings []string + ips []net.IP + addrs []netip.Addr + + skipFamily bool + skipParse bool + skipConvert bool +} + +// goodTestIPs are "good" test IP values. For each item: +// +// Preconditions (not involving functions in netutils): +// - Each element of .ips is the same (i.e., .Equal()). +// - Each element of .ips stringifies to .strings[0]. +// - Each element of .addrs is the same (i.e., ==). +// - Each element of .addrs stringifies to .strings[0]. +// +// IPFamily tests (unless `skipFamily: true`): +// - Each element of .strings should be identified as .family. +// - Each element of .ips should be identified as .family. +// - Each element of .addrs should be identified as .family. +// +// Parsing tests (unless `skipParse: true`): +// - Each element of .strings should parse to a value equal to .ips[0]. +// - Each element of .strings should parse to a value equal to .addrs[0]. +// +// Conversion tests (unless `skipConvert: true`): +// - Each element of .ips should convert to a value equal to .addrs[0]. +// - Each element of .addrs should convert to a value equal to .ips[0]. +var goodTestIPs = []testIP{ + { + desc: "IPv4", + family: IPv4, + strings: []string{ + "192.168.0.5", + "192.168.000.005", + }, + ips: []net.IP{ + net.IPv4(192, 168, 0, 5), + {192, 168, 0, 5}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 192, 168, 0, 5}, + net.ParseIP("192.168.0.5"), + func() net.IP { ip, _, _ := net.ParseCIDR("192.168.0.5/24"); return ip }(), + func() net.IP { _, ipnet, _ := net.ParseCIDR("192.168.0.5/32"); return ipnet.IP }(), + }, + addrs: []netip.Addr{ + netip.AddrFrom4([4]byte{192, 168, 0, 5}), + netip.MustParseAddr("192.168.0.5"), + netip.MustParsePrefix("192.168.0.5/24").Addr(), + }, + }, + { + desc: "IPv4 all-zeros", + family: IPv4, + strings: []string{ + "0.0.0.0", + "000.000.000.000", + }, + ips: []net.IP{ + net.IPv4zero, + net.IPv4(0, 0, 0, 0), + {0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0, 0, 0, 0}, + net.ParseIP("0.0.0.0"), + }, + addrs: []netip.Addr{ + netip.IPv4Unspecified(), + netip.AddrFrom4([4]byte{0, 0, 0, 0}), + netip.MustParseAddr("0.0.0.0"), + }, + }, + { + desc: "IPv4 broadcast", + family: IPv4, + strings: []string{ + "255.255.255.255", + }, + ips: []net.IP{ + net.IPv4bcast, + net.IPv4(255, 255, 255, 255), + {255, 255, 255, 255}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 255, 255, 255, 255}, + net.ParseIP("255.255.255.255"), + // A /32 IPMask is equivalent to 255.255.255.255 + func() net.IP { _, ipnet, _ := net.ParseCIDR("1.2.3.4/32"); return net.IP(ipnet.Mask) }(), + }, + addrs: []netip.Addr{ + netip.AddrFrom4([4]byte{0xFF, 0xFF, 0xFF, 0xFF}), + netip.MustParseAddr("255.255.255.255"), + }, + }, + { + desc: "IPv6", + family: IPv6, + strings: []string{ + "2001:db8::5", + "2001:0db8::0005", + "2001:DB8::5", + }, + ips: []net.IP{ + {0x20, 0x01, 0x0D, 0xB8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x05}, + net.ParseIP("2001:db8::5"), + func() net.IP { ip, _, _ := net.ParseCIDR("2001:db8::5/64"); return ip }(), + func() net.IP { _, ipnet, _ := net.ParseCIDR("2001:db8::5/128"); return ipnet.IP }(), + }, + addrs: []netip.Addr{ + netip.AddrFrom16([16]byte{0x20, 0x01, 0x0D, 0xB8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x05}), + netip.MustParseAddr("2001:db8::5"), + netip.MustParsePrefix("2001:db8::5/64").Addr(), + }, + }, + { + desc: "IPv6 all-zeros", + family: IPv6, + strings: []string{ + "::", + "0:0:0:0:0:0:0:0", + "0000:0000:0000:0000:0000:0000:0000:0000", + "0::0", + }, + ips: []net.IP{ + net.IPv6zero, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + net.ParseIP("::"), + // ::/0 has an IP, network base IP, and Mask that are all + // equivalent to :: + func() net.IP { ip, _, _ := net.ParseCIDR("::/0"); return ip }(), + func() net.IP { _, ipnet, _ := net.ParseCIDR("::/0"); return ipnet.IP }(), + func() net.IP { _, ipnet, _ := net.ParseCIDR("::/0"); return net.IP(ipnet.Mask) }(), + }, + addrs: []netip.Addr{ + netip.IPv6Unspecified(), + netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}), + netip.MustParseAddr("::"), + netip.MustParsePrefix("::/0").Addr(), + }, + }, + { + desc: "IPv6 loopback", + family: IPv6, + strings: []string{ + "::1", + "0000:0000:0000:0000:0000:0000:0000:0001", + }, + ips: []net.IP{ + net.IPv6loopback, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, + net.ParseIP("::1"), + }, + addrs: []netip.Addr{ + netip.IPv6Loopback(), + netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}), + netip.MustParseAddr("::1"), + }, + }, + { + desc: "IPv4-mapped IPv6", + // net.IP can represent an IPv4 address internally as either a 4-byte + // value or a 16-byte value, but it treats the two forms as equivalent. + // Because IPv4-mapped IPv6 is annoying, we make our ParseAddr() behave + // this way too, even though that's *not* how netip.ParseAddr() behaves. + // + // This test case confirms that: + // - The 4-byte and 16-byte forms of a given net.IP compare as .Equal(). + // - Our parsers parse the plain IPv4 and IPv4-mapped IPv6 forms of an + // IPv4 string to the same thing. + // - The 4-byte and 16-byte forms of a given net.IP, and the 4-byte + // (but *not* 16-byte) form of netip.Addr, all stringify to the plain + // IPv4 string form (i.e., .strings[0]). + family: IPv4, + strings: []string{ + "192.168.0.5", + "::ffff:192.168.0.5", + "::ffff:0192.0168.0000.0005", + }, + ips: []net.IP{ + {192, 168, 0, 5}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 192, 168, 0, 5}, + net.IPv4(192, 168, 0, 5).To4(), + net.IPv4(192, 168, 0, 5).To16(), + net.ParseIP("192.168.0.5").To4(), + net.ParseIP("192.168.0.5").To16(), + net.ParseIP("::ffff:192.168.0.5").To4(), + net.ParseIP("::ffff:192.168.0.5").To16(), + }, + addrs: []netip.Addr{ + netip.AddrFrom4([4]byte{192, 168, 0, 5}), + netip.MustParseAddr("192.168.0.5"), + }, + }, + { + desc: "IPv4-mapped IPv6 (netip.Addr)", + // In constrast to net.IP, netip.Addr considers plain IPv4 and IPv4-mapped + // IPv6 to be distinct things, and netip.ParseAddr will parse the plain + // IPv4 and IPv4-mapped IPv6 strings into distinct netip.Addr values + // (where the IPv4-mapped IPv6 netip.Addr value does not correspond + // exactly to any net.IP value). + family: IPv4, + strings: []string{ + "::ffff:192.168.0.5", + }, + addrs: []netip.Addr{ + netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 192, 168, 0, 5}), + netip.MustParseAddr("::ffff:192.168.0.5"), + }, + + // Skip the parsing tests, because no netutils method will parse + // .strings[0] to .addrs[0]. + skipParse: true, + + // Skip the conversion tests, because there is no net.IP value that + // unambiguously corresponds to these netip.Addr values. TestIPFromAddr() + // has a special case to test that an IPv4-mapped IPv6 netip.Addr maps to + // the expected net.IP value (which then doesn't round-trip back to the + // original netip.Addr value). + skipConvert: true, + }, +} + +// badTestIPs are bad test IP values. For each item: +// +// IPFamily tests (unless `skipFamily: true`): +// - Each element of .strings should be identified as IPFamilyUnknown. +// - Each element of .ips should be identified as IPFamilyUnknown. +// - Each element of .addrs should be identified as IPFamilyUnknown. +// +// Parsing tests (unless `skipParse: true`): +// - Each element of .strings should fail to parse. +// - Each element of .ips should stringify to an error value that fails to parse. +// - Each element of .addrs should stringify to an error value that fails to parse. +// +// Conversion tests (unless `skipConvert: true`: +// - Each element of .ips should convert to an invalid netip.Addr. +// - Each element of .addrs should convert to a nil net.IP. +var badTestIPs = []testIP{ + { + desc: "empty string is not an IP", + strings: []string{ + "", + }, + }, + { + desc: "random non-IP string is not an IP", + strings: []string{ + "bad ip", + }, + }, + { + desc: "domain name is not an IP", + strings: []string{ + "www.example.com", + }, + }, + { + desc: "mangled IPv4 addresses are invalid", + strings: []string{ + "1.2.3.400", + "1.2..4", + "1.2.3", + "1.2.3.4.5", + }, + }, + { + desc: "mangled IPv6 addresses are invalid", + strings: []string{ + "1:2::12345", + "1::2::3", + "1:2:::3", + "1:2:3", + "1:2:3:4:5:6:7:8:9", + "1:2:3:4::6:7:8:9", + }, + }, + { + desc: "IPs do not have ports or brackets", + strings: []string{ + "1.2.3.4:80", + "[2001:db8::5]", + "[2001:db8::5]:80", + "www.example.com:80", + }, + }, + { + desc: "IPs with zones are invalid", + strings: []string{ + "169.254.169.254%eth0", + "fe80::1234%eth0", + }, + }, + { + desc: "CIDR strings are not IPs", + strings: []string{ + "1.2.3.0/24", + "2001:db8::/64", + }, + }, + { + desc: "IPs with whitespace are invalid", + strings: []string{ + " 1.2.3.4", + "1.2.3.4 ", + " 2001:db8::5", + "2001:db8::5 ", + }, + }, + { + desc: "nil is an invalid net.IP", + ips: []net.IP{ + nil, + }, + }, + { + desc: "a byte slice of length other than 4 or 16 is an invalid net.IP", + ips: []net.IP{ + {}, + {1, 2, 3}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + }, + }, + { + desc: "the zero netip.Addr is invalid", + addrs: []netip.Addr{ + {}, + }, + }, +} + +// testCIDR represents a set of equivalent CIDR representations. +type testCIDR struct { + desc string + ifaddr bool + family IPFamily + strings []string + ipnets []*net.IPNet + prefixes []netip.Prefix + + skipFamily bool + skipParse bool + skipConvert bool +} + +// goodTestCIDRs are "good" test CIDR values. For each item: +// +// Preconditions: +// - Each element of .ipnets stringifies to .strings[0]. +// - Each element of .prefixes is the same (i.e., ==). +// - Each element of .prefixes stringifies to .strings[0]. +// +// IPFamily tests (unless `skipFamily: true`): +// - Each element of .strings should be identified as .family. +// - Each element of .ipnets should be identified as .family. +// - Each element of .prefixes should be identified as .family. +// +// Parsing tests (unless `skipParse: true`): +// - Each element of .strings should parse to a value "equal" to .ipnets[0] +// (via ParseIPNet if .ifaddr is false, or ParseIPAsIPNet if .ifaddr is true). +// - Each element of .strings should parse to a value equal to .prefixes[0] +// (via ParsePrefix if .ifaddr is false, or ParseAddrAsPrefix if .ifaddr is true). +// +// Conversion tests (unless `skipConvert: true`): +// - Each element of .ipnets should convert to a value equal to .prefixes[0]. +// - Each element of .prefixes should convert to a value "equal" to .ipnets[0]. +// +// (Unlike net.IP, *net.IPNet has no `.Equal()` method, and testing equality "by hand" is +// complicated (there are 4 equivalent representations of every IPv4 CIDR value), so we +// just consider two *net.IPNet values to be equal if they stringify to the same value.) +var goodTestCIDRs = []testCIDR{ + { + desc: "IPv4", + family: IPv4, + strings: []string{ + "1.2.3.0/24", + }, + ipnets: []*net.IPNet{ + {IP: net.IPv4(1, 2, 3, 0), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("1.2.3.0"), Mask: net.CIDRMask(24, 32)}, + func() *net.IPNet { _, ipnet, _ := net.ParseCIDR("1.2.3.0/24"); return ipnet }(), + }, + prefixes: []netip.Prefix{ + netip.MustParsePrefix("1.2.3.0/24"), + netip.PrefixFrom(netip.MustParseAddr("1.2.3.0"), 24), + netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 0}), 24), + }, + }, + { + desc: "IPv4, single IP", + family: IPv4, + strings: []string{ + "1.1.1.1/32", + }, + ipnets: []*net.IPNet{ + {IP: net.IPv4(1, 1, 1, 1), Mask: net.CIDRMask(32, 32)}, + func() *net.IPNet { _, ipnet, _ := net.ParseCIDR("1.1.1.1/32"); return ipnet }(), + }, + prefixes: []netip.Prefix{ + netip.MustParsePrefix("1.1.1.1/32"), + netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32), + }, + }, + { + desc: "IPv4, all IPs", + family: IPv4, + strings: []string{ + "0.0.0.0/0", + "000.000.000.000/000", + }, + ipnets: []*net.IPNet{ + {IP: net.IPv4zero.To4(), Mask: net.IPMask(net.IPv4zero.To4())}, + {IP: net.IPv4(0, 0, 0, 0), Mask: net.CIDRMask(0, 32)}, + func() *net.IPNet { _, ipnet, _ := net.ParseCIDR("0.0.0.0/0"); return ipnet }(), + }, + prefixes: []netip.Prefix{ + netip.MustParsePrefix("0.0.0.0/0"), + netip.PrefixFrom(netip.AddrFrom4([4]byte{0, 0, 0, 0}), 0), + netip.PrefixFrom(netip.IPv4Unspecified(), 0), + }, + }, + { + desc: "IPv4 ifaddr (masked)", + ifaddr: false, + // This tests that if you try to parse an "ifaddr-style" CIDR string with + // ParseIPNet/ParsePrefix, the return value has the bits beyond the prefix + // length masked out. + family: IPv4, + strings: []string{ + "1.2.3.0/24", + "1.2.3.4/24", + "1.2.3.255/24", + }, + ipnets: []*net.IPNet{ + {IP: net.IPv4(1, 2, 3, 0), Mask: net.CIDRMask(24, 32)}, + func() *net.IPNet { _, ipnet, _ := net.ParseCIDR("1.2.3.0/24"); return ipnet }(), + func() *net.IPNet { _, ipnet, _ := net.ParseCIDR("1.2.3.4/24"); return ipnet }(), + }, + prefixes: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 0}), 24), + netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24).Masked(), + netip.MustParsePrefix("1.2.3.0/24"), + netip.MustParsePrefix("1.2.3.4/24").Masked(), + }, + }, + { + desc: "IPv4 ifaddr", + ifaddr: true, + family: IPv4, + strings: []string{ + "1.2.3.4/24", + }, + ipnets: []*net.IPNet{ + {IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}, + }, + prefixes: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24), + netip.MustParsePrefix("1.2.3.4/24"), + }, + }, + { + desc: "IPv6", + family: IPv6, + strings: []string{ + "2001:db8::/64", + "2001:db8:0:0:0:0:0:0/64", + "2001:DB8::/64", + }, + ipnets: []*net.IPNet{ + {IP: net.IP{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, Mask: net.IPMask{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {IP: net.ParseIP("2001:db8::"), Mask: net.CIDRMask(64, 128)}, + func() *net.IPNet { _, ipnet, _ := net.ParseCIDR("2001:db8::/64"); return ipnet }(), + }, + prefixes: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/64"), + netip.PrefixFrom(netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}), 64), + }, + }, + { + desc: "IPv6, all IPs", + family: IPv6, + strings: []string{ + "::/0", + }, + ipnets: []*net.IPNet{ + {IP: net.IPv6zero, Mask: net.IPMask(net.IPv6zero)}, + {IP: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, Mask: net.CIDRMask(0, 128)}, + func() *net.IPNet { _, ipnet, _ := net.ParseCIDR("::/0"); return ipnet }(), + }, + prefixes: []netip.Prefix{ + netip.MustParsePrefix("::/0"), + netip.PrefixFrom(netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}), 0), + netip.PrefixFrom(netip.IPv6Unspecified(), 0), + }, + }, + { + desc: "IPv6, single IP", + family: IPv6, + strings: []string{ + "::1/128", + }, + ipnets: []*net.IPNet{ + {IP: net.IPv6loopback, Mask: net.CIDRMask(128, 128)}, + {IP: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}, Mask: net.CIDRMask(128, 128)}, + func() *net.IPNet { _, ipnet, _ := net.ParseCIDR("::1/128"); return ipnet }(), + }, + prefixes: []netip.Prefix{ + netip.MustParsePrefix("::1/128"), + netip.PrefixFrom(netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}), 128), + }, + }, + { + desc: "IPv6 ifaddr (masked)", + ifaddr: false, + // This tests that if you try to parse an "ifaddr-style" CIDR string with + // ParseIPNet, it value has the bits beyond the prefix length masked out. + family: IPv6, + strings: []string{ + "2001:db8::/64", + "2001:db8::1/64", + "2001:db8::f00f:f0f0:0f0f:000f/64", + }, + ipnets: []*net.IPNet{ + {IP: net.IP{0x20, 0x01, 0x0D, 0xB8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, Mask: net.CIDRMask(64, 128)}, + func() *net.IPNet { _, ipnet, _ := net.ParseCIDR("2001:db8::/64"); return ipnet }(), + func() *net.IPNet { _, ipnet, _ := net.ParseCIDR("2001:db8::1/64"); return ipnet }(), + }, + prefixes: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}), 64), + netip.PrefixFrom(netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}), 64).Masked(), + netip.MustParsePrefix("2001:db8::/64"), + netip.MustParsePrefix("2001:db8::1/64").Masked(), + }, + }, + { + desc: "IPv6 ifaddr", + ifaddr: true, + family: IPv6, + strings: []string{ + "2001:db8::1/64", + }, + ipnets: []*net.IPNet{ + {IP: net.IP{0x20, 0x01, 0x0D, 0xB8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, Mask: net.CIDRMask(64, 128)}, + }, + prefixes: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}), 64), + netip.MustParsePrefix("2001:db8::1/64"), + }, + }, + { + desc: "IPv4-mapped IPv6", + // As in the IP tests, confirm that plain IPv4 and IPv4-mapped IPv6 are + // treated as equivalent. + family: IPv4, + strings: []string{ + "1.1.1.0/24", + "::ffff:1.1.1.0/120", + "::ffff:01.01.01.00/0120", + }, + ipnets: []*net.IPNet{ + {IP: net.IP{1, 1, 1, 0}, Mask: net.CIDRMask(24, 32)}, + {IP: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 1, 1, 1, 0}, Mask: net.CIDRMask(120, 128)}, + func() *net.IPNet { _, ipnet, _ := net.ParseCIDR("1.1.1.0/24"); return ipnet }(), + func() *net.IPNet { _, ipnet, _ := net.ParseCIDR("::ffff:1.1.1.0/120"); return ipnet }(), + + // Explicitly test each of the 4 different combinations of 4-byte + // or 16-byte IP and 4-byte or 16-byte Mask, all of which should + // compare as equal and re-stringify to "1.1.1.0/24". + {IP: net.IP{1, 1, 1, 0}, Mask: net.IPMask{255, 255, 255, 0}}, + {IP: net.IP{1, 1, 1, 0}, Mask: net.IPMask{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0}}, + {IP: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 1, 1, 1, 0}, Mask: net.IPMask{255, 255, 255, 0}}, + {IP: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 1, 1, 1, 0}, Mask: net.IPMask{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0}}, + }, + prefixes: []netip.Prefix{ + netip.MustParsePrefix("1.1.1.0/24"), + netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 0}), 24), + }, + }, + { + // As in the IP/Addr tests, additional checks for IPv4-mapped IPv6 netip + // values. + desc: "IPv4-mapped IPv6 (netip.Prefix)", + family: IPv4, + strings: []string{ + "::ffff:1.1.1.0/120", + }, + prefixes: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 1, 1, 1, 0}), 120), + netip.MustParsePrefix("::ffff:1.1.1.0/120"), + }, + + // Skip the parsing tests, because no netutils method will parse + // .strings[0] to .prefixes[0]. + skipParse: true, + + // Skip the conversion tests, because there is no *net.IPNet value that + // unambiguously corresponds to these netip.Prefix values. + // TestIPNetFromPrefix() has a special case to test that a netip.Prefix + // with an IPv4-mapped IPv6 address maps to the expected *net.IPNet value + // (which then doesn't round-trip back to the original netip.Prefix value). + skipConvert: true, + }, +} + +// badTestCIDRs are bad test CIDR values. For each item: +// +// IPFamily tests (unless `skipFamily: true`): +// - Each element of .strings should be identified as IPFamilyUnknown. +// - Each element of .ipnets should be identified as IPFamilyUnknown. +// - Each element of .prefixes should be identified as IPFamilyUnknown. +// +// Parsing tests (unless `skipParse: true`): +// - Each element of .strings should fail to parse. +// - Each element of .ipnets should stringify to some error value that fails to parse. +// - Each element of .prefixes should stringify to some error value that fails to parse. +// +// Conversion tests (unless `skipConvert: true`): +// - Each element of .ipnets should convert to an invalid netip.Prefix. +// - Each element of .prefixes should convert to a nil *net.IPNet. +var badTestCIDRs = []testCIDR{ + { + desc: "empty string is not a CIDR", + strings: []string{ + "", + }, + }, + { + desc: "random unparseable string is not a CIDR", + strings: []string{ + "bad cidr", + }, + }, + { + desc: "CIDR with invalid IP is invalid", + strings: []string{ + "1.2.300.0/24", + "2001:db8000::/64", + }, + }, + { + desc: "CIDR with invalid prefix length is invalid", + strings: []string{ + "1.2.3.4/64", + "2001:db8::5/192", + "1.2.3.0/-8", + "1.2.3.0/+24", + }, + }, + { + desc: "URLs (that aren't also valid CIDRs) are invalid", + strings: []string{ + "www.example.com/24", + "192.168.0.1/0/99", + }, + }, + { + desc: "plain IP is not a CIDR", + strings: []string{ + "1.2.3.4", + "2001:db8::1", + }, + }, + { + desc: "CIDR with whitespace is invalid", + strings: []string{ + " 1.2.3.0/24", + "1.2.3.0/24 ", + }, + }, + { + desc: "nil is an invalid IPNet", + ipnets: []*net.IPNet{ + nil, + }, + }, + { + desc: "IPNet containing invalid IP is invalid", + ipnets: []*net.IPNet{ + {IP: net.IP{0x1}, Mask: net.CIDRMask(24, 32)}, + }, + }, + { + desc: "IPNet containing non-CIDR Mask is invalid", + ipnets: []*net.IPNet{ + {IP: net.IP{192, 168, 0, 0}, Mask: net.IPMask{255, 0, 255, 0}}, + }, + }, + { + desc: "IPNet containing IPv6 IP and IPv4 Mask is invalid", + ipnets: []*net.IPNet{ + {IP: net.IP{0x20, 0x01, 0x0D, 0xB8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, Mask: net.CIDRMask(24, 32)}, + }, + }, + { + desc: "the zero netip.Prefix is invalid", + family: IPFamilyUnknown, + prefixes: []netip.Prefix{ + {}, + }, + }, + { + desc: "Prefix containing an invalid Addr is invalid", + family: IPFamilyUnknown, + prefixes: []netip.Prefix{ + netip.PrefixFrom(netip.Addr{}, 24), + }, + }, + { + desc: "Prefix containing a negative length is invalid", + family: IPv4, + prefixes: []netip.Prefix{ + netip.PrefixFrom(netip.IPv4Unspecified(), -1), + }, + }, + { + desc: "Prefix containing a too-long length is invalid", + family: IPv4, + prefixes: []netip.Prefix{ + netip.PrefixFrom(netip.IPv4Unspecified(), 64), + }, + }, +} + +// TestGoodTestIPs confirms the Preconditions for goodTestIPs. +func TestGoodTestIPs(t *testing.T) { + for _, tc := range goodTestIPs { + t.Run(tc.desc, func(t *testing.T) { + for i, ip := range tc.ips { + if !ip.Equal(tc.ips[0]) { + t.Errorf("BAD TEST DATA: IP %d %#v %q does not equal %#v %q", i+1, ip, ip, tc.ips[0], tc.ips[0]) + } + str := ip.String() + if str != tc.strings[0] { + t.Errorf("BAD TEST DATA: IP %d %#v %q does not stringify to %q", i+1, ip, ip, tc.strings[0]) + } + } + + for i, addr := range tc.addrs { + if addr != tc.addrs[0] { + t.Errorf("BAD TEST DATA: Addr %d %#v %q does not equal %#v %q", i+1, addr, addr, tc.addrs[0], tc.addrs[0]) + } + str := addr.String() + if str != tc.strings[0] { + t.Errorf("BAD TEST DATA: Addr %d %#v %q does not stringify to %q", i+1, addr, addr, tc.strings[0]) + } + } + }) + } +} + +// TestGoodTestCIDRs confirms the Preconditions for goodTestCIDRs. +func TestGoodTestCIDRs(t *testing.T) { + for _, tc := range goodTestCIDRs { + t.Run(tc.desc, func(t *testing.T) { + for i, ipnet := range tc.ipnets { + if ipnet.String() != tc.strings[0] { + t.Errorf("BAD TEST DATA: IPNet %d %#v %q does not stringify to %q", i+1, ipnet, ipnet, tc.strings[0]) + } + } + + for i, prefix := range tc.prefixes { + if prefix != tc.prefixes[0] { + t.Errorf("BAD TEST DATA: Prefix %d %#v %q does not equal %#v %q", i+1, prefix, prefix, tc.prefixes[0], tc.prefixes[0]) + } + str := prefix.String() + if str != tc.strings[0] { + t.Errorf("BAD TEST DATA: Prefix %d %#v %q does not stringify to %q", i+1, prefix, prefix, tc.strings[0]) + } + } + }) + } +} diff --git a/net/v2/multi_listen.go b/net/v2/multi_listen.go new file mode 100644 index 00000000..e5d50805 --- /dev/null +++ b/net/v2/multi_listen.go @@ -0,0 +1,195 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" +) + +// connErrPair pairs conn and error which is returned by accept on sub-listeners. +type connErrPair struct { + conn net.Conn + err error +} + +// multiListener implements net.Listener +type multiListener struct { + listeners []net.Listener + wg sync.WaitGroup + + // connCh passes accepted connections, from child listeners to parent. + connCh chan connErrPair + // stopCh communicates from parent to child listeners. + stopCh chan struct{} + closed atomic.Bool +} + +// compile time check to ensure *multiListener implements net.Listener +var _ net.Listener = &multiListener{} + +// MultiListen returns net.Listener which can listen on and accept connections for +// the given network on multiple addresses. Internally it uses stdlib to create +// sub-listener and multiplexes connection requests using go-routines. +// The network must be "tcp", "tcp4" or "tcp6". +// It follows the semantics of net.Listen that primarily means: +// 1. If the host is an unspecified/zero IP address with "tcp" network, MultiListen +// listens on all available unicast and anycast IP addresses of the local system. +// 2. Use "tcp4" or "tcp6" to exclusively listen on IPv4 or IPv6 family, respectively. +// 3. The host can accept names (e.g, localhost) and it will create a listener for at +// most one of the host's IP. +func MultiListen(ctx context.Context, network string, addrs ...string) (net.Listener, error) { + var lc net.ListenConfig + return multiListen( + ctx, + network, + addrs, + func(ctx context.Context, network, address string) (net.Listener, error) { + return lc.Listen(ctx, network, address) + }) +} + +// multiListen implements MultiListen by consuming stdlib functions as dependency allowing +// mocking for unit-testing. +func multiListen( + ctx context.Context, + network string, + addrs []string, + listenFunc func(ctx context.Context, network, address string) (net.Listener, error), +) (net.Listener, error) { + if !(network == "tcp" || network == "tcp4" || network == "tcp6") { + return nil, fmt.Errorf("network %q not supported", network) + } + if len(addrs) == 0 { + return nil, fmt.Errorf("no address provided to listen on") + } + + ml := &multiListener{ + connCh: make(chan connErrPair), + stopCh: make(chan struct{}), + } + for _, addr := range addrs { + l, err := listenFunc(ctx, network, addr) + if err != nil { + // close all the sub-listeners and exit + _ = ml.Close() + return nil, err + } + ml.listeners = append(ml.listeners, l) + } + + for _, l := range ml.listeners { + ml.wg.Add(1) + go func(l net.Listener) { + defer ml.wg.Done() + for { + // Accept() is blocking, unless ml.Close() is called, in which + // case it will return immediately with an error. + conn, err := l.Accept() + // This assumes that ANY error from Accept() will terminate the + // sub-listener. We could maybe be more precise, but it + // doesn't seem necessary. + terminate := err != nil + + select { + case ml.connCh <- connErrPair{conn: conn, err: err}: + case <-ml.stopCh: + // In case we accepted a connection AND were stopped, and + // this select-case was chosen, just throw away the + // connection. This avoids potentially blocking on connCh + // or leaking a connection. + if conn != nil { + _ = conn.Close() + } + terminate = true + } + // Make sure we don't loop on Accept() returning an error and + // the select choosing the channel case. + if terminate { + return + } + } + }(l) + } + return ml, nil +} + +// Accept implements net.Listener. It waits for and returns a connection from +// any of the sub-listener. +func (ml *multiListener) Accept() (net.Conn, error) { + // wait for any sub-listener to enqueue an accepted connection + connErr, ok := <-ml.connCh + if !ok { + // The channel will be closed only when Close() is called on the + // multiListener. Closing of this channel implies that all + // sub-listeners are also closed, which causes a "use of closed + // network connection" error on their Accept() calls. We return the + // same error for multiListener.Accept() if multiListener.Close() + // has already been called. + return nil, fmt.Errorf("use of closed network connection") + } + return connErr.conn, connErr.err +} + +// Close implements net.Listener. It will close all sub-listeners and wait for +// the go-routines to exit. +func (ml *multiListener) Close() error { + // Make sure this can be called repeatedly without explosions. + if !ml.closed.CompareAndSwap(false, true) { + return fmt.Errorf("use of closed network connection") + } + + // Tell all sub-listeners to stop. + close(ml.stopCh) + + // Closing the listeners causes Accept() to immediately return an error in + // the sub-listener go-routines. + for _, l := range ml.listeners { + _ = l.Close() + } + + // Wait for all the sub-listener go-routines to exit. + ml.wg.Wait() + close(ml.connCh) + + // Drain any already-queued connections. + for connErr := range ml.connCh { + if connErr.conn != nil { + _ = connErr.conn.Close() + } + } + return nil +} + +// Addr is an implementation of the net.Listener interface. It always returns +// the address of the first listener. Callers should use conn.LocalAddr() to +// obtain the actual local address of the sub-listener. +func (ml *multiListener) Addr() net.Addr { + return ml.listeners[0].Addr() +} + +// Addrs is like Addr, but returns the address for all registered listeners. +func (ml *multiListener) Addrs() []net.Addr { + var ret []net.Addr + for _, l := range ml.listeners { + ret = append(ret, l.Addr()) + } + return ret +} diff --git a/net/v2/multi_listen_test.go b/net/v2/multi_listen_test.go new file mode 100644 index 00000000..e2121f9b --- /dev/null +++ b/net/v2/multi_listen_test.go @@ -0,0 +1,545 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "context" + "fmt" + "io" + "net" + "net/http" + "strconv" + "sync/atomic" + "testing" + "time" +) + +type fakeCon struct { + remoteAddr net.Addr +} + +func (f *fakeCon) Read(_ []byte) (n int, err error) { + return 0, nil +} + +func (f *fakeCon) Write(_ []byte) (n int, err error) { + return 0, nil +} + +func (f *fakeCon) Close() error { + return nil +} + +func (f *fakeCon) LocalAddr() net.Addr { + return nil +} + +func (f *fakeCon) RemoteAddr() net.Addr { + return f.remoteAddr +} + +func (f *fakeCon) SetDeadline(_ time.Time) error { + return nil +} + +func (f *fakeCon) SetReadDeadline(_ time.Time) error { + return nil +} + +func (f *fakeCon) SetWriteDeadline(_ time.Time) error { + return nil +} + +var _ net.Conn = &fakeCon{} + +type fakeListener struct { + addr net.Addr + index int + err error + closed atomic.Bool + connErrPairs []connErrPair +} + +func (f *fakeListener) Accept() (net.Conn, error) { + if f.index < len(f.connErrPairs) { + index := f.index + connErr := f.connErrPairs[index] + f.index++ + return connErr.conn, connErr.err + } + for { + if f.closed.Load() { + return nil, fmt.Errorf("use of closed network connection") + } + } +} + +func (f *fakeListener) Close() error { + f.closed.Store(true) + return nil +} + +func (f *fakeListener) Addr() net.Addr { + return f.addr +} + +var _ net.Listener = &fakeListener{} + +func listenFuncFactory(listeners []*fakeListener) func(_ context.Context, network string, address string) (net.Listener, error) { + index := 0 + return func(_ context.Context, network string, address string) (net.Listener, error) { + if index < len(listeners) { + host, portStr, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, err + } + listener := listeners[index] + addr := &net.TCPAddr{ + IP: MustParseIP(host), + Port: port, + } + if err != nil { + return nil, err + } + listener.addr = addr + index++ + + if listener.err != nil { + return nil, listener.err + } + return listener, nil + } + return nil, nil + } +} + +func TestMultiListen(t *testing.T) { + testCases := []struct { + name string + network string + addrs []string + fakeListeners []*fakeListener + errString string + }{ + { + name: "unsupported network", + network: "udp", + errString: "network \"udp\" not supported", + }, + { + name: "no host", + network: "tcp", + errString: "no address provided to listen on", + }, + { + name: "valid", + network: "tcp", + addrs: []string{"127.0.0.1:12345"}, + fakeListeners: []*fakeListener{{connErrPairs: []connErrPair{}}}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.TODO() + ml, err := multiListen(ctx, tc.network, tc.addrs, listenFuncFactory(tc.fakeListeners)) + + if tc.errString != "" { + assertError(t, tc.errString, err) + } else { + assertNoError(t, err) + } + if ml != nil { + err = ml.Close() + if err != nil { + t.Errorf("Did not expect error: %v", err) + } + } + }) + } +} + +func TestMultiListen_Addr(t *testing.T) { + ctx := context.TODO() + ml, err := multiListen(ctx, "tcp", []string{"10.10.10.10:5000", "192.168.1.10:5000", "127.0.0.1:5000"}, listenFuncFactory( + []*fakeListener{{}, {}, {}}, + )) + if err != nil { + t.Errorf("Did not expect error: %v", err) + } + + if ml.Addr().String() != "10.10.10.10:5000" { + t.Errorf("Expected '10.10.10.10:5000' but got '%s'", ml.Addr().String()) + } + + err = ml.Close() + if err != nil { + t.Errorf("Did not expect error: %v", err) + } +} + +func TestMultiListen_Addrs(t *testing.T) { + ctx := context.TODO() + addrs := []string{"10.10.10.10:5000", "192.168.1.10:5000", "127.0.0.1:5000"} + ml, err := multiListen(ctx, "tcp", addrs, listenFuncFactory( + []*fakeListener{{}, {}, {}}, + )) + if err != nil { + t.Errorf("Did not expect error: %v", err) + } + + gotAddrs := ml.(*multiListener).Addrs() + for i := range gotAddrs { + if gotAddrs[i].String() != addrs[i] { + t.Errorf("expected %q; got %q", addrs[i], gotAddrs[i].String()) + } + + } + + err = ml.Close() + if err != nil { + t.Errorf("Did not expect error: %v", err) + } +} + +func TestMultiListen_Close(t *testing.T) { + testCases := []struct { + name string + addrs []string + runner func(listener net.Listener, acceptCalls int) error + fakeListeners []*fakeListener + acceptCalls int + errString string + }{ + { + name: "close", + addrs: []string{"10.10.10.10:5000", "192.168.1.10:5000", "127.0.0.1:5000"}, + runner: func(ml net.Listener, acceptCalls int) error { + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + err := ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{}, {}, {}}, + }, + { + name: "close with pending connections", + addrs: []string{"10.10.10.10:5001", "192.168.1.10:5002", "127.0.0.1:5003"}, + runner: func(ml net.Listener, acceptCalls int) error { + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + err := ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{ + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("10.10.10.10"), Port: 50001}}, + }}}, { + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("192.168.1.10"), Port: 50002}}, + }, + }}, { + connErrPairs: []connErrPair{{ + conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("127.0.0.1"), Port: 50003}}, + }}, + }}, + }, + { + name: "close with no pending connections", + addrs: []string{"10.10.10.10:3001", "192.168.1.10:3002", "127.0.0.1:3003"}, + runner: func(ml net.Listener, acceptCalls int) error { + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + err := ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{ + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("10.10.10.10"), Port: 50001}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("192.168.1.10"), Port: 50002}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("127.0.0.1"), Port: 50003}}}, + }, + }}, + acceptCalls: 3, + }, + { + name: "close on close", + addrs: []string{"10.10.10.10:5000", "192.168.1.10:5000", "127.0.0.1:5000"}, + runner: func(ml net.Listener, acceptCalls int) error { + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + err := ml.Close() + if err != nil { + return err + } + + err = ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{}, {}, {}}, + errString: "use of closed network connection", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.TODO() + ml, err := multiListen(ctx, "tcp", tc.addrs, listenFuncFactory(tc.fakeListeners)) + if err != nil { + t.Errorf("Did not expect error: %v", err) + } + err = tc.runner(ml, tc.acceptCalls) + if tc.errString != "" { + assertError(t, tc.errString, err) + } else { + assertNoError(t, err) + } + + for _, f := range tc.fakeListeners { + if !f.closed.Load() { + t.Errorf("Expeted sub-listener to be closed") + } + } + }) + } +} + +func TestMultiListen_Accept(t *testing.T) { + testCases := []struct { + name string + addrs []string + runner func(listener net.Listener, acceptCalls int) error + fakeListeners []*fakeListener + acceptCalls int + errString string + }{ + { + name: "accept all connections", + addrs: []string{"10.10.10.10:3000", "192.168.1.103:4000", "127.0.0.1:5000"}, + runner: func(ml net.Listener, acceptCalls int) error { + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + err := ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{ + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("10.10.10.10"), Port: 50001}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("192.168.1.10"), Port: 50002}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("127.0.0.1"), Port: 50003}}}, + }, + }}, + acceptCalls: 3, + }, + { + name: "accept some connections", + addrs: []string{"10.10.10.10:3000", "192.168.1.103:4000", "172.16.20.10:5000", "127.0.0.1:6000"}, + runner: func(ml net.Listener, acceptCalls int) error { + + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + + } + err := ml.Close() + if err != nil { + return err + } + return nil + }, + fakeListeners: []*fakeListener{{ + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("10.10.10.10"), Port: 30001}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("192.168.1.10"), Port: 40001}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("192.168.1.10"), Port: 40002}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("172.16.20.10"), Port: 50001}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("172.16.20.10"), Port: 50002}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("172.16.20.10"), Port: 50003}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("172.16.20.10"), Port: 50004}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("127.0.0.1"), Port: 60001}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("127.0.0.1"), Port: 60002}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("127.0.0.1"), Port: 60003}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("127.0.0.1"), Port: 60004}}}, + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("127.0.0.1"), Port: 60005}}}, + }, + }}, + acceptCalls: 3, + }, + { + name: "accept on closed listener", + addrs: []string{"10.10.10.10:3001", "192.168.1.10:3002", "127.0.0.1:3003"}, + runner: func(ml net.Listener, acceptCalls int) error { + err := ml.Close() + if err != nil { + return err + } + for i := 0; i < acceptCalls; i++ { + _, err := ml.Accept() + if err != nil { + return err + } + } + return nil + }, + fakeListeners: []*fakeListener{{ + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("10.10.10.10"), Port: 50001}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("192.168.1.10"), Port: 50002}}}, + }}, { + connErrPairs: []connErrPair{ + {conn: &fakeCon{remoteAddr: &net.TCPAddr{IP: MustParseIP("127.0.0.1"), Port: 50003}}}, + }, + }}, + acceptCalls: 1, + errString: "use of closed network connection", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.TODO() + ml, err := multiListen(ctx, "tcp", tc.addrs, listenFuncFactory(tc.fakeListeners)) + if err != nil { + t.Errorf("Did not expect error: %v", err) + } + + err = tc.runner(ml, tc.acceptCalls) + if tc.errString != "" { + assertError(t, tc.errString, err) + } else { + assertNoError(t, err) + } + }) + } +} + +func TestMultiListen_HTTP(t *testing.T) { + ctx := context.TODO() + ml, err := MultiListen(ctx, "tcp", ":0", ":0", ":0") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + addrs := ml.(*multiListener).Addrs() + if len(addrs) != 3 { + t.Fatalf("expected 3 listeners, got %v", addrs) + } + + // serve http on multi-listener + handler := func(w http.ResponseWriter, _ *http.Request) { + io.WriteString(w, "hello") + } + server := http.Server{ + Handler: http.HandlerFunc(handler), + } + go func() { _ = server.Serve(ml) }() + defer server.Close() + + // Wait for server + awake := false + for i := 0; i < 5; i++ { + _, err = http.Get("http://" + addrs[0].String()) + if err == nil { + awake = true + break + } + time.Sleep(50 * time.Millisecond) + } + if !awake { + t.Fatalf("http server did not respond in time") + } + + // HTTP GET on each address. + for _, addr := range addrs { + _, err = http.Get("http://" + addr.String()) + if err != nil { + t.Errorf("error connecting to %q: %v", addr.String(), err) + } + } +} + +func assertError(t *testing.T, errString string, err error) { + if err == nil { + t.Errorf("Expected error '%s' but got none", errString) + } + if err.Error() != errString { + t.Errorf("Expected error '%s' but got '%s'", errString, err.Error()) + } +} + +func assertNoError(t *testing.T, err error) { + if err != nil { + t.Errorf("Did not expect error: %v", err) + } +} diff --git a/net/v2/net.go b/net/v2/net.go new file mode 100644 index 00000000..233733cd --- /dev/null +++ b/net/v2/net.go @@ -0,0 +1,77 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "errors" + "fmt" + "math" + "math/big" + "net" + "strconv" +) + +// ParsePort parses a string representing an IP port. If the string is not a +// valid port number, this returns an error. +func ParsePort(port string, allowZero bool) (int, error) { + portInt, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return 0, err + } + if portInt == 0 && !allowZero { + return 0, errors.New("0 is not a valid port number") + } + return int(portInt), nil +} + +// BigForIP creates a big.Int based on the provided net.IP +func BigForIP(ip net.IP) *big.Int { + // NOTE: Convert to 16-byte representation so we can + // handle v4 and v6 values the same way. + return big.NewInt(0).SetBytes(ip.To16()) +} + +// AddIPOffset adds the provided integer offset to a base big.Int representing a net.IP +// NOTE: If you started with a v4 address and overflow it, you get a v6 result. +func AddIPOffset(base *big.Int, offset int) net.IP { + r := big.NewInt(0).Add(base, big.NewInt(int64(offset))).Bytes() + r = append(make([]byte, 16), r...) + return net.IP(r[len(r)-16:]) +} + +// RangeSize returns the size of a range in valid addresses. +// returns the size of the subnet (or math.MaxInt64 if the range size would overflow int64) +func RangeSize(subnet *net.IPNet) int64 { + ones, bits := subnet.Mask.Size() + if bits == 32 && (bits-ones) >= 31 || bits == 128 && (bits-ones) >= 127 { + return 0 + } + // this checks that we are not overflowing an int64 + if bits-ones >= 63 { + return math.MaxInt64 + } + return int64(1) << uint(bits-ones) +} + +// GetIndexedIP returns a net.IP that is subnet.IP + index in the contiguous IP space. +func GetIndexedIP(subnet *net.IPNet, index int) (net.IP, error) { + ip := AddIPOffset(BigForIP(subnet.IP), index) + if !subnet.Contains(ip) { + return nil, fmt.Errorf("can't generate IP with index %d from subnet. subnet too small. subnet: %q", index, subnet) + } + return ip, nil +} diff --git a/net/v2/net_test.go b/net/v2/net_test.go new file mode 100644 index 00000000..e9770f2e --- /dev/null +++ b/net/v2/net_test.go @@ -0,0 +1,207 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "testing" +) + +func TestParsePort(t *testing.T) { + var tests = []struct { + name string + port string + allowZero bool + expectedPort int + expectedError bool + }{ + { + name: "valid port: 1", + port: "1", + expectedPort: 1, + }, + { + name: "valid port: 1234", + port: "1234", + expectedPort: 1234, + }, + { + name: "valid port: 65535", + port: "65535", + expectedPort: 65535, + }, + { + name: "invalid port: not a number", + port: "a", + expectedError: true, + allowZero: false, + }, + { + name: "invalid port: too small", + port: "0", + expectedError: true, + }, + { + name: "invalid port: negative", + port: "-10", + expectedError: true, + }, + { + name: "invalid port: too big", + port: "65536", + expectedError: true, + }, + { + name: "zero port: allowed", + port: "0", + allowZero: true, + }, + { + name: "zero port: not allowed", + port: "0", + expectedError: true, + }, + } + + for _, rt := range tests { + t.Run(rt.name, func(t *testing.T) { + actualPort, actualError := ParsePort(rt.port, rt.allowZero) + + if actualError != nil && !rt.expectedError { + t.Errorf("%s unexpected failure: %v", rt.name, actualError) + return + } + if actualError == nil && rt.expectedError { + t.Errorf("%s passed when expected to fail", rt.name) + return + } + if actualPort != rt.expectedPort { + t.Errorf("%s returned wrong port: got %d, expected %d", rt.name, actualPort, rt.expectedPort) + } + }) + } +} + +func TestRangeSize(t *testing.T) { + testCases := []struct { + name string + cidr string + addrs int64 + }{ + { + name: "supported IPv4 cidr", + cidr: "192.168.1.0/24", + addrs: 256, + }, + { + name: "unsupported IPv4 cidr", + cidr: "192.168.1.0/1", + addrs: 0, + }, + { + name: "unsupported IPv6 mask", + cidr: "2001:db8::/1", + addrs: 0, + }, + } + + for _, tc := range testCases { + cidr, err := ParseIPNet(tc.cidr) + if err != nil { + t.Errorf("failed to parse cidr for test %s, unexpected error: '%s'", tc.name, err) + } + if size := RangeSize(cidr); size != tc.addrs { + t.Errorf("test %s failed. %s should have a range size of %d, got %d", + tc.name, tc.cidr, tc.addrs, size) + } + } +} + +func TestGetIndexedIP(t *testing.T) { + testCases := []struct { + cidr string + index int + expectError bool + expectedIP string + }{ + { + cidr: "192.168.1.0/24", + index: 20, + expectError: false, + expectedIP: "192.168.1.20", + }, + { + cidr: "192.168.1.0/30", + index: 10, + expectError: true, + }, + { + cidr: "192.168.1.0/24", + index: 255, + expectError: false, + expectedIP: "192.168.1.255", + }, + { + cidr: "255.255.255.0/24", + index: 256, + expectError: true, + }, + { + cidr: "fd:11:b2:be::/120", + index: 20, + expectError: false, + expectedIP: "fd:11:b2:be::14", + }, + { + cidr: "fd:11:b2:be::/126", + index: 10, + expectError: true, + }, + { + cidr: "fd:11:b2:be::/120", + index: 255, + expectError: false, + expectedIP: "fd:11:b2:be::ff", + }, + { + cidr: "00:00:00:be::/120", + index: 255, + expectError: false, + expectedIP: "::be:0:0:0:ff", + }, + } + + for _, tc := range testCases { + subnet, err := ParseIPNet(tc.cidr) + if err != nil { + t.Errorf("failed to parse cidr %s, unexpected error: '%s'", tc.cidr, err) + } + + ip, err := GetIndexedIP(subnet, tc.index) + if err == nil && tc.expectError || err != nil && !tc.expectError { + t.Errorf("expectedError is %v and err is %s", tc.expectError, err) + continue + } + + if err == nil { + ipString := ip.String() + if ipString != tc.expectedIP { + t.Errorf("expected %s but instead got %s", tc.expectedIP, ipString) + } + } + + } +} diff --git a/net/v2/parse.go b/net/v2/parse.go new file mode 100644 index 00000000..b7e9e8b9 --- /dev/null +++ b/net/v2/parse.go @@ -0,0 +1,259 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "fmt" + "net" + "net/netip" + + forkednet "k8s.io/utils/internal/third_party/forked/golang/net" +) + +// ParseIP parses an IPv4 or IPv6 address to a net.IP. This accepts both fully-valid IP +// addresses and irregular/ambiguous forms, making it usable for both validated and +// non-validated input strings. It should be used instead of net.ParseIP (which rejects +// some strings we need to accept for backward compatibility) and the old +// netutilsv1.ParseIPSloppy. +// +// Compare ParseAddr, which returns a netip.Addr but is otherwise identical. +func ParseIP(ipStr string) (net.IP, error) { + // Note: if we want to get rid of forkednet, we should be able to use some + // invocation of regexp.ReplaceAllString to get rid of leading 0s in ipStr. + ip := forkednet.ParseIP(ipStr) + if ip != nil { + return ip, nil + } + + if ipStr == "" { + return nil, fmt.Errorf("expected an IP address") + } + // NB: we use forkednet.ParseCIDR directly, not ParseIPNet, to avoid recursing + // between ParseIP and ParseIPNet. + if _, _, err := forkednet.ParseCIDR(ipStr); err == nil { + return nil, fmt.Errorf("expected an IP address, got a CIDR value") + } + return nil, fmt.Errorf("not a valid IP address") +} + +// ParseAddr parses an IPv4 or IPv6 address to a netip.Addr. This accepts both fully-valid +// IP addresses and irregular/ambiguous forms, making it usable for both validated and +// non-validated input strings. As compared to netip.ParseAddr: +// +// - It accepts IPv4 IPs with extra leading "0"s, for backward compatibility. +// - It rejects IPs with attached zone identifiers (e.g. `"fe80::1234%eth0"`). +// - It converts "IPv4-mapped IPv6" addresses (e.g. `"::ffff:1.2.3.4"`) to the +// corresponding "plain" IPv4 values (e.g. `"1.2.3.4"`). That is, it never returns an +// address for which `Is4In6()` would return `true`. +// +// Compare ParseIP, which returns a net.IP but is otherwise identical. +func ParseAddr(ipStr string) (netip.Addr, error) { + // To ensure identical parsing, we use ParseIP and then convert. (If ParseIP + // returns a nil ip, AddrFromIP will convert that to the zero/invalid netip.Addr, + // which is what we want.) + ip, err := ParseIP(ipStr) + return AddrFromIP(ip), err +} + +// ParseIPNet parses an IPv4 or IPv6 CIDR string representing a subnet or mask, to a +// *net.IPNet. This accepts both fully-valid CIDR values and irregular/ambiguous forms, +// making it usable for both validated and non-validated input strings. It should be used +// instead of net.ParseCIDR (which rejects some strings that we need to accept for +// backward-compatibility) and the old netutilsv1.ParseCIDRSloppy. +// +// The return value is equivalent to the second return value from net.ParseCIDR. Note that +// this means that if the CIDR string has bits set beyond the prefix length (e.g., the "5" +// in "192.168.1.5/24"), those bits are simply discarded. Use ParseIPAsIPNet instead if +// you want a *net.IPNet value containing the complete IP. +// +// Compare ParsePrefix, which returns a netip.Prefix but is otherwise identical. +func ParseIPNet(cidrStr string) (*net.IPNet, error) { + _, ipnet, err := parseIPNetInternal(cidrStr) + return ipnet, err +} + +// ParseIPAsIPNet parses an IPv4 or IPv6 CIDR string representing an IP address and subnet +// mask, to a *net.IPNet. This accepts both fully-valid CIDR values and +// irregular/ambiguous forms, making it usable for both validated and non-validated input +// strings. It should be used instead of net.ParseCIDR (which rejects some strings that we +// need to accept for backward-compatibility) and the old netutilsv1.ParseCIDRSloppy. +// +// The return value combines the two return values of net.ParseCIDR; the returned +// *net.IPNet's IP field will contain a net.IP corresponding to the full IP address from +// the CIDR string, while the Mask field will contain a net.IPMask corresponding to the +// CIDR's prefix length. +// +// Compare ParseAddrAsPrefix, which returns a netip.Prefix, but is otherwise identical. +func ParseIPAsIPNet(cidrStr string) (*net.IPNet, error) { + ip, ipnet, err := parseIPNetInternal(cidrStr) + if err != nil { + return nil, err + } + return &net.IPNet{IP: ip, Mask: ipnet.Mask}, nil +} + +func parseIPNetInternal(cidrStr string) (net.IP, *net.IPNet, error) { + // Note: if we want to get rid of forkednet, we should be able to use some + // invocation of regexp.ReplaceAllString to get rid of leading 0s in cidrStr. + if ip, ipnet, err := forkednet.ParseCIDR(cidrStr); err == nil { + return ip, ipnet, nil + } + + if cidrStr == "" { + return nil, nil, fmt.Errorf("expected a CIDR value") + } + // NB: we use forkednet.ParseIP directly, not our own ParseIP, to avoid recursing + // between ParseIPNet and ParseIP. + if forkednet.ParseIP(cidrStr) != nil { + return nil, nil, fmt.Errorf("expected a CIDR value, but got IP address") + } + return nil, nil, fmt.Errorf("not a valid CIDR value") +} + +// ParsePrefix parses an IPv4 or IPv6 CIDR string representing a subnet or mask, to a +// netip.Prefix. This accepts both fully-valid CIDR values and irregular/ambiguous forms, +// making it usable for both validated and non-validated input strings. As compared to +// netip.ParsePrefix: +// +// - It accepts IPv4 IPs with extra leading "0"s, for backward compatibility. +// - It converts "IPv4-mapped IPv6" addresses (e.g. `"::ffff:1.2.3.0/120"`) to the +// corresponding "plain" IPv4 values (e.g. `"1.2.3.0/24"`). That is, it never returns +// a prefix for which `.Addr().Is4In6()` would return `true`. +// - When given a CIDR string with bits set beyond the prefix length, like +// `"192.168.1.5/24"`, it discards those extra bits (the equivalent of calling +// .Masked() on the return value of netip.ParsePrefix). Use ParseAddrAsPrefix instead +// if you want a netip.Prefix value containing the complete IP. +// +// Compare ParseIPNet, which returns a *net.IPNet but is otherwise identical. +func ParsePrefix(cidrStr string) (netip.Prefix, error) { + // To ensure identical parsing, we use ParseIPNet and then convert. (If ParseIPNet + // returns nil, PrefixFromIPNet will convert that to the zero/invalid + // netip.Prefix, which is what we want.) + ipnet, err := ParseIPNet(cidrStr) + return PrefixFromIPNet(ipnet), err +} + +// ParseAddrAsPrefix parses an IPv4 or IPv6 CIDR string representing an IP address and +// subnet mask, to a netip.Prefix. This accepts both fully-valid CIDR values and +// irregular/ambiguous forms, making it usable for both validated and non-validated input +// strings. As with ParsePrefix, this should be used rather than netip.ParsePrefix, +// for backward-compatibility and better handling of ambiguous values. +// +// The return value is identical to the value returned from ParsePrefix except in the +// case of CIDR strings with bits set beyond the prefix length, which are preserved by +// ParseAddrAsPrefix but discarded by ParsePrefix. +// +// Compare ParseIPAsIPNet, which returns a *net.IPNet, but is otherwise identical. +func ParseAddrAsPrefix(cidrStr string) (netip.Prefix, error) { + // To ensure identical parsing, we use ParseIPAsIPNet and then convert. (If + // ParseIPAsIPNet returns nil, PrefixFromIPNet will convert that to the + // zero/invalid netip.Prefix, which is what we want.) + ipnet, err := ParseIPAsIPNet(cidrStr) + return PrefixFromIPNet(ipnet), err +} + +type parser[T any] func(string) (T, error) + +func must[T any](parse parser[T]) func(string) T { + return func(str string) T { + ret, err := parse(str) + if err != nil { + panic(err) + } + return ret + } +} + +// MustParseIP is like ParseIP, but it panics on error instead of returning an error value. +var MustParseIP = must(ParseIP) + +// MustParseIPNet is like ParseIPNet, but it panics on error instead of returning an error value. +var MustParseIPNet = must(ParseIPNet) + +// MustParseAddr is like ParseAddr, but it panics on error instead of returning an error value. +var MustParseAddr = must(ParseAddr) + +// MustParsePrefix is like ParsePrefix, but it panics on error instead of returning an error value. +var MustParsePrefix = must(ParsePrefix) + +// MustParseIPAsIPNet is like ParseIPAsIPNet, but it panics on error instead of returning an error value. +var MustParseIPAsIPNet = must(ParseIPAsIPNet) + +// MustParseAddrAsPrefix is like ParseAddrAsPrefix, but it panics on error instead of returning an error value. +var MustParseAddrAsPrefix = must(ParseAddrAsPrefix) + +type listParser[T any] func(...string) ([]T, error) + +func list[T any](parse parser[T]) listParser[T] { + return func(strs ...string) ([]T, error) { + var err error + ret := make([]T, len(strs)) + for i, str := range strs { + ret[i], err = parse(str) + if err != nil { + return nil, err + } + } + return ret, nil + } +} + +// ParseIPList parses a list of strings with ParseIP and returns a []net.IP or an error. +var ParseIPList = list(ParseIP) + +// ParseIPNetList parses a list of strings with ParseIPNet and returns a []*net.IPNet or an error. +var ParseIPNetList = list(ParseIPNet) + +// ParseAddrList parses a list of strings with ParseAddr and returns a []netip.Addr or an error. +var ParseAddrList = list(ParseAddr) + +// ParsePrefixList parses a list of strings with ParsePrefix and returns a []netip.Prefix or an error. +var ParsePrefixList = list(ParsePrefix) + +// ParseIPAsIPNetList parses a list of strings with ParseIPAsIPNet and returns a []*net.IPNet or an error. +var ParseIPAsIPNetList = list(ParseIPAsIPNet) + +// ParseAddrAsPrefixList parses a list of strings with ParseAddrAsPrefix and returns a []netip.Prefix or an error. +var ParseAddrAsPrefixList = list(ParseAddrAsPrefix) + +func mustlist[T any](parse listParser[T]) func(...string) []T { + return func(strs ...string) []T { + ret, err := parse(strs...) + if err != nil { + panic(err) + } + return ret + } +} + +// MustParseIPList parses a list of strings with ParseIP and returns a []net.IP or else panics on error. +var MustParseIPList = mustlist(ParseIPList) + +// MustParseIPNetList parses a list of strings with ParseIPNet and returns a []*net.IPNet or else panics on error. +var MustParseIPNetList = mustlist(ParseIPNetList) + +// MustParseAddrList parses a list of strings with ParseAddr and returns a []netip.Addr or else panics on error. +var MustParseAddrList = mustlist(ParseAddrList) + +// MustParsePrefixList parses a list of strings with ParsePrefix and returns a []netip.Prefix or else panics on error. +var MustParsePrefixList = mustlist(ParsePrefixList) + +// MustParseIPAsIPNetList parses a list of strings with ParseIPAsIPNet and returns a []*net.IPNet or else panics on error. +var MustParseIPAsIPNetList = mustlist(ParseIPAsIPNetList) + +// MustParseAddrAsPrefixList parses a list of strings with ParseAddrAsPrefix and returns a []netip.Prefix or else panics on error. +var MustParseAddrAsPrefixList = mustlist(ParseAddrAsPrefixList) diff --git a/net/v2/parse_test.go b/net/v2/parse_test.go new file mode 100644 index 00000000..c42d1f04 --- /dev/null +++ b/net/v2/parse_test.go @@ -0,0 +1,276 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "testing" +) + +func TestParseIP(t *testing.T) { + // See test cases in ips_test.go + for _, tc := range goodTestIPs { + if tc.skipParse { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, str := range tc.strings { + ip, err := ParseIP(str) + if err != nil { + t.Errorf("expected %q to parse, but got error %v", str, err) + } + if !ip.Equal(tc.ips[0]) { + t.Errorf("expected string %d %q to parse equal to IP %#v %q but got %#v (%q)", i+1, str, tc.ips[0], tc.ips[0].String(), ip, ip.String()) + } + } + }) + } + + // See test cases in ips_test.go + for _, tc := range badTestIPs { + if tc.skipParse { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, ip := range tc.ips { + errStr := ip.String() + parsedIP, _ := ParseIP(errStr) + if parsedIP != nil { + t.Errorf("expected IP %d %#v (%q) to not re-parse but got %#v (%q)", i+1, ip, errStr, parsedIP, parsedIP.String()) + } + } + + for i, addr := range tc.addrs { + errStr := addr.String() + parsedIP, _ := ParseIP(errStr) + if parsedIP != nil { + t.Errorf("expected Addr %d %#v (%q) to not re-parse but got %#v (%q)", i+1, addr, errStr, parsedIP, parsedIP.String()) + } + } + + for i, str := range tc.strings { + ip, _ := ParseIP(str) + if ip != nil { + t.Errorf("expected string %d %q to not parse but got %#v (%q)", i+1, str, ip, ip.String()) + } + } + }) + } +} + +func TestParseAddr(t *testing.T) { + // See test cases in ips_test.go + for _, tc := range goodTestIPs { + if tc.skipParse { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, str := range tc.strings { + addr, err := ParseAddr(str) + if err != nil { + t.Errorf("expected %q to parse, but got error %v", str, err) + } + if addr != tc.addrs[0] { + t.Errorf("expected string %d %q to parse equal to Addr %#v %q but got %#v (%q)", i+1, str, tc.addrs[0], tc.addrs[0].String(), addr, addr.String()) + } + } + }) + } + + for _, tc := range badTestIPs { + if tc.skipParse { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, ip := range tc.ips { + errStr := ip.String() + parsedAddr, err := ParseAddr(errStr) + if err == nil { + t.Errorf("expected IP %d %#v (%q) to not re-parse but got %#v (%q)", i+1, ip, errStr, parsedAddr, parsedAddr.String()) + } + } + + for i, addr := range tc.addrs { + errStr := addr.String() + parsedAddr, err := ParseAddr(errStr) + if err == nil { + t.Errorf("expected Addr %d %#v (%q) to not re-parse but got %#v (%q)", i+1, addr, errStr, parsedAddr, parsedAddr.String()) + } + } + + for i, str := range tc.strings { + addr, err := ParseAddr(str) + if err == nil { + t.Errorf("expected string %d %q to not parse but got %#v (%q)", i+1, str, addr, addr.String()) + } + } + }) + } +} + +func TestParseIPNet(t *testing.T) { + // See test cases in ips_test.go + for _, tc := range goodTestCIDRs { + if tc.skipParse { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, str := range tc.strings { + ipnet, err := ParseIPNet(str) + if err != nil { + t.Errorf("expected %q to parse, but got error %v", str, err) + } + ifaddr, err := ParseIPAsIPNet(str) + if err != nil { + t.Errorf("expected %q to parse via ParseIPAsIPNet, but got error %v", str, err) + } + + if tc.ifaddr { + // The test case expects ParseIPNet and + // ParseIPAsIPNet to parse to different values. + if ipnet.String() == ifaddr.String() { + t.Errorf("expected %q to parse differently with ParseIPNet and ParseIPAsIPNet but got %q for both", str, ipnet.String()) + } + // In this case, it's the ParseIPAsIPNet value + // that should re-stringify correctly. (ParseIPNet + // will have discarded the trailing bits.) + ipnet = ifaddr + } else { + // Some strings might parse to the same value and + // others might parse to different values. + // However, in all cases, the ParseIPAsIPNet value + // should be the same as the ParseIPNet value + // after masking it. + if !ipnet.IP.Equal(ifaddr.IP.Mask(ifaddr.Mask)) { + t.Errorf("expected %q to parse similarly with ParseIPNet and ParseIPAsIPNet but got IPs %q and %q->%q", str, ipnet.IP, ifaddr, ifaddr.IP.Mask(ifaddr.Mask)) + } + } + + if ipnet.String() != tc.ipnets[0].String() { + t.Errorf("expected string %d %q to parse and re-stringify to %q but got %q", i+1, str, tc.ipnets[0].String(), ipnet.String()) + } + } + }) + } + + // See test cases in ips_test.go + for _, tc := range badTestCIDRs { + if tc.skipParse { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, ipnet := range tc.ipnets { + errStr := ipnet.String() + parsedIPNet, err := ParseIPNet(errStr) + if err == nil { + t.Errorf("expected IPNet %d %q to not parse but got %#v (%q)", i+1, errStr, *parsedIPNet, parsedIPNet.String()) + } + } + + for i, prefix := range tc.prefixes { + errStr := prefix.String() + parsedIPNet, err := ParseIPNet(errStr) + if err == nil { + t.Errorf("expected Prefix %d %#v %q to not parse but got %#v (%q)", i+1, prefix, errStr, *parsedIPNet, parsedIPNet.String()) + } + } + + for i, str := range tc.strings { + ipnet, err := ParseIPNet(str) + if err == nil { + t.Errorf("expected string %d %q to not parse but got %#v (%q)", i+1, str, *ipnet, ipnet.String()) + } + } + }) + } +} + +func TestParsePrefix(t *testing.T) { + // See test cases in ips_test.go + for _, tc := range goodTestCIDRs { + if tc.skipParse { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, str := range tc.strings { + prefix, err := ParsePrefix(str) + if err != nil { + t.Errorf("expected %q to parse, but got error %v", str, err) + } + ifaddr, err := ParseAddrAsPrefix(str) + if err != nil { + t.Errorf("expected %q to parse via ParseAddrAsPrefix, but got error %v", str, err) + } + + if tc.ifaddr { + // The test case expects ParsePrefix and + // ParseAddrAsPrefix to parse to different values. + if prefix == ifaddr { + t.Errorf("expected %q to parse differently with ParsePrefix and ParseAddrAsPrefix but got %#v %q for both", str, prefix, prefix) + } + // In this case, it's the ParseAddrAsPrefix value + // that should re-stringify correctly. (ParsePrefix + // will have discarded the trailing bits.) + prefix = ifaddr + } else { + // Some strings might parse to the same value and + // others might parse to different values. + // However, in all cases, the ParseAddrAsPrefix + // value should be the same as the ParsePrefix + // value after masking it. + if prefix != ifaddr.Masked() { + t.Errorf("expected %q to parse similarly with ParsePrefix and ParseAddrAsPrefix but got %q and %q->%q", str, prefix, ifaddr, ifaddr.Masked()) + } + } + + if prefix != tc.prefixes[0] { + t.Errorf("expected string %d %q to parse equal to Prefix %#v %q but got %#v (%q)", i+1, str, tc.prefixes[0], tc.prefixes[0].String(), prefix, prefix.String()) + } + } + }) + } + + // See test cases in ips_test.go + for _, tc := range badTestCIDRs { + if tc.skipParse { + continue + } + t.Run(tc.desc, func(t *testing.T) { + for i, ipnet := range tc.ipnets { + errStr := ipnet.String() + parsedPrefix, err := ParsePrefix(errStr) + if err == nil { + t.Errorf("expected IPNet %d %q to not parse but got %#v (%q)", i+1, errStr, parsedPrefix, parsedPrefix.String()) + } + } + + for i, prefix := range tc.prefixes { + errStr := prefix.String() + parsedPrefix, err := ParsePrefix(errStr) + if err == nil { + t.Errorf("expected Prefix %d %q to not parse but got %#v (%q)", i+1, errStr, parsedPrefix, parsedPrefix.String()) + } + } + + for i, str := range tc.strings { + prefix, err := ParsePrefix(str) + if err == nil { + t.Errorf("expected string %d %q to not parse but got %#v (%q)", i+1, str, prefix, prefix.String()) + } + } + }) + } +}