1 """Generate a mock model for LLVM tests.
3 The generated model is not a neural net - it is just a tf.function with the
4 correct input and output parameters. By construction, the mock model will always
12 import tensorflow
as tf
14 POLICY_DECISION_LABEL
= "inlining_decision"
15 POLICY_OUTPUT_SPEC
= """
18 "logging_name": "inlining_decision",
20 "name": "StatefulPartitionedCall",
32 # pylint: disable=g-complex-comprehension
33 def get_input_signature():
34 """Returns the list of features for LLVM inlining."""
37 tf
.TensorSpec(dtype
=tf
.int64
, shape
=(), name
=key
)
39 "caller_basic_block_count",
40 "caller_conditionally_executed_blocks",
42 "callee_basic_block_count",
43 "callee_conditionally_executed_blocks",
55 "call_argument_setup",
56 "load_relative_intrinsic",
57 "lowered_call_arg_setup",
58 "indirect_call_penalty",
60 "case_cluster_penalty",
62 "unsimplified_common_instructions",
65 "simplified_instructions",
67 "constant_offset_ptr_args",
70 "last_call_to_static_bonus",
73 "nested_inline_cost_estimate",
81 tf
.TensorSpec(dtype
=tf
.float32
, shape
=(), name
=key
)
82 for key
in ["discount", "reward"]
88 [tf
.TensorSpec(dtype
=tf
.int32
, shape
=(), name
=key
) for key
in ["step_type"]]
93 def get_output_signature():
94 return POLICY_DECISION_LABEL
97 def get_output_spec():
98 return POLICY_OUTPUT_SPEC
101 def get_output_spec_path(path
):
102 return os
.path
.join(path
, "output_spec.json")
105 def build_mock_model(path
, signature
):
106 """Build and save the mock model with the given signature"""
110 return {signature
["output"]: tf
.constant(value
=1, dtype
=tf
.int64
)}
112 module
.action
= tf
.function()(action
)
113 action
= {"action": module
.action
.get_concrete_function(signature
["inputs"])}
114 tf
.saved_model
.save(module
, path
, signatures
=action
)
116 output_spec_path
= get_output_spec_path(path
)
117 with
open(output_spec_path
, "w") as f
:
118 print(f
"Writing output spec to {output_spec_path}.")
119 f
.write(signature
["output_spec"])
124 "inputs": get_input_signature(),
125 "output": get_output_signature(),
126 "output_spec": get_output_spec(),
131 assert len(argv
) == 2
134 print(f
"Output model to: [{argv[1]}]")
135 signature
= get_signature()
136 build_mock_model(model_path
, signature
)
139 if __name__
== "__main__":