1 """Reader for training log.
3 See lib/Analysis/TrainingLogger.cpp for a description of the format.
11 from typing
import List
, Optional
14 "float": ctypes
.c_float
,
15 "double": ctypes
.c_double
,
16 "int8_t": ctypes
.c_int8
,
17 "uint8_t": ctypes
.c_uint8
,
18 "int16_t": ctypes
.c_int16
,
19 "uint16_t": ctypes
.c_uint16
,
20 "int32_t": ctypes
.c_int32
,
21 "uint32_t": ctypes
.c_uint32
,
22 "int64_t": ctypes
.c_int64
,
23 "uint64_t": ctypes
.c_uint64
,
27 @dataclasses.dataclass(frozen
=True)
35 def from_dict(d
: dict):
38 shape
= [int(e
) for e
in d
["shape"]]
39 element_type_str
= d
["type"]
40 if element_type_str
not in _element_types
:
41 raise ValueError(f
"uknown type: {element_type_str}")
46 element_type
=_element_types
[element_type_str
],
51 def __init__(self
, spec
: TensorSpec
, buffer: bytes
):
54 self
._view
= ctypes
.cast(self
._buffer
, ctypes
.POINTER(self
._spec
.element_type
))
55 self
._len
= math
.prod(self
._spec
.shape
)
57 def spec(self
) -> TensorSpec
:
60 def __len__(self
) -> int:
63 def __getitem__(self
, index
):
64 if index
< 0 or index
>= self
._len
:
65 raise IndexError(f
"Index {index} out of range [0..{self._len})")
66 return self
._view
[index
]
69 def read_tensor(fs
: io
.BufferedReader
, ts
: TensorSpec
) -> TensorValue
:
70 size
= math
.prod(ts
.shape
) * ctypes
.sizeof(ts
.element_type
)
72 return TensorValue(ts
, data
)
75 def pretty_print_tensor_value(tv
: TensorValue
):
76 print(f
'{tv.spec().name}: {",".join([str(v) for v in tv])}')
79 def read_header(f
: io
.BufferedReader
):
80 header
= json
.loads(f
.readline())
81 tensor_specs
= [TensorSpec
.from_dict(ts
) for ts
in header
["features"]]
82 score_spec
= TensorSpec
.from_dict(header
["score"]) if "score" in header
else None
83 advice_spec
= TensorSpec
.from_dict(header
["advice"]) if "advice" in header
else None
84 return tensor_specs
, score_spec
, advice_spec
87 def read_one_observation(
88 context
: Optional
[str],
91 tensor_specs
: List
[TensorSpec
],
92 score_spec
: Optional
[TensorSpec
],
94 event
= json
.loads(event_str
)
95 if "context" in event
:
96 context
= event
["context"]
97 event
= json
.loads(f
.readline())
98 observation_id
= int(event
["observation"])
100 for ts
in tensor_specs
:
101 features
.append(read_tensor(f
, ts
))
104 if score_spec
is not None:
105 score_header
= json
.loads(f
.readline())
106 assert int(score_header
["outcome"]) == observation_id
107 score
= read_tensor(f
, score_spec
)
109 return context
, observation_id
, features
, score
112 def read_stream(fname
: str):
113 with io
.BufferedReader(io
.FileIO(fname
, "rb")) as f
:
114 tensor_specs
, score_spec
, _
= read_header(f
)
117 event_str
= f
.readline()
120 context
, observation_id
, features
, score
= read_one_observation(
121 context
, event_str
, f
, tensor_specs
, score_spec
123 yield context
, observation_id
, features
, score
128 for ctx
, obs_id
, features
, score
in read_stream(args
[1]):
129 if last_context
!= ctx
:
130 print(f
"context: {ctx}")
132 print(f
"observation: {obs_id}")
134 pretty_print_tensor_value(fv
)
136 pretty_print_tensor_value(score
)
139 if __name__
== "__main__":