1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/Builders.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/SymbolTable.h"
18 #include "llvm/ADT/SmallVectorExtras.h"
19 #include "llvm/Support/raw_ostream.h"
23 //===----------------------------------------------------------------------===//
25 //===----------------------------------------------------------------------===//
27 Location
Builder::getUnknownLoc() { return UnknownLoc::get(context
); }
29 Location
Builder::getFusedLoc(ArrayRef
<Location
> locs
, Attribute metadata
) {
30 return FusedLoc::get(locs
, metadata
, context
);
33 //===----------------------------------------------------------------------===//
35 //===----------------------------------------------------------------------===//
37 FloatType
Builder::getFloat4E2M1FNType() {
38 return FloatType::getFloat4E2M1FN(context
);
41 FloatType
Builder::getFloat6E2M3FNType() {
42 return FloatType::getFloat6E2M3FN(context
);
45 FloatType
Builder::getFloat6E3M2FNType() {
46 return FloatType::getFloat6E3M2FN(context
);
49 FloatType
Builder::getFloat8E5M2Type() {
50 return FloatType::getFloat8E5M2(context
);
53 FloatType
Builder::getFloat8E4M3Type() {
54 return FloatType::getFloat8E4M3(context
);
57 FloatType
Builder::getFloat8E4M3FNType() {
58 return FloatType::getFloat8E4M3FN(context
);
61 FloatType
Builder::getFloat8E5M2FNUZType() {
62 return FloatType::getFloat8E5M2FNUZ(context
);
65 FloatType
Builder::getFloat8E4M3FNUZType() {
66 return FloatType::getFloat8E4M3FNUZ(context
);
69 FloatType
Builder::getFloat8E4M3B11FNUZType() {
70 return FloatType::getFloat8E4M3B11FNUZ(context
);
73 FloatType
Builder::getFloat8E3M4Type() {
74 return FloatType::getFloat8E3M4(context
);
77 FloatType
Builder::getFloat8E8M0FNUType() {
78 return FloatType::getFloat8E8M0FNU(context
);
81 FloatType
Builder::getBF16Type() { return FloatType::getBF16(context
); }
83 FloatType
Builder::getF16Type() { return FloatType::getF16(context
); }
85 FloatType
Builder::getTF32Type() { return FloatType::getTF32(context
); }
87 FloatType
Builder::getF32Type() { return FloatType::getF32(context
); }
89 FloatType
Builder::getF64Type() { return FloatType::getF64(context
); }
91 FloatType
Builder::getF80Type() { return FloatType::getF80(context
); }
93 FloatType
Builder::getF128Type() { return FloatType::getF128(context
); }
95 IndexType
Builder::getIndexType() { return IndexType::get(context
); }
97 IntegerType
Builder::getI1Type() { return IntegerType::get(context
, 1); }
99 IntegerType
Builder::getI2Type() { return IntegerType::get(context
, 2); }
101 IntegerType
Builder::getI4Type() { return IntegerType::get(context
, 4); }
103 IntegerType
Builder::getI8Type() { return IntegerType::get(context
, 8); }
105 IntegerType
Builder::getI16Type() { return IntegerType::get(context
, 16); }
107 IntegerType
Builder::getI32Type() { return IntegerType::get(context
, 32); }
109 IntegerType
Builder::getI64Type() { return IntegerType::get(context
, 64); }
111 IntegerType
Builder::getIntegerType(unsigned width
) {
112 return IntegerType::get(context
, width
);
115 IntegerType
Builder::getIntegerType(unsigned width
, bool isSigned
) {
116 return IntegerType::get(
117 context
, width
, isSigned
? IntegerType::Signed
: IntegerType::Unsigned
);
120 FunctionType
Builder::getFunctionType(TypeRange inputs
, TypeRange results
) {
121 return FunctionType::get(context
, inputs
, results
);
124 TupleType
Builder::getTupleType(TypeRange elementTypes
) {
125 return TupleType::get(context
, elementTypes
);
128 NoneType
Builder::getNoneType() { return NoneType::get(context
); }
130 //===----------------------------------------------------------------------===//
132 //===----------------------------------------------------------------------===//
134 NamedAttribute
Builder::getNamedAttr(StringRef name
, Attribute val
) {
135 return NamedAttribute(getStringAttr(name
), val
);
138 UnitAttr
Builder::getUnitAttr() { return UnitAttr::get(context
); }
140 BoolAttr
Builder::getBoolAttr(bool value
) {
141 return BoolAttr::get(context
, value
);
144 DictionaryAttr
Builder::getDictionaryAttr(ArrayRef
<NamedAttribute
> value
) {
145 return DictionaryAttr::get(context
, value
);
148 IntegerAttr
Builder::getIndexAttr(int64_t value
) {
149 return IntegerAttr::get(getIndexType(), APInt(64, value
));
152 IntegerAttr
Builder::getI64IntegerAttr(int64_t value
) {
153 return IntegerAttr::get(getIntegerType(64), APInt(64, value
));
156 DenseIntElementsAttr
Builder::getBoolVectorAttr(ArrayRef
<bool> values
) {
157 return DenseIntElementsAttr::get(
158 VectorType::get(static_cast<int64_t>(values
.size()), getI1Type()),
162 DenseIntElementsAttr
Builder::getI32VectorAttr(ArrayRef
<int32_t> values
) {
163 return DenseIntElementsAttr::get(
164 VectorType::get(static_cast<int64_t>(values
.size()), getIntegerType(32)),
168 DenseIntElementsAttr
Builder::getI64VectorAttr(ArrayRef
<int64_t> values
) {
169 return DenseIntElementsAttr::get(
170 VectorType::get(static_cast<int64_t>(values
.size()), getIntegerType(64)),
174 DenseIntElementsAttr
Builder::getIndexVectorAttr(ArrayRef
<int64_t> values
) {
175 return DenseIntElementsAttr::get(
176 VectorType::get(static_cast<int64_t>(values
.size()), getIndexType()),
180 DenseFPElementsAttr
Builder::getF32VectorAttr(ArrayRef
<float> values
) {
181 return DenseFPElementsAttr::get(
182 VectorType::get(static_cast<float>(values
.size()), getF32Type()), values
);
185 DenseFPElementsAttr
Builder::getF64VectorAttr(ArrayRef
<double> values
) {
186 return DenseFPElementsAttr::get(
187 VectorType::get(static_cast<double>(values
.size()), getF64Type()),
191 DenseBoolArrayAttr
Builder::getDenseBoolArrayAttr(ArrayRef
<bool> values
) {
192 return DenseBoolArrayAttr::get(context
, values
);
195 DenseI8ArrayAttr
Builder::getDenseI8ArrayAttr(ArrayRef
<int8_t> values
) {
196 return DenseI8ArrayAttr::get(context
, values
);
199 DenseI16ArrayAttr
Builder::getDenseI16ArrayAttr(ArrayRef
<int16_t> values
) {
200 return DenseI16ArrayAttr::get(context
, values
);
203 DenseI32ArrayAttr
Builder::getDenseI32ArrayAttr(ArrayRef
<int32_t> values
) {
204 return DenseI32ArrayAttr::get(context
, values
);
207 DenseI64ArrayAttr
Builder::getDenseI64ArrayAttr(ArrayRef
<int64_t> values
) {
208 return DenseI64ArrayAttr::get(context
, values
);
211 DenseF32ArrayAttr
Builder::getDenseF32ArrayAttr(ArrayRef
<float> values
) {
212 return DenseF32ArrayAttr::get(context
, values
);
215 DenseF64ArrayAttr
Builder::getDenseF64ArrayAttr(ArrayRef
<double> values
) {
216 return DenseF64ArrayAttr::get(context
, values
);
219 DenseIntElementsAttr
Builder::getI32TensorAttr(ArrayRef
<int32_t> values
) {
220 return DenseIntElementsAttr::get(
221 RankedTensorType::get(static_cast<int64_t>(values
.size()),
226 DenseIntElementsAttr
Builder::getI64TensorAttr(ArrayRef
<int64_t> values
) {
227 return DenseIntElementsAttr::get(
228 RankedTensorType::get(static_cast<int64_t>(values
.size()),
233 DenseIntElementsAttr
Builder::getIndexTensorAttr(ArrayRef
<int64_t> values
) {
234 return DenseIntElementsAttr::get(
235 RankedTensorType::get(static_cast<int64_t>(values
.size()),
240 IntegerAttr
Builder::getI32IntegerAttr(int32_t value
) {
241 // The APInt always uses isSigned=true here because we accept the value
243 return IntegerAttr::get(getIntegerType(32),
244 APInt(32, value
, /*isSigned=*/true));
247 IntegerAttr
Builder::getSI32IntegerAttr(int32_t value
) {
248 return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
249 APInt(32, value
, /*isSigned=*/true));
252 IntegerAttr
Builder::getUI32IntegerAttr(uint32_t value
) {
253 return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
254 APInt(32, (uint64_t)value
, /*isSigned=*/false));
257 IntegerAttr
Builder::getI16IntegerAttr(int16_t value
) {
258 return IntegerAttr::get(getIntegerType(16), APInt(16, value
));
261 IntegerAttr
Builder::getI8IntegerAttr(int8_t value
) {
262 // The APInt always uses isSigned=true here because we accept the value
264 return IntegerAttr::get(getIntegerType(8),
265 APInt(8, value
, /*isSigned=*/true));
268 IntegerAttr
Builder::getIntegerAttr(Type type
, int64_t value
) {
270 return IntegerAttr::get(type
, APInt(64, value
));
271 // TODO: Avoid implicit trunc?
272 // See https://github.com/llvm/llvm-project/issues/112510.
273 return IntegerAttr::get(type
, APInt(type
.getIntOrFloatBitWidth(), value
,
274 type
.isSignedInteger(),
275 /*implicitTrunc=*/true));
278 IntegerAttr
Builder::getIntegerAttr(Type type
, const APInt
&value
) {
279 return IntegerAttr::get(type
, value
);
282 FloatAttr
Builder::getF64FloatAttr(double value
) {
283 return FloatAttr::get(getF64Type(), APFloat(value
));
286 FloatAttr
Builder::getF32FloatAttr(float value
) {
287 return FloatAttr::get(getF32Type(), APFloat(value
));
290 FloatAttr
Builder::getF16FloatAttr(float value
) {
291 return FloatAttr::get(getF16Type(), value
);
294 FloatAttr
Builder::getFloatAttr(Type type
, double value
) {
295 return FloatAttr::get(type
, value
);
298 FloatAttr
Builder::getFloatAttr(Type type
, const APFloat
&value
) {
299 return FloatAttr::get(type
, value
);
302 StringAttr
Builder::getStringAttr(const Twine
&bytes
) {
303 return StringAttr::get(context
, bytes
);
306 ArrayAttr
Builder::getArrayAttr(ArrayRef
<Attribute
> value
) {
307 return ArrayAttr::get(context
, value
);
310 ArrayAttr
Builder::getBoolArrayAttr(ArrayRef
<bool> values
) {
311 auto attrs
= llvm::map_to_vector
<8>(
312 values
, [this](bool v
) -> Attribute
{ return getBoolAttr(v
); });
313 return getArrayAttr(attrs
);
316 ArrayAttr
Builder::getI32ArrayAttr(ArrayRef
<int32_t> values
) {
317 auto attrs
= llvm::map_to_vector
<8>(
318 values
, [this](int32_t v
) -> Attribute
{ return getI32IntegerAttr(v
); });
319 return getArrayAttr(attrs
);
321 ArrayAttr
Builder::getI64ArrayAttr(ArrayRef
<int64_t> values
) {
322 auto attrs
= llvm::map_to_vector
<8>(
323 values
, [this](int64_t v
) -> Attribute
{ return getI64IntegerAttr(v
); });
324 return getArrayAttr(attrs
);
327 ArrayAttr
Builder::getIndexArrayAttr(ArrayRef
<int64_t> values
) {
328 auto attrs
= llvm::map_to_vector
<8>(values
, [this](int64_t v
) -> Attribute
{
329 return getIntegerAttr(IndexType::get(getContext()), v
);
331 return getArrayAttr(attrs
);
334 ArrayAttr
Builder::getF32ArrayAttr(ArrayRef
<float> values
) {
335 auto attrs
= llvm::map_to_vector
<8>(
336 values
, [this](float v
) -> Attribute
{ return getF32FloatAttr(v
); });
337 return getArrayAttr(attrs
);
340 ArrayAttr
Builder::getF64ArrayAttr(ArrayRef
<double> values
) {
341 auto attrs
= llvm::map_to_vector
<8>(
342 values
, [this](double v
) -> Attribute
{ return getF64FloatAttr(v
); });
343 return getArrayAttr(attrs
);
346 ArrayAttr
Builder::getStrArrayAttr(ArrayRef
<StringRef
> values
) {
347 auto attrs
= llvm::map_to_vector
<8>(
348 values
, [this](StringRef v
) -> Attribute
{ return getStringAttr(v
); });
349 return getArrayAttr(attrs
);
352 ArrayAttr
Builder::getTypeArrayAttr(TypeRange values
) {
353 auto attrs
= llvm::map_to_vector
<8>(
354 values
, [](Type v
) -> Attribute
{ return TypeAttr::get(v
); });
355 return getArrayAttr(attrs
);
358 ArrayAttr
Builder::getAffineMapArrayAttr(ArrayRef
<AffineMap
> values
) {
359 auto attrs
= llvm::map_to_vector
<8>(
360 values
, [](AffineMap v
) -> Attribute
{ return AffineMapAttr::get(v
); });
361 return getArrayAttr(attrs
);
364 TypedAttr
Builder::getZeroAttr(Type type
) {
365 if (llvm::isa
<FloatType
>(type
))
366 return getFloatAttr(type
, 0.0);
367 if (llvm::isa
<IndexType
>(type
))
368 return getIndexAttr(0);
369 if (llvm::dyn_cast
<IntegerType
>(type
))
370 return getIntegerAttr(type
,
371 APInt(llvm::cast
<IntegerType
>(type
).getWidth(), 0));
372 if (llvm::isa
<RankedTensorType
, VectorType
>(type
)) {
373 auto vtType
= llvm::cast
<ShapedType
>(type
);
374 auto element
= getZeroAttr(vtType
.getElementType());
377 return DenseElementsAttr::get(vtType
, element
);
382 TypedAttr
Builder::getOneAttr(Type type
) {
383 if (llvm::isa
<FloatType
>(type
))
384 return getFloatAttr(type
, 1.0);
385 if (llvm::isa
<IndexType
>(type
))
386 return getIndexAttr(1);
387 if (llvm::dyn_cast
<IntegerType
>(type
))
388 return getIntegerAttr(type
,
389 APInt(llvm::cast
<IntegerType
>(type
).getWidth(), 1));
390 if (llvm::isa
<RankedTensorType
, VectorType
>(type
)) {
391 auto vtType
= llvm::cast
<ShapedType
>(type
);
392 auto element
= getOneAttr(vtType
.getElementType());
395 return DenseElementsAttr::get(vtType
, element
);
400 //===----------------------------------------------------------------------===//
401 // Affine Expressions, Affine Maps, and Integer Sets.
402 //===----------------------------------------------------------------------===//
404 AffineExpr
Builder::getAffineDimExpr(unsigned position
) {
405 return mlir::getAffineDimExpr(position
, context
);
408 AffineExpr
Builder::getAffineSymbolExpr(unsigned position
) {
409 return mlir::getAffineSymbolExpr(position
, context
);
412 AffineExpr
Builder::getAffineConstantExpr(int64_t constant
) {
413 return mlir::getAffineConstantExpr(constant
, context
);
416 AffineMap
Builder::getEmptyAffineMap() { return AffineMap::get(context
); }
418 AffineMap
Builder::getConstantAffineMap(int64_t val
) {
419 return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
420 getAffineConstantExpr(val
));
423 AffineMap
Builder::getDimIdentityMap() {
424 return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
427 AffineMap
Builder::getMultiDimIdentityMap(unsigned rank
) {
428 SmallVector
<AffineExpr
, 4> dimExprs
;
429 dimExprs
.reserve(rank
);
430 for (unsigned i
= 0; i
< rank
; ++i
)
431 dimExprs
.push_back(getAffineDimExpr(i
));
432 return AffineMap::get(/*dimCount=*/rank
, /*symbolCount=*/0, dimExprs
,
436 AffineMap
Builder::getSymbolIdentityMap() {
437 return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
438 getAffineSymbolExpr(0));
441 AffineMap
Builder::getSingleDimShiftAffineMap(int64_t shift
) {
442 // expr = d0 + shift.
443 auto expr
= getAffineDimExpr(0) + shift
;
444 return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr
);
447 AffineMap
Builder::getShiftedAffineMap(AffineMap map
, int64_t shift
) {
448 SmallVector
<AffineExpr
, 4> shiftedResults
;
449 shiftedResults
.reserve(map
.getNumResults());
450 for (auto resultExpr
: map
.getResults())
451 shiftedResults
.push_back(resultExpr
+ shift
);
452 return AffineMap::get(map
.getNumDims(), map
.getNumSymbols(), shiftedResults
,
456 //===----------------------------------------------------------------------===//
458 //===----------------------------------------------------------------------===//
460 /// Insert the given operation at the current insertion point and return it.
461 Operation
*OpBuilder::insert(Operation
*op
) {
463 block
->getOperations().insert(insertPoint
, op
);
465 listener
->notifyOperationInserted(op
, /*previous=*/{});
470 Block
*OpBuilder::createBlock(Region
*parent
, Region::iterator insertPt
,
471 TypeRange argTypes
, ArrayRef
<Location
> locs
) {
472 assert(parent
&& "expected valid parent region");
473 assert(argTypes
.size() == locs
.size() && "argument location mismatch");
474 if (insertPt
== Region::iterator())
475 insertPt
= parent
->end();
477 Block
*b
= new Block();
478 b
->addArguments(argTypes
, locs
);
479 parent
->getBlocks().insert(insertPt
, b
);
480 setInsertionPointToEnd(b
);
483 listener
->notifyBlockInserted(b
, /*previous=*/nullptr, /*previousIt=*/{});
487 /// Add new block with 'argTypes' arguments and set the insertion point to the
488 /// end of it. The block is placed before 'insertBefore'.
489 Block
*OpBuilder::createBlock(Block
*insertBefore
, TypeRange argTypes
,
490 ArrayRef
<Location
> locs
) {
491 assert(insertBefore
&& "expected valid insertion block");
492 return createBlock(insertBefore
->getParent(), Region::iterator(insertBefore
),
496 /// Create an operation given the fields represented as an OperationState.
497 Operation
*OpBuilder::create(const OperationState
&state
) {
498 return insert(Operation::create(state
));
501 /// Creates an operation with the given fields.
502 Operation
*OpBuilder::create(Location loc
, StringAttr opName
,
503 ValueRange operands
, TypeRange types
,
504 ArrayRef
<NamedAttribute
> attributes
,
505 BlockRange successors
,
506 MutableArrayRef
<std::unique_ptr
<Region
>> regions
) {
507 OperationState
state(loc
, opName
, operands
, types
, attributes
, successors
,
509 return create(state
);
512 LogicalResult
OpBuilder::tryFold(Operation
*op
,
513 SmallVectorImpl
<Value
> &results
) {
514 assert(results
.empty() && "expected empty results");
515 ResultRange opResults
= op
->getResults();
517 results
.reserve(opResults
.size());
518 auto cleanupFailure
= [&] {
523 // If this operation is already a constant, there is nothing to do.
524 if (matchPattern(op
, m_Constant()))
525 return cleanupFailure();
527 // Try to fold the operation.
528 SmallVector
<OpFoldResult
, 4> foldResults
;
529 if (failed(op
->fold(foldResults
)))
530 return cleanupFailure();
532 // An in-place fold does not require generation of any constants.
533 if (foldResults
.empty())
536 // A temporary builder used for creating constants during folding.
537 OpBuilder
cstBuilder(context
);
538 SmallVector
<Operation
*, 1> generatedConstants
;
540 // Populate the results with the folded results.
541 Dialect
*dialect
= op
->getDialect();
542 for (auto [foldResult
, expectedType
] :
543 llvm::zip_equal(foldResults
, opResults
.getTypes())) {
545 // Normal values get pushed back directly.
546 if (auto value
= llvm::dyn_cast_if_present
<Value
>(foldResult
)) {
547 results
.push_back(value
);
551 // Otherwise, try to materialize a constant operation.
553 return cleanupFailure();
555 // Ask the dialect to materialize a constant operation for this value.
556 Attribute attr
= foldResult
.get
<Attribute
>();
557 auto *constOp
= dialect
->materializeConstant(cstBuilder
, attr
, expectedType
,
560 // Erase any generated constants.
561 for (Operation
*cst
: generatedConstants
)
563 return cleanupFailure();
565 assert(matchPattern(constOp
, m_Constant()));
567 generatedConstants
.push_back(constOp
);
568 results
.push_back(constOp
->getResult(0));
571 // If we were successful, insert any generated constants.
572 for (Operation
*cst
: generatedConstants
)
578 /// Helper function that sends block insertion notifications for every block
579 /// that is directly nested in the given op.
580 static void notifyBlockInsertions(Operation
*op
,
581 OpBuilder::Listener
*listener
) {
582 for (Region
&r
: op
->getRegions())
583 for (Block
&b
: r
.getBlocks())
584 listener
->notifyBlockInserted(&b
, /*previous=*/nullptr,
588 Operation
*OpBuilder::clone(Operation
&op
, IRMapping
&mapper
) {
589 Operation
*newOp
= op
.clone(mapper
);
590 newOp
= insert(newOp
);
592 // The `insert` call above handles the notification for inserting `newOp`
593 // itself. But if `newOp` has any regions, we need to notify the listener
594 // about any ops that got inserted inside those regions as part of cloning.
596 // The `insert` call above notifies about op insertion, but not about block
598 notifyBlockInsertions(newOp
, listener
);
599 auto walkFn
= [&](Operation
*walkedOp
) {
600 listener
->notifyOperationInserted(walkedOp
, /*previous=*/{});
601 notifyBlockInsertions(walkedOp
, listener
);
603 for (Region
®ion
: newOp
->getRegions())
604 region
.walk
<WalkOrder::PreOrder
>(walkFn
);
610 Operation
*OpBuilder::clone(Operation
&op
) {
612 return clone(op
, mapper
);
615 void OpBuilder::cloneRegionBefore(Region
®ion
, Region
&parent
,
616 Region::iterator before
, IRMapping
&mapping
) {
617 region
.cloneInto(&parent
, before
, mapping
);
619 // Fast path: If no listener is attached, there is no more work to do.
623 // Notify about op/block insertion.
624 for (auto it
= mapping
.lookup(®ion
.front())->getIterator(); it
!= before
;
626 listener
->notifyBlockInserted(&*it
, /*previous=*/nullptr,
628 it
->walk
<WalkOrder::PreOrder
>([&](Operation
*walkedOp
) {
629 listener
->notifyOperationInserted(walkedOp
, /*previous=*/{});
630 notifyBlockInsertions(walkedOp
, listener
);
635 void OpBuilder::cloneRegionBefore(Region
®ion
, Region
&parent
,
636 Region::iterator before
) {
638 cloneRegionBefore(region
, parent
, before
, mapping
);
641 void OpBuilder::cloneRegionBefore(Region
®ion
, Block
*before
) {
642 cloneRegionBefore(region
, *before
->getParent(), before
->getIterator());