1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/Builders.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/SymbolTable.h"
18 #include "llvm/ADT/SmallVectorExtras.h"
19 #include "llvm/Support/raw_ostream.h"
23 //===----------------------------------------------------------------------===//
25 //===----------------------------------------------------------------------===//
27 Location
Builder::getUnknownLoc() { return UnknownLoc::get(context
); }
29 Location
Builder::getFusedLoc(ArrayRef
<Location
> locs
, Attribute metadata
) {
30 return FusedLoc::get(locs
, metadata
, context
);
33 //===----------------------------------------------------------------------===//
35 //===----------------------------------------------------------------------===//
37 FloatType
Builder::getBF16Type() { return BFloat16Type::get(context
); }
39 FloatType
Builder::getF16Type() { return Float16Type::get(context
); }
41 FloatType
Builder::getTF32Type() { return FloatTF32Type::get(context
); }
43 FloatType
Builder::getF32Type() { return Float32Type::get(context
); }
45 FloatType
Builder::getF64Type() { return Float64Type::get(context
); }
47 FloatType
Builder::getF80Type() { return Float80Type::get(context
); }
49 FloatType
Builder::getF128Type() { return Float128Type::get(context
); }
51 IndexType
Builder::getIndexType() { return IndexType::get(context
); }
53 IntegerType
Builder::getI1Type() { return IntegerType::get(context
, 1); }
55 IntegerType
Builder::getI2Type() { return IntegerType::get(context
, 2); }
57 IntegerType
Builder::getI4Type() { return IntegerType::get(context
, 4); }
59 IntegerType
Builder::getI8Type() { return IntegerType::get(context
, 8); }
61 IntegerType
Builder::getI16Type() { return IntegerType::get(context
, 16); }
63 IntegerType
Builder::getI32Type() { return IntegerType::get(context
, 32); }
65 IntegerType
Builder::getI64Type() { return IntegerType::get(context
, 64); }
67 IntegerType
Builder::getIntegerType(unsigned width
) {
68 return IntegerType::get(context
, width
);
71 IntegerType
Builder::getIntegerType(unsigned width
, bool isSigned
) {
72 return IntegerType::get(
73 context
, width
, isSigned
? IntegerType::Signed
: IntegerType::Unsigned
);
76 FunctionType
Builder::getFunctionType(TypeRange inputs
, TypeRange results
) {
77 return FunctionType::get(context
, inputs
, results
);
80 TupleType
Builder::getTupleType(TypeRange elementTypes
) {
81 return TupleType::get(context
, elementTypes
);
84 NoneType
Builder::getNoneType() { return NoneType::get(context
); }
86 //===----------------------------------------------------------------------===//
88 //===----------------------------------------------------------------------===//
90 NamedAttribute
Builder::getNamedAttr(StringRef name
, Attribute val
) {
91 return NamedAttribute(name
, val
);
94 UnitAttr
Builder::getUnitAttr() { return UnitAttr::get(context
); }
96 BoolAttr
Builder::getBoolAttr(bool value
) {
97 return BoolAttr::get(context
, value
);
100 DictionaryAttr
Builder::getDictionaryAttr(ArrayRef
<NamedAttribute
> value
) {
101 return DictionaryAttr::get(context
, value
);
104 IntegerAttr
Builder::getIndexAttr(int64_t value
) {
105 return IntegerAttr::get(getIndexType(), APInt(64, value
));
108 IntegerAttr
Builder::getI64IntegerAttr(int64_t value
) {
109 return IntegerAttr::get(getIntegerType(64), APInt(64, value
));
112 DenseIntElementsAttr
Builder::getBoolVectorAttr(ArrayRef
<bool> values
) {
113 return DenseIntElementsAttr::get(
114 VectorType::get(static_cast<int64_t>(values
.size()), getI1Type()),
118 DenseIntElementsAttr
Builder::getI32VectorAttr(ArrayRef
<int32_t> values
) {
119 return DenseIntElementsAttr::get(
120 VectorType::get(static_cast<int64_t>(values
.size()), getIntegerType(32)),
124 DenseIntElementsAttr
Builder::getI64VectorAttr(ArrayRef
<int64_t> values
) {
125 return DenseIntElementsAttr::get(
126 VectorType::get(static_cast<int64_t>(values
.size()), getIntegerType(64)),
130 DenseIntElementsAttr
Builder::getIndexVectorAttr(ArrayRef
<int64_t> values
) {
131 return DenseIntElementsAttr::get(
132 VectorType::get(static_cast<int64_t>(values
.size()), getIndexType()),
136 DenseFPElementsAttr
Builder::getF32VectorAttr(ArrayRef
<float> values
) {
137 return DenseFPElementsAttr::get(
138 VectorType::get(static_cast<float>(values
.size()), getF32Type()), values
);
141 DenseFPElementsAttr
Builder::getF64VectorAttr(ArrayRef
<double> values
) {
142 return DenseFPElementsAttr::get(
143 VectorType::get(static_cast<double>(values
.size()), getF64Type()),
147 DenseBoolArrayAttr
Builder::getDenseBoolArrayAttr(ArrayRef
<bool> values
) {
148 return DenseBoolArrayAttr::get(context
, values
);
151 DenseI8ArrayAttr
Builder::getDenseI8ArrayAttr(ArrayRef
<int8_t> values
) {
152 return DenseI8ArrayAttr::get(context
, values
);
155 DenseI16ArrayAttr
Builder::getDenseI16ArrayAttr(ArrayRef
<int16_t> values
) {
156 return DenseI16ArrayAttr::get(context
, values
);
159 DenseI32ArrayAttr
Builder::getDenseI32ArrayAttr(ArrayRef
<int32_t> values
) {
160 return DenseI32ArrayAttr::get(context
, values
);
163 DenseI64ArrayAttr
Builder::getDenseI64ArrayAttr(ArrayRef
<int64_t> values
) {
164 return DenseI64ArrayAttr::get(context
, values
);
167 DenseF32ArrayAttr
Builder::getDenseF32ArrayAttr(ArrayRef
<float> values
) {
168 return DenseF32ArrayAttr::get(context
, values
);
171 DenseF64ArrayAttr
Builder::getDenseF64ArrayAttr(ArrayRef
<double> values
) {
172 return DenseF64ArrayAttr::get(context
, values
);
175 DenseIntElementsAttr
Builder::getI32TensorAttr(ArrayRef
<int32_t> values
) {
176 return DenseIntElementsAttr::get(
177 RankedTensorType::get(static_cast<int64_t>(values
.size()),
182 DenseIntElementsAttr
Builder::getI64TensorAttr(ArrayRef
<int64_t> values
) {
183 return DenseIntElementsAttr::get(
184 RankedTensorType::get(static_cast<int64_t>(values
.size()),
189 DenseIntElementsAttr
Builder::getIndexTensorAttr(ArrayRef
<int64_t> values
) {
190 return DenseIntElementsAttr::get(
191 RankedTensorType::get(static_cast<int64_t>(values
.size()),
196 IntegerAttr
Builder::getI32IntegerAttr(int32_t value
) {
197 // The APInt always uses isSigned=true here because we accept the value
199 return IntegerAttr::get(getIntegerType(32),
200 APInt(32, value
, /*isSigned=*/true));
203 IntegerAttr
Builder::getSI32IntegerAttr(int32_t value
) {
204 return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
205 APInt(32, value
, /*isSigned=*/true));
208 IntegerAttr
Builder::getUI32IntegerAttr(uint32_t value
) {
209 return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
210 APInt(32, (uint64_t)value
, /*isSigned=*/false));
213 IntegerAttr
Builder::getI16IntegerAttr(int16_t value
) {
214 return IntegerAttr::get(getIntegerType(16), APInt(16, value
));
217 IntegerAttr
Builder::getI8IntegerAttr(int8_t value
) {
218 // The APInt always uses isSigned=true here because we accept the value
220 return IntegerAttr::get(getIntegerType(8),
221 APInt(8, value
, /*isSigned=*/true));
224 IntegerAttr
Builder::getIntegerAttr(Type type
, int64_t value
) {
226 return IntegerAttr::get(type
, APInt(64, value
));
227 // TODO: Avoid implicit trunc?
228 // See https://github.com/llvm/llvm-project/issues/112510.
229 return IntegerAttr::get(type
, APInt(type
.getIntOrFloatBitWidth(), value
,
230 type
.isSignedInteger(),
231 /*implicitTrunc=*/true));
234 IntegerAttr
Builder::getIntegerAttr(Type type
, const APInt
&value
) {
235 return IntegerAttr::get(type
, value
);
238 FloatAttr
Builder::getF64FloatAttr(double value
) {
239 return FloatAttr::get(getF64Type(), APFloat(value
));
242 FloatAttr
Builder::getF32FloatAttr(float value
) {
243 return FloatAttr::get(getF32Type(), APFloat(value
));
246 FloatAttr
Builder::getF16FloatAttr(float value
) {
247 return FloatAttr::get(getF16Type(), value
);
250 FloatAttr
Builder::getFloatAttr(Type type
, double value
) {
251 return FloatAttr::get(type
, value
);
254 FloatAttr
Builder::getFloatAttr(Type type
, const APFloat
&value
) {
255 return FloatAttr::get(type
, value
);
258 StringAttr
Builder::getStringAttr(const Twine
&bytes
) {
259 return StringAttr::get(context
, bytes
);
262 ArrayAttr
Builder::getArrayAttr(ArrayRef
<Attribute
> value
) {
263 return ArrayAttr::get(context
, value
);
266 ArrayAttr
Builder::getBoolArrayAttr(ArrayRef
<bool> values
) {
267 auto attrs
= llvm::map_to_vector
<8>(
268 values
, [this](bool v
) -> Attribute
{ return getBoolAttr(v
); });
269 return getArrayAttr(attrs
);
272 ArrayAttr
Builder::getI32ArrayAttr(ArrayRef
<int32_t> values
) {
273 auto attrs
= llvm::map_to_vector
<8>(
274 values
, [this](int32_t v
) -> Attribute
{ return getI32IntegerAttr(v
); });
275 return getArrayAttr(attrs
);
277 ArrayAttr
Builder::getI64ArrayAttr(ArrayRef
<int64_t> values
) {
278 auto attrs
= llvm::map_to_vector
<8>(
279 values
, [this](int64_t v
) -> Attribute
{ return getI64IntegerAttr(v
); });
280 return getArrayAttr(attrs
);
283 ArrayAttr
Builder::getIndexArrayAttr(ArrayRef
<int64_t> values
) {
284 auto attrs
= llvm::map_to_vector
<8>(values
, [this](int64_t v
) -> Attribute
{
285 return getIntegerAttr(IndexType::get(getContext()), v
);
287 return getArrayAttr(attrs
);
290 ArrayAttr
Builder::getF32ArrayAttr(ArrayRef
<float> values
) {
291 auto attrs
= llvm::map_to_vector
<8>(
292 values
, [this](float v
) -> Attribute
{ return getF32FloatAttr(v
); });
293 return getArrayAttr(attrs
);
296 ArrayAttr
Builder::getF64ArrayAttr(ArrayRef
<double> values
) {
297 auto attrs
= llvm::map_to_vector
<8>(
298 values
, [this](double v
) -> Attribute
{ return getF64FloatAttr(v
); });
299 return getArrayAttr(attrs
);
302 ArrayAttr
Builder::getStrArrayAttr(ArrayRef
<StringRef
> values
) {
303 auto attrs
= llvm::map_to_vector
<8>(
304 values
, [this](StringRef v
) -> Attribute
{ return getStringAttr(v
); });
305 return getArrayAttr(attrs
);
308 ArrayAttr
Builder::getTypeArrayAttr(TypeRange values
) {
309 auto attrs
= llvm::map_to_vector
<8>(
310 values
, [](Type v
) -> Attribute
{ return TypeAttr::get(v
); });
311 return getArrayAttr(attrs
);
314 ArrayAttr
Builder::getAffineMapArrayAttr(ArrayRef
<AffineMap
> values
) {
315 auto attrs
= llvm::map_to_vector
<8>(
316 values
, [](AffineMap v
) -> Attribute
{ return AffineMapAttr::get(v
); });
317 return getArrayAttr(attrs
);
320 TypedAttr
Builder::getZeroAttr(Type type
) {
321 if (llvm::isa
<FloatType
>(type
))
322 return getFloatAttr(type
, 0.0);
323 if (llvm::isa
<IndexType
>(type
))
324 return getIndexAttr(0);
325 if (llvm::dyn_cast
<IntegerType
>(type
))
326 return getIntegerAttr(type
,
327 APInt(llvm::cast
<IntegerType
>(type
).getWidth(), 0));
328 if (llvm::isa
<RankedTensorType
, VectorType
>(type
)) {
329 auto vtType
= llvm::cast
<ShapedType
>(type
);
330 auto element
= getZeroAttr(vtType
.getElementType());
333 return DenseElementsAttr::get(vtType
, element
);
338 TypedAttr
Builder::getOneAttr(Type type
) {
339 if (llvm::isa
<FloatType
>(type
))
340 return getFloatAttr(type
, 1.0);
341 if (llvm::isa
<IndexType
>(type
))
342 return getIndexAttr(1);
343 if (llvm::dyn_cast
<IntegerType
>(type
))
344 return getIntegerAttr(type
,
345 APInt(llvm::cast
<IntegerType
>(type
).getWidth(), 1));
346 if (llvm::isa
<RankedTensorType
, VectorType
>(type
)) {
347 auto vtType
= llvm::cast
<ShapedType
>(type
);
348 auto element
= getOneAttr(vtType
.getElementType());
351 return DenseElementsAttr::get(vtType
, element
);
356 //===----------------------------------------------------------------------===//
357 // Affine Expressions, Affine Maps, and Integer Sets.
358 //===----------------------------------------------------------------------===//
360 AffineExpr
Builder::getAffineDimExpr(unsigned position
) {
361 return mlir::getAffineDimExpr(position
, context
);
364 AffineExpr
Builder::getAffineSymbolExpr(unsigned position
) {
365 return mlir::getAffineSymbolExpr(position
, context
);
368 AffineExpr
Builder::getAffineConstantExpr(int64_t constant
) {
369 return mlir::getAffineConstantExpr(constant
, context
);
372 AffineMap
Builder::getEmptyAffineMap() { return AffineMap::get(context
); }
374 AffineMap
Builder::getConstantAffineMap(int64_t val
) {
375 return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
376 getAffineConstantExpr(val
));
379 AffineMap
Builder::getDimIdentityMap() {
380 return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
383 AffineMap
Builder::getMultiDimIdentityMap(unsigned rank
) {
384 SmallVector
<AffineExpr
, 4> dimExprs
;
385 dimExprs
.reserve(rank
);
386 for (unsigned i
= 0; i
< rank
; ++i
)
387 dimExprs
.push_back(getAffineDimExpr(i
));
388 return AffineMap::get(/*dimCount=*/rank
, /*symbolCount=*/0, dimExprs
,
392 AffineMap
Builder::getSymbolIdentityMap() {
393 return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
394 getAffineSymbolExpr(0));
397 AffineMap
Builder::getSingleDimShiftAffineMap(int64_t shift
) {
398 // expr = d0 + shift.
399 auto expr
= getAffineDimExpr(0) + shift
;
400 return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr
);
403 AffineMap
Builder::getShiftedAffineMap(AffineMap map
, int64_t shift
) {
404 SmallVector
<AffineExpr
, 4> shiftedResults
;
405 shiftedResults
.reserve(map
.getNumResults());
406 for (auto resultExpr
: map
.getResults())
407 shiftedResults
.push_back(resultExpr
+ shift
);
408 return AffineMap::get(map
.getNumDims(), map
.getNumSymbols(), shiftedResults
,
412 //===----------------------------------------------------------------------===//
414 //===----------------------------------------------------------------------===//
416 /// Insert the given operation at the current insertion point and return it.
417 Operation
*OpBuilder::insert(Operation
*op
) {
419 block
->getOperations().insert(insertPoint
, op
);
421 listener
->notifyOperationInserted(op
, /*previous=*/{});
426 Block
*OpBuilder::createBlock(Region
*parent
, Region::iterator insertPt
,
427 TypeRange argTypes
, ArrayRef
<Location
> locs
) {
428 assert(parent
&& "expected valid parent region");
429 assert(argTypes
.size() == locs
.size() && "argument location mismatch");
430 if (insertPt
== Region::iterator())
431 insertPt
= parent
->end();
433 Block
*b
= new Block();
434 b
->addArguments(argTypes
, locs
);
435 parent
->getBlocks().insert(insertPt
, b
);
436 setInsertionPointToEnd(b
);
439 listener
->notifyBlockInserted(b
, /*previous=*/nullptr, /*previousIt=*/{});
443 /// Add new block with 'argTypes' arguments and set the insertion point to the
444 /// end of it. The block is placed before 'insertBefore'.
445 Block
*OpBuilder::createBlock(Block
*insertBefore
, TypeRange argTypes
,
446 ArrayRef
<Location
> locs
) {
447 assert(insertBefore
&& "expected valid insertion block");
448 return createBlock(insertBefore
->getParent(), Region::iterator(insertBefore
),
452 /// Create an operation given the fields represented as an OperationState.
453 Operation
*OpBuilder::create(const OperationState
&state
) {
454 return insert(Operation::create(state
));
457 /// Creates an operation with the given fields.
458 Operation
*OpBuilder::create(Location loc
, StringAttr opName
,
459 ValueRange operands
, TypeRange types
,
460 ArrayRef
<NamedAttribute
> attributes
,
461 BlockRange successors
,
462 MutableArrayRef
<std::unique_ptr
<Region
>> regions
) {
463 OperationState
state(loc
, opName
, operands
, types
, attributes
, successors
,
465 return create(state
);
468 LogicalResult
OpBuilder::tryFold(Operation
*op
,
469 SmallVectorImpl
<Value
> &results
) {
470 assert(results
.empty() && "expected empty results");
471 ResultRange opResults
= op
->getResults();
473 results
.reserve(opResults
.size());
474 auto cleanupFailure
= [&] {
479 // If this operation is already a constant, there is nothing to do.
480 if (matchPattern(op
, m_Constant()))
481 return cleanupFailure();
483 // Try to fold the operation.
484 SmallVector
<OpFoldResult
, 4> foldResults
;
485 if (failed(op
->fold(foldResults
)))
486 return cleanupFailure();
488 // An in-place fold does not require generation of any constants.
489 if (foldResults
.empty())
492 // A temporary builder used for creating constants during folding.
493 OpBuilder
cstBuilder(context
);
494 SmallVector
<Operation
*, 1> generatedConstants
;
496 // Populate the results with the folded results.
497 Dialect
*dialect
= op
->getDialect();
498 for (auto [foldResult
, expectedType
] :
499 llvm::zip_equal(foldResults
, opResults
.getTypes())) {
501 // Normal values get pushed back directly.
502 if (auto value
= llvm::dyn_cast_if_present
<Value
>(foldResult
)) {
503 results
.push_back(value
);
507 // Otherwise, try to materialize a constant operation.
509 return cleanupFailure();
511 // Ask the dialect to materialize a constant operation for this value.
512 Attribute attr
= cast
<Attribute
>(foldResult
);
513 auto *constOp
= dialect
->materializeConstant(cstBuilder
, attr
, expectedType
,
516 // Erase any generated constants.
517 for (Operation
*cst
: generatedConstants
)
519 return cleanupFailure();
521 assert(matchPattern(constOp
, m_Constant()));
523 generatedConstants
.push_back(constOp
);
524 results
.push_back(constOp
->getResult(0));
527 // If we were successful, insert any generated constants.
528 for (Operation
*cst
: generatedConstants
)
534 /// Helper function that sends block insertion notifications for every block
535 /// that is directly nested in the given op.
536 static void notifyBlockInsertions(Operation
*op
,
537 OpBuilder::Listener
*listener
) {
538 for (Region
&r
: op
->getRegions())
539 for (Block
&b
: r
.getBlocks())
540 listener
->notifyBlockInserted(&b
, /*previous=*/nullptr,
544 Operation
*OpBuilder::clone(Operation
&op
, IRMapping
&mapper
) {
545 Operation
*newOp
= op
.clone(mapper
);
546 newOp
= insert(newOp
);
548 // The `insert` call above handles the notification for inserting `newOp`
549 // itself. But if `newOp` has any regions, we need to notify the listener
550 // about any ops that got inserted inside those regions as part of cloning.
552 // The `insert` call above notifies about op insertion, but not about block
554 notifyBlockInsertions(newOp
, listener
);
555 auto walkFn
= [&](Operation
*walkedOp
) {
556 listener
->notifyOperationInserted(walkedOp
, /*previous=*/{});
557 notifyBlockInsertions(walkedOp
, listener
);
559 for (Region
®ion
: newOp
->getRegions())
560 region
.walk
<WalkOrder::PreOrder
>(walkFn
);
566 Operation
*OpBuilder::clone(Operation
&op
) {
568 return clone(op
, mapper
);
571 void OpBuilder::cloneRegionBefore(Region
®ion
, Region
&parent
,
572 Region::iterator before
, IRMapping
&mapping
) {
573 region
.cloneInto(&parent
, before
, mapping
);
575 // Fast path: If no listener is attached, there is no more work to do.
579 // Notify about op/block insertion.
580 for (auto it
= mapping
.lookup(®ion
.front())->getIterator(); it
!= before
;
582 listener
->notifyBlockInserted(&*it
, /*previous=*/nullptr,
584 it
->walk
<WalkOrder::PreOrder
>([&](Operation
*walkedOp
) {
585 listener
->notifyOperationInserted(walkedOp
, /*previous=*/{});
586 notifyBlockInsertions(walkedOp
, listener
);
591 void OpBuilder::cloneRegionBefore(Region
®ion
, Region
&parent
,
592 Region::iterator before
) {
594 cloneRegionBefore(region
, parent
, before
, mapping
);
597 void OpBuilder::cloneRegionBefore(Region
®ion
, Block
*before
) {
598 cloneRegionBefore(region
, *before
->getParent(), before
->getIterator());