summaryrefslogtreecommitdiffstats
path: root/staslib/iputil.py
blob: 9199a491d18711d9e872c27f71f3b5ce43cb1218 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# Copyright (c) 2022, Dell Inc. or its subsidiaries.  All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# See the LICENSE file for details.
#
# This file is part of NVMe STorage Appliance Services (nvme-stas).
#
# Authors: Martin Belanger <Martin.Belanger@dell.com>

'''A collection of IP address and network interface utilities'''

import socket
import logging
import ipaddress
from staslib import conf

RTM_NEWADDR = 20
RTM_GETADDR = 22
NLM_F_REQUEST = 0x01
NLM_F_ROOT = 0x100
NLMSG_DONE = 3
IFLA_ADDRESS = 1
NLMSGHDR_SZ = 16
IFADDRMSG_SZ = 8
RTATTR_SZ = 4

# Netlink request (Get address command)
GETADDRCMD = (
    # BEGIN: struct nlmsghdr
    b'\0' * 4  # nlmsg_len (placeholder - actual length calculated below)
    + (RTM_GETADDR).to_bytes(2, byteorder='little', signed=False)  # nlmsg_type
    + (NLM_F_REQUEST | NLM_F_ROOT).to_bytes(2, byteorder='little', signed=False)  # nlmsg_flags
    + b'\0' * 2  # nlmsg_seq
    + b'\0' * 2  # nlmsg_pid
    # END: struct nlmsghdr
    + b'\0' * 8  # struct ifaddrmsg
)
GETADDRCMD = len(GETADDRCMD).to_bytes(4, byteorder='little') + GETADDRCMD[4:]  # nlmsg_len


# ******************************************************************************
def get_ipaddress_obj(ipaddr):
    '''@brief Return a IPv4Address or IPv6Address depending on whether @ipaddr
    is a valid IPv4 or IPv6 address. Return None otherwise.'''
    try:
        ip = ipaddress.ip_address(ipaddr)
    except ValueError:
        return None

    return ip


# ******************************************************************************
def _data_matches_ip(data_family, data, ip):
    if data_family == socket.AF_INET:
        try:
            other_ip = ipaddress.IPv4Address(data)
        except ValueError:
            return False
        if ip.version == 6:
            ip = ip.ipv4_mapped
    elif data_family == socket.AF_INET6:
        try:
            other_ip = ipaddress.IPv6Address(data)
        except ValueError:
            return False
        if ip.version == 4:
            other_ip = other_ip.ipv4_mapped
    else:
        return False

    return other_ip == ip


# ******************************************************************************
def iface_of(src_addr):
    '''@brief Find the interface that has src_addr as one of its assigned IP addresses.
    @param src_addr: The IP address to match
    @type src_addr: Instance of ipaddress.IPv4Address or ipaddress.IPv6Address
    '''
    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW) as sock:
        sock.sendall(GETADDRCMD)
        nlmsg = sock.recv(8192)
        nlmsg_idx = 0
        while True:
            if nlmsg_idx >= len(nlmsg):
                nlmsg += sock.recv(8192)

            nlmsg_type = int.from_bytes(nlmsg[nlmsg_idx + 4 : nlmsg_idx + 6], byteorder='little', signed=False)
            if nlmsg_type == NLMSG_DONE:
                break

            if nlmsg_type != RTM_NEWADDR:
                break

            nlmsg_len = int.from_bytes(nlmsg[nlmsg_idx : nlmsg_idx + 4], byteorder='little', signed=False)
            if nlmsg_len % 4:  # Is msg length not a multiple of 4?
                break

            ifaddrmsg_indx = nlmsg_idx + NLMSGHDR_SZ
            ifa_family = nlmsg[ifaddrmsg_indx]
            ifa_index = int.from_bytes(nlmsg[ifaddrmsg_indx + 4 : ifaddrmsg_indx + 8], byteorder='little', signed=False)

            rtattr_indx = ifaddrmsg_indx + IFADDRMSG_SZ
            while rtattr_indx < (nlmsg_idx + nlmsg_len):
                rta_len = int.from_bytes(nlmsg[rtattr_indx : rtattr_indx + 2], byteorder='little', signed=False)
                rta_type = int.from_bytes(nlmsg[rtattr_indx + 2 : rtattr_indx + 4], byteorder='little', signed=False)
                if rta_type == IFLA_ADDRESS:
                    data = nlmsg[rtattr_indx + RTATTR_SZ : rtattr_indx + rta_len]
                    if _data_matches_ip(ifa_family, data, src_addr):
                        return socket.if_indextoname(ifa_index)

                rta_len = (rta_len + 3) & ~3  # Round up to multiple of 4
                rtattr_indx += rta_len  # Move to next rtattr

            nlmsg_idx += nlmsg_len  # Move to next Netlink message

    return ''


# ******************************************************************************
def get_interface(src_addr):
    '''Get interface for given source address
    @param src_addr: The source address
    @type src_addr: str
    '''
    if not src_addr:
        return ''

    src_addr = src_addr.split('%')[0]  # remove scope-id (if any)
    src_addr = get_ipaddress_obj(src_addr)
    return '' if src_addr is None else iface_of(src_addr)


# ******************************************************************************
def remove_invalid_addresses(controllers: list):
    '''@brief Remove controllers with invalid addresses from the list of controllers.
    @param controllers: List of TIDs
    '''
    service_conf = conf.SvcConf()
    valid_controllers = list()
    for controller in controllers:
        if controller.transport in ('tcp', 'rdma'):
            # Let's make sure that traddr is
            # syntactically a valid IPv4 or IPv6 address.
            ip = get_ipaddress_obj(controller.traddr)
            if ip is None:
                logging.warning('%s IP address is not valid', controller)
                continue

            # Let's make sure the address family is enabled.
            if ip.version not in service_conf.ip_family:
                logging.debug(
                    '%s ignored because IPv%s is disabled in %s',
                    controller,
                    ip.version,
                    service_conf.conf_file,
                )
                continue

            valid_controllers.append(controller)

        elif controller.transport in ('fc', 'loop'):
            # At some point, need to validate FC addresses as well...
            valid_controllers.append(controller)

        else:
            logging.warning('Invalid transport %s', controller.transport)

    return valid_controllers