summaryrefslogtreecommitdiffstats
path: root/staslib/iputil.py
diff options
context:
space:
mode:
Diffstat (limited to 'staslib/iputil.py')
-rw-r--r--staslib/iputil.py169
1 files changed, 169 insertions, 0 deletions
diff --git a/staslib/iputil.py b/staslib/iputil.py
new file mode 100644
index 0000000..9199a49
--- /dev/null
+++ b/staslib/iputil.py
@@ -0,0 +1,169 @@
+# Copyright (c) 2022, Dell Inc. or its subsidiaries. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+# See the LICENSE file for details.
+#
+# This file is part of NVMe STorage Appliance Services (nvme-stas).
+#
+# Authors: Martin Belanger <Martin.Belanger@dell.com>
+
+'''A collection of IP address and network interface utilities'''
+
+import socket
+import logging
+import ipaddress
+from staslib import conf
+
+RTM_NEWADDR = 20
+RTM_GETADDR = 22
+NLM_F_REQUEST = 0x01
+NLM_F_ROOT = 0x100
+NLMSG_DONE = 3
+IFLA_ADDRESS = 1
+NLMSGHDR_SZ = 16
+IFADDRMSG_SZ = 8
+RTATTR_SZ = 4
+
+# Netlink request (Get address command)
+GETADDRCMD = (
+ # BEGIN: struct nlmsghdr
+ b'\0' * 4 # nlmsg_len (placeholder - actual length calculated below)
+ + (RTM_GETADDR).to_bytes(2, byteorder='little', signed=False) # nlmsg_type
+ + (NLM_F_REQUEST | NLM_F_ROOT).to_bytes(2, byteorder='little', signed=False) # nlmsg_flags
+ + b'\0' * 2 # nlmsg_seq
+ + b'\0' * 2 # nlmsg_pid
+ # END: struct nlmsghdr
+ + b'\0' * 8 # struct ifaddrmsg
+)
+GETADDRCMD = len(GETADDRCMD).to_bytes(4, byteorder='little') + GETADDRCMD[4:] # nlmsg_len
+
+
+# ******************************************************************************
+def get_ipaddress_obj(ipaddr):
+ '''@brief Return a IPv4Address or IPv6Address depending on whether @ipaddr
+ is a valid IPv4 or IPv6 address. Return None otherwise.'''
+ try:
+ ip = ipaddress.ip_address(ipaddr)
+ except ValueError:
+ return None
+
+ return ip
+
+
+# ******************************************************************************
+def _data_matches_ip(data_family, data, ip):
+ if data_family == socket.AF_INET:
+ try:
+ other_ip = ipaddress.IPv4Address(data)
+ except ValueError:
+ return False
+ if ip.version == 6:
+ ip = ip.ipv4_mapped
+ elif data_family == socket.AF_INET6:
+ try:
+ other_ip = ipaddress.IPv6Address(data)
+ except ValueError:
+ return False
+ if ip.version == 4:
+ other_ip = other_ip.ipv4_mapped
+ else:
+ return False
+
+ return other_ip == ip
+
+
+# ******************************************************************************
+def iface_of(src_addr):
+ '''@brief Find the interface that has src_addr as one of its assigned IP addresses.
+ @param src_addr: The IP address to match
+ @type src_addr: Instance of ipaddress.IPv4Address or ipaddress.IPv6Address
+ '''
+ with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW) as sock:
+ sock.sendall(GETADDRCMD)
+ nlmsg = sock.recv(8192)
+ nlmsg_idx = 0
+ while True:
+ if nlmsg_idx >= len(nlmsg):
+ nlmsg += sock.recv(8192)
+
+ nlmsg_type = int.from_bytes(nlmsg[nlmsg_idx + 4 : nlmsg_idx + 6], byteorder='little', signed=False)
+ if nlmsg_type == NLMSG_DONE:
+ break
+
+ if nlmsg_type != RTM_NEWADDR:
+ break
+
+ nlmsg_len = int.from_bytes(nlmsg[nlmsg_idx : nlmsg_idx + 4], byteorder='little', signed=False)
+ if nlmsg_len % 4: # Is msg length not a multiple of 4?
+ break
+
+ ifaddrmsg_indx = nlmsg_idx + NLMSGHDR_SZ
+ ifa_family = nlmsg[ifaddrmsg_indx]
+ ifa_index = int.from_bytes(nlmsg[ifaddrmsg_indx + 4 : ifaddrmsg_indx + 8], byteorder='little', signed=False)
+
+ rtattr_indx = ifaddrmsg_indx + IFADDRMSG_SZ
+ while rtattr_indx < (nlmsg_idx + nlmsg_len):
+ rta_len = int.from_bytes(nlmsg[rtattr_indx : rtattr_indx + 2], byteorder='little', signed=False)
+ rta_type = int.from_bytes(nlmsg[rtattr_indx + 2 : rtattr_indx + 4], byteorder='little', signed=False)
+ if rta_type == IFLA_ADDRESS:
+ data = nlmsg[rtattr_indx + RTATTR_SZ : rtattr_indx + rta_len]
+ if _data_matches_ip(ifa_family, data, src_addr):
+ return socket.if_indextoname(ifa_index)
+
+ rta_len = (rta_len + 3) & ~3 # Round up to multiple of 4
+ rtattr_indx += rta_len # Move to next rtattr
+
+ nlmsg_idx += nlmsg_len # Move to next Netlink message
+
+ return ''
+
+
+# ******************************************************************************
+def get_interface(src_addr):
+ '''Get interface for given source address
+ @param src_addr: The source address
+ @type src_addr: str
+ '''
+ if not src_addr:
+ return ''
+
+ src_addr = src_addr.split('%')[0] # remove scope-id (if any)
+ src_addr = get_ipaddress_obj(src_addr)
+ return '' if src_addr is None else iface_of(src_addr)
+
+
+# ******************************************************************************
+def remove_invalid_addresses(controllers: list):
+ '''@brief Remove controllers with invalid addresses from the list of controllers.
+ @param controllers: List of TIDs
+ '''
+ service_conf = conf.SvcConf()
+ valid_controllers = list()
+ for controller in controllers:
+ if controller.transport in ('tcp', 'rdma'):
+ # Let's make sure that traddr is
+ # syntactically a valid IPv4 or IPv6 address.
+ ip = get_ipaddress_obj(controller.traddr)
+ if ip is None:
+ logging.warning('%s IP address is not valid', controller)
+ continue
+
+ # Let's make sure the address family is enabled.
+ if ip.version not in service_conf.ip_family:
+ logging.debug(
+ '%s ignored because IPv%s is disabled in %s',
+ controller,
+ ip.version,
+ service_conf.conf_file,
+ )
+ continue
+
+ valid_controllers.append(controller)
+
+ elif controller.transport in ('fc', 'loop'):
+ # At some point, need to validate FC addresses as well...
+ valid_controllers.append(controller)
+
+ else:
+ logging.warning('Invalid transport %s', controller.transport)
+
+ return valid_controllers