Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / Test / TestOpDefs.cpp
blobb268e549b93ab6a1ac6f1552904ce6347f890e07
1 //===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "mlir/Dialect/Tensor/IR/Tensor.h"
12 #include "mlir/IR/Verifier.h"
13 #include "mlir/Interfaces/FunctionImplementation.h"
14 #include "mlir/Interfaces/MemorySlotInterfaces.h"
16 using namespace mlir;
17 using namespace test;
19 //===----------------------------------------------------------------------===//
20 // TestBranchOp
21 //===----------------------------------------------------------------------===//
23 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
24 assert(index == 0 && "invalid successor index");
25 return SuccessorOperands(getTargetOperandsMutable());
28 //===----------------------------------------------------------------------===//
29 // TestProducingBranchOp
30 //===----------------------------------------------------------------------===//
32 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
33 assert(index <= 1 && "invalid successor index");
34 if (index == 1)
35 return SuccessorOperands(getFirstOperandsMutable());
36 return SuccessorOperands(getSecondOperandsMutable());
39 //===----------------------------------------------------------------------===//
40 // TestInternalBranchOp
41 //===----------------------------------------------------------------------===//
43 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
44 assert(index <= 1 && "invalid successor index");
45 if (index == 0)
46 return SuccessorOperands(0, getSuccessOperandsMutable());
47 return SuccessorOperands(1, getErrorOperandsMutable());
50 //===----------------------------------------------------------------------===//
51 // TestCallOp
52 //===----------------------------------------------------------------------===//
54 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
55 // Check that the callee attribute was specified.
56 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
57 if (!fnAttr)
58 return emitOpError("requires a 'callee' symbol reference attribute");
59 if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
60 return emitOpError() << "'" << fnAttr.getValue()
61 << "' does not reference a valid function";
62 return success();
65 //===----------------------------------------------------------------------===//
66 // FoldToCallOp
67 //===----------------------------------------------------------------------===//
69 namespace {
70 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
71 using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
73 LogicalResult matchAndRewrite(FoldToCallOp op,
74 PatternRewriter &rewriter) const override {
75 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
76 op.getCalleeAttr(), ValueRange());
77 return success();
80 } // namespace
82 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
83 MLIRContext *context) {
84 results.add<FoldToCallOpPattern>(context);
87 //===----------------------------------------------------------------------===//
88 // IsolatedRegionOp - test parsing passthrough operands
89 //===----------------------------------------------------------------------===//
91 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
92 OperationState &result) {
93 // Parse the input operand.
94 OpAsmParser::Argument argInfo;
95 argInfo.type = parser.getBuilder().getIndexType();
96 if (parser.parseOperand(argInfo.ssaName) ||
97 parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
98 return failure();
100 // Parse the body region, and reuse the operand info as the argument info.
101 Region *body = result.addRegion();
102 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
105 void IsolatedRegionOp::print(OpAsmPrinter &p) {
106 p << ' ';
107 p.printOperand(getOperand());
108 p.shadowRegionArgs(getRegion(), getOperand());
109 p << ' ';
110 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
113 //===----------------------------------------------------------------------===//
114 // SSACFGRegionOp
115 //===----------------------------------------------------------------------===//
117 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
118 return RegionKind::SSACFG;
121 //===----------------------------------------------------------------------===//
122 // GraphRegionOp
123 //===----------------------------------------------------------------------===//
125 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
126 return RegionKind::Graph;
129 //===----------------------------------------------------------------------===//
130 // IsolatedGraphRegionOp
131 //===----------------------------------------------------------------------===//
133 RegionKind IsolatedGraphRegionOp::getRegionKind(unsigned index) {
134 return RegionKind::Graph;
137 //===----------------------------------------------------------------------===//
138 // AffineScopeOp
139 //===----------------------------------------------------------------------===//
141 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
142 // Parse the body region, and reuse the operand info as the argument info.
143 Region *body = result.addRegion();
144 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
147 void AffineScopeOp::print(OpAsmPrinter &p) {
148 p << " ";
149 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
152 //===----------------------------------------------------------------------===//
153 // TestRemoveOpWithInnerOps
154 //===----------------------------------------------------------------------===//
156 namespace {
157 struct TestRemoveOpWithInnerOps
158 : public OpRewritePattern<TestOpWithRegionPattern> {
159 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
161 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
163 LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
164 PatternRewriter &rewriter) const override {
165 rewriter.eraseOp(op);
166 return success();
169 } // namespace
171 //===----------------------------------------------------------------------===//
172 // TestOpWithRegionPattern
173 //===----------------------------------------------------------------------===//
175 void TestOpWithRegionPattern::getCanonicalizationPatterns(
176 RewritePatternSet &results, MLIRContext *context) {
177 results.add<TestRemoveOpWithInnerOps>(context);
180 //===----------------------------------------------------------------------===//
181 // TestOpWithRegionFold
182 //===----------------------------------------------------------------------===//
184 OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
185 return getOperand();
188 //===----------------------------------------------------------------------===//
189 // TestOpConstant
190 //===----------------------------------------------------------------------===//
192 OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
194 //===----------------------------------------------------------------------===//
195 // TestOpWithVariadicResultsAndFolder
196 //===----------------------------------------------------------------------===//
198 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
199 FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
200 for (Value input : this->getOperands()) {
201 results.push_back(input);
203 return success();
206 //===----------------------------------------------------------------------===//
207 // TestOpInPlaceFold
208 //===----------------------------------------------------------------------===//
210 OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
211 // Exercise the fact that an operation created with createOrFold should be
212 // allowed to access its parent block.
213 assert(getOperation()->getBlock() &&
214 "expected that operation is not unlinked");
216 if (adaptor.getOp() && !getProperties().attr) {
217 // The folder adds "attr" if not present.
218 getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp());
219 return getResult();
221 return {};
224 //===----------------------------------------------------------------------===//
225 // OpWithInferTypeInterfaceOp
226 //===----------------------------------------------------------------------===//
228 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
229 MLIRContext *, std::optional<Location> location, ValueRange operands,
230 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
231 SmallVectorImpl<Type> &inferredReturnTypes) {
232 if (operands[0].getType() != operands[1].getType()) {
233 return emitOptionalError(location, "operand type mismatch ",
234 operands[0].getType(), " vs ",
235 operands[1].getType());
237 inferredReturnTypes.assign({operands[0].getType()});
238 return success();
241 //===----------------------------------------------------------------------===//
242 // OpWithShapedTypeInferTypeInterfaceOp
243 //===----------------------------------------------------------------------===//
245 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
246 MLIRContext *context, std::optional<Location> location,
247 ValueShapeRange operands, DictionaryAttr attributes,
248 OpaqueProperties properties, RegionRange regions,
249 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
250 // Create return type consisting of the last element of the first operand.
251 auto operandType = operands.front().getType();
252 auto sval = dyn_cast<ShapedType>(operandType);
253 if (!sval)
254 return emitOptionalError(location, "only shaped type operands allowed");
255 int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
256 auto type = IntegerType::get(context, 17);
258 Attribute encoding;
259 if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
260 encoding = rankedTy.getEncoding();
261 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
262 return success();
265 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
266 OpBuilder &builder, ValueRange operands,
267 llvm::SmallVectorImpl<Value> &shapes) {
268 shapes = SmallVector<Value, 1>{
269 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
270 return success();
273 //===----------------------------------------------------------------------===//
274 // OpWithResultShapeInterfaceOp
275 //===----------------------------------------------------------------------===//
277 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
278 OpBuilder &builder, ValueRange operands,
279 llvm::SmallVectorImpl<Value> &shapes) {
280 Location loc = getLoc();
281 shapes.reserve(operands.size());
282 for (Value operand : llvm::reverse(operands)) {
283 auto rank = cast<RankedTensorType>(operand.getType()).getRank();
284 auto currShape = llvm::to_vector<4>(
285 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
286 return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
287 }));
288 shapes.push_back(builder.create<tensor::FromElementsOp>(
289 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
290 currShape));
292 return success();
295 //===----------------------------------------------------------------------===//
296 // OpWithResultShapePerDimInterfaceOp
297 //===----------------------------------------------------------------------===//
299 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
300 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
301 Location loc = getLoc();
302 shapes.reserve(getNumOperands());
303 for (Value operand : llvm::reverse(getOperands())) {
304 auto tensorType = cast<RankedTensorType>(operand.getType());
305 auto currShape = llvm::to_vector<4>(llvm::map_range(
306 llvm::seq<int64_t>(0, tensorType.getRank()),
307 [&](int64_t dim) -> OpFoldResult {
308 return tensorType.isDynamicDim(dim)
309 ? static_cast<OpFoldResult>(
310 builder.createOrFold<tensor::DimOp>(loc, operand,
311 dim))
312 : static_cast<OpFoldResult>(
313 builder.getIndexAttr(tensorType.getDimSize(dim)));
314 }));
315 shapes.emplace_back(std::move(currShape));
317 return success();
320 //===----------------------------------------------------------------------===//
321 // SideEffectOp
322 //===----------------------------------------------------------------------===//
324 namespace {
325 /// A test resource for side effects.
326 struct TestResource : public SideEffects::Resource::Base<TestResource> {
327 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
329 StringRef getName() final { return "<Test>"; }
331 } // namespace
333 void SideEffectOp::getEffects(
334 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
335 // Check for an effects attribute on the op instance.
336 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
337 if (!effectsAttr)
338 return;
340 for (Attribute element : effectsAttr) {
341 DictionaryAttr effectElement = cast<DictionaryAttr>(element);
343 // Get the specific memory effect.
344 MemoryEffects::Effect *effect =
345 StringSwitch<MemoryEffects::Effect *>(
346 cast<StringAttr>(effectElement.get("effect")).getValue())
347 .Case("allocate", MemoryEffects::Allocate::get())
348 .Case("free", MemoryEffects::Free::get())
349 .Case("read", MemoryEffects::Read::get())
350 .Case("write", MemoryEffects::Write::get());
352 // Check for a non-default resource to use.
353 SideEffects::Resource *resource = SideEffects::DefaultResource::get();
354 if (effectElement.get("test_resource"))
355 resource = TestResource::get();
357 // Check for a result to affect.
358 if (effectElement.get("on_result"))
359 effects.emplace_back(effect, getOperation()->getOpResults()[0], resource);
360 else if (Attribute ref = effectElement.get("on_reference"))
361 effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
362 else
363 effects.emplace_back(effect, resource);
367 void SideEffectOp::getEffects(
368 SmallVectorImpl<TestEffects::EffectInstance> &effects) {
369 testSideEffectOpGetEffect(getOperation(), effects);
372 void SideEffectWithRegionOp::getEffects(
373 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
374 // Check for an effects attribute on the op instance.
375 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
376 if (!effectsAttr)
377 return;
379 for (Attribute element : effectsAttr) {
380 DictionaryAttr effectElement = cast<DictionaryAttr>(element);
382 // Get the specific memory effect.
383 MemoryEffects::Effect *effect =
384 StringSwitch<MemoryEffects::Effect *>(
385 cast<StringAttr>(effectElement.get("effect")).getValue())
386 .Case("allocate", MemoryEffects::Allocate::get())
387 .Case("free", MemoryEffects::Free::get())
388 .Case("read", MemoryEffects::Read::get())
389 .Case("write", MemoryEffects::Write::get());
391 // Check for a non-default resource to use.
392 SideEffects::Resource *resource = SideEffects::DefaultResource::get();
393 if (effectElement.get("test_resource"))
394 resource = TestResource::get();
396 // Check for a result to affect.
397 if (effectElement.get("on_result"))
398 effects.emplace_back(effect, getOperation()->getOpResults()[0], resource);
399 else if (effectElement.get("on_operand"))
400 effects.emplace_back(effect, &getOperation()->getOpOperands()[0],
401 resource);
402 else if (effectElement.get("on_argument"))
403 effects.emplace_back(effect, getOperation()->getRegion(0).getArgument(0),
404 resource);
405 else if (Attribute ref = effectElement.get("on_reference"))
406 effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
407 else
408 effects.emplace_back(effect, resource);
412 void SideEffectWithRegionOp::getEffects(
413 SmallVectorImpl<TestEffects::EffectInstance> &effects) {
414 testSideEffectOpGetEffect(getOperation(), effects);
417 //===----------------------------------------------------------------------===//
418 // StringAttrPrettyNameOp
419 //===----------------------------------------------------------------------===//
421 // This op has fancy handling of its SSA result name.
422 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
423 OperationState &result) {
424 // Add the result types.
425 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
426 result.addTypes(parser.getBuilder().getIntegerType(32));
428 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
429 return failure();
431 // If the attribute dictionary contains no 'names' attribute, infer it from
432 // the SSA name (if specified).
433 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
434 return attr.getName() == "names";
437 // If there was no name specified, check to see if there was a useful name
438 // specified in the asm file.
439 if (hadNames || parser.getNumResults() == 0)
440 return success();
442 SmallVector<StringRef, 4> names;
443 auto *context = result.getContext();
445 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
446 auto resultName = parser.getResultName(i);
447 StringRef nameStr;
448 if (!resultName.first.empty() && !isdigit(resultName.first[0]))
449 nameStr = resultName.first;
451 names.push_back(nameStr);
454 auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
455 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
456 return success();
459 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
460 // Note that we only need to print the "name" attribute if the asmprinter
461 // result name disagrees with it. This can happen in strange cases, e.g.
462 // when there are conflicts.
463 bool namesDisagree = getNames().size() != getNumResults();
465 SmallString<32> resultNameStr;
466 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
467 resultNameStr.clear();
468 llvm::raw_svector_ostream tmpStream(resultNameStr);
469 p.printOperand(getResult(i), tmpStream);
471 auto expectedName = dyn_cast<StringAttr>(getNames()[i]);
472 if (!expectedName ||
473 tmpStream.str().drop_front() != expectedName.getValue()) {
474 namesDisagree = true;
478 if (namesDisagree)
479 p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
480 else
481 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
484 // We set the SSA name in the asm syntax to the contents of the name
485 // attribute.
486 void StringAttrPrettyNameOp::getAsmResultNames(
487 function_ref<void(Value, StringRef)> setNameFn) {
489 auto value = getNames();
490 for (size_t i = 0, e = value.size(); i != e; ++i)
491 if (auto str = dyn_cast<StringAttr>(value[i]))
492 if (!str.getValue().empty())
493 setNameFn(getResult(i), str.getValue());
496 //===----------------------------------------------------------------------===//
497 // CustomResultsNameOp
498 //===----------------------------------------------------------------------===//
500 void CustomResultsNameOp::getAsmResultNames(
501 function_ref<void(Value, StringRef)> setNameFn) {
502 ArrayAttr value = getNames();
503 for (size_t i = 0, e = value.size(); i != e; ++i)
504 if (auto str = dyn_cast<StringAttr>(value[i]))
505 if (!str.empty())
506 setNameFn(getResult(i), str.getValue());
509 //===----------------------------------------------------------------------===//
510 // ResultTypeWithTraitOp
511 //===----------------------------------------------------------------------===//
513 LogicalResult ResultTypeWithTraitOp::verify() {
514 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
515 return success();
516 return emitError("result type should have trait 'TestTypeTrait'");
519 //===----------------------------------------------------------------------===//
520 // AttrWithTraitOp
521 //===----------------------------------------------------------------------===//
523 LogicalResult AttrWithTraitOp::verify() {
524 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
525 return success();
526 return emitError("'attr' attribute should have trait 'TestAttrTrait'");
529 //===----------------------------------------------------------------------===//
530 // RegionIfOp
531 //===----------------------------------------------------------------------===//
533 void RegionIfOp::print(OpAsmPrinter &p) {
534 p << " ";
535 p.printOperands(getOperands());
536 p << ": " << getOperandTypes();
537 p.printArrowTypeList(getResultTypes());
538 p << " then ";
539 p.printRegion(getThenRegion(),
540 /*printEntryBlockArgs=*/true,
541 /*printBlockTerminators=*/true);
542 p << " else ";
543 p.printRegion(getElseRegion(),
544 /*printEntryBlockArgs=*/true,
545 /*printBlockTerminators=*/true);
546 p << " join ";
547 p.printRegion(getJoinRegion(),
548 /*printEntryBlockArgs=*/true,
549 /*printBlockTerminators=*/true);
552 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
553 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
554 SmallVector<Type, 2> operandTypes;
556 result.regions.reserve(3);
557 Region *thenRegion = result.addRegion();
558 Region *elseRegion = result.addRegion();
559 Region *joinRegion = result.addRegion();
561 // Parse operand, type and arrow type lists.
562 if (parser.parseOperandList(operandInfos) ||
563 parser.parseColonTypeList(operandTypes) ||
564 parser.parseArrowTypeList(result.types))
565 return failure();
567 // Parse all attached regions.
568 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
569 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
570 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
571 return failure();
573 return parser.resolveOperands(operandInfos, operandTypes,
574 parser.getCurrentLocation(), result.operands);
577 OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
578 assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
579 "invalid region index");
580 return getOperands();
583 void RegionIfOp::getSuccessorRegions(
584 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
585 // We always branch to the join region.
586 if (!point.isParent()) {
587 if (point != getJoinRegion())
588 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
589 else
590 regions.push_back(RegionSuccessor(getResults()));
591 return;
594 // The then and else regions are the entry regions of this op.
595 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
596 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
599 void RegionIfOp::getRegionInvocationBounds(
600 ArrayRef<Attribute> operands,
601 SmallVectorImpl<InvocationBounds> &invocationBounds) {
602 // Each region is invoked at most once.
603 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
606 //===----------------------------------------------------------------------===//
607 // AnyCondOp
608 //===----------------------------------------------------------------------===//
610 void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
611 SmallVectorImpl<RegionSuccessor> &regions) {
612 // The parent op branches into the only region, and the region branches back
613 // to the parent op.
614 if (point.isParent())
615 regions.emplace_back(&getRegion());
616 else
617 regions.emplace_back(getResults());
620 void AnyCondOp::getRegionInvocationBounds(
621 ArrayRef<Attribute> operands,
622 SmallVectorImpl<InvocationBounds> &invocationBounds) {
623 invocationBounds.emplace_back(1, 1);
626 //===----------------------------------------------------------------------===//
627 // SingleBlockImplicitTerminatorOp
628 //===----------------------------------------------------------------------===//
630 /// Testing the correctness of some traits.
631 static_assert(
632 llvm::is_detected<OpTrait::has_implicit_terminator_t,
633 SingleBlockImplicitTerminatorOp>::value,
634 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
635 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
636 SingleBlockImplicitTerminatorOp>::value,
637 "hasSingleBlockImplicitTerminator does not match "
638 "SingleBlockImplicitTerminatorOp");
640 //===----------------------------------------------------------------------===//
641 // SingleNoTerminatorCustomAsmOp
642 //===----------------------------------------------------------------------===//
644 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
645 OperationState &state) {
646 Region *body = state.addRegion();
647 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
648 return failure();
649 return success();
652 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
653 printer.printRegion(
654 getRegion(), /*printEntryBlockArgs=*/false,
655 // This op has a single block without terminators. But explicitly mark
656 // as not printing block terminators for testing.
657 /*printBlockTerminators=*/false);
660 //===----------------------------------------------------------------------===//
661 // TestVerifiersOp
662 //===----------------------------------------------------------------------===//
664 LogicalResult TestVerifiersOp::verify() {
665 if (!getRegion().hasOneBlock())
666 return emitOpError("`hasOneBlock` trait hasn't been verified");
668 Operation *definingOp = getInput().getDefiningOp();
669 if (definingOp && failed(mlir::verify(definingOp)))
670 return emitOpError("operand hasn't been verified");
672 // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
673 // loop.
674 mlir::emitRemark(getLoc(), "success run of verifier");
676 return success();
679 LogicalResult TestVerifiersOp::verifyRegions() {
680 if (!getRegion().hasOneBlock())
681 return emitOpError("`hasOneBlock` trait hasn't been verified");
683 for (Block &block : getRegion())
684 for (Operation &op : block)
685 if (failed(mlir::verify(&op)))
686 return emitOpError("nested op hasn't been verified");
688 // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
689 // loop.
690 mlir::emitRemark(getLoc(), "success run of region verifier");
692 return success();
695 //===----------------------------------------------------------------------===//
696 // Test InferIntRangeInterface
697 //===----------------------------------------------------------------------===//
699 //===----------------------------------------------------------------------===//
700 // TestWithBoundsOp
702 void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
703 SetIntRangeFn setResultRanges) {
704 setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
707 //===----------------------------------------------------------------------===//
708 // TestWithBoundsRegionOp
710 ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
711 OperationState &result) {
712 if (parser.parseOptionalAttrDict(result.attributes))
713 return failure();
715 // Parse the input argument
716 OpAsmParser::Argument argInfo;
717 if (failed(parser.parseArgument(argInfo, true)))
718 return failure();
720 // Parse the body region, and reuse the operand info as the argument info.
721 Region *body = result.addRegion();
722 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
725 void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
726 p.printOptionalAttrDict((*this)->getAttrs());
727 p << ' ';
728 p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
729 /*omitType=*/false);
730 p << ' ';
731 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
734 void TestWithBoundsRegionOp::inferResultRanges(
735 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
736 Value arg = getRegion().getArgument(0);
737 setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
740 //===----------------------------------------------------------------------===//
741 // TestIncrementOp
743 void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
744 SetIntRangeFn setResultRanges) {
745 const ConstantIntRanges &range = argRanges[0];
746 APInt one(range.umin().getBitWidth(), 1);
747 setResultRanges(getResult(),
748 {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
749 range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
752 //===----------------------------------------------------------------------===//
753 // TestReflectBoundsOp
755 void TestReflectBoundsOp::inferResultRanges(
756 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
757 const ConstantIntRanges &range = argRanges[0];
758 MLIRContext *ctx = getContext();
759 Builder b(ctx);
760 Type sIntTy, uIntTy;
761 // For plain `IntegerType`s, we can derive the appropriate signed and unsigned
762 // Types for the Attributes.
763 Type type = getElementTypeOrSelf(getType());
764 if (auto intTy = llvm::dyn_cast<IntegerType>(type)) {
765 unsigned bitwidth = intTy.getWidth();
766 sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true);
767 uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false);
768 } else
769 sIntTy = uIntTy = type;
771 setUminAttr(b.getIntegerAttr(uIntTy, range.umin()));
772 setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax()));
773 setSminAttr(b.getIntegerAttr(sIntTy, range.smin()));
774 setSmaxAttr(b.getIntegerAttr(sIntTy, range.smax()));
775 setResultRanges(getResult(), range);
778 //===----------------------------------------------------------------------===//
779 // ConversionFuncOp
780 //===----------------------------------------------------------------------===//
782 ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
783 OperationState &result) {
784 auto buildFuncType =
785 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
786 function_interface_impl::VariadicFlag,
787 std::string &) { return builder.getFunctionType(argTypes, results); };
789 return function_interface_impl::parseFunctionOp(
790 parser, result, /*allowVariadic=*/false,
791 getFunctionTypeAttrName(result.name), buildFuncType,
792 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
795 void ConversionFuncOp::print(OpAsmPrinter &p) {
796 function_interface_impl::printFunctionOp(
797 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
798 getArgAttrsAttrName(), getResAttrsAttrName());
801 //===----------------------------------------------------------------------===//
802 // ReifyBoundOp
803 //===----------------------------------------------------------------------===//
805 mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
806 if (getType() == "EQ")
807 return mlir::presburger::BoundType::EQ;
808 if (getType() == "LB")
809 return mlir::presburger::BoundType::LB;
810 if (getType() == "UB")
811 return mlir::presburger::BoundType::UB;
812 llvm_unreachable("invalid bound type");
815 LogicalResult ReifyBoundOp::verify() {
816 if (isa<ShapedType>(getVar().getType())) {
817 if (!getDim().has_value())
818 return emitOpError("expected 'dim' attribute for shaped type variable");
819 } else if (getVar().getType().isIndex()) {
820 if (getDim().has_value())
821 return emitOpError("unexpected 'dim' attribute for index variable");
822 } else {
823 return emitOpError("expected index-typed variable or shape type variable");
825 if (getConstant() && getScalable())
826 return emitOpError("'scalable' and 'constant' are mutually exlusive");
827 if (getScalable() != getVscaleMin().has_value())
828 return emitOpError("expected 'vscale_min' if and only if 'scalable'");
829 if (getScalable() != getVscaleMax().has_value())
830 return emitOpError("expected 'vscale_min' if and only if 'scalable'");
831 return success();
834 ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
835 if (getDim().has_value())
836 return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
837 return ValueBoundsConstraintSet::Variable(getVar());
840 //===----------------------------------------------------------------------===//
841 // CompareOp
842 //===----------------------------------------------------------------------===//
844 ValueBoundsConstraintSet::ComparisonOperator
845 CompareOp::getComparisonOperator() {
846 if (getCmp() == "EQ")
847 return ValueBoundsConstraintSet::ComparisonOperator::EQ;
848 if (getCmp() == "LT")
849 return ValueBoundsConstraintSet::ComparisonOperator::LT;
850 if (getCmp() == "LE")
851 return ValueBoundsConstraintSet::ComparisonOperator::LE;
852 if (getCmp() == "GT")
853 return ValueBoundsConstraintSet::ComparisonOperator::GT;
854 if (getCmp() == "GE")
855 return ValueBoundsConstraintSet::ComparisonOperator::GE;
856 llvm_unreachable("invalid comparison operator");
859 mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
860 if (!getLhsMap())
861 return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
862 SmallVector<Value> mapOperands(
863 getVarOperands().slice(0, getLhsMap()->getNumInputs()));
864 return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
867 mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
868 int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
869 if (!getRhsMap())
870 return ValueBoundsConstraintSet::Variable(
871 getVarOperands()[rhsOperandsBegin]);
872 SmallVector<Value> mapOperands(
873 getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
874 return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
877 LogicalResult CompareOp::verify() {
878 if (getCompose() && (getLhsMap() || getRhsMap()))
879 return emitOpError(
880 "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
881 int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
882 expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
883 if (getVarOperands().size() != size_t(expectedNumOperands))
884 return emitOpError("expected ")
885 << expectedNumOperands << " operands, but got "
886 << getVarOperands().size();
887 return success();
890 //===----------------------------------------------------------------------===//
891 // TestOpInPlaceSelfFold
892 //===----------------------------------------------------------------------===//
894 OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
895 if (!getFolded()) {
896 // The folder adds the "folded" if not present.
897 setFolded(true);
898 return getResult();
900 return {};
903 //===----------------------------------------------------------------------===//
904 // TestOpFoldWithFoldAdaptor
905 //===----------------------------------------------------------------------===//
907 OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
908 int64_t sum = 0;
909 if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
910 sum += value.getValue().getSExtValue();
912 for (Attribute attr : adaptor.getVariadic())
913 if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
914 sum += 2 * value.getValue().getSExtValue();
916 for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
917 for (Attribute attr : attrs)
918 if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
919 sum += 3 * value.getValue().getSExtValue();
921 sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
923 return IntegerAttr::get(getType(), sum);
926 //===----------------------------------------------------------------------===//
927 // OpWithInferTypeAdaptorInterfaceOp
928 //===----------------------------------------------------------------------===//
930 LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
931 MLIRContext *, std::optional<Location> location,
932 OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
933 SmallVectorImpl<Type> &inferredReturnTypes) {
934 if (adaptor.getX().getType() != adaptor.getY().getType()) {
935 return emitOptionalError(location, "operand type mismatch ",
936 adaptor.getX().getType(), " vs ",
937 adaptor.getY().getType());
939 inferredReturnTypes.assign({adaptor.getX().getType()});
940 return success();
943 //===----------------------------------------------------------------------===//
944 // OpWithRefineTypeInterfaceOp
945 //===----------------------------------------------------------------------===//
947 // TODO: We should be able to only define either inferReturnType or
948 // refineReturnType, currently only refineReturnType can be omitted.
949 LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
950 MLIRContext *context, std::optional<Location> location, ValueRange operands,
951 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
952 SmallVectorImpl<Type> &returnTypes) {
953 returnTypes.clear();
954 return OpWithRefineTypeInterfaceOp::refineReturnTypes(
955 context, location, operands, attributes, properties, regions,
956 returnTypes);
959 LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
960 MLIRContext *, std::optional<Location> location, ValueRange operands,
961 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
962 SmallVectorImpl<Type> &returnTypes) {
963 if (operands[0].getType() != operands[1].getType()) {
964 return emitOptionalError(location, "operand type mismatch ",
965 operands[0].getType(), " vs ",
966 operands[1].getType());
968 // TODO: Add helper to make this more concise to write.
969 if (returnTypes.empty())
970 returnTypes.resize(1, nullptr);
971 if (returnTypes[0] && returnTypes[0] != operands[0].getType())
972 return emitOptionalError(location,
973 "required first operand and result to match");
974 returnTypes[0] = operands[0].getType();
975 return success();
978 //===----------------------------------------------------------------------===//
979 // OpWithShapedTypeInferTypeAdaptorInterfaceOp
980 //===----------------------------------------------------------------------===//
982 LogicalResult
983 OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
984 MLIRContext *context, std::optional<Location> location,
985 OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
986 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
987 // Create return type consisting of the last element of the first operand.
988 auto operandType = adaptor.getOperand1().getType();
989 auto sval = dyn_cast<ShapedType>(operandType);
990 if (!sval)
991 return emitOptionalError(location, "only shaped type operands allowed");
992 int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
993 auto type = IntegerType::get(context, 17);
995 Attribute encoding;
996 if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
997 encoding = rankedTy.getEncoding();
998 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
999 return success();
1002 LogicalResult
1003 OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
1004 OpBuilder &builder, ValueRange operands,
1005 llvm::SmallVectorImpl<Value> &shapes) {
1006 shapes = SmallVector<Value, 1>{
1007 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
1008 return success();
1011 //===----------------------------------------------------------------------===//
1012 // TestOpWithPropertiesAndInferredType
1013 //===----------------------------------------------------------------------===//
1015 LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes(
1016 MLIRContext *context, std::optional<Location>, ValueRange operands,
1017 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
1018 SmallVectorImpl<Type> &inferredReturnTypes) {
1020 Adaptor adaptor(operands, attributes, properties, regions);
1021 inferredReturnTypes.push_back(IntegerType::get(
1022 context, adaptor.getLhs() + adaptor.getProperties().rhs));
1023 return success();
1026 //===----------------------------------------------------------------------===//
1027 // LoopBlockOp
1028 //===----------------------------------------------------------------------===//
1030 void LoopBlockOp::getSuccessorRegions(
1031 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1032 regions.emplace_back(&getBody(), getBody().getArguments());
1033 if (point.isParent())
1034 return;
1036 regions.emplace_back((*this)->getResults());
1039 OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1040 assert(point == getBody());
1041 return MutableOperandRange(getInitMutable());
1044 //===----------------------------------------------------------------------===//
1045 // LoopBlockTerminatorOp
1046 //===----------------------------------------------------------------------===//
1048 MutableOperandRange
1049 LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
1050 if (point.isParent())
1051 return getExitArgMutable();
1052 return getNextIterArgMutable();
1055 //===----------------------------------------------------------------------===//
1056 // SwitchWithNoBreakOp
1057 //===----------------------------------------------------------------------===//
1059 void TestNoTerminatorOp::getSuccessorRegions(
1060 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {}
1062 //===----------------------------------------------------------------------===//
1063 // Test InferIntRangeInterface
1064 //===----------------------------------------------------------------------===//
1066 OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
1067 // Just a simple fold for testing purposes that reads an operands constant
1068 // value and returns it.
1069 if (!attributes.empty())
1070 return attributes.front();
1071 return nullptr;
1074 //===----------------------------------------------------------------------===//
1075 // Tensor/Buffer Ops
1076 //===----------------------------------------------------------------------===//
1078 void ReadBufferOp::getEffects(
1079 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1080 &effects) {
1081 // The buffer operand is read.
1082 effects.emplace_back(MemoryEffects::Read::get(), &getBufferMutable(),
1083 SideEffects::DefaultResource::get());
1084 // The buffer contents are dumped.
1085 effects.emplace_back(MemoryEffects::Write::get(),
1086 SideEffects::DefaultResource::get());
1089 //===----------------------------------------------------------------------===//
1090 // Test Dataflow
1091 //===----------------------------------------------------------------------===//
1093 //===----------------------------------------------------------------------===//
1094 // TestCallAndStoreOp
1096 CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() {
1097 return getCallee();
1100 void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1101 setCalleeAttr(callee.get<SymbolRefAttr>());
1104 Operation::operand_range TestCallAndStoreOp::getArgOperands() {
1105 return getCalleeOperands();
1108 MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
1109 return getCalleeOperandsMutable();
1112 //===----------------------------------------------------------------------===//
1113 // TestCallOnDeviceOp
1115 CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() {
1116 return getCallee();
1119 void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1120 setCalleeAttr(callee.get<SymbolRefAttr>());
1123 Operation::operand_range TestCallOnDeviceOp::getArgOperands() {
1124 return getForwardedOperands();
1127 MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
1128 return getForwardedOperandsMutable();
1131 //===----------------------------------------------------------------------===//
1132 // TestStoreWithARegion
1134 void TestStoreWithARegion::getSuccessorRegions(
1135 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1136 if (point.isParent())
1137 regions.emplace_back(&getBody(), getBody().front().getArguments());
1138 else
1139 regions.emplace_back();
1142 //===----------------------------------------------------------------------===//
1143 // TestStoreWithALoopRegion
1145 void TestStoreWithALoopRegion::getSuccessorRegions(
1146 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1147 // Both the operation itself and the region may be branching into the body or
1148 // back into the operation itself. It is possible for the operation not to
1149 // enter the body.
1150 regions.emplace_back(
1151 RegionSuccessor(&getBody(), getBody().front().getArguments()));
1152 regions.emplace_back();
1155 //===----------------------------------------------------------------------===//
1156 // TestVersionedOpA
1157 //===----------------------------------------------------------------------===//
1159 LogicalResult
1160 TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader,
1161 mlir::OperationState &state) {
1162 auto &prop = state.getOrAddProperties<Properties>();
1163 if (mlir::failed(reader.readAttribute(prop.dims)))
1164 return mlir::failure();
1166 // Check if we have a version. If not, assume we are parsing the current
1167 // version.
1168 auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
1169 if (succeeded(maybeVersion)) {
1170 // If version is less than 2.0, there is no additional attribute to parse.
1171 // We can materialize missing properties post parsing before verification.
1172 const auto *version =
1173 reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1174 if ((version->major_ < 2)) {
1175 return success();
1179 if (mlir::failed(reader.readAttribute(prop.modifier)))
1180 return mlir::failure();
1181 return mlir::success();
1184 void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
1185 auto &prop = getProperties();
1186 writer.writeAttribute(prop.dims);
1188 auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
1189 if (succeeded(maybeVersion)) {
1190 // If version is less than 2.0, there is no additional attribute to write.
1191 const auto *version =
1192 reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1193 if ((version->major_ < 2)) {
1194 llvm::outs() << "downgrading op properties...\n";
1195 return;
1198 writer.writeAttribute(prop.modifier);
1201 //===----------------------------------------------------------------------===//
1202 // TestOpWithVersionedProperties
1203 //===----------------------------------------------------------------------===//
1205 llvm::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
1206 mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) {
1207 uint64_t value1, value2 = 0;
1208 if (failed(reader.readVarInt(value1)))
1209 return failure();
1211 // Check if we have a version. If not, assume we are parsing the current
1212 // version.
1213 auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
1214 bool needToParseAnotherInt = true;
1215 if (succeeded(maybeVersion)) {
1216 // If version is less than 2.0, there is no additional attribute to parse.
1217 // We can materialize missing properties post parsing before verification.
1218 const auto *version =
1219 reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1220 if ((version->major_ < 2))
1221 needToParseAnotherInt = false;
1223 if (needToParseAnotherInt && failed(reader.readVarInt(value2)))
1224 return failure();
1226 prop.value1 = value1;
1227 prop.value2 = value2;
1228 return success();
1231 void TestOpWithVersionedProperties::writeToMlirBytecode(
1232 mlir::DialectBytecodeWriter &writer,
1233 const test::VersionedProperties &prop) {
1234 writer.writeVarInt(prop.value1);
1235 writer.writeVarInt(prop.value2);
1238 //===----------------------------------------------------------------------===//
1239 // TestMultiSlotAlloca
1240 //===----------------------------------------------------------------------===//
1242 llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() {
1243 SmallVector<MemorySlot> slots;
1244 for (Value result : getResults()) {
1245 slots.push_back(MemorySlot{
1246 result, cast<MemRefType>(result.getType()).getElementType()});
1248 return slots;
1251 Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot,
1252 OpBuilder &builder) {
1253 return builder.create<TestOpConstant>(getLoc(), slot.elemType,
1254 builder.getI32IntegerAttr(42));
1257 void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
1258 BlockArgument argument,
1259 OpBuilder &builder) {
1260 // Not relevant for testing.
1263 /// Creates a new TestMultiSlotAlloca operation, just without the `slot`.
1264 static std::optional<TestMultiSlotAlloca>
1265 createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder,
1266 TestMultiSlotAlloca oldOp) {
1268 if (oldOp.getNumResults() == 1) {
1269 oldOp.erase();
1270 return std::nullopt;
1273 SmallVector<Type> newTypes;
1274 SmallVector<Value> remainingValues;
1276 for (Value oldResult : oldOp.getResults()) {
1277 if (oldResult == slot.ptr)
1278 continue;
1279 remainingValues.push_back(oldResult);
1280 newTypes.push_back(oldResult.getType());
1283 OpBuilder::InsertionGuard guard(builder);
1284 builder.setInsertionPoint(oldOp);
1285 auto replacement =
1286 builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes);
1287 for (auto [oldResult, newResult] :
1288 llvm::zip_equal(remainingValues, replacement.getResults()))
1289 oldResult.replaceAllUsesWith(newResult);
1291 oldOp.erase();
1292 return replacement;
1295 std::optional<PromotableAllocationOpInterface>
1296 TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
1297 Value defaultValue,
1298 OpBuilder &builder) {
1299 if (defaultValue && defaultValue.use_empty())
1300 defaultValue.getDefiningOp()->erase();
1301 return createNewMultiAllocaWithoutSlot(slot, builder, *this);
1304 SmallVector<DestructurableMemorySlot>
1305 TestMultiSlotAlloca::getDestructurableSlots() {
1306 SmallVector<DestructurableMemorySlot> slots;
1307 for (Value result : getResults()) {
1308 auto memrefType = cast<MemRefType>(result.getType());
1309 auto destructurable = dyn_cast<DestructurableTypeInterface>(memrefType);
1310 if (!destructurable)
1311 continue;
1313 std::optional<DenseMap<Attribute, Type>> destructuredType =
1314 destructurable.getSubelementIndexMap();
1315 if (!destructuredType)
1316 continue;
1317 slots.emplace_back(
1318 DestructurableMemorySlot{{result, memrefType}, *destructuredType});
1320 return slots;
1323 DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure(
1324 const DestructurableMemorySlot &slot,
1325 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
1326 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
1327 OpBuilder::InsertionGuard guard(builder);
1328 builder.setInsertionPointAfter(*this);
1330 DenseMap<Attribute, MemorySlot> slotMap;
1332 for (Attribute usedIndex : usedIndices) {
1333 Type elemType = slot.subelementTypes.lookup(usedIndex);
1334 MemRefType elemPtr = MemRefType::get({}, elemType);
1335 auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr);
1336 newAllocators.push_back(subAlloca);
1337 slotMap.try_emplace<MemorySlot>(usedIndex,
1338 {subAlloca.getResult(0), elemType});
1341 return slotMap;
1344 std::optional<DestructurableAllocationOpInterface>
1345 TestMultiSlotAlloca::handleDestructuringComplete(
1346 const DestructurableMemorySlot &slot, OpBuilder &builder) {
1347 return createNewMultiAllocaWithoutSlot(slot, builder, *this);