1 //===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "TestDialect.h"
11 #include "mlir/Dialect/Tensor/IR/Tensor.h"
12 #include "mlir/IR/Verifier.h"
13 #include "mlir/Interfaces/FunctionImplementation.h"
14 #include "mlir/Interfaces/MemorySlotInterfaces.h"
19 //===----------------------------------------------------------------------===//
21 //===----------------------------------------------------------------------===//
23 SuccessorOperands
TestBranchOp::getSuccessorOperands(unsigned index
) {
24 assert(index
== 0 && "invalid successor index");
25 return SuccessorOperands(getTargetOperandsMutable());
28 //===----------------------------------------------------------------------===//
29 // TestProducingBranchOp
30 //===----------------------------------------------------------------------===//
32 SuccessorOperands
TestProducingBranchOp::getSuccessorOperands(unsigned index
) {
33 assert(index
<= 1 && "invalid successor index");
35 return SuccessorOperands(getFirstOperandsMutable());
36 return SuccessorOperands(getSecondOperandsMutable());
39 //===----------------------------------------------------------------------===//
40 // TestInternalBranchOp
41 //===----------------------------------------------------------------------===//
43 SuccessorOperands
TestInternalBranchOp::getSuccessorOperands(unsigned index
) {
44 assert(index
<= 1 && "invalid successor index");
46 return SuccessorOperands(0, getSuccessOperandsMutable());
47 return SuccessorOperands(1, getErrorOperandsMutable());
50 //===----------------------------------------------------------------------===//
52 //===----------------------------------------------------------------------===//
54 LogicalResult
TestCallOp::verifySymbolUses(SymbolTableCollection
&symbolTable
) {
55 // Check that the callee attribute was specified.
56 auto fnAttr
= (*this)->getAttrOfType
<FlatSymbolRefAttr
>("callee");
58 return emitOpError("requires a 'callee' symbol reference attribute");
59 if (!symbolTable
.lookupNearestSymbolFrom
<FunctionOpInterface
>(*this, fnAttr
))
60 return emitOpError() << "'" << fnAttr
.getValue()
61 << "' does not reference a valid function";
65 //===----------------------------------------------------------------------===//
67 //===----------------------------------------------------------------------===//
70 struct FoldToCallOpPattern
: public OpRewritePattern
<FoldToCallOp
> {
71 using OpRewritePattern
<FoldToCallOp
>::OpRewritePattern
;
73 LogicalResult
matchAndRewrite(FoldToCallOp op
,
74 PatternRewriter
&rewriter
) const override
{
75 rewriter
.replaceOpWithNewOp
<func::CallOp
>(op
, TypeRange(),
76 op
.getCalleeAttr(), ValueRange());
82 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
83 MLIRContext
*context
) {
84 results
.add
<FoldToCallOpPattern
>(context
);
87 //===----------------------------------------------------------------------===//
88 // IsolatedRegionOp - test parsing passthrough operands
89 //===----------------------------------------------------------------------===//
91 ParseResult
IsolatedRegionOp::parse(OpAsmParser
&parser
,
92 OperationState
&result
) {
93 // Parse the input operand.
94 OpAsmParser::Argument argInfo
;
95 argInfo
.type
= parser
.getBuilder().getIndexType();
96 if (parser
.parseOperand(argInfo
.ssaName
) ||
97 parser
.resolveOperand(argInfo
.ssaName
, argInfo
.type
, result
.operands
))
100 // Parse the body region, and reuse the operand info as the argument info.
101 Region
*body
= result
.addRegion();
102 return parser
.parseRegion(*body
, argInfo
, /*enableNameShadowing=*/true);
105 void IsolatedRegionOp::print(OpAsmPrinter
&p
) {
107 p
.printOperand(getOperand());
108 p
.shadowRegionArgs(getRegion(), getOperand());
110 p
.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
113 //===----------------------------------------------------------------------===//
115 //===----------------------------------------------------------------------===//
117 RegionKind
SSACFGRegionOp::getRegionKind(unsigned index
) {
118 return RegionKind::SSACFG
;
121 //===----------------------------------------------------------------------===//
123 //===----------------------------------------------------------------------===//
125 RegionKind
GraphRegionOp::getRegionKind(unsigned index
) {
126 return RegionKind::Graph
;
129 //===----------------------------------------------------------------------===//
130 // IsolatedGraphRegionOp
131 //===----------------------------------------------------------------------===//
133 RegionKind
IsolatedGraphRegionOp::getRegionKind(unsigned index
) {
134 return RegionKind::Graph
;
137 //===----------------------------------------------------------------------===//
139 //===----------------------------------------------------------------------===//
141 ParseResult
AffineScopeOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
142 // Parse the body region, and reuse the operand info as the argument info.
143 Region
*body
= result
.addRegion();
144 return parser
.parseRegion(*body
, /*arguments=*/{}, /*argTypes=*/{});
147 void AffineScopeOp::print(OpAsmPrinter
&p
) {
149 p
.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
152 //===----------------------------------------------------------------------===//
153 // TestRemoveOpWithInnerOps
154 //===----------------------------------------------------------------------===//
157 struct TestRemoveOpWithInnerOps
158 : public OpRewritePattern
<TestOpWithRegionPattern
> {
159 using OpRewritePattern
<TestOpWithRegionPattern
>::OpRewritePattern
;
161 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
163 LogicalResult
matchAndRewrite(TestOpWithRegionPattern op
,
164 PatternRewriter
&rewriter
) const override
{
165 rewriter
.eraseOp(op
);
171 //===----------------------------------------------------------------------===//
172 // TestOpWithRegionPattern
173 //===----------------------------------------------------------------------===//
175 void TestOpWithRegionPattern::getCanonicalizationPatterns(
176 RewritePatternSet
&results
, MLIRContext
*context
) {
177 results
.add
<TestRemoveOpWithInnerOps
>(context
);
180 //===----------------------------------------------------------------------===//
181 // TestOpWithRegionFold
182 //===----------------------------------------------------------------------===//
184 OpFoldResult
TestOpWithRegionFold::fold(FoldAdaptor adaptor
) {
188 //===----------------------------------------------------------------------===//
190 //===----------------------------------------------------------------------===//
192 OpFoldResult
TestOpConstant::fold(FoldAdaptor adaptor
) { return getValue(); }
194 //===----------------------------------------------------------------------===//
195 // TestOpWithVariadicResultsAndFolder
196 //===----------------------------------------------------------------------===//
198 LogicalResult
TestOpWithVariadicResultsAndFolder::fold(
199 FoldAdaptor adaptor
, SmallVectorImpl
<OpFoldResult
> &results
) {
200 for (Value input
: this->getOperands()) {
201 results
.push_back(input
);
206 //===----------------------------------------------------------------------===//
208 //===----------------------------------------------------------------------===//
210 OpFoldResult
TestOpInPlaceFold::fold(FoldAdaptor adaptor
) {
211 // Exercise the fact that an operation created with createOrFold should be
212 // allowed to access its parent block.
213 assert(getOperation()->getBlock() &&
214 "expected that operation is not unlinked");
216 if (adaptor
.getOp() && !getProperties().attr
) {
217 // The folder adds "attr" if not present.
218 getProperties().attr
= dyn_cast_or_null
<IntegerAttr
>(adaptor
.getOp());
224 //===----------------------------------------------------------------------===//
225 // OpWithInferTypeInterfaceOp
226 //===----------------------------------------------------------------------===//
228 LogicalResult
OpWithInferTypeInterfaceOp::inferReturnTypes(
229 MLIRContext
*, std::optional
<Location
> location
, ValueRange operands
,
230 DictionaryAttr attributes
, OpaqueProperties properties
, RegionRange regions
,
231 SmallVectorImpl
<Type
> &inferredReturnTypes
) {
232 if (operands
[0].getType() != operands
[1].getType()) {
233 return emitOptionalError(location
, "operand type mismatch ",
234 operands
[0].getType(), " vs ",
235 operands
[1].getType());
237 inferredReturnTypes
.assign({operands
[0].getType()});
241 //===----------------------------------------------------------------------===//
242 // OpWithShapedTypeInferTypeInterfaceOp
243 //===----------------------------------------------------------------------===//
245 LogicalResult
OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
246 MLIRContext
*context
, std::optional
<Location
> location
,
247 ValueShapeRange operands
, DictionaryAttr attributes
,
248 OpaqueProperties properties
, RegionRange regions
,
249 SmallVectorImpl
<ShapedTypeComponents
> &inferredReturnShapes
) {
250 // Create return type consisting of the last element of the first operand.
251 auto operandType
= operands
.front().getType();
252 auto sval
= dyn_cast
<ShapedType
>(operandType
);
254 return emitOptionalError(location
, "only shaped type operands allowed");
255 int64_t dim
= sval
.hasRank() ? sval
.getShape().front() : ShapedType::kDynamic
;
256 auto type
= IntegerType::get(context
, 17);
259 if (auto rankedTy
= dyn_cast
<RankedTensorType
>(sval
))
260 encoding
= rankedTy
.getEncoding();
261 inferredReturnShapes
.push_back(ShapedTypeComponents({dim
}, type
, encoding
));
265 LogicalResult
OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
266 OpBuilder
&builder
, ValueRange operands
,
267 llvm::SmallVectorImpl
<Value
> &shapes
) {
268 shapes
= SmallVector
<Value
, 1>{
269 builder
.createOrFold
<tensor::DimOp
>(getLoc(), operands
.front(), 0)};
273 //===----------------------------------------------------------------------===//
274 // OpWithResultShapeInterfaceOp
275 //===----------------------------------------------------------------------===//
277 LogicalResult
OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
278 OpBuilder
&builder
, ValueRange operands
,
279 llvm::SmallVectorImpl
<Value
> &shapes
) {
280 Location loc
= getLoc();
281 shapes
.reserve(operands
.size());
282 for (Value operand
: llvm::reverse(operands
)) {
283 auto rank
= cast
<RankedTensorType
>(operand
.getType()).getRank();
284 auto currShape
= llvm::to_vector
<4>(
285 llvm::map_range(llvm::seq
<int64_t>(0, rank
), [&](int64_t dim
) -> Value
{
286 return builder
.createOrFold
<tensor::DimOp
>(loc
, operand
, dim
);
288 shapes
.push_back(builder
.create
<tensor::FromElementsOp
>(
289 getLoc(), RankedTensorType::get({rank
}, builder
.getIndexType()),
295 //===----------------------------------------------------------------------===//
296 // OpWithResultShapePerDimInterfaceOp
297 //===----------------------------------------------------------------------===//
299 LogicalResult
OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
300 OpBuilder
&builder
, ReifiedRankedShapedTypeDims
&shapes
) {
301 Location loc
= getLoc();
302 shapes
.reserve(getNumOperands());
303 for (Value operand
: llvm::reverse(getOperands())) {
304 auto tensorType
= cast
<RankedTensorType
>(operand
.getType());
305 auto currShape
= llvm::to_vector
<4>(llvm::map_range(
306 llvm::seq
<int64_t>(0, tensorType
.getRank()),
307 [&](int64_t dim
) -> OpFoldResult
{
308 return tensorType
.isDynamicDim(dim
)
309 ? static_cast<OpFoldResult
>(
310 builder
.createOrFold
<tensor::DimOp
>(loc
, operand
,
312 : static_cast<OpFoldResult
>(
313 builder
.getIndexAttr(tensorType
.getDimSize(dim
)));
315 shapes
.emplace_back(std::move(currShape
));
320 //===----------------------------------------------------------------------===//
322 //===----------------------------------------------------------------------===//
325 /// A test resource for side effects.
326 struct TestResource
: public SideEffects::Resource::Base
<TestResource
> {
327 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource
)
329 StringRef
getName() final
{ return "<Test>"; }
333 void SideEffectOp::getEffects(
334 SmallVectorImpl
<MemoryEffects::EffectInstance
> &effects
) {
335 // Check for an effects attribute on the op instance.
336 ArrayAttr effectsAttr
= (*this)->getAttrOfType
<ArrayAttr
>("effects");
340 for (Attribute element
: effectsAttr
) {
341 DictionaryAttr effectElement
= cast
<DictionaryAttr
>(element
);
343 // Get the specific memory effect.
344 MemoryEffects::Effect
*effect
=
345 StringSwitch
<MemoryEffects::Effect
*>(
346 cast
<StringAttr
>(effectElement
.get("effect")).getValue())
347 .Case("allocate", MemoryEffects::Allocate::get())
348 .Case("free", MemoryEffects::Free::get())
349 .Case("read", MemoryEffects::Read::get())
350 .Case("write", MemoryEffects::Write::get());
352 // Check for a non-default resource to use.
353 SideEffects::Resource
*resource
= SideEffects::DefaultResource::get();
354 if (effectElement
.get("test_resource"))
355 resource
= TestResource::get();
357 // Check for a result to affect.
358 if (effectElement
.get("on_result"))
359 effects
.emplace_back(effect
, getOperation()->getOpResults()[0], resource
);
360 else if (Attribute ref
= effectElement
.get("on_reference"))
361 effects
.emplace_back(effect
, cast
<SymbolRefAttr
>(ref
), resource
);
363 effects
.emplace_back(effect
, resource
);
367 void SideEffectOp::getEffects(
368 SmallVectorImpl
<TestEffects::EffectInstance
> &effects
) {
369 testSideEffectOpGetEffect(getOperation(), effects
);
372 void SideEffectWithRegionOp::getEffects(
373 SmallVectorImpl
<MemoryEffects::EffectInstance
> &effects
) {
374 // Check for an effects attribute on the op instance.
375 ArrayAttr effectsAttr
= (*this)->getAttrOfType
<ArrayAttr
>("effects");
379 for (Attribute element
: effectsAttr
) {
380 DictionaryAttr effectElement
= cast
<DictionaryAttr
>(element
);
382 // Get the specific memory effect.
383 MemoryEffects::Effect
*effect
=
384 StringSwitch
<MemoryEffects::Effect
*>(
385 cast
<StringAttr
>(effectElement
.get("effect")).getValue())
386 .Case("allocate", MemoryEffects::Allocate::get())
387 .Case("free", MemoryEffects::Free::get())
388 .Case("read", MemoryEffects::Read::get())
389 .Case("write", MemoryEffects::Write::get());
391 // Check for a non-default resource to use.
392 SideEffects::Resource
*resource
= SideEffects::DefaultResource::get();
393 if (effectElement
.get("test_resource"))
394 resource
= TestResource::get();
396 // Check for a result to affect.
397 if (effectElement
.get("on_result"))
398 effects
.emplace_back(effect
, getOperation()->getOpResults()[0], resource
);
399 else if (effectElement
.get("on_operand"))
400 effects
.emplace_back(effect
, &getOperation()->getOpOperands()[0],
402 else if (effectElement
.get("on_argument"))
403 effects
.emplace_back(effect
, getOperation()->getRegion(0).getArgument(0),
405 else if (Attribute ref
= effectElement
.get("on_reference"))
406 effects
.emplace_back(effect
, cast
<SymbolRefAttr
>(ref
), resource
);
408 effects
.emplace_back(effect
, resource
);
412 void SideEffectWithRegionOp::getEffects(
413 SmallVectorImpl
<TestEffects::EffectInstance
> &effects
) {
414 testSideEffectOpGetEffect(getOperation(), effects
);
417 //===----------------------------------------------------------------------===//
418 // StringAttrPrettyNameOp
419 //===----------------------------------------------------------------------===//
421 // This op has fancy handling of its SSA result name.
422 ParseResult
StringAttrPrettyNameOp::parse(OpAsmParser
&parser
,
423 OperationState
&result
) {
424 // Add the result types.
425 for (size_t i
= 0, e
= parser
.getNumResults(); i
!= e
; ++i
)
426 result
.addTypes(parser
.getBuilder().getIntegerType(32));
428 if (parser
.parseOptionalAttrDictWithKeyword(result
.attributes
))
431 // If the attribute dictionary contains no 'names' attribute, infer it from
432 // the SSA name (if specified).
433 bool hadNames
= llvm::any_of(result
.attributes
, [](NamedAttribute attr
) {
434 return attr
.getName() == "names";
437 // If there was no name specified, check to see if there was a useful name
438 // specified in the asm file.
439 if (hadNames
|| parser
.getNumResults() == 0)
442 SmallVector
<StringRef
, 4> names
;
443 auto *context
= result
.getContext();
445 for (size_t i
= 0, e
= parser
.getNumResults(); i
!= e
; ++i
) {
446 auto resultName
= parser
.getResultName(i
);
448 if (!resultName
.first
.empty() && !isdigit(resultName
.first
[0]))
449 nameStr
= resultName
.first
;
451 names
.push_back(nameStr
);
454 auto namesAttr
= parser
.getBuilder().getStrArrayAttr(names
);
455 result
.attributes
.push_back({StringAttr::get(context
, "names"), namesAttr
});
459 void StringAttrPrettyNameOp::print(OpAsmPrinter
&p
) {
460 // Note that we only need to print the "name" attribute if the asmprinter
461 // result name disagrees with it. This can happen in strange cases, e.g.
462 // when there are conflicts.
463 bool namesDisagree
= getNames().size() != getNumResults();
465 SmallString
<32> resultNameStr
;
466 for (size_t i
= 0, e
= getNumResults(); i
!= e
&& !namesDisagree
; ++i
) {
467 resultNameStr
.clear();
468 llvm::raw_svector_ostream
tmpStream(resultNameStr
);
469 p
.printOperand(getResult(i
), tmpStream
);
471 auto expectedName
= dyn_cast
<StringAttr
>(getNames()[i
]);
473 tmpStream
.str().drop_front() != expectedName
.getValue()) {
474 namesDisagree
= true;
479 p
.printOptionalAttrDictWithKeyword((*this)->getAttrs());
481 p
.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
484 // We set the SSA name in the asm syntax to the contents of the name
486 void StringAttrPrettyNameOp::getAsmResultNames(
487 function_ref
<void(Value
, StringRef
)> setNameFn
) {
489 auto value
= getNames();
490 for (size_t i
= 0, e
= value
.size(); i
!= e
; ++i
)
491 if (auto str
= dyn_cast
<StringAttr
>(value
[i
]))
492 if (!str
.getValue().empty())
493 setNameFn(getResult(i
), str
.getValue());
496 //===----------------------------------------------------------------------===//
497 // CustomResultsNameOp
498 //===----------------------------------------------------------------------===//
500 void CustomResultsNameOp::getAsmResultNames(
501 function_ref
<void(Value
, StringRef
)> setNameFn
) {
502 ArrayAttr value
= getNames();
503 for (size_t i
= 0, e
= value
.size(); i
!= e
; ++i
)
504 if (auto str
= dyn_cast
<StringAttr
>(value
[i
]))
506 setNameFn(getResult(i
), str
.getValue());
509 //===----------------------------------------------------------------------===//
510 // ResultTypeWithTraitOp
511 //===----------------------------------------------------------------------===//
513 LogicalResult
ResultTypeWithTraitOp::verify() {
514 if ((*this)->getResultTypes()[0].hasTrait
<TypeTrait::TestTypeTrait
>())
516 return emitError("result type should have trait 'TestTypeTrait'");
519 //===----------------------------------------------------------------------===//
521 //===----------------------------------------------------------------------===//
523 LogicalResult
AttrWithTraitOp::verify() {
524 if (getAttr().hasTrait
<AttributeTrait::TestAttrTrait
>())
526 return emitError("'attr' attribute should have trait 'TestAttrTrait'");
529 //===----------------------------------------------------------------------===//
531 //===----------------------------------------------------------------------===//
533 void RegionIfOp::print(OpAsmPrinter
&p
) {
535 p
.printOperands(getOperands());
536 p
<< ": " << getOperandTypes();
537 p
.printArrowTypeList(getResultTypes());
539 p
.printRegion(getThenRegion(),
540 /*printEntryBlockArgs=*/true,
541 /*printBlockTerminators=*/true);
543 p
.printRegion(getElseRegion(),
544 /*printEntryBlockArgs=*/true,
545 /*printBlockTerminators=*/true);
547 p
.printRegion(getJoinRegion(),
548 /*printEntryBlockArgs=*/true,
549 /*printBlockTerminators=*/true);
552 ParseResult
RegionIfOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
553 SmallVector
<OpAsmParser::UnresolvedOperand
, 2> operandInfos
;
554 SmallVector
<Type
, 2> operandTypes
;
556 result
.regions
.reserve(3);
557 Region
*thenRegion
= result
.addRegion();
558 Region
*elseRegion
= result
.addRegion();
559 Region
*joinRegion
= result
.addRegion();
561 // Parse operand, type and arrow type lists.
562 if (parser
.parseOperandList(operandInfos
) ||
563 parser
.parseColonTypeList(operandTypes
) ||
564 parser
.parseArrowTypeList(result
.types
))
567 // Parse all attached regions.
568 if (parser
.parseKeyword("then") || parser
.parseRegion(*thenRegion
, {}, {}) ||
569 parser
.parseKeyword("else") || parser
.parseRegion(*elseRegion
, {}, {}) ||
570 parser
.parseKeyword("join") || parser
.parseRegion(*joinRegion
, {}, {}))
573 return parser
.resolveOperands(operandInfos
, operandTypes
,
574 parser
.getCurrentLocation(), result
.operands
);
577 OperandRange
RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point
) {
578 assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point
) &&
579 "invalid region index");
580 return getOperands();
583 void RegionIfOp::getSuccessorRegions(
584 RegionBranchPoint point
, SmallVectorImpl
<RegionSuccessor
> ®ions
) {
585 // We always branch to the join region.
586 if (!point
.isParent()) {
587 if (point
!= getJoinRegion())
588 regions
.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
590 regions
.push_back(RegionSuccessor(getResults()));
594 // The then and else regions are the entry regions of this op.
595 regions
.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
596 regions
.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
599 void RegionIfOp::getRegionInvocationBounds(
600 ArrayRef
<Attribute
> operands
,
601 SmallVectorImpl
<InvocationBounds
> &invocationBounds
) {
602 // Each region is invoked at most once.
603 invocationBounds
.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
606 //===----------------------------------------------------------------------===//
608 //===----------------------------------------------------------------------===//
610 void AnyCondOp::getSuccessorRegions(RegionBranchPoint point
,
611 SmallVectorImpl
<RegionSuccessor
> ®ions
) {
612 // The parent op branches into the only region, and the region branches back
614 if (point
.isParent())
615 regions
.emplace_back(&getRegion());
617 regions
.emplace_back(getResults());
620 void AnyCondOp::getRegionInvocationBounds(
621 ArrayRef
<Attribute
> operands
,
622 SmallVectorImpl
<InvocationBounds
> &invocationBounds
) {
623 invocationBounds
.emplace_back(1, 1);
626 //===----------------------------------------------------------------------===//
627 // SingleBlockImplicitTerminatorOp
628 //===----------------------------------------------------------------------===//
630 /// Testing the correctness of some traits.
632 llvm::is_detected
<OpTrait::has_implicit_terminator_t
,
633 SingleBlockImplicitTerminatorOp
>::value
,
634 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
635 static_assert(OpTrait::hasSingleBlockImplicitTerminator
<
636 SingleBlockImplicitTerminatorOp
>::value
,
637 "hasSingleBlockImplicitTerminator does not match "
638 "SingleBlockImplicitTerminatorOp");
640 //===----------------------------------------------------------------------===//
641 // SingleNoTerminatorCustomAsmOp
642 //===----------------------------------------------------------------------===//
644 ParseResult
SingleNoTerminatorCustomAsmOp::parse(OpAsmParser
&parser
,
645 OperationState
&state
) {
646 Region
*body
= state
.addRegion();
647 if (parser
.parseRegion(*body
, /*arguments=*/{}, /*argTypes=*/{}))
652 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter
&printer
) {
654 getRegion(), /*printEntryBlockArgs=*/false,
655 // This op has a single block without terminators. But explicitly mark
656 // as not printing block terminators for testing.
657 /*printBlockTerminators=*/false);
660 //===----------------------------------------------------------------------===//
662 //===----------------------------------------------------------------------===//
664 LogicalResult
TestVerifiersOp::verify() {
665 if (!getRegion().hasOneBlock())
666 return emitOpError("`hasOneBlock` trait hasn't been verified");
668 Operation
*definingOp
= getInput().getDefiningOp();
669 if (definingOp
&& failed(mlir::verify(definingOp
)))
670 return emitOpError("operand hasn't been verified");
672 // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
674 mlir::emitRemark(getLoc(), "success run of verifier");
679 LogicalResult
TestVerifiersOp::verifyRegions() {
680 if (!getRegion().hasOneBlock())
681 return emitOpError("`hasOneBlock` trait hasn't been verified");
683 for (Block
&block
: getRegion())
684 for (Operation
&op
: block
)
685 if (failed(mlir::verify(&op
)))
686 return emitOpError("nested op hasn't been verified");
688 // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
690 mlir::emitRemark(getLoc(), "success run of region verifier");
695 //===----------------------------------------------------------------------===//
696 // Test InferIntRangeInterface
697 //===----------------------------------------------------------------------===//
699 //===----------------------------------------------------------------------===//
702 void TestWithBoundsOp::inferResultRanges(ArrayRef
<ConstantIntRanges
> argRanges
,
703 SetIntRangeFn setResultRanges
) {
704 setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
707 //===----------------------------------------------------------------------===//
708 // TestWithBoundsRegionOp
710 ParseResult
TestWithBoundsRegionOp::parse(OpAsmParser
&parser
,
711 OperationState
&result
) {
712 if (parser
.parseOptionalAttrDict(result
.attributes
))
715 // Parse the input argument
716 OpAsmParser::Argument argInfo
;
717 if (failed(parser
.parseArgument(argInfo
, true)))
720 // Parse the body region, and reuse the operand info as the argument info.
721 Region
*body
= result
.addRegion();
722 return parser
.parseRegion(*body
, argInfo
, /*enableNameShadowing=*/false);
725 void TestWithBoundsRegionOp::print(OpAsmPrinter
&p
) {
726 p
.printOptionalAttrDict((*this)->getAttrs());
728 p
.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
731 p
.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
734 void TestWithBoundsRegionOp::inferResultRanges(
735 ArrayRef
<ConstantIntRanges
> argRanges
, SetIntRangeFn setResultRanges
) {
736 Value arg
= getRegion().getArgument(0);
737 setResultRanges(arg
, {getUmin(), getUmax(), getSmin(), getSmax()});
740 //===----------------------------------------------------------------------===//
743 void TestIncrementOp::inferResultRanges(ArrayRef
<ConstantIntRanges
> argRanges
,
744 SetIntRangeFn setResultRanges
) {
745 const ConstantIntRanges
&range
= argRanges
[0];
746 APInt
one(range
.umin().getBitWidth(), 1);
747 setResultRanges(getResult(),
748 {range
.umin().uadd_sat(one
), range
.umax().uadd_sat(one
),
749 range
.smin().sadd_sat(one
), range
.smax().sadd_sat(one
)});
752 //===----------------------------------------------------------------------===//
753 // TestReflectBoundsOp
755 void TestReflectBoundsOp::inferResultRanges(
756 ArrayRef
<ConstantIntRanges
> argRanges
, SetIntRangeFn setResultRanges
) {
757 const ConstantIntRanges
&range
= argRanges
[0];
758 MLIRContext
*ctx
= getContext();
761 // For plain `IntegerType`s, we can derive the appropriate signed and unsigned
762 // Types for the Attributes.
763 Type type
= getElementTypeOrSelf(getType());
764 if (auto intTy
= llvm::dyn_cast
<IntegerType
>(type
)) {
765 unsigned bitwidth
= intTy
.getWidth();
766 sIntTy
= b
.getIntegerType(bitwidth
, /*isSigned=*/true);
767 uIntTy
= b
.getIntegerType(bitwidth
, /*isSigned=*/false);
769 sIntTy
= uIntTy
= type
;
771 setUminAttr(b
.getIntegerAttr(uIntTy
, range
.umin()));
772 setUmaxAttr(b
.getIntegerAttr(uIntTy
, range
.umax()));
773 setSminAttr(b
.getIntegerAttr(sIntTy
, range
.smin()));
774 setSmaxAttr(b
.getIntegerAttr(sIntTy
, range
.smax()));
775 setResultRanges(getResult(), range
);
778 //===----------------------------------------------------------------------===//
780 //===----------------------------------------------------------------------===//
782 ParseResult
ConversionFuncOp::parse(OpAsmParser
&parser
,
783 OperationState
&result
) {
785 [](Builder
&builder
, ArrayRef
<Type
> argTypes
, ArrayRef
<Type
> results
,
786 function_interface_impl::VariadicFlag
,
787 std::string
&) { return builder
.getFunctionType(argTypes
, results
); };
789 return function_interface_impl::parseFunctionOp(
790 parser
, result
, /*allowVariadic=*/false,
791 getFunctionTypeAttrName(result
.name
), buildFuncType
,
792 getArgAttrsAttrName(result
.name
), getResAttrsAttrName(result
.name
));
795 void ConversionFuncOp::print(OpAsmPrinter
&p
) {
796 function_interface_impl::printFunctionOp(
797 p
, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
798 getArgAttrsAttrName(), getResAttrsAttrName());
801 //===----------------------------------------------------------------------===//
803 //===----------------------------------------------------------------------===//
805 mlir::presburger::BoundType
ReifyBoundOp::getBoundType() {
806 if (getType() == "EQ")
807 return mlir::presburger::BoundType::EQ
;
808 if (getType() == "LB")
809 return mlir::presburger::BoundType::LB
;
810 if (getType() == "UB")
811 return mlir::presburger::BoundType::UB
;
812 llvm_unreachable("invalid bound type");
815 LogicalResult
ReifyBoundOp::verify() {
816 if (isa
<ShapedType
>(getVar().getType())) {
817 if (!getDim().has_value())
818 return emitOpError("expected 'dim' attribute for shaped type variable");
819 } else if (getVar().getType().isIndex()) {
820 if (getDim().has_value())
821 return emitOpError("unexpected 'dim' attribute for index variable");
823 return emitOpError("expected index-typed variable or shape type variable");
825 if (getConstant() && getScalable())
826 return emitOpError("'scalable' and 'constant' are mutually exlusive");
827 if (getScalable() != getVscaleMin().has_value())
828 return emitOpError("expected 'vscale_min' if and only if 'scalable'");
829 if (getScalable() != getVscaleMax().has_value())
830 return emitOpError("expected 'vscale_min' if and only if 'scalable'");
834 ValueBoundsConstraintSet::Variable
ReifyBoundOp::getVariable() {
835 if (getDim().has_value())
836 return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
837 return ValueBoundsConstraintSet::Variable(getVar());
840 //===----------------------------------------------------------------------===//
842 //===----------------------------------------------------------------------===//
844 ValueBoundsConstraintSet::ComparisonOperator
845 CompareOp::getComparisonOperator() {
846 if (getCmp() == "EQ")
847 return ValueBoundsConstraintSet::ComparisonOperator::EQ
;
848 if (getCmp() == "LT")
849 return ValueBoundsConstraintSet::ComparisonOperator::LT
;
850 if (getCmp() == "LE")
851 return ValueBoundsConstraintSet::ComparisonOperator::LE
;
852 if (getCmp() == "GT")
853 return ValueBoundsConstraintSet::ComparisonOperator::GT
;
854 if (getCmp() == "GE")
855 return ValueBoundsConstraintSet::ComparisonOperator::GE
;
856 llvm_unreachable("invalid comparison operator");
859 mlir::ValueBoundsConstraintSet::Variable
CompareOp::getLhs() {
861 return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
862 SmallVector
<Value
> mapOperands(
863 getVarOperands().slice(0, getLhsMap()->getNumInputs()));
864 return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands
);
867 mlir::ValueBoundsConstraintSet::Variable
CompareOp::getRhs() {
868 int64_t rhsOperandsBegin
= getLhsMap() ? getLhsMap()->getNumInputs() : 1;
870 return ValueBoundsConstraintSet::Variable(
871 getVarOperands()[rhsOperandsBegin
]);
872 SmallVector
<Value
> mapOperands(
873 getVarOperands().slice(rhsOperandsBegin
, getRhsMap()->getNumInputs()));
874 return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands
);
877 LogicalResult
CompareOp::verify() {
878 if (getCompose() && (getLhsMap() || getRhsMap()))
880 "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
881 int64_t expectedNumOperands
= getLhsMap() ? getLhsMap()->getNumInputs() : 1;
882 expectedNumOperands
+= getRhsMap() ? getRhsMap()->getNumInputs() : 1;
883 if (getVarOperands().size() != size_t(expectedNumOperands
))
884 return emitOpError("expected ")
885 << expectedNumOperands
<< " operands, but got "
886 << getVarOperands().size();
890 //===----------------------------------------------------------------------===//
891 // TestOpInPlaceSelfFold
892 //===----------------------------------------------------------------------===//
894 OpFoldResult
TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor
) {
896 // The folder adds the "folded" if not present.
903 //===----------------------------------------------------------------------===//
904 // TestOpFoldWithFoldAdaptor
905 //===----------------------------------------------------------------------===//
907 OpFoldResult
TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor
) {
909 if (auto value
= dyn_cast_or_null
<IntegerAttr
>(adaptor
.getOp()))
910 sum
+= value
.getValue().getSExtValue();
912 for (Attribute attr
: adaptor
.getVariadic())
913 if (auto value
= dyn_cast_or_null
<IntegerAttr
>(attr
))
914 sum
+= 2 * value
.getValue().getSExtValue();
916 for (ArrayRef
<Attribute
> attrs
: adaptor
.getVarOfVar())
917 for (Attribute attr
: attrs
)
918 if (auto value
= dyn_cast_or_null
<IntegerAttr
>(attr
))
919 sum
+= 3 * value
.getValue().getSExtValue();
921 sum
+= 4 * std::distance(adaptor
.getBody().begin(), adaptor
.getBody().end());
923 return IntegerAttr::get(getType(), sum
);
926 //===----------------------------------------------------------------------===//
927 // OpWithInferTypeAdaptorInterfaceOp
928 //===----------------------------------------------------------------------===//
930 LogicalResult
OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
931 MLIRContext
*, std::optional
<Location
> location
,
932 OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor
,
933 SmallVectorImpl
<Type
> &inferredReturnTypes
) {
934 if (adaptor
.getX().getType() != adaptor
.getY().getType()) {
935 return emitOptionalError(location
, "operand type mismatch ",
936 adaptor
.getX().getType(), " vs ",
937 adaptor
.getY().getType());
939 inferredReturnTypes
.assign({adaptor
.getX().getType()});
943 //===----------------------------------------------------------------------===//
944 // OpWithRefineTypeInterfaceOp
945 //===----------------------------------------------------------------------===//
947 // TODO: We should be able to only define either inferReturnType or
948 // refineReturnType, currently only refineReturnType can be omitted.
949 LogicalResult
OpWithRefineTypeInterfaceOp::inferReturnTypes(
950 MLIRContext
*context
, std::optional
<Location
> location
, ValueRange operands
,
951 DictionaryAttr attributes
, OpaqueProperties properties
, RegionRange regions
,
952 SmallVectorImpl
<Type
> &returnTypes
) {
954 return OpWithRefineTypeInterfaceOp::refineReturnTypes(
955 context
, location
, operands
, attributes
, properties
, regions
,
959 LogicalResult
OpWithRefineTypeInterfaceOp::refineReturnTypes(
960 MLIRContext
*, std::optional
<Location
> location
, ValueRange operands
,
961 DictionaryAttr attributes
, OpaqueProperties properties
, RegionRange regions
,
962 SmallVectorImpl
<Type
> &returnTypes
) {
963 if (operands
[0].getType() != operands
[1].getType()) {
964 return emitOptionalError(location
, "operand type mismatch ",
965 operands
[0].getType(), " vs ",
966 operands
[1].getType());
968 // TODO: Add helper to make this more concise to write.
969 if (returnTypes
.empty())
970 returnTypes
.resize(1, nullptr);
971 if (returnTypes
[0] && returnTypes
[0] != operands
[0].getType())
972 return emitOptionalError(location
,
973 "required first operand and result to match");
974 returnTypes
[0] = operands
[0].getType();
978 //===----------------------------------------------------------------------===//
979 // OpWithShapedTypeInferTypeAdaptorInterfaceOp
980 //===----------------------------------------------------------------------===//
983 OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
984 MLIRContext
*context
, std::optional
<Location
> location
,
985 OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor
,
986 SmallVectorImpl
<ShapedTypeComponents
> &inferredReturnShapes
) {
987 // Create return type consisting of the last element of the first operand.
988 auto operandType
= adaptor
.getOperand1().getType();
989 auto sval
= dyn_cast
<ShapedType
>(operandType
);
991 return emitOptionalError(location
, "only shaped type operands allowed");
992 int64_t dim
= sval
.hasRank() ? sval
.getShape().front() : ShapedType::kDynamic
;
993 auto type
= IntegerType::get(context
, 17);
996 if (auto rankedTy
= dyn_cast
<RankedTensorType
>(sval
))
997 encoding
= rankedTy
.getEncoding();
998 inferredReturnShapes
.push_back(ShapedTypeComponents({dim
}, type
, encoding
));
1003 OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
1004 OpBuilder
&builder
, ValueRange operands
,
1005 llvm::SmallVectorImpl
<Value
> &shapes
) {
1006 shapes
= SmallVector
<Value
, 1>{
1007 builder
.createOrFold
<tensor::DimOp
>(getLoc(), operands
.front(), 0)};
1011 //===----------------------------------------------------------------------===//
1012 // TestOpWithPropertiesAndInferredType
1013 //===----------------------------------------------------------------------===//
1015 LogicalResult
TestOpWithPropertiesAndInferredType::inferReturnTypes(
1016 MLIRContext
*context
, std::optional
<Location
>, ValueRange operands
,
1017 DictionaryAttr attributes
, OpaqueProperties properties
, RegionRange regions
,
1018 SmallVectorImpl
<Type
> &inferredReturnTypes
) {
1020 Adaptor
adaptor(operands
, attributes
, properties
, regions
);
1021 inferredReturnTypes
.push_back(IntegerType::get(
1022 context
, adaptor
.getLhs() + adaptor
.getProperties().rhs
));
1026 //===----------------------------------------------------------------------===//
1028 //===----------------------------------------------------------------------===//
1030 void LoopBlockOp::getSuccessorRegions(
1031 RegionBranchPoint point
, SmallVectorImpl
<RegionSuccessor
> ®ions
) {
1032 regions
.emplace_back(&getBody(), getBody().getArguments());
1033 if (point
.isParent())
1036 regions
.emplace_back((*this)->getResults());
1039 OperandRange
LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point
) {
1040 assert(point
== getBody());
1041 return MutableOperandRange(getInitMutable());
1044 //===----------------------------------------------------------------------===//
1045 // LoopBlockTerminatorOp
1046 //===----------------------------------------------------------------------===//
1049 LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point
) {
1050 if (point
.isParent())
1051 return getExitArgMutable();
1052 return getNextIterArgMutable();
1055 //===----------------------------------------------------------------------===//
1056 // SwitchWithNoBreakOp
1057 //===----------------------------------------------------------------------===//
1059 void TestNoTerminatorOp::getSuccessorRegions(
1060 RegionBranchPoint point
, SmallVectorImpl
<RegionSuccessor
> ®ions
) {}
1062 //===----------------------------------------------------------------------===//
1063 // Test InferIntRangeInterface
1064 //===----------------------------------------------------------------------===//
1066 OpFoldResult
ManualCppOpWithFold::fold(ArrayRef
<Attribute
> attributes
) {
1067 // Just a simple fold for testing purposes that reads an operands constant
1068 // value and returns it.
1069 if (!attributes
.empty())
1070 return attributes
.front();
1074 //===----------------------------------------------------------------------===//
1075 // Tensor/Buffer Ops
1076 //===----------------------------------------------------------------------===//
1078 void ReadBufferOp::getEffects(
1079 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
1081 // The buffer operand is read.
1082 effects
.emplace_back(MemoryEffects::Read::get(), &getBufferMutable(),
1083 SideEffects::DefaultResource::get());
1084 // The buffer contents are dumped.
1085 effects
.emplace_back(MemoryEffects::Write::get(),
1086 SideEffects::DefaultResource::get());
1089 //===----------------------------------------------------------------------===//
1091 //===----------------------------------------------------------------------===//
1093 //===----------------------------------------------------------------------===//
1094 // TestCallAndStoreOp
1096 CallInterfaceCallable
TestCallAndStoreOp::getCallableForCallee() {
1100 void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee
) {
1101 setCalleeAttr(callee
.get
<SymbolRefAttr
>());
1104 Operation::operand_range
TestCallAndStoreOp::getArgOperands() {
1105 return getCalleeOperands();
1108 MutableOperandRange
TestCallAndStoreOp::getArgOperandsMutable() {
1109 return getCalleeOperandsMutable();
1112 //===----------------------------------------------------------------------===//
1113 // TestCallOnDeviceOp
1115 CallInterfaceCallable
TestCallOnDeviceOp::getCallableForCallee() {
1119 void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee
) {
1120 setCalleeAttr(callee
.get
<SymbolRefAttr
>());
1123 Operation::operand_range
TestCallOnDeviceOp::getArgOperands() {
1124 return getForwardedOperands();
1127 MutableOperandRange
TestCallOnDeviceOp::getArgOperandsMutable() {
1128 return getForwardedOperandsMutable();
1131 //===----------------------------------------------------------------------===//
1132 // TestStoreWithARegion
1134 void TestStoreWithARegion::getSuccessorRegions(
1135 RegionBranchPoint point
, SmallVectorImpl
<RegionSuccessor
> ®ions
) {
1136 if (point
.isParent())
1137 regions
.emplace_back(&getBody(), getBody().front().getArguments());
1139 regions
.emplace_back();
1142 //===----------------------------------------------------------------------===//
1143 // TestStoreWithALoopRegion
1145 void TestStoreWithALoopRegion::getSuccessorRegions(
1146 RegionBranchPoint point
, SmallVectorImpl
<RegionSuccessor
> ®ions
) {
1147 // Both the operation itself and the region may be branching into the body or
1148 // back into the operation itself. It is possible for the operation not to
1150 regions
.emplace_back(
1151 RegionSuccessor(&getBody(), getBody().front().getArguments()));
1152 regions
.emplace_back();
1155 //===----------------------------------------------------------------------===//
1157 //===----------------------------------------------------------------------===//
1160 TestVersionedOpA::readProperties(mlir::DialectBytecodeReader
&reader
,
1161 mlir::OperationState
&state
) {
1162 auto &prop
= state
.getOrAddProperties
<Properties
>();
1163 if (mlir::failed(reader
.readAttribute(prop
.dims
)))
1164 return mlir::failure();
1166 // Check if we have a version. If not, assume we are parsing the current
1168 auto maybeVersion
= reader
.getDialectVersion
<test::TestDialect
>();
1169 if (succeeded(maybeVersion
)) {
1170 // If version is less than 2.0, there is no additional attribute to parse.
1171 // We can materialize missing properties post parsing before verification.
1172 const auto *version
=
1173 reinterpret_cast<const TestDialectVersion
*>(*maybeVersion
);
1174 if ((version
->major_
< 2)) {
1179 if (mlir::failed(reader
.readAttribute(prop
.modifier
)))
1180 return mlir::failure();
1181 return mlir::success();
1184 void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter
&writer
) {
1185 auto &prop
= getProperties();
1186 writer
.writeAttribute(prop
.dims
);
1188 auto maybeVersion
= writer
.getDialectVersion
<test::TestDialect
>();
1189 if (succeeded(maybeVersion
)) {
1190 // If version is less than 2.0, there is no additional attribute to write.
1191 const auto *version
=
1192 reinterpret_cast<const TestDialectVersion
*>(*maybeVersion
);
1193 if ((version
->major_
< 2)) {
1194 llvm::outs() << "downgrading op properties...\n";
1198 writer
.writeAttribute(prop
.modifier
);
1201 //===----------------------------------------------------------------------===//
1202 // TestOpWithVersionedProperties
1203 //===----------------------------------------------------------------------===//
1205 llvm::LogicalResult
TestOpWithVersionedProperties::readFromMlirBytecode(
1206 mlir::DialectBytecodeReader
&reader
, test::VersionedProperties
&prop
) {
1207 uint64_t value1
, value2
= 0;
1208 if (failed(reader
.readVarInt(value1
)))
1211 // Check if we have a version. If not, assume we are parsing the current
1213 auto maybeVersion
= reader
.getDialectVersion
<test::TestDialect
>();
1214 bool needToParseAnotherInt
= true;
1215 if (succeeded(maybeVersion
)) {
1216 // If version is less than 2.0, there is no additional attribute to parse.
1217 // We can materialize missing properties post parsing before verification.
1218 const auto *version
=
1219 reinterpret_cast<const TestDialectVersion
*>(*maybeVersion
);
1220 if ((version
->major_
< 2))
1221 needToParseAnotherInt
= false;
1223 if (needToParseAnotherInt
&& failed(reader
.readVarInt(value2
)))
1226 prop
.value1
= value1
;
1227 prop
.value2
= value2
;
1231 void TestOpWithVersionedProperties::writeToMlirBytecode(
1232 mlir::DialectBytecodeWriter
&writer
,
1233 const test::VersionedProperties
&prop
) {
1234 writer
.writeVarInt(prop
.value1
);
1235 writer
.writeVarInt(prop
.value2
);
1238 //===----------------------------------------------------------------------===//
1239 // TestMultiSlotAlloca
1240 //===----------------------------------------------------------------------===//
1242 llvm::SmallVector
<MemorySlot
> TestMultiSlotAlloca::getPromotableSlots() {
1243 SmallVector
<MemorySlot
> slots
;
1244 for (Value result
: getResults()) {
1245 slots
.push_back(MemorySlot
{
1246 result
, cast
<MemRefType
>(result
.getType()).getElementType()});
1251 Value
TestMultiSlotAlloca::getDefaultValue(const MemorySlot
&slot
,
1252 OpBuilder
&builder
) {
1253 return builder
.create
<TestOpConstant
>(getLoc(), slot
.elemType
,
1254 builder
.getI32IntegerAttr(42));
1257 void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot
&slot
,
1258 BlockArgument argument
,
1259 OpBuilder
&builder
) {
1260 // Not relevant for testing.
1263 /// Creates a new TestMultiSlotAlloca operation, just without the `slot`.
1264 static std::optional
<TestMultiSlotAlloca
>
1265 createNewMultiAllocaWithoutSlot(const MemorySlot
&slot
, OpBuilder
&builder
,
1266 TestMultiSlotAlloca oldOp
) {
1268 if (oldOp
.getNumResults() == 1) {
1270 return std::nullopt
;
1273 SmallVector
<Type
> newTypes
;
1274 SmallVector
<Value
> remainingValues
;
1276 for (Value oldResult
: oldOp
.getResults()) {
1277 if (oldResult
== slot
.ptr
)
1279 remainingValues
.push_back(oldResult
);
1280 newTypes
.push_back(oldResult
.getType());
1283 OpBuilder::InsertionGuard
guard(builder
);
1284 builder
.setInsertionPoint(oldOp
);
1286 builder
.create
<TestMultiSlotAlloca
>(oldOp
->getLoc(), newTypes
);
1287 for (auto [oldResult
, newResult
] :
1288 llvm::zip_equal(remainingValues
, replacement
.getResults()))
1289 oldResult
.replaceAllUsesWith(newResult
);
1295 std::optional
<PromotableAllocationOpInterface
>
1296 TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot
&slot
,
1298 OpBuilder
&builder
) {
1299 if (defaultValue
&& defaultValue
.use_empty())
1300 defaultValue
.getDefiningOp()->erase();
1301 return createNewMultiAllocaWithoutSlot(slot
, builder
, *this);
1304 SmallVector
<DestructurableMemorySlot
>
1305 TestMultiSlotAlloca::getDestructurableSlots() {
1306 SmallVector
<DestructurableMemorySlot
> slots
;
1307 for (Value result
: getResults()) {
1308 auto memrefType
= cast
<MemRefType
>(result
.getType());
1309 auto destructurable
= dyn_cast
<DestructurableTypeInterface
>(memrefType
);
1310 if (!destructurable
)
1313 std::optional
<DenseMap
<Attribute
, Type
>> destructuredType
=
1314 destructurable
.getSubelementIndexMap();
1315 if (!destructuredType
)
1318 DestructurableMemorySlot
{{result
, memrefType
}, *destructuredType
});
1323 DenseMap
<Attribute
, MemorySlot
> TestMultiSlotAlloca::destructure(
1324 const DestructurableMemorySlot
&slot
,
1325 const SmallPtrSetImpl
<Attribute
> &usedIndices
, OpBuilder
&builder
,
1326 SmallVectorImpl
<DestructurableAllocationOpInterface
> &newAllocators
) {
1327 OpBuilder::InsertionGuard
guard(builder
);
1328 builder
.setInsertionPointAfter(*this);
1330 DenseMap
<Attribute
, MemorySlot
> slotMap
;
1332 for (Attribute usedIndex
: usedIndices
) {
1333 Type elemType
= slot
.subelementTypes
.lookup(usedIndex
);
1334 MemRefType elemPtr
= MemRefType::get({}, elemType
);
1335 auto subAlloca
= builder
.create
<TestMultiSlotAlloca
>(getLoc(), elemPtr
);
1336 newAllocators
.push_back(subAlloca
);
1337 slotMap
.try_emplace
<MemorySlot
>(usedIndex
,
1338 {subAlloca
.getResult(0), elemType
});
1344 std::optional
<DestructurableAllocationOpInterface
>
1345 TestMultiSlotAlloca::handleDestructuringComplete(
1346 const DestructurableMemorySlot
&slot
, OpBuilder
&builder
) {
1347 return createNewMultiAllocaWithoutSlot(slot
, builder
, *this);