1 //===- TFUtils.cpp - tensorflow evaluation utilities ----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements utilities for interfacing with tensorflow C APIs.
11 //===----------------------------------------------------------------------===//
12 #include "llvm/Config/config.h"
13 #if defined(LLVM_HAVE_TF_API)
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Analysis/Utils/TFUtils.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/JSON.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/MemoryBuffer.h"
22 #include "llvm/Support/Path.h"
23 #include "llvm/Support/raw_ostream.h"
25 #include "google/protobuf/text_format.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/c_api_experimental.h"
28 #include "tensorflow/core/example/example.pb.h"
34 using google::protobuf::Message
;
35 using google::protobuf::TextFormat
;
38 ProtobufTextMode("tfutils-text-log", cl::init(false), cl::Hidden
,
39 cl::desc("Output textual (human-readable) protobuf."));
43 using TFGraphPtr
= std::unique_ptr
<TF_Graph
, decltype(&TF_DeleteGraph
)>;
44 using TFSessionOptionsPtr
=
45 std::unique_ptr
<TF_SessionOptions
, decltype(&TF_DeleteSessionOptions
)>;
46 using TFStatusPtr
= std::unique_ptr
<TF_Status
, decltype(&TF_DeleteStatus
)>;
48 struct TFInitializer
{
50 assert(!IsInitialized
&& "TFInitialized should be called only once");
52 const char *Name
= "";
53 const char **NamePtr
= &Name
;
54 TF_InitMain(Name
, &Argc
, const_cast<char ***>(&NamePtr
));
57 bool IsInitialized
= false;
60 llvm::ManagedStatic
<TFInitializer
> TFLibInitializer
;
62 bool ensureInitTF() { return TFLibInitializer
->IsInitialized
; }
64 TFGraphPtr
createTFGraph() {
65 return TFGraphPtr(TF_NewGraph(), &TF_DeleteGraph
);
68 TFStatusPtr
createTFStatus() {
69 return TFStatusPtr(TF_NewStatus(), &TF_DeleteStatus
);
72 TFSessionOptionsPtr
createTFSessionOptions() {
73 return TFSessionOptionsPtr(TF_NewSessionOptions(), &TF_DeleteSessionOptions
);
78 class EvaluationResultImpl
{
80 EvaluationResultImpl(size_t OutputSize
)
81 : OutputSize(OutputSize
), Output(OutputSize
){};
83 ~EvaluationResultImpl() {
84 for (auto *P
: Output
)
89 EvaluationResultImpl(const EvaluationResultImpl
&) = delete;
90 EvaluationResultImpl(EvaluationResultImpl
&&Other
) = delete;
91 std::vector
<TF_Tensor
*> &getOutput() { return Output
; }
94 const size_t OutputSize
;
95 std::vector
<TF_Tensor
*> Output
;
98 size_t TensorSpec::getElementByteSize() const {
99 return TF_DataTypeSize(static_cast<TF_DataType
>(TypeIndex
));
102 TensorSpec::TensorSpec(const std::string
&Name
, int Port
, int TypeIndex
,
103 const std::vector
<int64_t> &Shape
)
104 : Name(Name
), Port(Port
), TypeIndex(TypeIndex
), Shape(Shape
),
105 ElementCount(std::accumulate(Shape
.begin(), Shape
.end(), 1,
106 std::multiplies
<int64_t>())) {}
108 Optional
<TensorSpec
> getTensorSpecFromJSON(LLVMContext
&Ctx
,
109 const json::Value
&Value
) {
110 auto EmitError
= [&](const llvm::Twine
&Message
) -> Optional
<TensorSpec
> {
112 llvm::raw_string_ostream
OS(S
);
114 Ctx
.emitError("Unable to parse JSON Value as spec (" + Message
+ "): " + S
);
117 // FIXME: accept a Path as a parameter, and use it for error reporting.
118 json::Path::Root
Root("tensor_spec");
119 json::ObjectMapper
Mapper(Value
, Root
);
121 return EmitError("Value is not a dict");
123 std::string TensorName
;
125 std::string TensorType
;
126 std::vector
<int64_t> TensorShape
;
128 if (!Mapper
.map
<std::string
>("name", TensorName
))
129 return EmitError("'name' property not present or not a string");
130 if (!Mapper
.map
<std::string
>("type", TensorType
))
131 return EmitError("'type' property not present or not a string");
132 if (!Mapper
.map
<int>("port", TensorPort
))
133 return EmitError("'port' property not present or not an int");
134 if (!Mapper
.map
<std::vector
<int64_t>>("shape", TensorShape
))
135 return EmitError("'shape' property not present or not an int array");
137 #define PARSE_TYPE(T, E) \
138 if (TensorType == #T) \
139 return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
140 TFUTILS_SUPPORTED_TYPES(PARSE_TYPE
)
145 Optional
<std::vector
<LoggedFeatureSpec
>>
146 loadOutputSpecs(LLVMContext
&Ctx
, StringRef ExpectedDecisionName
,
147 StringRef ModelPath
, StringRef SpecFileOverride
) {
148 SmallVector
<char, 128> OutputSpecsPath
;
149 StringRef FileName
= SpecFileOverride
;
150 if (FileName
.empty()) {
151 llvm::sys::path::append(OutputSpecsPath
, ModelPath
, "output_spec.json");
152 FileName
= {OutputSpecsPath
.data(), OutputSpecsPath
.size()};
155 auto BufferOrError
= MemoryBuffer::getFileOrSTDIN(FileName
);
156 if (!BufferOrError
) {
157 Ctx
.emitError("Error opening output specs file: " + FileName
+ " : " +
158 BufferOrError
.getError().message());
161 auto ParsedJSONValues
= json::parse(BufferOrError
.get()->getBuffer());
162 if (!ParsedJSONValues
) {
163 Ctx
.emitError("Could not parse specs file: " + FileName
);
166 auto ValuesArray
= ParsedJSONValues
->getAsArray();
168 Ctx
.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
169 "logging_name:<name>} dictionaries");
172 std::vector
<LoggedFeatureSpec
> Ret
;
173 for (const auto &Value
: *ValuesArray
)
174 if (const auto *Obj
= Value
.getAsObject())
175 if (const auto *SpecPart
= Obj
->get("tensor_spec"))
176 if (auto TensorSpec
= getTensorSpecFromJSON(Ctx
, *SpecPart
))
177 if (auto LoggingName
= Obj
->getString("logging_name")) {
178 if (!TensorSpec
->isElementType
<int64_t>() &&
179 !TensorSpec
->isElementType
<int32_t>() &&
180 !TensorSpec
->isElementType
<float>()) {
182 "Only int64, int32, and float tensors are supported. "
183 "Found unsupported type for tensor named " +
187 Ret
.push_back({*TensorSpec
, LoggingName
->str()});
190 if (ValuesArray
->size() != Ret
.size()) {
192 "Unable to parse output spec. It should be a json file containing an "
193 "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
194 "with a json object describing a TensorSpec; and a 'logging_name' key, "
195 "which is a string to use as name when logging this tensor in the "
199 if (Ret
.empty() || *Ret
[0].LoggingName
!= ExpectedDecisionName
) {
200 Ctx
.emitError("The first output spec must describe the decision tensor, "
201 "and must have the logging_name " +
202 StringRef(ExpectedDecisionName
));
208 class TFModelEvaluatorImpl
{
210 TFModelEvaluatorImpl(StringRef SavedModelPath
,
211 const std::vector
<TensorSpec
> &InputSpecs
,
212 function_ref
<TensorSpec(size_t)> GetOutputSpecs
,
213 size_t OutputSpecsSize
, const char *Tags
);
215 bool isValid() const { return IsValid
; }
216 size_t OutputSize() const { return OutputFeed
.size(); }
218 void evaluate(TF_Tensor
**Output
, TF_Status
*Status
) {
219 TF_SessionRun(Session
, nullptr, InputFeed
.data(), Input
.data(),
220 Input
.size(), OutputFeed
.data(), Output
, OutputFeed
.size(),
221 nullptr, 0, nullptr, Status
);
224 void initInput(size_t Index
, TF_DataType Type
,
225 const std::vector
<int64_t> &Dimensions
);
226 const std::vector
<TF_Tensor
*> &getInput() const { return Input
; }
228 ~TFModelEvaluatorImpl();
231 /// The objects necessary for carrying out an evaluation of the SavedModel.
232 /// They are expensive to set up, and we maintain them accross all the
233 /// evaluations of the model.
234 TF_Session
*Session
= nullptr;
236 TFSessionOptionsPtr Options
;
238 /// The specification of the input nodes.
239 std::vector
<TF_Output
> InputFeed
;
241 /// The input tensors. They must match by index of the corresponding InputFeed
242 /// value. We set up the tensors once and just mutate theirs scalars before
243 /// each evaluation. The input tensors keep their value after an evaluation.
244 std::vector
<TF_Tensor
*> Input
;
246 /// The specification of the output nodes. When evaluating, the tensors in the
247 /// output tensor vector must match by index the corresponding element in the
249 std::vector
<TF_Output
> OutputFeed
;
251 void invalidate() { IsValid
= false; }
255 /// Reusable utility for ensuring we can bind the requested Name to a node in
256 /// the SavedModel Graph.
257 bool checkReportAndInvalidate(const TF_Output
&Output
,
258 const TensorSpec
&OutputSpec
);
261 class LoggerDataImpl
{
262 const std::vector
<LoggedFeatureSpec
> LoggedFeatureSpecs
;
263 const TensorSpec RewardSpec
;
264 const bool IncludeReward
;
266 std::vector
<tensorflow::FeatureList
> FeatureLists
;
267 tensorflow::FeatureList Reward
;
269 bool isSelfConsistent(const tensorflow::SequenceExample
&SE
,
270 size_t NrRecords
) const {
272 for (const auto &TSpecs
: LoggedFeatureSpecs
) {
273 const auto &Name
= TSpecs
.getLoggingName();
274 const auto &FL
= SE
.feature_lists().feature_list().at(Name
).feature();
275 if (NrRecords
!= static_cast<size_t>(FL
.size())) {
276 dbgs() << "[TF-UTILS]: " << Name
<< " has missing records. Expected "
277 << NrRecords
<< " got " << FL
.size() << "\n";
281 if (IncludeReward
&& static_cast<size_t>(SE
.feature_lists()
283 .at(RewardSpec
.name())
285 .size()) != NrRecords
) {
286 dbgs() << "[TF-UTILS]: reward is missing records.\n";
292 void transferLog(tensorflow::SequenceExample
&SE
) {
293 auto *FL
= SE
.mutable_feature_lists()->mutable_feature_list();
295 (*FL
)[RewardSpec
.name()] = std::move(Reward
);
296 assert(FeatureLists
.size() == LoggedFeatureSpecs
.size());
297 for (size_t I
= 0; I
< FeatureLists
.size(); ++I
) {
298 const auto &LFS
= LoggedFeatureSpecs
[I
];
299 (*FL
)[LFS
.getLoggingName()] = std::move(FeatureLists
[I
]);
304 LoggerDataImpl(const std::vector
<LoggedFeatureSpec
> &LoggedSpecs
,
305 const TensorSpec
&RewardSpec
, bool IncludeReward
)
306 : LoggedFeatureSpecs(LoggedSpecs
), RewardSpec(RewardSpec
),
307 IncludeReward(IncludeReward
), FeatureLists(LoggedFeatureSpecs
.size()) {}
309 // flush the logged info to a stream and clear the log contents.
310 void flush(raw_ostream
&OS
) {
311 size_t NrRecords
= getNrRecords();
313 tensorflow::SequenceExample SE
;
315 assert(isSelfConsistent(SE
, NrRecords
));
317 if (ProtobufTextMode
)
318 google::protobuf::TextFormat::PrintToString(SE
, &OutStr
);
320 OutStr
= SE
.SerializeAsString();
325 char *addNewTensor(size_t FeatureID
) {
326 const auto &Spec
= LoggedFeatureSpecs
[FeatureID
].Spec
;
327 if (Spec
.isElementType
<float>()) {
328 auto *RF
= FeatureLists
[FeatureID
]
330 ->mutable_float_list()
332 RF
->Resize(Spec
.getElementCount(), 0.0);
333 return reinterpret_cast<char *>(RF
->mutable_data());
334 } else if (Spec
.isElementType
<int32_t>() || Spec
.isElementType
<int64_t>()) {
335 auto *RF
= FeatureLists
[FeatureID
]
337 ->mutable_int64_list()
339 RF
->Resize(Spec
.getElementCount(), 0);
340 return reinterpret_cast<char *>(RF
->mutable_data());
342 llvm_unreachable("Unsupported tensor type.");
345 template <typename T
> void logReward(T Value
) {
346 assert(IncludeReward
);
347 if (RewardSpec
.isElementType
<float>())
348 Reward
.add_feature()->mutable_float_list()->add_value(Value
);
349 else if (RewardSpec
.isElementType
<int32_t>() ||
350 RewardSpec
.isElementType
<int64_t>())
351 Reward
.add_feature()->mutable_int64_list()->add_value(Value
);
353 llvm_unreachable("Unsupported tensor type.");
356 size_t getNrRecords() const {
357 return FeatureLists
.empty() ? 0 : FeatureLists
[0].feature().size();
362 TFModelEvaluatorImpl::TFModelEvaluatorImpl(
363 StringRef SavedModelPath
, const std::vector
<TensorSpec
> &InputSpecs
,
364 function_ref
<TensorSpec(size_t)> GetOutputSpecs
, size_t OutputSpecsSize
,
365 const char *Tags
= "serve")
366 : Graph(createTFGraph()), Options(createTFSessionOptions()),
367 InputFeed(InputSpecs
.size()), Input(InputSpecs
.size()),
368 OutputFeed(OutputSpecsSize
) {
369 if (!ensureInitTF()) {
370 errs() << "Tensorflow should have been initialized";
373 auto Status
= createTFStatus();
375 Session
= TF_LoadSessionFromSavedModel(Options
.get(), nullptr,
376 SavedModelPath
.str().c_str(), &Tags
, 1,
377 Graph
.get(), nullptr, Status
.get());
378 if (TF_GetCode(Status
.get()) != TF_Code::TF_OK
) {
379 errs() << TF_Message(Status
.get());
382 for (size_t I
= 0; I
< InputSpecs
.size(); ++I
) {
383 auto &InputSpec
= InputSpecs
[I
];
385 TF_GraphOperationByName(Graph
.get(), (InputSpec
.name()).c_str()),
387 if (!checkReportAndInvalidate(InputFeed
[I
], InputSpec
))
389 initInput(I
, static_cast<TF_DataType
>(InputSpec
.typeIndex()),
392 for (size_t I
= 0; I
< OutputSpecsSize
; ++I
) {
393 auto OutputSpec
= GetOutputSpecs(I
);
395 TF_GraphOperationByName(Graph
.get(), (OutputSpec
.name()).c_str()),
397 if (!checkReportAndInvalidate(OutputFeed
[I
], OutputSpec
))
402 TFModelEvaluator::TFModelEvaluator(
403 StringRef SavedModelPath
, const std::vector
<TensorSpec
> &InputSpecs
,
404 function_ref
<TensorSpec(size_t)> GetOutputSpecs
, size_t OutputSpecsSize
,
406 : Impl(new TFModelEvaluatorImpl(SavedModelPath
, InputSpecs
, GetOutputSpecs
,
407 OutputSpecsSize
, Tags
)) {
408 if (!Impl
->isValid())
412 TFModelEvaluator::TFModelEvaluator(StringRef SavedModelPath
,
413 const std::vector
<TensorSpec
> &InputSpecs
,
414 const std::vector
<TensorSpec
> &OutputSpecs
,
417 SavedModelPath
, InputSpecs
, [&](size_t I
) { return OutputSpecs
[I
]; },
418 OutputSpecs
.size(), Tags
) {}
420 TFModelEvaluatorImpl::~TFModelEvaluatorImpl() {
421 for (auto *T
: Input
) {
424 if (Session
== nullptr)
426 auto Status
= createTFStatus();
427 TF_DeleteSession(Session
, Status
.get());
429 if (TF_GetCode(Status
.get()) != TF_Code::TF_OK
)
430 errs() << "Could not delete TF session";
433 bool TFModelEvaluatorImpl::checkReportAndInvalidate(
434 const TF_Output
&Output
, const TensorSpec
&OutputSpec
) {
437 errs() << "Could not find TF_Output named: " + OutputSpec
.name();
442 Optional
<TFModelEvaluator::EvaluationResult
> TFModelEvaluator::evaluate() {
445 std::unique_ptr
<EvaluationResultImpl
> Ret
=
446 std::make_unique
<EvaluationResultImpl
>(Impl
->OutputSize());
447 auto Status
= createTFStatus();
448 Impl
->evaluate(Ret
->getOutput().data(), Status
.get());
449 if (TF_GetCode(Status
.get()) != TF_Code::TF_OK
) {
450 errs() << TF_Message(Status
.get());
454 return EvaluationResult(std::move(Ret
));
457 void TFModelEvaluatorImpl::initInput(size_t Index
, TF_DataType Type
,
458 const std::vector
<int64_t> &Dimensions
) {
459 int64_t TotalSize
= TF_DataTypeSize(Type
);
460 for (auto &D
: Dimensions
)
464 TF_AllocateTensor(Type
, Dimensions
.data(), Dimensions
.size(), TotalSize
);
465 std::memset(TF_TensorData(Input
[Index
]), 0, TotalSize
);
468 void *TFModelEvaluator::getUntypedInput(size_t Index
) {
469 return TF_TensorData(Impl
->getInput()[Index
]);
472 TFModelEvaluator::EvaluationResult::EvaluationResult(
473 std::unique_ptr
<EvaluationResultImpl
> Impl
)
474 : Impl(std::move(Impl
)) {}
476 TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult
&&Other
)
477 : Impl(std::move(Other
.Impl
)) {}
479 TFModelEvaluator::EvaluationResult
&
480 TFModelEvaluator::EvaluationResult::operator=(EvaluationResult
&&Other
) {
481 Impl
= std::move(Other
.Impl
);
485 void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index
) {
486 return TF_TensorData(Impl
->getOutput()[Index
]);
490 TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index
) const {
491 return TF_TensorData(Impl
->getOutput()[Index
]);
494 #define TFUTILS_GETDATATYPE_IMPL(T, E) \
495 template <> int TensorSpec::getDataType<T>() { return E; }
497 TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL
)
499 #undef TFUTILS_GETDATATYPE_IMPL
501 TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
502 TFModelEvaluator::~TFModelEvaluator() {}
504 Logger::Logger(const std::vector
<LoggedFeatureSpec
> &FeatureSpecs
,
505 const TensorSpec
&RewardSpec
, bool IncludeReward
)
506 : FeatureSpecs(FeatureSpecs
), RewardSpec(RewardSpec
),
507 IncludeReward(IncludeReward
),
508 LoggerData(std::make_unique
<LoggerDataImpl
>(FeatureSpecs
, RewardSpec
,
513 #define LOG_REWARD(NAME, TYPE) \
514 void Logger::log##NAME##Reward(TYPE Value) { \
515 assert(IncludeReward); \
516 LoggerData->logReward(Value); \
519 LOG_REWARD(Float
, float)
520 LOG_REWARD(Int32
, int32_t)
521 LOG_REWARD(Int64
, int64_t)
524 #define LOG_FINAL_REWARD(NAME, TYPE) \
525 void Logger::log##NAME##FinalReward(TYPE Value) { \
526 assert(RewardSpec.isElementType<TYPE>()); \
527 for (size_t I = 1; I < LoggerData->getNrRecords(); ++I) \
528 log##NAME##Reward(0); \
529 log##NAME##Reward(Value); \
532 LOG_FINAL_REWARD(Float
, float)
533 LOG_FINAL_REWARD(Int32
, int32_t)
534 LOG_FINAL_REWARD(Int64
, int64_t)
535 #undef LOG_FINAL_REWARD
537 void Logger::logFloatValue(size_t FeatureID
, const float *Value
) {
538 assert(FeatureSpecs
[FeatureID
].Spec
.isElementType
<float>());
539 logSpecifiedTensorValue(FeatureID
, reinterpret_cast<const char *>(Value
));
542 void Logger::logInt64Value(size_t FeatureID
, const int64_t *Value
) {
543 assert(FeatureSpecs
[FeatureID
].Spec
.isElementType
<int64_t>());
544 logSpecifiedTensorValue(FeatureID
, reinterpret_cast<const char *>(Value
));
547 void Logger::logInt32Value(size_t FeatureID
, const int32_t *Value
) {
548 assert(FeatureSpecs
[FeatureID
].Spec
.isElementType
<int32_t>());
549 logSpecifiedTensorValue(FeatureID
, reinterpret_cast<const char *>(Value
));
552 void Logger::logSpecifiedTensorValue(size_t FeatureID
, const char *RawData
) {
553 const auto &Spec
= FeatureSpecs
[FeatureID
].Spec
;
554 char *Buff
= addEntryAndGetFloatOrInt64Buffer(FeatureID
);
555 if (Spec
.isElementType
<int32_t>())
556 for (size_t I
= 0; I
< Spec
.getElementCount(); ++I
)
557 (reinterpret_cast<int64_t *>(Buff
))[I
] =
558 static_cast<int64_t>((reinterpret_cast<const int32_t *>(RawData
))[I
]);
559 else if (Spec
.isElementType
<int64_t>() || Spec
.isElementType
<float>())
560 std::memcpy(Buff
, RawData
,
561 Spec
.getElementCount() * Spec
.getElementByteSize());
563 llvm_unreachable("Unsupported tensor type");
566 char *Logger::addEntryAndGetFloatOrInt64Buffer(size_t FeatureID
) {
567 return reinterpret_cast<char *>(LoggerData
->addNewTensor(FeatureID
));
570 void Logger::flush(raw_ostream
&OS
) { LoggerData
->flush(OS
); }
571 #endif // defined(LLVM_HAVE_TF_API)