# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause # # pylint: disable=missing-class-docstring, missing-function-docstring # pylint: disable=too-many-branches, too-many-locals, too-many-instance-attributes # pylint: disable=too-many-lines """ YAML Netlink Library An implementation of the genetlink and raw netlink protocols. """ from collections import namedtuple from enum import Enum import functools import os import random import socket import struct from struct import Struct import sys import ipaddress import uuid import queue import selectors import time from .nlspec import SpecFamily # # Generic Netlink code which should really be in some library, but I can't quickly find one. # class YnlException(Exception): pass # pylint: disable=too-few-public-methods class Netlink: # Netlink socket SOL_NETLINK = 270 NETLINK_ADD_MEMBERSHIP = 1 NETLINK_LISTEN_ALL_NSID = 8 NETLINK_CAP_ACK = 10 NETLINK_EXT_ACK = 11 NETLINK_GET_STRICT_CHK = 12 # Netlink message NLMSG_ERROR = 2 NLMSG_DONE = 3 NLM_F_REQUEST = 1 NLM_F_ACK = 4 NLM_F_ROOT = 0x100 NLM_F_MATCH = 0x200 NLM_F_REPLACE = 0x100 NLM_F_EXCL = 0x200 NLM_F_CREATE = 0x400 NLM_F_APPEND = 0x800 NLM_F_CAPPED = 0x100 NLM_F_ACK_TLVS = 0x200 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH NLA_F_NESTED = 0x8000 NLA_F_NET_BYTEORDER = 0x4000 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER # Genetlink defines NETLINK_GENERIC = 16 GENL_ID_CTRL = 0x10 # nlctrl CTRL_CMD_GETFAMILY = 3 CTRL_CMD_GETPOLICY = 10 CTRL_ATTR_FAMILY_ID = 1 CTRL_ATTR_FAMILY_NAME = 2 CTRL_ATTR_MAXATTR = 5 CTRL_ATTR_MCAST_GROUPS = 7 CTRL_ATTR_POLICY = 8 CTRL_ATTR_OP_POLICY = 9 CTRL_ATTR_OP = 10 CTRL_ATTR_MCAST_GRP_NAME = 1 CTRL_ATTR_MCAST_GRP_ID = 2 CTRL_ATTR_POLICY_DO = 1 CTRL_ATTR_POLICY_DUMP = 2 # Extack types NLMSGERR_ATTR_MSG = 1 NLMSGERR_ATTR_OFFS = 2 NLMSGERR_ATTR_COOKIE = 3 NLMSGERR_ATTR_POLICY = 4 NLMSGERR_ATTR_MISS_TYPE = 5 NLMSGERR_ATTR_MISS_NEST = 6 # Policy types NL_POLICY_TYPE_ATTR_TYPE = 1 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 NL_POLICY_TYPE_ATTR_PAD = 11 NL_POLICY_TYPE_ATTR_MASK = 12 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'binary', 'string', 'nul-string', 'nested', 'nested-array', 'bitfield32', 'sint', 'uint']) class NlError(Exception): def __init__(self, nl_msg): self.nl_msg = nl_msg self.error = -nl_msg.error def __str__(self): msg = "Netlink error: " extack = self.nl_msg.extack.copy() if self.nl_msg.extack else {} if 'msg' in extack: msg += extack['msg'] + ': ' del extack['msg'] msg += os.strerror(self.error) if extack: msg += ' ' + str(extack) return msg class ConfigError(Exception): pass class NlPolicy: """Kernel policy for one mode (do or dump) of one operation. Returned by YnlFamily.get_policy(). Attributes of the policy are accessible as attributes of the object. Nested policies can be accessed indexing the object like a dictionary:: pol = ynl.get_policy('page-pool-stats-get', 'do') pol['info'].type # 'nested' pol['info']['id'].type # 'uint' pol['info']['id'].min_value # 1 Each policy entry always has a 'type' attribute (e.g. u32, string, nested). Optional attributes depending on the 'type': min-value, max-value, min-length, max-length, mask. Policies can form infinite nesting loops. These loops are trimmed when policy is converted to a dict with pol.to_dict(). """ def __init__(self, ynl, policy_idx, policy_table, attr_set, props=None): self._policy_idx = policy_idx self._policy_table = policy_table self._ynl = ynl self._props = props or {} self._entries = {} self._cache = {} if policy_idx is not None and policy_idx in policy_table: for attr_id, decoded in policy_table[policy_idx].items(): if attr_set and attr_id in attr_set.attrs_by_val: spec = attr_set.attrs_by_val[attr_id] name = spec['name'] else: spec = None name = f'attr-{attr_id}' self._entries[name] = (spec, decoded) def __getitem__(self, name): """Descend into a nested policy by attribute name.""" if name not in self._cache: spec, decoded = self._entries[name] props = dict(decoded) child_idx = None child_set = None if 'policy-idx' in props: child_idx = props.pop('policy-idx') if spec and 'nested-attributes' in spec.yaml: child_set = self._ynl.attr_sets[spec.yaml['nested-attributes']] self._cache[name] = NlPolicy(self._ynl, child_idx, self._policy_table, child_set, props) return self._cache[name] def __getattr__(self, name): """Access this policy entry's own properties (type, min-value, etc.). Underscores in the name are converted to dashes, so that pol.min_value looks up "min-value". """ key = name.replace('_', '-') try: # Hack for level-0 which we still want to have .type but we don't # want type to pointlessly show up in the dict / JSON form. if not self._props and name == "type": return "nested" return self._props[key] except KeyError: raise AttributeError(name) def get(self, name, default=None): """Look up a child policy entry by attribute name, with a default.""" try: return self[name] except KeyError: return default def __contains__(self, name): return name in self._entries def __len__(self): return len(self._entries) def __iter__(self): return iter(self._entries) def keys(self): """Return attribute names accepted by this policy.""" return self._entries.keys() def to_dict(self, seen=None): """Convert to a plain dict, suitable for JSON serialization. Nested NlPolicy objects are expanded recursively. Cyclic references are trimmed (resolved to just {"type": "nested"}). """ if seen is None: seen = set() result = dict(self._props) if self._policy_idx is not None: if self._policy_idx not in seen: seen = seen | {self._policy_idx} children = {} for name in self: children[name] = self[name].to_dict(seen) if self._props: result['policy'] = children else: result = children return result def __repr__(self): return repr(self.to_dict()) class NlAttr: ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) type_formats = { 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("h"), Struct("I"), Struct("i"), Struct("Q"), Struct("q"), Struct("(attrs, ...) By default this will execute the as "do", pass dump=True to perform a dump operation. ynl. is a shorthand / convenience wrapper for the following methods which take the op_name as a string: ynl.do(op_name, attrs, flags=None) -- execute a do operation ynl.dump(op_name, attrs) -- execute a dump operation ynl.do_multi(ops) -- batch multiple do operations The flags argument in ynl.do() allows passing in extra NLM_F_* flags which may be necessary for old families. Notification API: ynl.ntf_subscribe(mcast_name) -- join a multicast group ynl.ntf_listen_all_nsid() -- listen on all netns ynl.check_ntf() -- drain pending notifications ynl.poll_ntf(duration=None) -- yield notifications Policy introspection allows querying validation criteria from the running kernel. Allows checking whether kernel supports a given attribute or value. ynl.get_policy(op_name, mode) -- query kernel policy for an op """ def __init__(self, def_path, schema=None, process_unknown=False, recv_size=0): super().__init__(def_path, schema) self.include_raw = False self.process_unknown = process_unknown try: if self.proto == "netlink-raw": self.nlproto = NetlinkProtocol(self.yaml['name'], self.yaml['protonum']) else: self.nlproto = GenlProtocol(self.yaml['name']) except KeyError as err: raise YnlException(f"Family '{self.yaml['name']}' not supported by the kernel") from err self._recv_dbg = False # Note that netlink will use conservative (min) message size for # the first dump recv() on the socket, our setting will only matter # from the second recv() on. self._recv_size = recv_size if recv_size else 131072 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) # for a message, so smaller receive sizes will lead to truncation. # Note that the min size for other families may be larger than 4k! if self._recv_size < 4000: raise ConfigError() self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) self.async_msg_ids = set() self.async_msg_queue = queue.Queue() for msg in self.msgs.values(): if msg.is_async: self.async_msg_ids.add(msg.rsp_value) for op_name, op in self.ops.items(): bound_f = functools.partial(self._op, op_name) setattr(self, op.ident_name, bound_f) def close(self): if self.sock is not None: self.sock.close() self.sock = None def __enter__(self): return self def __exit__(self, exc_type, exc, tb): self.close() def ntf_subscribe(self, mcast_name): mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) self.sock.bind((0, 0)) self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, mcast_id) def ntf_listen_all_nsid(self): """Enable NETLINK_LISTEN_ALL_NSID to receive notifications from all namespaces that have an nsid mapped in the current one.""" self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_LISTEN_ALL_NSID, 1) @staticmethod def _decode_nsid(ancdata): for cmsg_level, cmsg_type, cmsg_data in ancdata: if (cmsg_level == Netlink.SOL_NETLINK and cmsg_type == Netlink.NETLINK_LISTEN_ALL_NSID): nsid = struct.unpack('i', cmsg_data)[0] if nsid >= 0: return nsid return None return None def set_recv_dbg(self, enabled): self._recv_dbg = enabled def _recv_dbg_print(self, reply, nl_msgs): if not self._recv_dbg: return print("Recv: read", len(reply), "bytes,", len(nl_msgs.msgs), "messages", file=sys.stderr) for nl_msg in nl_msgs: print(" ", nl_msg, file=sys.stderr) def _encode_enum(self, attr_spec, value): enum = self.consts[attr_spec['enum']] if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): scalar = 0 if isinstance(value, str): value = [value] for single_value in value: scalar += enum.entries[single_value].user_value(as_flags = True) return scalar return enum.entries[value].user_value() def _get_scalar(self, attr_spec, value): try: return int(value) except (ValueError, TypeError) as e: if 'enum' in attr_spec: return self._encode_enum(attr_spec, value) if attr_spec.display_hint: return self._from_string(value, attr_spec) raise e # pylint: disable=too-many-statements def _add_attr(self, space, name, value, search_attrs): try: attr = self.attr_sets[space][name] except KeyError as err: raise YnlException(f"Space '{space}' has no attribute '{name}'") from err nl_type = attr.value if attr.is_multi and isinstance(value, list): attr_payload = b'' for subvalue in value: attr_payload += self._add_attr(space, name, subvalue, search_attrs) return attr_payload if attr["type"] == 'nest': nl_type |= Netlink.NLA_F_NESTED sub_space = attr['nested-attributes'] attr_payload = self._add_nest_attrs(value, sub_space, search_attrs) elif attr['type'] == 'indexed-array' and attr['sub-type'] == 'nest': nl_type |= Netlink.NLA_F_NESTED sub_space = attr['nested-attributes'] attr_payload = self._encode_indexed_array(value, sub_space, search_attrs) elif attr["type"] == 'flag': if not value: # If value is absent or false then skip attribute creation. return b'' attr_payload = b'' elif attr["type"] == 'string': attr_payload = str(value).encode('ascii') + b'\x00' elif attr["type"] == 'binary': if value is None: attr_payload = b'' elif isinstance(value, bytes): attr_payload = value elif isinstance(value, str): if attr.display_hint: attr_payload = self._from_string(value, attr) else: attr_payload = bytes.fromhex(value) elif isinstance(value, dict) and attr.struct_name: attr_payload = self._encode_struct(attr.struct_name, value) elif isinstance(value, list) and attr.sub_type in NlAttr.type_formats: format_ = NlAttr.get_format(attr.sub_type) attr_payload = b''.join([format_.pack(x) for x in value]) else: raise YnlException(f'Unknown type for binary attribute, value: {value}') elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: scalar = self._get_scalar(attr, value) if attr.is_auto_scalar: attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') else: attr_type = attr["type"] format_ = NlAttr.get_format(attr_type, attr.byte_order) attr_payload = format_.pack(scalar) elif attr['type'] in "bitfield32": scalar_value = self._get_scalar(attr, value["value"]) scalar_selector = self._get_scalar(attr, value["selector"]) attr_payload = struct.pack("II", scalar_value, scalar_selector) elif attr['type'] == 'sub-message': msg_format, _ = self._resolve_selector(attr, search_attrs) attr_payload = b'' if msg_format.fixed_header: attr_payload += self._encode_struct(msg_format.fixed_header, value) if msg_format.attr_set: if msg_format.attr_set in self.attr_sets: nl_type |= Netlink.NLA_F_NESTED sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) for subname, subvalue in value.items(): attr_payload += self._add_attr(msg_format.attr_set, subname, subvalue, sub_attrs) else: raise YnlException(f"Unknown attribute-set '{msg_format.attr_set}'") else: raise YnlException(f'Unknown type at {space} {name} {value} {attr["type"]}') return self._add_attr_raw(nl_type, attr_payload) def _add_attr_raw(self, nl_type, attr_payload): pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad def _add_nest_attrs(self, value, sub_space, search_attrs): sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs) attr_payload = b'' for subname, subvalue in value.items(): attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs) return attr_payload def _encode_indexed_array(self, vals, sub_space, search_attrs): attr_payload = b'' for i, val in enumerate(vals): idx = i | Netlink.NLA_F_NESTED val_payload = self._add_nest_attrs(val, sub_space, search_attrs) attr_payload += self._add_attr_raw(idx, val_payload) return attr_payload def _get_enum_or_unknown(self, enum, raw): try: name = enum.entries_by_val[raw].name except KeyError as error: if self.process_unknown: name = f"Unknown({raw})" else: raise error return name def _decode_enum(self, raw, attr_spec): enum = self.consts[attr_spec['enum']] if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): i = 0 value = set() while raw: if raw & 1: value.add(self._get_enum_or_unknown(enum, i)) raw >>= 1 i += 1 else: value = self._get_enum_or_unknown(enum, raw) return value def _decode_binary(self, attr, attr_spec): if attr_spec.struct_name: decoded = self._decode_struct(attr.raw, attr_spec.struct_name) elif attr_spec.sub_type: decoded = attr.as_c_array(attr_spec.sub_type) if 'enum' in attr_spec: decoded = [ self._decode_enum(x, attr_spec) for x in decoded ] elif attr_spec.display_hint: decoded = [ self._formatted_string(x, attr_spec.display_hint) for x in decoded ] else: decoded = attr.as_bin() if attr_spec.display_hint: decoded = self._formatted_string(decoded, attr_spec.display_hint) return decoded def _decode_array_attr(self, attr, attr_spec): decoded = [] offset = 0 while offset < len(attr.raw): item = NlAttr(attr.raw, offset) offset += item.full_len if attr_spec["sub-type"] == 'nest': subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) decoded.append({ item.type: subattrs }) elif attr_spec["sub-type"] == 'binary': subattr = item.as_bin() if attr_spec.display_hint: subattr = self._formatted_string(subattr, attr_spec.display_hint) decoded.append(subattr) elif attr_spec["sub-type"] in NlAttr.type_formats: subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) if 'enum' in attr_spec: subattr = self._decode_enum(subattr, attr_spec) elif attr_spec.display_hint: subattr = self._formatted_string(subattr, attr_spec.display_hint) decoded.append(subattr) else: raise YnlException(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') return decoded def _decode_nest_type_value(self, attr, attr_spec): decoded = {} value = attr for name in attr_spec['type-value']: value = NlAttr(value.raw, 0) decoded[name] = value.type subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) decoded.update(subattrs) return decoded def _decode_unknown(self, attr): if attr.is_nest: return self._decode(NlAttrs(attr.raw), None) return attr.as_bin() def _rsp_add(self, rsp, name, is_multi, decoded): if is_multi is None: if name in rsp and not isinstance(rsp[name], list): rsp[name] = [rsp[name]] is_multi = True else: is_multi = False if not is_multi: rsp[name] = decoded elif name in rsp: rsp[name].append(decoded) else: rsp[name] = [decoded] def _resolve_selector(self, attr_spec, search_attrs): sub_msg = attr_spec.sub_message if sub_msg not in self.sub_msgs: raise YnlException(f"No sub-message spec named {sub_msg} for {attr_spec.name}") sub_msg_spec = self.sub_msgs[sub_msg] selector = attr_spec.selector value = search_attrs.lookup(selector) if value not in sub_msg_spec.formats: raise YnlException(f"No message format for '{value}' in sub-message spec '{sub_msg}'") spec = sub_msg_spec.formats[value] return spec, value def _decode_sub_msg(self, attr, attr_spec, search_attrs): msg_format, _ = self._resolve_selector(attr_spec, search_attrs) decoded = {} offset = 0 if msg_format.fixed_header: decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)) offset = self.struct_size(msg_format.fixed_header) if msg_format.attr_set: if msg_format.attr_set in self.attr_sets: subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) decoded.update(subdict) else: raise YnlException(f"Unknown attribute-set '{msg_format.attr_set}' " f"when decoding '{attr_spec.name}'") return decoded # pylint: disable=too-many-statements def _decode(self, attrs, space, outer_attrs = None): rsp = {} search_attrs = {} if space: attr_space = self.attr_sets[space] search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) for attr in attrs: try: attr_spec = attr_space.attrs_by_val[attr.type] except (KeyError, UnboundLocalError) as err: if not self.process_unknown: raise YnlException(f"Space '{space}' has no attribute " f"with value '{attr.type}'") from err attr_name = f"UnknownAttr({attr.type})" self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) continue try: if attr_spec["type"] == 'pad': continue elif attr_spec["type"] == 'nest': subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) decoded = subdict elif attr_spec["type"] == 'string': decoded = attr.as_strz() elif attr_spec["type"] == 'binary': decoded = self._decode_binary(attr, attr_spec) elif attr_spec["type"] == 'flag': decoded = True elif attr_spec.is_auto_scalar: decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) if 'enum' in attr_spec: decoded = self._decode_enum(decoded, attr_spec) elif attr_spec["type"] in NlAttr.type_formats: decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) if 'enum' in attr_spec: decoded = self._decode_enum(decoded, attr_spec) elif attr_spec.display_hint: decoded = self._formatted_string(decoded, attr_spec.display_hint) elif attr_spec["type"] == 'indexed-array': decoded = self._decode_array_attr(attr, attr_spec) elif attr_spec["type"] == 'bitfield32': value, selector = struct.unpack("II", attr.raw) if 'enum' in attr_spec: value = self._decode_enum(value, attr_spec) selector = self._decode_enum(selector, attr_spec) decoded = {"value": value, "selector": selector} elif attr_spec["type"] == 'sub-message': decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) elif attr_spec["type"] == 'nest-type-value': decoded = self._decode_nest_type_value(attr, attr_spec) else: if not self.process_unknown: raise YnlException(f'Unknown {attr_spec["type"]} ' f'with name {attr_spec["name"]}') decoded = self._decode_unknown(attr) self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) except: print(f"Error decoding '{attr_spec.name}' from '{space}'") raise return rsp # pylint: disable=too-many-arguments, too-many-positional-arguments def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs): for attr in attrs: try: attr_spec = attr_set.attrs_by_val[attr.type] except KeyError as err: raise YnlException( f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") from err if offset > target: break if offset == target: return '.' + attr_spec.name if offset + attr.full_len <= target: offset += attr.full_len continue pathname = attr_spec.name if attr_spec['type'] == 'nest': sub_attrs = self.attr_sets[attr_spec['nested-attributes']] search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name'])) elif attr_spec['type'] == 'sub-message': msg_format, value = self._resolve_selector(attr_spec, search_attrs) if msg_format is None: raise YnlException(f"Can't resolve sub-message of " f"{attr_spec['name']} for extack") sub_attrs = self.attr_sets[msg_format.attr_set] pathname += f"({value})" else: raise YnlException(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") offset += 4 subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs, offset, target, search_attrs) if subpath is None: return None return '.' + pathname + subpath return None def _decode_extack(self, request, op, extack, vals): if 'bad-attr-offs' not in extack: return msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) offset = self.nlproto.msghdr_size() + self.struct_size(op.fixed_header) search_attrs = SpaceAttrs(op.attr_set, vals) path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, extack['bad-attr-offs'], search_attrs) if path: del extack['bad-attr-offs'] extack['bad-attr'] = path def struct_size(self, name): if name: members = self.consts[name].members size = 0 for m in members: if m.type in ['pad', 'binary']: if m.struct: size += self.struct_size(m.struct) else: size += m.len else: format_ = NlAttr.get_format(m.type, m.byte_order) size += format_.size return size return 0 def _decode_struct(self, data, name): members = self.consts[name].members attrs = {} offset = 0 for m in members: value = None if m.type == 'pad': offset += m.len elif m.type == 'binary': if m.struct: len_ = self.struct_size(m.struct) value = self._decode_struct(data[offset : offset + len_], m.struct) offset += len_ else: value = data[offset : offset + m.len] offset += m.len else: format_ = NlAttr.get_format(m.type, m.byte_order) [ value ] = format_.unpack_from(data, offset) offset += format_.size if value is not None: if m.enum: value = self._decode_enum(value, m) elif m.display_hint: value = self._formatted_string(value, m.display_hint) attrs[m.name] = value return attrs def _encode_struct(self, name, vals): members = self.consts[name].members attr_payload = b'' for m in members: value = vals.pop(m.name) if m.name in vals else None if m.type == 'pad': attr_payload += bytearray(m.len) elif m.type == 'binary': if m.struct: if value is None: value = {} attr_payload += self._encode_struct(m.struct, value) else: if value is None: attr_payload += bytearray(m.len) else: attr_payload += bytes.fromhex(value) else: if value is None: value = 0 format_ = NlAttr.get_format(m.type, m.byte_order) attr_payload += format_.pack(value) return attr_payload def _formatted_string(self, raw, display_hint): if display_hint == 'mac': formatted = ':'.join(f'{b:02x}' for b in raw) elif display_hint == 'hex': if isinstance(raw, int): formatted = hex(raw) else: formatted = bytes.hex(raw, ' ') elif display_hint in [ 'ipv4', 'ipv6', 'ipv4-or-v6' ]: formatted = format(ipaddress.ip_address(raw)) elif display_hint == 'uuid': formatted = str(uuid.UUID(bytes=raw)) else: formatted = raw return formatted def _from_string(self, string, attr_spec): if attr_spec.display_hint in ['ipv4', 'ipv6', 'ipv4-or-v6']: ip = ipaddress.ip_address(string) if attr_spec['type'] == 'binary': raw = ip.packed else: raw = int(ip) elif attr_spec.display_hint == 'hex': if attr_spec['type'] == 'binary': raw = bytes.fromhex(string) else: raw = int(string, 16) elif attr_spec.display_hint == 'mac': # Parse MAC address in format "00:11:22:33:44:55" or "001122334455" if ':' in string: mac_bytes = [int(x, 16) for x in string.split(':')] else: if len(string) % 2 != 0: raise YnlException(f"Invalid MAC address format: {string}") mac_bytes = [int(string[i:i+2], 16) for i in range(0, len(string), 2)] raw = bytes(mac_bytes) else: raise YnlException(f"Display hint '{attr_spec.display_hint}' not implemented" f" when parsing '{attr_spec['name']}'") return raw def handle_ntf(self, decoded, nsid=None): msg = {} if self.include_raw: msg['raw'] = decoded op = self.rsp_by_value[decoded.cmd()] attrs = self._decode(decoded.raw_attrs, op.attr_set.name) if op.fixed_header: attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) msg['name'] = op['name'] msg['msg'] = attrs if nsid is not None: msg['nsid'] = nsid self.async_msg_queue.put(msg) def _recvmsg(self, flags=0): reply, ancdata, _, _ = self.sock.recvmsg(self._recv_size, 4096, flags) return reply, ancdata def check_ntf(self): while True: try: reply, ancdata = self._recvmsg(socket.MSG_DONTWAIT) except BlockingIOError: return nsid = self._decode_nsid(ancdata) nms = NlMsgs(reply) self._recv_dbg_print(reply, nms) for nl_msg in nms: if nl_msg.error: print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) print(nl_msg) continue if nl_msg.done: print("Netlink done while checking for ntf!?") continue decoded = self.nlproto.decode(self, nl_msg, None) if decoded.cmd() not in self.async_msg_ids: print("Unexpected msg id while checking for ntf", decoded) continue self.handle_ntf(decoded, nsid) def poll_ntf(self, duration=None): start_time = time.time() selector = selectors.DefaultSelector() selector.register(self.sock, selectors.EVENT_READ) while True: try: yield self.async_msg_queue.get_nowait() except queue.Empty: if duration is not None: timeout = start_time + duration - time.time() if timeout <= 0: return else: timeout = None events = selector.select(timeout) if events: self.check_ntf() def operation_do_attributes(self, name): """ For a given operation name, find and return a supported set of attributes (as a dict). """ op = self.find_operation(name) if not op: return None return op['do']['request']['attributes'].copy() def _encode_message(self, op, vals, flags, req_seq): nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK for flag in flags or []: nl_flags |= flag msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) if op.fixed_header: msg += self._encode_struct(op.fixed_header, vals) search_attrs = SpaceAttrs(op.attr_set, vals) for name, value in vals.items(): msg += self._add_attr(op.attr_set.name, name, value, search_attrs) msg = _genl_msg_finalize(msg) return msg # pylint: disable=too-many-statements def _ops(self, ops): reqs_by_seq = {} req_seq = random.randint(1024, 65535) payload = b'' for (method, vals, flags) in ops: op = self.ops[method] msg = self._encode_message(op, vals, flags, req_seq) reqs_by_seq[req_seq] = (op, vals, msg, flags) payload += msg req_seq += 1 self.sock.send(payload, 0) done = False rsp = [] op_rsp = [] while not done: reply, ancdata = self._recvmsg() nsid = self._decode_nsid(ancdata) nms = NlMsgs(reply) self._recv_dbg_print(reply, nms) for nl_msg in nms: if nl_msg.nl_seq in reqs_by_seq: (op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] if nl_msg.extack: nl_msg.annotate_extack(op.attr_set) self._decode_extack(req_msg, op, nl_msg.extack, vals) else: op = None req_flags = [] if nl_msg.error: raise NlError(nl_msg) if nl_msg.done: if nl_msg.extack: print("Netlink warning:") print(nl_msg) if Netlink.NLM_F_DUMP in req_flags: rsp.append(op_rsp) elif not op_rsp: rsp.append(None) elif len(op_rsp) == 1: rsp.append(op_rsp[0]) else: rsp.append(op_rsp) op_rsp = [] del reqs_by_seq[nl_msg.nl_seq] done = len(reqs_by_seq) == 0 break decoded = self.nlproto.decode(self, nl_msg, op) # Check if this is a reply to our request if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: if decoded.cmd() in self.async_msg_ids: self.handle_ntf(decoded, nsid) continue print('Unexpected message: ' + repr(decoded)) continue rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) if op.fixed_header: rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) op_rsp.append(rsp_msg) return rsp def _op(self, method, vals, flags=None, dump=False): req_flags = flags or [] if dump: req_flags.append(Netlink.NLM_F_DUMP) ops = [(method, vals, req_flags)] return self._ops(ops)[0] def do(self, method, vals, flags=None): return self._op(method, vals, flags) def dump(self, method, vals): return self._op(method, vals, dump=True) def do_multi(self, ops): return self._ops(ops) def get_policy(self, op_name, mode): """Query running kernel for the Netlink policy of an operation. Allows checking whether kernel supports a given attribute or value. This method consults the running kernel, not the YAML spec. Args: op_name: operation name as it appears in the YAML spec mode: 'do' or 'dump' Returns: NlPolicy acting as a read-only dict mapping attribute names to their policy properties (type, min/max, nested, etc.), or None if the operation has no policy for the given mode. Empty policy usually implies that the operation rejects all attributes. """ op = self.ops[op_name] op_policy, policy_table = _genl_policy_dump(self.nlproto.family_id, op.req_value) if mode not in op_policy: return None policy_idx = op_policy[mode] return NlPolicy(self, policy_idx, policy_table, op.attr_set)