diff options
Diffstat (limited to 'tools/net/ynl/lib')
-rw-r--r-- | tools/net/ynl/lib/.gitignore | 1 | ||||
-rw-r--r-- | tools/net/ynl/lib/Makefile | 28 | ||||
-rw-r--r-- | tools/net/ynl/lib/__init__.py | 8 | ||||
-rw-r--r-- | tools/net/ynl/lib/nlspec.py | 545 | ||||
-rw-r--r-- | tools/net/ynl/lib/ynl.c | 901 | ||||
-rw-r--r-- | tools/net/ynl/lib/ynl.h | 237 | ||||
-rw-r--r-- | tools/net/ynl/lib/ynl.py | 729 |
7 files changed, 2449 insertions, 0 deletions
diff --git a/tools/net/ynl/lib/.gitignore b/tools/net/ynl/lib/.gitignore new file mode 100644 index 0000000000..c18dd8d83c --- /dev/null +++ b/tools/net/ynl/lib/.gitignore @@ -0,0 +1 @@ +__pycache__/ diff --git a/tools/net/ynl/lib/Makefile b/tools/net/ynl/lib/Makefile new file mode 100644 index 0000000000..d2e50fd0a5 --- /dev/null +++ b/tools/net/ynl/lib/Makefile @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: GPL-2.0 + +CC=gcc +CFLAGS=-std=gnu11 -O2 -W -Wall -Wextra -Wno-unused-parameter -Wshadow +ifeq ("$(DEBUG)","1") + CFLAGS += -g -fsanitize=address -fsanitize=leak -static-libasan +endif + +SRCS=$(wildcard *.c) +OBJS=$(patsubst %.c,%.o,${SRCS}) + +include $(wildcard *.d) + +all: ynl.a + +ynl.a: $(OBJS) + ar rcs $@ $(OBJS) +clean: + rm -f *.o *.d *~ + +hardclean: clean + rm -f *.a + +%.o: %.c + $(COMPILE.c) -MMD -c -o $@ $< + +.PHONY: all clean +.DEFAULT_GOAL=all diff --git a/tools/net/ynl/lib/__init__.py b/tools/net/ynl/lib/__init__.py new file mode 100644 index 0000000000..f7eaa07783 --- /dev/null +++ b/tools/net/ynl/lib/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause + +from .nlspec import SpecAttr, SpecAttrSet, SpecEnumEntry, SpecEnumSet, \ + SpecFamily, SpecOperation +from .ynl import YnlFamily, Netlink + +__all__ = ["SpecAttr", "SpecAttrSet", "SpecEnumEntry", "SpecEnumSet", + "SpecFamily", "SpecOperation", "YnlFamily", "Netlink"] diff --git a/tools/net/ynl/lib/nlspec.py b/tools/net/ynl/lib/nlspec.py new file mode 100644 index 0000000000..37bcb4d8b3 --- /dev/null +++ b/tools/net/ynl/lib/nlspec.py @@ -0,0 +1,545 @@ +# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause + +import collections +import importlib +import os +import yaml + + +# To be loaded dynamically as needed +jsonschema = None + + +class SpecElement: + """Netlink spec element. + + Abstract element of the Netlink spec. Implements the dictionary interface + for access to the raw spec. Supports iterative resolution of dependencies + across elements and class inheritance levels. The elements of the spec + may refer to each other, and although loops should be very rare, having + to maintain correct ordering of instantiation is painful, so the resolve() + method should be used to perform parts of init which require access to + other parts of the spec. + + Attributes: + yaml raw spec as loaded from the spec file + family back reference to the full family + + name name of the entity as listed in the spec (optional) + ident_name name which can be safely used as identifier in code (optional) + """ + def __init__(self, family, yaml): + self.yaml = yaml + self.family = family + + if 'name' in self.yaml: + self.name = self.yaml['name'] + self.ident_name = self.name.replace('-', '_') + + self._super_resolved = False + family.add_unresolved(self) + + def __getitem__(self, key): + return self.yaml[key] + + def __contains__(self, key): + return key in self.yaml + + def get(self, key, default=None): + return self.yaml.get(key, default) + + def resolve_up(self, up): + if not self._super_resolved: + up.resolve() + self._super_resolved = True + + def resolve(self): + pass + + +class SpecEnumEntry(SpecElement): + """ Entry within an enum declared in the Netlink spec. + + Attributes: + doc documentation string + enum_set back reference to the enum + value numerical value of this enum (use accessors in most situations!) + + Methods: + raw_value raw value, i.e. the id in the enum, unlike user value which is a mask for flags + user_value user value, same as raw value for enums, for flags it's the mask + """ + def __init__(self, enum_set, yaml, prev, value_start): + if isinstance(yaml, str): + yaml = {'name': yaml} + super().__init__(enum_set.family, yaml) + + self.doc = yaml.get('doc', '') + self.enum_set = enum_set + + if 'value' in yaml: + self.value = yaml['value'] + elif prev: + self.value = prev.value + 1 + else: + self.value = value_start + + def has_doc(self): + return bool(self.doc) + + def raw_value(self): + return self.value + + def user_value(self, as_flags=None): + if self.enum_set['type'] == 'flags' or as_flags: + return 1 << self.value + else: + return self.value + + +class SpecEnumSet(SpecElement): + """ Enum type + + Represents an enumeration (list of numerical constants) + as declared in the "definitions" section of the spec. + + Attributes: + type enum or flags + entries entries by name + entries_by_val entries by value + Methods: + get_mask for flags compute the mask of all defined values + """ + def __init__(self, family, yaml): + super().__init__(family, yaml) + + self.type = yaml['type'] + + prev_entry = None + value_start = self.yaml.get('value-start', 0) + self.entries = dict() + self.entries_by_val = dict() + for entry in self.yaml['entries']: + e = self.new_entry(entry, prev_entry, value_start) + self.entries[e.name] = e + self.entries_by_val[e.raw_value()] = e + prev_entry = e + + def new_entry(self, entry, prev_entry, value_start): + return SpecEnumEntry(self, entry, prev_entry, value_start) + + def has_doc(self): + if 'doc' in self.yaml: + return True + for entry in self.entries.values(): + if entry.has_doc(): + return True + return False + + def get_mask(self, as_flags=None): + mask = 0 + for e in self.entries.values(): + mask += e.user_value(as_flags) + return mask + + +class SpecAttr(SpecElement): + """ Single Netlink atttribute type + + Represents a single attribute type within an attr space. + + Attributes: + value numerical ID when serialized + attr_set Attribute Set containing this attr + is_multi bool, attr may repeat multiple times + struct_name string, name of struct definition + sub_type string, name of sub type + len integer, optional byte length of binary types + display_hint string, hint to help choose format specifier + when displaying the value + """ + def __init__(self, family, attr_set, yaml, value): + super().__init__(family, yaml) + + self.value = value + self.attr_set = attr_set + self.is_multi = yaml.get('multi-attr', False) + self.struct_name = yaml.get('struct') + self.sub_type = yaml.get('sub-type') + self.byte_order = yaml.get('byte-order') + self.len = yaml.get('len') + self.display_hint = yaml.get('display-hint') + + +class SpecAttrSet(SpecElement): + """ Netlink Attribute Set class. + + Represents a ID space of attributes within Netlink. + + Note that unlike other elements, which expose contents of the raw spec + via the dictionary interface Attribute Set exposes attributes by name. + + Attributes: + attrs ordered dict of all attributes (indexed by name) + attrs_by_val ordered dict of all attributes (indexed by value) + subset_of parent set if this is a subset, otherwise None + """ + def __init__(self, family, yaml): + super().__init__(family, yaml) + + self.subset_of = self.yaml.get('subset-of', None) + + self.attrs = collections.OrderedDict() + self.attrs_by_val = collections.OrderedDict() + + if self.subset_of is None: + val = 1 + for elem in self.yaml['attributes']: + if 'value' in elem: + val = elem['value'] + + attr = self.new_attr(elem, val) + self.attrs[attr.name] = attr + self.attrs_by_val[attr.value] = attr + val += 1 + else: + real_set = family.attr_sets[self.subset_of] + for elem in self.yaml['attributes']: + attr = real_set[elem['name']] + self.attrs[attr.name] = attr + self.attrs_by_val[attr.value] = attr + + def new_attr(self, elem, value): + return SpecAttr(self.family, self, elem, value) + + def __getitem__(self, key): + return self.attrs[key] + + def __contains__(self, key): + return key in self.attrs + + def __iter__(self): + yield from self.attrs + + def items(self): + return self.attrs.items() + + +class SpecStructMember(SpecElement): + """Struct member attribute + + Represents a single struct member attribute. + + Attributes: + type string, type of the member attribute + byte_order string or None for native byte order + enum string, name of the enum definition + len integer, optional byte length of binary types + display_hint string, hint to help choose format specifier + when displaying the value + """ + def __init__(self, family, yaml): + super().__init__(family, yaml) + self.type = yaml['type'] + self.byte_order = yaml.get('byte-order') + self.enum = yaml.get('enum') + self.len = yaml.get('len') + self.display_hint = yaml.get('display-hint') + + +class SpecStruct(SpecElement): + """Netlink struct type + + Represents a C struct definition. + + Attributes: + members ordered list of struct members + """ + def __init__(self, family, yaml): + super().__init__(family, yaml) + + self.members = [] + for member in yaml.get('members', []): + self.members.append(self.new_member(family, member)) + + def new_member(self, family, elem): + return SpecStructMember(family, elem) + + def __iter__(self): + yield from self.members + + def items(self): + return self.members.items() + + +class SpecOperation(SpecElement): + """Netlink Operation + + Information about a single Netlink operation. + + Attributes: + value numerical ID when serialized, None if req/rsp values differ + + req_value numerical ID when serialized, user -> kernel + rsp_value numerical ID when serialized, user <- kernel + is_call bool, whether the operation is a call + is_async bool, whether the operation is a notification + is_resv bool, whether the operation does not exist (it's just a reserved ID) + attr_set attribute set name + fixed_header string, optional name of fixed header struct + + yaml raw spec as loaded from the spec file + """ + def __init__(self, family, yaml, req_value, rsp_value): + super().__init__(family, yaml) + + self.value = req_value if req_value == rsp_value else None + self.req_value = req_value + self.rsp_value = rsp_value + + self.is_call = 'do' in yaml or 'dump' in yaml + self.is_async = 'notify' in yaml or 'event' in yaml + self.is_resv = not self.is_async and not self.is_call + self.fixed_header = self.yaml.get('fixed-header', family.fixed_header) + + # Added by resolve: + self.attr_set = None + delattr(self, "attr_set") + + def resolve(self): + self.resolve_up(super()) + + if 'attribute-set' in self.yaml: + attr_set_name = self.yaml['attribute-set'] + elif 'notify' in self.yaml: + msg = self.family.msgs[self.yaml['notify']] + attr_set_name = msg['attribute-set'] + elif self.is_resv: + attr_set_name = '' + else: + raise Exception(f"Can't resolve attribute set for op '{self.name}'") + if attr_set_name: + self.attr_set = self.family.attr_sets[attr_set_name] + + +class SpecMcastGroup(SpecElement): + """Netlink Multicast Group + + Information about a multicast group. + + Value is only used for classic netlink families that use the + netlink-raw schema. Genetlink families use dynamic ID allocation + where the ids of multicast groups get resolved at runtime. Value + will be None for genetlink families. + + Attributes: + name name of the mulitcast group + value integer id of this multicast group for netlink-raw or None + yaml raw spec as loaded from the spec file + """ + def __init__(self, family, yaml): + super().__init__(family, yaml) + self.value = self.yaml.get('value') + + +class SpecFamily(SpecElement): + """ Netlink Family Spec class. + + Netlink family information loaded from a spec (e.g. in YAML). + Takes care of unfolding implicit information which can be skipped + in the spec itself for brevity. + + The class can be used like a dictionary to access the raw spec + elements but that's usually a bad idea. + + Attributes: + proto protocol type (e.g. genetlink) + msg_id_model enum-model for operations (unified, directional etc.) + license spec license (loaded from an SPDX tag on the spec) + + attr_sets dict of attribute sets + msgs dict of all messages (index by name) + ops dict of all valid requests / responses + ntfs dict of all async events + consts dict of all constants/enums + fixed_header string, optional name of family default fixed header struct + mcast_groups dict of all multicast groups (index by name) + """ + def __init__(self, spec_path, schema_path=None, exclude_ops=None): + with open(spec_path, "r") as stream: + prefix = '# SPDX-License-Identifier: ' + first = stream.readline().strip() + if not first.startswith(prefix): + raise Exception('SPDX license tag required in the spec') + self.license = first[len(prefix):] + + stream.seek(0) + spec = yaml.safe_load(stream) + + self._resolution_list = [] + + super().__init__(self, spec) + + self._exclude_ops = exclude_ops if exclude_ops else [] + + self.proto = self.yaml.get('protocol', 'genetlink') + self.msg_id_model = self.yaml['operations'].get('enum-model', 'unified') + + if schema_path is None: + schema_path = os.path.dirname(os.path.dirname(spec_path)) + f'/{self.proto}.yaml' + if schema_path: + global jsonschema + + with open(schema_path, "r") as stream: + schema = yaml.safe_load(stream) + + if jsonschema is None: + jsonschema = importlib.import_module("jsonschema") + + jsonschema.validate(self.yaml, schema) + + self.attr_sets = collections.OrderedDict() + self.msgs = collections.OrderedDict() + self.req_by_value = collections.OrderedDict() + self.rsp_by_value = collections.OrderedDict() + self.ops = collections.OrderedDict() + self.ntfs = collections.OrderedDict() + self.consts = collections.OrderedDict() + self.mcast_groups = collections.OrderedDict() + + last_exception = None + while len(self._resolution_list) > 0: + resolved = [] + unresolved = self._resolution_list + self._resolution_list = [] + + for elem in unresolved: + try: + elem.resolve() + except (KeyError, AttributeError) as e: + self._resolution_list.append(elem) + last_exception = e + continue + + resolved.append(elem) + + if len(resolved) == 0: + raise last_exception + + def new_enum(self, elem): + return SpecEnumSet(self, elem) + + def new_attr_set(self, elem): + return SpecAttrSet(self, elem) + + def new_struct(self, elem): + return SpecStruct(self, elem) + + def new_operation(self, elem, req_val, rsp_val): + return SpecOperation(self, elem, req_val, rsp_val) + + def new_mcast_group(self, elem): + return SpecMcastGroup(self, elem) + + def add_unresolved(self, elem): + self._resolution_list.append(elem) + + def _dictify_ops_unified(self): + self.fixed_header = self.yaml['operations'].get('fixed-header') + val = 1 + for elem in self.yaml['operations']['list']: + if 'value' in elem: + val = elem['value'] + + op = self.new_operation(elem, val, val) + val += 1 + + self.msgs[op.name] = op + + def _dictify_ops_directional(self): + self.fixed_header = self.yaml['operations'].get('fixed-header') + req_val = rsp_val = 1 + for elem in self.yaml['operations']['list']: + if 'notify' in elem or 'event' in elem: + if 'value' in elem: + rsp_val = elem['value'] + req_val_next = req_val + rsp_val_next = rsp_val + 1 + req_val = None + elif 'do' in elem or 'dump' in elem: + mode = elem['do'] if 'do' in elem else elem['dump'] + + v = mode.get('request', {}).get('value', None) + if v: + req_val = v + v = mode.get('reply', {}).get('value', None) + if v: + rsp_val = v + + rsp_inc = 1 if 'reply' in mode else 0 + req_val_next = req_val + 1 + rsp_val_next = rsp_val + rsp_inc + else: + raise Exception("Can't parse directional ops") + + if req_val == req_val_next: + req_val = None + if rsp_val == rsp_val_next: + rsp_val = None + + skip = False + for exclude in self._exclude_ops: + skip |= bool(exclude.match(elem['name'])) + if not skip: + op = self.new_operation(elem, req_val, rsp_val) + + req_val = req_val_next + rsp_val = rsp_val_next + + self.msgs[op.name] = op + + def find_operation(self, name): + """ + For a given operation name, find and return operation spec. + """ + for op in self.yaml['operations']['list']: + if name == op['name']: + return op + return None + + def resolve(self): + self.resolve_up(super()) + + definitions = self.yaml.get('definitions', []) + for elem in definitions: + if elem['type'] == 'enum' or elem['type'] == 'flags': + self.consts[elem['name']] = self.new_enum(elem) + elif elem['type'] == 'struct': + self.consts[elem['name']] = self.new_struct(elem) + else: + self.consts[elem['name']] = elem + + for elem in self.yaml['attribute-sets']: + attr_set = self.new_attr_set(elem) + self.attr_sets[elem['name']] = attr_set + + if self.msg_id_model == 'unified': + self._dictify_ops_unified() + elif self.msg_id_model == 'directional': + self._dictify_ops_directional() + + for op in self.msgs.values(): + if op.req_value is not None: + self.req_by_value[op.req_value] = op + if op.rsp_value is not None: + self.rsp_by_value[op.rsp_value] = op + if not op.is_async and 'attribute-set' in op: + self.ops[op.name] = op + elif op.is_async: + self.ntfs[op.name] = op + + mcgs = self.yaml.get('mcast-groups') + if mcgs: + for elem in mcgs['list']: + mcg = self.new_mcast_group(elem) + self.mcast_groups[elem['name']] = mcg diff --git a/tools/net/ynl/lib/ynl.c b/tools/net/ynl/lib/ynl.c new file mode 100644 index 0000000000..514e0d69e7 --- /dev/null +++ b/tools/net/ynl/lib/ynl.c @@ -0,0 +1,901 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +#include <errno.h> +#include <poll.h> +#include <string.h> +#include <stdlib.h> +#include <linux/types.h> + +#include <libmnl/libmnl.h> +#include <linux/genetlink.h> + +#include "ynl.h" + +#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof(*arr)) + +#define __yerr_msg(yse, _msg...) \ + ({ \ + struct ynl_error *_yse = (yse); \ + \ + if (_yse) { \ + snprintf(_yse->msg, sizeof(_yse->msg) - 1, _msg); \ + _yse->msg[sizeof(_yse->msg) - 1] = 0; \ + } \ + }) + +#define __yerr_code(yse, _code...) \ + ({ \ + struct ynl_error *_yse = (yse); \ + \ + if (_yse) { \ + _yse->code = _code; \ + } \ + }) + +#define __yerr(yse, _code, _msg...) \ + ({ \ + __yerr_msg(yse, _msg); \ + __yerr_code(yse, _code); \ + }) + +#define __perr(yse, _msg) __yerr(yse, errno, _msg) + +#define yerr_msg(_ys, _msg...) __yerr_msg(&(_ys)->err, _msg) +#define yerr(_ys, _code, _msg...) __yerr(&(_ys)->err, _code, _msg) +#define perr(_ys, _msg) __yerr(&(_ys)->err, errno, _msg) + +/* -- Netlink boiler plate */ +static int +ynl_err_walk_report_one(struct ynl_policy_nest *policy, unsigned int type, + char *str, int str_sz, int *n) +{ + if (!policy) { + if (*n < str_sz) + *n += snprintf(str, str_sz, "!policy"); + return 1; + } + + if (type > policy->max_attr) { + if (*n < str_sz) + *n += snprintf(str, str_sz, "!oob"); + return 1; + } + + if (!policy->table[type].name) { + if (*n < str_sz) + *n += snprintf(str, str_sz, "!name"); + return 1; + } + + if (*n < str_sz) + *n += snprintf(str, str_sz - *n, + ".%s", policy->table[type].name); + return 0; +} + +static int +ynl_err_walk(struct ynl_sock *ys, void *start, void *end, unsigned int off, + struct ynl_policy_nest *policy, char *str, int str_sz, + struct ynl_policy_nest **nest_pol) +{ + unsigned int astart_off, aend_off; + const struct nlattr *attr; + unsigned int data_len; + unsigned int type; + bool found = false; + int n = 0; + + if (!policy) { + if (n < str_sz) + n += snprintf(str, str_sz, "!policy"); + return n; + } + + data_len = end - start; + + mnl_attr_for_each_payload(start, data_len) { + astart_off = (char *)attr - (char *)start; + aend_off = astart_off + mnl_attr_get_payload_len(attr); + if (aend_off <= off) + continue; + + found = true; + break; + } + if (!found) + return 0; + + off -= astart_off; + + type = mnl_attr_get_type(attr); + + if (ynl_err_walk_report_one(policy, type, str, str_sz, &n)) + return n; + + if (!off) { + if (nest_pol) + *nest_pol = policy->table[type].nest; + return n; + } + + if (!policy->table[type].nest) { + if (n < str_sz) + n += snprintf(str, str_sz, "!nest"); + return n; + } + + off -= sizeof(struct nlattr); + start = mnl_attr_get_payload(attr); + end = start + mnl_attr_get_payload_len(attr); + + return n + ynl_err_walk(ys, start, end, off, policy->table[type].nest, + &str[n], str_sz - n, nest_pol); +} + +#define NLMSGERR_ATTR_MISS_TYPE (NLMSGERR_ATTR_POLICY + 1) +#define NLMSGERR_ATTR_MISS_NEST (NLMSGERR_ATTR_POLICY + 2) +#define NLMSGERR_ATTR_MAX (NLMSGERR_ATTR_MAX + 2) + +static int +ynl_ext_ack_check(struct ynl_sock *ys, const struct nlmsghdr *nlh, + unsigned int hlen) +{ + const struct nlattr *tb[NLMSGERR_ATTR_MAX + 1] = {}; + char miss_attr[sizeof(ys->err.msg)]; + char bad_attr[sizeof(ys->err.msg)]; + const struct nlattr *attr; + const char *str = NULL; + + if (!(nlh->nlmsg_flags & NLM_F_ACK_TLVS)) + return MNL_CB_OK; + + mnl_attr_for_each(attr, nlh, hlen) { + unsigned int len, type; + + len = mnl_attr_get_payload_len(attr); + type = mnl_attr_get_type(attr); + + if (type > NLMSGERR_ATTR_MAX) + continue; + + tb[type] = attr; + + switch (type) { + case NLMSGERR_ATTR_OFFS: + case NLMSGERR_ATTR_MISS_TYPE: + case NLMSGERR_ATTR_MISS_NEST: + if (len != sizeof(__u32)) + return MNL_CB_ERROR; + break; + case NLMSGERR_ATTR_MSG: + str = mnl_attr_get_payload(attr); + if (str[len - 1]) + return MNL_CB_ERROR; + break; + default: + break; + } + } + + bad_attr[0] = '\0'; + miss_attr[0] = '\0'; + + if (tb[NLMSGERR_ATTR_OFFS]) { + unsigned int n, off; + void *start, *end; + + ys->err.attr_offs = mnl_attr_get_u32(tb[NLMSGERR_ATTR_OFFS]); + + n = snprintf(bad_attr, sizeof(bad_attr), "%sbad attribute: ", + str ? " (" : ""); + + start = mnl_nlmsg_get_payload_offset(ys->nlh, + sizeof(struct genlmsghdr)); + end = mnl_nlmsg_get_payload_tail(ys->nlh); + + off = ys->err.attr_offs; + off -= sizeof(struct nlmsghdr); + off -= sizeof(struct genlmsghdr); + + n += ynl_err_walk(ys, start, end, off, ys->req_policy, + &bad_attr[n], sizeof(bad_attr) - n, NULL); + + if (n >= sizeof(bad_attr)) + n = sizeof(bad_attr) - 1; + bad_attr[n] = '\0'; + } + if (tb[NLMSGERR_ATTR_MISS_TYPE]) { + struct ynl_policy_nest *nest_pol = NULL; + unsigned int n, off, type; + void *start, *end; + int n2; + + type = mnl_attr_get_u32(tb[NLMSGERR_ATTR_MISS_TYPE]); + + n = snprintf(miss_attr, sizeof(miss_attr), "%smissing attribute: ", + bad_attr[0] ? ", " : (str ? " (" : "")); + + start = mnl_nlmsg_get_payload_offset(ys->nlh, + sizeof(struct genlmsghdr)); + end = mnl_nlmsg_get_payload_tail(ys->nlh); + + nest_pol = ys->req_policy; + if (tb[NLMSGERR_ATTR_MISS_NEST]) { + off = mnl_attr_get_u32(tb[NLMSGERR_ATTR_MISS_NEST]); + off -= sizeof(struct nlmsghdr); + off -= sizeof(struct genlmsghdr); + + n += ynl_err_walk(ys, start, end, off, ys->req_policy, + &miss_attr[n], sizeof(miss_attr) - n, + &nest_pol); + } + + n2 = 0; + ynl_err_walk_report_one(nest_pol, type, &miss_attr[n], + sizeof(miss_attr) - n, &n2); + n += n2; + + if (n >= sizeof(miss_attr)) + n = sizeof(miss_attr) - 1; + miss_attr[n] = '\0'; + } + + /* Implicitly depend on ys->err.code already set */ + if (str) + yerr_msg(ys, "Kernel %s: '%s'%s%s%s", + ys->err.code ? "error" : "warning", + str, bad_attr, miss_attr, + bad_attr[0] || miss_attr[0] ? ")" : ""); + else if (bad_attr[0] || miss_attr[0]) + yerr_msg(ys, "Kernel %s: %s%s", + ys->err.code ? "error" : "warning", + bad_attr, miss_attr); + + return MNL_CB_OK; +} + +static int ynl_cb_error(const struct nlmsghdr *nlh, void *data) +{ + const struct nlmsgerr *err = mnl_nlmsg_get_payload(nlh); + struct ynl_parse_arg *yarg = data; + unsigned int hlen; + int code; + + code = err->error >= 0 ? err->error : -err->error; + yarg->ys->err.code = code; + errno = code; + + hlen = sizeof(*err); + if (!(nlh->nlmsg_flags & NLM_F_CAPPED)) + hlen += mnl_nlmsg_get_payload_len(&err->msg); + + ynl_ext_ack_check(yarg->ys, nlh, hlen); + + return code ? MNL_CB_ERROR : MNL_CB_STOP; +} + +static int ynl_cb_done(const struct nlmsghdr *nlh, void *data) +{ + struct ynl_parse_arg *yarg = data; + int err; + + err = *(int *)NLMSG_DATA(nlh); + if (err < 0) { + yarg->ys->err.code = -err; + errno = -err; + + ynl_ext_ack_check(yarg->ys, nlh, sizeof(int)); + + return MNL_CB_ERROR; + } + return MNL_CB_STOP; +} + +static int ynl_cb_noop(const struct nlmsghdr *nlh, void *data) +{ + return MNL_CB_OK; +} + +mnl_cb_t ynl_cb_array[NLMSG_MIN_TYPE] = { + [NLMSG_NOOP] = ynl_cb_noop, + [NLMSG_ERROR] = ynl_cb_error, + [NLMSG_DONE] = ynl_cb_done, + [NLMSG_OVERRUN] = ynl_cb_noop, +}; + +/* Attribute validation */ + +int ynl_attr_validate(struct ynl_parse_arg *yarg, const struct nlattr *attr) +{ + struct ynl_policy_attr *policy; + unsigned int type, len; + unsigned char *data; + + data = mnl_attr_get_payload(attr); + len = mnl_attr_get_payload_len(attr); + type = mnl_attr_get_type(attr); + if (type > yarg->rsp_policy->max_attr) { + yerr(yarg->ys, YNL_ERROR_INTERNAL, + "Internal error, validating unknown attribute"); + return -1; + } + + policy = &yarg->rsp_policy->table[type]; + + switch (policy->type) { + case YNL_PT_REJECT: + yerr(yarg->ys, YNL_ERROR_ATTR_INVALID, + "Rejected attribute (%s)", policy->name); + return -1; + case YNL_PT_IGNORE: + break; + case YNL_PT_U8: + if (len == sizeof(__u8)) + break; + yerr(yarg->ys, YNL_ERROR_ATTR_INVALID, + "Invalid attribute (u8 %s)", policy->name); + return -1; + case YNL_PT_U16: + if (len == sizeof(__u16)) + break; + yerr(yarg->ys, YNL_ERROR_ATTR_INVALID, + "Invalid attribute (u16 %s)", policy->name); + return -1; + case YNL_PT_U32: + if (len == sizeof(__u32)) + break; + yerr(yarg->ys, YNL_ERROR_ATTR_INVALID, + "Invalid attribute (u32 %s)", policy->name); + return -1; + case YNL_PT_U64: + if (len == sizeof(__u64)) + break; + yerr(yarg->ys, YNL_ERROR_ATTR_INVALID, + "Invalid attribute (u64 %s)", policy->name); + return -1; + case YNL_PT_FLAG: + /* Let flags grow into real attrs, why not.. */ + break; + case YNL_PT_NEST: + if (!len || len >= sizeof(*attr)) + break; + yerr(yarg->ys, YNL_ERROR_ATTR_INVALID, + "Invalid attribute (nest %s)", policy->name); + return -1; + case YNL_PT_BINARY: + if (!policy->len || len == policy->len) + break; + yerr(yarg->ys, YNL_ERROR_ATTR_INVALID, + "Invalid attribute (binary %s)", policy->name); + return -1; + case YNL_PT_NUL_STR: + if ((!policy->len || len <= policy->len) && !data[len - 1]) + break; + yerr(yarg->ys, YNL_ERROR_ATTR_INVALID, + "Invalid attribute (string %s)", policy->name); + return -1; + default: + yerr(yarg->ys, YNL_ERROR_ATTR_INVALID, + "Invalid attribute (unknown %s)", policy->name); + return -1; + } + + return 0; +} + +/* Generic code */ + +static void ynl_err_reset(struct ynl_sock *ys) +{ + ys->err.code = 0; + ys->err.attr_offs = 0; + ys->err.msg[0] = 0; +} + +struct nlmsghdr *ynl_msg_start(struct ynl_sock *ys, __u32 id, __u16 flags) +{ + struct nlmsghdr *nlh; + + ynl_err_reset(ys); + + nlh = ys->nlh = mnl_nlmsg_put_header(ys->tx_buf); + nlh->nlmsg_type = id; + nlh->nlmsg_flags = flags; + nlh->nlmsg_seq = ++ys->seq; + + return nlh; +} + +struct nlmsghdr * +ynl_gemsg_start(struct ynl_sock *ys, __u32 id, __u16 flags, + __u8 cmd, __u8 version) +{ + struct genlmsghdr gehdr; + struct nlmsghdr *nlh; + void *data; + + nlh = ynl_msg_start(ys, id, flags); + + memset(&gehdr, 0, sizeof(gehdr)); + gehdr.cmd = cmd; + gehdr.version = version; + + data = mnl_nlmsg_put_extra_header(nlh, sizeof(gehdr)); + memcpy(data, &gehdr, sizeof(gehdr)); + + return nlh; +} + +void ynl_msg_start_req(struct ynl_sock *ys, __u32 id) +{ + ynl_msg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK); +} + +void ynl_msg_start_dump(struct ynl_sock *ys, __u32 id) +{ + ynl_msg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP); +} + +struct nlmsghdr * +ynl_gemsg_start_req(struct ynl_sock *ys, __u32 id, __u8 cmd, __u8 version) +{ + return ynl_gemsg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK, cmd, version); +} + +struct nlmsghdr * +ynl_gemsg_start_dump(struct ynl_sock *ys, __u32 id, __u8 cmd, __u8 version) +{ + return ynl_gemsg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP, + cmd, version); +} + +int ynl_recv_ack(struct ynl_sock *ys, int ret) +{ + if (!ret) { + yerr(ys, YNL_ERROR_EXPECT_ACK, + "Expecting an ACK but nothing received"); + return -1; + } + + ret = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE); + if (ret < 0) { + perr(ys, "Socket receive failed"); + return ret; + } + return mnl_cb_run(ys->rx_buf, ret, ys->seq, ys->portid, + ynl_cb_null, ys); +} + +int ynl_cb_null(const struct nlmsghdr *nlh, void *data) +{ + struct ynl_parse_arg *yarg = data; + + yerr(yarg->ys, YNL_ERROR_UNEXPECT_MSG, + "Received a message when none were expected"); + + return MNL_CB_ERROR; +} + +/* Init/fini and genetlink boiler plate */ +static int +ynl_get_family_info_mcast(struct ynl_sock *ys, const struct nlattr *mcasts) +{ + const struct nlattr *entry, *attr; + unsigned int i; + + mnl_attr_for_each_nested(attr, mcasts) + ys->n_mcast_groups++; + + if (!ys->n_mcast_groups) + return 0; + + ys->mcast_groups = calloc(ys->n_mcast_groups, + sizeof(*ys->mcast_groups)); + if (!ys->mcast_groups) + return MNL_CB_ERROR; + + i = 0; + mnl_attr_for_each_nested(entry, mcasts) { + mnl_attr_for_each_nested(attr, entry) { + if (mnl_attr_get_type(attr) == CTRL_ATTR_MCAST_GRP_ID) + ys->mcast_groups[i].id = mnl_attr_get_u32(attr); + if (mnl_attr_get_type(attr) == CTRL_ATTR_MCAST_GRP_NAME) { + strncpy(ys->mcast_groups[i].name, + mnl_attr_get_str(attr), + GENL_NAMSIZ - 1); + ys->mcast_groups[i].name[GENL_NAMSIZ - 1] = 0; + } + } + } + + return 0; +} + +static int ynl_get_family_info_cb(const struct nlmsghdr *nlh, void *data) +{ + struct ynl_parse_arg *yarg = data; + struct ynl_sock *ys = yarg->ys; + const struct nlattr *attr; + bool found_id = true; + + mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr)) { + if (mnl_attr_get_type(attr) == CTRL_ATTR_MCAST_GROUPS) + if (ynl_get_family_info_mcast(ys, attr)) + return MNL_CB_ERROR; + + if (mnl_attr_get_type(attr) != CTRL_ATTR_FAMILY_ID) + continue; + + if (mnl_attr_get_payload_len(attr) != sizeof(__u16)) { + yerr(ys, YNL_ERROR_ATTR_INVALID, "Invalid family ID"); + return MNL_CB_ERROR; + } + + ys->family_id = mnl_attr_get_u16(attr); + found_id = true; + } + + if (!found_id) { + yerr(ys, YNL_ERROR_ATTR_MISSING, "Family ID missing"); + return MNL_CB_ERROR; + } + return MNL_CB_OK; +} + +static int ynl_sock_read_family(struct ynl_sock *ys, const char *family_name) +{ + struct ynl_parse_arg yarg = { .ys = ys, }; + struct nlmsghdr *nlh; + int err; + + nlh = ynl_gemsg_start_req(ys, GENL_ID_CTRL, CTRL_CMD_GETFAMILY, 1); + mnl_attr_put_strz(nlh, CTRL_ATTR_FAMILY_NAME, family_name); + + err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len); + if (err < 0) { + perr(ys, "failed to request socket family info"); + return err; + } + + err = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE); + if (err <= 0) { + perr(ys, "failed to receive the socket family info"); + return err; + } + err = mnl_cb_run2(ys->rx_buf, err, ys->seq, ys->portid, + ynl_get_family_info_cb, &yarg, + ynl_cb_array, ARRAY_SIZE(ynl_cb_array)); + if (err < 0) { + free(ys->mcast_groups); + perr(ys, "failed to receive the socket family info - no such family?"); + return err; + } + + return ynl_recv_ack(ys, err); +} + +struct ynl_sock * +ynl_sock_create(const struct ynl_family *yf, struct ynl_error *yse) +{ + struct ynl_sock *ys; + int one = 1; + + ys = malloc(sizeof(*ys) + 2 * MNL_SOCKET_BUFFER_SIZE); + if (!ys) + return NULL; + memset(ys, 0, sizeof(*ys)); + + ys->family = yf; + ys->tx_buf = &ys->raw_buf[0]; + ys->rx_buf = &ys->raw_buf[MNL_SOCKET_BUFFER_SIZE]; + ys->ntf_last_next = &ys->ntf_first; + + ys->sock = mnl_socket_open(NETLINK_GENERIC); + if (!ys->sock) { + __perr(yse, "failed to create a netlink socket"); + goto err_free_sock; + } + + if (mnl_socket_setsockopt(ys->sock, NETLINK_CAP_ACK, + &one, sizeof(one))) { + __perr(yse, "failed to enable netlink ACK"); + goto err_close_sock; + } + if (mnl_socket_setsockopt(ys->sock, NETLINK_EXT_ACK, + &one, sizeof(one))) { + __perr(yse, "failed to enable netlink ext ACK"); + goto err_close_sock; + } + + ys->seq = random(); + ys->portid = mnl_socket_get_portid(ys->sock); + + if (ynl_sock_read_family(ys, yf->name)) { + if (yse) + memcpy(yse, &ys->err, sizeof(*yse)); + goto err_close_sock; + } + + return ys; + +err_close_sock: + mnl_socket_close(ys->sock); +err_free_sock: + free(ys); + return NULL; +} + +void ynl_sock_destroy(struct ynl_sock *ys) +{ + struct ynl_ntf_base_type *ntf; + + mnl_socket_close(ys->sock); + while ((ntf = ynl_ntf_dequeue(ys))) + ynl_ntf_free(ntf); + free(ys->mcast_groups); + free(ys); +} + +/* YNL multicast handling */ + +void ynl_ntf_free(struct ynl_ntf_base_type *ntf) +{ + ntf->free(ntf); +} + +int ynl_subscribe(struct ynl_sock *ys, const char *grp_name) +{ + unsigned int i; + int err; + + for (i = 0; i < ys->n_mcast_groups; i++) + if (!strcmp(ys->mcast_groups[i].name, grp_name)) + break; + if (i == ys->n_mcast_groups) { + yerr(ys, ENOENT, "Multicast group '%s' not found", grp_name); + return -1; + } + + err = mnl_socket_setsockopt(ys->sock, NETLINK_ADD_MEMBERSHIP, + &ys->mcast_groups[i].id, + sizeof(ys->mcast_groups[i].id)); + if (err < 0) { + perr(ys, "Subscribing to multicast group failed"); + return -1; + } + + return 0; +} + +int ynl_socket_get_fd(struct ynl_sock *ys) +{ + return mnl_socket_get_fd(ys->sock); +} + +struct ynl_ntf_base_type *ynl_ntf_dequeue(struct ynl_sock *ys) +{ + struct ynl_ntf_base_type *ntf; + + if (!ynl_has_ntf(ys)) + return NULL; + + ntf = ys->ntf_first; + ys->ntf_first = ntf->next; + if (ys->ntf_last_next == &ntf->next) + ys->ntf_last_next = &ys->ntf_first; + + return ntf; +} + +static int ynl_ntf_parse(struct ynl_sock *ys, const struct nlmsghdr *nlh) +{ + struct ynl_parse_arg yarg = { .ys = ys, }; + const struct ynl_ntf_info *info; + struct ynl_ntf_base_type *rsp; + struct genlmsghdr *gehdr; + int ret; + + gehdr = mnl_nlmsg_get_payload(nlh); + if (gehdr->cmd >= ys->family->ntf_info_size) + return MNL_CB_ERROR; + info = &ys->family->ntf_info[gehdr->cmd]; + if (!info->cb) + return MNL_CB_ERROR; + + rsp = calloc(1, info->alloc_sz); + rsp->free = info->free; + yarg.data = rsp->data; + yarg.rsp_policy = info->policy; + + ret = info->cb(nlh, &yarg); + if (ret <= MNL_CB_STOP) + goto err_free; + + rsp->family = nlh->nlmsg_type; + rsp->cmd = gehdr->cmd; + + *ys->ntf_last_next = rsp; + ys->ntf_last_next = &rsp->next; + + return MNL_CB_OK; + +err_free: + info->free(rsp); + return MNL_CB_ERROR; +} + +static int ynl_ntf_trampoline(const struct nlmsghdr *nlh, void *data) +{ + return ynl_ntf_parse((struct ynl_sock *)data, nlh); +} + +int ynl_ntf_check(struct ynl_sock *ys) +{ + ssize_t len; + int err; + + do { + /* libmnl doesn't let us pass flags to the recv to make + * it non-blocking so we need to poll() or peek() :| + */ + struct pollfd pfd = { }; + + pfd.fd = mnl_socket_get_fd(ys->sock); + pfd.events = POLLIN; + err = poll(&pfd, 1, 1); + if (err < 1) + return err; + + len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, + MNL_SOCKET_BUFFER_SIZE); + if (len < 0) + return len; + + err = mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid, + ynl_ntf_trampoline, ys, + ynl_cb_array, NLMSG_MIN_TYPE); + if (err < 0) + return err; + } while (err > 0); + + return 0; +} + +/* YNL specific helpers used by the auto-generated code */ + +struct ynl_dump_list_type *YNL_LIST_END = (void *)(0xb4d123); + +void ynl_error_unknown_notification(struct ynl_sock *ys, __u8 cmd) +{ + yerr(ys, YNL_ERROR_UNKNOWN_NTF, + "Unknown notification message type '%d'", cmd); +} + +int ynl_error_parse(struct ynl_parse_arg *yarg, const char *msg) +{ + yerr(yarg->ys, YNL_ERROR_INV_RESP, "Error parsing response: %s", msg); + return MNL_CB_ERROR; +} + +static int +ynl_check_alien(struct ynl_sock *ys, const struct nlmsghdr *nlh, __u32 rsp_cmd) +{ + struct genlmsghdr *gehdr; + + if (mnl_nlmsg_get_payload_len(nlh) < sizeof(*gehdr)) { + yerr(ys, YNL_ERROR_INV_RESP, + "Kernel responded with truncated message"); + return -1; + } + + gehdr = mnl_nlmsg_get_payload(nlh); + if (gehdr->cmd != rsp_cmd) + return ynl_ntf_parse(ys, nlh); + + return 0; +} + +static int ynl_req_trampoline(const struct nlmsghdr *nlh, void *data) +{ + struct ynl_req_state *yrs = data; + int ret; + + ret = ynl_check_alien(yrs->yarg.ys, nlh, yrs->rsp_cmd); + if (ret) + return ret < 0 ? MNL_CB_ERROR : MNL_CB_OK; + + return yrs->cb(nlh, &yrs->yarg); +} + +int ynl_exec(struct ynl_sock *ys, struct nlmsghdr *req_nlh, + struct ynl_req_state *yrs) +{ + ssize_t len; + int err; + + err = mnl_socket_sendto(ys->sock, req_nlh, req_nlh->nlmsg_len); + if (err < 0) + return err; + + do { + len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, + MNL_SOCKET_BUFFER_SIZE); + if (len < 0) + return len; + + err = mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid, + ynl_req_trampoline, yrs, + ynl_cb_array, NLMSG_MIN_TYPE); + if (err < 0) + return err; + } while (err > 0); + + return 0; +} + +static int ynl_dump_trampoline(const struct nlmsghdr *nlh, void *data) +{ + struct ynl_dump_state *ds = data; + struct ynl_dump_list_type *obj; + struct ynl_parse_arg yarg = {}; + int ret; + + ret = ynl_check_alien(ds->ys, nlh, ds->rsp_cmd); + if (ret) + return ret < 0 ? MNL_CB_ERROR : MNL_CB_OK; + + obj = calloc(1, ds->alloc_sz); + if (!obj) + return MNL_CB_ERROR; + + if (!ds->first) + ds->first = obj; + if (ds->last) + ds->last->next = obj; + ds->last = obj; + + yarg.ys = ds->ys; + yarg.rsp_policy = ds->rsp_policy; + yarg.data = &obj->data; + + return ds->cb(nlh, &yarg); +} + +static void *ynl_dump_end(struct ynl_dump_state *ds) +{ + if (!ds->first) + return YNL_LIST_END; + + ds->last->next = YNL_LIST_END; + return ds->first; +} + +int ynl_exec_dump(struct ynl_sock *ys, struct nlmsghdr *req_nlh, + struct ynl_dump_state *yds) +{ + ssize_t len; + int err; + + err = mnl_socket_sendto(ys->sock, req_nlh, req_nlh->nlmsg_len); + if (err < 0) + return err; + + do { + len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, + MNL_SOCKET_BUFFER_SIZE); + if (len < 0) + goto err_close_list; + + err = mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid, + ynl_dump_trampoline, yds, + ynl_cb_array, NLMSG_MIN_TYPE); + if (err < 0) + goto err_close_list; + } while (err > 0); + + yds->first = ynl_dump_end(yds); + return 0; + +err_close_list: + yds->first = ynl_dump_end(yds); + return -1; +} diff --git a/tools/net/ynl/lib/ynl.h b/tools/net/ynl/lib/ynl.h new file mode 100644 index 0000000000..9eafa3552c --- /dev/null +++ b/tools/net/ynl/lib/ynl.h @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +#ifndef __YNL_C_H +#define __YNL_C_H 1 + +#include <stddef.h> +#include <libmnl/libmnl.h> +#include <linux/genetlink.h> +#include <linux/types.h> + +struct mnl_socket; +struct nlmsghdr; + +/* + * User facing code + */ + +struct ynl_ntf_base_type; +struct ynl_ntf_info; +struct ynl_sock; + +enum ynl_error_code { + YNL_ERROR_NONE = 0, + __YNL_ERRNO_END = 4096, + YNL_ERROR_INTERNAL, + YNL_ERROR_EXPECT_ACK, + YNL_ERROR_EXPECT_MSG, + YNL_ERROR_UNEXPECT_MSG, + YNL_ERROR_ATTR_MISSING, + YNL_ERROR_ATTR_INVALID, + YNL_ERROR_UNKNOWN_NTF, + YNL_ERROR_INV_RESP, +}; + +/** + * struct ynl_error - error encountered by YNL + * @code: errno (low values) or YNL error code (enum ynl_error_code) + * @attr_offs: offset of bad attribute (for very advanced users) + * @msg: error message + * + * Error information for when YNL operations fail. + * Users should interact with the err member of struct ynl_sock directly. + * The main exception to that rule is ynl_sock_create(). + */ +struct ynl_error { + enum ynl_error_code code; + unsigned int attr_offs; + char msg[512]; +}; + +/** + * struct ynl_family - YNL family info + * Family description generated by codegen. Pass to ynl_sock_create(). + */ +struct ynl_family { +/* private: */ + const char *name; + const struct ynl_ntf_info *ntf_info; + unsigned int ntf_info_size; +}; + +/** + * struct ynl_sock - YNL wrapped netlink socket + * @err: YNL error descriptor, cleared on every request. + */ +struct ynl_sock { + struct ynl_error err; + +/* private: */ + const struct ynl_family *family; + struct mnl_socket *sock; + __u32 seq; + __u32 portid; + __u16 family_id; + + unsigned int n_mcast_groups; + struct { + unsigned int id; + char name[GENL_NAMSIZ]; + } *mcast_groups; + + struct ynl_ntf_base_type *ntf_first; + struct ynl_ntf_base_type **ntf_last_next; + + struct nlmsghdr *nlh; + struct ynl_policy_nest *req_policy; + unsigned char *tx_buf; + unsigned char *rx_buf; + unsigned char raw_buf[]; +}; + +struct ynl_sock * +ynl_sock_create(const struct ynl_family *yf, struct ynl_error *e); +void ynl_sock_destroy(struct ynl_sock *ys); + +#define ynl_dump_foreach(dump, iter) \ + for (typeof(dump->obj) *iter = &dump->obj; \ + !ynl_dump_obj_is_last(iter); \ + iter = ynl_dump_obj_next(iter)) + +int ynl_subscribe(struct ynl_sock *ys, const char *grp_name); +int ynl_socket_get_fd(struct ynl_sock *ys); +int ynl_ntf_check(struct ynl_sock *ys); + +/** + * ynl_has_ntf() - check if socket has *parsed* notifications + * @ys: active YNL socket + * + * Note that this does not take into account notifications sitting + * in netlink socket, just the notifications which have already been + * read and parsed (e.g. during a ynl_ntf_check() call). + */ +static inline bool ynl_has_ntf(struct ynl_sock *ys) +{ + return ys->ntf_last_next != &ys->ntf_first; +} +struct ynl_ntf_base_type *ynl_ntf_dequeue(struct ynl_sock *ys); + +void ynl_ntf_free(struct ynl_ntf_base_type *ntf); + +/* + * YNL internals / low level stuff + */ + +/* Generic mnl helper code */ + +enum ynl_policy_type { + YNL_PT_REJECT = 1, + YNL_PT_IGNORE, + YNL_PT_NEST, + YNL_PT_FLAG, + YNL_PT_BINARY, + YNL_PT_U8, + YNL_PT_U16, + YNL_PT_U32, + YNL_PT_U64, + YNL_PT_NUL_STR, +}; + +struct ynl_policy_attr { + enum ynl_policy_type type; + unsigned int len; + const char *name; + struct ynl_policy_nest *nest; +}; + +struct ynl_policy_nest { + unsigned int max_attr; + struct ynl_policy_attr *table; +}; + +struct ynl_parse_arg { + struct ynl_sock *ys; + struct ynl_policy_nest *rsp_policy; + void *data; +}; + +struct ynl_dump_list_type { + struct ynl_dump_list_type *next; + unsigned char data[] __attribute__ ((aligned (8))); +}; +extern struct ynl_dump_list_type *YNL_LIST_END; + +static inline bool ynl_dump_obj_is_last(void *obj) +{ + unsigned long uptr = (unsigned long)obj; + + uptr -= offsetof(struct ynl_dump_list_type, data); + return uptr == (unsigned long)YNL_LIST_END; +} + +static inline void *ynl_dump_obj_next(void *obj) +{ + unsigned long uptr = (unsigned long)obj; + struct ynl_dump_list_type *list; + + uptr -= offsetof(struct ynl_dump_list_type, data); + list = (void *)uptr; + uptr = (unsigned long)list->next; + uptr += offsetof(struct ynl_dump_list_type, data); + + return (void *)uptr; +} + +struct ynl_ntf_base_type { + __u16 family; + __u8 cmd; + struct ynl_ntf_base_type *next; + void (*free)(struct ynl_ntf_base_type *ntf); + unsigned char data[] __attribute__ ((aligned (8))); +}; + +extern mnl_cb_t ynl_cb_array[NLMSG_MIN_TYPE]; + +struct nlmsghdr * +ynl_gemsg_start_req(struct ynl_sock *ys, __u32 id, __u8 cmd, __u8 version); +struct nlmsghdr * +ynl_gemsg_start_dump(struct ynl_sock *ys, __u32 id, __u8 cmd, __u8 version); + +int ynl_attr_validate(struct ynl_parse_arg *yarg, const struct nlattr *attr); + +int ynl_recv_ack(struct ynl_sock *ys, int ret); +int ynl_cb_null(const struct nlmsghdr *nlh, void *data); + +/* YNL specific helpers used by the auto-generated code */ + +struct ynl_req_state { + struct ynl_parse_arg yarg; + mnl_cb_t cb; + __u32 rsp_cmd; +}; + +struct ynl_dump_state { + struct ynl_sock *ys; + struct ynl_policy_nest *rsp_policy; + void *first; + struct ynl_dump_list_type *last; + size_t alloc_sz; + mnl_cb_t cb; + __u32 rsp_cmd; +}; + +struct ynl_ntf_info { + struct ynl_policy_nest *policy; + mnl_cb_t cb; + size_t alloc_sz; + void (*free)(struct ynl_ntf_base_type *ntf); +}; + +int ynl_exec(struct ynl_sock *ys, struct nlmsghdr *req_nlh, + struct ynl_req_state *yrs); +int ynl_exec_dump(struct ynl_sock *ys, struct nlmsghdr *req_nlh, + struct ynl_dump_state *yds); + +void ynl_error_unknown_notification(struct ynl_sock *ys, __u8 cmd); +int ynl_error_parse(struct ynl_parse_arg *yarg, const char *msg); + +#endif diff --git a/tools/net/ynl/lib/ynl.py b/tools/net/ynl/lib/ynl.py new file mode 100644 index 0000000000..13c4b019a8 --- /dev/null +++ b/tools/net/ynl/lib/ynl.py @@ -0,0 +1,729 @@ +# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause + +from collections import namedtuple +import functools +import os +import random +import socket +import struct +from struct import Struct +import yaml +import ipaddress +import uuid + +from .nlspec import SpecFamily + +# +# Generic Netlink code which should really be in some library, but I can't quickly find one. +# + + +class Netlink: + # Netlink socket + SOL_NETLINK = 270 + + NETLINK_ADD_MEMBERSHIP = 1 + NETLINK_CAP_ACK = 10 + NETLINK_EXT_ACK = 11 + NETLINK_GET_STRICT_CHK = 12 + + # Netlink message + NLMSG_ERROR = 2 + NLMSG_DONE = 3 + + NLM_F_REQUEST = 1 + NLM_F_ACK = 4 + NLM_F_ROOT = 0x100 + NLM_F_MATCH = 0x200 + + NLM_F_REPLACE = 0x100 + NLM_F_EXCL = 0x200 + NLM_F_CREATE = 0x400 + NLM_F_APPEND = 0x800 + + NLM_F_CAPPED = 0x100 + NLM_F_ACK_TLVS = 0x200 + + NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH + + NLA_F_NESTED = 0x8000 + NLA_F_NET_BYTEORDER = 0x4000 + + NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER + + # Genetlink defines + NETLINK_GENERIC = 16 + + GENL_ID_CTRL = 0x10 + + # nlctrl + CTRL_CMD_GETFAMILY = 3 + + CTRL_ATTR_FAMILY_ID = 1 + CTRL_ATTR_FAMILY_NAME = 2 + CTRL_ATTR_MAXATTR = 5 + CTRL_ATTR_MCAST_GROUPS = 7 + + CTRL_ATTR_MCAST_GRP_NAME = 1 + CTRL_ATTR_MCAST_GRP_ID = 2 + + # Extack types + NLMSGERR_ATTR_MSG = 1 + NLMSGERR_ATTR_OFFS = 2 + NLMSGERR_ATTR_COOKIE = 3 + NLMSGERR_ATTR_POLICY = 4 + NLMSGERR_ATTR_MISS_TYPE = 5 + NLMSGERR_ATTR_MISS_NEST = 6 + + +class NlError(Exception): + def __init__(self, nl_msg): + self.nl_msg = nl_msg + + def __str__(self): + return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}" + + +class NlAttr: + ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) + type_formats = { + 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), + 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), + 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), + 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), + 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), + 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), + 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), + 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) + } + + def __init__(self, raw, offset): + self._len, self._type = struct.unpack("HH", raw[offset:offset + 4]) + self.type = self._type & ~Netlink.NLA_TYPE_MASK + self.payload_len = self._len + self.full_len = (self.payload_len + 3) & ~3 + self.raw = raw[offset + 4:offset + self.payload_len] + + @classmethod + def get_format(cls, attr_type, byte_order=None): + format = cls.type_formats[attr_type] + if byte_order: + return format.big if byte_order == "big-endian" \ + else format.little + return format.native + + @classmethod + def formatted_string(cls, raw, display_hint): + if display_hint == 'mac': + formatted = ':'.join('%02x' % b for b in raw) + elif display_hint == 'hex': + formatted = bytes.hex(raw, ' ') + elif display_hint in [ 'ipv4', 'ipv6' ]: + formatted = format(ipaddress.ip_address(raw)) + elif display_hint == 'uuid': + formatted = str(uuid.UUID(bytes=raw)) + else: + formatted = raw + return formatted + + def as_scalar(self, attr_type, byte_order=None): + format = self.get_format(attr_type, byte_order) + return format.unpack(self.raw)[0] + + def as_strz(self): + return self.raw.decode('ascii')[:-1] + + def as_bin(self): + return self.raw + + def as_c_array(self, type): + format = self.get_format(type) + return [ x[0] for x in format.iter_unpack(self.raw) ] + + def as_struct(self, members): + value = dict() + offset = 0 + for m in members: + # TODO: handle non-scalar members + if m.type == 'binary': + decoded = self.raw[offset:offset+m['len']] + offset += m['len'] + elif m.type in NlAttr.type_formats: + format = self.get_format(m.type, m.byte_order) + [ decoded ] = format.unpack_from(self.raw, offset) + offset += format.size + if m.display_hint: + decoded = self.formatted_string(decoded, m.display_hint) + value[m.name] = decoded + return value + + def __repr__(self): + return f"[type:{self.type} len:{self._len}] {self.raw}" + + +class NlAttrs: + def __init__(self, msg): + self.attrs = [] + + offset = 0 + while offset < len(msg): + attr = NlAttr(msg, offset) + offset += attr.full_len + self.attrs.append(attr) + + def __iter__(self): + yield from self.attrs + + def __repr__(self): + msg = '' + for a in self.attrs: + if msg: + msg += '\n' + msg += repr(a) + return msg + + +class NlMsg: + def __init__(self, msg, offset, attr_space=None): + self.hdr = msg[offset:offset + 16] + + self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ + struct.unpack("IHHII", self.hdr) + + self.raw = msg[offset + 16:offset + self.nl_len] + + self.error = 0 + self.done = 0 + + extack_off = None + if self.nl_type == Netlink.NLMSG_ERROR: + self.error = struct.unpack("i", self.raw[0:4])[0] + self.done = 1 + extack_off = 20 + elif self.nl_type == Netlink.NLMSG_DONE: + self.done = 1 + extack_off = 4 + + self.extack = None + if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: + self.extack = dict() + extack_attrs = NlAttrs(self.raw[extack_off:]) + for extack in extack_attrs: + if extack.type == Netlink.NLMSGERR_ATTR_MSG: + self.extack['msg'] = extack.as_strz() + elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: + self.extack['miss-type'] = extack.as_scalar('u32') + elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: + self.extack['miss-nest'] = extack.as_scalar('u32') + elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: + self.extack['bad-attr-offs'] = extack.as_scalar('u32') + else: + if 'unknown' not in self.extack: + self.extack['unknown'] = [] + self.extack['unknown'].append(extack) + + if attr_space: + # We don't have the ability to parse nests yet, so only do global + if 'miss-type' in self.extack and 'miss-nest' not in self.extack: + miss_type = self.extack['miss-type'] + if miss_type in attr_space.attrs_by_val: + spec = attr_space.attrs_by_val[miss_type] + desc = spec['name'] + if 'doc' in spec: + desc += f" ({spec['doc']})" + self.extack['miss-type'] = desc + + def cmd(self): + return self.nl_type + + def __repr__(self): + msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n" + if self.error: + msg += '\terror: ' + str(self.error) + if self.extack: + msg += '\textack: ' + repr(self.extack) + return msg + + +class NlMsgs: + def __init__(self, data, attr_space=None): + self.msgs = [] + + offset = 0 + while offset < len(data): + msg = NlMsg(data, offset, attr_space=attr_space) + offset += msg.nl_len + self.msgs.append(msg) + + def __iter__(self): + yield from self.msgs + + +genl_family_name_to_id = None + + +def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): + # we prepend length in _genl_msg_finalize() + if seq is None: + seq = random.randint(1, 1024) + nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) + genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) + return nlmsg + genlmsg + + +def _genl_msg_finalize(msg): + return struct.pack("I", len(msg) + 4) + msg + + +def _genl_load_families(): + with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: + sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) + + msg = _genl_msg(Netlink.GENL_ID_CTRL, + Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, + Netlink.CTRL_CMD_GETFAMILY, 1) + msg = _genl_msg_finalize(msg) + + sock.send(msg, 0) + + global genl_family_name_to_id + genl_family_name_to_id = dict() + + while True: + reply = sock.recv(128 * 1024) + nms = NlMsgs(reply) + for nl_msg in nms: + if nl_msg.error: + print("Netlink error:", nl_msg.error) + return + if nl_msg.done: + return + + gm = GenlMsg(nl_msg) + fam = dict() + for attr in NlAttrs(gm.raw): + if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: + fam['id'] = attr.as_scalar('u16') + elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: + fam['name'] = attr.as_strz() + elif attr.type == Netlink.CTRL_ATTR_MAXATTR: + fam['maxattr'] = attr.as_scalar('u32') + elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: + fam['mcast'] = dict() + for entry in NlAttrs(attr.raw): + mcast_name = None + mcast_id = None + for entry_attr in NlAttrs(entry.raw): + if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: + mcast_name = entry_attr.as_strz() + elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: + mcast_id = entry_attr.as_scalar('u32') + if mcast_name and mcast_id is not None: + fam['mcast'][mcast_name] = mcast_id + if 'name' in fam and 'id' in fam: + genl_family_name_to_id[fam['name']] = fam + + +class GenlMsg: + def __init__(self, nl_msg): + self.nl = nl_msg + self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) + self.raw = nl_msg.raw[4:] + + def cmd(self): + return self.genl_cmd + + def __repr__(self): + msg = repr(self.nl) + msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" + for a in self.raw_attrs: + msg += '\t\t' + repr(a) + '\n' + return msg + + +class NetlinkProtocol: + def __init__(self, family_name, proto_num): + self.family_name = family_name + self.proto_num = proto_num + + def _message(self, nl_type, nl_flags, seq=None): + if seq is None: + seq = random.randint(1, 1024) + nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) + return nlmsg + + def message(self, flags, command, version, seq=None): + return self._message(command, flags, seq) + + def _decode(self, nl_msg): + return nl_msg + + def decode(self, ynl, nl_msg): + msg = self._decode(nl_msg) + fixed_header_size = 0 + if ynl: + op = ynl.rsp_by_value[msg.cmd()] + fixed_header_size = ynl._fixed_header_size(op) + msg.raw_attrs = NlAttrs(msg.raw[fixed_header_size:]) + return msg + + def get_mcast_id(self, mcast_name, mcast_groups): + if mcast_name not in mcast_groups: + raise Exception(f'Multicast group "{mcast_name}" not present in the spec') + return mcast_groups[mcast_name].value + + +class GenlProtocol(NetlinkProtocol): + def __init__(self, family_name): + super().__init__(family_name, Netlink.NETLINK_GENERIC) + + global genl_family_name_to_id + if genl_family_name_to_id is None: + _genl_load_families() + + self.genl_family = genl_family_name_to_id[family_name] + self.family_id = genl_family_name_to_id[family_name]['id'] + + def message(self, flags, command, version, seq=None): + nlmsg = self._message(self.family_id, flags, seq) + genlmsg = struct.pack("BBH", command, version, 0) + return nlmsg + genlmsg + + def _decode(self, nl_msg): + return GenlMsg(nl_msg) + + def get_mcast_id(self, mcast_name, mcast_groups): + if mcast_name not in self.genl_family['mcast']: + raise Exception(f'Multicast group "{mcast_name}" not present in the family') + return self.genl_family['mcast'][mcast_name] + + +# +# YNL implementation details. +# + + +class YnlFamily(SpecFamily): + def __init__(self, def_path, schema=None): + super().__init__(def_path, schema) + + self.include_raw = False + + try: + if self.proto == "netlink-raw": + self.nlproto = NetlinkProtocol(self.yaml['name'], + self.yaml['protonum']) + else: + self.nlproto = GenlProtocol(self.yaml['name']) + except KeyError: + raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") + + self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) + self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) + self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) + self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) + + self.async_msg_ids = set() + self.async_msg_queue = [] + + for msg in self.msgs.values(): + if msg.is_async: + self.async_msg_ids.add(msg.rsp_value) + + for op_name, op in self.ops.items(): + bound_f = functools.partial(self._op, op_name) + setattr(self, op.ident_name, bound_f) + + + def ntf_subscribe(self, mcast_name): + mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) + self.sock.bind((0, 0)) + self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, + mcast_id) + + def _add_attr(self, space, name, value): + try: + attr = self.attr_sets[space][name] + except KeyError: + raise Exception(f"Space '{space}' has no attribute '{name}'") + nl_type = attr.value + if attr["type"] == 'nest': + nl_type |= Netlink.NLA_F_NESTED + attr_payload = b'' + for subname, subvalue in value.items(): + attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue) + elif attr["type"] == 'flag': + attr_payload = b'' + elif attr["type"] == 'string': + attr_payload = str(value).encode('ascii') + b'\x00' + elif attr["type"] == 'binary': + if isinstance(value, bytes): + attr_payload = value + elif isinstance(value, str): + attr_payload = bytes.fromhex(value) + else: + raise Exception(f'Unknown type for binary attribute, value: {value}') + elif attr['type'] in NlAttr.type_formats: + format = NlAttr.get_format(attr['type'], attr.byte_order) + attr_payload = format.pack(int(value)) + else: + raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') + + pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) + return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad + + def _decode_enum(self, raw, attr_spec): + enum = self.consts[attr_spec['enum']] + if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']: + i = 0 + value = set() + while raw: + if raw & 1: + value.add(enum.entries_by_val[i].name) + raw >>= 1 + i += 1 + else: + value = enum.entries_by_val[raw].name + return value + + def _decode_binary(self, attr, attr_spec): + if attr_spec.struct_name: + members = self.consts[attr_spec.struct_name] + decoded = attr.as_struct(members) + for m in members: + if m.enum: + decoded[m.name] = self._decode_enum(decoded[m.name], m) + elif attr_spec.sub_type: + decoded = attr.as_c_array(attr_spec.sub_type) + else: + decoded = attr.as_bin() + if attr_spec.display_hint: + decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint) + return decoded + + def _decode_array_nest(self, attr, attr_spec): + decoded = [] + offset = 0 + while offset < len(attr.raw): + item = NlAttr(attr.raw, offset) + offset += item.full_len + + subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) + decoded.append({ item.type: subattrs }) + return decoded + + def _decode(self, attrs, space): + attr_space = self.attr_sets[space] + rsp = dict() + for attr in attrs: + try: + attr_spec = attr_space.attrs_by_val[attr.type] + except KeyError: + raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") + if attr_spec["type"] == 'nest': + subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes']) + decoded = subdict + elif attr_spec["type"] == 'string': + decoded = attr.as_strz() + elif attr_spec["type"] == 'binary': + decoded = self._decode_binary(attr, attr_spec) + elif attr_spec["type"] == 'flag': + decoded = True + elif attr_spec["type"] in NlAttr.type_formats: + decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) + elif attr_spec["type"] == 'array-nest': + decoded = self._decode_array_nest(attr, attr_spec) + else: + raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') + + if 'enum' in attr_spec: + decoded = self._decode_enum(decoded, attr_spec) + + if not attr_spec.is_multi: + rsp[attr_spec['name']] = decoded + elif attr_spec.name in rsp: + rsp[attr_spec.name].append(decoded) + else: + rsp[attr_spec.name] = [decoded] + + return rsp + + def _decode_extack_path(self, attrs, attr_set, offset, target): + for attr in attrs: + try: + attr_spec = attr_set.attrs_by_val[attr.type] + except KeyError: + raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") + if offset > target: + break + if offset == target: + return '.' + attr_spec.name + + if offset + attr.full_len <= target: + offset += attr.full_len + continue + if attr_spec['type'] != 'nest': + raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") + offset += 4 + subpath = self._decode_extack_path(NlAttrs(attr.raw), + self.attr_sets[attr_spec['nested-attributes']], + offset, target) + if subpath is None: + return None + return '.' + attr_spec.name + subpath + + return None + + def _decode_extack(self, request, op, extack): + if 'bad-attr-offs' not in extack: + return + + msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set)) + offset = 20 + self._fixed_header_size(op) + path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, + extack['bad-attr-offs']) + if path: + del extack['bad-attr-offs'] + extack['bad-attr'] = path + + def _fixed_header_size(self, op): + if op.fixed_header: + fixed_header_members = self.consts[op.fixed_header].members + size = 0 + for m in fixed_header_members: + format = NlAttr.get_format(m.type, m.byte_order) + size += format.size + return size + else: + return 0 + + def _decode_fixed_header(self, msg, name): + fixed_header_members = self.consts[name].members + fixed_header_attrs = dict() + offset = 0 + for m in fixed_header_members: + format = NlAttr.get_format(m.type, m.byte_order) + [ value ] = format.unpack_from(msg.raw, offset) + offset += format.size + if m.enum: + value = self._decode_enum(value, m) + fixed_header_attrs[m.name] = value + return fixed_header_attrs + + def handle_ntf(self, decoded): + msg = dict() + if self.include_raw: + msg['raw'] = decoded + op = self.rsp_by_value[decoded.cmd()] + attrs = self._decode(decoded.raw_attrs, op.attr_set.name) + if op.fixed_header: + attrs.update(self._decode_fixed_header(decoded, op.fixed_header)) + + msg['name'] = op['name'] + msg['msg'] = attrs + self.async_msg_queue.append(msg) + + def check_ntf(self): + while True: + try: + reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT) + except BlockingIOError: + return + + nms = NlMsgs(reply) + for nl_msg in nms: + if nl_msg.error: + print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) + print(nl_msg) + continue + if nl_msg.done: + print("Netlink done while checking for ntf!?") + continue + + decoded = self.nlproto.decode(self, nl_msg) + if decoded.cmd() not in self.async_msg_ids: + print("Unexpected msg id done while checking for ntf", decoded) + continue + + self.handle_ntf(decoded) + + def operation_do_attributes(self, name): + """ + For a given operation name, find and return a supported + set of attributes (as a dict). + """ + op = self.find_operation(name) + if not op: + return None + + return op['do']['request']['attributes'].copy() + + def _op(self, method, vals, flags, dump=False): + op = self.ops[method] + + nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK + for flag in flags or []: + nl_flags |= flag + if dump: + nl_flags |= Netlink.NLM_F_DUMP + + req_seq = random.randint(1024, 65535) + msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) + fixed_header_members = [] + if op.fixed_header: + fixed_header_members = self.consts[op.fixed_header].members + for m in fixed_header_members: + value = vals.pop(m.name) if m.name in vals else 0 + format = NlAttr.get_format(m.type, m.byte_order) + msg += format.pack(value) + for name, value in vals.items(): + msg += self._add_attr(op.attr_set.name, name, value) + msg = _genl_msg_finalize(msg) + + self.sock.send(msg, 0) + + done = False + rsp = [] + while not done: + reply = self.sock.recv(128 * 1024) + nms = NlMsgs(reply, attr_space=op.attr_set) + for nl_msg in nms: + if nl_msg.extack: + self._decode_extack(msg, op, nl_msg.extack) + + if nl_msg.error: + raise NlError(nl_msg) + if nl_msg.done: + if nl_msg.extack: + print("Netlink warning:") + print(nl_msg) + done = True + break + + decoded = self.nlproto.decode(self, nl_msg) + + # Check if this is a reply to our request + if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value: + if decoded.cmd() in self.async_msg_ids: + self.handle_ntf(decoded) + continue + else: + print('Unexpected message: ' + repr(decoded)) + continue + + rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) + if op.fixed_header: + rsp_msg.update(self._decode_fixed_header(decoded, op.fixed_header)) + rsp.append(rsp_msg) + + if not rsp: + return None + if not dump and len(rsp) == 1: + return rsp[0] + return rsp + + def do(self, method, vals, flags): + return self._op(method, vals, flags) + + def dump(self, method, vals): + return self._op(method, vals, [], dump=True) |