1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/Builders.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/SymbolTable.h"
18 #include "llvm/ADT/SmallVectorExtras.h"
19 #include "llvm/Support/raw_ostream.h"
23 //===----------------------------------------------------------------------===//
25 //===----------------------------------------------------------------------===//
27 Location
Builder::getUnknownLoc() { return UnknownLoc::get(context
); }
29 Location
Builder::getFusedLoc(ArrayRef
<Location
> locs
, Attribute metadata
) {
30 return FusedLoc::get(locs
, metadata
, context
);
33 //===----------------------------------------------------------------------===//
35 //===----------------------------------------------------------------------===//
37 FloatType
Builder::getFloat8E5M2Type() {
38 return FloatType::getFloat8E5M2(context
);
41 FloatType
Builder::getFloat8E4M3FNType() {
42 return FloatType::getFloat8E4M3FN(context
);
45 FloatType
Builder::getFloat8E5M2FNUZType() {
46 return FloatType::getFloat8E5M2FNUZ(context
);
49 FloatType
Builder::getFloat8E4M3FNUZType() {
50 return FloatType::getFloat8E4M3FNUZ(context
);
53 FloatType
Builder::getFloat8E4M3B11FNUZType() {
54 return FloatType::getFloat8E4M3B11FNUZ(context
);
57 FloatType
Builder::getBF16Type() { return FloatType::getBF16(context
); }
59 FloatType
Builder::getF16Type() { return FloatType::getF16(context
); }
61 FloatType
Builder::getTF32Type() { return FloatType::getTF32(context
); }
63 FloatType
Builder::getF32Type() { return FloatType::getF32(context
); }
65 FloatType
Builder::getF64Type() { return FloatType::getF64(context
); }
67 FloatType
Builder::getF80Type() { return FloatType::getF80(context
); }
69 FloatType
Builder::getF128Type() { return FloatType::getF128(context
); }
71 IndexType
Builder::getIndexType() { return IndexType::get(context
); }
73 IntegerType
Builder::getI1Type() { return IntegerType::get(context
, 1); }
75 IntegerType
Builder::getI2Type() { return IntegerType::get(context
, 2); }
77 IntegerType
Builder::getI4Type() { return IntegerType::get(context
, 4); }
79 IntegerType
Builder::getI8Type() { return IntegerType::get(context
, 8); }
81 IntegerType
Builder::getI16Type() { return IntegerType::get(context
, 16); }
83 IntegerType
Builder::getI32Type() { return IntegerType::get(context
, 32); }
85 IntegerType
Builder::getI64Type() { return IntegerType::get(context
, 64); }
87 IntegerType
Builder::getIntegerType(unsigned width
) {
88 return IntegerType::get(context
, width
);
91 IntegerType
Builder::getIntegerType(unsigned width
, bool isSigned
) {
92 return IntegerType::get(
93 context
, width
, isSigned
? IntegerType::Signed
: IntegerType::Unsigned
);
96 FunctionType
Builder::getFunctionType(TypeRange inputs
, TypeRange results
) {
97 return FunctionType::get(context
, inputs
, results
);
100 TupleType
Builder::getTupleType(TypeRange elementTypes
) {
101 return TupleType::get(context
, elementTypes
);
104 NoneType
Builder::getNoneType() { return NoneType::get(context
); }
106 //===----------------------------------------------------------------------===//
108 //===----------------------------------------------------------------------===//
110 NamedAttribute
Builder::getNamedAttr(StringRef name
, Attribute val
) {
111 return NamedAttribute(getStringAttr(name
), val
);
114 UnitAttr
Builder::getUnitAttr() { return UnitAttr::get(context
); }
116 BoolAttr
Builder::getBoolAttr(bool value
) {
117 return BoolAttr::get(context
, value
);
120 DictionaryAttr
Builder::getDictionaryAttr(ArrayRef
<NamedAttribute
> value
) {
121 return DictionaryAttr::get(context
, value
);
124 IntegerAttr
Builder::getIndexAttr(int64_t value
) {
125 return IntegerAttr::get(getIndexType(), APInt(64, value
));
128 IntegerAttr
Builder::getI64IntegerAttr(int64_t value
) {
129 return IntegerAttr::get(getIntegerType(64), APInt(64, value
));
132 DenseIntElementsAttr
Builder::getBoolVectorAttr(ArrayRef
<bool> values
) {
133 return DenseIntElementsAttr::get(
134 VectorType::get(static_cast<int64_t>(values
.size()), getI1Type()),
138 DenseIntElementsAttr
Builder::getI32VectorAttr(ArrayRef
<int32_t> values
) {
139 return DenseIntElementsAttr::get(
140 VectorType::get(static_cast<int64_t>(values
.size()), getIntegerType(32)),
144 DenseIntElementsAttr
Builder::getI64VectorAttr(ArrayRef
<int64_t> values
) {
145 return DenseIntElementsAttr::get(
146 VectorType::get(static_cast<int64_t>(values
.size()), getIntegerType(64)),
150 DenseIntElementsAttr
Builder::getIndexVectorAttr(ArrayRef
<int64_t> values
) {
151 return DenseIntElementsAttr::get(
152 VectorType::get(static_cast<int64_t>(values
.size()), getIndexType()),
156 DenseFPElementsAttr
Builder::getF32VectorAttr(ArrayRef
<float> values
) {
157 return DenseFPElementsAttr::get(
158 VectorType::get(static_cast<float>(values
.size()), getF32Type()), values
);
161 DenseFPElementsAttr
Builder::getF64VectorAttr(ArrayRef
<double> values
) {
162 return DenseFPElementsAttr::get(
163 VectorType::get(static_cast<double>(values
.size()), getF64Type()),
167 DenseBoolArrayAttr
Builder::getDenseBoolArrayAttr(ArrayRef
<bool> values
) {
168 return DenseBoolArrayAttr::get(context
, values
);
171 DenseI8ArrayAttr
Builder::getDenseI8ArrayAttr(ArrayRef
<int8_t> values
) {
172 return DenseI8ArrayAttr::get(context
, values
);
175 DenseI16ArrayAttr
Builder::getDenseI16ArrayAttr(ArrayRef
<int16_t> values
) {
176 return DenseI16ArrayAttr::get(context
, values
);
179 DenseI32ArrayAttr
Builder::getDenseI32ArrayAttr(ArrayRef
<int32_t> values
) {
180 return DenseI32ArrayAttr::get(context
, values
);
183 DenseI64ArrayAttr
Builder::getDenseI64ArrayAttr(ArrayRef
<int64_t> values
) {
184 return DenseI64ArrayAttr::get(context
, values
);
187 DenseF32ArrayAttr
Builder::getDenseF32ArrayAttr(ArrayRef
<float> values
) {
188 return DenseF32ArrayAttr::get(context
, values
);
191 DenseF64ArrayAttr
Builder::getDenseF64ArrayAttr(ArrayRef
<double> values
) {
192 return DenseF64ArrayAttr::get(context
, values
);
195 DenseIntElementsAttr
Builder::getI32TensorAttr(ArrayRef
<int32_t> values
) {
196 return DenseIntElementsAttr::get(
197 RankedTensorType::get(static_cast<int64_t>(values
.size()),
202 DenseIntElementsAttr
Builder::getI64TensorAttr(ArrayRef
<int64_t> values
) {
203 return DenseIntElementsAttr::get(
204 RankedTensorType::get(static_cast<int64_t>(values
.size()),
209 DenseIntElementsAttr
Builder::getIndexTensorAttr(ArrayRef
<int64_t> values
) {
210 return DenseIntElementsAttr::get(
211 RankedTensorType::get(static_cast<int64_t>(values
.size()),
216 IntegerAttr
Builder::getI32IntegerAttr(int32_t value
) {
217 return IntegerAttr::get(getIntegerType(32), APInt(32, value
));
220 IntegerAttr
Builder::getSI32IntegerAttr(int32_t value
) {
221 return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
222 APInt(32, value
, /*isSigned=*/true));
225 IntegerAttr
Builder::getUI32IntegerAttr(uint32_t value
) {
226 return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
227 APInt(32, (uint64_t)value
, /*isSigned=*/false));
230 IntegerAttr
Builder::getI16IntegerAttr(int16_t value
) {
231 return IntegerAttr::get(getIntegerType(16), APInt(16, value
));
234 IntegerAttr
Builder::getI8IntegerAttr(int8_t value
) {
235 return IntegerAttr::get(getIntegerType(8), APInt(8, value
));
238 IntegerAttr
Builder::getIntegerAttr(Type type
, int64_t value
) {
240 return IntegerAttr::get(type
, APInt(64, value
));
241 return IntegerAttr::get(
242 type
, APInt(type
.getIntOrFloatBitWidth(), value
, type
.isSignedInteger()));
245 IntegerAttr
Builder::getIntegerAttr(Type type
, const APInt
&value
) {
246 return IntegerAttr::get(type
, value
);
249 FloatAttr
Builder::getF64FloatAttr(double value
) {
250 return FloatAttr::get(getF64Type(), APFloat(value
));
253 FloatAttr
Builder::getF32FloatAttr(float value
) {
254 return FloatAttr::get(getF32Type(), APFloat(value
));
257 FloatAttr
Builder::getF16FloatAttr(float value
) {
258 return FloatAttr::get(getF16Type(), value
);
261 FloatAttr
Builder::getFloatAttr(Type type
, double value
) {
262 return FloatAttr::get(type
, value
);
265 FloatAttr
Builder::getFloatAttr(Type type
, const APFloat
&value
) {
266 return FloatAttr::get(type
, value
);
269 StringAttr
Builder::getStringAttr(const Twine
&bytes
) {
270 return StringAttr::get(context
, bytes
);
273 ArrayAttr
Builder::getArrayAttr(ArrayRef
<Attribute
> value
) {
274 return ArrayAttr::get(context
, value
);
277 ArrayAttr
Builder::getBoolArrayAttr(ArrayRef
<bool> values
) {
278 auto attrs
= llvm::map_to_vector
<8>(
279 values
, [this](bool v
) -> Attribute
{ return getBoolAttr(v
); });
280 return getArrayAttr(attrs
);
283 ArrayAttr
Builder::getI32ArrayAttr(ArrayRef
<int32_t> values
) {
284 auto attrs
= llvm::map_to_vector
<8>(
285 values
, [this](int32_t v
) -> Attribute
{ return getI32IntegerAttr(v
); });
286 return getArrayAttr(attrs
);
288 ArrayAttr
Builder::getI64ArrayAttr(ArrayRef
<int64_t> values
) {
289 auto attrs
= llvm::map_to_vector
<8>(
290 values
, [this](int64_t v
) -> Attribute
{ return getI64IntegerAttr(v
); });
291 return getArrayAttr(attrs
);
294 ArrayAttr
Builder::getIndexArrayAttr(ArrayRef
<int64_t> values
) {
295 auto attrs
= llvm::map_to_vector
<8>(values
, [this](int64_t v
) -> Attribute
{
296 return getIntegerAttr(IndexType::get(getContext()), v
);
298 return getArrayAttr(attrs
);
301 ArrayAttr
Builder::getF32ArrayAttr(ArrayRef
<float> values
) {
302 auto attrs
= llvm::map_to_vector
<8>(
303 values
, [this](float v
) -> Attribute
{ return getF32FloatAttr(v
); });
304 return getArrayAttr(attrs
);
307 ArrayAttr
Builder::getF64ArrayAttr(ArrayRef
<double> values
) {
308 auto attrs
= llvm::map_to_vector
<8>(
309 values
, [this](double v
) -> Attribute
{ return getF64FloatAttr(v
); });
310 return getArrayAttr(attrs
);
313 ArrayAttr
Builder::getStrArrayAttr(ArrayRef
<StringRef
> values
) {
314 auto attrs
= llvm::map_to_vector
<8>(
315 values
, [this](StringRef v
) -> Attribute
{ return getStringAttr(v
); });
316 return getArrayAttr(attrs
);
319 ArrayAttr
Builder::getTypeArrayAttr(TypeRange values
) {
320 auto attrs
= llvm::map_to_vector
<8>(
321 values
, [](Type v
) -> Attribute
{ return TypeAttr::get(v
); });
322 return getArrayAttr(attrs
);
325 ArrayAttr
Builder::getAffineMapArrayAttr(ArrayRef
<AffineMap
> values
) {
326 auto attrs
= llvm::map_to_vector
<8>(
327 values
, [](AffineMap v
) -> Attribute
{ return AffineMapAttr::get(v
); });
328 return getArrayAttr(attrs
);
331 TypedAttr
Builder::getZeroAttr(Type type
) {
332 if (llvm::isa
<FloatType
>(type
))
333 return getFloatAttr(type
, 0.0);
334 if (llvm::isa
<IndexType
>(type
))
335 return getIndexAttr(0);
336 if (llvm::dyn_cast
<IntegerType
>(type
))
337 return getIntegerAttr(type
,
338 APInt(llvm::cast
<IntegerType
>(type
).getWidth(), 0));
339 if (llvm::isa
<RankedTensorType
, VectorType
>(type
)) {
340 auto vtType
= llvm::cast
<ShapedType
>(type
);
341 auto element
= getZeroAttr(vtType
.getElementType());
344 return DenseElementsAttr::get(vtType
, element
);
349 //===----------------------------------------------------------------------===//
350 // Affine Expressions, Affine Maps, and Integer Sets.
351 //===----------------------------------------------------------------------===//
353 AffineExpr
Builder::getAffineDimExpr(unsigned position
) {
354 return mlir::getAffineDimExpr(position
, context
);
357 AffineExpr
Builder::getAffineSymbolExpr(unsigned position
) {
358 return mlir::getAffineSymbolExpr(position
, context
);
361 AffineExpr
Builder::getAffineConstantExpr(int64_t constant
) {
362 return mlir::getAffineConstantExpr(constant
, context
);
365 AffineMap
Builder::getEmptyAffineMap() { return AffineMap::get(context
); }
367 AffineMap
Builder::getConstantAffineMap(int64_t val
) {
368 return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
369 getAffineConstantExpr(val
));
372 AffineMap
Builder::getDimIdentityMap() {
373 return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
376 AffineMap
Builder::getMultiDimIdentityMap(unsigned rank
) {
377 SmallVector
<AffineExpr
, 4> dimExprs
;
378 dimExprs
.reserve(rank
);
379 for (unsigned i
= 0; i
< rank
; ++i
)
380 dimExprs
.push_back(getAffineDimExpr(i
));
381 return AffineMap::get(/*dimCount=*/rank
, /*symbolCount=*/0, dimExprs
,
385 AffineMap
Builder::getSymbolIdentityMap() {
386 return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
387 getAffineSymbolExpr(0));
390 AffineMap
Builder::getSingleDimShiftAffineMap(int64_t shift
) {
391 // expr = d0 + shift.
392 auto expr
= getAffineDimExpr(0) + shift
;
393 return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr
);
396 AffineMap
Builder::getShiftedAffineMap(AffineMap map
, int64_t shift
) {
397 SmallVector
<AffineExpr
, 4> shiftedResults
;
398 shiftedResults
.reserve(map
.getNumResults());
399 for (auto resultExpr
: map
.getResults())
400 shiftedResults
.push_back(resultExpr
+ shift
);
401 return AffineMap::get(map
.getNumDims(), map
.getNumSymbols(), shiftedResults
,
405 //===----------------------------------------------------------------------===//
407 //===----------------------------------------------------------------------===//
409 /// Insert the given operation at the current insertion point and return it.
410 Operation
*OpBuilder::insert(Operation
*op
) {
412 block
->getOperations().insert(insertPoint
, op
);
415 listener
->notifyOperationInserted(op
);
419 Block
*OpBuilder::createBlock(Region
*parent
, Region::iterator insertPt
,
420 TypeRange argTypes
, ArrayRef
<Location
> locs
) {
421 assert(parent
&& "expected valid parent region");
422 assert(argTypes
.size() == locs
.size() && "argument location mismatch");
423 if (insertPt
== Region::iterator())
424 insertPt
= parent
->end();
426 Block
*b
= new Block();
427 b
->addArguments(argTypes
, locs
);
428 parent
->getBlocks().insert(insertPt
, b
);
429 setInsertionPointToEnd(b
);
432 listener
->notifyBlockCreated(b
);
436 /// Add new block with 'argTypes' arguments and set the insertion point to the
437 /// end of it. The block is placed before 'insertBefore'.
438 Block
*OpBuilder::createBlock(Block
*insertBefore
, TypeRange argTypes
,
439 ArrayRef
<Location
> locs
) {
440 assert(insertBefore
&& "expected valid insertion block");
441 return createBlock(insertBefore
->getParent(), Region::iterator(insertBefore
),
445 /// Create an operation given the fields represented as an OperationState.
446 Operation
*OpBuilder::create(const OperationState
&state
) {
447 return insert(Operation::create(state
));
450 /// Creates an operation with the given fields.
451 Operation
*OpBuilder::create(Location loc
, StringAttr opName
,
452 ValueRange operands
, TypeRange types
,
453 ArrayRef
<NamedAttribute
> attributes
,
454 BlockRange successors
,
455 MutableArrayRef
<std::unique_ptr
<Region
>> regions
) {
456 OperationState
state(loc
, opName
, operands
, types
, attributes
, successors
,
458 return create(state
);
461 /// Attempts to fold the given operation and places new results within
462 /// 'results'. Returns success if the operation was folded, failure otherwise.
463 /// Note: This function does not erase the operation on a successful fold.
464 LogicalResult
OpBuilder::tryFold(Operation
*op
,
465 SmallVectorImpl
<Value
> &results
) {
466 ResultRange opResults
= op
->getResults();
468 results
.reserve(opResults
.size());
469 auto cleanupFailure
= [&] {
470 results
.assign(opResults
.begin(), opResults
.end());
474 // If this operation is already a constant, there is nothing to do.
475 if (matchPattern(op
, m_Constant()))
476 return cleanupFailure();
478 // Try to fold the operation.
479 SmallVector
<OpFoldResult
, 4> foldResults
;
480 if (failed(op
->fold(foldResults
)) || foldResults
.empty())
481 return cleanupFailure();
483 // A temporary builder used for creating constants during folding.
484 OpBuilder
cstBuilder(context
);
485 SmallVector
<Operation
*, 1> generatedConstants
;
487 // Populate the results with the folded results.
488 Dialect
*dialect
= op
->getDialect();
489 for (auto it
: llvm::zip(foldResults
, opResults
.getTypes())) {
490 Type expectedType
= std::get
<1>(it
);
492 // Normal values get pushed back directly.
493 if (auto value
= llvm::dyn_cast_if_present
<Value
>(std::get
<0>(it
))) {
494 if (value
.getType() != expectedType
)
495 return cleanupFailure();
497 results
.push_back(value
);
501 // Otherwise, try to materialize a constant operation.
503 return cleanupFailure();
505 // Ask the dialect to materialize a constant operation for this value.
506 Attribute attr
= std::get
<0>(it
).get
<Attribute
>();
507 auto *constOp
= dialect
->materializeConstant(cstBuilder
, attr
, expectedType
,
510 // Erase any generated constants.
511 for (Operation
*cst
: generatedConstants
)
513 return cleanupFailure();
515 assert(matchPattern(constOp
, m_Constant()));
517 generatedConstants
.push_back(constOp
);
518 results
.push_back(constOp
->getResult(0));
521 // If we were successful, insert any generated constants.
522 for (Operation
*cst
: generatedConstants
)
528 Operation
*OpBuilder::clone(Operation
&op
, IRMapping
&mapper
) {
529 Operation
*newOp
= op
.clone(mapper
);
530 // The `insert` call below handles the notification for inserting `newOp`
531 // itself. But if `newOp` has any regions, we need to notify the listener
532 // about any ops that got inserted inside those regions as part of cloning.
534 auto walkFn
= [&](Operation
*walkedOp
) {
535 listener
->notifyOperationInserted(walkedOp
);
537 for (Region
®ion
: newOp
->getRegions())
540 return insert(newOp
);
543 Operation
*OpBuilder::clone(Operation
&op
) {
545 return clone(op
, mapper
);