1 """Generate a mock model for LLVM tests for Register Allocation.
2 The generated model is not a neural net - it is just a tf.function with the
3 correct input and output parameters.
5 ## By construction, the mock model will always output the first liverange that can be evicted.
9 import tensorflow
as tf
11 POLICY_DECISION_LABEL
= "priority"
12 POLICY_OUTPUT_SPEC
= """
15 "logging_name": "priority",
17 "name": "StatefulPartitionedCall",
27 PER_LIVEINTERVAL_INT64_FEATURE_LIST
= ["li_size", "stage"]
28 PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST
= ["weight"]
29 PER_LIVEINTERVAL_FEATURE_LIST
= (
30 PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST
+ PER_LIVEINTERVAL_INT64_FEATURE_LIST
32 CONTEXT_FEATURE_LIST
= ("discount", "reward", "step_type")
35 def get_input_signature():
36 """Returns (time_step_spec, action_spec) for LLVM register allocation."""
38 (key
, tf
.TensorSpec(dtype
=tf
.int64
, shape
=(), name
=key
))
39 for key
in PER_LIVEINTERVAL_INT64_FEATURE_LIST
43 (key
, tf
.TensorSpec(dtype
=tf
.float32
, shape
=(), name
=key
))
44 for key
in PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST
49 (key
, tf
.TensorSpec(dtype
=tf
.float32
, shape
=(), name
=key
))
50 for key
in ["discount", "reward"]
55 (key
, tf
.TensorSpec(dtype
=tf
.int32
, shape
=(), name
=key
))
56 for key
in ["step_type"]
62 def get_output_spec_path(path
):
63 return os
.path
.join(path
, "output_spec.json")
66 def build_mock_model(path
):
67 """Build and save the mock model with the given signature."""
69 # We have to set this useless variable in order for the TF C API to correctly
71 module
.var
= tf
.Variable(0, dtype
=tf
.float32
)
76 tf
.cast(inputs
[0][key
], tf
.float32
)
77 for key
in PER_LIVEINTERVAL_FEATURE_LIST
82 [tf
.cast(inputs
[0][key
], tf
.float32
) for key
in CONTEXT_FEATURE_LIST
]
84 # Add a large number so s won't be 0.
86 result
= s
+ module
.var
87 return {POLICY_DECISION_LABEL
: result
}
89 module
.action
= tf
.function()(action
)
90 action
= {"action": module
.action
.get_concrete_function(get_input_signature())}
92 tf
.saved_model
.save(module
, path
, signatures
=action
)
93 output_spec_path
= get_output_spec_path(path
)
94 with
open(output_spec_path
, "w") as f
:
95 print(f
"Writing output spec to {output_spec_path}.")
96 f
.write(POLICY_OUTPUT_SPEC
)
100 assert len(argv
) == 2
102 build_mock_model(model_path
)
105 if __name__
== "__main__":