Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / Test / TestDialectInterfaces.cpp
blob64add8cef36986ac44f6916a6babbcac2c4e6a0d
1 //===- TestDialectInterfaces.cpp - Test dialect interface definitions -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "mlir/Interfaces/FoldInterfaces.h"
12 #include "mlir/Reducer/ReductionPatternInterface.h"
13 #include "mlir/Transforms/InliningUtils.h"
15 using namespace mlir;
16 using namespace test;
18 //===----------------------------------------------------------------------===//
19 // TestDialect Interfaces
20 //===----------------------------------------------------------------------===//
22 namespace {
24 /// Testing the correctness of some traits.
25 static_assert(
26 llvm::is_detected<OpTrait::has_implicit_terminator_t,
27 SingleBlockImplicitTerminatorOp>::value,
28 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
29 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
30 SingleBlockImplicitTerminatorOp>::value,
31 "hasSingleBlockImplicitTerminator does not match "
32 "SingleBlockImplicitTerminatorOp");
34 struct TestResourceBlobManagerInterface
35 : public ResourceBlobManagerDialectInterfaceBase<
36 TestDialectResourceBlobHandle> {
37 using ResourceBlobManagerDialectInterfaceBase<
38 TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase;
41 namespace {
42 enum test_encoding { k_attr_params = 0, k_test_i32 = 99 };
43 } // namespace
45 // Test support for interacting with the Bytecode reader/writer.
46 struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
47 using BytecodeDialectInterface::BytecodeDialectInterface;
48 TestBytecodeDialectInterface(Dialect *dialect)
49 : BytecodeDialectInterface(dialect) {}
51 LogicalResult writeType(Type type,
52 DialectBytecodeWriter &writer) const final {
53 if (auto concreteType = llvm::dyn_cast<TestI32Type>(type)) {
54 writer.writeVarInt(test_encoding::k_test_i32);
55 return success();
57 return failure();
60 Type readType(DialectBytecodeReader &reader) const final {
61 uint64_t encoding;
62 if (failed(reader.readVarInt(encoding)))
63 return Type();
64 if (encoding == test_encoding::k_test_i32)
65 return TestI32Type::get(getContext());
66 return Type();
69 LogicalResult writeAttribute(Attribute attr,
70 DialectBytecodeWriter &writer) const final {
71 if (auto concreteAttr = llvm::dyn_cast<TestAttrParamsAttr>(attr)) {
72 writer.writeVarInt(test_encoding::k_attr_params);
73 writer.writeVarInt(concreteAttr.getV0());
74 writer.writeVarInt(concreteAttr.getV1());
75 return success();
77 return failure();
80 Attribute readAttribute(DialectBytecodeReader &reader) const final {
81 auto versionOr = reader.getDialectVersion<test::TestDialect>();
82 // Assume current version if not available through the reader.
83 const auto version =
84 (succeeded(versionOr))
85 ? *reinterpret_cast<const TestDialectVersion *>(*versionOr)
86 : TestDialectVersion();
87 if (version.major_ < 2)
88 return readAttrOldEncoding(reader);
89 if (version.major_ == 2 && version.minor_ == 0)
90 return readAttrNewEncoding(reader);
91 // Forbid reading future versions by returning nullptr.
92 return Attribute();
95 // Emit a specific version of the dialect.
96 void writeVersion(DialectBytecodeWriter &writer) const final {
97 // Construct the current dialect version.
98 test::TestDialectVersion versionToEmit;
100 // Check if a target version to emit was specified on the writer configs.
101 auto versionOr = writer.getDialectVersion<test::TestDialect>();
102 if (succeeded(versionOr))
103 versionToEmit =
104 *reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
105 writer.writeVarInt(versionToEmit.major_); // major
106 writer.writeVarInt(versionToEmit.minor_); // minor
109 std::unique_ptr<DialectVersion>
110 readVersion(DialectBytecodeReader &reader) const final {
111 uint64_t major_, minor_;
112 if (failed(reader.readVarInt(major_)) || failed(reader.readVarInt(minor_)))
113 return nullptr;
114 auto version = std::make_unique<TestDialectVersion>();
115 version->major_ = major_;
116 version->minor_ = minor_;
117 return version;
120 LogicalResult upgradeFromVersion(Operation *topLevelOp,
121 const DialectVersion &version_) const final {
122 const auto &version = static_cast<const TestDialectVersion &>(version_);
123 if ((version.major_ == 2) && (version.minor_ == 0))
124 return success();
125 if (version.major_ > 2 || (version.major_ == 2 && version.minor_ > 0)) {
126 return topLevelOp->emitError()
127 << "current test dialect version is 2.0, can't parse version: "
128 << version.major_ << "." << version.minor_;
130 // Prior version 2.0, the old op supported only a single attribute called
131 // "dimensions". We can perform the upgrade.
132 topLevelOp->walk([](TestVersionedOpA op) {
133 // Prior version 2.0, `readProperties` did not process the modifier
134 // attribute. Handle that according to the version here.
135 auto &prop = op.getProperties();
136 prop.modifier = BoolAttr::get(op->getContext(), false);
138 return success();
141 private:
142 Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const {
143 uint64_t encoding;
144 if (failed(reader.readVarInt(encoding)) ||
145 encoding != test_encoding::k_attr_params)
146 return Attribute();
147 // The new encoding has v0 first, v1 second.
148 uint64_t v0, v1;
149 if (failed(reader.readVarInt(v0)) || failed(reader.readVarInt(v1)))
150 return Attribute();
151 return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
152 static_cast<int>(v1));
155 Attribute readAttrOldEncoding(DialectBytecodeReader &reader) const {
156 uint64_t encoding;
157 if (failed(reader.readVarInt(encoding)) ||
158 encoding != test_encoding::k_attr_params)
159 return Attribute();
160 // The old encoding has v1 first, v0 second.
161 uint64_t v0, v1;
162 if (failed(reader.readVarInt(v1)) || failed(reader.readVarInt(v0)))
163 return Attribute();
164 return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
165 static_cast<int>(v1));
169 // Test support for interacting with the AsmPrinter.
170 struct TestOpAsmInterface : public OpAsmDialectInterface {
171 using OpAsmDialectInterface::OpAsmDialectInterface;
172 TestOpAsmInterface(Dialect *dialect, TestResourceBlobManagerInterface &mgr)
173 : OpAsmDialectInterface(dialect), blobManager(mgr) {}
175 //===------------------------------------------------------------------===//
176 // Aliases
177 //===------------------------------------------------------------------===//
179 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
180 StringAttr strAttr = dyn_cast<StringAttr>(attr);
181 if (!strAttr)
182 return AliasResult::NoAlias;
184 // Check the contents of the string attribute to see what the test alias
185 // should be named.
186 std::optional<StringRef> aliasName =
187 StringSwitch<std::optional<StringRef>>(strAttr.getValue())
188 .Case("alias_test:dot_in_name", StringRef("test.alias"))
189 .Case("alias_test:trailing_digit", StringRef("test_alias0"))
190 .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
191 .Case("alias_test:prefixed_symbol", StringRef("%test"))
192 .Case("alias_test:sanitize_conflict_a",
193 StringRef("test_alias_conflict0"))
194 .Case("alias_test:sanitize_conflict_b",
195 StringRef("test_alias_conflict0_"))
196 .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
197 .Default(std::nullopt);
198 if (!aliasName)
199 return AliasResult::NoAlias;
201 os << *aliasName;
202 return AliasResult::FinalAlias;
205 AliasResult getAlias(Type type, raw_ostream &os) const final {
206 if (auto tupleType = dyn_cast<TupleType>(type)) {
207 if (tupleType.size() > 0 &&
208 llvm::all_of(tupleType.getTypes(), [](Type elemType) {
209 return isa<SimpleAType>(elemType);
210 })) {
211 os << "test_tuple";
212 return AliasResult::FinalAlias;
215 if (auto intType = dyn_cast<TestIntegerType>(type)) {
216 if (intType.getSignedness() ==
217 TestIntegerType::SignednessSemantics::Unsigned &&
218 intType.getWidth() == 8) {
219 os << "test_ui8";
220 return AliasResult::FinalAlias;
223 if (auto recType = dyn_cast<TestRecursiveType>(type)) {
224 if (recType.getName() == "type_to_alias") {
225 // We only make alias for a specific recursive type.
226 os << "testrec";
227 return AliasResult::FinalAlias;
230 if (auto recAliasType = dyn_cast<TestRecursiveAliasType>(type)) {
231 os << recAliasType.getName();
232 return AliasResult::FinalAlias;
234 return AliasResult::NoAlias;
237 //===------------------------------------------------------------------===//
238 // Resources
239 //===------------------------------------------------------------------===//
241 std::string
242 getResourceKey(const AsmDialectResourceHandle &handle) const override {
243 return cast<TestDialectResourceBlobHandle>(handle).getKey().str();
246 FailureOr<AsmDialectResourceHandle>
247 declareResource(StringRef key) const final {
248 return blobManager.insert(key);
251 LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
252 FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
253 if (failed(blob))
254 return failure();
256 // Update the blob for this entry.
257 blobManager.update(entry.getKey(), std::move(*blob));
258 return success();
261 void
262 buildResources(Operation *op,
263 const SetVector<AsmDialectResourceHandle> &referencedResources,
264 AsmResourceBuilder &provider) const final {
265 blobManager.buildResources(provider, referencedResources.getArrayRef());
268 private:
269 /// The blob manager for the dialect.
270 TestResourceBlobManagerInterface &blobManager;
273 struct TestDialectFoldInterface : public DialectFoldInterface {
274 using DialectFoldInterface::DialectFoldInterface;
276 /// Registered hook to check if the given region, which is attached to an
277 /// operation that is *not* isolated from above, should be used when
278 /// materializing constants.
279 bool shouldMaterializeInto(Region *region) const final {
280 // If this is a one region operation, then insert into it.
281 return isa<OneRegionOp>(region->getParentOp());
285 /// This class defines the interface for handling inlining with standard
286 /// operations.
287 struct TestInlinerInterface : public DialectInlinerInterface {
288 using DialectInlinerInterface::DialectInlinerInterface;
290 //===--------------------------------------------------------------------===//
291 // Analysis Hooks
292 //===--------------------------------------------------------------------===//
294 bool isLegalToInline(Operation *call, Operation *callable,
295 bool wouldBeCloned) const final {
296 // Don't allow inlining calls that are marked `noinline`.
297 return !call->hasAttr("noinline");
299 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
300 // Inlining into test dialect regions is legal.
301 return true;
303 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
304 return true;
307 bool shouldAnalyzeRecursively(Operation *op) const final {
308 // Analyze recursively if this is not a functional region operation, it
309 // froms a separate functional scope.
310 return !isa<FunctionalRegionOp>(op);
313 //===--------------------------------------------------------------------===//
314 // Transformation Hooks
315 //===--------------------------------------------------------------------===//
317 /// Handle the given inlined terminator by replacing it with a new operation
318 /// as necessary.
319 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
320 // Only handle "test.return" here.
321 auto returnOp = dyn_cast<TestReturnOp>(op);
322 if (!returnOp)
323 return;
325 // Replace the values directly with the return operands.
326 assert(returnOp.getNumOperands() == valuesToRepl.size());
327 for (const auto &it : llvm::enumerate(returnOp.getOperands()))
328 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
331 /// Attempt to materialize a conversion for a type mismatch between a call
332 /// from this dialect, and a callable region. This method should generate an
333 /// operation that takes 'input' as the only operand, and produces a single
334 /// result of 'resultType'. If a conversion can not be generated, nullptr
335 /// should be returned.
336 Operation *materializeCallConversion(OpBuilder &builder, Value input,
337 Type resultType,
338 Location conversionLoc) const final {
339 // Only allow conversion for i16/i32 types.
340 if (!(resultType.isSignlessInteger(16) ||
341 resultType.isSignlessInteger(32)) ||
342 !(input.getType().isSignlessInteger(16) ||
343 input.getType().isSignlessInteger(32)))
344 return nullptr;
345 return builder.create<TestCastOp>(conversionLoc, resultType, input);
348 Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
349 Value argument,
350 DictionaryAttr argumentAttrs) const final {
351 if (!argumentAttrs.contains("test.handle_argument"))
352 return argument;
353 return builder.create<TestTypeChangerOp>(call->getLoc(), argument.getType(),
354 argument);
357 Value handleResult(OpBuilder &builder, Operation *call, Operation *callable,
358 Value result, DictionaryAttr resultAttrs) const final {
359 if (!resultAttrs.contains("test.handle_result"))
360 return result;
361 return builder.create<TestTypeChangerOp>(call->getLoc(), result.getType(),
362 result);
365 void processInlinedCallBlocks(
366 Operation *call,
367 iterator_range<Region::iterator> inlinedBlocks) const final {
368 if (!isa<ConversionCallOp>(call))
369 return;
371 // Set attributed on all ops in the inlined blocks.
372 for (Block &block : inlinedBlocks) {
373 block.walk([&](Operation *op) {
374 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
380 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
381 public:
382 TestReductionPatternInterface(Dialect *dialect)
383 : DialectReductionPatternInterface(dialect) {}
385 void populateReductionPatterns(RewritePatternSet &patterns) const final {
386 populateTestReductionPatterns(patterns);
390 } // namespace
392 void TestDialect::registerInterfaces() {
393 auto &blobInterface = addInterface<TestResourceBlobManagerInterface>();
394 addInterface<TestOpAsmInterface>(blobInterface);
396 addInterfaces<TestDialectFoldInterface, TestInlinerInterface,
397 TestReductionPatternInterface, TestBytecodeDialectInterface>();