diff options
Diffstat (limited to 'netaddr/ip/sets.py')
-rw-r--r-- | netaddr/ip/sets.py | 748 |
1 files changed, 748 insertions, 0 deletions
diff --git a/netaddr/ip/sets.py b/netaddr/ip/sets.py new file mode 100644 index 0000000..6db896d --- /dev/null +++ b/netaddr/ip/sets.py @@ -0,0 +1,748 @@ +#----------------------------------------------------------------------------- +# Copyright (c) 2008 by David P. D. Moss. All rights reserved. +# +# Released under the BSD license. See the LICENSE file for details. +#----------------------------------------------------------------------------- +"""Set based operations for IP addresses and subnets.""" + +import itertools as _itertools + +from netaddr.ip import (IPNetwork, IPAddress, IPRange, cidr_merge, + cidr_exclude, iprange_to_cidrs) + +from netaddr.compat import _sys_maxint, _dict_keys, _int_type + + +def _subtract(supernet, subnets, subnet_idx, ranges): + """Calculate IPSet([supernet]) - IPSet(subnets). + + Assumptions: subnets is sorted, subnet_idx points to the first + element in subnets that is a subnet of supernet. + + Results are appended to the ranges parameter as tuples of in format + (version, first, last). Return value is the first subnet_idx that + does not point to a subnet of supernet (or len(subnets) if all + subsequents items are a subnet of supernet). + """ + version = supernet._module.version + subnet = subnets[subnet_idx] + if subnet.first > supernet.first: + ranges.append((version, supernet.first, subnet.first - 1)) + + subnet_idx += 1 + prev_subnet = subnet + while subnet_idx < len(subnets): + cur_subnet = subnets[subnet_idx] + + if cur_subnet not in supernet: + break + if prev_subnet.last + 1 == cur_subnet.first: + # two adjacent, non-mergable IPNetworks + pass + else: + ranges.append((version, prev_subnet.last + 1, cur_subnet.first - 1)) + + subnet_idx += 1 + prev_subnet = cur_subnet + + first = prev_subnet.last + 1 + last = supernet.last + if first <= last: + ranges.append((version, first, last)) + + return subnet_idx + + +def _iter_merged_ranges(sorted_ranges): + """Iterate over sorted_ranges, merging where possible + + Sorted ranges must be a sorted iterable of (version, first, last) tuples. + Merging occurs for pairs like [(4, 10, 42), (4, 43, 100)] which is merged + into (4, 10, 100), and leads to return value + ( IPAddress(10, 4), IPAddress(100, 4) ), which is suitable input for the + iprange_to_cidrs function. + """ + if not sorted_ranges: + return + + current_version, current_start, current_stop = sorted_ranges[0] + + for next_version, next_start, next_stop in sorted_ranges[1:]: + if next_start == current_stop + 1 and next_version == current_version: + # Can be merged. + current_stop = next_stop + continue + # Cannot be merged. + yield (IPAddress(current_start, current_version), + IPAddress(current_stop, current_version)) + current_start = next_start + current_stop = next_stop + current_version = next_version + yield (IPAddress(current_start, current_version), + IPAddress(current_stop, current_version)) + + +class IPSet(object): + """ + Represents an unordered collection (set) of unique IP addresses and + subnets. + + """ + __slots__ = ('_cidrs', '__weakref__') + + def __init__(self, iterable=None, flags=0): + """ + Constructor. + + :param iterable: (optional) an iterable containing IP addresses, + subnets or ranges. + + :param flags: decides which rules are applied to the interpretation + of the addr value. See the :class:`IPAddress` documentation + for supported constant values. + + """ + if isinstance(iterable, IPNetwork): + self._cidrs = {iterable.cidr: True} + elif isinstance(iterable, IPRange): + self._cidrs = dict.fromkeys( + iprange_to_cidrs(iterable[0], iterable[-1]), True) + elif isinstance(iterable, IPSet): + self._cidrs = dict.fromkeys(iterable.iter_cidrs(), True) + else: + self._cidrs = {} + if iterable is not None: + mergeable = [] + for addr in iterable: + if isinstance(addr, _int_type): + addr = IPAddress(addr, flags=flags) + mergeable.append(addr) + + for cidr in cidr_merge(mergeable): + self._cidrs[cidr] = True + + def __getstate__(self): + """:return: Pickled state of an ``IPSet`` object.""" + return tuple([cidr.__getstate__() for cidr in self._cidrs]) + + def __setstate__(self, state): + """ + :param state: data used to unpickle a pickled ``IPSet`` object. + + """ + self._cidrs = dict.fromkeys( + (IPNetwork((value, prefixlen), version=version) + for value, prefixlen, version in state), + True) + + def _compact_single_network(self, added_network): + """ + Same as compact(), but assume that added_network is the only change and + that this IPSet was properly compacted before added_network was added. + This allows to perform compaction much faster. added_network must + already be present in self._cidrs. + """ + added_first = added_network.first + added_last = added_network.last + added_version = added_network.version + + # Check for supernets and subnets of added_network. + if added_network._prefixlen == added_network._module.width: + # This is a single IP address, i.e. /32 for IPv4 or /128 for IPv6. + # It does not have any subnets, so we only need to check for its + # potential supernets. + for potential_supernet in added_network.supernet(): + if potential_supernet in self._cidrs: + del self._cidrs[added_network] + return + else: + # IPNetworks from self._cidrs that are subnets of added_network. + to_remove = [] + for cidr in self._cidrs: + if (cidr._module.version != added_version or cidr == added_network): + # We found added_network or some network of a different version. + continue + first = cidr.first + last = cidr.last + if first >= added_first and last <= added_last: + # cidr is a subnet of added_network. Remember to remove it. + to_remove.append(cidr) + elif first <= added_first and last >= added_last: + # cidr is a supernet of added_network. Remove added_network. + del self._cidrs[added_network] + # This IPSet was properly compacted before. Since added_network + # is removed now, it must again be properly compacted -> done. + assert (not to_remove) + return + for item in to_remove: + del self._cidrs[item] + + # Check if added_network can be merged with another network. + + # Note that merging can only happen between networks of the same + # prefixlen. This just leaves 2 candidates: The IPNetworks just before + # and just after the added_network. + # This can be reduced to 1 candidate: 10.0.0.0/24 and 10.0.1.0/24 can + # be merged into into 10.0.0.0/23. But 10.0.1.0/24 and 10.0.2.0/24 + # cannot be merged. With only 1 candidate, we might as well make a + # dictionary lookup. + shift_width = added_network._module.width - added_network.prefixlen + while added_network.prefixlen != 0: + # figure out if the least significant bit of the network part is 0 or 1. + the_bit = (added_network._value >> shift_width) & 1 + if the_bit: + candidate = added_network.previous() + else: + candidate = added_network.next() + + if candidate not in self._cidrs: + # The only possible merge does not work -> merge done + return + # Remove added_network&candidate, add merged network. + del self._cidrs[candidate] + del self._cidrs[added_network] + added_network.prefixlen -= 1 + # Be sure that we set the host bits to 0 when we move the prefixlen. + # Otherwise, adding 255.255.255.255/32 will result in a merged + # 255.255.255.255/24 network, but we want 255.255.255.0/24. + shift_width += 1 + added_network._value = (added_network._value >> shift_width) << shift_width + self._cidrs[added_network] = True + + def compact(self): + """ + Compact internal list of `IPNetwork` objects using a CIDR merge. + """ + cidrs = cidr_merge(self._cidrs) + self._cidrs = dict.fromkeys(cidrs, True) + + def __hash__(self): + """ + Raises ``TypeError`` if this method is called. + + .. note:: IPSet objects are not hashable and cannot be used as \ + dictionary keys or as members of other sets. \ + """ + raise TypeError('IP sets are unhashable!') + + def __contains__(self, ip): + """ + :param ip: An IP address or subnet. + + :return: ``True`` if IP address or subnet is a member of this IP set. + """ + # Iterating over self._cidrs is an O(n) operation: 1000 items in + # self._cidrs would mean 1000 loops. Iterating over all possible + # supernets loops at most 32 times for IPv4 or 128 times for IPv6, + # no matter how many CIDRs this object contains. + supernet = IPNetwork(ip) + if supernet in self._cidrs: + return True + while supernet._prefixlen: + supernet._prefixlen -= 1 + if supernet in self._cidrs: + return True + return False + + def __nonzero__(self): + """Return True if IPSet contains at least one IP, else False""" + return bool(self._cidrs) + + __bool__ = __nonzero__ # Python 3.x. + + def __iter__(self): + """ + :return: an iterator over the IP addresses within this IP set. + """ + return _itertools.chain(*sorted(self._cidrs)) + + def iter_cidrs(self): + """ + :return: an iterator over individual IP subnets within this IP set. + """ + return sorted(self._cidrs) + + def add(self, addr, flags=0): + """ + Adds an IP address or subnet or IPRange to this IP set. Has no effect if + it is already present. + + Note that where possible the IP address or subnet is merged with other + members of the set to form more concise CIDR blocks. + + :param addr: An IP address or subnet in either string or object form, or + an IPRange object. + + :param flags: decides which rules are applied to the interpretation + of the addr value. See the :class:`IPAddress` documentation + for supported constant values. + + """ + if isinstance(addr, IPRange): + new_cidrs = dict.fromkeys( + iprange_to_cidrs(addr[0], addr[-1]), True) + self._cidrs.update(new_cidrs) + self.compact() + return + if isinstance(addr, IPNetwork): + # Networks like 10.1.2.3/8 need to be normalized to 10.0.0.0/8 + addr = addr.cidr + elif isinstance(addr, _int_type): + addr = IPNetwork(IPAddress(addr, flags=flags)) + else: + addr = IPNetwork(addr) + + self._cidrs[addr] = True + self._compact_single_network(addr) + + def remove(self, addr, flags=0): + """ + Removes an IP address or subnet or IPRange from this IP set. Does + nothing if it is not already a member. + + Note that this method behaves more like discard() found in regular + Python sets because it doesn't raise KeyError exceptions if the + IP address or subnet is question does not exist. It doesn't make sense + to fully emulate that behaviour here as IP sets contain groups of + individual IP addresses as individual set members using IPNetwork + objects. + + :param addr: An IP address or subnet, or an IPRange. + + :param flags: decides which rules are applied to the interpretation + of the addr value. See the :class:`IPAddress` documentation + for supported constant values. + + """ + if isinstance(addr, IPRange): + cidrs = iprange_to_cidrs(addr[0], addr[-1]) + for cidr in cidrs: + self.remove(cidr) + return + + if isinstance(addr, _int_type): + addr = IPAddress(addr, flags=flags) + else: + addr = IPNetwork(addr) + + # This add() is required for address blocks provided that are larger + # than blocks found within the set but have overlaps. e.g. :- + # + # >>> IPSet(['192.0.2.0/24']).remove('192.0.2.0/23') + # IPSet([]) + # + self.add(addr) + + remainder = None + matching_cidr = None + + # Search for a matching CIDR and exclude IP from it. + for cidr in self._cidrs: + if addr in cidr: + remainder = cidr_exclude(cidr, addr) + matching_cidr = cidr + break + + # Replace matching CIDR with remaining CIDR elements. + if remainder is not None: + del self._cidrs[matching_cidr] + for cidr in remainder: + self._cidrs[cidr] = True + # No call to self.compact() is needed. Removing an IPNetwork cannot + # create mergeable networks. + + def pop(self): + """ + Removes and returns an arbitrary IP address or subnet from this IP + set. + + :return: An IP address or subnet. + """ + return self._cidrs.popitem()[0] + + def isdisjoint(self, other): + """ + :param other: an IP set. + + :return: ``True`` if this IP set has no elements (IP addresses + or subnets) in common with other. Intersection *must* be an + empty set. + """ + result = self.intersection(other) + return not result + + def copy(self): + """:return: a shallow copy of this IP set.""" + obj_copy = self.__class__() + obj_copy._cidrs.update(self._cidrs) + return obj_copy + + def update(self, iterable, flags=0): + """ + Update the contents of this IP set with the union of itself and + other IP set. + + :param iterable: an iterable containing IP addresses, subnets or ranges. + + :param flags: decides which rules are applied to the interpretation + of the addr value. See the :class:`IPAddress` documentation + for supported constant values. + + """ + if isinstance(iterable, IPSet): + self._cidrs = dict.fromkeys( + (ip for ip in cidr_merge(_dict_keys(self._cidrs) + + _dict_keys(iterable._cidrs))), True) + return + elif isinstance(iterable, (IPNetwork, IPRange)): + self.add(iterable) + return + + if not hasattr(iterable, '__iter__'): + raise TypeError('an iterable was expected!') + # An iterable containing IP addresses or subnets. + mergeable = [] + for addr in iterable: + if isinstance(addr, _int_type): + addr = IPAddress(addr, flags=flags) + mergeable.append(addr) + + for cidr in cidr_merge(_dict_keys(self._cidrs) + mergeable): + self._cidrs[cidr] = True + + self.compact() + + def clear(self): + """Remove all IP addresses and subnets from this IP set.""" + self._cidrs = {} + + def __eq__(self, other): + """ + :param other: an IP set + + :return: ``True`` if this IP set is equivalent to the ``other`` IP set, + ``False`` otherwise. + """ + try: + return self._cidrs == other._cidrs + except AttributeError: + return NotImplemented + + def __ne__(self, other): + """ + :param other: an IP set + + :return: ``False`` if this IP set is equivalent to the ``other`` IP set, + ``True`` otherwise. + """ + try: + return self._cidrs != other._cidrs + except AttributeError: + return NotImplemented + + def __lt__(self, other): + """ + :param other: an IP set + + :return: ``True`` if this IP set is less than the ``other`` IP set, + ``False`` otherwise. + """ + if not hasattr(other, '_cidrs'): + return NotImplemented + + return self.size < other.size and self.issubset(other) + + def issubset(self, other): + """ + :param other: an IP set. + + :return: ``True`` if every IP address and subnet in this IP set + is found within ``other``. + """ + for cidr in self._cidrs: + if cidr not in other: + return False + return True + + __le__ = issubset + + def __gt__(self, other): + """ + :param other: an IP set. + + :return: ``True`` if this IP set is greater than the ``other`` IP set, + ``False`` otherwise. + """ + if not hasattr(other, '_cidrs'): + return NotImplemented + + return self.size > other.size and self.issuperset(other) + + def issuperset(self, other): + """ + :param other: an IP set. + + :return: ``True`` if every IP address and subnet in other IP set + is found within this one. + """ + if not hasattr(other, '_cidrs'): + return NotImplemented + + for cidr in other._cidrs: + if cidr not in self: + return False + return True + + __ge__ = issuperset + + def union(self, other): + """ + :param other: an IP set. + + :return: the union of this IP set and another as a new IP set + (combines IP addresses and subnets from both sets). + """ + ip_set = self.copy() + ip_set.update(other) + return ip_set + + __or__ = union + + def intersection(self, other): + """ + :param other: an IP set. + + :return: the intersection of this IP set and another as a new IP set. + (IP addresses and subnets common to both sets). + """ + result_cidrs = {} + + own_nets = sorted(self._cidrs) + other_nets = sorted(other._cidrs) + own_idx = 0 + other_idx = 0 + own_len = len(own_nets) + other_len = len(other_nets) + while own_idx < own_len and other_idx < other_len: + own_cur = own_nets[own_idx] + other_cur = other_nets[other_idx] + + if own_cur == other_cur: + result_cidrs[own_cur] = True + own_idx += 1 + other_idx += 1 + elif own_cur in other_cur: + result_cidrs[own_cur] = True + own_idx += 1 + elif other_cur in own_cur: + result_cidrs[other_cur] = True + other_idx += 1 + else: + # own_cur and other_cur have nothing in common + if own_cur < other_cur: + own_idx += 1 + else: + other_idx += 1 + + # We ran out of networks in own_nets or other_nets. Either way, there + # can be no further result_cidrs. + result = IPSet() + result._cidrs = result_cidrs + return result + + __and__ = intersection + + def symmetric_difference(self, other): + """ + :param other: an IP set. + + :return: the symmetric difference of this IP set and another as a new + IP set (all IP addresses and subnets that are in exactly one + of the sets). + """ + # In contrast to intersection() and difference(), we cannot construct + # the result_cidrs easily. Some cidrs may have to be merged, e.g. for + # IPSet(["10.0.0.0/32"]).symmetric_difference(IPSet(["10.0.0.1/32"])). + result_ranges = [] + + own_nets = sorted(self._cidrs) + other_nets = sorted(other._cidrs) + own_idx = 0 + other_idx = 0 + own_len = len(own_nets) + other_len = len(other_nets) + while own_idx < own_len and other_idx < other_len: + own_cur = own_nets[own_idx] + other_cur = other_nets[other_idx] + + if own_cur == other_cur: + own_idx += 1 + other_idx += 1 + elif own_cur in other_cur: + own_idx = _subtract(other_cur, own_nets, own_idx, result_ranges) + other_idx += 1 + elif other_cur in own_cur: + other_idx = _subtract(own_cur, other_nets, other_idx, result_ranges) + own_idx += 1 + else: + # own_cur and other_cur have nothing in common + if own_cur < other_cur: + result_ranges.append((own_cur._module.version, + own_cur.first, own_cur.last)) + own_idx += 1 + else: + result_ranges.append((other_cur._module.version, + other_cur.first, other_cur.last)) + other_idx += 1 + + # If the above loop terminated because it processed all cidrs of + # "other", then any remaining cidrs in self must be part of the result. + while own_idx < own_len: + own_cur = own_nets[own_idx] + result_ranges.append((own_cur._module.version, + own_cur.first, own_cur.last)) + own_idx += 1 + + # If the above loop terminated because it processed all cidrs of + # self, then any remaining cidrs in "other" must be part of the result. + while other_idx < other_len: + other_cur = other_nets[other_idx] + result_ranges.append((other_cur._module.version, + other_cur.first, other_cur.last)) + other_idx += 1 + + result = IPSet() + for start, stop in _iter_merged_ranges(result_ranges): + cidrs = iprange_to_cidrs(start, stop) + for cidr in cidrs: + result._cidrs[cidr] = True + return result + + __xor__ = symmetric_difference + + def difference(self, other): + """ + :param other: an IP set. + + :return: the difference between this IP set and another as a new IP + set (all IP addresses and subnets that are in this IP set but + not found in the other.) + """ + result_ranges = [] + result_cidrs = {} + + own_nets = sorted(self._cidrs) + other_nets = sorted(other._cidrs) + own_idx = 0 + other_idx = 0 + own_len = len(own_nets) + other_len = len(other_nets) + while own_idx < own_len and other_idx < other_len: + own_cur = own_nets[own_idx] + other_cur = other_nets[other_idx] + + if own_cur == other_cur: + own_idx += 1 + other_idx += 1 + elif own_cur in other_cur: + own_idx += 1 + elif other_cur in own_cur: + other_idx = _subtract(own_cur, other_nets, other_idx, + result_ranges) + own_idx += 1 + else: + # own_cur and other_cur have nothing in common + if own_cur < other_cur: + result_cidrs[own_cur] = True + own_idx += 1 + else: + other_idx += 1 + + # If the above loop terminated because it processed all cidrs of + # "other", then any remaining cidrs in self must be part of the result. + while own_idx < own_len: + result_cidrs[own_nets[own_idx]] = True + own_idx += 1 + + for start, stop in _iter_merged_ranges(result_ranges): + for cidr in iprange_to_cidrs(start, stop): + result_cidrs[cidr] = True + + result = IPSet() + result._cidrs = result_cidrs + return result + + __sub__ = difference + + def __len__(self): + """ + :return: the cardinality of this IP set (i.e. sum of individual IP \ + addresses). Raises ``IndexError`` if size > maxint (a Python \ + limitation). Use the .size property for subnets of any size. + """ + size = self.size + if size > _sys_maxint: + raise IndexError( + "range contains more than %d (sys.maxint) IP addresses!" + "Use the .size property instead." % _sys_maxint) + return size + + @property + def size(self): + """ + The cardinality of this IP set (based on the number of individual IP + addresses including those implicitly defined in subnets). + """ + return sum([cidr.size for cidr in self._cidrs]) + + def __repr__(self): + """:return: Python statement to create an equivalent object""" + return 'IPSet(%r)' % [str(c) for c in sorted(self._cidrs)] + + __str__ = __repr__ + + def iscontiguous(self): + """ + Returns True if the members of the set form a contiguous IP + address range (with no gaps), False otherwise. + + :return: ``True`` if the ``IPSet`` object is contiguous. + """ + cidrs = self.iter_cidrs() + if len(cidrs) > 1: + previous = cidrs[0][0] + for cidr in cidrs: + if cidr[0] != previous: + return False + previous = cidr[-1] + 1 + return True + + def iprange(self): + """ + Generates an IPRange for this IPSet, if all its members + form a single contiguous sequence. + + Raises ``ValueError`` if the set is not contiguous. + + :return: An ``IPRange`` for all IPs in the IPSet. + """ + if self.iscontiguous(): + cidrs = self.iter_cidrs() + if not cidrs: + return None + return IPRange(cidrs[0][0], cidrs[-1][-1]) + else: + raise ValueError("IPSet is not contiguous") + + def iter_ipranges(self): + """Generate the merged IPRanges for this IPSet. + + In contrast to self.iprange(), this will work even when the IPSet is + not contiguous. Adjacent IPRanges will be merged together, so you + get the minimal number of IPRanges. + """ + sorted_ranges = [(cidr._module.version, cidr.first, cidr.last) for + cidr in self.iter_cidrs()] + + for start, stop in _iter_merged_ranges(sorted_ranges): + yield IPRange(start, stop) |