trxcon/l1sched: clarify TDMA Fn (mod 26) maps
[osmocom-bb.git] / src / target / trx_toolkit / codec.py
blobc57060096889db41cc55deb7691b523039c994a2
1 # -*- coding: utf-8 -*-
3 '''
4 Very simple (performance oriented) declarative message codec.
5 Inspired by Pycrate and Scapy.
6 '''
8 # TRX Toolkit
10 # (C) 2021 by sysmocom - s.f.m.c. GmbH <info@sysmocom.de>
11 # Author: Vadim Yanitskiy <vyanitskiy@sysmocom.de>
13 # All Rights Reserved
15 # This program is free software; you can redistribute it and/or modify
16 # it under the terms of the GNU General Public License as published by
17 # the Free Software Foundation; either version 2 of the License, or
18 # (at your option) any later version.
20 # This program is distributed in the hope that it will be useful,
21 # but WITHOUT ANY WARRANTY; without even the implied warranty of
22 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23 # GNU General Public License for more details.
25 from typing import Optional, Callable, Tuple, Any
26 import abc
28 class ProtocolError(Exception):
29 ''' Error in a protocol definition. '''
31 class DecodeError(Exception):
32 ''' Error during decoding of a field/message. '''
34 class EncodeError(Exception):
35 ''' Error during encoding of a field/message. '''
38 class Codec(abc.ABC):
39 ''' Base class providing encoding and decoding API. '''
41 @abc.abstractmethod
42 def from_bytes(self, vals: dict, data: bytes) -> int:
43 ''' Decode value(s) from the given buffer of bytes. '''
45 @abc.abstractmethod
46 def to_bytes(self, vals: dict) -> bytes:
47 ''' Encode value(s) into bytes. '''
50 class Field(Codec):
51 ''' Base class representing one field in a Message. '''
53 # Default length (0 means the whole buffer)
54 DEF_LEN = 0 # type: int
56 # Default parameters
57 DEF_PARAMS = { } # type: dict
59 # Presence of a field during decoding and encoding
60 ## get_pres: Callable[[dict], bool]
61 # Length of a field for self.from_bytes()
62 ## get_len: Callable[[dict, bytes], int]
63 # Value of a field for self.to_bytes()
64 ## get_val: Callable[[dict], Any]
66 def __init__(self, name: str, **kw) -> None:
67 self.name = name
69 self.len = kw.get('len', self.DEF_LEN)
70 if self.len == 0: # flexible field
71 self.get_len = lambda _, data: len(data)
72 else: # fixed length
73 self.get_len = lambda vals, _: self.len
75 # Field is unconditionally present by default
76 self.get_pres = lambda vals: True
77 # Field takes its value from the given dict by default
78 self.get_val = lambda vals: vals[self.name]
80 # Additional parameters for derived field types
81 self.p = { key : kw.get(key, self.DEF_PARAMS[key])
82 for key in self.DEF_PARAMS }
84 def from_bytes(self, vals: dict, data: bytes) -> int:
85 if self.get_pres(vals) is False:
86 return 0
87 length = self.get_len(vals, data)
88 if len(data) < length:
89 raise DecodeError('Short read')
90 self._from_bytes(vals, data[:length])
91 return length
93 def to_bytes(self, vals: dict) -> bytes:
94 if self.get_pres(vals) is False:
95 return b''
96 data = self._to_bytes(vals)
97 if self.len > 0 and len(data) != self.len:
98 raise EncodeError('Field length mismatch')
99 return data
101 @abc.abstractmethod
102 def _from_bytes(self, vals: dict, data: bytes) -> None:
103 ''' Decode value(s) from the given buffer of bytes. '''
104 raise NotImplementedError
106 @abc.abstractmethod
107 def _to_bytes(self, vals: dict) -> bytes:
108 ''' Encode value(s) into bytes. '''
109 raise NotImplementedError
112 class Buf(Field):
113 ''' A sequence of octets. '''
115 def _from_bytes(self, vals: dict, data: bytes) -> None:
116 vals[self.name] = data
118 def _to_bytes(self, vals: dict) -> bytes:
119 # TODO: handle len(self.get_val()) < self.get_len()
120 return self.get_val(vals)
123 class Spare(Field):
124 ''' Spare filling for RFU fields or padding. '''
126 # Default parameters
127 DEF_PARAMS = {
128 'filler' : b'\x00',
131 def _from_bytes(self, vals: dict, data: bytes) -> None:
132 pass # Just ignore it
134 def _to_bytes(self, vals: dict) -> bytes:
135 return self.p['filler'] * self.get_len(vals, b'')
138 class Uint(Field):
139 ''' An integer field: unsigned, N bits, big endian. '''
141 # Uint8 by default
142 DEF_LEN = 1
144 # Default parameters
145 DEF_PARAMS = {
146 'offset' : 0,
147 'mult' : 1,
150 # Big endian, unsigned
151 SIGN = False
152 BO = 'big'
154 def _from_bytes(self, vals: dict, data: bytes) -> None:
155 val = int.from_bytes(data, self.BO, signed=self.SIGN)
156 vals[self.name] = val * self.p['mult'] + self.p['offset']
158 def _to_bytes(self, vals: dict) -> bytes:
159 val = (self.get_val(vals) - self.p['offset']) // self.p['mult']
160 return val.to_bytes(self.len, self.BO, signed=self.SIGN)
162 class Uint16BE(Uint):
163 DEF_LEN = 16 // 8
165 class Uint16LE(Uint16BE):
166 BO = 'little'
168 class Uint32BE(Uint):
169 DEF_LEN = 32 // 8
171 class Uint32LE(Uint32BE):
172 BO = 'little'
174 class Int(Uint):
175 SIGN = True
177 class Int16BE(Int):
178 DEF_LEN = 16 // 8
180 class Int16LE(Int16BE):
181 BO = 'little'
183 class Int32BE(Int):
184 DEF_LEN = 32 // 8
186 class Int32LE(Int32BE):
187 BO = 'little'
190 class BitFieldSet(Field):
191 ''' A set of bit-fields. '''
193 # Default parameters
194 DEF_PARAMS = {
195 # Default field order (MSB first)
196 'order' : 'big',
199 # To be defined by derived types
200 STRUCT = () # type: Tuple['BitField', ...]
202 def __init__(self, **kw) -> None:
203 Field.__init__(self, self.__class__.__name__, **kw)
205 self._fields = kw.get('set', self.STRUCT)
206 if type(self._fields) is not tuple:
207 raise ProtocolError('Expected a tuple')
209 # LSB first is basically reversed order
210 if self.p['order'] in ('little', 'lsb'):
211 self._fields = self._fields[::-1]
213 # Calculate the overall field length
214 if self.len == 0:
215 bl_sum = sum([f.bl for f in self._fields])
216 self.len = bl_sum // 8
217 if bl_sum % 8 > 0:
218 self.len += 1
220 # Re-define self.get_len() since we always know the length
221 self.get_len = lambda vals, data: self.len
223 # Pre-calculate offset and mask for each field
224 offset = self.len * 8
225 for f in self._fields:
226 if f.bl > offset:
227 raise ProtocolError(f, 'BitFieldSet overflow')
228 f.offset = offset - f.bl
229 f.mask = 2 ** f.bl - 1
230 offset -= f.bl
232 def _from_bytes(self, vals: dict, data: bytes) -> None:
233 blob = int.from_bytes(data, byteorder='big') # intentionally using 'big' here
234 for f in self._fields:
235 f.dec_val(vals, blob)
237 def _to_bytes(self, vals: dict) -> bytes:
238 blob = 0x00
239 for f in self._fields: # TODO: use functools.reduce()?
240 blob |= f.enc_val(vals)
241 return blob.to_bytes(self.len, byteorder='big')
243 class BitField:
244 ''' One field in a BitFieldSet. '''
246 # Special fields for BitFieldSet
247 offset = 0 # type: int
248 mask = 0 # type: int
250 class Spare:
251 ''' Spare filling in a BitFieldSet. '''
253 def __init__(self, bl: int) -> None:
254 self.name = None
255 self.bl = bl
257 def enc_val(self, vals: dict) -> int:
258 return 0
260 def dec_val(self, vals: dict, blob: int) -> None:
261 pass # Just ignore it
263 def __init__(self, name: str, bl: int, **kw) -> None:
264 if bl < 1: # Ensure proper length
265 raise ProtocolError('Incorrect bit-field length')
267 self.name = name
268 self.bl = bl
270 # (Optional) fixed value for encoding and decoding
271 self.val = kw.get('val', None) # type: Optional[int]
273 def enc_val(self, vals: dict) -> int:
274 if self.val is None:
275 val = vals[self.name]
276 else:
277 val = self.val
278 return (val & self.mask) << self.offset
280 def dec_val(self, vals: dict, blob: int) -> None:
281 vals[self.name] = (blob >> self.offset) & self.mask
282 if (self.val is not None) and (vals[self.name] != self.val):
283 raise DecodeError('Unexpected value %d, expected %d'
284 % (vals[self.name], self.val))
287 class Envelope:
288 ''' A group of related fields. '''
290 STRUCT = () # type: Tuple[Codec, ...]
292 def __init__(self, check_len: bool = True):
293 # TODO: ensure uniqueue field names in self.STRUCT
294 self.c = { } # type: dict
295 self.check_len = check_len
297 def __getitem__(self, key: str) -> Any:
298 return self.c[key]
300 def __setitem__(self, key: str, val: Any) -> None:
301 self.c[key] = val
303 def __delitem__(self, key: str) -> None:
304 del self.c[key]
306 def check(self, vals: dict) -> None:
307 ''' Check the content before encoding and after decoding.
308 Raise exceptions (e.g. ValueError) if something is wrong.
310 Do not assert for every possible error (e.g. a negative value
311 for a Uint field) if an exception will be thrown by the field's
312 to_bytes() method anyway. Only additional constraints here.
315 def from_bytes(self, data: bytes) -> int:
316 self.c.clear() # forget the old content
317 return self._from_bytes(self.c, data)
319 def to_bytes(self) -> bytes:
320 return self._to_bytes(self.c)
322 def _from_bytes(self, vals: dict, data: bytes, offset: int = 0) -> int:
323 try: # Fields throw exceptions
324 for f in self.STRUCT:
325 offset += f.from_bytes(vals, data[offset:])
326 except Exception as e:
327 # Add contextual info
328 raise DecodeError(self, f, offset) from e
329 if self.check_len and len(data) != offset:
330 raise DecodeError(self, 'Unhandled tail octets: %s'
331 % data[offset:].hex())
332 self.check(vals) # Check the content after decoding (raises exceptions)
333 return offset
335 def _to_bytes(self, vals: dict) -> bytes:
336 def proc(f: Codec):
337 try: # Fields throw exceptions
338 return f.to_bytes(vals)
339 except Exception as e:
340 # Add contextual info
341 raise EncodeError(self, f) from e
342 self.check(vals) # Check the content before encoding (raises exceptions)
343 return b''.join([proc(f) for f in self.STRUCT])
345 class F(Field):
346 ''' Field wrapper. '''
348 def __init__(self, e: 'Envelope', name: str, **kw) -> None:
349 Field.__init__(self, name, **kw)
350 self.e = e
352 def _from_bytes(self, vals: dict, data: bytes) -> None:
353 vals[self.name] = { }
354 self.e._from_bytes(vals[self.name], data)
356 def _to_bytes(self, vals: dict) -> bytes:
357 return self.e._to_bytes(self.get_val(vals))
359 def f(self, name: str, **kw) -> Field:
360 return self.F(self, name, **kw)
363 class Sequence:
364 ''' A sequence of repeating elements (e.g. TLVs). '''
366 # The item of sequence
367 ITEM = None # type: Optional[Envelope]
369 def __init__(self, **kw) -> None:
370 if (self.ITEM is None) and ('item' not in kw):
371 raise ProtocolError('Missing Sequence item')
372 self._item = kw.get('item', self.ITEM) # type: Envelope
373 self._item.check_len = False
375 def from_bytes(self, data: bytes) -> list:
376 proc = self._item._from_bytes
377 vseq, offset = [], 0
378 length = len(data)
380 while offset < length:
381 vseq.append({ }) # new item of sequence
382 offset += proc(vseq[-1], data[offset:])
384 return vseq
386 def to_bytes(self, vseq: list) -> bytes:
387 proc = self._item._to_bytes
388 return b''.join([proc(v) for v in vseq])
390 class F(Field):
391 ''' Field wrapper. '''
393 def __init__(self, s: 'Sequence', name: str, **kw) -> None:
394 Field.__init__(self, name, **kw)
395 self.s = s
397 def _from_bytes(self, vals: dict, data: bytes) -> None:
398 vals[self.name] = self.s.from_bytes(data)
400 def _to_bytes(self, vals: dict) -> bytes:
401 return self.s.to_bytes(self.get_val(vals))
403 def f(self, name: str, **kw) -> Field:
404 return self.F(self, name, **kw)