diff options
Diffstat (limited to 'staslib/iputil.py')
-rw-r--r-- | staslib/iputil.py | 106 |
1 files changed, 33 insertions, 73 deletions
diff --git a/staslib/iputil.py b/staslib/iputil.py index 96c5d56..d5e93dd 100644 --- a/staslib/iputil.py +++ b/staslib/iputil.py @@ -131,66 +131,22 @@ def mac2iface(mac: str): # pylint: disable=too-many-locals # ****************************************************************************** -def _data_matches_ip(data_family, data, ip): - if data_family == socket.AF_INET: - try: - other_ip = ipaddress.IPv4Address(data) - except ValueError: - return False - if ip.version == 6: - ip = ip.ipv4_mapped - elif data_family == socket.AF_INET6: - try: - other_ip = ipaddress.IPv6Address(data) - except ValueError: - return False - if ip.version == 4: - other_ip = other_ip.ipv4_mapped - else: - return False - - return other_ip == ip - - -def _iface_of(src_addr): # pylint: disable=too-many-locals - '''@brief Find the interface that has src_addr as one of its assigned IP addresses. - @param src_addr: The IP address to match - @type src_addr: Instance of ipaddress.IPv4Address or ipaddress.IPv6Address +def ip_equal(ip1, ip2): + '''Check whther two IP addresses are equal. + @param ip1: IPv4Address or IPv6Address object + @param ip2: IPv4Address or IPv6Address object ''' - with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW) as sock: - sock.sendall(GETADDRCMD) - nlmsg = sock.recv(8192) - nlmsg_idx = 0 - while True: - if nlmsg_idx >= len(nlmsg): - nlmsg += sock.recv(8192) - - nlmsghdr = nlmsg[nlmsg_idx : nlmsg_idx + NLMSG_HDRLEN] - nlmsg_len, nlmsg_type, _, _, _ = struct.unpack('<LHHLL', nlmsghdr) - - if nlmsg_type == NLMSG_DONE: - break - - if nlmsg_type == RTM_NEWADDR: - msg_indx = nlmsg_idx + NLMSG_HDRLEN - msg = nlmsg[msg_indx : msg_indx + IFADDRMSG_SZ] # ifaddrmsg - ifa_family, _, _, _, ifa_index = struct.unpack('<BBBBL', msg) - - rtattr_indx = msg_indx + IFADDRMSG_SZ - while rtattr_indx < (nlmsg_idx + nlmsg_len): - rtattr = nlmsg[rtattr_indx : rtattr_indx + RTATTR_SZ] - rta_len, rta_type = struct.unpack('<HH', rtattr) - if rta_type == IFLA_ADDRESS: - data = nlmsg[rtattr_indx + RTATTR_SZ : rtattr_indx + rta_len] - if _data_matches_ip(ifa_family, data, src_addr): - return socket.if_indextoname(ifa_index) + if not isinstance(ip1, ipaddress._BaseAddress): # pylint: disable=protected-access + return False + if not isinstance(ip2, ipaddress._BaseAddress): # pylint: disable=protected-access + return False - rta_len = RTA_ALIGN(rta_len) # Round up to multiple of 4 - rtattr_indx += rta_len # Move to next rtattr + if ip1.version == 4 and ip2.version == 6: + ip2 = ip2.ipv4_mapped + elif ip1.version == 6 and ip2.version == 4: + ip1 = ip1.ipv4_mapped - nlmsg_idx += nlmsg_len # Move to next Netlink message - - return '' + return ip1 == ip2 # ****************************************************************************** @@ -221,21 +177,21 @@ def net_if_addrs(): # pylint: disable=too-many-locals source address. @example: { 'wlp0s20f3': { - 4: ['10.0.0.28'], + 4: [IPv4Address('10.0.0.28')], 6: [ - 'fd5e:9a9e:c5bd:0:5509:890c:1848:3843', - 'fd5e:9a9e:c5bd:0:1fd5:e527:8df7:7912', - '2605:59c8:6128:fb00:c083:1b8:c467:81d2', - '2605:59c8:6128:fb00:e99d:1a02:38e0:ad52', - 'fe80::d71b:e807:d5ee:7614' + IPv6Address('fd5e:9a9e:c5bd:0:5509:890c:1848:3843'), + IPv6Address('fd5e:9a9e:c5bd:0:1fd5:e527:8df7:7912'), + IPv6Address('2605:59c8:6128:fb00:c083:1b8:c467:81d2'), + IPv6Address('2605:59c8:6128:fb00:e99d:1a02:38e0:ad52'), + IPv6Address('fe80::d71b:e807:d5ee:7614'), ], }, 'lo': { - 4: ['127.0.0.1'], - 6: ['::1'], + 4: [IPv4Address('127.0.0.1')], + 6: [IPv6Address('::1')], }, 'docker0': { - 4: ['172.17.0.1'], + 4: [IPv4Address('172.17.0.1')], 6: [] }, } @@ -295,14 +251,18 @@ def net_if_addrs(): # pylint: disable=too-many-locals # ****************************************************************************** -def get_interface(src_addr): +def get_interface(ifaces: dict, src_addr): '''Get interface for given source address - @param src_addr: The source address - @type src_addr: str + @param ifaces: Interface info previously returned by @net_if_addrs() + @param src_addr: IPv4Address or IPv6Address object ''' - if not src_addr: + if not isinstance(src_addr, ipaddress._BaseAddress): # pylint: disable=protected-access return '' - src_addr = src_addr.split('%')[0] # remove scope-id (if any) - src_addr = get_ipaddress_obj(src_addr) - return '' if src_addr is None else _iface_of(src_addr) + for iface, addr_map in ifaces.items(): + for addrs in addr_map.values(): + for addr in addrs: + if ip_equal(src_addr, addr): + return iface + + return '' |