1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the Linalg operations.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/AsmParser/AsmParser.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Arith/Utils/Utils.h"
19 #include "mlir/Dialect/Complex/IR/Complex.h"
20 #include "mlir/Dialect/Math/IR/Math.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Utils/IndexingUtils.h"
26 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
27 #include "mlir/Dialect/Utils/StaticValueUtils.h"
28 #include "mlir/IR/AffineExprVisitor.h"
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Attributes.h"
31 #include "mlir/IR/BuiltinAttributes.h"
32 #include "mlir/IR/BuiltinTypeInterfaces.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OpImplementation.h"
35 #include "mlir/IR/OperationSupport.h"
36 #include "mlir/IR/PatternMatch.h"
37 #include "mlir/Interfaces/InferTypeOpInterface.h"
38 #include "mlir/Interfaces/SideEffectInterfaces.h"
40 #include "llvm/ADT/DenseMap.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SetOperations.h"
43 #include "llvm/ADT/SmallSet.h"
44 #include "llvm/ADT/SmallVector.h"
45 #include "llvm/ADT/StringSet.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include "llvm/Support/LogicalResult.h"
49 #include "llvm/Support/MathExtras.h"
50 #include "llvm/Support/raw_ostream.h"
55 using namespace mlir::linalg
;
57 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
58 static OpFoldResult
getDimValue(OpBuilder
&builder
, Location loc
, Value v
,
60 auto type
= cast
<ShapedType
>(v
.getType());
61 if (!type
.isDynamicDim(dim
))
62 return builder
.getIndexAttr(type
.getDimSize(dim
));
64 return getAsOpFoldResult(
65 TypeSwitch
<Type
, Value
>(v
.getType())
66 .Case
<RankedTensorType
>([&](RankedTensorType t
) -> Value
{
67 return builder
.create
<tensor::DimOp
>(loc
, v
, dim
);
69 .Case
<MemRefType
>([&](MemRefType t
) -> Value
{
70 return builder
.create
<memref::DimOp
>(loc
, v
, dim
);
74 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
76 static Operation
*getSlice(OpBuilder
&b
, Location loc
, Value source
,
77 ArrayRef
<OpFoldResult
> offsets
,
78 ArrayRef
<OpFoldResult
> sizes
,
79 ArrayRef
<OpFoldResult
> strides
) {
80 return TypeSwitch
<Type
, Operation
*>(source
.getType())
81 .Case
<RankedTensorType
>([&](RankedTensorType t
) -> Operation
* {
82 return b
.create
<tensor::ExtractSliceOp
>(loc
, source
, offsets
, sizes
,
85 .Case
<MemRefType
>([&](MemRefType type
) -> Operation
* {
86 return b
.create
<memref::SubViewOp
>(loc
, source
, offsets
, sizes
,
89 .Default([&](Type t
) -> Operation
* { return nullptr; });
92 //===----------------------------------------------------------------------===//
94 //===----------------------------------------------------------------------===//
96 Value
linalg::createOrFoldDimOp(OpBuilder
&b
, Location loc
, Value source
,
98 if (llvm::isa
<UnrankedMemRefType
, MemRefType
>(source
.getType()))
99 return b
.createOrFold
<memref::DimOp
>(loc
, source
, dim
);
100 if (llvm::isa
<UnrankedTensorType
, RankedTensorType
>(source
.getType()))
101 return b
.createOrFold
<tensor::DimOp
>(loc
, source
, dim
);
102 llvm_unreachable("Expected MemRefType or TensorType");
105 OpFoldResult
linalg::createFoldedDimOp(OpBuilder
&b
, Location loc
, Value source
,
107 auto shapedType
= llvm::cast
<ShapedType
>(source
.getType());
108 if (!shapedType
.hasRank() || shapedType
.isDynamicDim(dim
))
109 return createOrFoldDimOp(b
, loc
, source
, dim
);
110 return b
.getIndexAttr(shapedType
.getDimSize(dim
));
113 //===----------------------------------------------------------------------===//
114 // Support for named Linalg ops defined in ods-gen.
115 //===----------------------------------------------------------------------===//
117 using RegionBuilderFn
= llvm::function_ref
<void(ImplicitLocOpBuilder
&, Block
&,
118 ArrayRef
<NamedAttribute
>)>;
120 /// Fills the region of a structured operation using the provided
121 /// `regionBuilder`. The method is used by both named structured ops created by
122 /// ods-gen and by manually defined C++ ops. It is called by both builders and
123 /// parsers and creates a block with arguments corresponding to the elemental
124 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
126 static void fillStructuredOpRegion(OpBuilder
&opBuilder
, Region
®ion
,
127 TypeRange inputTypes
, TypeRange outputTypes
,
128 ArrayRef
<NamedAttribute
> attrs
,
129 RegionBuilderFn regionBuilder
) {
130 assert(llvm::all_of(outputTypes
, llvm::IsaPred
<ShapedType
>));
132 SmallVector
<Type
, 8> argTypes
;
133 SmallVector
<Location
, 8> argLocs
;
134 for (auto containers
: {inputTypes
, outputTypes
}) {
135 for (auto t
: containers
) {
137 isa
<MemRefType
, RankedTensorType
>(t
) ? getElementTypeOrSelf(t
) : t
);
139 // TODO: Pass in a proper location here.
140 argLocs
.push_back(opBuilder
.getUnknownLoc());
145 OpBuilder::InsertionGuard
guard(opBuilder
);
147 opBuilder
.createBlock(®ion
, /*insertPt=*/{}, argTypes
, argLocs
);
149 opBuilder
.setInsertionPointToStart(body
);
150 ImplicitLocOpBuilder
b(opBuilder
.getUnknownLoc(), opBuilder
);
151 regionBuilder(b
, *body
, attrs
);
153 // indexing_maps is an auto-generated method.
155 // iterator_types is an auto-generated method.
158 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
159 /// The result types are derived automatically if `resultTensorTypes` is none.
160 /// The body of the operation is filled using `regionBuilder`. All ods-gen
161 /// created structured operations use the method to implement their builders.
162 static void buildStructuredOp(OpBuilder
&b
, OperationState
&state
,
163 std::optional
<TypeRange
> resultTensorTypes
,
164 ValueRange inputs
, ValueRange outputs
,
165 ArrayRef
<NamedAttribute
> attributes
,
166 RegionBuilderFn regionBuilder
) {
167 // Derive the result types if needed.
168 SmallVector
<Type
> derivedResultTypes
=
169 resultTensorTypes
.value_or(TypeRange());
170 if (!resultTensorTypes
)
171 copy_if(outputs
.getTypes(), std::back_inserter(derivedResultTypes
),
172 llvm::IsaPred
<RankedTensorType
>);
174 state
.addOperands(inputs
);
175 state
.addOperands(outputs
);
176 state
.addTypes(derivedResultTypes
);
178 state
.addAttributes(attributes
);
180 "operandSegmentSizes",
181 b
.getDenseI32ArrayAttr({static_cast<int32_t>(inputs
.size()),
182 static_cast<int32_t>(outputs
.size())}));
184 // Create and fill the region of the structured operation.
185 Region
®ion
= *state
.addRegion();
186 fillStructuredOpRegion(b
, region
, TypeRange(inputs
), TypeRange(outputs
),
187 state
.attributes
.getAttrs(), regionBuilder
);
190 static void buildMatmulOp(OpBuilder
&b
, OperationState
&state
,
191 std::optional
<TypeRange
> resultTensorTypes
,
192 ValueRange inputs
, ValueRange outputs
,
193 ArrayRef
<NamedAttribute
> attributes
,
194 RegionBuilderFn regionBuilder
,
195 ArrayRef
<AffineMap
> indexingMaps
) {
196 // Initialize indexingMaps attribute, for MatmulOp.
197 SmallVector
<Attribute
, 3> indexingMapsAttrVal
;
198 indexingMapsAttrVal
= llvm::map_to_vector(
199 MatmulOp::getDefaultIndexingMaps(b
.getContext()),
200 [](AffineMap map
) -> Attribute
{ return AffineMapAttr::get(map
); });
201 state
.addAttribute("indexing_maps", b
.getArrayAttr(indexingMapsAttrVal
));
202 return buildStructuredOp(b
, state
, resultTensorTypes
, inputs
, outputs
,
203 attributes
, regionBuilder
);
206 /// Common parsing used for both named structured ops created by ods-gen and by
207 /// manually defined C++ ops. Does not handle regions.
209 parseCommonStructuredOpParts(OpAsmParser
&parser
, OperationState
&result
,
210 SmallVectorImpl
<Type
> &inputTypes
,
211 SmallVectorImpl
<Type
> &outputTypes
,
212 bool addOperandSegmentSizes
= true) {
213 SMLoc attrsLoc
, inputsOperandsLoc
, outputsOperandsLoc
;
214 SmallVector
<OpAsmParser::UnresolvedOperand
, 4> inputsOperands
,
217 if (succeeded(parser
.parseOptionalLess())) {
218 if (parser
.parseAttribute(result
.propertiesAttr
) || parser
.parseGreater())
221 attrsLoc
= parser
.getCurrentLocation();
222 if (parser
.parseOptionalAttrDict(result
.attributes
))
225 if (succeeded(parser
.parseOptionalKeyword("ins"))) {
226 if (parser
.parseLParen())
229 inputsOperandsLoc
= parser
.getCurrentLocation();
230 if (parser
.parseOperandList(inputsOperands
) ||
231 parser
.parseColonTypeList(inputTypes
) || parser
.parseRParen())
235 if (succeeded(parser
.parseOptionalKeyword("outs"))) {
236 outputsOperandsLoc
= parser
.getCurrentLocation();
237 if (parser
.parseLParen() || parser
.parseOperandList(outputsOperands
) ||
238 parser
.parseColonTypeList(outputTypes
) || parser
.parseRParen())
242 if (parser
.resolveOperands(inputsOperands
, inputTypes
, inputsOperandsLoc
,
244 parser
.resolveOperands(outputsOperands
, outputTypes
, outputsOperandsLoc
,
248 if (addOperandSegmentSizes
) {
249 // This is a bit complex because we're trying to be backward compatible with
250 // operation syntax that mix the inherent attributes and the discardable
251 // ones in the same dictionary. If the properties are used, we append the
252 // operandSegmentSizes there directly. Otherwise we append it to the
253 // discardable attributes dictionary where it is handled by the generic
254 // Operation::create(...) method.
255 if (result
.propertiesAttr
) {
256 NamedAttrList attrs
= llvm::cast
<DictionaryAttr
>(result
.propertiesAttr
);
257 attrs
.append("operandSegmentSizes",
258 parser
.getBuilder().getDenseI32ArrayAttr(
259 {static_cast<int32_t>(inputsOperands
.size()),
260 static_cast<int32_t>(outputsOperands
.size())}));
261 result
.propertiesAttr
= attrs
.getDictionary(parser
.getContext());
263 result
.addAttribute("operandSegmentSizes",
264 parser
.getBuilder().getDenseI32ArrayAttr(
265 {static_cast<int32_t>(inputsOperands
.size()),
266 static_cast<int32_t>(outputsOperands
.size())}));
269 if (!result
.propertiesAttr
) {
270 std::optional
<RegisteredOperationName
> info
=
271 result
.name
.getRegisteredInfo();
273 if (failed(info
->verifyInherentAttrs(result
.attributes
, [&]() {
274 return parser
.emitError(attrsLoc
)
275 << "'" << result
.name
.getStringRef() << "' op ";
283 static void printCommonStructuredOpParts(OpAsmPrinter
&p
, ValueRange inputs
,
284 ValueRange outputs
) {
286 p
<< " ins(" << inputs
<< " : " << inputs
.getTypes() << ")";
287 if (!outputs
.empty())
288 p
<< " outs(" << outputs
<< " : " << outputs
.getTypes() << ")";
291 //===----------------------------------------------------------------------===//
292 // Specific parsing and printing for named structured ops created by ods-gen.
293 //===----------------------------------------------------------------------===//
295 static ParseResult
parseNamedStructuredOpRegion(
296 OpAsmParser
&parser
, Region
®ion
, unsigned numRegionArgs
,
297 TypeRange inputTypes
, TypeRange outputTypes
, ArrayRef
<NamedAttribute
> attrs
,
298 RegionBuilderFn regionBuilder
) {
299 if (numRegionArgs
!= inputTypes
.size() + outputTypes
.size()) {
300 return parser
.emitError(
301 parser
.getCurrentLocation(),
302 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
303 "region expects {0} args, got {1}",
304 numRegionArgs
, inputTypes
.size() + outputTypes
.size()));
307 OpBuilder
opBuilder(parser
.getContext());
308 fillStructuredOpRegion(opBuilder
, region
, inputTypes
, outputTypes
, attrs
,
314 parseNamedStructuredOpResults(OpAsmParser
&parser
,
315 SmallVectorImpl
<Type
> &resultTypes
) {
316 if (parser
.parseOptionalArrowTypeList(resultTypes
))
321 static ParseResult
parseNamedStructuredOp(OpAsmParser
&parser
,
322 OperationState
&result
,
323 unsigned numRegionArgs
,
324 RegionBuilderFn regionBuilder
) {
325 // TODO: Enable when ods-gen supports captures.
326 SmallVector
<Type
, 1> inputTypes
, outputTypes
;
327 if (parseCommonStructuredOpParts(parser
, result
, inputTypes
, outputTypes
))
330 // Parse optional attributes.
331 if (parser
.parseOptionalAttrDict(result
.attributes
))
334 // TODO: consider merging results parsing into region parsing.
335 // Need to wait for declarative assembly resolution to decide.
336 SmallVector
<Type
, 1> outputTensorsTypes
;
337 if (parseNamedStructuredOpResults(parser
, outputTensorsTypes
))
339 result
.addTypes(outputTensorsTypes
);
341 std::unique_ptr
<Region
> region
= std::make_unique
<Region
>();
342 if (parseNamedStructuredOpRegion(parser
, *region
, numRegionArgs
, inputTypes
,
343 outputTypes
, result
.attributes
.getAttrs(),
346 result
.addRegion(std::move(region
));
351 static void printNamedStructuredOpResults(OpAsmPrinter
&p
,
352 TypeRange resultTypes
) {
353 if (resultTypes
.empty())
355 p
.printOptionalArrowTypeList(resultTypes
);
358 static void printNamedStructuredOp(OpAsmPrinter
&p
, Operation
*op
,
359 ValueRange inputs
, ValueRange outputs
,
360 ArrayRef
<StringRef
> elidedAttrs
= {}) {
361 p
.printOptionalAttrDict(op
->getAttrs(), elidedAttrs
);
363 // Printing is shared with generic ops, except for the region and
365 printCommonStructuredOpParts(p
, inputs
, outputs
);
368 printNamedStructuredOpResults(p
, op
->getResultTypes());
373 //===----------------------------------------------------------------------===//
374 // Region builder helper.
375 // TODO: Move this to a utility library.
376 // The public methods on this class are referenced directly from generated code.
377 // Helper build the unary, binary, and type conversion functions defined by the
378 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
381 // Implementations of the math functions must be polymorphic over numeric types,
382 // internally performing necessary casts. If the function application makes no
383 // sense, then the only recourse is to assert and return nullptr. This can be
384 // extended later if it becomes possible to fail construction of the region. The
385 // invariant should be enforced at a higher level.
387 // TODO: These helpers are currently type polymorphic over the class of integer
388 // and floating point types, but they will not internally cast within bit
389 // widths of a class (mixed precision such as i8->i32) or across classes
390 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
391 // to be handled with care and work is being considered to extend the op
392 // language to make such cases explicit. In the mean-time, violating this will
393 // fail verification, which is deemed acceptable.
394 //===----------------------------------------------------------------------===//
398 class RegionBuilderHelper
{
400 RegionBuilderHelper(OpBuilder
&builder
, Block
&block
)
401 : builder(builder
), block(block
) {}
403 // Build the unary functions defined by OpDSL.
404 Value
buildUnaryFn(UnaryFn unaryFn
, Value arg
) {
405 if (!isFloatingPoint(arg
))
406 llvm_unreachable("unsupported non numeric type");
407 OpBuilder::InsertionGuard
g(builder
);
408 builder
.setInsertionPointToEnd(&block
);
411 return builder
.create
<math::ExpOp
>(arg
.getLoc(), arg
);
413 return builder
.create
<math::LogOp
>(arg
.getLoc(), arg
);
415 return builder
.create
<math::AbsFOp
>(arg
.getLoc(), arg
);
417 return builder
.create
<math::CeilOp
>(arg
.getLoc(), arg
);
419 return builder
.create
<math::FloorOp
>(arg
.getLoc(), arg
);
421 return builder
.create
<arith::NegFOp
>(arg
.getLoc(), arg
);
422 case UnaryFn::reciprocal
: {
423 Attribute oneAttr
= builder
.getOneAttr(arg
.getType());
424 auto one
= builder
.create
<arith::ConstantOp
>(arg
.getLoc(),
425 ::cast
<TypedAttr
>(oneAttr
));
426 return builder
.create
<arith::DivFOp
>(arg
.getLoc(), one
, arg
);
429 return builder
.create
<math::RoundOp
>(arg
.getLoc(), arg
);
431 return builder
.create
<math::SqrtOp
>(arg
.getLoc(), arg
);
433 return builder
.create
<math::RsqrtOp
>(arg
.getLoc(), arg
);
434 case UnaryFn::square
:
435 return builder
.create
<arith::MulFOp
>(arg
.getLoc(), arg
, arg
);
437 return builder
.create
<math::TanhOp
>(arg
.getLoc(), arg
);
439 return builder
.create
<math::ErfOp
>(arg
.getLoc(), arg
);
441 llvm_unreachable("unsupported unary function");
444 // Build the binary functions defined by OpDSL.
445 Value
buildBinaryFn(BinaryFn binaryFn
, Value arg0
, Value arg1
) {
446 bool allComplex
= isComplex(arg0
) && isComplex(arg1
);
447 bool allFloatingPoint
= isFloatingPoint(arg0
) && isFloatingPoint(arg1
);
448 bool allInteger
= isInteger(arg0
) && isInteger(arg1
);
449 bool allBool
= allInteger
&& arg0
.getType().getIntOrFloatBitWidth() == 1 &&
450 arg1
.getType().getIntOrFloatBitWidth() == 1;
451 if (!allComplex
&& !allFloatingPoint
&& !allInteger
)
452 llvm_unreachable("unsupported non numeric type");
453 OpBuilder::InsertionGuard
g(builder
);
454 builder
.setInsertionPointToEnd(&block
);
458 return builder
.create
<complex::AddOp
>(arg0
.getLoc(), arg0
, arg1
);
459 if (allFloatingPoint
)
460 return builder
.create
<arith::AddFOp
>(arg0
.getLoc(), arg0
, arg1
);
462 return builder
.create
<arith::OrIOp
>(arg0
.getLoc(), arg0
, arg1
);
463 return builder
.create
<arith::AddIOp
>(arg0
.getLoc(), arg0
, arg1
);
466 return builder
.create
<complex::SubOp
>(arg0
.getLoc(), arg0
, arg1
);
467 if (allFloatingPoint
)
468 return builder
.create
<arith::SubFOp
>(arg0
.getLoc(), arg0
, arg1
);
470 llvm_unreachable("unsupported operation: sub with bools");
471 return builder
.create
<arith::SubIOp
>(arg0
.getLoc(), arg0
, arg1
);
474 return builder
.create
<complex::MulOp
>(arg0
.getLoc(), arg0
, arg1
);
475 if (allFloatingPoint
)
476 return builder
.create
<arith::MulFOp
>(arg0
.getLoc(), arg0
, arg1
);
478 return builder
.create
<arith::AndIOp
>(arg0
.getLoc(), arg0
, arg1
);
479 return builder
.create
<arith::MulIOp
>(arg0
.getLoc(), arg0
, arg1
);
482 return builder
.create
<complex::DivOp
>(arg0
.getLoc(), arg0
, arg1
);
483 if (allFloatingPoint
)
484 return builder
.create
<arith::DivFOp
>(arg0
.getLoc(), arg0
, arg1
);
486 llvm_unreachable("unsupported operation: div with bools");
487 return builder
.create
<arith::DivSIOp
>(arg0
.getLoc(), arg0
, arg1
);
488 case BinaryFn::div_unsigned
:
489 if (!allInteger
|| allBool
)
490 llvm_unreachable("unsupported operation: unsigned div not on uint");
491 return builder
.create
<arith::DivUIOp
>(arg0
.getLoc(), arg0
, arg1
);
492 case BinaryFn::max_signed
:
494 if (allFloatingPoint
)
495 return builder
.create
<arith::MaximumFOp
>(arg0
.getLoc(), arg0
, arg1
);
496 return builder
.create
<arith::MaxSIOp
>(arg0
.getLoc(), arg0
, arg1
);
497 case BinaryFn::min_signed
:
499 if (allFloatingPoint
)
500 return builder
.create
<arith::MinimumFOp
>(arg0
.getLoc(), arg0
, arg1
);
501 return builder
.create
<arith::MinSIOp
>(arg0
.getLoc(), arg0
, arg1
);
502 case BinaryFn::max_unsigned
:
504 if (allFloatingPoint
)
505 return builder
.create
<arith::MaximumFOp
>(arg0
.getLoc(), arg0
, arg1
);
506 return builder
.create
<arith::MaxUIOp
>(arg0
.getLoc(), arg0
, arg1
);
507 case BinaryFn::min_unsigned
:
509 if (allFloatingPoint
)
510 return builder
.create
<arith::MinimumFOp
>(arg0
.getLoc(), arg0
, arg1
);
511 return builder
.create
<arith::MinUIOp
>(arg0
.getLoc(), arg0
, arg1
);
513 assert(allFloatingPoint
);
514 return builder
.create
<math::PowFOp
>(arg0
.getLoc(), arg0
, arg1
);
516 llvm_unreachable("unsupported binary function");
519 // Build the ternary functions defined by OpDSL.
520 Value
buildTernaryFn(TernaryFn ternaryFn
, Value arg0
, Value arg1
,
523 isInteger(arg0
) && arg0
.getType().getIntOrFloatBitWidth() == 1;
524 bool tailFloatingPoint
=
525 isFloatingPoint(arg0
) && isFloatingPoint(arg1
) && isFloatingPoint(arg2
);
526 bool tailInteger
= isInteger(arg0
) && isInteger(arg1
) && isInteger(arg1
);
527 OpBuilder::InsertionGuard
g(builder
);
528 builder
.setInsertionPointToEnd(&block
);
530 case TernaryFn::select
:
531 if (!headBool
&& !(tailFloatingPoint
|| tailInteger
))
532 llvm_unreachable("unsupported non numeric type");
533 return builder
.create
<arith::SelectOp
>(arg0
.getLoc(), arg0
, arg1
, arg2
);
535 llvm_unreachable("unsupported ternary function");
538 // Build the type functions defined by OpDSL.
539 Value
buildTypeFn(TypeFn typeFn
, Type toType
, Value operand
) {
541 case TypeFn::cast_signed
:
542 return cast(toType
, operand
, false);
543 case TypeFn::cast_unsigned
:
544 return cast(toType
, operand
, true);
546 llvm_unreachable("unsupported type conversion function");
549 void yieldOutputs(ValueRange values
) {
550 OpBuilder::InsertionGuard
g(builder
);
551 builder
.setInsertionPointToEnd(&block
);
552 Location loc
= builder
.getUnknownLoc();
553 builder
.create
<YieldOp
>(loc
, values
);
556 Value
constant(const std::string
&value
) {
557 OpBuilder::InsertionGuard
g(builder
);
558 builder
.setInsertionPointToEnd(&block
);
559 Location loc
= builder
.getUnknownLoc();
560 Attribute valueAttr
= parseAttribute(value
, builder
.getContext());
561 return builder
.create
<arith::ConstantOp
>(loc
, ::cast
<TypedAttr
>(valueAttr
));
564 Value
index(int64_t dim
) {
565 OpBuilder::InsertionGuard
g(builder
);
566 builder
.setInsertionPointToEnd(&block
);
567 return builder
.create
<IndexOp
>(builder
.getUnknownLoc(), dim
);
570 Type
getIntegerType(unsigned width
) {
571 return IntegerType::get(builder
.getContext(), width
);
574 Type
getFloat32Type() { return Float32Type::get(builder
.getContext()); }
575 Type
getFloat64Type() { return Float64Type::get(builder
.getContext()); }
578 // Generates operations to cast the given operand to a specified type.
579 // If the cast cannot be performed, a warning will be issued and the
580 // operand returned as-is (which will presumably yield a verification
581 // issue downstream).
582 Value
cast(Type toType
, Value operand
, bool isUnsignedCast
) {
583 OpBuilder::InsertionGuard
g(builder
);
584 builder
.setInsertionPointToEnd(&block
);
585 auto loc
= operand
.getLoc();
586 return convertScalarToDtype(builder
, loc
, operand
, toType
, isUnsignedCast
);
589 bool isComplex(Value value
) {
590 return llvm::isa
<ComplexType
>(value
.getType());
592 bool isFloatingPoint(Value value
) {
593 return llvm::isa
<FloatType
>(value
.getType());
595 bool isInteger(Value value
) {
596 return llvm::isa
<IntegerType
>(value
.getType());
605 //===----------------------------------------------------------------------===//
607 //===----------------------------------------------------------------------===//
611 struct EraseSelfCopy
: OpRewritePattern
<CopyOp
> {
612 using OpRewritePattern
<CopyOp
>::OpRewritePattern
;
613 LogicalResult
matchAndRewrite(CopyOp copyOp
,
614 PatternRewriter
&rewriter
) const override
{
615 if (copyOp
.getInputs() != copyOp
.getOutputs())
616 return rewriter
.notifyMatchFailure(copyOp
, "not a self copy");
617 if (copyOp
.hasPureBufferSemantics())
618 rewriter
.eraseOp(copyOp
);
620 rewriter
.replaceOp(copyOp
, copyOp
.getInputs());
628 void CopyOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
629 MLIRContext
*context
) {
630 results
.add
<EraseSelfCopy
>(context
);
633 //===----------------------------------------------------------------------===//
635 //===----------------------------------------------------------------------===//
639 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
641 /// For such op chains, we can create new linalg.fill ops with the result
642 /// type of the tensor.expand/collapse_shape op.
643 template <typename TensorReshapeOp
>
644 struct FoldFillWithTensorReshape
: OpRewritePattern
<TensorReshapeOp
> {
645 using OpRewritePattern
<TensorReshapeOp
>::OpRewritePattern
;
646 LogicalResult
matchAndRewrite(TensorReshapeOp reshapeOp
,
647 PatternRewriter
&rewriter
) const override
{
648 auto oldFill
= reshapeOp
.getSrc().template getDefiningOp
<FillOp
>();
652 Location loc
= oldFill
.getLoc();
653 TensorReshapeOp newInit
;
654 if constexpr (std::is_same
<TensorReshapeOp
, tensor::ExpandShapeOp
>::value
) {
656 newInit
= rewriter
.create
<TensorReshapeOp
>(
657 loc
, reshapeOp
.getResultType(), oldFill
.output(),
658 reshapeOp
.getReassociation(), reshapeOp
.getOutputShape(),
659 reshapeOp
.getStaticOutputShape());
661 newInit
= rewriter
.create
<TensorReshapeOp
>(loc
, reshapeOp
.getResultType(),
663 reshapeOp
.getReassociation());
665 rewriter
.replaceOpWithNewOp
<FillOp
>(reshapeOp
, ValueRange
{oldFill
.value()},
666 ValueRange
{newInit
});
671 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
672 /// filling value are the same.
673 struct FoldFillWithPad final
: public OpRewritePattern
<tensor::PadOp
> {
674 using OpRewritePattern::OpRewritePattern
;
676 LogicalResult
matchAndRewrite(tensor::PadOp padOp
,
677 PatternRewriter
&rewriter
) const override
{
678 auto fillOp
= padOp
.getSource().getDefiningOp
<linalg::FillOp
>();
682 // We can only fold if the padding value is the same as the original
684 Value padValue
= padOp
.getConstantPaddingValue();
685 if (!padValue
|| fillOp
.value() != padValue
)
688 ReifiedRankedShapedTypeDims reifiedShape
;
689 if (failed(reifyResultShapes(rewriter
, padOp
, reifiedShape
)))
690 return rewriter
.notifyMatchFailure(
691 padOp
, "failed to reify tensor.pad op result shape");
693 auto emptyTensor
= rewriter
.create
<tensor::EmptyOp
>(
694 padOp
.getLoc(), reifiedShape
.front(),
695 padOp
.getResultType().getElementType());
698 .create
<FillOp
>(fillOp
.getLoc(), ValueRange
{padValue
},
699 ValueRange
{emptyTensor
})
701 if (replacement
.getType() != padOp
.getResultType()) {
702 replacement
= rewriter
.create
<tensor::CastOp
>(
703 fillOp
.getLoc(), padOp
.getResultType(), replacement
);
705 rewriter
.replaceOp(padOp
, replacement
);
710 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
711 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
712 /// filling value are the same.
713 struct FoldInsertPadIntoFill
: public OpRewritePattern
<tensor::InsertSliceOp
> {
714 using OpRewritePattern::OpRewritePattern
;
716 LogicalResult
matchAndRewrite(tensor::InsertSliceOp insertOp
,
717 PatternRewriter
&rewriter
) const override
{
718 auto srcPadOp
= insertOp
.getSource().getDefiningOp
<tensor::PadOp
>();
722 if (insertOp
.getType().getRank() != insertOp
.getSourceType().getRank())
725 // Walk back the tensor.insert_slice chain and find the first destination
726 // value at the start of the chain.
727 Value firstDest
= insertOp
.getDest();
728 while (auto prevOp
= firstDest
.getDefiningOp
<tensor::InsertSliceOp
>()) {
729 if (prevOp
.getType().getRank() != prevOp
.getSourceType().getRank())
732 // Make sure the range of values accessed are disjoint. Without this, we
733 // cannot fold tensor.pad away.
734 bool disjoint
= false;
735 for (int i
= 0, e
= prevOp
.getType().getRank(); i
< e
; ++i
) {
736 // If the dimension has dynamic offset/size, we cannot guarantee
737 // disjoint. So just skip it.
738 if (insertOp
.isDynamicOffset(i
) || insertOp
.isDynamicSize(i
) ||
739 insertOp
.isDynamicStride(i
) || prevOp
.isDynamicOffset(i
) ||
740 prevOp
.isDynamicSize(i
) || prevOp
.isDynamicStride(i
))
743 // Get the range start and end, inclusively for both.
744 int64_t prevStart
= prevOp
.getStaticOffset(i
);
745 int64_t prevEnd
= prevStart
+ (prevOp
.getStaticSize(i
) - 1) *
746 prevOp
.getStaticStride(i
);
747 int64_t nextStart
= insertOp
.getStaticOffset(i
);
748 int64_t nextEnd
= nextStart
+ (insertOp
.getStaticSize(i
) - 1) *
749 insertOp
.getStaticStride(i
);
750 if (prevEnd
< nextStart
|| nextEnd
< prevStart
) {
758 firstDest
= prevOp
.getDest();
761 // Check whether the first destination is a fill op. For overlapped cases,
762 // this also cannot be true.
763 auto dstFillOp
= firstDest
.getDefiningOp
<linalg::FillOp
>();
767 // We can only fold if the padding value is the same as the original
769 Value padValue
= srcPadOp
.getConstantPaddingValue();
770 if (!padValue
|| dstFillOp
.value() != padValue
)
773 SmallVector
<OpFoldResult
> lowPads
= srcPadOp
.getMixedLowPad();
774 SmallVector
<OpFoldResult
> oldOffsets
= insertOp
.getMixedOffsets();
776 Location loc
= insertOp
.getLoc();
777 MLIRContext
*context
= getContext();
779 AffineExpr sym0
, sym1
;
780 bindSymbols(context
, sym0
, sym1
);
781 auto addMap
= AffineMap::get(0, 2, {sym0
+ sym1
}, context
);
783 // Calculate the new offsets for the insert. It should be the old offsets
784 // plus low padding sizes.
785 SmallVector
<OpFoldResult
, 4> newOffsets
;
786 for (const auto &p
: llvm::zip(lowPads
, oldOffsets
)) {
787 newOffsets
.push_back(affine::makeComposedFoldedAffineApply(
788 rewriter
, loc
, addMap
, {std::get
<0>(p
), std::get
<1>(p
)}));
791 RankedTensorType srcPadType
= srcPadOp
.getSourceType();
792 SmallVector
<OpFoldResult
, 4> newSizes
;
793 for (int i
= 0, e
= srcPadType
.getRank(); i
< e
; ++i
) {
794 if (srcPadType
.isDynamicDim(i
)) {
796 rewriter
.create
<tensor::DimOp
>(loc
, srcPadOp
.getSource(), i
)
799 newSizes
.push_back(rewriter
.getIndexAttr(srcPadType
.getDimSize(i
)));
803 rewriter
.replaceOpWithNewOp
<tensor::InsertSliceOp
>(
804 insertOp
, srcPadOp
.getSource(), insertOp
.getDest(), newOffsets
,
805 newSizes
, insertOp
.getMixedStrides());
810 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
811 struct FoldFillWithTensorExtract
: public OpRewritePattern
<tensor::ExtractOp
> {
813 using OpRewritePattern
<tensor::ExtractOp
>::OpRewritePattern
;
815 LogicalResult
matchAndRewrite(tensor::ExtractOp extractOp
,
816 PatternRewriter
&rewriter
) const override
{
817 // See if tensor input of tensor.extract op is the result of a linalg.fill
819 auto fillOp
= extractOp
.getTensor().getDefiningOp
<linalg::FillOp
>();
823 // Get scalar input operand of linalg.fill op.
824 Value extractedScalar
= fillOp
.getInputs()[0];
826 // Replace tensor.extract op with scalar value used to fill the tensor.
827 rewriter
.replaceOp(extractOp
, extractedScalar
);
832 /// Folds pack(fill) into a single fill op if
833 /// 1. The pack op does not have padding value, or
834 /// 2. The filled value and padding value are the same.
835 static FailureOr
<FillOp
> foldFillPackIntoFillOp(RewriterBase
&rewriter
,
836 tensor::PackOp packOp
) {
837 auto fillOp
= packOp
.getSource().getDefiningOp
<FillOp
>();
841 if (auto paddingValue
= packOp
.getPaddingValue())
842 if (!isEqualConstantIntOrValue(paddingValue
, fillOp
.value()))
845 Value packOpDest
= packOp
.getDest();
846 if (!packOpDest
.hasOneUse())
849 return rewriter
.create
<linalg::FillOp
>(packOp
.getLoc(), fillOp
.getInputs(),
853 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
854 struct FoldFillWithPack
: public OpRewritePattern
<tensor::PackOp
> {
856 FoldFillWithPack(MLIRContext
*context
)
857 : OpRewritePattern
<tensor::PackOp
>(context
) {}
859 LogicalResult
matchAndRewrite(tensor::PackOp packOp
,
860 PatternRewriter
&rewriter
) const override
{
861 auto fillOp
= foldFillPackIntoFillOp(rewriter
, packOp
);
864 rewriter
.replaceOp(packOp
, fillOp
.value().result());
869 /// Fold fill with copy.
870 struct FoldFillWithCopy
: OpRewritePattern
<linalg::CopyOp
> {
871 using OpRewritePattern
<linalg::CopyOp
>::OpRewritePattern
;
873 LogicalResult
matchAndRewrite(linalg::CopyOp copyOp
,
874 PatternRewriter
&rewriter
) const override
{
875 if (auto fillOp
= copyOp
.getInputs().front().getDefiningOp
<FillOp
>()) {
876 rewriter
.replaceOpWithNewOp
<FillOp
>(copyOp
, copyOp
.getResultTypes(),
878 copyOp
.getOutputs());
881 if (auto fillOp
= copyOp
.getOutputs().front().getDefiningOp
<FillOp
>()) {
882 rewriter
.replaceOpWithNewOp
<linalg::CopyOp
>(copyOp
, copyOp
.getInputs(),
883 fillOp
.getOutputs());
890 /// Fold fill with transpose.
891 struct FoldFillWithTranspose
: OpRewritePattern
<linalg::TransposeOp
> {
892 using OpRewritePattern
<linalg::TransposeOp
>::OpRewritePattern
;
894 LogicalResult
matchAndRewrite(linalg::TransposeOp transposeOp
,
895 PatternRewriter
&rewriter
) const override
{
896 if (auto fillOp
= transposeOp
.getInput().getDefiningOp
<FillOp
>()) {
897 rewriter
.replaceOpWithNewOp
<FillOp
>(
898 transposeOp
, transposeOp
.getResultTypes(), fillOp
.getInputs(),
899 transposeOp
.getDpsInitOperand(0)->get());
906 /// Fold a concat with all elements being fills of the same value
907 /// into a fill of the concat result shape.
908 struct FoldConcatsOfFill
: public OpRewritePattern
<tensor::ConcatOp
> {
909 using OpRewritePattern::OpRewritePattern
;
911 LogicalResult
matchAndRewrite(tensor::ConcatOp concatOp
,
912 PatternRewriter
&rewriter
) const override
{
913 auto concatOperands
= concatOp
.getInputs();
914 if (concatOperands
.empty()) {
918 auto firstFillOp
= concatOperands
.front().getDefiningOp
<linalg::FillOp
>();
922 // Prefetch the fill value.
923 OpFoldResult firstFillVal
=
924 getAsOpFoldResult(firstFillOp
.getDpsInputOperand(0)->get());
925 // Collect all the outs values for the fill operations.
926 SmallVector
<Value
> allOuts
;
927 allOuts
.push_back(firstFillOp
.getDpsInitOperand(0)->get());
929 auto isDefinedByCompatibleFillOp
= [&](Value v
) -> bool {
930 auto fillOp
= v
.getDefiningOp
<linalg::FillOp
>();
935 OpFoldResult fillVal
=
936 getAsOpFoldResult(fillOp
.getDpsInputOperand(0)->get());
937 if (fillVal
!= firstFillVal
)
940 allOuts
.push_back(fillOp
.getDpsInitOperand(0)->get());
943 if (!llvm::all_of(concatOperands
.drop_front(),
944 isDefinedByCompatibleFillOp
)) {
945 return rewriter
.notifyMatchFailure(
946 concatOp
, "not all operands are defined by a compatible fill op");
949 Value outsConcat
= rewriter
.create
<tensor::ConcatOp
>(
950 concatOp
.getLoc(), concatOp
.getDim(), allOuts
);
951 rewriter
.replaceOpWithNewOp
<linalg::FillOp
>(
952 concatOp
, firstFillOp
.getDpsInputOperand(0)->get(), outsConcat
);
959 void FillOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
960 MLIRContext
*context
) {
961 results
.add
<FoldConcatsOfFill
, FoldFillWithCopy
, FoldFillWithTensorExtract
,
962 FoldFillWithPack
, FoldFillWithPad
,
963 FoldFillWithTensorReshape
<tensor::CollapseShapeOp
>,
964 FoldFillWithTensorReshape
<tensor::ExpandShapeOp
>,
965 FoldInsertPadIntoFill
, FoldFillWithTranspose
>(context
);
968 //===----------------------------------------------------------------------===//
970 //===----------------------------------------------------------------------===//
972 static void buildGenericRegion(
973 OpBuilder
&builder
, Location loc
, Region
®ion
, ValueRange inputs
,
975 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
) {
976 SmallVector
<Type
, 4> blockArgTypes
;
977 SmallVector
<Location
, 4> blockArgLocs
;
978 for (ValueRange container
: {inputs
, outputs
}) {
979 for (Value v
: container
) {
980 Type t
= v
.getType();
981 blockArgTypes
.push_back(
982 isa
<MemRefType
, RankedTensorType
>(t
) ? getElementTypeOrSelf(t
) : t
);
983 blockArgLocs
.push_back(v
.getLoc());
987 OpBuilder::InsertionGuard
guard(builder
);
989 builder
.createBlock(®ion
, region
.end(), blockArgTypes
, blockArgLocs
);
990 bodyBuild(builder
, loc
, bodyBlock
->getArguments());
993 void GenericOp::getAsmBlockArgumentNames(Region
®ion
,
994 OpAsmSetValueNameFn setNameFn
) {
995 for (Value v
: getRegionInputArgs())
997 for (Value v
: getRegionOutputArgs())
1001 void GenericOp::build(
1002 OpBuilder
&builder
, OperationState
&result
, TypeRange resultTensorTypes
,
1003 ValueRange inputs
, ValueRange outputs
, ArrayAttr indexingMaps
,
1004 ArrayAttr iteratorTypes
, StringAttr doc
, StringAttr libraryCall
,
1005 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
,
1006 ArrayRef
<NamedAttribute
> attributes
) {
1007 build(builder
, result
, resultTensorTypes
, inputs
, outputs
, indexingMaps
,
1008 iteratorTypes
, doc
, libraryCall
);
1009 result
.addAttributes(attributes
);
1011 buildGenericRegion(builder
, result
.location
, *result
.regions
.front(),
1012 inputs
, outputs
, bodyBuild
);
1015 void GenericOp::build(
1016 OpBuilder
&builder
, OperationState
&result
, TypeRange resultTensorTypes
,
1017 ValueRange inputs
, ValueRange outputs
, ArrayRef
<AffineMap
> indexingMaps
,
1018 ArrayRef
<utils::IteratorType
> iteratorTypes
, StringRef doc
,
1019 StringRef libraryCall
,
1020 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
,
1021 ArrayRef
<NamedAttribute
> attributes
) {
1022 build(builder
, result
, resultTensorTypes
, inputs
, outputs
,
1023 builder
.getAffineMapArrayAttr(indexingMaps
),
1024 builder
.getArrayAttr(llvm::to_vector(llvm::map_range(
1026 [&](utils::IteratorType iter
) -> mlir::Attribute
{
1027 return IteratorTypeAttr::get(builder
.getContext(), iter
);
1029 doc
.empty() ? StringAttr() : builder
.getStringAttr(doc
),
1030 libraryCall
.empty() ? StringAttr() : builder
.getStringAttr(libraryCall
),
1031 bodyBuild
, attributes
);
1034 void GenericOp::build(
1035 OpBuilder
&builder
, OperationState
&result
, ValueRange inputs
,
1036 ValueRange outputs
, ArrayRef
<AffineMap
> indexingMaps
,
1037 ArrayRef
<utils::IteratorType
> iteratorTypes
, StringRef doc
,
1038 StringRef libraryCall
,
1039 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
,
1040 ArrayRef
<NamedAttribute
> attributes
) {
1041 build(builder
, result
, TypeRange
{}, inputs
, outputs
, indexingMaps
,
1042 iteratorTypes
, doc
, libraryCall
, bodyBuild
, attributes
);
1045 void GenericOp::build(
1046 OpBuilder
&builder
, OperationState
&result
, ValueRange inputs
,
1047 ValueRange outputs
, ArrayRef
<AffineMap
> indexingMaps
,
1048 ArrayRef
<utils::IteratorType
> iteratorTypes
,
1049 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
,
1050 ArrayRef
<NamedAttribute
> attributes
) {
1051 build(builder
, result
, inputs
, outputs
, indexingMaps
, iteratorTypes
,
1053 /*libraryCall=*/"", bodyBuild
, attributes
);
1056 void GenericOp::build(
1057 OpBuilder
&builder
, OperationState
&result
, TypeRange resultTensorTypes
,
1058 ValueRange inputs
, ValueRange outputs
, ArrayRef
<AffineMap
> indexingMaps
,
1059 ArrayRef
<utils::IteratorType
> iteratorTypes
,
1060 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
,
1061 ArrayRef
<NamedAttribute
> attributes
) {
1062 build(builder
, result
, resultTensorTypes
, inputs
, outputs
, indexingMaps
,
1065 /*libraryCall=*/"", bodyBuild
, attributes
);
1068 void GenericOp::print(OpAsmPrinter
&p
) {
1071 // Print extra attributes.
1072 auto genericAttrNames
= linalgTraitAttrNames();
1074 llvm::StringSet
<> genericAttrNamesSet
;
1075 genericAttrNamesSet
.insert(genericAttrNames
.begin(), genericAttrNames
.end());
1076 SmallVector
<NamedAttribute
, 8> genericAttrs
;
1077 for (auto attr
: (*this)->getAttrs()) {
1078 if (attr
.getName() == getIteratorTypesAttrName()) {
1079 auto iteratorTypes
=
1080 llvm::cast
<ArrayAttr
>(attr
.getValue())
1081 .getAsValueRange
<IteratorTypeAttr
, utils::IteratorType
>();
1082 // Convert IteratorType enums into the string representation. This is
1083 // needed, because tests still use the old format when 'iterator_types'
1084 // attribute is represented as an array of strings.
1085 // TODO: Remove this conversion once tests are fixed.
1086 SmallVector
<Attribute
> iteratorTypeNames
=
1087 llvm::to_vector(llvm::map_range(
1088 iteratorTypes
, [&](utils::IteratorType t
) -> Attribute
{
1089 return StringAttr::get(getContext(), stringifyIteratorType(t
));
1092 genericAttrs
.emplace_back(
1093 getIteratorTypesAttrName(),
1094 ArrayAttr::get(getContext(), iteratorTypeNames
));
1095 } else if (genericAttrNamesSet
.count(attr
.getName().strref()) > 0) {
1096 genericAttrs
.push_back(attr
);
1099 if (!genericAttrs
.empty()) {
1100 auto genericDictAttr
= DictionaryAttr::get(getContext(), genericAttrs
);
1101 p
<< genericDictAttr
;
1104 // Printing is shared with named ops, except for the region and attributes
1105 printCommonStructuredOpParts(p
, getDpsInputs(), getDpsInits());
1107 genericAttrNames
.push_back("operandSegmentSizes");
1108 genericAttrNamesSet
.insert(genericAttrNames
.back());
1110 bool hasExtraAttrs
= false;
1111 for (NamedAttribute n
: (*this)->getAttrs()) {
1112 if ((hasExtraAttrs
= !genericAttrNamesSet
.contains(n
.getName().strref())))
1115 if (hasExtraAttrs
) {
1117 p
.printOptionalAttrDict((*this)->getAttrs(),
1118 /*elidedAttrs=*/genericAttrNames
);
1122 if (!getRegion().empty()) {
1124 p
.printRegion(getRegion());
1128 printNamedStructuredOpResults(p
, getResultTensors().getTypes());
1131 ParseResult
GenericOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
1132 DictionaryAttr dictAttr
;
1133 // Parse the core linalg traits that must check into a dictAttr.
1134 // The name is unimportant as we will overwrite result.attributes.
1135 // The core linalg traits must contain the information necessary to pass the
1137 llvm::SMLoc attributeLocation
= parser
.getCurrentLocation();
1138 if (parser
.parseAttribute(dictAttr
, "_", result
.attributes
))
1140 result
.attributes
.assign(dictAttr
.getValue().begin(),
1141 dictAttr
.getValue().end());
1143 // Convert array of string into an array of IteratorType enums. This is
1144 // needed, because tests still use the old format when 'iterator_types'
1145 // attribute is represented as an array of strings.
1146 // TODO: Remove this conversion once tests are fixed.
1147 auto iteratorTypes
= dyn_cast_or_null
<ArrayAttr
>(
1148 result
.attributes
.get(getIteratorTypesAttrName(result
.name
)));
1149 if (!iteratorTypes
) {
1150 return parser
.emitError(attributeLocation
)
1151 << "expected " << getIteratorTypesAttrName(result
.name
)
1152 << " array attribute";
1155 SmallVector
<Attribute
> iteratorTypeAttrs
;
1157 for (StringRef s
: iteratorTypes
.getAsValueRange
<StringAttr
>()) {
1158 auto maybeIteratorType
= utils::symbolizeIteratorType(s
);
1159 if (!maybeIteratorType
.has_value())
1160 return parser
.emitError(parser
.getCurrentLocation())
1161 << "unexpected iterator_type (" << s
<< ")";
1163 iteratorTypeAttrs
.push_back(
1164 IteratorTypeAttr::get(parser
.getContext(), maybeIteratorType
.value()));
1166 result
.attributes
.set(getIteratorTypesAttrName(result
.name
),
1167 parser
.getBuilder().getArrayAttr(iteratorTypeAttrs
));
1169 // Parsing is shared with named ops, except for the region.
1170 SmallVector
<Type
, 1> inputTypes
, outputTypes
;
1171 if (parseCommonStructuredOpParts(parser
, result
, inputTypes
, outputTypes
))
1174 // Optional attributes may be added.
1175 if (succeeded(parser
.parseOptionalKeyword("attrs")))
1176 if (failed(parser
.parseEqual()) ||
1177 failed(parser
.parseOptionalAttrDict(result
.attributes
)))
1180 std::unique_ptr
<Region
> region
= std::make_unique
<Region
>();
1181 if (parser
.parseRegion(*region
, {}))
1183 result
.addRegion(std::move(region
));
1185 // Generic ops may specify that a subset of its outputs are tensors. Such
1186 // outputs are specified in the result type.
1187 // TODO: may need to move output parsing before region parsing.
1188 // Need to wait for declarative assembly resolution to decide.
1189 SmallVector
<Type
, 1> outputTensorsTypes
;
1190 if (parseNamedStructuredOpResults(parser
, outputTensorsTypes
))
1192 result
.addTypes(outputTensorsTypes
);
1197 static void getGenericEffectsImpl(
1198 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
1200 LinalgOp linalgOp
) {
1201 for (auto [index
, operand
] : llvm::enumerate(linalgOp
.getDpsInputs())) {
1202 if (!llvm::isa
<MemRefType
>(operand
.getType()))
1204 effects
.emplace_back(
1205 MemoryEffects::Read::get(), &linalgOp
->getOpOperand(index
), /*stage=*/0,
1206 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1209 for (OpOperand
&operand
: linalgOp
.getDpsInitsMutable()) {
1210 if (!llvm::isa
<MemRefType
>(operand
.get().getType()))
1212 if (linalgOp
.payloadUsesValueFromOperand(&operand
)) {
1213 effects
.emplace_back(MemoryEffects::Read::get(), &operand
, /*stage=*/0,
1214 /*effectOnFullRegion=*/true,
1215 SideEffects::DefaultResource::get());
1217 effects
.emplace_back(MemoryEffects::Write::get(), &operand
, /*stage=*/0,
1218 /*effectOnFullRegion=*/true,
1219 SideEffects::DefaultResource::get());
1223 void GenericOp::getEffects(
1224 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
1226 getGenericEffectsImpl(effects
, cast
<LinalgOp
>(getOperation()));
1229 static Speculation::Speculatability
1230 getGenericSpeculatabilityImpl(LinalgOp linalgOp
) {
1231 // Operands with value semantics are speculatable, while operands with memory
1232 // semantics are not.
1233 if (!linalgOp
.hasPureTensorSemantics())
1234 return Speculation::NotSpeculatable
;
1235 // The body of the op can still have speculation in its region.
1236 return Speculation::RecursivelySpeculatable
;
1239 Speculation::Speculatability
GenericOp::getSpeculatability() {
1240 return getGenericSpeculatabilityImpl(cast
<LinalgOp
>(getOperation()));
1243 LogicalResult
GenericOp::verify() { return success(); }
1247 /// Remove any linalg operation (on tensors) that are just copying
1248 /// the values from inputs to the results. Requirements are
1249 /// 1) All iterator types are parallel
1250 /// 2) The body contains just a yield operation with the yielded values being
1251 /// the arguments corresponding to the operands.
1252 template <typename OpTy
>
1253 struct EraseIdentityLinalgOp
: public OpRewritePattern
<OpTy
> {
1254 using OpRewritePattern
<OpTy
>::OpRewritePattern
;
1256 LogicalResult
matchAndRewrite(OpTy linalgOp
,
1257 PatternRewriter
&rewriter
) const override
{
1258 // All indexing maps must be equal. It follows that they are permutations.
1259 if (!llvm::all_equal(linalgOp
.getIndexingMapsArray()))
1262 // Check that the body of the linalg operation is just a linalg.yield
1264 Block
&body
= linalgOp
->getRegion(0).front();
1265 if (!llvm::hasSingleElement(body
))
1267 auto yieldOp
= dyn_cast
<linalg::YieldOp
>(body
.getTerminator());
1271 // In the buffer case, we need to check exact buffer equality.
1272 if (linalgOp
.hasPureBufferSemantics()) {
1273 if (linalgOp
.getNumDpsInputs() == 1 && linalgOp
.getNumDpsInits() == 1 &&
1274 linalgOp
.getDpsInputOperand(0)->get() ==
1275 linalgOp
.getDpsInitOperand(0)->get()) {
1276 rewriter
.eraseOp(linalgOp
);
1282 // Mixed semantics is not supported yet.
1283 if (!linalgOp
.hasPureTensorSemantics())
1286 // Get the argument number of the returned values. That is the operand
1287 // number to use for replacing uses of this operation.
1288 SmallVector
<Value
> returnedArgs
;
1289 for (const auto &yieldVal
: llvm::enumerate(yieldOp
.getValues())) {
1290 auto yieldArg
= llvm::dyn_cast
<BlockArgument
>(yieldVal
.value());
1291 if (!yieldArg
|| yieldArg
.getOwner() != &body
)
1293 unsigned argumentNumber
= yieldArg
.getArgNumber();
1294 Value returnedArg
= linalgOp
->getOperand(argumentNumber
);
1295 Type resultType
= linalgOp
->getResult(yieldVal
.index()).getType();
1296 // The input can have a different type than the result, e.g. a dynamic
1297 // input dimension can be turned into a static output dimension.
1298 Type returnType
= returnedArg
.getType();
1299 if (returnType
!= resultType
) {
1300 // Distinguish between sparse conversion or dense tensor casting.
1301 // TODO: unify the two ops?
1302 if (sparse_tensor::getSparseTensorEncoding(returnType
) ||
1303 sparse_tensor::getSparseTensorEncoding(resultType
))
1304 returnedArg
= rewriter
.create
<sparse_tensor::ConvertOp
>(
1305 linalgOp
.getLoc(), resultType
, returnedArg
);
1307 if (!tensor::CastOp::areCastCompatible(returnedArg
.getType(),
1310 returnedArg
= rewriter
.create
<tensor::CastOp
>(
1311 linalgOp
.getLoc(), resultType
, returnedArg
);
1314 returnedArgs
.push_back(returnedArg
);
1317 if (returnedArgs
.size() != linalgOp
->getNumResults())
1319 rewriter
.replaceOp(linalgOp
, returnedArgs
);
1326 void GenericOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
1327 MLIRContext
*context
) {
1328 results
.add
<EraseIdentityLinalgOp
<GenericOp
>>(context
);
1331 LogicalResult
GenericOp::fold(FoldAdaptor
, SmallVectorImpl
<OpFoldResult
> &) {
1332 return memref::foldMemRefCast(*this);
1335 //===----------------------------------------------------------------------===//
1337 //===----------------------------------------------------------------------===//
1339 static ParseResult
parseDstStyleOp(
1340 OpAsmParser
&parser
, OperationState
&result
,
1341 function_ref
<ParseResult(OpAsmParser
&, NamedAttrList
&)> parseAttrsFn
=
1343 // Parse `ins` and `outs`.
1344 SmallVector
<Type
, 4> inputTypes
, outputTypes
;
1345 if (parseCommonStructuredOpParts(parser
, result
, inputTypes
, outputTypes
,
1346 /*addOperandSegmentSizes=*/false))
1349 // Add result types.
1350 for (Type outputType
: outputTypes
) {
1351 if (llvm::isa
<RankedTensorType
>(outputType
))
1352 result
.addTypes(outputType
);
1355 // Parse required attributes.
1356 if (parseAttrsFn
&& failed(parseAttrsFn(parser
, result
.attributes
)))
1359 // Parse optional attributes.
1360 if (parser
.parseOptionalAttrDict(result
.attributes
))
1365 void MapOp::getAsmBlockArgumentNames(Region
®ion
,
1366 OpAsmSetValueNameFn setNameFn
) {
1367 for (Value v
: getRegionInputArgs())
1371 void MapOp::getAsmResultNames(function_ref
<void(Value
, StringRef
)> setNameFn
) {
1372 if (!getResults().empty())
1373 setNameFn(getResults().front(), "mapped");
1377 OpBuilder
&builder
, OperationState
&result
, ValueRange inputs
, Value init
,
1378 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
,
1379 ArrayRef
<NamedAttribute
> attributes
) {
1380 build(builder
, result
, TypeRange
{}, inputs
, init
);
1381 result
.addAttributes(attributes
);
1383 // Add output types for `RankedTensorType` output arguments.
1384 Type initType
= init
.getType();
1385 if (llvm::isa
<RankedTensorType
>(initType
))
1386 result
.addTypes(initType
);
1389 buildGenericRegion(builder
, result
.location
, *result
.regions
.front(),
1390 inputs
, /*outputs=*/{}, bodyBuild
);
1393 static void addBodyWithPayloadOp(OpAsmParser
&parser
, OperationState
&result
,
1394 const OperationName
&payloadOpName
,
1395 const NamedAttrList
&payloadOpAttrs
,
1396 ArrayRef
<Value
> operands
,
1397 bool initFirst
= false) {
1398 OpBuilder
b(parser
.getContext());
1399 Region
*body
= result
.addRegion();
1400 Block
&block
= body
->emplaceBlock();
1401 b
.setInsertionPointToStart(&block
);
1402 SmallVector
<Value
> bbArgs
;
1403 for (auto &operand
: operands
) {
1405 llvm::cast
<ShapedType
>(operand
.getType()).getElementType(),
1408 SmallVector
<Value
> payloadOpOperands
;
1409 // If initFirst flag is enabled, we consider init as the first position of
1410 // payload operands.
1412 payloadOpOperands
.push_back(block
.getArguments().back());
1413 for (const auto &arg
: block
.getArguments().drop_back())
1414 payloadOpOperands
.push_back(arg
);
1416 payloadOpOperands
= {block
.getArguments().begin(),
1417 block
.getArguments().end()};
1420 Operation
*payloadOp
= b
.create(
1421 result
.location
, b
.getStringAttr(payloadOpName
.getStringRef()),
1423 TypeRange
{llvm::cast
<ShapedType
>(result
.operands
.back().getType())
1426 b
.create
<YieldOp
>(result
.location
, payloadOp
->getResults());
1429 ParseResult
MapOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
1430 std::optional
<OperationName
> payloadOpName
;
1431 NamedAttrList payloadOpAttrs
;
1432 if (succeeded(parser
.parseOptionalLBrace())) {
1433 FailureOr
<OperationName
> operationName
= parser
.parseCustomOperationName();
1434 if (failed(operationName
))
1436 if (parser
.parseOptionalAttrDict(payloadOpAttrs
))
1438 payloadOpName
= operationName
.value();
1439 if (parser
.parseRBrace())
1443 if (parseDstStyleOp(parser
, result
))
1446 if (payloadOpName
.has_value()) {
1447 if (!result
.operands
.empty())
1448 addBodyWithPayloadOp(parser
, result
, payloadOpName
.value(),
1450 ArrayRef(result
.operands
).drop_back());
1454 SmallVector
<OpAsmParser::Argument
> regionArgs
;
1455 if (parser
.parseArgumentList(regionArgs
, OpAsmParser::Delimiter::Paren
,
1456 /*allowType=*/true, /*allowAttrs=*/true)) {
1459 Region
*body
= result
.addRegion();
1460 if (parser
.parseRegion(*body
, regionArgs
))
1466 // Retrieve the operation from the body, if it is the only one (except
1467 // yield) and if it gets the same amount of arguments as the body does.
1468 // If initFirst flag is enabled, we check that init takes the first position in
1469 // operands of payload.
1470 static Operation
*findPayloadOp(Block
*body
, bool initFirst
= false) {
1471 if (body
->getOperations().size() != 2)
1473 Operation
&payload
= body
->getOperations().front();
1474 assert(isa
<YieldOp
>(body
->getOperations().back()));
1476 if (payload
.getNumOperands() == 0 ||
1477 payload
.getNumOperands() != body
->getNumArguments())
1481 if (payload
.getOperands().back() != body
->getArgument(0))
1484 for (const auto &[operand
, bbArg
] :
1485 llvm::zip(payload
.getOperands(), body
->getArguments().drop_front())) {
1486 if (bbArg
!= operand
)
1490 for (const auto &[operand
, bbArg
] :
1491 llvm::zip(payload
.getOperands(), body
->getArguments())) {
1492 if (bbArg
!= operand
)
1499 void printShortForm(OpAsmPrinter
&p
, Operation
*payloadOp
) {
1500 SmallVector
<StringRef
> elidedAttrs
;
1501 std::string attrToElide
;
1502 p
<< " { " << payloadOp
->getName().getStringRef();
1503 for (const auto &attr
: payloadOp
->getAttrs()) {
1505 llvm::dyn_cast
<mlir::arith::FastMathFlagsAttr
>(attr
.getValue());
1506 if (fastAttr
&& fastAttr
.getValue() == mlir::arith::FastMathFlags::none
) {
1507 attrToElide
= attr
.getName().str();
1508 elidedAttrs
.push_back(attrToElide
);
1512 p
.printOptionalAttrDict(payloadOp
->getAttrs(), elidedAttrs
);
1516 void MapOp::print(OpAsmPrinter
&p
) {
1517 Block
*mapper
= getBody();
1518 Operation
*payloadOp
= findPayloadOp(mapper
);
1520 printShortForm(p
, payloadOp
);
1523 printCommonStructuredOpParts(p
, getDpsInputs(), getDpsInits());
1524 p
.printOptionalAttrDict((*this)->getAttrs());
1527 // Print region if the payload op was not detected.
1531 llvm::interleaveComma(mapper
->getArguments(), p
,
1532 [&](auto arg
) { p
.printRegionArgument(arg
); });
1535 p
.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1540 LogicalResult
MapOp::verify() {
1541 auto *bodyBlock
= getBody();
1542 auto blockArgs
= bodyBlock
->getArguments();
1544 // Checks if the number of `inputs` match the arity of the `mapper` region.
1545 if (getInputs().size() != blockArgs
.size())
1546 return emitOpError() << "expects number of operands to match the arity of "
1548 << getInputs().size() << " and " << blockArgs
.size();
1550 // The parameters of mapper should all match the element type of inputs.
1551 for (const auto &[bbArgType
, inputArg
] :
1552 llvm::zip(bodyBlock
->getArgumentTypes(), getInputs())) {
1553 auto inputElemType
=
1554 llvm::cast
<ShapedType
>(inputArg
.getType()).getElementType();
1555 if (bbArgType
!= inputElemType
) {
1556 return emitOpError() << "expected element type of input " << inputElemType
1557 << " to match bbArg type " << bbArgType
;
1561 // The shape of each input must match the shape of the output.
1562 auto outputShape
= getInit().getType().getShape();
1563 for (Type inputArgType
: TypeRange
{getInputs()}) {
1564 auto inputElemShape
= llvm::cast
<ShapedType
>(inputArgType
).getShape();
1565 if (inputElemShape
!= outputShape
) {
1566 return emitOpError() << "expected shape of input (" << inputElemShape
1567 << ") to match shape of output (" << outputShape
1575 SmallVector
<utils::IteratorType
> MapOp::getIteratorTypesArray() {
1576 int64_t rank
= getInit().getType().getRank();
1577 return SmallVector
<utils::IteratorType
>(rank
, utils::IteratorType::parallel
);
1580 ArrayAttr
MapOp::getIndexingMaps() {
1581 Builder
builder(getContext());
1582 int64_t rank
= getInit().getType().getRank();
1583 int64_t numIndexingMaps
= getOperands().size();
1584 return builder
.getAffineMapArrayAttr(SmallVector
<AffineMap
>(
1585 numIndexingMaps
, builder
.getMultiDimIdentityMap(rank
)));
1588 void MapOp::getEffects(
1589 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
1591 getGenericEffectsImpl(effects
, cast
<LinalgOp
>(getOperation()));
1594 Speculation::Speculatability
MapOp::getSpeculatability() {
1595 return getGenericSpeculatabilityImpl(cast
<LinalgOp
>(getOperation()));
1598 //===----------------------------------------------------------------------===//
1600 //===----------------------------------------------------------------------===//
1602 void ReduceOp::getAsmBlockArgumentNames(Region
®ion
,
1603 OpAsmSetValueNameFn setNameFn
) {
1604 for (Value v
: getRegionInputArgs())
1606 for (Value v
: getRegionOutputArgs())
1607 setNameFn(v
, "init");
1610 void ReduceOp::getAsmResultNames(
1611 function_ref
<void(Value
, StringRef
)> setNameFn
) {
1612 if (!getResults().empty())
1613 setNameFn(getResults().front(), "reduced");
1616 void ReduceOp::build(
1617 OpBuilder
&builder
, OperationState
&result
, ValueRange inputs
,
1618 ValueRange inits
, ArrayRef
<int64_t> dimensions
,
1619 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
,
1620 ArrayRef
<NamedAttribute
> attributes
) {
1621 build(builder
, result
, TypeRange
{}, inputs
, inits
, dimensions
);
1622 result
.addAttributes(attributes
);
1624 // Add output types for `RankedTensorType` output arguments.
1625 for (Value init
: inits
) {
1626 Type initType
= init
.getType();
1627 if (llvm::isa
<RankedTensorType
>(initType
))
1628 result
.addTypes(initType
);
1632 buildGenericRegion(builder
, result
.location
, *result
.regions
.front(),
1633 inputs
, inits
, bodyBuild
);
1636 SmallVector
<utils::IteratorType
> ReduceOp::getIteratorTypesArray() {
1638 llvm::cast
<ShapedType
>(getInputs()[0].getType()).getRank();
1639 SmallVector
<utils::IteratorType
> iteratorTypes(inputRank
,
1640 utils::IteratorType::parallel
);
1641 for (int64_t reductionDim
: getDimensions())
1642 iteratorTypes
[reductionDim
] = utils::IteratorType::reduction
;
1643 return iteratorTypes
;
1646 ArrayAttr
ReduceOp::getIndexingMaps() {
1648 llvm::cast
<ShapedType
>(getInputs()[0].getType()).getRank();
1649 SmallVector
<AffineMap
> affineMaps(
1651 AffineMap::getMultiDimIdentityMap(inputRank
, getContext()));
1652 AffineMap resultMap
=
1653 AffineMap::getMultiDimIdentityMap(inputRank
, getContext())
1654 .dropResults(getDimensions());
1655 for (int64_t i
= 0, e
= getNumDpsInits(); i
< e
; ++i
)
1656 affineMaps
.push_back(resultMap
);
1657 return Builder(getContext()).getAffineMapArrayAttr(affineMaps
);
1660 void ReduceOp::getEffects(
1661 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
1663 getGenericEffectsImpl(effects
, cast
<LinalgOp
>(getOperation()));
1666 Speculation::Speculatability
ReduceOp::getSpeculatability() {
1667 return getGenericSpeculatabilityImpl(cast
<LinalgOp
>(getOperation()));
1670 static ParseResult
parseDenseI64ArrayAttr(OpAsmParser
&parser
,
1671 NamedAttrList
&attributes
,
1672 StringRef attributeName
) {
1673 if (parser
.parseKeyword(attributeName
) || parser
.parseEqual())
1676 attributes
.set(attributeName
, DenseI64ArrayAttr::parse(parser
, Type
{}));
1680 ParseResult
ReduceOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
1681 std::optional
<OperationName
> payloadOpName
;
1682 NamedAttrList payloadOpAttrs
;
1683 if (succeeded(parser
.parseOptionalLBrace())) {
1684 FailureOr
<OperationName
> operationName
= parser
.parseCustomOperationName();
1685 if (failed(operationName
))
1687 if (parser
.parseOptionalAttrDict(payloadOpAttrs
))
1689 payloadOpName
= operationName
.value();
1690 if (parser
.parseRBrace())
1694 if (parseDstStyleOp(
1695 parser
, result
, [&](OpAsmParser
&parser
, NamedAttrList
&attributes
) {
1696 return parseDenseI64ArrayAttr(parser
, attributes
, "dimensions");
1700 if (payloadOpName
.has_value()) {
1701 addBodyWithPayloadOp(parser
, result
, payloadOpName
.value(), payloadOpAttrs
,
1702 ArrayRef(result
.operands
), /*initFirst=*/true);
1704 SmallVector
<OpAsmParser::Argument
> regionArgs
;
1705 if (parser
.parseArgumentList(regionArgs
, OpAsmParser::Delimiter::Paren
,
1706 /*allowType=*/true, /*allowAttrs=*/true)) {
1710 Region
*body
= result
.addRegion();
1711 if (parser
.parseRegion(*body
, regionArgs
))
1718 static void printDenseI64ArrayAttr(OpAsmPrinter
&p
, StringRef attributeName
,
1719 ArrayRef
<int64_t> attributeValue
) {
1720 p
<< ' ' << attributeName
<< " = [" << attributeValue
<< "] ";
1723 void ReduceOp::print(OpAsmPrinter
&p
) {
1724 Block
*mapper
= getBody();
1725 Operation
*payloadOp
= findPayloadOp(mapper
, /*initFirst=*/true);
1727 printShortForm(p
, payloadOp
);
1730 printCommonStructuredOpParts(p
, getDpsInputs(), getDpsInits());
1731 printDenseI64ArrayAttr(p
, getDimensionsAttrName(), getDimensions());
1732 p
.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1734 // Print region if the payload op was not detected.
1738 llvm::interleaveComma(mapper
->getArguments(), p
,
1739 [&](auto arg
) { p
.printRegionArgument(arg
); });
1742 p
.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1747 LogicalResult
ReduceOp::verify() {
1748 ArrayRef
<int64_t> dimensionsRef
= getDimensions();
1750 for (int64_t i
= 1; i
< getNumDpsInputs(); ++i
) {
1751 if (llvm::cast
<ShapedType
>(getInputs()[i
].getType()).getShape() !=
1752 llvm::cast
<ShapedType
>(getInputs()[0].getType()).getShape()) {
1753 return emitOpError() << "expects all inputs to have the same shapes. "
1754 "Shape at input-index "
1756 << " is not equal to the shape at input-index 0.";
1759 for (int64_t i
= 1; i
< getNumDpsInits(); ++i
) {
1760 if (llvm::cast
<ShapedType
>(getInits()[i
].getType()).getShape() !=
1761 llvm::cast
<ShapedType
>(getInits()[0].getType()).getShape()) {
1762 return emitOpError() << "expects all outputs to have the same shapes. "
1763 "Shape at output-index "
1765 << " is not equal to the shape at output-index 0.";
1768 auto inputType
= llvm::cast
<ShapedType
>(getInputs()[0].getType());
1769 auto initType
= llvm::cast
<ShapedType
>(getInits()[0].getType());
1771 DenseSet
<int64_t> dimensionsToReduce
;
1772 for (int64_t dimension
: dimensionsRef
) {
1773 if (dimension
< 0 || dimension
>= inputType
.getRank()) {
1774 return emitOpError()
1775 << "dimensions for reduction should be in the range [0, "
1776 << inputType
.getRank() - 1 << "].";
1778 dimensionsToReduce
.insert(dimension
);
1781 auto inputDims
= inputType
.getShape();
1782 auto initDims
= initType
.getShape();
1784 // Input dimensions that will be left after the reduction.
1785 SmallVector
<int64_t> reducedInputDims
;
1786 for (const auto &en
: llvm::enumerate(inputDims
)) {
1787 if (!dimensionsToReduce
.count(en
.index()))
1788 reducedInputDims
.push_back(en
.value());
1791 if (reducedInputDims
.size() != static_cast<size_t>(initType
.getRank())) {
1792 return emitOpError() << "number of dimensions after reduction "
1793 << reducedInputDims
.size()
1794 << " doesn't match the init rank "
1795 << initType
.getRank();
1798 if (reducedInputDims
!= initDims
)
1799 return emitOpError() << "init dimensions [" << initDims
1800 << "] doesn't match input dimensions after reduction ["
1801 << reducedInputDims
<< "]";
1803 Block
*block
= getBody();
1804 if (block
->getNumArguments() != this->getNumOperands())
1805 return emitOpError()
1806 << "mismatching number of operands and block arguments";
1808 // Check that the first block arguments match the element type of the inputs.
1809 for (auto [input
, bbArg
] : llvm::zip(getInputs(), block
->getArguments())) {
1810 Type inputElementType
=
1811 llvm::cast
<ShapedType
>(input
.getType()).getElementType();
1812 if (inputElementType
!= bbArg
.getType())
1813 return emitOpError()
1814 << "input element type " << inputElementType
1815 << " does not match corresponding block argument type "
1819 // Check that the last block arguments match the element type of the outputs.
1820 for (auto [output
, bbArg
] : llvm::zip(
1821 getDpsInits(), block
->getArguments().take_back(getNumDpsInits()))) {
1822 auto outputElementType
=
1823 llvm::cast
<ShapedType
>(output
.getType()).getElementType();
1824 if (outputElementType
!= bbArg
.getType())
1825 return emitOpError()
1826 << "output element type " << outputElementType
1827 << " does not match corresponding block argument type "
1833 //===----------------------------------------------------------------------===//
1835 //===----------------------------------------------------------------------===//
1837 static void buildIdentityRegion(OpBuilder
&builder
, Location loc
,
1838 Region
®ion
, ValueRange inputs
,
1839 ValueRange outputs
) {
1840 buildGenericRegion(builder
, loc
, region
, inputs
, outputs
,
1841 [](OpBuilder
&b
, Location loc
, ValueRange args
) {
1843 b
.create
<linalg::YieldOp
>(loc
, args
[0]);
1847 void TransposeOp::build(::mlir::OpBuilder
&builder
,
1848 ::mlir::OperationState
&result
, Value input
, Value init
,
1849 DenseI64ArrayAttr permutation
,
1850 ArrayRef
<NamedAttribute
> attributes
) {
1851 result
.addOperands(input
);
1852 result
.addOperands(init
);
1853 result
.addAttribute(getPermutationAttrName(result
.name
), permutation
);
1854 result
.addAttributes(attributes
);
1856 // Add output types for `RankedTensorType` output arguments.
1857 Type initType
= init
.getType();
1858 if (llvm::isa
<RankedTensorType
>(initType
))
1859 result
.addTypes(initType
);
1861 buildIdentityRegion(builder
, result
.location
, *result
.addRegion(), input
,
1865 void TransposeOp::build(::mlir::OpBuilder
&builder
,
1866 ::mlir::OperationState
&result
, Value input
, Value init
,
1867 ArrayRef
<int64_t> permutation
,
1868 ArrayRef
<NamedAttribute
> attributes
) {
1869 build(builder
, result
, input
, init
, builder
.getDenseI64ArrayAttr(permutation
),
1873 ParseResult
TransposeOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
1874 if (failed(parseDstStyleOp(
1875 parser
, result
, [&](OpAsmParser
&parser
, NamedAttrList
&attributes
) {
1876 return parseDenseI64ArrayAttr(parser
, attributes
, "permutation");
1880 OpBuilder
builder(parser
.getContext());
1881 buildIdentityRegion(builder
, result
.location
, *result
.addRegion(),
1882 /*inputs=*/result
.operands
,
1887 void TransposeOp::getAsmResultNames(
1888 function_ref
<void(Value
, StringRef
)> setNameFn
) {
1889 if (!getResults().empty())
1890 setNameFn(getResults().front(), "transposed");
1893 void TransposeOp::print(OpAsmPrinter
&p
) {
1894 printCommonStructuredOpParts(p
, getDpsInputs(), getDpsInits());
1895 printDenseI64ArrayAttr(p
, getPermutationAttrName(), getPermutation());
1896 p
.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1899 LogicalResult
TransposeOp::verify() {
1900 ArrayRef
<int64_t> permutationRef
= getPermutation();
1902 if (!isPermutationVector(permutationRef
))
1903 return emitOpError("permutation is not valid");
1905 auto inputType
= getInput().getType();
1906 auto initType
= getInit().getType();
1908 int64_t rank
= inputType
.getRank();
1910 if (rank
!= initType
.getRank())
1911 return emitOpError() << "input rank " << rank
1912 << " does not match init rank " << initType
.getRank();
1914 if (rank
!= static_cast<int64_t>(permutationRef
.size()))
1915 return emitOpError() << "size of permutation " << permutationRef
.size()
1916 << " does not match the argument rank " << rank
;
1918 auto inputDims
= inputType
.getShape();
1919 auto initDims
= initType
.getShape();
1921 for (int64_t i
= 0; i
< rank
; ++i
) {
1922 int64_t inputDim
= inputDims
[permutationRef
[i
]];
1923 int64_t initDim
= initDims
[i
];
1925 if (inputDim
!= initDim
) {
1926 return emitOpError() << "dim(result, " << i
<< ") = " << initDim
1927 << " doesn't match dim(input, permutation[" << i
1928 << "]) = " << inputDim
;
1935 SmallVector
<utils::IteratorType
> TransposeOp::getIteratorTypesArray() {
1936 int64_t rank
= getInit().getType().getRank();
1937 return SmallVector
<utils::IteratorType
>(rank
, utils::IteratorType::parallel
);
1940 ArrayAttr
TransposeOp::getIndexingMaps() {
1941 Builder
builder(getContext());
1942 int64_t rank
= getInit().getType().getRank();
1943 return builder
.getAffineMapArrayAttr(
1944 {inversePermutation(AffineMap::getPermutationMap(
1945 llvm::to_vector_of
<unsigned>(getPermutation()), getContext())),
1946 builder
.getMultiDimIdentityMap(rank
)});
1949 void TransposeOp::getEffects(
1950 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
1952 getGenericEffectsImpl(effects
, cast
<LinalgOp
>(getOperation()));
1955 Speculation::Speculatability
TransposeOp::getSpeculatability() {
1956 return getGenericSpeculatabilityImpl(cast
<LinalgOp
>(getOperation()));
1959 LogicalResult
TransposeOp::fold(FoldAdaptor adaptor
,
1960 SmallVectorImpl
<OpFoldResult
> &result
) {
1961 // Only the tensor type is supported.
1962 if (!isa
<TensorType
>(getInput().getType()))
1965 // Single dimension transpose.
1966 if (getPermutation().size() == 0) {
1967 result
.push_back(getInput());
1970 // Identity permutation.
1971 if (isIdentityPermutation(getPermutation())) {
1972 result
.push_back(getInput());
1979 /// Fold transpose with transpose.
1980 struct FoldTransposeWithTranspose
: OpRewritePattern
<linalg::TransposeOp
> {
1981 using OpRewritePattern
<linalg::TransposeOp
>::OpRewritePattern
;
1983 LogicalResult
matchAndRewrite(linalg::TransposeOp transposeOp
,
1984 PatternRewriter
&rewriter
) const override
{
1985 auto defTransposeOp
= transposeOp
.getInput().getDefiningOp
<TransposeOp
>();
1986 if (!defTransposeOp
)
1988 ArrayRef
<int64_t> defPerms
= defTransposeOp
.getPermutation();
1989 ArrayRef
<int64_t> perms
= transposeOp
.getPermutation();
1990 SmallVector
<int64_t> foldedPerms
;
1991 foldedPerms
.reserve(perms
.size());
1992 for (int64_t perm
: perms
)
1993 foldedPerms
.push_back(defPerms
[perm
]);
1995 rewriter
.replaceOpWithNewOp
<TransposeOp
>(
1996 transposeOp
, defTransposeOp
.getInput(), transposeOp
.getInit(),
2002 /// This pattern canonicalize transpose by swapping the order of
2003 /// broadcast and transpose:
2004 /// transpose(broadcast(input)) -> broadcast(transpose(input))
2005 struct SwapTransposeWithBroadcast
: OpRewritePattern
<linalg::TransposeOp
> {
2006 using OpRewritePattern
<linalg::TransposeOp
>::OpRewritePattern
;
2008 LogicalResult
matchAndRewrite(linalg::TransposeOp transposeOp
,
2009 PatternRewriter
&rewriter
) const override
{
2010 Value input
= transposeOp
.getInput();
2011 BroadcastOp broadcastOp
= input
.getDefiningOp
<BroadcastOp
>();
2012 if (!input
.hasOneUse() || !broadcastOp
)
2015 ArrayRef
<int64_t> dimensions
= broadcastOp
.getDimensions();
2016 ArrayRef
<int64_t> perms
= transposeOp
.getPermutation();
2018 // Get new perms and new dimensions.
2019 SmallVector
<int64_t> resultPerms
= dropDims(perms
, dimensions
);
2020 SmallVector
<int64_t> invertPerm
= invertPermutationVector(perms
);
2021 SmallVector
<int64_t> resultDimensions
;
2022 unsigned dimensionSize
= dimensions
.size();
2023 for (unsigned i
= 0; i
< dimensionSize
; ++i
)
2024 resultDimensions
.push_back(invertPerm
[dimensions
[i
]]);
2026 // Create transpose result.
2027 Value broadcastInput
= broadcastOp
.getInput();
2028 Location loc
= transposeOp
.getLoc();
2029 MLIRContext
*ctx
= transposeOp
.getContext();
2030 SmallVector
<OpFoldResult
> dims
;
2031 auto broadcastInputTy
=
2032 mlir::cast
<RankedTensorType
>(broadcastInput
.getType());
2033 unsigned inputRank
= broadcastInputTy
.getRank();
2034 for (unsigned i
= 0; i
< inputRank
; ++i
) {
2035 if (broadcastInputTy
.isDynamicDim(i
)) {
2036 dims
.push_back(rewriter
.create
<tensor::DimOp
>(loc
, broadcastInput
, i
)
2039 dims
.push_back(IntegerAttr::get(IndexType::get(ctx
),
2040 broadcastInputTy
.getDimSize(i
)));
2043 SmallVector
<OpFoldResult
> transposeResultShapes
=
2044 applyPermutation(dims
, resultPerms
);
2045 Value transposeInit
= rewriter
.create
<tensor::EmptyOp
>(
2046 transposeOp
.getLoc(), transposeResultShapes
,
2047 broadcastInputTy
.getElementType());
2049 // Create broadcast(transpose(input)).
2050 Value transposeResult
=
2052 .create
<TransposeOp
>(loc
, broadcastOp
.getInput(), transposeInit
,
2055 rewriter
.replaceOpWithNewOp
<BroadcastOp
>(
2056 transposeOp
, transposeResult
, transposeOp
.getInit(), resultDimensions
);
2061 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
2062 MLIRContext
*context
) {
2063 results
.add
<FoldTransposeWithTranspose
, SwapTransposeWithBroadcast
>(context
);
2066 //===----------------------------------------------------------------------===//
2068 //===----------------------------------------------------------------------===//
2070 void BroadcastOp::build(::mlir::OpBuilder
&builder
,
2071 ::mlir::OperationState
&result
, Value input
, Value init
,
2072 DenseI64ArrayAttr dimensions
,
2073 ArrayRef
<NamedAttribute
> attributes
) {
2074 result
.addOperands(input
);
2075 result
.addOperands(init
);
2076 result
.addAttribute(getDimensionsAttrName(result
.name
), dimensions
);
2077 result
.addAttributes(attributes
);
2079 // Add output types for `RankedTensorType` output arguments.
2080 Type initType
= init
.getType();
2081 if (llvm::isa
<RankedTensorType
>(initType
))
2082 result
.addTypes(initType
);
2084 buildIdentityRegion(builder
, result
.location
, *result
.addRegion(), input
,
2088 void BroadcastOp::build(::mlir::OpBuilder
&builder
,
2089 ::mlir::OperationState
&result
, Value input
, Value init
,
2090 ArrayRef
<int64_t> dimensions
,
2091 ArrayRef
<NamedAttribute
> attributes
) {
2092 build(builder
, result
, input
, init
, builder
.getDenseI64ArrayAttr(dimensions
),
2096 ParseResult
BroadcastOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
2097 if (failed(parseDstStyleOp(
2098 parser
, result
, [&](OpAsmParser
&parser
, NamedAttrList
&attributes
) {
2099 return parseDenseI64ArrayAttr(parser
, attributes
, "dimensions");
2103 OpBuilder
builder(parser
.getContext());
2104 buildIdentityRegion(builder
, result
.location
, *result
.addRegion(),
2105 /*inputs=*/result
.operands
,
2110 void BroadcastOp::getAsmResultNames(
2111 function_ref
<void(Value
, StringRef
)> setNameFn
) {
2112 if (!getResults().empty())
2113 setNameFn(getResults().front(), "broadcasted");
2116 void BroadcastOp::print(OpAsmPrinter
&p
) {
2117 printCommonStructuredOpParts(p
, getDpsInputs(), getDpsInits());
2118 printDenseI64ArrayAttr(p
, getDimensionsAttrName(), getDimensions());
2119 p
.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2122 LogicalResult
BroadcastOp::verify() {
2123 ArrayRef
<int64_t> dimensionsRef
= getDimensions();
2125 auto inputType
= getInput().getType();
2126 auto initType
= getInit().getType();
2128 int64_t inputRank
= inputType
.getRank();
2129 int64_t initRank
= initType
.getRank();
2131 auto inputShape
= inputType
.getShape();
2132 auto initShape
= initType
.getShape();
2134 if ((size_t)inputRank
+ dimensionsRef
.size() != (size_t)initRank
)
2135 return emitOpError() << "input rank plus added dimensions does not "
2136 "match init rank. input rank: "
2138 << ", dimensions size: " << dimensionsRef
.size()
2139 << ", init rank: " << initRank
;
2141 for (const auto &[idx
, dim
] : llvm::enumerate(dimensionsRef
)) {
2142 if (dim
< 0 || dim
>= initRank
)
2143 return emitOpError() << "dimension " << idx
2144 << " is out of range. expected range: [0, "
2145 << initRank
- 1 << "], got: " << dim
;
2148 // Mapping from input dims to init dims.
2149 SmallVector
<int64_t> dimMap
;
2150 for (auto dim
: llvm::seq
<int64_t>(0, initRank
)) {
2151 if (!llvm::is_contained(dimensionsRef
, dim
))
2152 dimMap
.push_back(dim
);
2155 for (const auto &[inputDimIdx
, initDimIdx
] : llvm::enumerate(dimMap
)) {
2156 // This dimensions is mapped from the input. Init and input dims should
2158 if (inputShape
[inputDimIdx
] != initShape
[initDimIdx
])
2159 return emitOpError() << "input dim " << inputDimIdx
2160 << " should match init dim " << initDimIdx
2161 << ". input: " << inputShape
[inputDimIdx
]
2162 << ", init: " << initShape
[initDimIdx
];
2168 SmallVector
<utils::IteratorType
> BroadcastOp::getIteratorTypesArray() {
2169 int64_t rank
= getInit().getType().getRank();
2170 return SmallVector
<utils::IteratorType
>(rank
, utils::IteratorType::parallel
);
2173 ArrayAttr
BroadcastOp::getIndexingMaps() {
2174 Builder
builder(getContext());
2175 int64_t rank
= getInit().getType().getRank();
2176 return builder
.getAffineMapArrayAttr(
2177 {builder
.getMultiDimIdentityMap(rank
).dropResults(getDimensions()),
2178 builder
.getMultiDimIdentityMap(rank
)});
2181 void BroadcastOp::getEffects(
2182 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
2184 getGenericEffectsImpl(effects
, cast
<LinalgOp
>(getOperation()));
2187 Speculation::Speculatability
BroadcastOp::getSpeculatability() {
2188 return getGenericSpeculatabilityImpl(cast
<LinalgOp
>(getOperation()));
2191 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
2192 MLIRContext
*context
) {
2193 results
.add
<EraseIdentityLinalgOp
<BroadcastOp
>>(context
);
2196 //===----------------------------------------------------------------------===//
2198 //===----------------------------------------------------------------------===//
2200 void linalg::YieldOp::print(OpAsmPrinter
&p
) {
2201 if (getNumOperands() > 0)
2202 p
<< ' ' << getOperands();
2203 p
.printOptionalAttrDict((*this)->getAttrs());
2204 if (getNumOperands() > 0)
2205 p
<< " : " << getOperandTypes();
2208 ParseResult
YieldOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
2209 SmallVector
<OpAsmParser::UnresolvedOperand
, 2> opInfo
;
2210 SmallVector
<Type
, 2> types
;
2211 SMLoc loc
= parser
.getCurrentLocation();
2212 return failure(parser
.parseOperandList(opInfo
) ||
2213 parser
.parseOptionalAttrDict(result
.attributes
) ||
2214 (!opInfo
.empty() && parser
.parseColonTypeList(types
)) ||
2215 parser
.resolveOperands(opInfo
, types
, loc
, result
.operands
));
2218 // Check the operand number and types must match the element types of the
2219 // LinalgOp interface's shaped operands.
2220 static LogicalResult
verifyYield(linalg::YieldOp op
, LinalgOp linalgOp
) {
2221 if (op
.getNumOperands() != linalgOp
.getNumDpsInits())
2222 return op
.emitOpError("expected number of yield values (")
2223 << op
.getNumOperands()
2224 << ") to match the number of inits / outs operands of the enclosing "
2225 << "LinalgOp (" << linalgOp
.getNumDpsInits() << ")";
2227 for (OpOperand
&opOperand
: op
->getOpOperands()) {
2228 OpOperand
*outputOperand
=
2229 linalgOp
.getDpsInitOperand(opOperand
.getOperandNumber());
2230 Type elementType
= outputOperand
->get().getType();
2231 if (isa
<MemRefType
, RankedTensorType
>(elementType
))
2232 elementType
= getElementTypeOrSelf(outputOperand
->get().getType());
2233 if (opOperand
.get().getType() != elementType
)
2234 return op
.emitOpError("type of yield operand ")
2235 << (opOperand
.getOperandNumber() + 1) << " ("
2236 << opOperand
.get().getType() << ") doesn't match "
2237 << "the element type of the enclosing linalg.generic op ("
2238 << elementType
<< ")";
2243 LogicalResult
linalg::YieldOp::verify() {
2244 auto *parentOp
= (*this)->getParentOp();
2245 if (parentOp
->getNumRegions() != 1 || parentOp
->getRegion(0).empty())
2246 return emitOpError("expected single non-empty parent region");
2248 if (auto linalgOp
= dyn_cast
<LinalgOp
>(parentOp
))
2249 return verifyYield(*this, linalgOp
);
2251 return emitOpError("expected parent op with LinalgOp interface");
2254 //===----------------------------------------------------------------------===//
2256 //===----------------------------------------------------------------------===//
2258 LogicalResult
IndexOp::verify() {
2259 auto linalgOp
= dyn_cast
<LinalgOp
>((*this)->getParentOp());
2261 return emitOpError("expected parent op with LinalgOp interface");
2262 if (linalgOp
.getNumLoops() <= getDim())
2263 return emitOpError("expected dim (")
2264 << getDim() << ") to be lower than the number of loops ("
2265 << linalgOp
.getNumLoops() << ") of the enclosing LinalgOp";
2269 /////// Operations corresponding to library calls defined with Tablegen ////////
2271 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2273 #define GET_OP_CLASSES
2274 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2276 #define GET_OP_CLASSES
2277 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2279 AffineMap
mlir::linalg::extractOrIdentityMap(std::optional
<AffineMap
> maybeMap
,
2281 MLIRContext
*context
) {
2285 return AffineMap::get(context
);
2286 return AffineMap::getMultiDimIdentityMap(rank
, context
);
2289 SmallVector
<AffineExpr
, 4>
2290 mlir::linalg::makeAffineDimExprs(unsigned num
, unsigned &startIdx
,
2291 MLIRContext
*context
) {
2292 SmallVector
<AffineExpr
, 4> res
;
2294 for (unsigned i
= 0; i
< num
; ++i
)
2295 res
.push_back(getAffineDimExpr(startIdx
++, context
));
2299 SmallVector
<AffineExpr
, 4> mlir::linalg::concat(ArrayRef
<AffineExpr
> a
,
2300 ArrayRef
<AffineExpr
> b
) {
2301 auto rangeA
= llvm::make_range(a
.begin(), a
.end());
2302 auto rangeB
= llvm::make_range(b
.begin(), b
.end());
2303 auto concatRanges
= llvm::concat
<const AffineExpr
>(rangeA
, rangeB
);
2304 return llvm::to_vector
<4>(concatRanges
);
2307 static LogicalResult
appendMangledType(llvm::raw_string_ostream
&ss
, Type t
) {
2308 if (auto memref
= llvm::dyn_cast
<MemRefType
>(t
)) {
2310 for (auto size
: memref
.getShape())
2315 if (failed(appendMangledType(ss
, memref
.getElementType())))
2317 if (auto as
= memref
.getMemorySpace()) {
2318 if (auto attr
= llvm::dyn_cast
<IntegerAttr
>(as
))
2319 ss
<< "as" << attr
.getInt();
2325 if (auto vec
= llvm::dyn_cast
<VectorType
>(t
)) {
2328 vec
.getShape(), [&](int64_t i
) { ss
<< i
; }, [&]() { ss
<< "x"; });
2329 if (failed(appendMangledType(ss
, vec
.getElementType())))
2333 if (t
.isSignlessIntOrIndexOrFloat()) {
2340 std::string
mlir::linalg::generateLibraryCallName(Operation
*op
) {
2341 assert(isa
<LinalgOp
>(op
));
2342 std::string
name(op
->getName().getStringRef().str());
2343 std::string fun
= "";
2344 for (NamedAttribute kv
: op
->getAttrs()) {
2345 if (UnaryFnAttr ufa
= llvm::dyn_cast
<UnaryFnAttr
>(kv
.getValue())) {
2346 fun
= stringifyEnum(ufa
.getValue()).str() + "_";
2347 } else if (BinaryFnAttr bfa
= llvm::dyn_cast
<BinaryFnAttr
>(kv
.getValue())) {
2348 fun
= stringifyEnum(bfa
.getValue()).str() + "_";
2352 std::replace(name
.begin(), name
.end(), '.', '_');
2353 llvm::raw_string_ostream
ss(name
);
2355 for (Type t
: op
->getOperandTypes()) {
2356 if (failed(appendMangledType(ss
, t
)))
2357 return std::string();
2364 //===----------------------------------------------------------------------===//
2365 // Canonicalizers and Folders.
2366 //===----------------------------------------------------------------------===//
2369 struct EraseDeadLinalgOp
: public OpInterfaceRewritePattern
<LinalgOp
> {
2370 using OpInterfaceRewritePattern
<LinalgOp
>::OpInterfaceRewritePattern
;
2372 LogicalResult
matchAndRewrite(LinalgOp op
,
2373 PatternRewriter
&rewriter
) const override
{
2374 for (OpOperand
&opOperand
: op
->getOpOperands()) {
2375 // Linalg "inputs" may be either tensor or memref type.
2376 // tensor<0xelt_type> is a convention that may not always mean
2377 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2378 auto mt
= llvm::dyn_cast
<MemRefType
>(opOperand
.get().getType());
2381 if (llvm::is_contained(op
.getShape(&opOperand
), 0)) {
2382 rewriter
.eraseOp(op
);
2390 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2391 /// result that is more static than the linalg op.
2392 struct FoldTensorCastConsumerOp
: public OpRewritePattern
<tensor::CastOp
> {
2393 using OpRewritePattern
<tensor::CastOp
>::OpRewritePattern
;
2395 LogicalResult
matchAndRewrite(tensor::CastOp castOp
,
2396 PatternRewriter
&rewriter
) const override
{
2397 if (!tensor::canFoldIntoProducerOp(castOp
))
2400 auto linalgOp
= castOp
.getSource().getDefiningOp
<LinalgOp
>();
2404 // Cast can be in conditionally reachable region, if which case folding will
2405 // generate invalid code. Only conservatively fold ops in same block for
2407 if (castOp
->getBlock() != linalgOp
->getBlock())
2410 OpBuilder::InsertionGuard
guard(rewriter
);
2411 rewriter
.setInsertionPoint(linalgOp
);
2413 Location loc
= linalgOp
.getLoc();
2414 OpResult resultValue
= llvm::cast
<OpResult
>(castOp
.getSource());
2415 unsigned resultNumber
= resultValue
.getResultNumber();
2417 llvm::cast
<RankedTensorType
>(castOp
->getResult(0).getType());
2418 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2419 // going from a more dynamic shape to a less dynamic shape. If the producer
2420 // for this cast, i.e. producer of the out operand, is also an operation
2421 // that folds with tensor.cast consumer (like this pattern), the cast will
2422 // continue to propagate as far up the stack as it can go.
2423 OpOperand
*outOperand
= linalgOp
.getDpsInitOperand(resultNumber
);
2425 rewriter
.create
<tensor::CastOp
>(loc
, resultType
, outOperand
->get());
2426 SmallVector
<Value
> newOperands
= linalgOp
.getDpsInputs();
2427 SmallVector
<Value
> outputOperands(linalgOp
.getDpsInits().begin(),
2428 linalgOp
.getDpsInits().end());
2429 outputOperands
[resultNumber
] = newOperand
;
2430 newOperands
.append(outputOperands
.begin(), outputOperands
.end());
2432 SmallVector
<Type
> resultTypes(linalgOp
->result_type_begin(),
2433 linalgOp
->result_type_end());
2434 resultTypes
[resultNumber
] = resultType
;
2435 Operation
*newOp
= clone(rewriter
, linalgOp
, resultTypes
, newOperands
);
2437 // Create a tensor.cast operation back to the original type.
2438 Value castBack
= rewriter
.create
<tensor::CastOp
>(
2439 loc
, resultValue
.getType(), newOp
->getResult(resultNumber
));
2441 SmallVector
<Value
> results(newOp
->result_begin(), newOp
->result_end());
2442 results
[resultNumber
] = castBack
;
2443 rewriter
.replaceOp(linalgOp
, results
);
2444 rewriter
.replaceOp(castOp
, newOp
->getResult(resultNumber
));
2449 /// For each of the operand in `operands` this function maps the static sizes of
2450 /// dimensions to their affine dim expressions.
2451 static void populateMap(LinalgOp linalgOp
, MutableArrayRef
<OpOperand
> operands
,
2452 llvm::DenseMap
<AffineExpr
, int64_t> &affineExprToSize
) {
2453 for (OpOperand
&opOperand
: operands
) {
2454 if (linalgOp
.isScalar(&opOperand
))
2456 Value src
= opOperand
.get();
2457 auto sourceType
= llvm::cast
<RankedTensorType
>(src
.getType());
2458 auto sourceMap
= linalgOp
.getMatchingIndexingMap(&opOperand
);
2460 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2461 // `tensor.cast` operation and source of the cast operation has a static
2462 // shape, then assign it to the `sourceShape`.
2463 auto *parentOp
= src
.getDefiningOp();
2464 ArrayRef
<int64_t> sourceShape
= sourceType
.getShape();
2466 if (auto castOp
= dyn_cast
<tensor::CastOp
>(parentOp
)) {
2467 Value castSource
= castOp
.getSource();
2468 auto castSourceType
=
2469 llvm::dyn_cast
<RankedTensorType
>(castSource
.getType());
2470 if (castSourceType
&& castSourceType
.hasStaticShape())
2471 sourceShape
= castSourceType
.getShape();
2475 // If the source shape's dimension has a static shape, map the affine dim
2476 // expression to the known static size.
2477 for (unsigned i
= 0; i
< sourceShape
.size(); i
++) {
2478 if (sourceType
.isDynamicDim(i
))
2480 if (auto affineDimExpr
= dyn_cast
<AffineDimExpr
>(sourceMap
.getResult(i
)))
2481 affineExprToSize
.try_emplace(affineDimExpr
, sourceShape
[i
]);
2486 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2487 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2488 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2489 /// change then `changeNeeded` is false and same operand is added in the
2490 /// `newOperands` list.
2491 static void createNewOperandWithStaticSizes(
2492 Location loc
, PatternRewriter
&rewriter
, OpOperand
*opOperand
,
2493 llvm::DenseMap
<AffineExpr
, int64_t> &affineExprToSize
, LinalgOp linalgOp
,
2494 SmallVector
<Value
> &newOperands
, SmallVector
<Type
> &resultTypes
,
2495 bool &changeNeeded
) {
2496 Value src
= opOperand
->get();
2497 newOperands
.push_back(src
);
2498 if (linalgOp
.isScalar(opOperand
))
2500 auto sourceType
= llvm::cast
<RankedTensorType
>(src
.getType());
2501 Type resultType
= sourceType
;
2502 if (sourceType
.hasStaticShape() && linalgOp
.isDpsInit(opOperand
)) {
2503 resultTypes
.push_back(resultType
);
2506 ArrayRef
<int64_t> sourceShape
= sourceType
.getShape();
2507 AffineMap sourceMap
= linalgOp
.getMatchingIndexingMap(opOperand
);
2508 SmallVector
<int64_t> newShape
;
2509 // If operand is updated with new shape, `newOperandNeeded` will be
2511 bool newOperandNeeded
= false;
2512 for (unsigned i
= 0; i
< sourceShape
.size(); i
++) {
2513 int64_t dimShape
= sourceShape
[i
];
2514 AffineExpr dimExpr
= sourceMap
.getResult(i
);
2515 if (!affineExprToSize
.contains(dimExpr
) || !sourceType
.isDynamicDim(i
)) {
2516 newShape
.push_back(dimShape
);
2519 // Dimension has a dynamic shape and corresponding affine dim
2520 // expression is present in the map. So assign the size for the
2521 // given affine dim expression to the dimension.
2522 newShape
.push_back(affineExprToSize
[dimExpr
]);
2523 newOperandNeeded
= true;
2525 resultType
= RankedTensorType::get(newShape
, sourceType
.getElementType());
2526 if (newOperandNeeded
) {
2527 changeNeeded
= true;
2528 // Get the new operand value given its size and element type by
2530 Value newOperand
= rewriter
.create
<tensor::CastOp
>(loc
, resultType
, src
);
2531 unsigned index
= opOperand
->getOperandNumber();
2532 newOperands
[index
] = newOperand
;
2534 if (linalgOp
.isDpsInit(opOperand
))
2535 resultTypes
.push_back(resultType
);
2538 /// Static shapes for the operands can be inferred if any one of the operands
2539 /// have a static shape. This can be done by referring to the affine dim
2540 /// expressions for the operand.
2541 struct InferStaticShapeOfOperands
: public OpInterfaceRewritePattern
<LinalgOp
> {
2542 using OpInterfaceRewritePattern
<LinalgOp
>::OpInterfaceRewritePattern
;
2544 LogicalResult
matchAndRewrite(LinalgOp linalgOp
,
2545 PatternRewriter
&rewriter
) const override
{
2546 if (!linalgOp
.hasPureTensorSemantics())
2549 // Maps must be projected permutations.
2550 if (llvm::any_of(linalgOp
.getIndexingMapsArray(), [](AffineMap map
) {
2551 return !map
.isProjectedPermutation();
2555 // Maps affine dim expressions to the static size of that dimension.
2556 llvm::DenseMap
<AffineExpr
, int64_t> affineExprToSize
;
2557 Location loc
= linalgOp
.getLoc();
2559 // For each of the affine dim expression, check if the size is known. If
2560 // known add that in the map.
2561 populateMap(linalgOp
, linalgOp
->getOpOperands(), affineExprToSize
);
2563 SmallVector
<Value
> newOperands
;
2564 SmallVector
<Type
> resultTypes
;
2566 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2567 // change in their types.
2568 bool changeNeeded
= false;
2569 newOperands
.reserve(linalgOp
->getNumOperands());
2570 resultTypes
.reserve(linalgOp
.getNumDpsInits());
2572 // Iterate over all the operands and update the static sizes.
2573 for (OpOperand
&opOperand
: linalgOp
->getOpOperands()) {
2574 createNewOperandWithStaticSizes(loc
, rewriter
, &opOperand
,
2575 affineExprToSize
, linalgOp
, newOperands
,
2576 resultTypes
, changeNeeded
);
2579 // If the generic op has all the required static information, no
2580 // canonicalization needed.
2585 Operation
*newOp
= clone(rewriter
, linalgOp
, resultTypes
, newOperands
);
2586 SmallVector
<Value
> replacements
;
2587 replacements
.reserve(newOp
->getNumResults());
2588 for (auto it
: llvm::zip(linalgOp
->getResults(), newOp
->getResults())) {
2589 Value newResult
= std::get
<1>(it
);
2590 Value oldResult
= std::get
<0>(it
);
2591 Type newType
= newResult
.getType();
2592 Type oldType
= oldResult
.getType();
2593 replacements
.push_back(
2594 (newType
!= oldType
)
2595 ? rewriter
.create
<tensor::CastOp
>(loc
, oldType
, newResult
)
2598 rewriter
.replaceOp(linalgOp
, replacements
);
2605 // All named ops canonicalizers and folders are auto-generated in the
2608 //===----------------------------------------------------------------------===//
2610 //===----------------------------------------------------------------------===//
2612 LogicalResult
SoftmaxOp::verify() {
2613 ShapedType inputType
= getInputOperandType();
2614 ShapedType outputType
= getOutputOperandType();
2616 ArrayRef
<int64_t> inputShape
= inputType
.getShape();
2617 ArrayRef
<int64_t> outputShape
= outputType
.getShape();
2618 if (failed(verifyCompatibleShape(inputShape
, outputShape
)))
2619 return emitOpError("incompatible output shape");
2621 int64_t inputRank
= getInputOperandRank();
2622 int64_t dimension
= getDimension();
2623 if ((dimension
< 0) || (dimension
>= inputRank
))
2624 return emitOpError("incorrect dimension specified");
2629 SmallVector
<Range
> SoftmaxOp::getIterationDomain(OpBuilder
&builder
) {
2630 int64_t operandRank
= getInputOperandRank();
2631 SmallVector
<Range
> loopBounds(operandRank
);
2632 Location loc
= getLoc();
2633 Value zero
= builder
.create
<arith::ConstantIndexOp
>(loc
, 0);
2634 Value one
= builder
.create
<arith::ConstantIndexOp
>(loc
, 1);
2635 Value source
= getInput();
2636 for (auto dim
: llvm::seq
<int64_t>(0, operandRank
)) {
2637 loopBounds
[dim
].offset
= zero
;
2638 loopBounds
[dim
].size
= getDimValue(builder
, loc
, source
, dim
);
2639 loopBounds
[dim
].stride
= one
;
2644 SmallVector
<utils::IteratorType
> SoftmaxOp::getLoopIteratorTypes() {
2645 SmallVector
<utils::IteratorType
> iteratorTypes(getInputOperandRank(),
2646 utils::IteratorType::parallel
);
2647 iteratorTypes
[getDimension()] = utils::IteratorType::reduction
;
2648 return iteratorTypes
;
2651 FailureOr
<TilingResult
>
2652 SoftmaxOp::getTiledImplementation(OpBuilder
&builder
,
2653 ArrayRef
<OpFoldResult
> offsets
,
2654 ArrayRef
<OpFoldResult
> sizes
) {
2655 int64_t rank
= getInputOperandRank();
2656 auto oneAttr
= builder
.getI64IntegerAttr(1);
2657 SmallVector
<OpFoldResult
> strides(rank
, oneAttr
);
2658 SmallVector
<Value
> tiledOperands
;
2659 Operation
*inputSlice
=
2660 getSlice(builder
, getLoc(), getInput(), offsets
, sizes
, strides
);
2662 return emitOpError("failed to compute input slice");
2664 tiledOperands
.emplace_back(inputSlice
->getResult(0));
2665 Operation
*outputSlice
=
2666 getSlice(builder
, getLoc(), getOutput(), offsets
, sizes
, strides
);
2668 return emitOpError("failed to compute output slice");
2670 tiledOperands
.emplace_back(outputSlice
->getResult(0));
2672 SmallVector
<Type
, 4> resultTypes
;
2673 if (hasPureTensorSemantics())
2674 resultTypes
.push_back(tiledOperands
[1].getType());
2675 Operation
*tiledOp
=
2676 mlir::clone(builder
, getOperation(), resultTypes
, tiledOperands
);
2678 return TilingResult
{
2680 SmallVector
<Value
>(tiledOp
->getResults()),
2681 llvm::to_vector(ArrayRef
<Operation
*>{inputSlice
, outputSlice
})};
2684 LogicalResult
SoftmaxOp::getResultTilePosition(
2685 OpBuilder
&builder
, unsigned resultNumber
, ArrayRef
<OpFoldResult
> offsets
,
2686 ArrayRef
<OpFoldResult
> sizes
, SmallVector
<OpFoldResult
> &resultOffsets
,
2687 SmallVector
<OpFoldResult
> &resultSizes
) {
2688 if (resultNumber
== 0) {
2689 resultOffsets
.assign(offsets
.begin(), offsets
.end());
2690 resultSizes
.assign(sizes
.begin(), sizes
.end());
2696 // cast(dynamic) -> static.
2697 LogicalResult
SoftmaxOp::fold(FoldAdaptor
, SmallVectorImpl
<OpFoldResult
> &) {
2698 return memref::foldMemRefCast(*this);
2702 SoftmaxOp::reifyResultShapes(OpBuilder
&b
,
2703 ReifiedRankedShapedTypeDims
&reifiedReturnShapes
) {
2704 SmallVector
<OpFoldResult
> shapes
;
2705 Location loc
= getOperation()->getLoc();
2706 IRRewriter
rewriter(b
);
2707 auto inputShapedType
= llvm::cast
<ShapedType
>(getInputOperandType());
2708 auto outputShapedType
= llvm::cast
<ShapedType
>(getOutputOperandType());
2709 for (int64_t dim
: llvm::seq
<int64_t>(0, getOutputOperandRank())) {
2710 if (!outputShapedType
.isDynamicDim(dim
)) {
2711 // Static dim: Return IntegerAttr.
2712 shapes
.push_back(b
.getIndexAttr(inputShapedType
.getDimSize(dim
)));
2714 // Dynamic dim: Return Value.
2715 OpFoldResult ofr
= createOrFoldDimOp(b
, loc
, getInput(), dim
);
2716 shapes
.push_back(getValueOrCreateConstantIndexOp(b
, loc
, ofr
));
2719 reifiedReturnShapes
.emplace_back(std::move(shapes
));
2723 void SoftmaxOp::getEffects(
2724 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
2726 for (auto [index
, operand
] : llvm::enumerate(getDpsInputs())) {
2727 if (!llvm::isa
<MemRefType
>(operand
.getType()))
2729 effects
.emplace_back(MemoryEffects::Read::get(),
2730 &getOperation()->getOpOperand(index
), /*stage=*/0,
2731 /*effectOnFullRegion=*/true,
2732 SideEffects::DefaultResource::get());
2735 for (OpOperand
&operand
: getDpsInitsMutable()) {
2736 if (!llvm::isa
<MemRefType
>(operand
.get().getType()))
2738 effects
.emplace_back(MemoryEffects::Read::get(), &operand
, /*stage=*/0,
2739 /*effectOnFullRegion=*/true,
2740 SideEffects::DefaultResource::get());
2741 effects
.emplace_back(MemoryEffects::Write::get(), &operand
, /*stage=*/0,
2742 /*effectOnFullRegion=*/true,
2743 SideEffects::DefaultResource::get());
2747 // Helper functions for softmax decomposition.
2750 // Helper function to produce the iterator types (reduction or parallel) and
2751 // affine maps for the iterators used in the decomposition of softmax.
2752 // This method creates:
2753 // If allParallel == true:
2754 // - iterator type: {parallel, ..., parallel}
2756 // -- identity with inputRank dimensions.
2757 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2758 // where N == inputRank.
2760 // If allParallel == false:
2761 // - iterator type at dim(i) == parallel for i != \p dim and
2762 // dim(dim) == reduction.
2764 // -- identity with inputRank dimensions.
2765 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2766 // where N == inputRank.
2767 static std::tuple
<SmallVector
<utils::IteratorType
>, SmallVector
<AffineMap
>>
2768 computeIteratorTypesAndIndexingMaps(OpBuilder
&builder
, int64_t inputRank
,
2769 int64_t dim
, bool allParallel
= false) {
2770 SmallVector
<utils::IteratorType
> iteratorTypes(inputRank
,
2771 utils::IteratorType::parallel
);
2773 iteratorTypes
[dim
] = utils::IteratorType::reduction
;
2774 MLIRContext
*ctxt
= builder
.getContext();
2775 auto identityMap
= AffineMap::getMultiDimIdentityMap(inputRank
, ctxt
);
2776 SmallVector
<AffineExpr
, 2> affineExprs
;
2777 for (int i
= 0; i
< inputRank
; i
++) {
2779 affineExprs
.push_back(mlir::getAffineDimExpr(i
, ctxt
));
2782 AffineMap::get(inputRank
, /*symbols=*/0, affineExprs
, ctxt
);
2783 SmallVector
<AffineMap
> indexingMaps
{identityMap
, reductionMap
};
2784 return std::make_tuple(iteratorTypes
, indexingMaps
);
2787 // Helper function to produce a linalg.generic that computes a reduction on
2788 // dimension \p dim with the operation type \p T.
2789 template <typename T
>
2790 static Value
reduce(OpBuilder
&builder
, Location loc
, Value input
, Value output
,
2792 auto inputType
= cast
<ShapedType
>(input
.getType());
2793 ArrayRef
<int64_t> inputShape
= inputType
.getShape();
2794 int64_t inputRank
= inputShape
.size();
2795 auto [iteratorTypes
, indexingMaps
] =
2796 computeIteratorTypesAndIndexingMaps(builder
, inputRank
, dim
);
2797 assert(indexingMaps
.size() == 2 &&
2798 "We should have two maps: 1 for the input, 1 for the output");
2799 assert(indexingMaps
[0].isIdentity() && "input map should be identity");
2801 auto genericOp
= builder
.create
<linalg::GenericOp
>(
2802 loc
, output
.getType(), input
, output
, indexingMaps
, iteratorTypes
,
2803 [&](OpBuilder
&b
, Location loc
, ValueRange args
) {
2804 Value result
= b
.create
<T
>(loc
, args
[0], args
[1]);
2805 b
.create
<linalg::YieldOp
>(loc
, result
);
2807 return genericOp
.getResult(0);
2810 /// Produce a linalg generic that computes the second step of the softmax
2811 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2812 /// on dimension \p dim.
2813 static Value
buildSubAndExpOp(OpBuilder
&builder
, Location loc
, Value input
,
2814 Value max
, Value output
, int64_t dim
) {
2815 auto inputType
= cast
<ShapedType
>(input
.getType());
2816 ArrayRef
<int64_t> inputShape
= inputType
.getShape();
2817 int64_t inputRank
= inputShape
.size();
2818 auto [iteratorTypes
, indexingMaps
] = computeIteratorTypesAndIndexingMaps(
2819 builder
, inputRank
, dim
, /*allParallel=*/true);
2820 assert(indexingMaps
.size() == 2 && "We should have one map for each input");
2821 assert(indexingMaps
[0].isIdentity() && "input map should be identity");
2822 // Add the affine map for the output argument.
2823 indexingMaps
.push_back(indexingMaps
[0]);
2824 auto genericOp
= builder
.create
<linalg::GenericOp
>(
2825 loc
, input
.getType(), ValueRange
{input
, max
}, output
, indexingMaps
,
2826 iteratorTypes
, [&](OpBuilder
&b
, Location loc
, ValueRange args
) {
2827 Value diff
= b
.create
<arith::SubFOp
>(loc
, args
[0], args
[1]);
2828 Value result
= b
.create
<math::ExpOp
>(loc
, diff
);
2829 b
.create
<linalg::YieldOp
>(loc
, result
);
2831 return genericOp
.getResult(0);
2834 /// Produce a linalg generic that computes the final step of the softmax
2836 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2839 static Value
buildDivOp(OpBuilder
&builder
, Location loc
, Value numerator
,
2840 Value denominator
, Value output
, int64_t dim
) {
2841 auto inputType
= cast
<ShapedType
>(numerator
.getType());
2842 ArrayRef
<int64_t> inputShape
= inputType
.getShape();
2843 int64_t inputRank
= inputShape
.size();
2844 auto [iteratorTypes
, indexingMaps
] = computeIteratorTypesAndIndexingMaps(
2845 builder
, inputRank
, dim
, /*allParallel=*/true);
2846 assert(indexingMaps
.size() == 2 &&
2847 "We should have one map for each input (2)");
2848 assert(indexingMaps
[0].isIdentity() && "Numerator map should be identity");
2849 // Add the affine map for the output tensor.
2850 indexingMaps
.push_back(indexingMaps
[0]);
2851 auto genericOp
= builder
.create
<linalg::GenericOp
>(
2852 loc
, numerator
.getType(), ValueRange
{numerator
, denominator
}, output
,
2853 indexingMaps
, iteratorTypes
,
2854 [&](OpBuilder
&b
, Location loc
, ValueRange args
) {
2855 Value result
= b
.create
<arith::DivFOp
>(loc
, args
[0], args
[1]);
2856 b
.create
<linalg::YieldOp
>(loc
, result
);
2858 return genericOp
.getResult(0);
2860 // @} End helper functions for softmax decomposition.
2862 /// Given an N-dimensional tensor x, this method converts
2863 /// softmax(x) to the following sequence of operations:
2865 /// 1. Compute the max of x along dimension d. This results
2866 /// in a N-1 dimensional tensor m.
2867 /// m = max(x, dim = d)
2869 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2870 /// a N dimensional tensor z.
2873 /// 3. Compute the sum of z along dimension d. This results in
2874 /// a N-1 dimensional tensor l.
2875 /// l = sum(z, dim = d)
2877 /// 4. Divide z and l. This gives the N-dimensional softmax.
2880 FailureOr
<SmallVector
<Value
>> SoftmaxOp::decomposeOperation(OpBuilder
&b
) {
2881 OpBuilder::InsertionGuard
guard(b
);
2882 b
.setInsertionPoint(*this);
2883 Location loc
= getLoc();
2884 Value input
= getInput();
2885 ShapedType inputType
= getInputOperandType();
2886 Type elementType
= inputType
.getElementType();
2887 int64_t reductionDim
= getDimension();
2888 SmallVector
<OpFoldResult
> dims
= tensor::getMixedSizes(b
, loc
, input
);
2889 Value output
= getOutput();
2890 dims
.erase(dims
.begin() + reductionDim
);
2891 // Step 1: Compute max along dim.
2892 Value outputReduce
= b
.create
<tensor::EmptyOp
>(loc
, dims
, elementType
);
2893 Value neutralForMaxF
= arith::getIdentityValue(arith::AtomicRMWKind::maximumf
,
2894 elementType
, b
, loc
,
2895 /*useOnlyFiniteValue=*/true);
2896 Value neutralForMaxFInit
=
2897 b
.create
<linalg::FillOp
>(loc
, Value
{neutralForMaxF
}, outputReduce
)
2900 reduce
<arith::MaxNumFOp
>(b
, loc
, input
, neutralForMaxFInit
, reductionDim
);
2902 // Step 2: Subtract max from input and exponentiate.
2903 Value numerator
= buildSubAndExpOp(b
, loc
, input
, max
, output
, reductionDim
);
2905 // Step 3: Compute sum along dim.
2906 Value zero
= arith::getIdentityValue(arith::AtomicRMWKind::addf
, elementType
,
2907 b
, loc
, /*useOnlyFiniteValue=*/true);
2909 b
.create
<linalg::FillOp
>(loc
, Value
{zero
}, outputReduce
).result();
2911 reduce
<arith::AddFOp
>(b
, loc
, numerator
, zeroInit
, reductionDim
);
2913 // Step 4: Compute softmax.
2915 buildDivOp(b
, loc
, numerator
, denominator
, output
, reductionDim
);
2916 return SmallVector
<Value
>{result
};
2919 //===----------------------------------------------------------------------===//
2920 // WinogradFilterTransformOp
2921 //===----------------------------------------------------------------------===//
2923 LogicalResult
WinogradFilterTransformOp::verify() {
2924 auto filterType
= cast
<ShapedType
>(getFilter().getType());
2925 ArrayRef
<int64_t> filterShape
= filterType
.getShape();
2926 int64_t filterH
= filterShape
[getFilterHDim()];
2927 int64_t filterW
= filterShape
[getFilterWDim()];
2931 if (filterH
!= r
&& filterH
!= 1)
2932 return emitOpError("expect filter height either equals to r or 1");
2933 if (filterW
!= r
&& filterW
!= 1)
2934 return emitOpError("expect filter width either equals to r or 1");
2935 if (filterH
== 1 && filterW
== 1)
2936 return emitOpError("expect either filter height or width equals to r");
2938 SmallVector
<int64_t> expectedOutputShape
;
2939 expectedOutputShape
.push_back(filterH
== r
? m
+ r
- 1 : 1);
2940 expectedOutputShape
.push_back(filterW
== r
? m
+ r
- 1 : 1);
2941 expectedOutputShape
.push_back(filterShape
[getFilterCDim()]);
2942 expectedOutputShape
.push_back(filterShape
[getFilterFDim()]);
2944 auto outputType
= cast
<ShapedType
>(getOutput().getType());
2945 ArrayRef
<int64_t> outputShape
= outputType
.getShape();
2946 if (failed(verifyCompatibleShape(expectedOutputShape
, outputShape
))) {
2947 return emitOpError("the output shape is not expected");
2953 WinogradFilterTransformOp::getIterationDomain(OpBuilder
&builder
) {
2954 Location loc
= getLoc();
2955 IntegerAttr zeroAttr
= builder
.getIndexAttr(0);
2956 IntegerAttr oneAttr
= builder
.getIndexAttr(1);
2957 Value filter
= getFilter();
2958 int64_t filterRank
= getFilterOperandRank();
2959 SmallVector
<Range
> loopBounds(filterRank
);
2960 for (unsigned dim
= 0; dim
< filterRank
; ++dim
) {
2961 loopBounds
[dim
].offset
= zeroAttr
;
2962 loopBounds
[dim
].size
= getDimValue(builder
, loc
, filter
, dim
);
2963 loopBounds
[dim
].stride
= oneAttr
;
2968 SmallVector
<utils::IteratorType
>
2969 WinogradFilterTransformOp::getLoopIteratorTypes() {
2970 int64_t filterRank
= getFilterOperandRank();
2971 SmallVector
<utils::IteratorType
> iteratorTypes(filterRank
,
2972 utils::IteratorType::parallel
);
2973 return iteratorTypes
;
2976 LogicalResult
WinogradFilterTransformOp::getResultTilePosition(
2977 OpBuilder
&builder
, unsigned resultNumber
, ArrayRef
<OpFoldResult
> offsets
,
2978 ArrayRef
<OpFoldResult
> sizes
, SmallVector
<OpFoldResult
> &resultOffsets
,
2979 SmallVector
<OpFoldResult
> &resultSizes
) {
2980 IntegerAttr zeroAttr
= builder
.getI64IntegerAttr(0);
2981 ShapedType filterType
= getFilterOperandType();
2982 ArrayRef
<int64_t> filterShape
= filterType
.getShape();
2983 int64_t filterH
= filterShape
[getFilterHDim()];
2984 int64_t filterW
= filterShape
[getFilterWDim()];
2987 int64_t alpha
= m
+ r
- 1;
2988 int64_t alphaH
= filterH
!= 1 ? alpha
: 1;
2989 int64_t alphaW
= filterW
!= 1 ? alpha
: 1;
2990 IntegerAttr alphaHAttr
= builder
.getI64IntegerAttr(alphaH
);
2991 IntegerAttr alphaWAttr
= builder
.getI64IntegerAttr(alphaW
);
2993 resultOffsets
.append(
2994 {zeroAttr
, zeroAttr
, offsets
[getFilterCDim()], offsets
[getFilterFDim()]});
2996 {alphaHAttr
, alphaWAttr
, sizes
[getFilterCDim()], sizes
[getFilterFDim()]});
3001 /// Implement tiling for winograd_filter_transform
3002 /// The input of winograd_filter_transform is (F, KH, KW, C).
3003 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3004 /// Users can specify the tile sizes of F and C.
3005 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3006 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3007 FailureOr
<TilingResult
> WinogradFilterTransformOp::getTiledImplementation(
3008 OpBuilder
&builder
, ArrayRef
<OpFoldResult
> offsets
,
3009 ArrayRef
<OpFoldResult
> sizes
) {
3010 IntegerAttr oneAttr
= builder
.getI64IntegerAttr(1);
3011 IntegerAttr zeroAttr
= builder
.getI64IntegerAttr(0);
3012 ShapedType filterType
= getFilterOperandType();
3013 ArrayRef
<int64_t> filterShape
= filterType
.getShape();
3014 int64_t filterH
= filterShape
[getFilterHDim()];
3015 int64_t filterW
= filterShape
[getFilterWDim()];
3016 IntegerAttr filterHAttr
= builder
.getI64IntegerAttr(filterH
);
3017 IntegerAttr filterWAttr
= builder
.getI64IntegerAttr(filterW
);
3018 SmallVector
<Value
> tiledOperands
;
3019 SmallVector
<OpFoldResult
> sliceOffsets
, sliceSizes
;
3021 sliceOffsets
.append(
3022 {offsets
[getFilterFDim()], zeroAttr
, zeroAttr
, offsets
[getFilterCDim()]});
3023 sliceSizes
.append({sizes
[getFilterFDim()], filterHAttr
, filterWAttr
,
3024 sizes
[getFilterCDim()]});
3025 int64_t filterRank
= getFilterOperandRank();
3026 SmallVector
<OpFoldResult
> filterStrides(filterRank
, oneAttr
);
3027 Location loc
= getLoc();
3028 auto filterSlice
= builder
.create
<tensor::ExtractSliceOp
>(
3029 loc
, getFilter(), sliceOffsets
, sliceSizes
, filterStrides
);
3030 tiledOperands
.emplace_back(filterSlice
);
3032 SmallVector
<OpFoldResult
> resultOffsets
, resultSizes
;
3033 if (failed(getResultTilePosition(builder
, 1, offsets
, sizes
, resultOffsets
,
3037 int64_t outputRank
= getOutputOperandRank();
3038 SmallVector
<OpFoldResult
> outputStrides(outputRank
, oneAttr
);
3039 auto outputSlice
= builder
.create
<tensor::ExtractSliceOp
>(
3040 loc
, getOutput(), resultOffsets
, resultSizes
, outputStrides
);
3041 tiledOperands
.emplace_back(outputSlice
);
3043 SmallVector
<Type
> resultTypes
;
3044 resultTypes
.push_back(tiledOperands
[1].getType());
3045 Operation
*tiledOp
=
3046 mlir::clone(builder
, getOperation(), resultTypes
, tiledOperands
);
3048 return TilingResult
{
3050 SmallVector
<Value
>(tiledOp
->getResults()),
3051 llvm::to_vector(ArrayRef
<Operation
*>{filterSlice
, outputSlice
})};
3054 //===----------------------------------------------------------------------===//
3055 // WinogradInputTransformOp
3056 //===----------------------------------------------------------------------===//
3058 LogicalResult
WinogradInputTransformOp::verify() {
3059 auto inputType
= cast
<ShapedType
>(getInput().getType());
3060 ArrayRef
<int64_t> inputShape
= inputType
.getShape();
3061 int64_t inputH
= inputShape
[getInputHDim()];
3062 int64_t inputW
= inputShape
[getInputWDim()];
3065 int64_t tileSize
= m
+ r
- 1;
3066 bool leftTransform
= inputH
!= 1;
3067 bool rightTransform
= inputW
!= 1;
3069 SmallVector
<int64_t> expectedOutputShape(6, inputH
);
3070 if (ShapedType::isDynamic(inputH
)) {
3071 expectedOutputShape
[getOutputAlphaHDim()] = tileSize
;
3072 expectedOutputShape
[getOutputTileHDim()] = ShapedType::kDynamic
;
3074 expectedOutputShape
[getOutputAlphaHDim()] = leftTransform
? tileSize
: 1;
3075 expectedOutputShape
[getOutputTileHDim()] =
3076 leftTransform
? (inputH
- (r
- 1)) / m
: 1;
3078 if (ShapedType::isDynamic(inputW
)) {
3079 expectedOutputShape
[getOutputAlphaWDim()] = tileSize
;
3080 expectedOutputShape
[getOutputTileWDim()] = ShapedType::kDynamic
;
3082 expectedOutputShape
[getOutputAlphaWDim()] = rightTransform
? tileSize
: 1;
3083 expectedOutputShape
[getOutputTileWDim()] =
3084 rightTransform
? (inputW
- (r
- 1)) / m
: 1;
3086 expectedOutputShape
[getOutputNDim()] = inputShape
[getInputNDim()];
3087 expectedOutputShape
[getOutputCDim()] = inputShape
[getInputCDim()];
3089 auto outputType
= cast
<ShapedType
>(getOutput().getType());
3090 ArrayRef
<int64_t> outputShape
= outputType
.getShape();
3091 if (failed(verifyCompatibleShape(expectedOutputShape
, outputShape
))) {
3092 return emitOpError("the output shape is not expected");
3098 WinogradInputTransformOp::getIterationDomain(OpBuilder
&builder
) {
3099 Location loc
= getLoc();
3100 IntegerAttr zeroAttr
= builder
.getIndexAttr(0);
3101 IntegerAttr oneAttr
= builder
.getIndexAttr(1);
3102 Value output
= getOutput();
3103 int64_t outputRank
= getOutputOperandRank();
3104 SmallVector
<Range
> loopBounds(outputRank
);
3105 for (unsigned dim
= 0; dim
< outputRank
; ++dim
) {
3106 loopBounds
[dim
].offset
= zeroAttr
;
3107 // alphaH, alphaW, tileH, tileW, N, C
3108 loopBounds
[dim
].size
= getDimValue(builder
, loc
, output
, dim
);
3109 loopBounds
[dim
].stride
= oneAttr
;
3114 SmallVector
<utils::IteratorType
>
3115 WinogradInputTransformOp::getLoopIteratorTypes() {
3116 int64_t outputRank
= getOutputOperandRank();
3117 SmallVector
<utils::IteratorType
> iteratorTypes(outputRank
,
3118 utils::IteratorType::parallel
);
3119 return iteratorTypes
;
3122 LogicalResult
WinogradInputTransformOp::getResultTilePosition(
3123 OpBuilder
&builder
, unsigned resultNumber
, ArrayRef
<OpFoldResult
> offsets
,
3124 ArrayRef
<OpFoldResult
> sizes
, SmallVector
<OpFoldResult
> &resultOffsets
,
3125 SmallVector
<OpFoldResult
> &resultSizes
) {
3126 IntegerAttr zeroAttr
= builder
.getI64IntegerAttr(0);
3127 ShapedType inputType
= getInputOperandType();
3128 ArrayRef
<int64_t> inputShape
= inputType
.getShape();
3129 int64_t inputH
= inputShape
[getInputHDim()];
3130 int64_t inputW
= inputShape
[getInputWDim()];
3133 int64_t alpha
= m
+ r
- 1;
3134 int64_t alphaH
= inputH
!= 1 ? alpha
: 1;
3135 int64_t alphaW
= inputW
!= 1 ? alpha
: 1;
3136 IntegerAttr alphaHAttr
= builder
.getI64IntegerAttr(alphaH
);
3137 IntegerAttr alphaWAttr
= builder
.getI64IntegerAttr(alphaW
);
3139 resultOffsets
.append({zeroAttr
, zeroAttr
, offsets
[getOutputTileHDim()],
3140 offsets
[getOutputTileWDim()], offsets
[getOutputNDim()],
3141 offsets
[getOutputCDim()]});
3142 resultSizes
.append({alphaHAttr
, alphaWAttr
, sizes
[getOutputTileHDim()],
3143 sizes
[getOutputTileWDim()], sizes
[getOutputNDim()],
3144 sizes
[getOutputCDim()]});
3149 /// Implement tiling for winograd_input_transform
3150 /// The input of winograd_input_transform is (N, H, W, C).
3151 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3152 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3153 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3154 /// the values for the sizes of tileH, tileW, N, C for one tile.
3155 FailureOr
<TilingResult
>
3156 WinogradInputTransformOp::getTiledImplementation(OpBuilder
&builder
,
3157 ArrayRef
<OpFoldResult
> offsets
,
3158 ArrayRef
<OpFoldResult
> sizes
) {
3159 IntegerAttr oneAttr
= builder
.getI64IntegerAttr(1);
3160 IntegerAttr zeroAttr
= builder
.getI64IntegerAttr(0);
3161 ShapedType inputType
= getInputOperandType();
3162 ArrayRef
<int64_t> inputShape
= inputType
.getShape();
3163 int64_t inputH
= inputShape
[getInputHDim()];
3164 int64_t inputW
= inputShape
[getInputWDim()];
3168 Location loc
= getLoc();
3169 MLIRContext
*context
= builder
.getContext();
3170 auto offsetAffineMap
=
3171 AffineMap::get(1, 0, {builder
.getAffineDimExpr(0) * m
}, context
);
3172 Value mappedOffsetH
= affine::makeComposedAffineApply(
3173 builder
, loc
, offsetAffineMap
, offsets
[getOutputTileHDim()]);
3174 Value mappedOffsetW
= affine::makeComposedAffineApply(
3175 builder
, loc
, offsetAffineMap
, offsets
[getOutputTileWDim()]);
3176 auto sizeAffineMap
= AffineMap::get(
3177 1, 0, {builder
.getAffineDimExpr(0) * m
+ (r
- 1)}, context
);
3178 Value mappedSizeH
= affine::makeComposedAffineApply(
3179 builder
, loc
, sizeAffineMap
, sizes
[getOutputTileHDim()]);
3180 Value mappedSizeW
= affine::makeComposedAffineApply(
3181 builder
, loc
, sizeAffineMap
, sizes
[getOutputTileWDim()]);
3183 SmallVector
<Value
> tiledOperands
;
3184 SmallVector
<OpFoldResult
> sliceOffsets
, sliceSizes
;
3186 OpFoldResult offsetH
=
3187 inputH
!= 1 ? OpFoldResult(mappedOffsetH
) : OpFoldResult(zeroAttr
);
3188 OpFoldResult offsetW
=
3189 inputW
!= 1 ? OpFoldResult(mappedOffsetW
) : OpFoldResult(zeroAttr
);
3190 sliceOffsets
.append(
3191 {offsets
[getOutputNDim()], offsetH
, offsetW
, offsets
[getOutputCDim()]});
3192 OpFoldResult sizeH
=
3193 inputH
!= 1 ? OpFoldResult(mappedSizeH
) : OpFoldResult(oneAttr
);
3194 OpFoldResult sizeW
=
3195 inputW
!= 1 ? OpFoldResult(mappedSizeW
) : OpFoldResult(oneAttr
);
3197 {sizes
[getOutputNDim()], sizeH
, sizeW
, sizes
[getOutputCDim()]});
3198 int64_t inputRank
= getInputOperandRank();
3199 SmallVector
<OpFoldResult
> inputStrides(inputRank
, oneAttr
);
3200 auto inputSlice
= builder
.create
<tensor::ExtractSliceOp
>(
3201 loc
, getInput(), sliceOffsets
, sliceSizes
, inputStrides
);
3202 tiledOperands
.emplace_back(inputSlice
);
3204 SmallVector
<OpFoldResult
> resultOffsets
, resultSizes
;
3205 if (failed(getResultTilePosition(builder
, 1, offsets
, sizes
, resultOffsets
,
3209 int64_t outputRank
= getOutputOperandRank();
3210 SmallVector
<OpFoldResult
> outputStrides(outputRank
, oneAttr
);
3211 auto outputSlice
= builder
.create
<tensor::ExtractSliceOp
>(
3212 loc
, getOutput(), resultOffsets
, resultSizes
, outputStrides
);
3213 tiledOperands
.emplace_back(outputSlice
);
3215 SmallVector
<Type
> resultTypes
;
3216 resultTypes
.push_back(tiledOperands
[1].getType());
3217 Operation
*tiledOp
=
3218 mlir::clone(builder
, getOperation(), resultTypes
, tiledOperands
);
3220 return TilingResult
{
3222 SmallVector
<Value
>(tiledOp
->getResults()),
3223 llvm::to_vector(ArrayRef
<Operation
*>{inputSlice
, outputSlice
})};
3226 //===----------------------------------------------------------------------===//
3227 // WinogradOutputTransformOp
3228 //===----------------------------------------------------------------------===//
3230 LogicalResult
WinogradOutputTransformOp::verify() {
3231 auto valueType
= cast
<ShapedType
>(getValue().getType());
3232 ArrayRef
<int64_t> valueShape
= valueType
.getShape();
3233 int64_t valueH
= valueShape
[getValueAlphaHDim()];
3234 int64_t valueW
= valueShape
[getValueAlphaWDim()];
3235 int64_t valueTileH
= valueShape
[getValueTileHDim()];
3236 int64_t valueTileW
= valueShape
[getValueTileWDim()];
3239 bool leftTransform
= valueH
!= 1;
3240 bool rightTransform
= valueW
!= 1;
3242 int64_t outputRank
= getOutputOperandRank();
3243 SmallVector
<int64_t> expectedOutputShape(outputRank
, valueH
);
3244 if (ShapedType::isDynamic(valueH
) || ShapedType::isDynamic(valueTileH
)) {
3245 expectedOutputShape
[getOutputHDim()] = ShapedType::kDynamic
;
3247 if (valueH
!= (leftTransform
? m
+ r
- 1 : 1))
3248 return emitOpError("expect input height equals to input tile size");
3249 expectedOutputShape
[getOutputHDim()] = (leftTransform
? m
: 1) * valueTileH
;
3251 if (ShapedType::isDynamic(valueW
) || ShapedType::isDynamic(valueTileW
)) {
3252 expectedOutputShape
[getOutputWDim()] = ShapedType::kDynamic
;
3254 if (valueW
!= (rightTransform
? m
+ r
- 1 : 1))
3255 return emitOpError("expect input width equals to input tile size");
3256 expectedOutputShape
[getOutputWDim()] =
3257 (rightTransform
? m
: 1) * valueTileW
;
3259 expectedOutputShape
[getOutputNDim()] = valueShape
[getValueNDim()];
3260 expectedOutputShape
[getOutputFDim()] = valueShape
[getValueFDim()];
3262 auto outputType
= cast
<ShapedType
>(getOutput().getType());
3263 ArrayRef
<int64_t> outputShape
= outputType
.getShape();
3264 if (failed(verifyCompatibleShape(expectedOutputShape
, outputShape
))) {
3265 return emitOpError("the output shape is not expected");
3271 WinogradOutputTransformOp::getIterationDomain(OpBuilder
&builder
) {
3272 Location loc
= getLoc();
3273 IntegerAttr zeroAttr
= builder
.getIndexAttr(0);
3274 IntegerAttr oneAttr
= builder
.getIndexAttr(1);
3275 Value value
= getValue();
3276 int64_t valueRank
= getValueOperandRank();
3277 SmallVector
<Range
> loopBounds(valueRank
);
3278 for (unsigned dim
= 0; dim
< valueRank
; ++dim
) {
3279 loopBounds
[dim
].offset
= zeroAttr
;
3280 // alphaH, alphaW, tileH, tileW, N, F
3281 loopBounds
[dim
].size
= getDimValue(builder
, loc
, value
, dim
);
3282 loopBounds
[dim
].stride
= oneAttr
;
3287 SmallVector
<utils::IteratorType
>
3288 WinogradOutputTransformOp::getLoopIteratorTypes() {
3289 int64_t valueRank
= getValueOperandRank();
3290 SmallVector
<utils::IteratorType
> iteratorTypes(valueRank
,
3291 utils::IteratorType::parallel
);
3292 return iteratorTypes
;
3295 LogicalResult
WinogradOutputTransformOp::getResultTilePosition(
3296 OpBuilder
&builder
, unsigned resultNumber
, ArrayRef
<OpFoldResult
> offsets
,
3297 ArrayRef
<OpFoldResult
> sizes
, SmallVector
<OpFoldResult
> &resultOffsets
,
3298 SmallVector
<OpFoldResult
> &resultSizes
) {
3301 Location loc
= getLoc();
3302 MLIRContext
*context
= builder
.getContext();
3304 AffineMap::get(1, 0, {builder
.getAffineDimExpr(0) * m
}, context
);
3306 Value mappedOffsetH
= affine::makeComposedAffineApply(
3307 builder
, loc
, affineMap
, offsets
[getValueTileHDim()]);
3308 Value mappedOffsetW
= affine::makeComposedAffineApply(
3309 builder
, loc
, affineMap
, offsets
[getValueTileWDim()]);
3310 Value mappedSizeH
= affine::makeComposedAffineApply(
3311 builder
, loc
, affineMap
, sizes
[getValueTileHDim()]);
3312 Value mappedSizeW
= affine::makeComposedAffineApply(
3313 builder
, loc
, affineMap
, sizes
[getValueTileWDim()]);
3315 ShapedType valueType
= getValueOperandType();
3316 ArrayRef
<int64_t> valueShape
= valueType
.getShape();
3317 int64_t valueH
= valueShape
[0];
3318 int64_t valueW
= valueShape
[1];
3319 IntegerAttr oneAttr
= builder
.getI64IntegerAttr(1);
3320 IntegerAttr zeroAttr
= builder
.getI64IntegerAttr(0);
3321 OpFoldResult offsetH
=
3322 valueH
!= 1 ? OpFoldResult(mappedOffsetH
) : OpFoldResult(zeroAttr
);
3323 OpFoldResult offsetW
=
3324 valueW
!= 1 ? OpFoldResult(mappedOffsetW
) : OpFoldResult(zeroAttr
);
3325 OpFoldResult sizeH
=
3326 valueH
!= 1 ? OpFoldResult(mappedSizeH
) : OpFoldResult(oneAttr
);
3327 OpFoldResult sizeW
=
3328 valueW
!= 1 ? OpFoldResult(mappedSizeW
) : OpFoldResult(oneAttr
);
3330 resultOffsets
.append(
3331 {offsets
[getValueNDim()], offsetH
, offsetW
, offsets
[getValueFDim()]});
3333 {sizes
[getValueNDim()], sizeH
, sizeW
, sizes
[getValueFDim()]});
3337 /// Implement tiling for winograd_output_transform
3338 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3339 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3340 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3341 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3342 /// for the sizes of tileH, tileW, N, F for one tile.
3343 FailureOr
<TilingResult
> WinogradOutputTransformOp::getTiledImplementation(
3344 OpBuilder
&builder
, ArrayRef
<OpFoldResult
> offsets
,
3345 ArrayRef
<OpFoldResult
> sizes
) {
3346 IntegerAttr oneAttr
= builder
.getI64IntegerAttr(1);
3347 IntegerAttr zeroAttr
= builder
.getI64IntegerAttr(0);
3348 Location loc
= getLoc();
3349 SmallVector
<Value
> tiledOperands
;
3350 SmallVector
<OpFoldResult
> sliceOffsets
, sliceSizes
;
3352 ShapedType valueType
= getValueOperandType();
3353 ArrayRef
<int64_t> valueShape
= valueType
.getShape();
3354 int64_t alphaH
= valueShape
[getValueAlphaHDim()];
3355 int64_t alphaW
= valueShape
[getValueAlphaWDim()];
3356 IntegerAttr alphaHAttr
= builder
.getI64IntegerAttr(alphaH
);
3357 IntegerAttr alphaWAttr
= builder
.getI64IntegerAttr(alphaW
);
3359 sliceOffsets
.append({zeroAttr
, zeroAttr
, offsets
[getValueTileHDim()],
3360 offsets
[getValueTileWDim()], offsets
[getValueNDim()],
3361 offsets
[getValueFDim()]});
3362 sliceSizes
.append({alphaHAttr
, alphaWAttr
, sizes
[getValueTileHDim()],
3363 sizes
[getValueTileWDim()], sizes
[getValueNDim()],
3364 sizes
[getValueFDim()]});
3365 int64_t valueRank
= getValueOperandRank();
3366 SmallVector
<OpFoldResult
> sliceStrides(valueRank
, oneAttr
);
3367 auto valueSlice
= builder
.create
<tensor::ExtractSliceOp
>(
3368 loc
, getValue(), sliceOffsets
, sliceSizes
, sliceStrides
);
3369 tiledOperands
.emplace_back(valueSlice
);
3371 SmallVector
<OpFoldResult
> resultOffsets
, resultSizes
;
3372 if (failed(getResultTilePosition(builder
, 1, offsets
, sizes
, resultOffsets
,
3376 int64_t outputRank
= getOutputOperandRank();
3377 SmallVector
<OpFoldResult
> strides(outputRank
, oneAttr
);
3378 auto outputSlice
= builder
.create
<tensor::ExtractSliceOp
>(
3379 loc
, getOutput(), resultOffsets
, resultSizes
, strides
);
3380 tiledOperands
.emplace_back(outputSlice
);
3382 SmallVector
<Type
> resultTypes
;
3383 resultTypes
.push_back(tiledOperands
[1].getType());
3384 Operation
*tiledOp
=
3385 mlir::clone(builder
, getOperation(), resultTypes
, tiledOperands
);
3387 return TilingResult
{
3389 SmallVector
<Value
>(tiledOp
->getResults()),
3390 llvm::to_vector(ArrayRef
<Operation
*>{valueSlice
, outputSlice
})};
3393 //===----------------------------------------------------------------------===//
3395 //===----------------------------------------------------------------------===//
3397 void LinalgDialect::getCanonicalizationPatterns(
3398 RewritePatternSet
&results
) const {
3399 results
.add
<EraseDeadLinalgOp
, FoldTensorCastConsumerOp
,
3400 InferStaticShapeOfOperands
>(getContext());
3403 Operation
*LinalgDialect::materializeConstant(OpBuilder
&builder
,
3404 Attribute value
, Type type
,
3406 return arith::ConstantOp::materialize(builder
, value
, type
, loc
);
3409 /// Returns true if the result AffineExpr of the \p explicitMap is same as \p
3411 static bool isValidResultDimExprs(AffineMap explictMap
, AffineMap defaultMap
) {
3412 auto explicitRange
= explictMap
.getResults();
3413 auto defaultRange
= defaultMap
.getResults();
3414 DenseSet
<AffineExpr
> explicitSet(explicitRange
.begin(), explicitRange
.end());
3415 DenseSet
<AffineExpr
> defaultSet(defaultRange
.begin(), defaultRange
.end());
3416 llvm::set_union(explicitSet
, defaultSet
);
3417 return explicitSet
== defaultSet
;
3420 /// Returns true if the \p explictMap is broadcasted with respect to the
3422 static bool isBroadcasted(AffineMap explictMap
, AffineMap defaultMap
) {
3423 return explictMap
.getNumResults() < defaultMap
.getNumResults();
3426 /// Verifies the broadcast and transpose semantic sepecified by the explicit
3427 /// indexing map for the MatmulOp \p op for each operand specified by \p
3429 static LogicalResult
verifyExtendedMatmulSemantic(MatmulOp matmulOp
,
3431 SmallVector
<AffineMap
, 3> opIndexingMaps
= matmulOp
.getIndexingMapsArray();
3432 SmallVector
<AffineMap
, 3> defaultIndexingMaps
=
3433 matmulOp
.getDefaultIndexingMaps(matmulOp
->getContext());
3435 auto opIndexingMap
= opIndexingMaps
[opIndex
];
3436 auto defaultIndexingMap
= defaultIndexingMaps
[opIndex
];
3437 // Check general validity of indexing map results.
3438 if (!isValidResultDimExprs(opIndexingMap
, defaultIndexingMap
))
3439 return matmulOp
->emitOpError()
3440 << "Unexpected dim expression in map result.";
3442 // Check if the requested broadcast is valid.
3443 if (isBroadcasted(opIndexingMap
, defaultIndexingMap
)) {
3444 if (!matmulOp
.isValidLhsRhsBroadcastMap(opIndexingMap
)) {
3445 return matmulOp
->emitOpError()
3446 << "Invalid broadcast requested, should be (d2).";
3456 //===----------------------------------------------------------------------===//
3458 //===----------------------------------------------------------------------===//
3460 /// Returns a list of AffineMap with the typical matmul indexing charactristic.
3461 SmallVector
<AffineMap
> MatmulOp::getDefaultIndexingMaps(MLIRContext
*context
) {
3462 AffineExpr d0
, d1
, d2
;
3463 SmallVector
<AffineMap
> indexingMaps
;
3464 bindDims(context
, d0
, d1
, d2
);
3465 indexingMaps
.push_back(AffineMap::get(3, 0, {d0
, d2
}, context
));
3466 indexingMaps
.push_back(AffineMap::get(3, 0, {d2
, d1
}, context
));
3467 indexingMaps
.push_back(AffineMap::get(3, 0, {d0
, d1
}, context
));
3468 return indexingMaps
;
3471 SmallVector
<utils::IteratorType
> MatmulOp::getIteratorTypesArray() {
3472 return SmallVector
<utils::IteratorType
>{utils::IteratorType::parallel
,
3473 utils::IteratorType::parallel
,
3474 utils::IteratorType::reduction
};
3477 unsigned MatmulOp::getNumRegionArgs() { return 3; }
3479 std::string
MatmulOp::getLibraryCallName() {
3480 return generateLibraryCallName(getOperation());
3483 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3485 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3486 /// the user defined indexing maps are not equal to default map.
3487 bool MatmulOp::hasUserDefinedMaps() {
3488 SmallVector
<AffineMap
, 3> defaultMaps
=
3489 getDefaultIndexingMaps(this->getContext());
3490 SmallVector
<AffineMap
, 3> explicitMaps
= getIndexingMapsArray();
3491 return defaultMaps
!= explicitMaps
;
3494 /// Implements the block region builder for the MatmulOp. This is called by
3495 /// 'fillStructuredOpRegion'.
3496 void MatmulOp::regionBuilder(ImplicitLocOpBuilder
&b
, Block
&block
,
3497 ArrayRef
<NamedAttribute
> attrs
) {
3498 assert(3 > 0 && block
.getNumArguments() == 3 &&
3499 "MatmulOp regionBuilder expects 3 (>=0) args");
3500 RegionBuilderHelper
helper(b
, block
);
3501 SmallVector
<Value
> yields
;
3503 TypeFn castVal
= TypeFn::cast_signed
;
3504 auto castIter
= llvm::find_if(attrs
, [&](const NamedAttribute
&attr
) {
3505 return attr
.getName() == "cast";
3507 if (castIter
!= attrs
.end()) {
3508 if (auto attr
= llvm::dyn_cast
<TypeFnAttr
>(castIter
->getValue()))
3509 castVal
= attr
.getValue();
3512 Value value1
= helper
.buildTypeFn(castVal
, block
.getArgument(2).getType(),
3513 block
.getArgument(0));
3514 Value value2
= helper
.buildTypeFn(castVal
, block
.getArgument(2).getType(),
3515 block
.getArgument(1));
3516 Value value3
= helper
.buildBinaryFn(BinaryFn::mul
, value1
, value2
);
3518 helper
.buildBinaryFn(BinaryFn::add
, block
.getArgument(2), value3
);
3519 yields
.push_back(value4
);
3520 helper
.yieldOutputs(yields
);
3523 /// Returns true if the given broadcast map \p bcastMap is valid for this op.
3524 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap
) {
3525 assert(bcastMap
.getNumResults() == 1 && "Expected single result dim expr.");
3526 AffineExpr exp
= bcastMap
.getResult(0);
3527 // Invalid map if the common dimension of matmul not found.
3528 return exp
.isFunctionOfDim(bcastMap
.getNumDims() - 1);
3531 ParseResult
MatmulOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
3532 SmallVector
<Attribute
, 3> indexingMapsAttr
;
3534 if (succeeded(parser
.parseOptionalKeyword("indexing_maps"))) {
3535 if (parser
.parseEqual())
3538 if (parser
.parseLSquare())
3542 if (parser
.parseAttribute(mapAttr
))
3544 if (!isa
<AffineMapAttr
>(mapAttr
)) {
3545 return parser
.emitError(parser
.getCurrentLocation(),
3546 "expected affine map attribute");
3548 indexingMapsAttr
.push_back(mapAttr
);
3550 if (parser
.parseOptionalComma())
3554 if (parser
.parseRSquare())
3557 // Initialize indexingMaps, if not supplied explicitly.
3558 if (indexingMapsAttr
.empty()) {
3559 indexingMapsAttr
= llvm::map_to_vector(
3560 MatmulOp::getDefaultIndexingMaps(parser
.getContext()),
3561 [](AffineMap map
) -> Attribute
{ return AffineMapAttr::get(map
); });
3563 result
.addAttribute("indexing_maps",
3564 parser
.getBuilder().getArrayAttr(indexingMapsAttr
));
3566 return parseNamedStructuredOp(parser
, result
, MatmulOp::getNumRegionArgs(),
3567 MatmulOp::getRegionBuilder());
3569 void MatmulOp::print(OpAsmPrinter
&p
) {
3570 SmallVector
<StringRef
, 3> elidedAttrs
= {
3571 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3572 printNamedStructuredOp(p
, getOperation(), getInputs(), getOutputs(),
3575 SmallVector
<Attribute
, 3> indexingMaps
= llvm::map_to_vector(
3576 MatmulOp::getDefaultIndexingMaps(getContext()),
3577 [](AffineMap map
) -> Attribute
{ return AffineMapAttr::get(map
); });
3578 if (!llvm::equal(getIndexingMaps(), indexingMaps
)) {
3579 p
<< " indexing_maps = [";
3580 llvm::interleaveComma(getIndexingMaps(), p
,
3581 [&](Attribute attr
) { p
.printAttribute(attr
); });
3586 /// Verify the user defined indexing maps.
3587 LogicalResult
MatmulOp::verify() {
3588 // Verification of pure matmul is handled by verifyStructuredOpInterface().
3589 if (!hasUserDefinedMaps())
3592 for (unsigned opIndex
= 0; opIndex
< 2; opIndex
++) {
3593 if (failed(verifyExtendedMatmulSemantic(*this, opIndex
)))
3599 LogicalResult
MatmulOp::fold(FoldAdaptor
, SmallVectorImpl
<OpFoldResult
> &) {
3600 return memref::foldMemRefCast(*this);
3602 void MatmulOp::getEffects(
3603 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
3605 if (hasPureTensorSemantics())
3607 getGenericEffectsImpl(effects
, cast
<LinalgOp
>(getOperation()));
3610 Speculation::Speculatability
MatmulOp::getSpeculatability() {
3611 return getGenericSpeculatabilityImpl(cast
<LinalgOp
>(getOperation()));
3614 } // namespace linalg