diff options
Diffstat (limited to 'staslib/iputil.py')
-rw-r--r-- | staslib/iputil.py | 169 |
1 files changed, 169 insertions, 0 deletions
diff --git a/staslib/iputil.py b/staslib/iputil.py new file mode 100644 index 0000000..9199a49 --- /dev/null +++ b/staslib/iputil.py @@ -0,0 +1,169 @@ +# Copyright (c) 2022, Dell Inc. or its subsidiaries. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# See the LICENSE file for details. +# +# This file is part of NVMe STorage Appliance Services (nvme-stas). +# +# Authors: Martin Belanger <Martin.Belanger@dell.com> + +'''A collection of IP address and network interface utilities''' + +import socket +import logging +import ipaddress +from staslib import conf + +RTM_NEWADDR = 20 +RTM_GETADDR = 22 +NLM_F_REQUEST = 0x01 +NLM_F_ROOT = 0x100 +NLMSG_DONE = 3 +IFLA_ADDRESS = 1 +NLMSGHDR_SZ = 16 +IFADDRMSG_SZ = 8 +RTATTR_SZ = 4 + +# Netlink request (Get address command) +GETADDRCMD = ( + # BEGIN: struct nlmsghdr + b'\0' * 4 # nlmsg_len (placeholder - actual length calculated below) + + (RTM_GETADDR).to_bytes(2, byteorder='little', signed=False) # nlmsg_type + + (NLM_F_REQUEST | NLM_F_ROOT).to_bytes(2, byteorder='little', signed=False) # nlmsg_flags + + b'\0' * 2 # nlmsg_seq + + b'\0' * 2 # nlmsg_pid + # END: struct nlmsghdr + + b'\0' * 8 # struct ifaddrmsg +) +GETADDRCMD = len(GETADDRCMD).to_bytes(4, byteorder='little') + GETADDRCMD[4:] # nlmsg_len + + +# ****************************************************************************** +def get_ipaddress_obj(ipaddr): + '''@brief Return a IPv4Address or IPv6Address depending on whether @ipaddr + is a valid IPv4 or IPv6 address. Return None otherwise.''' + try: + ip = ipaddress.ip_address(ipaddr) + except ValueError: + return None + + return ip + + +# ****************************************************************************** +def _data_matches_ip(data_family, data, ip): + if data_family == socket.AF_INET: + try: + other_ip = ipaddress.IPv4Address(data) + except ValueError: + return False + if ip.version == 6: + ip = ip.ipv4_mapped + elif data_family == socket.AF_INET6: + try: + other_ip = ipaddress.IPv6Address(data) + except ValueError: + return False + if ip.version == 4: + other_ip = other_ip.ipv4_mapped + else: + return False + + return other_ip == ip + + +# ****************************************************************************** +def iface_of(src_addr): + '''@brief Find the interface that has src_addr as one of its assigned IP addresses. + @param src_addr: The IP address to match + @type src_addr: Instance of ipaddress.IPv4Address or ipaddress.IPv6Address + ''' + with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW) as sock: + sock.sendall(GETADDRCMD) + nlmsg = sock.recv(8192) + nlmsg_idx = 0 + while True: + if nlmsg_idx >= len(nlmsg): + nlmsg += sock.recv(8192) + + nlmsg_type = int.from_bytes(nlmsg[nlmsg_idx + 4 : nlmsg_idx + 6], byteorder='little', signed=False) + if nlmsg_type == NLMSG_DONE: + break + + if nlmsg_type != RTM_NEWADDR: + break + + nlmsg_len = int.from_bytes(nlmsg[nlmsg_idx : nlmsg_idx + 4], byteorder='little', signed=False) + if nlmsg_len % 4: # Is msg length not a multiple of 4? + break + + ifaddrmsg_indx = nlmsg_idx + NLMSGHDR_SZ + ifa_family = nlmsg[ifaddrmsg_indx] + ifa_index = int.from_bytes(nlmsg[ifaddrmsg_indx + 4 : ifaddrmsg_indx + 8], byteorder='little', signed=False) + + rtattr_indx = ifaddrmsg_indx + IFADDRMSG_SZ + while rtattr_indx < (nlmsg_idx + nlmsg_len): + rta_len = int.from_bytes(nlmsg[rtattr_indx : rtattr_indx + 2], byteorder='little', signed=False) + rta_type = int.from_bytes(nlmsg[rtattr_indx + 2 : rtattr_indx + 4], byteorder='little', signed=False) + if rta_type == IFLA_ADDRESS: + data = nlmsg[rtattr_indx + RTATTR_SZ : rtattr_indx + rta_len] + if _data_matches_ip(ifa_family, data, src_addr): + return socket.if_indextoname(ifa_index) + + rta_len = (rta_len + 3) & ~3 # Round up to multiple of 4 + rtattr_indx += rta_len # Move to next rtattr + + nlmsg_idx += nlmsg_len # Move to next Netlink message + + return '' + + +# ****************************************************************************** +def get_interface(src_addr): + '''Get interface for given source address + @param src_addr: The source address + @type src_addr: str + ''' + if not src_addr: + return '' + + src_addr = src_addr.split('%')[0] # remove scope-id (if any) + src_addr = get_ipaddress_obj(src_addr) + return '' if src_addr is None else iface_of(src_addr) + + +# ****************************************************************************** +def remove_invalid_addresses(controllers: list): + '''@brief Remove controllers with invalid addresses from the list of controllers. + @param controllers: List of TIDs + ''' + service_conf = conf.SvcConf() + valid_controllers = list() + for controller in controllers: + if controller.transport in ('tcp', 'rdma'): + # Let's make sure that traddr is + # syntactically a valid IPv4 or IPv6 address. + ip = get_ipaddress_obj(controller.traddr) + if ip is None: + logging.warning('%s IP address is not valid', controller) + continue + + # Let's make sure the address family is enabled. + if ip.version not in service_conf.ip_family: + logging.debug( + '%s ignored because IPv%s is disabled in %s', + controller, + ip.version, + service_conf.conf_file, + ) + continue + + valid_controllers.append(controller) + + elif controller.transport in ('fc', 'loop'): + # At some point, need to validate FC addresses as well... + valid_controllers.append(controller) + + else: + logging.warning('Invalid transport %s', controller.transport) + + return valid_controllers |