drm/dp_mst: Add helper to get port number at specific LCT from RAD
[drm/drm-misc.git] / tools / net / ynl / lib / ynl.py
blob01ec01a90e763ce62645eb91d7ee6e53de283d94
1 # SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
3 from collections import namedtuple
4 from enum import Enum
5 import functools
6 import os
7 import random
8 import socket
9 import struct
10 from struct import Struct
11 import sys
12 import yaml
13 import ipaddress
14 import uuid
15 import queue
16 import selectors
17 import time
19 from .nlspec import SpecFamily
22 # Generic Netlink code which should really be in some library, but I can't quickly find one.
26 class Netlink:
27 # Netlink socket
28 SOL_NETLINK = 270
30 NETLINK_ADD_MEMBERSHIP = 1
31 NETLINK_CAP_ACK = 10
32 NETLINK_EXT_ACK = 11
33 NETLINK_GET_STRICT_CHK = 12
35 # Netlink message
36 NLMSG_ERROR = 2
37 NLMSG_DONE = 3
39 NLM_F_REQUEST = 1
40 NLM_F_ACK = 4
41 NLM_F_ROOT = 0x100
42 NLM_F_MATCH = 0x200
44 NLM_F_REPLACE = 0x100
45 NLM_F_EXCL = 0x200
46 NLM_F_CREATE = 0x400
47 NLM_F_APPEND = 0x800
49 NLM_F_CAPPED = 0x100
50 NLM_F_ACK_TLVS = 0x200
52 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
54 NLA_F_NESTED = 0x8000
55 NLA_F_NET_BYTEORDER = 0x4000
57 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
59 # Genetlink defines
60 NETLINK_GENERIC = 16
62 GENL_ID_CTRL = 0x10
64 # nlctrl
65 CTRL_CMD_GETFAMILY = 3
67 CTRL_ATTR_FAMILY_ID = 1
68 CTRL_ATTR_FAMILY_NAME = 2
69 CTRL_ATTR_MAXATTR = 5
70 CTRL_ATTR_MCAST_GROUPS = 7
72 CTRL_ATTR_MCAST_GRP_NAME = 1
73 CTRL_ATTR_MCAST_GRP_ID = 2
75 # Extack types
76 NLMSGERR_ATTR_MSG = 1
77 NLMSGERR_ATTR_OFFS = 2
78 NLMSGERR_ATTR_COOKIE = 3
79 NLMSGERR_ATTR_POLICY = 4
80 NLMSGERR_ATTR_MISS_TYPE = 5
81 NLMSGERR_ATTR_MISS_NEST = 6
83 # Policy types
84 NL_POLICY_TYPE_ATTR_TYPE = 1
85 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
86 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
87 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
88 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
89 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
90 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
91 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
92 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
93 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
94 NL_POLICY_TYPE_ATTR_PAD = 11
95 NL_POLICY_TYPE_ATTR_MASK = 12
97 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
98 's8', 's16', 's32', 's64',
99 'binary', 'string', 'nul-string',
100 'nested', 'nested-array',
101 'bitfield32', 'sint', 'uint'])
103 class NlError(Exception):
104 def __init__(self, nl_msg):
105 self.nl_msg = nl_msg
106 self.error = -nl_msg.error
108 def __str__(self):
109 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
112 class ConfigError(Exception):
113 pass
116 class NlAttr:
117 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
118 type_formats = {
119 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")),
120 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")),
121 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
122 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
123 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
124 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
125 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
126 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
129 def __init__(self, raw, offset):
130 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
131 self.type = self._type & ~Netlink.NLA_TYPE_MASK
132 self.is_nest = self._type & Netlink.NLA_F_NESTED
133 self.payload_len = self._len
134 self.full_len = (self.payload_len + 3) & ~3
135 self.raw = raw[offset + 4 : offset + self.payload_len]
137 @classmethod
138 def get_format(cls, attr_type, byte_order=None):
139 format = cls.type_formats[attr_type]
140 if byte_order:
141 return format.big if byte_order == "big-endian" \
142 else format.little
143 return format.native
145 def as_scalar(self, attr_type, byte_order=None):
146 format = self.get_format(attr_type, byte_order)
147 return format.unpack(self.raw)[0]
149 def as_auto_scalar(self, attr_type, byte_order=None):
150 if len(self.raw) != 4 and len(self.raw) != 8:
151 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
152 real_type = attr_type[0] + str(len(self.raw) * 8)
153 format = self.get_format(real_type, byte_order)
154 return format.unpack(self.raw)[0]
156 def as_strz(self):
157 return self.raw.decode('ascii')[:-1]
159 def as_bin(self):
160 return self.raw
162 def as_c_array(self, type):
163 format = self.get_format(type)
164 return [ x[0] for x in format.iter_unpack(self.raw) ]
166 def __repr__(self):
167 return f"[type:{self.type} len:{self._len}] {self.raw}"
170 class NlAttrs:
171 def __init__(self, msg, offset=0):
172 self.attrs = []
174 while offset < len(msg):
175 attr = NlAttr(msg, offset)
176 offset += attr.full_len
177 self.attrs.append(attr)
179 def __iter__(self):
180 yield from self.attrs
182 def __repr__(self):
183 msg = ''
184 for a in self.attrs:
185 if msg:
186 msg += '\n'
187 msg += repr(a)
188 return msg
191 class NlMsg:
192 def __init__(self, msg, offset, attr_space=None):
193 self.hdr = msg[offset : offset + 16]
195 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
196 struct.unpack("IHHII", self.hdr)
198 self.raw = msg[offset + 16 : offset + self.nl_len]
200 self.error = 0
201 self.done = 0
203 extack_off = None
204 if self.nl_type == Netlink.NLMSG_ERROR:
205 self.error = struct.unpack("i", self.raw[0:4])[0]
206 self.done = 1
207 extack_off = 20
208 elif self.nl_type == Netlink.NLMSG_DONE:
209 self.error = struct.unpack("i", self.raw[0:4])[0]
210 self.done = 1
211 extack_off = 4
213 self.extack = None
214 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
215 self.extack = dict()
216 extack_attrs = NlAttrs(self.raw[extack_off:])
217 for extack in extack_attrs:
218 if extack.type == Netlink.NLMSGERR_ATTR_MSG:
219 self.extack['msg'] = extack.as_strz()
220 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
221 self.extack['miss-type'] = extack.as_scalar('u32')
222 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
223 self.extack['miss-nest'] = extack.as_scalar('u32')
224 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
225 self.extack['bad-attr-offs'] = extack.as_scalar('u32')
226 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
227 self.extack['policy'] = self._decode_policy(extack.raw)
228 else:
229 if 'unknown' not in self.extack:
230 self.extack['unknown'] = []
231 self.extack['unknown'].append(extack)
233 if attr_space:
234 # We don't have the ability to parse nests yet, so only do global
235 if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
236 miss_type = self.extack['miss-type']
237 if miss_type in attr_space.attrs_by_val:
238 spec = attr_space.attrs_by_val[miss_type]
239 self.extack['miss-type'] = spec['name']
240 if 'doc' in spec:
241 self.extack['miss-type-doc'] = spec['doc']
243 def _decode_policy(self, raw):
244 policy = {}
245 for attr in NlAttrs(raw):
246 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
247 type = attr.as_scalar('u32')
248 policy['type'] = Netlink.AttrType(type).name
249 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
250 policy['min-value'] = attr.as_scalar('s64')
251 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
252 policy['max-value'] = attr.as_scalar('s64')
253 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
254 policy['min-value'] = attr.as_scalar('u64')
255 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
256 policy['max-value'] = attr.as_scalar('u64')
257 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
258 policy['min-length'] = attr.as_scalar('u32')
259 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
260 policy['max-length'] = attr.as_scalar('u32')
261 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
262 policy['bitfield32-mask'] = attr.as_scalar('u32')
263 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
264 policy['mask'] = attr.as_scalar('u64')
265 return policy
267 def cmd(self):
268 return self.nl_type
270 def __repr__(self):
271 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
272 if self.error:
273 msg += '\n\terror: ' + str(self.error)
274 if self.extack:
275 msg += '\n\textack: ' + repr(self.extack)
276 return msg
279 class NlMsgs:
280 def __init__(self, data, attr_space=None):
281 self.msgs = []
283 offset = 0
284 while offset < len(data):
285 msg = NlMsg(data, offset, attr_space=attr_space)
286 offset += msg.nl_len
287 self.msgs.append(msg)
289 def __iter__(self):
290 yield from self.msgs
293 genl_family_name_to_id = None
296 def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
297 # we prepend length in _genl_msg_finalize()
298 if seq is None:
299 seq = random.randint(1, 1024)
300 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
301 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
302 return nlmsg + genlmsg
305 def _genl_msg_finalize(msg):
306 return struct.pack("I", len(msg) + 4) + msg
309 def _genl_load_families():
310 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
311 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
313 msg = _genl_msg(Netlink.GENL_ID_CTRL,
314 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
315 Netlink.CTRL_CMD_GETFAMILY, 1)
316 msg = _genl_msg_finalize(msg)
318 sock.send(msg, 0)
320 global genl_family_name_to_id
321 genl_family_name_to_id = dict()
323 while True:
324 reply = sock.recv(128 * 1024)
325 nms = NlMsgs(reply)
326 for nl_msg in nms:
327 if nl_msg.error:
328 print("Netlink error:", nl_msg.error)
329 return
330 if nl_msg.done:
331 return
333 gm = GenlMsg(nl_msg)
334 fam = dict()
335 for attr in NlAttrs(gm.raw):
336 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
337 fam['id'] = attr.as_scalar('u16')
338 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
339 fam['name'] = attr.as_strz()
340 elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
341 fam['maxattr'] = attr.as_scalar('u32')
342 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
343 fam['mcast'] = dict()
344 for entry in NlAttrs(attr.raw):
345 mcast_name = None
346 mcast_id = None
347 for entry_attr in NlAttrs(entry.raw):
348 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
349 mcast_name = entry_attr.as_strz()
350 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
351 mcast_id = entry_attr.as_scalar('u32')
352 if mcast_name and mcast_id is not None:
353 fam['mcast'][mcast_name] = mcast_id
354 if 'name' in fam and 'id' in fam:
355 genl_family_name_to_id[fam['name']] = fam
358 class GenlMsg:
359 def __init__(self, nl_msg):
360 self.nl = nl_msg
361 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
362 self.raw = nl_msg.raw[4:]
364 def cmd(self):
365 return self.genl_cmd
367 def __repr__(self):
368 msg = repr(self.nl)
369 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
370 for a in self.raw_attrs:
371 msg += '\t\t' + repr(a) + '\n'
372 return msg
375 class NetlinkProtocol:
376 def __init__(self, family_name, proto_num):
377 self.family_name = family_name
378 self.proto_num = proto_num
380 def _message(self, nl_type, nl_flags, seq=None):
381 if seq is None:
382 seq = random.randint(1, 1024)
383 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
384 return nlmsg
386 def message(self, flags, command, version, seq=None):
387 return self._message(command, flags, seq)
389 def _decode(self, nl_msg):
390 return nl_msg
392 def decode(self, ynl, nl_msg, op):
393 msg = self._decode(nl_msg)
394 if op is None:
395 op = ynl.rsp_by_value[msg.cmd()]
396 fixed_header_size = ynl._struct_size(op.fixed_header)
397 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
398 return msg
400 def get_mcast_id(self, mcast_name, mcast_groups):
401 if mcast_name not in mcast_groups:
402 raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
403 return mcast_groups[mcast_name].value
405 def msghdr_size(self):
406 return 16
409 class GenlProtocol(NetlinkProtocol):
410 def __init__(self, family_name):
411 super().__init__(family_name, Netlink.NETLINK_GENERIC)
413 global genl_family_name_to_id
414 if genl_family_name_to_id is None:
415 _genl_load_families()
417 self.genl_family = genl_family_name_to_id[family_name]
418 self.family_id = genl_family_name_to_id[family_name]['id']
420 def message(self, flags, command, version, seq=None):
421 nlmsg = self._message(self.family_id, flags, seq)
422 genlmsg = struct.pack("BBH", command, version, 0)
423 return nlmsg + genlmsg
425 def _decode(self, nl_msg):
426 return GenlMsg(nl_msg)
428 def get_mcast_id(self, mcast_name, mcast_groups):
429 if mcast_name not in self.genl_family['mcast']:
430 raise Exception(f'Multicast group "{mcast_name}" not present in the family')
431 return self.genl_family['mcast'][mcast_name]
433 def msghdr_size(self):
434 return super().msghdr_size() + 4
437 class SpaceAttrs:
438 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
440 def __init__(self, attr_space, attrs, outer = None):
441 outer_scopes = outer.scopes if outer else []
442 inner_scope = self.SpecValuesPair(attr_space, attrs)
443 self.scopes = [inner_scope] + outer_scopes
445 def lookup(self, name):
446 for scope in self.scopes:
447 if name in scope.spec:
448 if name in scope.values:
449 return scope.values[name]
450 spec_name = scope.spec.yaml['name']
451 raise Exception(
452 f"No value for '{name}' in attribute space '{spec_name}'")
453 raise Exception(f"Attribute '{name}' not defined in any attribute-set")
457 # YNL implementation details.
461 class YnlFamily(SpecFamily):
462 def __init__(self, def_path, schema=None, process_unknown=False,
463 recv_size=0):
464 super().__init__(def_path, schema)
466 self.include_raw = False
467 self.process_unknown = process_unknown
469 try:
470 if self.proto == "netlink-raw":
471 self.nlproto = NetlinkProtocol(self.yaml['name'],
472 self.yaml['protonum'])
473 else:
474 self.nlproto = GenlProtocol(self.yaml['name'])
475 except KeyError:
476 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
478 self._recv_dbg = False
479 # Note that netlink will use conservative (min) message size for
480 # the first dump recv() on the socket, our setting will only matter
481 # from the second recv() on.
482 self._recv_size = recv_size if recv_size else 131072
483 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
484 # for a message, so smaller receive sizes will lead to truncation.
485 # Note that the min size for other families may be larger than 4k!
486 if self._recv_size < 4000:
487 raise ConfigError()
489 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
490 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
491 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
492 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
494 self.async_msg_ids = set()
495 self.async_msg_queue = queue.Queue()
497 for msg in self.msgs.values():
498 if msg.is_async:
499 self.async_msg_ids.add(msg.rsp_value)
501 for op_name, op in self.ops.items():
502 bound_f = functools.partial(self._op, op_name)
503 setattr(self, op.ident_name, bound_f)
506 def ntf_subscribe(self, mcast_name):
507 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
508 self.sock.bind((0, 0))
509 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
510 mcast_id)
512 def set_recv_dbg(self, enabled):
513 self._recv_dbg = enabled
515 def _recv_dbg_print(self, reply, nl_msgs):
516 if not self._recv_dbg:
517 return
518 print("Recv: read", len(reply), "bytes,",
519 len(nl_msgs.msgs), "messages", file=sys.stderr)
520 for nl_msg in nl_msgs:
521 print(" ", nl_msg, file=sys.stderr)
523 def _encode_enum(self, attr_spec, value):
524 enum = self.consts[attr_spec['enum']]
525 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
526 scalar = 0
527 if isinstance(value, str):
528 value = [value]
529 for single_value in value:
530 scalar += enum.entries[single_value].user_value(as_flags = True)
531 return scalar
532 else:
533 return enum.entries[value].user_value()
535 def _get_scalar(self, attr_spec, value):
536 try:
537 return int(value)
538 except (ValueError, TypeError) as e:
539 if 'enum' not in attr_spec:
540 raise e
541 return self._encode_enum(attr_spec, value)
543 def _add_attr(self, space, name, value, search_attrs):
544 try:
545 attr = self.attr_sets[space][name]
546 except KeyError:
547 raise Exception(f"Space '{space}' has no attribute '{name}'")
548 nl_type = attr.value
550 if attr.is_multi and isinstance(value, list):
551 attr_payload = b''
552 for subvalue in value:
553 attr_payload += self._add_attr(space, name, subvalue, search_attrs)
554 return attr_payload
556 if attr["type"] == 'nest':
557 nl_type |= Netlink.NLA_F_NESTED
558 attr_payload = b''
559 sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs)
560 for subname, subvalue in value.items():
561 attr_payload += self._add_attr(attr['nested-attributes'],
562 subname, subvalue, sub_attrs)
563 elif attr["type"] == 'flag':
564 if not value:
565 # If value is absent or false then skip attribute creation.
566 return b''
567 attr_payload = b''
568 elif attr["type"] == 'string':
569 attr_payload = str(value).encode('ascii') + b'\x00'
570 elif attr["type"] == 'binary':
571 if isinstance(value, bytes):
572 attr_payload = value
573 elif isinstance(value, str):
574 attr_payload = bytes.fromhex(value)
575 elif isinstance(value, dict) and attr.struct_name:
576 attr_payload = self._encode_struct(attr.struct_name, value)
577 else:
578 raise Exception(f'Unknown type for binary attribute, value: {value}')
579 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
580 scalar = self._get_scalar(attr, value)
581 if attr.is_auto_scalar:
582 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
583 else:
584 attr_type = attr["type"]
585 format = NlAttr.get_format(attr_type, attr.byte_order)
586 attr_payload = format.pack(scalar)
587 elif attr['type'] in "bitfield32":
588 scalar_value = self._get_scalar(attr, value["value"])
589 scalar_selector = self._get_scalar(attr, value["selector"])
590 attr_payload = struct.pack("II", scalar_value, scalar_selector)
591 elif attr['type'] == 'sub-message':
592 msg_format = self._resolve_selector(attr, search_attrs)
593 attr_payload = b''
594 if msg_format.fixed_header:
595 attr_payload += self._encode_struct(msg_format.fixed_header, value)
596 if msg_format.attr_set:
597 if msg_format.attr_set in self.attr_sets:
598 nl_type |= Netlink.NLA_F_NESTED
599 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
600 for subname, subvalue in value.items():
601 attr_payload += self._add_attr(msg_format.attr_set,
602 subname, subvalue, sub_attrs)
603 else:
604 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
605 else:
606 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
608 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
609 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
611 def _decode_enum(self, raw, attr_spec):
612 enum = self.consts[attr_spec['enum']]
613 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
614 i = 0
615 value = set()
616 while raw:
617 if raw & 1:
618 value.add(enum.entries_by_val[i].name)
619 raw >>= 1
620 i += 1
621 else:
622 value = enum.entries_by_val[raw].name
623 return value
625 def _decode_binary(self, attr, attr_spec):
626 if attr_spec.struct_name:
627 decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
628 elif attr_spec.sub_type:
629 decoded = attr.as_c_array(attr_spec.sub_type)
630 else:
631 decoded = attr.as_bin()
632 if attr_spec.display_hint:
633 decoded = self._formatted_string(decoded, attr_spec.display_hint)
634 return decoded
636 def _decode_array_attr(self, attr, attr_spec):
637 decoded = []
638 offset = 0
639 while offset < len(attr.raw):
640 item = NlAttr(attr.raw, offset)
641 offset += item.full_len
643 if attr_spec["sub-type"] == 'nest':
644 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
645 decoded.append({ item.type: subattrs })
646 elif attr_spec["sub-type"] == 'binary':
647 subattrs = item.as_bin()
648 if attr_spec.display_hint:
649 subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
650 decoded.append(subattrs)
651 elif attr_spec["sub-type"] in NlAttr.type_formats:
652 subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
653 if attr_spec.display_hint:
654 subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
655 decoded.append(subattrs)
656 else:
657 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
658 return decoded
660 def _decode_nest_type_value(self, attr, attr_spec):
661 decoded = {}
662 value = attr
663 for name in attr_spec['type-value']:
664 value = NlAttr(value.raw, 0)
665 decoded[name] = value.type
666 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
667 decoded.update(subattrs)
668 return decoded
670 def _decode_unknown(self, attr):
671 if attr.is_nest:
672 return self._decode(NlAttrs(attr.raw), None)
673 else:
674 return attr.as_bin()
676 def _rsp_add(self, rsp, name, is_multi, decoded):
677 if is_multi == None:
678 if name in rsp and type(rsp[name]) is not list:
679 rsp[name] = [rsp[name]]
680 is_multi = True
681 else:
682 is_multi = False
684 if not is_multi:
685 rsp[name] = decoded
686 elif name in rsp:
687 rsp[name].append(decoded)
688 else:
689 rsp[name] = [decoded]
691 def _resolve_selector(self, attr_spec, search_attrs):
692 sub_msg = attr_spec.sub_message
693 if sub_msg not in self.sub_msgs:
694 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
695 sub_msg_spec = self.sub_msgs[sub_msg]
697 selector = attr_spec.selector
698 value = search_attrs.lookup(selector)
699 if value not in sub_msg_spec.formats:
700 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
702 spec = sub_msg_spec.formats[value]
703 return spec
705 def _decode_sub_msg(self, attr, attr_spec, search_attrs):
706 msg_format = self._resolve_selector(attr_spec, search_attrs)
707 decoded = {}
708 offset = 0
709 if msg_format.fixed_header:
710 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
711 offset = self._struct_size(msg_format.fixed_header)
712 if msg_format.attr_set:
713 if msg_format.attr_set in self.attr_sets:
714 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
715 decoded.update(subdict)
716 else:
717 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
718 return decoded
720 def _decode(self, attrs, space, outer_attrs = None):
721 rsp = dict()
722 if space:
723 attr_space = self.attr_sets[space]
724 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
726 for attr in attrs:
727 try:
728 attr_spec = attr_space.attrs_by_val[attr.type]
729 except (KeyError, UnboundLocalError):
730 if not self.process_unknown:
731 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
732 attr_name = f"UnknownAttr({attr.type})"
733 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
734 continue
736 if attr_spec["type"] == 'nest':
737 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
738 decoded = subdict
739 elif attr_spec["type"] == 'string':
740 decoded = attr.as_strz()
741 elif attr_spec["type"] == 'binary':
742 decoded = self._decode_binary(attr, attr_spec)
743 elif attr_spec["type"] == 'flag':
744 decoded = True
745 elif attr_spec.is_auto_scalar:
746 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
747 elif attr_spec["type"] in NlAttr.type_formats:
748 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
749 if 'enum' in attr_spec:
750 decoded = self._decode_enum(decoded, attr_spec)
751 elif attr_spec.display_hint:
752 decoded = self._formatted_string(decoded, attr_spec.display_hint)
753 elif attr_spec["type"] == 'indexed-array':
754 decoded = self._decode_array_attr(attr, attr_spec)
755 elif attr_spec["type"] == 'bitfield32':
756 value, selector = struct.unpack("II", attr.raw)
757 if 'enum' in attr_spec:
758 value = self._decode_enum(value, attr_spec)
759 selector = self._decode_enum(selector, attr_spec)
760 decoded = {"value": value, "selector": selector}
761 elif attr_spec["type"] == 'sub-message':
762 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
763 elif attr_spec["type"] == 'nest-type-value':
764 decoded = self._decode_nest_type_value(attr, attr_spec)
765 else:
766 if not self.process_unknown:
767 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
768 decoded = self._decode_unknown(attr)
770 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
772 return rsp
774 def _decode_extack_path(self, attrs, attr_set, offset, target):
775 for attr in attrs:
776 try:
777 attr_spec = attr_set.attrs_by_val[attr.type]
778 except KeyError:
779 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
780 if offset > target:
781 break
782 if offset == target:
783 return '.' + attr_spec.name
785 if offset + attr.full_len <= target:
786 offset += attr.full_len
787 continue
788 if attr_spec['type'] != 'nest':
789 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
790 offset += 4
791 subpath = self._decode_extack_path(NlAttrs(attr.raw),
792 self.attr_sets[attr_spec['nested-attributes']],
793 offset, target)
794 if subpath is None:
795 return None
796 return '.' + attr_spec.name + subpath
798 return None
800 def _decode_extack(self, request, op, extack):
801 if 'bad-attr-offs' not in extack:
802 return
804 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
805 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
806 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
807 extack['bad-attr-offs'])
808 if path:
809 del extack['bad-attr-offs']
810 extack['bad-attr'] = path
812 def _struct_size(self, name):
813 if name:
814 members = self.consts[name].members
815 size = 0
816 for m in members:
817 if m.type in ['pad', 'binary']:
818 if m.struct:
819 size += self._struct_size(m.struct)
820 else:
821 size += m.len
822 else:
823 format = NlAttr.get_format(m.type, m.byte_order)
824 size += format.size
825 return size
826 else:
827 return 0
829 def _decode_struct(self, data, name):
830 members = self.consts[name].members
831 attrs = dict()
832 offset = 0
833 for m in members:
834 value = None
835 if m.type == 'pad':
836 offset += m.len
837 elif m.type == 'binary':
838 if m.struct:
839 len = self._struct_size(m.struct)
840 value = self._decode_struct(data[offset : offset + len],
841 m.struct)
842 offset += len
843 else:
844 value = data[offset : offset + m.len]
845 offset += m.len
846 else:
847 format = NlAttr.get_format(m.type, m.byte_order)
848 [ value ] = format.unpack_from(data, offset)
849 offset += format.size
850 if value is not None:
851 if m.enum:
852 value = self._decode_enum(value, m)
853 elif m.display_hint:
854 value = self._formatted_string(value, m.display_hint)
855 attrs[m.name] = value
856 return attrs
858 def _encode_struct(self, name, vals):
859 members = self.consts[name].members
860 attr_payload = b''
861 for m in members:
862 value = vals.pop(m.name) if m.name in vals else None
863 if m.type == 'pad':
864 attr_payload += bytearray(m.len)
865 elif m.type == 'binary':
866 if m.struct:
867 if value is None:
868 value = dict()
869 attr_payload += self._encode_struct(m.struct, value)
870 else:
871 if value is None:
872 attr_payload += bytearray(m.len)
873 else:
874 attr_payload += bytes.fromhex(value)
875 else:
876 if value is None:
877 value = 0
878 format = NlAttr.get_format(m.type, m.byte_order)
879 attr_payload += format.pack(value)
880 return attr_payload
882 def _formatted_string(self, raw, display_hint):
883 if display_hint == 'mac':
884 formatted = ':'.join('%02x' % b for b in raw)
885 elif display_hint == 'hex':
886 if isinstance(raw, int):
887 formatted = hex(raw)
888 else:
889 formatted = bytes.hex(raw, ' ')
890 elif display_hint in [ 'ipv4', 'ipv6' ]:
891 formatted = format(ipaddress.ip_address(raw))
892 elif display_hint == 'uuid':
893 formatted = str(uuid.UUID(bytes=raw))
894 else:
895 formatted = raw
896 return formatted
898 def handle_ntf(self, decoded):
899 msg = dict()
900 if self.include_raw:
901 msg['raw'] = decoded
902 op = self.rsp_by_value[decoded.cmd()]
903 attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
904 if op.fixed_header:
905 attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
907 msg['name'] = op['name']
908 msg['msg'] = attrs
909 self.async_msg_queue.put(msg)
911 def check_ntf(self):
912 while True:
913 try:
914 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
915 except BlockingIOError:
916 return
918 nms = NlMsgs(reply)
919 self._recv_dbg_print(reply, nms)
920 for nl_msg in nms:
921 if nl_msg.error:
922 print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
923 print(nl_msg)
924 continue
925 if nl_msg.done:
926 print("Netlink done while checking for ntf!?")
927 continue
929 decoded = self.nlproto.decode(self, nl_msg, None)
930 if decoded.cmd() not in self.async_msg_ids:
931 print("Unexpected msg id while checking for ntf", decoded)
932 continue
934 self.handle_ntf(decoded)
936 def poll_ntf(self, duration=None):
937 start_time = time.time()
938 selector = selectors.DefaultSelector()
939 selector.register(self.sock, selectors.EVENT_READ)
941 while True:
942 try:
943 yield self.async_msg_queue.get_nowait()
944 except queue.Empty:
945 if duration is not None:
946 timeout = start_time + duration - time.time()
947 if timeout <= 0:
948 return
949 else:
950 timeout = None
951 events = selector.select(timeout)
952 if events:
953 self.check_ntf()
955 def operation_do_attributes(self, name):
957 For a given operation name, find and return a supported
958 set of attributes (as a dict).
960 op = self.find_operation(name)
961 if not op:
962 return None
964 return op['do']['request']['attributes'].copy()
966 def _encode_message(self, op, vals, flags, req_seq):
967 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
968 for flag in flags or []:
969 nl_flags |= flag
971 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
972 if op.fixed_header:
973 msg += self._encode_struct(op.fixed_header, vals)
974 search_attrs = SpaceAttrs(op.attr_set, vals)
975 for name, value in vals.items():
976 msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
977 msg = _genl_msg_finalize(msg)
978 return msg
980 def _ops(self, ops):
981 reqs_by_seq = {}
982 req_seq = random.randint(1024, 65535)
983 payload = b''
984 for (method, vals, flags) in ops:
985 op = self.ops[method]
986 msg = self._encode_message(op, vals, flags, req_seq)
987 reqs_by_seq[req_seq] = (op, msg, flags)
988 payload += msg
989 req_seq += 1
991 self.sock.send(payload, 0)
993 done = False
994 rsp = []
995 op_rsp = []
996 while not done:
997 reply = self.sock.recv(self._recv_size)
998 nms = NlMsgs(reply, attr_space=op.attr_set)
999 self._recv_dbg_print(reply, nms)
1000 for nl_msg in nms:
1001 if nl_msg.nl_seq in reqs_by_seq:
1002 (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1003 if nl_msg.extack:
1004 self._decode_extack(req_msg, op, nl_msg.extack)
1005 else:
1006 op = None
1007 req_flags = []
1009 if nl_msg.error:
1010 raise NlError(nl_msg)
1011 if nl_msg.done:
1012 if nl_msg.extack:
1013 print("Netlink warning:")
1014 print(nl_msg)
1016 if Netlink.NLM_F_DUMP in req_flags:
1017 rsp.append(op_rsp)
1018 elif not op_rsp:
1019 rsp.append(None)
1020 elif len(op_rsp) == 1:
1021 rsp.append(op_rsp[0])
1022 else:
1023 rsp.append(op_rsp)
1024 op_rsp = []
1026 del reqs_by_seq[nl_msg.nl_seq]
1027 done = len(reqs_by_seq) == 0
1028 break
1030 decoded = self.nlproto.decode(self, nl_msg, op)
1032 # Check if this is a reply to our request
1033 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1034 if decoded.cmd() in self.async_msg_ids:
1035 self.handle_ntf(decoded)
1036 continue
1037 else:
1038 print('Unexpected message: ' + repr(decoded))
1039 continue
1041 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1042 if op.fixed_header:
1043 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1044 op_rsp.append(rsp_msg)
1046 return rsp
1048 def _op(self, method, vals, flags=None, dump=False):
1049 req_flags = flags or []
1050 if dump:
1051 req_flags.append(Netlink.NLM_F_DUMP)
1053 ops = [(method, vals, req_flags)]
1054 return self._ops(ops)[0]
1056 def do(self, method, vals, flags=None):
1057 return self._op(method, vals, flags)
1059 def dump(self, method, vals):
1060 return self._op(method, vals, dump=True)
1062 def do_multi(self, ops):
1063 return self._ops(ops)