Revert "[lldb][test] Remove compiler version check and use regex" (#124101)
[llvm-project.git] / llvm / lib / Analysis / models / gen-regalloc-priority-test-model.py
blob889ddae48b1ffcda8134099720b14e2ebef06d76
1 """Generate a mock model for LLVM tests for Register Allocation.
2 The generated model is not a neural net - it is just a tf.function with the
3 correct input and output parameters.
4 """
5 ## By construction, the mock model will always output the first liverange that can be evicted.
7 import os
8 import sys
9 import tensorflow as tf
11 POLICY_DECISION_LABEL = "priority"
12 POLICY_OUTPUT_SPEC = """
15 "logging_name": "priority",
16 "tensor_spec": {
17 "name": "StatefulPartitionedCall",
18 "port": 0,
19 "type": "float",
20 "shape": [
26 """
27 PER_LIVEINTERVAL_INT64_FEATURE_LIST = ["li_size", "stage"]
28 PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST = ["weight"]
29 PER_LIVEINTERVAL_FEATURE_LIST = (
30 PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST + PER_LIVEINTERVAL_INT64_FEATURE_LIST
32 CONTEXT_FEATURE_LIST = ("discount", "reward", "step_type")
35 def get_input_signature():
36 """Returns (time_step_spec, action_spec) for LLVM register allocation."""
37 inputs = dict(
38 (key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key))
39 for key in PER_LIVEINTERVAL_INT64_FEATURE_LIST
41 inputs.update(
42 dict(
43 (key, tf.TensorSpec(dtype=tf.float32, shape=(), name=key))
44 for key in PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST
47 inputs.update(
48 dict(
49 (key, tf.TensorSpec(dtype=tf.float32, shape=(), name=key))
50 for key in ["discount", "reward"]
53 inputs.update(
54 dict(
55 (key, tf.TensorSpec(dtype=tf.int32, shape=(), name=key))
56 for key in ["step_type"]
59 return inputs
62 def get_output_spec_path(path):
63 return os.path.join(path, "output_spec.json")
66 def build_mock_model(path):
67 """Build and save the mock model with the given signature."""
68 module = tf.Module()
69 # We have to set this useless variable in order for the TF C API to correctly
70 # intake it
71 module.var = tf.Variable(0, dtype=tf.float32)
73 def action(*inputs):
74 s1 = tf.reduce_sum(
76 tf.cast(inputs[0][key], tf.float32)
77 for key in PER_LIVEINTERVAL_FEATURE_LIST
79 axis=0,
81 s2 = tf.reduce_sum(
82 [tf.cast(inputs[0][key], tf.float32) for key in CONTEXT_FEATURE_LIST]
84 # Add a large number so s won't be 0.
85 s = s1 + s2
86 result = s + module.var
87 return {POLICY_DECISION_LABEL: result}
89 module.action = tf.function()(action)
90 action = {"action": module.action.get_concrete_function(get_input_signature())}
92 tf.saved_model.save(module, path, signatures=action)
93 output_spec_path = get_output_spec_path(path)
94 with open(output_spec_path, "w") as f:
95 print(f"Writing output spec to {output_spec_path}.")
96 f.write(POLICY_OUTPUT_SPEC)
99 def main(argv):
100 assert len(argv) == 2
101 model_path = argv[1]
102 build_mock_model(model_path)
105 if __name__ == "__main__":
106 main(sys.argv)