diff --git a/ipwhois/net.py b/ipwhois/net.py index 8793921..1f1ef06 100644 --- a/ipwhois/net.py +++ b/ipwhois/net.py @@ -43,7 +43,10 @@ if sys.version_info >= (3, 3): # pragma: no cover from ipaddress import (ip_address, IPv4Address, - IPv6Address) + IPv6Address, + ip_network, + IPv4Network, + IPv6Network) else: # pragma: no cover from ipaddr import (IPAddress as ip_address, IPv4Address, @@ -117,10 +120,15 @@ def __init__(self, address, timeout=5, proxy_opener=None, # IPv4Address or IPv6Address if isinstance(address, IPv4Address) or isinstance( - address, IPv6Address): + address, IPv6Address) or isinstance( + address, IPv4Network) or isinstance(address, IPv6Network): self.address = address + elif isinstance(address, str) and '/' in address: + + self.address = ip_network(address) + else: # Use ipaddress package exception handling. @@ -161,8 +169,16 @@ def __init__(self, address, timeout=5, proxy_opener=None, if self.version == 4: - # Check if no ASN/whois resolution needs to occur. - is_defined = ipv4_is_defined(self.address_str) + if isinstance(self.address, IPv4Network): + + # Check if no ASN/whois resolution needs to occur with the + # first IP in the subnet. + is_defined = ipv4_is_defined(next(self.address.hosts())) + + else: + + # Check if no ASN/whois resolution needs to occur. + is_defined = ipv4_is_defined(self.address_str) if is_defined[0]: @@ -182,8 +198,16 @@ def __init__(self, address, timeout=5, proxy_opener=None, else: - # Check if no ASN/whois resolution needs to occur. - is_defined = ipv6_is_defined(self.address_str) + if isinstance(self.address, IPv6Network): + + # Check if no ASN/whois resolution needs to occur with the + # first IP in the subnet. + is_defined = ipv6_is_defined(next(self.address.hosts())) + + else: + + # Check if no ASN/whois resolution needs to occur. + is_defined = ipv6_is_defined(self.address_str) if is_defined[0]: diff --git a/ipwhois/tests/test_net.py b/ipwhois/tests/test_net.py index aedf6ac..1b72a68 100644 --- a/ipwhois/tests/test_net.py +++ b/ipwhois/tests/test_net.py @@ -16,23 +16,35 @@ class TestNet(TestCommon): def test_ip_invalid(self): self.assertRaises(ValueError, Net, '192.168.0.256') self.assertRaises(ValueError, Net, 'fe::80::') + self.assertRaises(ValueError, Net, '192.256.0.0/16') + self.assertRaises(ValueError, Net, 'fe::80::/10') def test_ip_defined(self): if sys.version_info >= (3, 3): - from ipaddress import (IPv4Address, IPv6Address) + from ipaddress import (IPv4Address, IPv6Address, + IPv4Network, IPv6Network) else: - from ipaddr import (IPv4Address, IPv6Address) + from ipaddr import (IPv4Address, IPv6Address, + IPv4Network, IPv6Network) self.assertRaises(IPDefinedError, Net, '192.168.0.1') self.assertRaises(IPDefinedError, Net, 'fe80::') self.assertRaises(IPDefinedError, Net, IPv4Address('192.168.0.1')) self.assertRaises(IPDefinedError, Net, IPv6Address('fe80::')) + self.assertRaises(IPDefinedError, Net, '192.168.0.0/16') + self.assertRaises(IPDefinedError, Net, 'fe80::/10') + self.assertRaises(IPDefinedError, Net, IPv4Network('192.168.0.0/16')) + self.assertRaises(IPDefinedError, Net, IPv6Network('fe80::/10')) def test_ip_version(self): result = Net('74.125.225.229') self.assertEqual(result.version, 4) result = Net('2001:4860:4860::8888') self.assertEqual(result.version, 6) + result = Net('8.8.8.0/24') + self.assertEqual(result.version, 4) + result = Net('2001:4860::/32') + self.assertEqual(result.version, 6) def test_timeout(self): result = Net('74.125.225.229')