1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "TestDialect.h"
11 #include "TestTypes.h"
12 #include "mlir/Bytecode/BytecodeImplementation.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/ExtensibleDialect.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/ODSSupport.h"
23 #include "mlir/IR/OperationSupport.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/TypeUtilities.h"
26 #include "mlir/IR/Verifier.h"
27 #include "mlir/Interfaces/CallInterfaces.h"
28 #include "mlir/Interfaces/FunctionImplementation.h"
29 #include "mlir/Interfaces/InferIntRangeInterface.h"
30 #include "mlir/Support/LLVM.h"
31 #include "mlir/Transforms/FoldUtils.h"
32 #include "mlir/Transforms/InliningUtils.h"
33 #include "llvm/ADT/STLFunctionalExtras.h"
34 #include "llvm/ADT/SmallString.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/ADT/StringSwitch.h"
37 #include "llvm/Support/Base64.h"
38 #include "llvm/Support/Casting.h"
40 #include "mlir/Dialect/Arith/IR/Arith.h"
41 #include "mlir/Dialect/DLTI/DLTI.h"
42 #include "mlir/Interfaces/FoldInterfaces.h"
43 #include "mlir/Reducer/ReductionPatternInterface.h"
44 #include "mlir/Transforms/InliningUtils.h"
49 // Include this before the using namespace lines below to test that we don't
50 // have namespace dependencies.
51 #include "TestOpsDialect.cpp.inc"
56 //===----------------------------------------------------------------------===//
57 // PropertiesWithCustomPrint
58 //===----------------------------------------------------------------------===//
61 test::setPropertiesFromAttribute(PropertiesWithCustomPrint
&prop
,
63 function_ref
<InFlightDiagnostic()> emitError
) {
64 DictionaryAttr dict
= dyn_cast
<DictionaryAttr
>(attr
);
66 emitError() << "expected DictionaryAttr to set TestProperties";
69 auto label
= dict
.getAs
<mlir::StringAttr
>("label");
71 emitError() << "expected StringAttr for key `label`";
74 auto valueAttr
= dict
.getAs
<IntegerAttr
>("value");
76 emitError() << "expected IntegerAttr for key `value`";
80 prop
.label
= std::make_shared
<std::string
>(label
.getValue());
81 prop
.value
= valueAttr
.getValue().getSExtValue();
86 test::getPropertiesAsAttribute(MLIRContext
*ctx
,
87 const PropertiesWithCustomPrint
&prop
) {
88 SmallVector
<NamedAttribute
> attrs
;
90 attrs
.push_back(b
.getNamedAttr("label", b
.getStringAttr(*prop
.label
)));
91 attrs
.push_back(b
.getNamedAttr("value", b
.getI32IntegerAttr(prop
.value
)));
92 return b
.getDictionaryAttr(attrs
);
95 llvm::hash_code
test::computeHash(const PropertiesWithCustomPrint
&prop
) {
96 return llvm::hash_combine(prop
.value
, StringRef(*prop
.label
));
99 void test::customPrintProperties(OpAsmPrinter
&p
,
100 const PropertiesWithCustomPrint
&prop
) {
101 p
.printKeywordOrString(*prop
.label
);
102 p
<< " is " << prop
.value
;
105 ParseResult
test::customParseProperties(OpAsmParser
&parser
,
106 PropertiesWithCustomPrint
&prop
) {
108 if (parser
.parseKeywordOrString(&label
) || parser
.parseKeyword("is") ||
109 parser
.parseInteger(prop
.value
))
111 prop
.label
= std::make_shared
<std::string
>(std::move(label
));
115 //===----------------------------------------------------------------------===//
117 //===----------------------------------------------------------------------===//
119 Attribute
MyPropStruct::asAttribute(MLIRContext
*ctx
) const {
120 return StringAttr::get(ctx
, content
);
124 MyPropStruct::setFromAttr(MyPropStruct
&prop
, Attribute attr
,
125 function_ref
<InFlightDiagnostic()> emitError
) {
126 StringAttr strAttr
= dyn_cast
<StringAttr
>(attr
);
128 emitError() << "Expect StringAttr but got " << attr
;
131 prop
.content
= strAttr
.getValue();
135 llvm::hash_code
MyPropStruct::hash() const {
136 return hash_value(StringRef(content
));
139 LogicalResult
test::readFromMlirBytecode(DialectBytecodeReader
&reader
,
140 MyPropStruct
&prop
) {
142 if (failed(reader
.readString(str
)))
144 prop
.content
= str
.str();
148 void test::writeToMlirBytecode(DialectBytecodeWriter
&writer
,
149 MyPropStruct
&prop
) {
150 writer
.writeOwnedString(prop
.content
);
153 //===----------------------------------------------------------------------===//
154 // VersionedProperties
155 //===----------------------------------------------------------------------===//
158 test::setPropertiesFromAttribute(VersionedProperties
&prop
, Attribute attr
,
159 function_ref
<InFlightDiagnostic()> emitError
) {
160 DictionaryAttr dict
= dyn_cast
<DictionaryAttr
>(attr
);
162 emitError() << "expected DictionaryAttr to set VersionedProperties";
165 auto value1Attr
= dict
.getAs
<IntegerAttr
>("value1");
167 emitError() << "expected IntegerAttr for key `value1`";
170 auto value2Attr
= dict
.getAs
<IntegerAttr
>("value2");
172 emitError() << "expected IntegerAttr for key `value2`";
176 prop
.value1
= value1Attr
.getValue().getSExtValue();
177 prop
.value2
= value2Attr
.getValue().getSExtValue();
181 DictionaryAttr
test::getPropertiesAsAttribute(MLIRContext
*ctx
,
182 const VersionedProperties
&prop
) {
183 SmallVector
<NamedAttribute
> attrs
;
185 attrs
.push_back(b
.getNamedAttr("value1", b
.getI32IntegerAttr(prop
.value1
)));
186 attrs
.push_back(b
.getNamedAttr("value2", b
.getI32IntegerAttr(prop
.value2
)));
187 return b
.getDictionaryAttr(attrs
);
190 llvm::hash_code
test::computeHash(const VersionedProperties
&prop
) {
191 return llvm::hash_combine(prop
.value1
, prop
.value2
);
194 void test::customPrintProperties(OpAsmPrinter
&p
,
195 const VersionedProperties
&prop
) {
196 p
<< prop
.value1
<< " | " << prop
.value2
;
199 ParseResult
test::customParseProperties(OpAsmParser
&parser
,
200 VersionedProperties
&prop
) {
201 if (parser
.parseInteger(prop
.value1
) || parser
.parseVerticalBar() ||
202 parser
.parseInteger(prop
.value2
))
207 //===----------------------------------------------------------------------===//
209 //===----------------------------------------------------------------------===//
211 LogicalResult
test::readFromMlirBytecode(DialectBytecodeReader
&reader
,
212 MutableArrayRef
<int64_t> prop
) {
214 if (failed(reader
.readVarInt(size
)))
216 if (size
!= prop
.size())
217 return reader
.emitError("array size mismach when reading properties: ")
218 << size
<< " vs expected " << prop
.size();
219 for (auto &elt
: prop
) {
221 if (failed(reader
.readVarInt(value
)))
228 void test::writeToMlirBytecode(DialectBytecodeWriter
&writer
,
229 ArrayRef
<int64_t> prop
) {
230 writer
.writeVarInt(prop
.size());
231 for (auto elt
: prop
)
232 writer
.writeVarInt(elt
);
235 //===----------------------------------------------------------------------===//
236 // Dynamic operations
237 //===----------------------------------------------------------------------===//
239 std::unique_ptr
<DynamicOpDefinition
> getDynamicGenericOp(TestDialect
*dialect
) {
240 return DynamicOpDefinition::get(
241 "dynamic_generic", dialect
, [](Operation
*op
) { return success(); },
242 [](Operation
*op
) { return success(); });
245 std::unique_ptr
<DynamicOpDefinition
>
246 getDynamicOneOperandTwoResultsOp(TestDialect
*dialect
) {
247 return DynamicOpDefinition::get(
248 "dynamic_one_operand_two_results", dialect
,
250 if (op
->getNumOperands() != 1) {
252 << "expected 1 operand, but had " << op
->getNumOperands();
255 if (op
->getNumResults() != 2) {
257 << "expected 2 results, but had " << op
->getNumResults();
262 [](Operation
*op
) { return success(); });
265 std::unique_ptr
<DynamicOpDefinition
>
266 getDynamicCustomParserPrinterOp(TestDialect
*dialect
) {
267 auto verifier
= [](Operation
*op
) {
268 if (op
->getNumOperands() == 0 && op
->getNumResults() == 0)
270 op
->emitError() << "operation should have no operands and no results";
273 auto regionVerifier
= [](Operation
*op
) { return success(); };
275 auto parser
= [](OpAsmParser
&parser
, OperationState
&state
) {
276 return parser
.parseKeyword("custom_keyword");
279 auto printer
= [](Operation
*op
, OpAsmPrinter
&printer
, llvm::StringRef
) {
280 printer
<< op
->getName() << " custom_keyword";
283 return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect
,
284 verifier
, regionVerifier
, parser
, printer
);
287 //===----------------------------------------------------------------------===//
289 //===----------------------------------------------------------------------===//
291 void test::registerTestDialect(DialectRegistry
®istry
) {
292 registry
.insert
<TestDialect
>();
295 void test::testSideEffectOpGetEffect(
297 SmallVectorImpl
<SideEffects::EffectInstance
<TestEffects::Effect
>>
299 auto effectsAttr
= op
->getAttrOfType
<AffineMapAttr
>("effect_parameter");
303 effects
.emplace_back(TestEffects::Concrete::get(), effectsAttr
);
306 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
307 struct TestOpEffectInterfaceFallback
308 : public TestEffectOpInterface::FallbackModel
<
309 TestOpEffectInterfaceFallback
> {
310 static bool classof(Operation
*op
) {
312 op
->getName().getStringRef() == "test.unregistered_side_effect_op";
313 assert(isSupportedOp
&& "Unexpected dispatch");
314 return isSupportedOp
;
318 getEffects(Operation
*op
,
319 SmallVectorImpl
<SideEffects::EffectInstance
<TestEffects::Effect
>>
321 testSideEffectOpGetEffect(op
, effects
);
325 void TestDialect::initialize() {
326 registerAttributes();
329 addOperations
<ManualCppOpWithFold
>();
330 registerTestDialectOperations(this);
331 registerDynamicOp(getDynamicGenericOp(this));
332 registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
333 registerDynamicOp(getDynamicCustomParserPrinterOp(this));
334 registerInterfaces();
335 allowUnknownOperations();
337 // Instantiate our fallback op interface that we'll use on specific
339 fallbackEffectOpInterfaces
= new TestOpEffectInterfaceFallback
;
342 TestDialect::~TestDialect() {
343 delete static_cast<TestOpEffectInterfaceFallback
*>(
344 fallbackEffectOpInterfaces
);
347 Operation
*TestDialect::materializeConstant(OpBuilder
&builder
, Attribute value
,
348 Type type
, Location loc
) {
349 return builder
.create
<TestOpConstant
>(loc
, type
, value
);
352 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID
,
353 OperationName opName
) {
354 if (opName
.getIdentifier() == "test.unregistered_side_effect_op" &&
355 typeID
== TypeID::get
<TestEffectOpInterface
>())
356 return fallbackEffectOpInterfaces
;
360 LogicalResult
TestDialect::verifyOperationAttribute(Operation
*op
,
361 NamedAttribute namedAttr
) {
362 if (namedAttr
.getName() == "test.invalid_attr")
363 return op
->emitError() << "invalid to use 'test.invalid_attr'";
367 LogicalResult
TestDialect::verifyRegionArgAttribute(Operation
*op
,
368 unsigned regionIndex
,
370 NamedAttribute namedAttr
) {
371 if (namedAttr
.getName() == "test.invalid_attr")
372 return op
->emitError() << "invalid to use 'test.invalid_attr'";
377 TestDialect::verifyRegionResultAttribute(Operation
*op
, unsigned regionIndex
,
378 unsigned resultIndex
,
379 NamedAttribute namedAttr
) {
380 if (namedAttr
.getName() == "test.invalid_attr")
381 return op
->emitError() << "invalid to use 'test.invalid_attr'";
385 std::optional
<Dialect::ParseOpHook
>
386 TestDialect::getParseOperationHook(StringRef opName
) const {
387 if (opName
== "test.dialect_custom_printer") {
388 return ParseOpHook
{[](OpAsmParser
&parser
, OperationState
&state
) {
389 return parser
.parseKeyword("custom_format");
392 if (opName
== "test.dialect_custom_format_fallback") {
393 return ParseOpHook
{[](OpAsmParser
&parser
, OperationState
&state
) {
394 return parser
.parseKeyword("custom_format_fallback");
397 if (opName
== "test.dialect_custom_printer.with.dot") {
398 return ParseOpHook
{[](OpAsmParser
&parser
, OperationState
&state
) {
399 return ParseResult::success();
405 llvm::unique_function
<void(Operation
*, OpAsmPrinter
&)>
406 TestDialect::getOperationPrinter(Operation
*op
) const {
407 StringRef opName
= op
->getName().getStringRef();
408 if (opName
== "test.dialect_custom_printer") {
409 return [](Operation
*op
, OpAsmPrinter
&printer
) {
410 printer
.getStream() << " custom_format";
413 if (opName
== "test.dialect_custom_format_fallback") {
414 return [](Operation
*op
, OpAsmPrinter
&printer
) {
415 printer
.getStream() << " custom_format_fallback";
422 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op
,
423 PatternRewriter
&rewriter
) {
424 rewriter
.replaceOpWithNewOp
<arith::ConstantOp
>(
425 op
, rewriter
.getI32IntegerAttr(42));
429 void TestDialect::getCanonicalizationPatterns(
430 RewritePatternSet
&results
) const {
431 results
.add(&dialectCanonicalizationPattern
);