1 """Generate a mock model for LLVM tests.
3 The generated model is not a neural net - it is just a tf.function with the
4 correct input and output parameters. By construction, the mock model will always
12 import tensorflow
as tf
14 POLICY_DECISION_LABEL
= "inlining_decision"
15 POLICY_OUTPUT_SPEC
= """
18 "logging_name": "inlining_decision",
20 "name": "StatefulPartitionedCall",
32 # pylint: disable=g-complex-comprehension
33 def get_input_signature():
34 """Returns the list of features for LLVM inlining."""
37 tf
.TensorSpec(dtype
=tf
.int64
, shape
=(), name
=key
)
39 "caller_basic_block_count",
40 "caller_conditionally_executed_blocks",
42 "callee_basic_block_count",
43 "callee_conditionally_executed_blocks",
54 "call_argument_setup",
55 "load_relative_intrinsic",
56 "lowered_call_arg_setup",
57 "indirect_call_penalty",
59 "case_cluster_penalty",
61 "unsimplified_common_instructions",
64 "simplified_instructions",
66 "constant_offset_ptr_args",
69 "last_call_to_static_bonus",
72 "nested_inline_cost_estimate",
74 "is_callee_avail_external",
75 "is_caller_avail_external",
82 tf
.TensorSpec(dtype
=tf
.float32
, shape
=(), name
=key
)
83 for key
in ["discount", "reward"]
89 [tf
.TensorSpec(dtype
=tf
.int32
, shape
=(), name
=key
) for key
in ["step_type"]]
94 def get_output_signature():
95 return POLICY_DECISION_LABEL
98 def get_output_spec():
99 return POLICY_OUTPUT_SPEC
102 def get_output_spec_path(path
):
103 return os
.path
.join(path
, "output_spec.json")
106 def build_mock_model(path
, signature
, advice
):
107 """Build and save the mock model with the given signature"""
111 return {signature
["output"]: tf
.constant(value
=advice
, dtype
=tf
.int64
)}
113 module
.action
= tf
.function()(action
)
114 action
= {"action": module
.action
.get_concrete_function(signature
["inputs"])}
115 tf
.saved_model
.save(module
, path
, signatures
=action
)
117 output_spec_path
= get_output_spec_path(path
)
118 with
open(output_spec_path
, "w") as f
:
119 print(f
"Writing output spec to {output_spec_path}.")
120 f
.write(signature
["output_spec"])
125 "inputs": get_input_signature(),
126 "output": get_output_signature(),
127 "output_spec": get_output_spec(),
132 assert len(argv
) == 2 or (len(argv
) == 3 and argv
[2] == "never")
135 print(f
"Output model to: [{argv[1]}]")
140 print(f
"The model will always return: {constant_advice}")
142 signature
= get_signature()
143 build_mock_model(model_path
, signature
, constant_advice
)
146 if __name__
== "__main__":