1 //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation
10 // happens off a model that's provided from the command line and is interpreted.
12 //===----------------------------------------------------------------------===//
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/Config/config.h"
16 #if defined(LLVM_HAVE_TFLITE)
17 #include "llvm/Analysis/ModelUnderTrainingRunner.h"
18 #include "llvm/Support/MemoryBuffer.h"
19 #include "llvm/Support/Path.h"
24 struct LoggedFeatureSpec
{
26 std::optional
<std::string
> LoggingName
;
29 std::optional
<std::vector
<LoggedFeatureSpec
>>
30 loadOutputSpecs(LLVMContext
&Ctx
, StringRef ExpectedDecisionName
,
31 StringRef ModelPath
, StringRef SpecFileOverride
) {
32 SmallVector
<char, 128> OutputSpecsPath
;
33 StringRef FileName
= SpecFileOverride
;
34 if (FileName
.empty()) {
35 llvm::sys::path::append(OutputSpecsPath
, ModelPath
, "output_spec.json");
36 FileName
= {OutputSpecsPath
.data(), OutputSpecsPath
.size()};
39 auto BufferOrError
= MemoryBuffer::getFileOrSTDIN(FileName
);
41 Ctx
.emitError("Error opening output specs file: " + FileName
+ " : " +
42 BufferOrError
.getError().message());
45 auto ParsedJSONValues
= json::parse(BufferOrError
.get()->getBuffer());
46 if (!ParsedJSONValues
) {
47 Ctx
.emitError("Could not parse specs file: " + FileName
);
50 auto ValuesArray
= ParsedJSONValues
->getAsArray();
52 Ctx
.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
53 "logging_name:<name>} dictionaries");
56 std::vector
<LoggedFeatureSpec
> Ret
;
57 for (const auto &Value
: *ValuesArray
)
58 if (const auto *Obj
= Value
.getAsObject())
59 if (const auto *SpecPart
= Obj
->get("tensor_spec"))
60 if (auto TensorSpec
= getTensorSpecFromJSON(Ctx
, *SpecPart
))
61 if (auto LoggingName
= Obj
->getString("logging_name")) {
62 if (!TensorSpec
->isElementType
<int64_t>() &&
63 !TensorSpec
->isElementType
<int32_t>() &&
64 !TensorSpec
->isElementType
<float>()) {
66 "Only int64, int32, and float tensors are supported. "
67 "Found unsupported type for tensor named " +
71 Ret
.push_back({*TensorSpec
, LoggingName
->str()});
74 if (ValuesArray
->size() != Ret
.size()) {
76 "Unable to parse output spec. It should be a json file containing an "
77 "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
78 "with a json object describing a TensorSpec; and a 'logging_name' key, "
79 "which is a string to use as name when logging this tensor in the "
83 if (Ret
.empty() || *Ret
[0].LoggingName
!= ExpectedDecisionName
) {
84 Ctx
.emitError("The first output spec must describe the decision tensor, "
85 "and must have the logging_name " +
86 StringRef(ExpectedDecisionName
));
93 ModelUnderTrainingRunner::ModelUnderTrainingRunner(
94 LLVMContext
&Ctx
, const std::string
&ModelPath
,
95 const std::vector
<TensorSpec
> &InputSpecs
,
96 const std::vector
<TensorSpec
> &OutputSpecs
,
97 const std::vector
<TensorSpec
> &ExtraOutputsForLogging
)
98 : MLModelRunner(Ctx
, MLModelRunner::Kind::Development
, InputSpecs
.size()),
99 OutputSpecs(OutputSpecs
), ExtraOutputsForLogging(ExtraOutputsForLogging
) {
101 std::make_unique
<TFModelEvaluator
>(ModelPath
, InputSpecs
, OutputSpecs
);
102 if (!Evaluator
|| !Evaluator
->isValid()) {
103 Ctx
.emitError("Failed to create saved model evaluator");
108 for (size_t I
= 0, E
= InputSpecs
.size(); I
< E
; ++I
) {
109 setUpBufferForTensor(I
, InputSpecs
[I
], Evaluator
->getUntypedInput(I
));
113 void *ModelUnderTrainingRunner::evaluateUntyped() {
114 LastEvaluationResult
= Evaluator
->evaluate();
115 if (!LastEvaluationResult
.has_value()) {
116 Ctx
.emitError("Error evaluating model.");
119 return LastEvaluationResult
->getUntypedTensorValue(0);
122 std::unique_ptr
<ModelUnderTrainingRunner
>
123 ModelUnderTrainingRunner::createAndEnsureValid(
124 LLVMContext
&Ctx
, const std::string
&ModelPath
, StringRef DecisionName
,
125 const std::vector
<TensorSpec
> &InputSpecs
,
126 StringRef OutputSpecsPathOverride
) {
127 if (auto MaybeOutputSpecs
= loadOutputSpecs(Ctx
, DecisionName
, ModelPath
,
128 OutputSpecsPathOverride
)) {
129 std::unique_ptr
<ModelUnderTrainingRunner
> MUTR
;
130 std::vector
<TensorSpec
> OutputSpecs
;
131 std::vector
<TensorSpec
> ExtraOutputsForLogging
;
132 append_range(OutputSpecs
,
133 map_range(*MaybeOutputSpecs
, [](const LoggedFeatureSpec
&LFS
) {
136 append_range(ExtraOutputsForLogging
,
137 map_range(drop_begin(*MaybeOutputSpecs
),
138 [](const LoggedFeatureSpec
&LFS
) {
139 return TensorSpec(LFS
.LoggingName
145 MUTR
.reset(new ModelUnderTrainingRunner(
146 Ctx
, ModelPath
, InputSpecs
, OutputSpecs
, ExtraOutputsForLogging
));
147 if (MUTR
&& MUTR
->isValid())
150 Ctx
.emitError("Could not load or create model evaluator.");
153 Ctx
.emitError("Could not load the policy model from the provided path");
157 #endif // defined(LLVM_HAVE_TFLITE)