Revert "[lldb][test] Remove compiler version check and use regex" (#124101)
[llvm-project.git] / llvm / lib / Analysis / models / log_reader.py
blob7080276a0d85d9d395f1780fb07045c76f906b6b
1 """Reader for training log.
3 See lib/Analysis/TrainingLogger.cpp for a description of the format.
4 """
5 import ctypes
6 import dataclasses
7 import io
8 import json
9 import math
10 import sys
11 from typing import List, Optional
13 _element_types = {
14 "float": ctypes.c_float,
15 "double": ctypes.c_double,
16 "int8_t": ctypes.c_int8,
17 "uint8_t": ctypes.c_uint8,
18 "int16_t": ctypes.c_int16,
19 "uint16_t": ctypes.c_uint16,
20 "int32_t": ctypes.c_int32,
21 "uint32_t": ctypes.c_uint32,
22 "int64_t": ctypes.c_int64,
23 "uint64_t": ctypes.c_uint64,
27 @dataclasses.dataclass(frozen=True)
28 class TensorSpec:
29 name: str
30 port: int
31 shape: List[int]
32 element_type: type
34 @staticmethod
35 def from_dict(d: dict):
36 name = d["name"]
37 port = d["port"]
38 shape = [int(e) for e in d["shape"]]
39 element_type_str = d["type"]
40 if element_type_str not in _element_types:
41 raise ValueError(f"uknown type: {element_type_str}")
42 return TensorSpec(
43 name=name,
44 port=port,
45 shape=shape,
46 element_type=_element_types[element_type_str],
50 class TensorValue:
51 def __init__(self, spec: TensorSpec, buffer: bytes):
52 self._spec = spec
53 self._buffer = buffer
54 self._view = ctypes.cast(self._buffer, ctypes.POINTER(self._spec.element_type))
55 self._len = math.prod(self._spec.shape)
57 def spec(self) -> TensorSpec:
58 return self._spec
60 def __len__(self) -> int:
61 return self._len
63 def __getitem__(self, index):
64 if index < 0 or index >= self._len:
65 raise IndexError(f"Index {index} out of range [0..{self._len})")
66 return self._view[index]
69 def read_tensor(fs: io.BufferedReader, ts: TensorSpec) -> TensorValue:
70 size = math.prod(ts.shape) * ctypes.sizeof(ts.element_type)
71 data = fs.read(size)
72 return TensorValue(ts, data)
75 def pretty_print_tensor_value(tv: TensorValue):
76 print(f'{tv.spec().name}: {",".join([str(v) for v in tv])}')
79 def read_header(f: io.BufferedReader):
80 header = json.loads(f.readline())
81 tensor_specs = [TensorSpec.from_dict(ts) for ts in header["features"]]
82 score_spec = TensorSpec.from_dict(header["score"]) if "score" in header else None
83 advice_spec = TensorSpec.from_dict(header["advice"]) if "advice" in header else None
84 return tensor_specs, score_spec, advice_spec
87 def read_one_observation(
88 context: Optional[str],
89 event_str: str,
90 f: io.BufferedReader,
91 tensor_specs: List[TensorSpec],
92 score_spec: Optional[TensorSpec],
94 event = json.loads(event_str)
95 if "context" in event:
96 context = event["context"]
97 event = json.loads(f.readline())
98 observation_id = int(event["observation"])
99 features = []
100 for ts in tensor_specs:
101 features.append(read_tensor(f, ts))
102 f.readline()
103 score = None
104 if score_spec is not None:
105 score_header = json.loads(f.readline())
106 assert int(score_header["outcome"]) == observation_id
107 score = read_tensor(f, score_spec)
108 f.readline()
109 return context, observation_id, features, score
112 def read_stream(fname: str):
113 with io.BufferedReader(io.FileIO(fname, "rb")) as f:
114 tensor_specs, score_spec, _ = read_header(f)
115 context = None
116 while True:
117 event_str = f.readline()
118 if not event_str:
119 break
120 context, observation_id, features, score = read_one_observation(
121 context, event_str, f, tensor_specs, score_spec
123 yield context, observation_id, features, score
126 def main(args):
127 last_context = None
128 for ctx, obs_id, features, score in read_stream(args[1]):
129 if last_context != ctx:
130 print(f"context: {ctx}")
131 last_context = ctx
132 print(f"observation: {obs_id}")
133 for fv in features:
134 pretty_print_tensor_value(fv)
135 if score:
136 pretty_print_tensor_value(score)
139 if __name__ == "__main__":
140 main(sys.argv)