1 //===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Implementation file for the abstraction of a tensor type, and JSON loading
12 //===----------------------------------------------------------------------===//
13 #include "llvm/ADT/STLExtras.h"
14 #include "llvm/Config/config.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/TensorSpec.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/JSON.h"
22 #include "llvm/Support/ManagedStatic.h"
23 #include "llvm/Support/raw_ostream.h"
32 #define TFUTILS_GETDATATYPE_IMPL(T, E) \
33 template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
35 SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL
)
37 #undef TFUTILS_GETDATATYPE_IMPL
39 static std::array
<std::string
, static_cast<size_t>(TensorType::Total
)>
40 TensorTypeNames
{"INVALID",
41 #define TFUTILS_GETNAME_IMPL(T, _) #T,
42 SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL
)
43 #undef TFUTILS_GETNAME_IMPL
46 StringRef
toString(TensorType TT
) {
47 return TensorTypeNames
[static_cast<size_t>(TT
)];
50 void TensorSpec::toJSON(json::OStream
&OS
) const {
52 OS
.attribute("name", name());
53 OS
.attribute("type", toString(type()));
54 OS
.attribute("port", port());
55 OS
.attributeArray("shape", [&]() {
56 for (size_t D
: shape())
57 OS
.value(static_cast<int64_t>(D
));
62 TensorSpec::TensorSpec(const std::string
&Name
, int Port
, TensorType Type
,
63 size_t ElementSize
, const std::vector
<int64_t> &Shape
)
64 : Name(Name
), Port(Port
), Type(Type
), Shape(Shape
),
65 ElementCount(std::accumulate(Shape
.begin(), Shape
.end(), 1,
66 std::multiplies
<int64_t>())),
67 ElementSize(ElementSize
) {}
69 std::optional
<TensorSpec
> getTensorSpecFromJSON(LLVMContext
&Ctx
,
70 const json::Value
&Value
) {
72 [&](const llvm::Twine
&Message
) -> std::optional
<TensorSpec
> {
74 llvm::raw_string_ostream
OS(S
);
76 Ctx
.emitError("Unable to parse JSON Value as spec (" + Message
+ "): " + S
);
79 // FIXME: accept a Path as a parameter, and use it for error reporting.
80 json::Path::Root
Root("tensor_spec");
81 json::ObjectMapper
Mapper(Value
, Root
);
83 return EmitError("Value is not a dict");
85 std::string TensorName
;
87 std::string TensorType
;
88 std::vector
<int64_t> TensorShape
;
90 if (!Mapper
.map
<std::string
>("name", TensorName
))
91 return EmitError("'name' property not present or not a string");
92 if (!Mapper
.map
<std::string
>("type", TensorType
))
93 return EmitError("'type' property not present or not a string");
94 if (!Mapper
.map
<int>("port", TensorPort
))
95 return EmitError("'port' property not present or not an int");
96 if (!Mapper
.map
<std::vector
<int64_t>>("shape", TensorShape
))
97 return EmitError("'shape' property not present or not an int array");
99 #define PARSE_TYPE(T, E) \
100 if (TensorType == #T) \
101 return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
102 SUPPORTED_TENSOR_TYPES(PARSE_TYPE
)
107 std::string
tensorValueToString(const char *Buffer
, const TensorSpec
&Spec
) {
108 switch (Spec
.type()) {
109 #define _IMR_DBG_PRINTER(T, N) \
110 case TensorType::N: { \
111 const T *TypedBuff = reinterpret_cast<const T *>(Buffer); \
112 auto R = llvm::make_range(TypedBuff, TypedBuff + Spec.getElementCount()); \
114 llvm::map_range(R, [](T V) { return std::to_string(V); }), ","); \
116 SUPPORTED_TENSOR_TYPES(_IMR_DBG_PRINTER
)
117 #undef _IMR_DBG_PRINTER
118 case TensorType::Total
:
119 case TensorType::Invalid
:
120 llvm_unreachable("invalid tensor type");
122 // To appease warnings about not all control paths returning a value.