1 # -*- coding: utf-8 -*-
4 Very simple (performance oriented) declarative message codec.
5 Inspired by Pycrate and Scapy.
10 # (C) 2021 by sysmocom - s.f.m.c. GmbH <info@sysmocom.de>
11 # Author: Vadim Yanitskiy <vyanitskiy@sysmocom.de>
15 # This program is free software; you can redistribute it and/or modify
16 # it under the terms of the GNU General Public License as published by
17 # the Free Software Foundation; either version 2 of the License, or
18 # (at your option) any later version.
20 # This program is distributed in the hope that it will be useful,
21 # but WITHOUT ANY WARRANTY; without even the implied warranty of
22 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23 # GNU General Public License for more details.
25 from typing
import Optional
, Callable
, Tuple
, Any
28 class ProtocolError(Exception):
29 ''' Error in a protocol definition. '''
31 class DecodeError(Exception):
32 ''' Error during decoding of a field/message. '''
34 class EncodeError(Exception):
35 ''' Error during encoding of a field/message. '''
39 ''' Base class providing encoding and decoding API. '''
42 def from_bytes(self
, vals
: dict, data
: bytes
) -> int:
43 ''' Decode value(s) from the given buffer of bytes. '''
46 def to_bytes(self
, vals
: dict) -> bytes
:
47 ''' Encode value(s) into bytes. '''
51 ''' Base class representing one field in a Message. '''
53 # Default length (0 means the whole buffer)
54 DEF_LEN
= 0 # type: int
57 DEF_PARAMS
= { } # type: dict
59 # Presence of a field during decoding and encoding
60 ## get_pres: Callable[[dict], bool]
61 # Length of a field for self.from_bytes()
62 ## get_len: Callable[[dict, bytes], int]
63 # Value of a field for self.to_bytes()
64 ## get_val: Callable[[dict], Any]
66 def __init__(self
, name
: str, **kw
) -> None:
69 self
.len = kw
.get('len', self
.DEF_LEN
)
70 if self
.len == 0: # flexible field
71 self
.get_len
= lambda _
, data
: len(data
)
73 self
.get_len
= lambda vals
, _
: self
.len
75 # Field is unconditionally present by default
76 self
.get_pres
= lambda vals
: True
77 # Field takes its value from the given dict by default
78 self
.get_val
= lambda vals
: vals
[self
.name
]
80 # Additional parameters for derived field types
81 self
.p
= { key
: kw
.get(key
, self
.DEF_PARAMS
[key
])
82 for key
in self
.DEF_PARAMS
}
84 def from_bytes(self
, vals
: dict, data
: bytes
) -> int:
85 if self
.get_pres(vals
) is False:
87 length
= self
.get_len(vals
, data
)
88 if len(data
) < length
:
89 raise DecodeError('Short read')
90 self
._from
_bytes
(vals
, data
[:length
])
93 def to_bytes(self
, vals
: dict) -> bytes
:
94 if self
.get_pres(vals
) is False:
96 data
= self
._to
_bytes
(vals
)
97 if self
.len > 0 and len(data
) != self
.len:
98 raise EncodeError('Field length mismatch')
102 def _from_bytes(self
, vals
: dict, data
: bytes
) -> None:
103 ''' Decode value(s) from the given buffer of bytes. '''
104 raise NotImplementedError
107 def _to_bytes(self
, vals
: dict) -> bytes
:
108 ''' Encode value(s) into bytes. '''
109 raise NotImplementedError
113 ''' A sequence of octets. '''
115 def _from_bytes(self
, vals
: dict, data
: bytes
) -> None:
116 vals
[self
.name
] = data
118 def _to_bytes(self
, vals
: dict) -> bytes
:
119 # TODO: handle len(self.get_val()) < self.get_len()
120 return self
.get_val(vals
)
124 ''' Spare filling for RFU fields or padding. '''
131 def _from_bytes(self
, vals
: dict, data
: bytes
) -> None:
132 pass # Just ignore it
134 def _to_bytes(self
, vals
: dict) -> bytes
:
135 return self
.p
['filler'] * self
.get_len(vals
, b
'')
139 ''' An integer field: unsigned, N bits, big endian. '''
150 # Big endian, unsigned
154 def _from_bytes(self
, vals
: dict, data
: bytes
) -> None:
155 val
= int.from_bytes(data
, self
.BO
, signed
=self
.SIGN
)
156 vals
[self
.name
] = val
* self
.p
['mult'] + self
.p
['offset']
158 def _to_bytes(self
, vals
: dict) -> bytes
:
159 val
= (self
.get_val(vals
) - self
.p
['offset']) // self
.p
['mult']
160 return val
.to_bytes(self
.len, self
.BO
, signed
=self
.SIGN
)
162 class Uint16BE(Uint
):
165 class Uint16LE(Uint16BE
):
168 class Uint32BE(Uint
):
171 class Uint32LE(Uint32BE
):
180 class Int16LE(Int16BE
):
186 class Int32LE(Int32BE
):
190 class BitFieldSet(Field
):
191 ''' A set of bit-fields. '''
195 # Default field order (MSB first)
199 # To be defined by derived types
200 STRUCT
= () # type: Tuple['BitField', ...]
202 def __init__(self
, **kw
) -> None:
203 Field
.__init
__(self
, self
.__class
__.__name
__, **kw
)
205 self
._fields
= kw
.get('set', self
.STRUCT
)
206 if type(self
._fields
) is not tuple:
207 raise ProtocolError('Expected a tuple')
209 # LSB first is basically reversed order
210 if self
.p
['order'] in ('little', 'lsb'):
211 self
._fields
= self
._fields
[::-1]
213 # Calculate the overall field length
215 bl_sum
= sum([f
.bl
for f
in self
._fields
])
216 self
.len = bl_sum
// 8
220 # Re-define self.get_len() since we always know the length
221 self
.get_len
= lambda vals
, data
: self
.len
223 # Pre-calculate offset and mask for each field
224 offset
= self
.len * 8
225 for f
in self
._fields
:
227 raise ProtocolError(f
, 'BitFieldSet overflow')
228 f
.offset
= offset
- f
.bl
229 f
.mask
= 2 ** f
.bl
- 1
232 def _from_bytes(self
, vals
: dict, data
: bytes
) -> None:
233 blob
= int.from_bytes(data
, byteorder
='big') # intentionally using 'big' here
234 for f
in self
._fields
:
235 f
.dec_val(vals
, blob
)
237 def _to_bytes(self
, vals
: dict) -> bytes
:
239 for f
in self
._fields
: # TODO: use functools.reduce()?
240 blob |
= f
.enc_val(vals
)
241 return blob
.to_bytes(self
.len, byteorder
='big')
244 ''' One field in a BitFieldSet. '''
246 # Special fields for BitFieldSet
247 offset
= 0 # type: int
251 ''' Spare filling in a BitFieldSet. '''
253 def __init__(self
, bl
: int) -> None:
257 def enc_val(self
, vals
: dict) -> int:
260 def dec_val(self
, vals
: dict, blob
: int) -> None:
261 pass # Just ignore it
263 def __init__(self
, name
: str, bl
: int, **kw
) -> None:
264 if bl
< 1: # Ensure proper length
265 raise ProtocolError('Incorrect bit-field length')
270 # (Optional) fixed value for encoding and decoding
271 self
.val
= kw
.get('val', None) # type: Optional[int]
273 def enc_val(self
, vals
: dict) -> int:
275 val
= vals
[self
.name
]
278 return (val
& self
.mask
) << self
.offset
280 def dec_val(self
, vals
: dict, blob
: int) -> None:
281 vals
[self
.name
] = (blob
>> self
.offset
) & self
.mask
282 if (self
.val
is not None) and (vals
[self
.name
] != self
.val
):
283 raise DecodeError('Unexpected value %d, expected %d'
284 % (vals
[self
.name
], self
.val
))
288 ''' A group of related fields. '''
290 STRUCT
= () # type: Tuple[Codec, ...]
292 def __init__(self
, check_len
: bool = True):
293 # TODO: ensure uniqueue field names in self.STRUCT
294 self
.c
= { } # type: dict
295 self
.check_len
= check_len
297 def __getitem__(self
, key
: str) -> Any
:
300 def __setitem__(self
, key
: str, val
: Any
) -> None:
303 def __delitem__(self
, key
: str) -> None:
306 def check(self
, vals
: dict) -> None:
307 ''' Check the content before encoding and after decoding.
308 Raise exceptions (e.g. ValueError) if something is wrong.
310 Do not assert for every possible error (e.g. a negative value
311 for a Uint field) if an exception will be thrown by the field's
312 to_bytes() method anyway. Only additional constraints here.
315 def from_bytes(self
, data
: bytes
) -> int:
316 self
.c
.clear() # forget the old content
317 return self
._from
_bytes
(self
.c
, data
)
319 def to_bytes(self
) -> bytes
:
320 return self
._to
_bytes
(self
.c
)
322 def _from_bytes(self
, vals
: dict, data
: bytes
, offset
: int = 0) -> int:
323 try: # Fields throw exceptions
324 for f
in self
.STRUCT
:
325 offset
+= f
.from_bytes(vals
, data
[offset
:])
326 except Exception as e
:
327 # Add contextual info
328 raise DecodeError(self
, f
, offset
) from e
329 if self
.check_len
and len(data
) != offset
:
330 raise DecodeError(self
, 'Unhandled tail octets: %s'
331 % data
[offset
:].hex())
332 self
.check(vals
) # Check the content after decoding (raises exceptions)
335 def _to_bytes(self
, vals
: dict) -> bytes
:
337 try: # Fields throw exceptions
338 return f
.to_bytes(vals
)
339 except Exception as e
:
340 # Add contextual info
341 raise EncodeError(self
, f
) from e
342 self
.check(vals
) # Check the content before encoding (raises exceptions)
343 return b
''.join([proc(f
) for f
in self
.STRUCT
])
346 ''' Field wrapper. '''
348 def __init__(self
, e
: 'Envelope', name
: str, **kw
) -> None:
349 Field
.__init
__(self
, name
, **kw
)
352 def _from_bytes(self
, vals
: dict, data
: bytes
) -> None:
353 vals
[self
.name
] = { }
354 self
.e
._from
_bytes
(vals
[self
.name
], data
)
356 def _to_bytes(self
, vals
: dict) -> bytes
:
357 return self
.e
._to
_bytes
(self
.get_val(vals
))
359 def f(self
, name
: str, **kw
) -> Field
:
360 return self
.F(self
, name
, **kw
)
364 ''' A sequence of repeating elements (e.g. TLVs). '''
366 # The item of sequence
367 ITEM
= None # type: Optional[Envelope]
369 def __init__(self
, **kw
) -> None:
370 if (self
.ITEM
is None) and ('item' not in kw
):
371 raise ProtocolError('Missing Sequence item')
372 self
._item
= kw
.get('item', self
.ITEM
) # type: Envelope
373 self
._item
.check_len
= False
375 def from_bytes(self
, data
: bytes
) -> list:
376 proc
= self
._item
._from
_bytes
380 while offset
< length
:
381 vseq
.append({ }) # new item of sequence
382 offset
+= proc(vseq
[-1], data
[offset
:])
386 def to_bytes(self
, vseq
: list) -> bytes
:
387 proc
= self
._item
._to
_bytes
388 return b
''.join([proc(v
) for v
in vseq
])
391 ''' Field wrapper. '''
393 def __init__(self
, s
: 'Sequence', name
: str, **kw
) -> None:
394 Field
.__init
__(self
, name
, **kw
)
397 def _from_bytes(self
, vals
: dict, data
: bytes
) -> None:
398 vals
[self
.name
] = self
.s
.from_bytes(data
)
400 def _to_bytes(self
, vals
: dict) -> bytes
:
401 return self
.s
.to_bytes(self
.get_val(vals
))
403 def f(self
, name
: str, **kw
) -> Field
:
404 return self
.F(self
, name
, **kw
)