1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the Linalg operations.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/AsmParser/AsmParser.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Arith/Utils/Utils.h"
19 #include "mlir/Dialect/Complex/IR/Complex.h"
20 #include "mlir/Dialect/Math/IR/Math.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Utils/IndexingUtils.h"
26 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
27 #include "mlir/Dialect/Utils/StaticValueUtils.h"
28 #include "mlir/IR/AffineExprVisitor.h"
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Attributes.h"
31 #include "mlir/IR/BuiltinAttributes.h"
32 #include "mlir/IR/BuiltinTypeInterfaces.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OpImplementation.h"
35 #include "mlir/IR/OperationSupport.h"
36 #include "mlir/IR/PatternMatch.h"
37 #include "mlir/Interfaces/InferTypeOpInterface.h"
38 #include "mlir/Interfaces/SideEffectInterfaces.h"
40 #include "llvm/ADT/DenseMap.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SetOperations.h"
43 #include "llvm/ADT/SmallSet.h"
44 #include "llvm/ADT/SmallVector.h"
45 #include "llvm/ADT/StringSet.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include "llvm/Support/LogicalResult.h"
49 #include "llvm/Support/MathExtras.h"
50 #include "llvm/Support/raw_ostream.h"
55 using namespace mlir::linalg
;
57 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
58 static OpFoldResult
getDimValue(OpBuilder
&builder
, Location loc
, Value v
,
60 auto type
= cast
<ShapedType
>(v
.getType());
61 if (!type
.isDynamicDim(dim
))
62 return builder
.getIndexAttr(type
.getDimSize(dim
));
64 return getAsOpFoldResult(
65 TypeSwitch
<Type
, Value
>(v
.getType())
66 .Case
<RankedTensorType
>([&](RankedTensorType t
) -> Value
{
67 return builder
.create
<tensor::DimOp
>(loc
, v
, dim
);
69 .Case
<MemRefType
>([&](MemRefType t
) -> Value
{
70 return builder
.create
<memref::DimOp
>(loc
, v
, dim
);
74 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
76 static Operation
*getSlice(OpBuilder
&b
, Location loc
, Value source
,
77 ArrayRef
<OpFoldResult
> offsets
,
78 ArrayRef
<OpFoldResult
> sizes
,
79 ArrayRef
<OpFoldResult
> strides
) {
80 return TypeSwitch
<Type
, Operation
*>(source
.getType())
81 .Case
<RankedTensorType
>([&](RankedTensorType t
) -> Operation
* {
82 return b
.create
<tensor::ExtractSliceOp
>(loc
, source
, offsets
, sizes
,
85 .Case
<MemRefType
>([&](MemRefType type
) -> Operation
* {
86 return b
.create
<memref::SubViewOp
>(loc
, source
, offsets
, sizes
,
89 .Default([&](Type t
) -> Operation
* { return nullptr; });
92 //===----------------------------------------------------------------------===//
94 //===----------------------------------------------------------------------===//
96 Value
linalg::createOrFoldDimOp(OpBuilder
&b
, Location loc
, Value source
,
98 if (llvm::isa
<UnrankedMemRefType
, MemRefType
>(source
.getType()))
99 return b
.createOrFold
<memref::DimOp
>(loc
, source
, dim
);
100 if (llvm::isa
<UnrankedTensorType
, RankedTensorType
>(source
.getType()))
101 return b
.createOrFold
<tensor::DimOp
>(loc
, source
, dim
);
102 llvm_unreachable("Expected MemRefType or TensorType");
105 OpFoldResult
linalg::createFoldedDimOp(OpBuilder
&b
, Location loc
, Value source
,
107 auto shapedType
= llvm::cast
<ShapedType
>(source
.getType());
108 if (!shapedType
.hasRank() || shapedType
.isDynamicDim(dim
))
109 return createOrFoldDimOp(b
, loc
, source
, dim
);
110 return b
.getIndexAttr(shapedType
.getDimSize(dim
));
113 //===----------------------------------------------------------------------===//
114 // Support for named Linalg ops defined in ods-gen.
115 //===----------------------------------------------------------------------===//
117 using RegionBuilderFn
= llvm::function_ref
<void(ImplicitLocOpBuilder
&, Block
&,
118 ArrayRef
<NamedAttribute
>)>;
120 /// Fills the region of a structured operation using the provided
121 /// `regionBuilder`. The method is used by both named structured ops created by
122 /// ods-gen and by manually defined C++ ops. It is called by both builders and
123 /// parsers and creates a block with arguments corresponding to the elemental
124 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
126 static void fillStructuredOpRegion(OpBuilder
&opBuilder
, Region
®ion
,
127 TypeRange inputTypes
, TypeRange outputTypes
,
128 ArrayRef
<NamedAttribute
> attrs
,
129 RegionBuilderFn regionBuilder
) {
130 assert(llvm::all_of(outputTypes
, llvm::IsaPred
<ShapedType
>));
132 SmallVector
<Type
, 8> argTypes
;
133 SmallVector
<Location
, 8> argLocs
;
134 for (auto containers
: {inputTypes
, outputTypes
}) {
135 for (auto t
: containers
) {
137 isa
<MemRefType
, RankedTensorType
>(t
) ? getElementTypeOrSelf(t
) : t
);
139 // TODO: Pass in a proper location here.
140 argLocs
.push_back(opBuilder
.getUnknownLoc());
145 OpBuilder::InsertionGuard
guard(opBuilder
);
147 opBuilder
.createBlock(®ion
, /*insertPt=*/{}, argTypes
, argLocs
);
149 opBuilder
.setInsertionPointToStart(body
);
150 ImplicitLocOpBuilder
b(opBuilder
.getUnknownLoc(), opBuilder
);
151 regionBuilder(b
, *body
, attrs
);
153 // indexing_maps is an auto-generated method.
155 // iterator_types is an auto-generated method.
158 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
159 /// The result types are derived automatically if `resultTensorTypes` is none.
160 /// The body of the operation is filled using `regionBuilder`. All ods-gen
161 /// created structured operations use the method to implement their builders.
162 static void buildStructuredOp(OpBuilder
&b
, OperationState
&state
,
163 std::optional
<TypeRange
> resultTensorTypes
,
164 ValueRange inputs
, ValueRange outputs
,
165 ArrayRef
<NamedAttribute
> attributes
,
166 RegionBuilderFn regionBuilder
) {
167 // Derive the result types if needed.
168 SmallVector
<Type
> derivedResultTypes
=
169 resultTensorTypes
.value_or(TypeRange());
170 if (!resultTensorTypes
)
171 copy_if(outputs
.getTypes(), std::back_inserter(derivedResultTypes
),
172 llvm::IsaPred
<RankedTensorType
>);
174 state
.addOperands(inputs
);
175 state
.addOperands(outputs
);
176 state
.addTypes(derivedResultTypes
);
178 state
.addAttributes(attributes
);
180 "operandSegmentSizes",
181 b
.getDenseI32ArrayAttr({static_cast<int32_t>(inputs
.size()),
182 static_cast<int32_t>(outputs
.size())}));
184 // Create and fill the region of the structured operation.
185 Region
®ion
= *state
.addRegion();
186 fillStructuredOpRegion(b
, region
, TypeRange(inputs
), TypeRange(outputs
),
187 state
.attributes
.getAttrs(), regionBuilder
);
190 static void buildMatmulOp(OpBuilder
&b
, OperationState
&state
,
191 std::optional
<TypeRange
> resultTensorTypes
,
192 ValueRange inputs
, ValueRange outputs
,
193 ArrayRef
<NamedAttribute
> attributes
,
194 RegionBuilderFn regionBuilder
,
195 ArrayRef
<AffineMap
> indexingMaps
) {
196 // Initialize indexingMaps attribute, for MatmulOp.
197 SmallVector
<Attribute
, 3> indexingMapsAttrVal
;
198 indexingMapsAttrVal
= llvm::map_to_vector(
199 MatmulOp::getDefaultIndexingMaps(b
.getContext()),
200 [](AffineMap map
) -> Attribute
{ return AffineMapAttr::get(map
); });
201 state
.addAttribute("indexing_maps", b
.getArrayAttr(indexingMapsAttrVal
));
202 return buildStructuredOp(b
, state
, resultTensorTypes
, inputs
, outputs
,
203 attributes
, regionBuilder
);
206 /// Common parsing used for both named structured ops created by ods-gen and by
207 /// manually defined C++ ops. Does not handle regions.
209 parseCommonStructuredOpParts(OpAsmParser
&parser
, OperationState
&result
,
210 SmallVectorImpl
<Type
> &inputTypes
,
211 SmallVectorImpl
<Type
> &outputTypes
,
212 bool addOperandSegmentSizes
= true) {
213 SMLoc attrsLoc
, inputsOperandsLoc
, outputsOperandsLoc
;
214 SmallVector
<OpAsmParser::UnresolvedOperand
, 4> inputsOperands
,
217 if (succeeded(parser
.parseOptionalLess())) {
218 if (parser
.parseAttribute(result
.propertiesAttr
) || parser
.parseGreater())
221 attrsLoc
= parser
.getCurrentLocation();
222 if (parser
.parseOptionalAttrDict(result
.attributes
))
225 if (succeeded(parser
.parseOptionalKeyword("ins"))) {
226 if (parser
.parseLParen())
229 inputsOperandsLoc
= parser
.getCurrentLocation();
230 if (parser
.parseOperandList(inputsOperands
) ||
231 parser
.parseColonTypeList(inputTypes
) || parser
.parseRParen())
235 if (succeeded(parser
.parseOptionalKeyword("outs"))) {
236 outputsOperandsLoc
= parser
.getCurrentLocation();
237 if (parser
.parseLParen() || parser
.parseOperandList(outputsOperands
) ||
238 parser
.parseColonTypeList(outputTypes
) || parser
.parseRParen())
242 if (parser
.resolveOperands(inputsOperands
, inputTypes
, inputsOperandsLoc
,
244 parser
.resolveOperands(outputsOperands
, outputTypes
, outputsOperandsLoc
,
248 if (addOperandSegmentSizes
) {
249 // This is a bit complex because we're trying to be backward compatible with
250 // operation syntax that mix the inherent attributes and the discardable
251 // ones in the same dictionary. If the properties are used, we append the
252 // operandSegmentSizes there directly. Otherwise we append it to the
253 // discardable attributes dictionary where it is handled by the generic
254 // Operation::create(...) method.
255 if (result
.propertiesAttr
) {
256 NamedAttrList attrs
= llvm::cast
<DictionaryAttr
>(result
.propertiesAttr
);
257 attrs
.append("operandSegmentSizes",
258 parser
.getBuilder().getDenseI32ArrayAttr(
259 {static_cast<int32_t>(inputsOperands
.size()),
260 static_cast<int32_t>(outputsOperands
.size())}));
261 result
.propertiesAttr
= attrs
.getDictionary(parser
.getContext());
263 result
.addAttribute("operandSegmentSizes",
264 parser
.getBuilder().getDenseI32ArrayAttr(
265 {static_cast<int32_t>(inputsOperands
.size()),
266 static_cast<int32_t>(outputsOperands
.size())}));
269 if (!result
.propertiesAttr
) {
270 std::optional
<RegisteredOperationName
> info
=
271 result
.name
.getRegisteredInfo();
273 if (failed(info
->verifyInherentAttrs(result
.attributes
, [&]() {
274 return parser
.emitError(attrsLoc
)
275 << "'" << result
.name
.getStringRef() << "' op ";
283 static void printCommonStructuredOpParts(OpAsmPrinter
&p
, ValueRange inputs
,
284 ValueRange outputs
) {
286 p
<< " ins(" << inputs
<< " : " << inputs
.getTypes() << ")";
287 if (!outputs
.empty())
288 p
<< " outs(" << outputs
<< " : " << outputs
.getTypes() << ")";
291 //===----------------------------------------------------------------------===//
292 // Specific parsing and printing for named structured ops created by ods-gen.
293 //===----------------------------------------------------------------------===//
295 static ParseResult
parseNamedStructuredOpRegion(
296 OpAsmParser
&parser
, Region
®ion
, unsigned numRegionArgs
,
297 TypeRange inputTypes
, TypeRange outputTypes
, ArrayRef
<NamedAttribute
> attrs
,
298 RegionBuilderFn regionBuilder
) {
299 if (numRegionArgs
!= inputTypes
.size() + outputTypes
.size()) {
300 return parser
.emitError(
301 parser
.getCurrentLocation(),
302 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
303 "region expects {0} args, got {1}",
304 numRegionArgs
, inputTypes
.size() + outputTypes
.size()));
307 OpBuilder
opBuilder(parser
.getContext());
308 fillStructuredOpRegion(opBuilder
, region
, inputTypes
, outputTypes
, attrs
,
314 parseNamedStructuredOpResults(OpAsmParser
&parser
,
315 SmallVectorImpl
<Type
> &resultTypes
) {
316 if (parser
.parseOptionalArrowTypeList(resultTypes
))
321 static ParseResult
parseNamedStructuredOp(OpAsmParser
&parser
,
322 OperationState
&result
,
323 unsigned numRegionArgs
,
324 RegionBuilderFn regionBuilder
) {
325 // TODO: Enable when ods-gen supports captures.
326 SmallVector
<Type
, 1> inputTypes
, outputTypes
;
327 if (parseCommonStructuredOpParts(parser
, result
, inputTypes
, outputTypes
))
330 // Parse optional attributes.
331 if (parser
.parseOptionalAttrDict(result
.attributes
))
334 // TODO: consider merging results parsing into region parsing.
335 // Need to wait for declarative assembly resolution to decide.
336 SmallVector
<Type
, 1> outputTensorsTypes
;
337 if (parseNamedStructuredOpResults(parser
, outputTensorsTypes
))
339 result
.addTypes(outputTensorsTypes
);
341 std::unique_ptr
<Region
> region
= std::make_unique
<Region
>();
342 if (parseNamedStructuredOpRegion(parser
, *region
, numRegionArgs
, inputTypes
,
343 outputTypes
, result
.attributes
.getAttrs(),
346 result
.addRegion(std::move(region
));
351 static void printNamedStructuredOpResults(OpAsmPrinter
&p
,
352 TypeRange resultTypes
) {
353 if (resultTypes
.empty())
355 p
.printOptionalArrowTypeList(resultTypes
);
358 static void printNamedStructuredOp(OpAsmPrinter
&p
, Operation
*op
,
359 ValueRange inputs
, ValueRange outputs
,
360 ArrayRef
<StringRef
> elidedAttrs
= {}) {
361 p
.printOptionalAttrDict(op
->getAttrs(), elidedAttrs
);
363 // Printing is shared with generic ops, except for the region and
365 printCommonStructuredOpParts(p
, inputs
, outputs
);
368 printNamedStructuredOpResults(p
, op
->getResultTypes());
373 //===----------------------------------------------------------------------===//
374 // Region builder helper.
375 // TODO: Move this to a utility library.
376 // The public methods on this class are referenced directly from generated code.
377 // Helper build the unary, binary, and type conversion functions defined by the
378 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
381 // Implementations of the math functions must be polymorphic over numeric types,
382 // internally performing necessary casts. If the function application makes no
383 // sense, then the only recourse is to assert and return nullptr. This can be
384 // extended later if it becomes possible to fail construction of the region. The
385 // invariant should be enforced at a higher level.
387 // TODO: These helpers are currently type polymorphic over the class of integer
388 // and floating point types, but they will not internally cast within bit
389 // widths of a class (mixed precision such as i8->i32) or across classes
390 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
391 // to be handled with care and work is being considered to extend the op
392 // language to make such cases explicit. In the mean-time, violating this will
393 // fail verification, which is deemed acceptable.
394 //===----------------------------------------------------------------------===//
398 class RegionBuilderHelper
{
400 RegionBuilderHelper(OpBuilder
&builder
, Block
&block
)
401 : builder(builder
), block(block
) {}
403 // Build the unary functions defined by OpDSL.
404 Value
buildUnaryFn(UnaryFn unaryFn
, Value arg
) {
405 if (!isFloatingPoint(arg
))
406 llvm_unreachable("unsupported non numeric type");
407 OpBuilder::InsertionGuard
g(builder
);
408 builder
.setInsertionPointToEnd(&block
);
411 return builder
.create
<math::ExpOp
>(arg
.getLoc(), arg
);
413 return builder
.create
<math::LogOp
>(arg
.getLoc(), arg
);
415 return builder
.create
<math::AbsFOp
>(arg
.getLoc(), arg
);
417 return builder
.create
<math::CeilOp
>(arg
.getLoc(), arg
);
419 return builder
.create
<math::FloorOp
>(arg
.getLoc(), arg
);
421 return builder
.create
<arith::NegFOp
>(arg
.getLoc(), arg
);
422 case UnaryFn::reciprocal
: {
423 Attribute oneAttr
= builder
.getOneAttr(arg
.getType());
424 auto one
= builder
.create
<arith::ConstantOp
>(arg
.getLoc(),
425 ::cast
<TypedAttr
>(oneAttr
));
426 return builder
.create
<arith::DivFOp
>(arg
.getLoc(), one
, arg
);
429 return builder
.create
<math::RoundOp
>(arg
.getLoc(), arg
);
431 return builder
.create
<math::SqrtOp
>(arg
.getLoc(), arg
);
433 return builder
.create
<math::RsqrtOp
>(arg
.getLoc(), arg
);
434 case UnaryFn::square
:
435 return builder
.create
<arith::MulFOp
>(arg
.getLoc(), arg
, arg
);
437 return builder
.create
<math::TanhOp
>(arg
.getLoc(), arg
);
439 return builder
.create
<math::ErfOp
>(arg
.getLoc(), arg
);
441 llvm_unreachable("unsupported unary function");
444 // Build the binary functions defined by OpDSL.
445 Value
buildBinaryFn(BinaryFn binaryFn
, Value arg0
, Value arg1
) {
446 bool allComplex
= isComplex(arg0
) && isComplex(arg1
);
447 bool allFloatingPoint
= isFloatingPoint(arg0
) && isFloatingPoint(arg1
);
448 bool allInteger
= isInteger(arg0
) && isInteger(arg1
);
449 bool allBool
= allInteger
&& arg0
.getType().getIntOrFloatBitWidth() == 1 &&
450 arg1
.getType().getIntOrFloatBitWidth() == 1;
451 if (!allComplex
&& !allFloatingPoint
&& !allInteger
)
452 llvm_unreachable("unsupported non numeric type");
453 OpBuilder::InsertionGuard
g(builder
);
454 builder
.setInsertionPointToEnd(&block
);
458 return builder
.create
<complex::AddOp
>(arg0
.getLoc(), arg0
, arg1
);
459 if (allFloatingPoint
)
460 return builder
.create
<arith::AddFOp
>(arg0
.getLoc(), arg0
, arg1
);
462 return builder
.create
<arith::OrIOp
>(arg0
.getLoc(), arg0
, arg1
);
463 return builder
.create
<arith::AddIOp
>(arg0
.getLoc(), arg0
, arg1
);
466 return builder
.create
<complex::SubOp
>(arg0
.getLoc(), arg0
, arg1
);
467 if (allFloatingPoint
)
468 return builder
.create
<arith::SubFOp
>(arg0
.getLoc(), arg0
, arg1
);
470 llvm_unreachable("unsupported operation: sub with bools");
471 return builder
.create
<arith::SubIOp
>(arg0
.getLoc(), arg0
, arg1
);
474 return builder
.create
<complex::MulOp
>(arg0
.getLoc(), arg0
, arg1
);
475 if (allFloatingPoint
)
476 return builder
.create
<arith::MulFOp
>(arg0
.getLoc(), arg0
, arg1
);
478 return builder
.create
<arith::AndIOp
>(arg0
.getLoc(), arg0
, arg1
);
479 return builder
.create
<arith::MulIOp
>(arg0
.getLoc(), arg0
, arg1
);
482 return builder
.create
<complex::DivOp
>(arg0
.getLoc(), arg0
, arg1
);
483 if (allFloatingPoint
)
484 return builder
.create
<arith::DivFOp
>(arg0
.getLoc(), arg0
, arg1
);
486 llvm_unreachable("unsupported operation: div with bools");
487 return builder
.create
<arith::DivSIOp
>(arg0
.getLoc(), arg0
, arg1
);
488 case BinaryFn::div_unsigned
:
489 if (!allInteger
|| allBool
)
490 llvm_unreachable("unsupported operation: unsigned div not on uint");
491 return builder
.create
<arith::DivUIOp
>(arg0
.getLoc(), arg0
, arg1
);
492 case BinaryFn::max_signed
:
494 if (allFloatingPoint
)
495 return builder
.create
<arith::MaximumFOp
>(arg0
.getLoc(), arg0
, arg1
);
496 return builder
.create
<arith::MaxSIOp
>(arg0
.getLoc(), arg0
, arg1
);
497 case BinaryFn::min_signed
:
499 if (allFloatingPoint
)
500 return builder
.create
<arith::MinimumFOp
>(arg0
.getLoc(), arg0
, arg1
);
501 return builder
.create
<arith::MinSIOp
>(arg0
.getLoc(), arg0
, arg1
);
502 case BinaryFn::max_unsigned
:
504 if (allFloatingPoint
)
505 return builder
.create
<arith::MaximumFOp
>(arg0
.getLoc(), arg0
, arg1
);
506 return builder
.create
<arith::MaxUIOp
>(arg0
.getLoc(), arg0
, arg1
);
507 case BinaryFn::min_unsigned
:
509 if (allFloatingPoint
)
510 return builder
.create
<arith::MinimumFOp
>(arg0
.getLoc(), arg0
, arg1
);
511 return builder
.create
<arith::MinUIOp
>(arg0
.getLoc(), arg0
, arg1
);
513 assert(allFloatingPoint
);
514 return builder
.create
<math::PowFOp
>(arg0
.getLoc(), arg0
, arg1
);
516 llvm_unreachable("unsupported binary function");
519 // Build the ternary functions defined by OpDSL.
520 Value
buildTernaryFn(TernaryFn ternaryFn
, Value arg0
, Value arg1
,
523 isInteger(arg0
) && arg0
.getType().getIntOrFloatBitWidth() == 1;
524 bool tailFloatingPoint
=
525 isFloatingPoint(arg0
) && isFloatingPoint(arg1
) && isFloatingPoint(arg2
);
526 bool tailInteger
= isInteger(arg0
) && isInteger(arg1
) && isInteger(arg2
);
527 OpBuilder::InsertionGuard
g(builder
);
528 builder
.setInsertionPointToEnd(&block
);
530 case TernaryFn::select
:
531 if (!headBool
&& !(tailFloatingPoint
|| tailInteger
))
532 llvm_unreachable("unsupported non numeric type");
533 return builder
.create
<arith::SelectOp
>(arg0
.getLoc(), arg0
, arg1
, arg2
);
535 llvm_unreachable("unsupported ternary function");
538 // Build the type functions defined by OpDSL.
539 Value
buildTypeFn(TypeFn typeFn
, Type toType
, Value operand
) {
541 case TypeFn::cast_signed
:
542 return cast(toType
, operand
, false);
543 case TypeFn::cast_unsigned
:
544 return cast(toType
, operand
, true);
546 llvm_unreachable("unsupported type conversion function");
549 void yieldOutputs(ValueRange values
) {
550 OpBuilder::InsertionGuard
g(builder
);
551 builder
.setInsertionPointToEnd(&block
);
552 Location loc
= builder
.getUnknownLoc();
553 builder
.create
<YieldOp
>(loc
, values
);
556 Value
constant(const std::string
&value
) {
557 OpBuilder::InsertionGuard
g(builder
);
558 builder
.setInsertionPointToEnd(&block
);
559 Location loc
= builder
.getUnknownLoc();
560 Attribute valueAttr
= parseAttribute(value
, builder
.getContext());
561 return builder
.create
<arith::ConstantOp
>(loc
, ::cast
<TypedAttr
>(valueAttr
));
564 Value
index(int64_t dim
) {
565 OpBuilder::InsertionGuard
g(builder
);
566 builder
.setInsertionPointToEnd(&block
);
567 return builder
.create
<IndexOp
>(builder
.getUnknownLoc(), dim
);
570 Type
getIntegerType(unsigned width
) {
571 return IntegerType::get(builder
.getContext(), width
);
574 Type
getFloat32Type() { return Float32Type::get(builder
.getContext()); }
575 Type
getFloat64Type() { return Float64Type::get(builder
.getContext()); }
578 // Generates operations to cast the given operand to a specified type.
579 // If the cast cannot be performed, a warning will be issued and the
580 // operand returned as-is (which will presumably yield a verification
581 // issue downstream).
582 Value
cast(Type toType
, Value operand
, bool isUnsignedCast
) {
583 OpBuilder::InsertionGuard
g(builder
);
584 builder
.setInsertionPointToEnd(&block
);
585 auto loc
= operand
.getLoc();
586 return convertScalarToDtype(builder
, loc
, operand
, toType
, isUnsignedCast
);
589 bool isComplex(Value value
) {
590 return llvm::isa
<ComplexType
>(value
.getType());
592 bool isFloatingPoint(Value value
) {
593 return llvm::isa
<FloatType
>(value
.getType());
595 bool isInteger(Value value
) {
596 return llvm::isa
<IntegerType
>(value
.getType());
605 //===----------------------------------------------------------------------===//
607 //===----------------------------------------------------------------------===//
611 struct EraseSelfCopy
: OpRewritePattern
<CopyOp
> {
612 using OpRewritePattern
<CopyOp
>::OpRewritePattern
;
613 LogicalResult
matchAndRewrite(CopyOp copyOp
,
614 PatternRewriter
&rewriter
) const override
{
615 if (copyOp
.getInputs() != copyOp
.getOutputs())
616 return rewriter
.notifyMatchFailure(copyOp
, "not a self copy");
617 if (copyOp
.hasPureBufferSemantics())
618 rewriter
.eraseOp(copyOp
);
620 rewriter
.replaceOp(copyOp
, copyOp
.getInputs());
628 void CopyOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
629 MLIRContext
*context
) {
630 results
.add
<EraseSelfCopy
>(context
);
633 //===----------------------------------------------------------------------===//
635 //===----------------------------------------------------------------------===//
639 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
641 /// For such op chains, we can create new linalg.fill ops with the result
642 /// type of the tensor.expand/collapse_shape op.
643 template <typename TensorReshapeOp
>
644 struct FoldFillWithTensorReshape
: OpRewritePattern
<TensorReshapeOp
> {
645 using OpRewritePattern
<TensorReshapeOp
>::OpRewritePattern
;
646 LogicalResult
matchAndRewrite(TensorReshapeOp reshapeOp
,
647 PatternRewriter
&rewriter
) const override
{
648 auto oldFill
= reshapeOp
.getSrc().template getDefiningOp
<FillOp
>();
652 Location loc
= oldFill
.getLoc();
653 TensorReshapeOp newInit
;
654 if constexpr (std::is_same
<TensorReshapeOp
, tensor::ExpandShapeOp
>::value
) {
656 newInit
= rewriter
.create
<TensorReshapeOp
>(
657 loc
, reshapeOp
.getResultType(), oldFill
.output(),
658 reshapeOp
.getReassociation(), reshapeOp
.getOutputShape(),
659 reshapeOp
.getStaticOutputShape());
661 newInit
= rewriter
.create
<TensorReshapeOp
>(loc
, reshapeOp
.getResultType(),
663 reshapeOp
.getReassociation());
665 rewriter
.replaceOpWithNewOp
<FillOp
>(reshapeOp
, ValueRange
{oldFill
.value()},
666 ValueRange
{newInit
});
671 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
672 /// filling value are the same.
673 struct FoldFillWithPad final
: public OpRewritePattern
<tensor::PadOp
> {
674 using OpRewritePattern::OpRewritePattern
;
676 LogicalResult
matchAndRewrite(tensor::PadOp padOp
,
677 PatternRewriter
&rewriter
) const override
{
678 auto fillOp
= padOp
.getSource().getDefiningOp
<linalg::FillOp
>();
682 // We can only fold if the padding value is the same as the original
684 Value padValue
= padOp
.getConstantPaddingValue();
685 if (!padValue
|| fillOp
.value() != padValue
)
688 ReifiedRankedShapedTypeDims reifiedShape
;
689 if (failed(reifyResultShapes(rewriter
, padOp
, reifiedShape
)))
690 return rewriter
.notifyMatchFailure(
691 padOp
, "failed to reify tensor.pad op result shape");
693 auto emptyTensor
= rewriter
.create
<tensor::EmptyOp
>(
694 padOp
.getLoc(), reifiedShape
.front(),
695 padOp
.getResultType().getElementType());
698 .create
<FillOp
>(fillOp
.getLoc(), ValueRange
{padValue
},
699 ValueRange
{emptyTensor
})
701 if (replacement
.getType() != padOp
.getResultType()) {
702 replacement
= rewriter
.create
<tensor::CastOp
>(
703 fillOp
.getLoc(), padOp
.getResultType(), replacement
);
705 rewriter
.replaceOp(padOp
, replacement
);
710 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
711 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
712 /// filling value are the same.
713 struct FoldInsertPadIntoFill
: public OpRewritePattern
<tensor::InsertSliceOp
> {
714 using OpRewritePattern::OpRewritePattern
;
716 LogicalResult
matchAndRewrite(tensor::InsertSliceOp insertOp
,
717 PatternRewriter
&rewriter
) const override
{
718 auto srcPadOp
= insertOp
.getSource().getDefiningOp
<tensor::PadOp
>();
722 if (insertOp
.getType().getRank() != insertOp
.getSourceType().getRank())
725 // Walk back the tensor.insert_slice chain and find the first destination
726 // value at the start of the chain.
727 Value firstDest
= insertOp
.getDest();
728 while (auto prevOp
= firstDest
.getDefiningOp
<tensor::InsertSliceOp
>()) {
729 if (prevOp
.getType().getRank() != prevOp
.getSourceType().getRank())
732 // Make sure the range of values accessed are disjoint. Without this, we
733 // cannot fold tensor.pad away.
734 bool disjoint
= false;
735 for (int i
= 0, e
= prevOp
.getType().getRank(); i
< e
; ++i
) {
736 // If the dimension has dynamic offset/size, we cannot guarantee
737 // disjoint. So just skip it.
738 if (insertOp
.isDynamicOffset(i
) || insertOp
.isDynamicSize(i
) ||
739 insertOp
.isDynamicStride(i
) || prevOp
.isDynamicOffset(i
) ||
740 prevOp
.isDynamicSize(i
) || prevOp
.isDynamicStride(i
))
743 // Get the range start and end, inclusively for both.
744 int64_t prevStart
= prevOp
.getStaticOffset(i
);
745 int64_t prevEnd
= prevStart
+ (prevOp
.getStaticSize(i
) - 1) *
746 prevOp
.getStaticStride(i
);
747 int64_t nextStart
= insertOp
.getStaticOffset(i
);
748 int64_t nextEnd
= nextStart
+ (insertOp
.getStaticSize(i
) - 1) *
749 insertOp
.getStaticStride(i
);
750 if (prevEnd
< nextStart
|| nextEnd
< prevStart
) {
758 firstDest
= prevOp
.getDest();
761 // Check whether the first destination is a fill op. For overlapped cases,
762 // this also cannot be true.
763 auto dstFillOp
= firstDest
.getDefiningOp
<linalg::FillOp
>();
767 // We can only fold if the padding value is the same as the original
769 Value padValue
= srcPadOp
.getConstantPaddingValue();
770 if (!padValue
|| dstFillOp
.value() != padValue
)
773 SmallVector
<OpFoldResult
> lowPads
= srcPadOp
.getMixedLowPad();
774 SmallVector
<OpFoldResult
> oldOffsets
= insertOp
.getMixedOffsets();
776 Location loc
= insertOp
.getLoc();
777 MLIRContext
*context
= getContext();
779 AffineExpr sym0
, sym1
;
780 bindSymbols(context
, sym0
, sym1
);
781 auto addMap
= AffineMap::get(0, 2, {sym0
+ sym1
}, context
);
783 // Calculate the new offsets for the insert. It should be the old offsets
784 // plus low padding sizes.
785 SmallVector
<OpFoldResult
, 4> newOffsets
;
786 for (const auto &p
: llvm::zip(lowPads
, oldOffsets
)) {
787 newOffsets
.push_back(affine::makeComposedFoldedAffineApply(
788 rewriter
, loc
, addMap
, {std::get
<0>(p
), std::get
<1>(p
)}));
791 RankedTensorType srcPadType
= srcPadOp
.getSourceType();
792 SmallVector
<OpFoldResult
, 4> newSizes
;
793 for (int i
= 0, e
= srcPadType
.getRank(); i
< e
; ++i
) {
794 if (srcPadType
.isDynamicDim(i
)) {
796 rewriter
.create
<tensor::DimOp
>(loc
, srcPadOp
.getSource(), i
)
799 newSizes
.push_back(rewriter
.getIndexAttr(srcPadType
.getDimSize(i
)));
803 rewriter
.replaceOpWithNewOp
<tensor::InsertSliceOp
>(
804 insertOp
, srcPadOp
.getSource(), insertOp
.getDest(), newOffsets
,
805 newSizes
, insertOp
.getMixedStrides());
810 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
811 struct FoldFillWithTensorExtract
: public OpRewritePattern
<tensor::ExtractOp
> {
813 using OpRewritePattern
<tensor::ExtractOp
>::OpRewritePattern
;
815 LogicalResult
matchAndRewrite(tensor::ExtractOp extractOp
,
816 PatternRewriter
&rewriter
) const override
{
817 // See if tensor input of tensor.extract op is the result of a linalg.fill
819 auto fillOp
= extractOp
.getTensor().getDefiningOp
<linalg::FillOp
>();
823 // Get scalar input operand of linalg.fill op.
824 Value extractedScalar
= fillOp
.getInputs()[0];
826 // Replace tensor.extract op with scalar value used to fill the tensor.
827 rewriter
.replaceOp(extractOp
, extractedScalar
);
832 /// Folds pack(fill) into a single fill op if
833 /// 1. The pack op does not have padding value, or
834 /// 2. The filled value and padding value are the same.
835 static FailureOr
<FillOp
> foldFillPackIntoFillOp(RewriterBase
&rewriter
,
836 tensor::PackOp packOp
) {
837 auto fillOp
= packOp
.getSource().getDefiningOp
<FillOp
>();
841 if (auto paddingValue
= packOp
.getPaddingValue())
842 if (!isEqualConstantIntOrValue(paddingValue
, fillOp
.value()))
845 Value packOpDest
= packOp
.getDest();
846 if (!packOpDest
.hasOneUse())
849 return rewriter
.create
<linalg::FillOp
>(packOp
.getLoc(), fillOp
.getInputs(),
853 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
854 struct FoldFillWithPack
: public OpRewritePattern
<tensor::PackOp
> {
856 FoldFillWithPack(MLIRContext
*context
)
857 : OpRewritePattern
<tensor::PackOp
>(context
) {}
859 LogicalResult
matchAndRewrite(tensor::PackOp packOp
,
860 PatternRewriter
&rewriter
) const override
{
861 auto fillOp
= foldFillPackIntoFillOp(rewriter
, packOp
);
864 rewriter
.replaceOp(packOp
, fillOp
.value().result());
869 /// Fold fill with copy.
870 struct FoldFillWithCopy
: OpRewritePattern
<linalg::CopyOp
> {
871 using OpRewritePattern
<linalg::CopyOp
>::OpRewritePattern
;
873 LogicalResult
matchAndRewrite(linalg::CopyOp copyOp
,
874 PatternRewriter
&rewriter
) const override
{
875 if (auto fillOp
= copyOp
.getInputs().front().getDefiningOp
<FillOp
>()) {
876 rewriter
.replaceOpWithNewOp
<FillOp
>(copyOp
, copyOp
.getResultTypes(),
878 copyOp
.getOutputs());
881 if (auto fillOp
= copyOp
.getOutputs().front().getDefiningOp
<FillOp
>()) {
882 rewriter
.replaceOpWithNewOp
<linalg::CopyOp
>(copyOp
, copyOp
.getInputs(),
883 fillOp
.getOutputs());
890 /// Fold fill with transpose.
891 struct FoldFillWithTranspose
: OpRewritePattern
<linalg::TransposeOp
> {
892 using OpRewritePattern
<linalg::TransposeOp
>::OpRewritePattern
;
894 LogicalResult
matchAndRewrite(linalg::TransposeOp transposeOp
,
895 PatternRewriter
&rewriter
) const override
{
896 if (auto fillOp
= transposeOp
.getInput().getDefiningOp
<FillOp
>()) {
897 rewriter
.replaceOpWithNewOp
<FillOp
>(
898 transposeOp
, transposeOp
.getResultTypes(), fillOp
.getInputs(),
899 transposeOp
.getDpsInitOperand(0)->get());
906 /// Fold a concat with all elements being fills of the same value
907 /// into a fill of the concat result shape.
908 struct FoldConcatsOfFill
: public OpRewritePattern
<tensor::ConcatOp
> {
909 using OpRewritePattern::OpRewritePattern
;
911 LogicalResult
matchAndRewrite(tensor::ConcatOp concatOp
,
912 PatternRewriter
&rewriter
) const override
{
913 auto concatOperands
= concatOp
.getInputs();
914 if (concatOperands
.empty()) {
918 auto firstFillOp
= concatOperands
.front().getDefiningOp
<linalg::FillOp
>();
922 // Prefetch the fill value.
923 OpFoldResult firstFillVal
=
924 getAsOpFoldResult(firstFillOp
.getDpsInputOperand(0)->get());
925 // Collect all the outs values for the fill operations.
926 SmallVector
<Value
> allOuts
;
927 allOuts
.push_back(firstFillOp
.getDpsInitOperand(0)->get());
929 auto isDefinedByCompatibleFillOp
= [&](Value v
) -> bool {
930 auto fillOp
= v
.getDefiningOp
<linalg::FillOp
>();
935 OpFoldResult fillVal
=
936 getAsOpFoldResult(fillOp
.getDpsInputOperand(0)->get());
937 if (fillVal
!= firstFillVal
)
940 allOuts
.push_back(fillOp
.getDpsInitOperand(0)->get());
943 if (!llvm::all_of(concatOperands
.drop_front(),
944 isDefinedByCompatibleFillOp
)) {
945 return rewriter
.notifyMatchFailure(
946 concatOp
, "not all operands are defined by a compatible fill op");
949 Value outsConcat
= rewriter
.create
<tensor::ConcatOp
>(
950 concatOp
.getLoc(), concatOp
.getDim(), allOuts
);
951 rewriter
.replaceOpWithNewOp
<linalg::FillOp
>(
952 concatOp
, firstFillOp
.getDpsInputOperand(0)->get(), outsConcat
);
959 void FillOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
960 MLIRContext
*context
) {
961 results
.add
<FoldConcatsOfFill
, FoldFillWithCopy
, FoldFillWithTensorExtract
,
962 FoldFillWithPack
, FoldFillWithPad
,
963 FoldFillWithTensorReshape
<tensor::CollapseShapeOp
>,
964 FoldFillWithTensorReshape
<tensor::ExpandShapeOp
>,
965 FoldInsertPadIntoFill
, FoldFillWithTranspose
>(context
);
968 //===----------------------------------------------------------------------===//
970 //===----------------------------------------------------------------------===//
972 static void buildGenericRegion(
973 OpBuilder
&builder
, Location loc
, Region
®ion
, ValueRange inputs
,
975 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
) {
976 SmallVector
<Type
, 4> blockArgTypes
;
977 SmallVector
<Location
, 4> blockArgLocs
;
978 for (ValueRange container
: {inputs
, outputs
}) {
979 for (Value v
: container
) {
980 Type t
= v
.getType();
981 blockArgTypes
.push_back(
982 isa
<MemRefType
, RankedTensorType
>(t
) ? getElementTypeOrSelf(t
) : t
);
983 blockArgLocs
.push_back(v
.getLoc());
987 OpBuilder::InsertionGuard
guard(builder
);
989 builder
.createBlock(®ion
, region
.end(), blockArgTypes
, blockArgLocs
);
990 bodyBuild(builder
, loc
, bodyBlock
->getArguments());
993 void GenericOp::getAsmBlockArgumentNames(Region
®ion
,
994 OpAsmSetValueNameFn setNameFn
) {
995 for (Value v
: getRegionInputArgs())
997 for (Value v
: getRegionOutputArgs())
1001 void GenericOp::build(
1002 OpBuilder
&builder
, OperationState
&result
, TypeRange resultTensorTypes
,
1003 ValueRange inputs
, ValueRange outputs
, ArrayAttr indexingMaps
,
1004 ArrayAttr iteratorTypes
, StringAttr doc
, StringAttr libraryCall
,
1005 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
,
1006 ArrayRef
<NamedAttribute
> attributes
) {
1007 build(builder
, result
, resultTensorTypes
, inputs
, outputs
, indexingMaps
,
1008 iteratorTypes
, doc
, libraryCall
);
1009 result
.addAttributes(attributes
);
1011 buildGenericRegion(builder
, result
.location
, *result
.regions
.front(),
1012 inputs
, outputs
, bodyBuild
);
1015 void GenericOp::build(
1016 OpBuilder
&builder
, OperationState
&result
, TypeRange resultTensorTypes
,
1017 ValueRange inputs
, ValueRange outputs
, ArrayRef
<AffineMap
> indexingMaps
,
1018 ArrayRef
<utils::IteratorType
> iteratorTypes
, StringRef doc
,
1019 StringRef libraryCall
,
1020 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
,
1021 ArrayRef
<NamedAttribute
> attributes
) {
1022 build(builder
, result
, resultTensorTypes
, inputs
, outputs
,
1023 builder
.getAffineMapArrayAttr(indexingMaps
),
1024 builder
.getArrayAttr(llvm::to_vector(llvm::map_range(
1026 [&](utils::IteratorType iter
) -> mlir::Attribute
{
1027 return IteratorTypeAttr::get(builder
.getContext(), iter
);
1029 doc
.empty() ? StringAttr() : builder
.getStringAttr(doc
),
1030 libraryCall
.empty() ? StringAttr() : builder
.getStringAttr(libraryCall
),
1031 bodyBuild
, attributes
);
1034 void GenericOp::build(
1035 OpBuilder
&builder
, OperationState
&result
, ValueRange inputs
,
1036 ValueRange outputs
, ArrayRef
<AffineMap
> indexingMaps
,
1037 ArrayRef
<utils::IteratorType
> iteratorTypes
, StringRef doc
,
1038 StringRef libraryCall
,
1039 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
,
1040 ArrayRef
<NamedAttribute
> attributes
) {
1041 build(builder
, result
, TypeRange
{}, inputs
, outputs
, indexingMaps
,
1042 iteratorTypes
, doc
, libraryCall
, bodyBuild
, attributes
);
1045 void GenericOp::build(
1046 OpBuilder
&builder
, OperationState
&result
, ValueRange inputs
,
1047 ValueRange outputs
, ArrayRef
<AffineMap
> indexingMaps
,
1048 ArrayRef
<utils::IteratorType
> iteratorTypes
,
1049 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
,
1050 ArrayRef
<NamedAttribute
> attributes
) {
1051 build(builder
, result
, inputs
, outputs
, indexingMaps
, iteratorTypes
,
1053 /*libraryCall=*/"", bodyBuild
, attributes
);
1056 void GenericOp::build(
1057 OpBuilder
&builder
, OperationState
&result
, TypeRange resultTensorTypes
,
1058 ValueRange inputs
, ValueRange outputs
, ArrayRef
<AffineMap
> indexingMaps
,
1059 ArrayRef
<utils::IteratorType
> iteratorTypes
,
1060 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
,
1061 ArrayRef
<NamedAttribute
> attributes
) {
1062 build(builder
, result
, resultTensorTypes
, inputs
, outputs
, indexingMaps
,
1065 /*libraryCall=*/"", bodyBuild
, attributes
);
1068 void GenericOp::print(OpAsmPrinter
&p
) {
1071 // Print extra attributes.
1072 auto genericAttrNames
= linalgTraitAttrNames();
1074 llvm::StringSet
<> genericAttrNamesSet
;
1075 genericAttrNamesSet
.insert(genericAttrNames
.begin(), genericAttrNames
.end());
1076 SmallVector
<NamedAttribute
, 8> genericAttrs
;
1077 for (auto attr
: (*this)->getAttrs()) {
1078 if (attr
.getName() == getIteratorTypesAttrName()) {
1079 auto iteratorTypes
=
1080 llvm::cast
<ArrayAttr
>(attr
.getValue())
1081 .getAsValueRange
<IteratorTypeAttr
, utils::IteratorType
>();
1082 // Convert IteratorType enums into the string representation. This is
1083 // needed, because tests still use the old format when 'iterator_types'
1084 // attribute is represented as an array of strings.
1085 // TODO: Remove this conversion once tests are fixed.
1086 SmallVector
<Attribute
> iteratorTypeNames
=
1087 llvm::to_vector(llvm::map_range(
1088 iteratorTypes
, [&](utils::IteratorType t
) -> Attribute
{
1089 return StringAttr::get(getContext(), stringifyIteratorType(t
));
1092 genericAttrs
.emplace_back(
1093 getIteratorTypesAttrName(),
1094 ArrayAttr::get(getContext(), iteratorTypeNames
));
1095 } else if (genericAttrNamesSet
.count(attr
.getName().strref()) > 0) {
1096 genericAttrs
.push_back(attr
);
1099 if (!genericAttrs
.empty()) {
1100 auto genericDictAttr
= DictionaryAttr::get(getContext(), genericAttrs
);
1101 p
<< genericDictAttr
;
1104 // Printing is shared with named ops, except for the region and attributes
1105 printCommonStructuredOpParts(p
, getDpsInputs(), getDpsInits());
1107 genericAttrNames
.push_back("operandSegmentSizes");
1108 genericAttrNamesSet
.insert(genericAttrNames
.back());
1110 bool hasExtraAttrs
= false;
1111 for (NamedAttribute n
: (*this)->getAttrs()) {
1112 if ((hasExtraAttrs
= !genericAttrNamesSet
.contains(n
.getName().strref())))
1115 if (hasExtraAttrs
) {
1117 p
.printOptionalAttrDict((*this)->getAttrs(),
1118 /*elidedAttrs=*/genericAttrNames
);
1122 if (!getRegion().empty()) {
1124 p
.printRegion(getRegion());
1128 printNamedStructuredOpResults(p
, getResultTensors().getTypes());
1131 ParseResult
GenericOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
1132 DictionaryAttr dictAttr
;
1133 // Parse the core linalg traits that must check into a dictAttr.
1134 // The name is unimportant as we will overwrite result.attributes.
1135 // The core linalg traits must contain the information necessary to pass the
1137 llvm::SMLoc attributeLocation
= parser
.getCurrentLocation();
1138 if (parser
.parseAttribute(dictAttr
, "_", result
.attributes
))
1140 result
.attributes
.assign(dictAttr
.getValue().begin(),
1141 dictAttr
.getValue().end());
1143 // Convert array of string into an array of IteratorType enums. This is
1144 // needed, because tests still use the old format when 'iterator_types'
1145 // attribute is represented as an array of strings.
1146 // TODO: Remove this conversion once tests are fixed.
1147 auto iteratorTypes
= dyn_cast_or_null
<ArrayAttr
>(
1148 result
.attributes
.get(getIteratorTypesAttrName(result
.name
)));
1149 if (!iteratorTypes
) {
1150 return parser
.emitError(attributeLocation
)
1151 << "expected " << getIteratorTypesAttrName(result
.name
)
1152 << " array attribute";
1155 SmallVector
<Attribute
> iteratorTypeAttrs
;
1157 for (StringRef s
: iteratorTypes
.getAsValueRange
<StringAttr
>()) {
1158 auto maybeIteratorType
= utils::symbolizeIteratorType(s
);
1159 if (!maybeIteratorType
.has_value())
1160 return parser
.emitError(parser
.getCurrentLocation())
1161 << "unexpected iterator_type (" << s
<< ")";
1163 iteratorTypeAttrs
.push_back(
1164 IteratorTypeAttr::get(parser
.getContext(), maybeIteratorType
.value()));
1166 result
.attributes
.set(getIteratorTypesAttrName(result
.name
),
1167 parser
.getBuilder().getArrayAttr(iteratorTypeAttrs
));
1169 // Parsing is shared with named ops, except for the region.
1170 SmallVector
<Type
, 1> inputTypes
, outputTypes
;
1171 if (parseCommonStructuredOpParts(parser
, result
, inputTypes
, outputTypes
))
1174 // Optional attributes may be added.
1175 if (succeeded(parser
.parseOptionalKeyword("attrs")))
1176 if (failed(parser
.parseEqual()) ||
1177 failed(parser
.parseOptionalAttrDict(result
.attributes
)))
1180 std::unique_ptr
<Region
> region
= std::make_unique
<Region
>();
1181 if (parser
.parseRegion(*region
, {}))
1183 result
.addRegion(std::move(region
));
1185 // Generic ops may specify that a subset of its outputs are tensors. Such
1186 // outputs are specified in the result type.
1187 // TODO: may need to move output parsing before region parsing.
1188 // Need to wait for declarative assembly resolution to decide.
1189 SmallVector
<Type
, 1> outputTensorsTypes
;
1190 if (parseNamedStructuredOpResults(parser
, outputTensorsTypes
))
1192 result
.addTypes(outputTensorsTypes
);
1197 static void getGenericEffectsImpl(
1198 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
1200 LinalgOp linalgOp
) {
1201 for (auto [index
, operand
] : llvm::enumerate(linalgOp
.getDpsInputs())) {
1202 if (!llvm::isa
<MemRefType
>(operand
.getType()))
1204 effects
.emplace_back(
1205 MemoryEffects::Read::get(), &linalgOp
->getOpOperand(index
), /*stage=*/0,
1206 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1209 for (OpOperand
&operand
: linalgOp
.getDpsInitsMutable()) {
1210 if (!llvm::isa
<MemRefType
>(operand
.get().getType()))
1212 if (linalgOp
.payloadUsesValueFromOperand(&operand
)) {
1213 effects
.emplace_back(MemoryEffects::Read::get(), &operand
, /*stage=*/0,
1214 /*effectOnFullRegion=*/true,
1215 SideEffects::DefaultResource::get());
1217 effects
.emplace_back(MemoryEffects::Write::get(), &operand
, /*stage=*/0,
1218 /*effectOnFullRegion=*/true,
1219 SideEffects::DefaultResource::get());
1223 void GenericOp::getEffects(
1224 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
1226 getGenericEffectsImpl(effects
, cast
<LinalgOp
>(getOperation()));
1229 static Speculation::Speculatability
1230 getGenericSpeculatabilityImpl(LinalgOp linalgOp
) {
1231 // Operands with value semantics are speculatable, while operands with memory
1232 // semantics are not.
1233 if (!linalgOp
.hasPureTensorSemantics())
1234 return Speculation::NotSpeculatable
;
1235 // The body of the op can still have speculation in its region.
1236 return Speculation::RecursivelySpeculatable
;
1239 Speculation::Speculatability
GenericOp::getSpeculatability() {
1240 return getGenericSpeculatabilityImpl(cast
<LinalgOp
>(getOperation()));
1243 LogicalResult
GenericOp::verify() { return success(); }
1247 /// Remove any linalg operation (on tensors) that are just copying
1248 /// the values from inputs to the results. Requirements are
1249 /// 1) All iterator types are parallel
1250 /// 2) The body contains just a yield operation with the yielded values being
1251 /// the arguments corresponding to the operands.
1252 template <typename OpTy
>
1253 struct EraseIdentityLinalgOp
: public OpRewritePattern
<OpTy
> {
1254 using OpRewritePattern
<OpTy
>::OpRewritePattern
;
1256 LogicalResult
matchAndRewrite(OpTy linalgOp
,
1257 PatternRewriter
&rewriter
) const override
{
1258 // All indexing maps must be equal. It follows that they are permutations.
1259 if (!llvm::all_equal(linalgOp
.getIndexingMapsArray()))
1262 // Check that the body of the linalg operation is just a linalg.yield
1264 Block
&body
= linalgOp
->getRegion(0).front();
1265 if (!llvm::hasSingleElement(body
))
1267 auto yieldOp
= dyn_cast
<linalg::YieldOp
>(body
.getTerminator());
1271 // In the buffer case, we need to check exact buffer equality.
1272 if (linalgOp
.hasPureBufferSemantics()) {
1273 if (linalgOp
.getNumDpsInputs() == 1 && linalgOp
.getNumDpsInits() == 1 &&
1274 linalgOp
.getDpsInputOperand(0)->get() ==
1275 linalgOp
.getDpsInitOperand(0)->get()) {
1276 rewriter
.eraseOp(linalgOp
);
1282 // Mixed semantics is not supported yet.
1283 if (!linalgOp
.hasPureTensorSemantics())
1286 // Get the argument number of the returned values. That is the operand
1287 // number to use for replacing uses of this operation.
1288 SmallVector
<Value
> returnedArgs
;
1289 for (const auto &yieldVal
: llvm::enumerate(yieldOp
.getValues())) {
1290 auto yieldArg
= llvm::dyn_cast
<BlockArgument
>(yieldVal
.value());
1291 if (!yieldArg
|| yieldArg
.getOwner() != &body
)
1293 unsigned argumentNumber
= yieldArg
.getArgNumber();
1294 Value returnedArg
= linalgOp
->getOperand(argumentNumber
);
1295 Type resultType
= linalgOp
->getResult(yieldVal
.index()).getType();
1296 // The input can have a different type than the result, e.g. a dynamic
1297 // input dimension can be turned into a static output dimension.
1298 Type returnType
= returnedArg
.getType();
1299 if (returnType
!= resultType
) {
1300 // Distinguish between sparse conversion or dense tensor casting.
1301 // TODO: unify the two ops?
1302 if (sparse_tensor::getSparseTensorEncoding(returnType
) ||
1303 sparse_tensor::getSparseTensorEncoding(resultType
))
1304 returnedArg
= rewriter
.create
<sparse_tensor::ConvertOp
>(
1305 linalgOp
.getLoc(), resultType
, returnedArg
);
1307 if (!tensor::CastOp::areCastCompatible(returnedArg
.getType(),
1310 returnedArg
= rewriter
.create
<tensor::CastOp
>(
1311 linalgOp
.getLoc(), resultType
, returnedArg
);
1314 returnedArgs
.push_back(returnedArg
);
1317 if (returnedArgs
.size() != linalgOp
->getNumResults())
1319 rewriter
.replaceOp(linalgOp
, returnedArgs
);
1326 void GenericOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
1327 MLIRContext
*context
) {
1328 results
.add
<EraseIdentityLinalgOp
<GenericOp
>>(context
);
1331 LogicalResult
GenericOp::fold(FoldAdaptor
, SmallVectorImpl
<OpFoldResult
> &) {
1332 return memref::foldMemRefCast(*this);
1335 //===----------------------------------------------------------------------===//
1337 //===----------------------------------------------------------------------===//
1339 static ParseResult
parseDstStyleOp(
1340 OpAsmParser
&parser
, OperationState
&result
,
1341 function_ref
<ParseResult(OpAsmParser
&, NamedAttrList
&)> parseAttrsFn
=
1343 // Parse `ins` and `outs`.
1344 SmallVector
<Type
, 4> inputTypes
, outputTypes
;
1345 if (parseCommonStructuredOpParts(parser
, result
, inputTypes
, outputTypes
,
1346 /*addOperandSegmentSizes=*/false))
1349 // Add result types.
1350 for (Type outputType
: outputTypes
) {
1351 if (llvm::isa
<RankedTensorType
>(outputType
))
1352 result
.addTypes(outputType
);
1355 // Parse required attributes.
1356 if (parseAttrsFn
&& failed(parseAttrsFn(parser
, result
.attributes
)))
1359 // Parse optional attributes.
1360 if (parser
.parseOptionalAttrDict(result
.attributes
))
1365 void MapOp::getAsmBlockArgumentNames(Region
®ion
,
1366 OpAsmSetValueNameFn setNameFn
) {
1367 for (Value v
: getRegionInputArgs())
1371 void MapOp::getAsmResultNames(function_ref
<void(Value
, StringRef
)> setNameFn
) {
1372 if (!getResults().empty())
1373 setNameFn(getResults().front(), "mapped");
1377 OpBuilder
&builder
, OperationState
&result
, ValueRange inputs
, Value init
,
1378 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
,
1379 ArrayRef
<NamedAttribute
> attributes
) {
1380 build(builder
, result
, TypeRange
{}, inputs
, init
);
1381 result
.addAttributes(attributes
);
1383 // Add output types for `RankedTensorType` output arguments.
1384 Type initType
= init
.getType();
1385 if (llvm::isa
<RankedTensorType
>(initType
))
1386 result
.addTypes(initType
);
1389 buildGenericRegion(builder
, result
.location
, *result
.regions
.front(),
1390 inputs
, /*outputs=*/{}, bodyBuild
);
1393 static void addBodyWithPayloadOp(OpAsmParser
&parser
, OperationState
&result
,
1394 const OperationName
&payloadOpName
,
1395 const NamedAttrList
&payloadOpAttrs
,
1396 ArrayRef
<Value
> operands
,
1397 bool initFirst
= false) {
1398 OpBuilder
b(parser
.getContext());
1399 Region
*body
= result
.addRegion();
1400 Block
&block
= body
->emplaceBlock();
1401 b
.setInsertionPointToStart(&block
);
1402 SmallVector
<Value
> bbArgs
;
1403 for (auto &operand
: operands
) {
1405 llvm::cast
<ShapedType
>(operand
.getType()).getElementType(),
1408 SmallVector
<Value
> payloadOpOperands
;
1409 // If initFirst flag is enabled, we consider init as the first position of
1410 // payload operands.
1412 payloadOpOperands
.push_back(block
.getArguments().back());
1413 for (const auto &arg
: block
.getArguments().drop_back())
1414 payloadOpOperands
.push_back(arg
);
1416 payloadOpOperands
= {block
.getArguments().begin(),
1417 block
.getArguments().end()};
1420 Operation
*payloadOp
= b
.create(
1421 result
.location
, b
.getStringAttr(payloadOpName
.getStringRef()),
1423 TypeRange
{llvm::cast
<ShapedType
>(result
.operands
.back().getType())
1426 b
.create
<YieldOp
>(result
.location
, payloadOp
->getResults());
1429 ParseResult
MapOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
1430 std::optional
<OperationName
> payloadOpName
;
1431 NamedAttrList payloadOpAttrs
;
1432 if (succeeded(parser
.parseOptionalLBrace())) {
1433 FailureOr
<OperationName
> operationName
= parser
.parseCustomOperationName();
1434 if (failed(operationName
))
1436 if (parser
.parseOptionalAttrDict(payloadOpAttrs
))
1438 payloadOpName
= operationName
.value();
1439 if (parser
.parseRBrace())
1443 if (parseDstStyleOp(parser
, result
))
1446 if (payloadOpName
.has_value()) {
1447 if (!result
.operands
.empty())
1448 addBodyWithPayloadOp(parser
, result
, payloadOpName
.value(),
1450 ArrayRef(result
.operands
).drop_back());
1454 SmallVector
<OpAsmParser::Argument
> regionArgs
;
1455 if (parser
.parseArgumentList(regionArgs
, OpAsmParser::Delimiter::Paren
,
1456 /*allowType=*/true, /*allowAttrs=*/true)) {
1459 Region
*body
= result
.addRegion();
1460 if (parser
.parseRegion(*body
, regionArgs
))
1466 // Retrieve the operation from the body, if it is the only one (except
1467 // yield) and if it gets the same amount of arguments as the body does.
1468 // If initFirst flag is enabled, we check that init takes the first position in
1469 // operands of payload.
1470 static Operation
*findPayloadOp(Block
*body
, bool initFirst
= false) {
1471 if (body
->getOperations().size() != 2)
1473 Operation
&payload
= body
->getOperations().front();
1474 assert(isa
<YieldOp
>(body
->getOperations().back()));
1476 if (payload
.getNumOperands() == 0 ||
1477 payload
.getNumOperands() != body
->getNumArguments())
1481 if (payload
.getOperands().back() != body
->getArgument(0))
1484 for (const auto &[operand
, bbArg
] :
1485 llvm::zip(payload
.getOperands(), body
->getArguments().drop_front())) {
1486 if (bbArg
!= operand
)
1490 for (const auto &[operand
, bbArg
] :
1491 llvm::zip(payload
.getOperands(), body
->getArguments())) {
1492 if (bbArg
!= operand
)
1499 void printShortForm(OpAsmPrinter
&p
, Operation
*payloadOp
) {
1500 SmallVector
<StringRef
> elidedAttrs
;
1501 p
<< " { " << payloadOp
->getName().getStringRef();
1502 for (const auto &attr
: payloadOp
->getAttrs()) {
1503 if (auto fastAttr
= dyn_cast
<arith::FastMathFlagsAttr
>(attr
.getValue())) {
1504 if (fastAttr
.getValue() == arith::FastMathFlags::none
) {
1505 elidedAttrs
.push_back(attr
.getName());
1508 if (auto denormAttr
= dyn_cast
<arith::DenormalModeAttr
>(attr
.getValue())) {
1509 if (denormAttr
.getValue() == arith::DenormalMode::ieee
) {
1510 elidedAttrs
.push_back(attr
.getName());
1514 p
.printOptionalAttrDict(payloadOp
->getAttrs(), elidedAttrs
);
1518 void MapOp::print(OpAsmPrinter
&p
) {
1519 Block
*mapper
= getBody();
1520 Operation
*payloadOp
= findPayloadOp(mapper
);
1522 printShortForm(p
, payloadOp
);
1525 printCommonStructuredOpParts(p
, getDpsInputs(), getDpsInits());
1526 p
.printOptionalAttrDict((*this)->getAttrs());
1529 // Print region if the payload op was not detected.
1533 llvm::interleaveComma(mapper
->getArguments(), p
,
1534 [&](auto arg
) { p
.printRegionArgument(arg
); });
1537 p
.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1542 LogicalResult
MapOp::verify() {
1543 auto *bodyBlock
= getBody();
1544 auto blockArgs
= bodyBlock
->getArguments();
1546 // Checks if the number of `inputs` match the arity of the `mapper` region.
1547 if (getInputs().size() != blockArgs
.size())
1548 return emitOpError() << "expects number of operands to match the arity of "
1550 << getInputs().size() << " and " << blockArgs
.size();
1552 // The parameters of mapper should all match the element type of inputs.
1553 for (const auto &[bbArgType
, inputArg
] :
1554 llvm::zip(bodyBlock
->getArgumentTypes(), getInputs())) {
1555 auto inputElemType
=
1556 llvm::cast
<ShapedType
>(inputArg
.getType()).getElementType();
1557 if (bbArgType
!= inputElemType
) {
1558 return emitOpError() << "expected element type of input " << inputElemType
1559 << " to match bbArg type " << bbArgType
;
1563 // The shape of each input must match the shape of the output.
1564 auto outputShape
= getInit().getType().getShape();
1565 for (Type inputArgType
: TypeRange
{getInputs()}) {
1566 auto inputElemShape
= llvm::cast
<ShapedType
>(inputArgType
).getShape();
1567 if (inputElemShape
!= outputShape
) {
1568 return emitOpError() << "expected shape of input (" << inputElemShape
1569 << ") to match shape of output (" << outputShape
1577 SmallVector
<utils::IteratorType
> MapOp::getIteratorTypesArray() {
1578 int64_t rank
= getInit().getType().getRank();
1579 return SmallVector
<utils::IteratorType
>(rank
, utils::IteratorType::parallel
);
1582 ArrayAttr
MapOp::getIndexingMaps() {
1583 Builder
builder(getContext());
1584 int64_t rank
= getInit().getType().getRank();
1585 int64_t numIndexingMaps
= getOperands().size();
1586 return builder
.getAffineMapArrayAttr(SmallVector
<AffineMap
>(
1587 numIndexingMaps
, builder
.getMultiDimIdentityMap(rank
)));
1590 void MapOp::getEffects(
1591 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
1593 getGenericEffectsImpl(effects
, cast
<LinalgOp
>(getOperation()));
1596 Speculation::Speculatability
MapOp::getSpeculatability() {
1597 return getGenericSpeculatabilityImpl(cast
<LinalgOp
>(getOperation()));
1600 //===----------------------------------------------------------------------===//
1602 //===----------------------------------------------------------------------===//
1604 void ReduceOp::getAsmBlockArgumentNames(Region
®ion
,
1605 OpAsmSetValueNameFn setNameFn
) {
1606 for (Value v
: getRegionInputArgs())
1608 for (Value v
: getRegionOutputArgs())
1609 setNameFn(v
, "init");
1612 void ReduceOp::getAsmResultNames(
1613 function_ref
<void(Value
, StringRef
)> setNameFn
) {
1614 if (!getResults().empty())
1615 setNameFn(getResults().front(), "reduced");
1618 void ReduceOp::build(
1619 OpBuilder
&builder
, OperationState
&result
, ValueRange inputs
,
1620 ValueRange inits
, ArrayRef
<int64_t> dimensions
,
1621 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuild
,
1622 ArrayRef
<NamedAttribute
> attributes
) {
1623 build(builder
, result
, TypeRange
{}, inputs
, inits
, dimensions
);
1624 result
.addAttributes(attributes
);
1626 // Add output types for `RankedTensorType` output arguments.
1627 for (Value init
: inits
) {
1628 Type initType
= init
.getType();
1629 if (llvm::isa
<RankedTensorType
>(initType
))
1630 result
.addTypes(initType
);
1634 buildGenericRegion(builder
, result
.location
, *result
.regions
.front(),
1635 inputs
, inits
, bodyBuild
);
1638 SmallVector
<utils::IteratorType
> ReduceOp::getIteratorTypesArray() {
1640 llvm::cast
<ShapedType
>(getInputs()[0].getType()).getRank();
1641 SmallVector
<utils::IteratorType
> iteratorTypes(inputRank
,
1642 utils::IteratorType::parallel
);
1643 for (int64_t reductionDim
: getDimensions())
1644 iteratorTypes
[reductionDim
] = utils::IteratorType::reduction
;
1645 return iteratorTypes
;
1648 ArrayAttr
ReduceOp::getIndexingMaps() {
1650 llvm::cast
<ShapedType
>(getInputs()[0].getType()).getRank();
1651 SmallVector
<AffineMap
> affineMaps(
1653 AffineMap::getMultiDimIdentityMap(inputRank
, getContext()));
1654 AffineMap resultMap
=
1655 AffineMap::getMultiDimIdentityMap(inputRank
, getContext())
1656 .dropResults(getDimensions());
1657 for (int64_t i
= 0, e
= getNumDpsInits(); i
< e
; ++i
)
1658 affineMaps
.push_back(resultMap
);
1659 return Builder(getContext()).getAffineMapArrayAttr(affineMaps
);
1662 void ReduceOp::getEffects(
1663 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
1665 getGenericEffectsImpl(effects
, cast
<LinalgOp
>(getOperation()));
1668 Speculation::Speculatability
ReduceOp::getSpeculatability() {
1669 return getGenericSpeculatabilityImpl(cast
<LinalgOp
>(getOperation()));
1672 static ParseResult
parseDenseI64ArrayAttr(OpAsmParser
&parser
,
1673 NamedAttrList
&attributes
,
1674 StringRef attributeName
) {
1675 if (parser
.parseKeyword(attributeName
) || parser
.parseEqual())
1678 attributes
.set(attributeName
, DenseI64ArrayAttr::parse(parser
, Type
{}));
1682 ParseResult
ReduceOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
1683 std::optional
<OperationName
> payloadOpName
;
1684 NamedAttrList payloadOpAttrs
;
1685 if (succeeded(parser
.parseOptionalLBrace())) {
1686 FailureOr
<OperationName
> operationName
= parser
.parseCustomOperationName();
1687 if (failed(operationName
))
1689 if (parser
.parseOptionalAttrDict(payloadOpAttrs
))
1691 payloadOpName
= operationName
.value();
1692 if (parser
.parseRBrace())
1696 if (parseDstStyleOp(
1697 parser
, result
, [&](OpAsmParser
&parser
, NamedAttrList
&attributes
) {
1698 return parseDenseI64ArrayAttr(parser
, attributes
, "dimensions");
1702 if (payloadOpName
.has_value()) {
1703 addBodyWithPayloadOp(parser
, result
, payloadOpName
.value(), payloadOpAttrs
,
1704 ArrayRef(result
.operands
), /*initFirst=*/true);
1706 SmallVector
<OpAsmParser::Argument
> regionArgs
;
1707 if (parser
.parseArgumentList(regionArgs
, OpAsmParser::Delimiter::Paren
,
1708 /*allowType=*/true, /*allowAttrs=*/true)) {
1712 Region
*body
= result
.addRegion();
1713 if (parser
.parseRegion(*body
, regionArgs
))
1720 static void printDenseI64ArrayAttr(OpAsmPrinter
&p
, StringRef attributeName
,
1721 ArrayRef
<int64_t> attributeValue
) {
1722 p
<< ' ' << attributeName
<< " = [" << attributeValue
<< "] ";
1725 void ReduceOp::print(OpAsmPrinter
&p
) {
1726 Block
*mapper
= getBody();
1727 Operation
*payloadOp
= findPayloadOp(mapper
, /*initFirst=*/true);
1729 printShortForm(p
, payloadOp
);
1732 printCommonStructuredOpParts(p
, getDpsInputs(), getDpsInits());
1733 printDenseI64ArrayAttr(p
, getDimensionsAttrName(), getDimensions());
1734 p
.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1736 // Print region if the payload op was not detected.
1740 llvm::interleaveComma(mapper
->getArguments(), p
,
1741 [&](auto arg
) { p
.printRegionArgument(arg
); });
1744 p
.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1749 LogicalResult
ReduceOp::verify() {
1750 ArrayRef
<int64_t> dimensionsRef
= getDimensions();
1752 for (int64_t i
= 1; i
< getNumDpsInputs(); ++i
) {
1753 if (llvm::cast
<ShapedType
>(getInputs()[i
].getType()).getShape() !=
1754 llvm::cast
<ShapedType
>(getInputs()[0].getType()).getShape()) {
1755 return emitOpError() << "expects all inputs to have the same shapes. "
1756 "Shape at input-index "
1758 << " is not equal to the shape at input-index 0.";
1761 for (int64_t i
= 1; i
< getNumDpsInits(); ++i
) {
1762 if (llvm::cast
<ShapedType
>(getInits()[i
].getType()).getShape() !=
1763 llvm::cast
<ShapedType
>(getInits()[0].getType()).getShape()) {
1764 return emitOpError() << "expects all outputs to have the same shapes. "
1765 "Shape at output-index "
1767 << " is not equal to the shape at output-index 0.";
1770 auto inputType
= llvm::cast
<ShapedType
>(getInputs()[0].getType());
1771 auto initType
= llvm::cast
<ShapedType
>(getInits()[0].getType());
1773 DenseSet
<int64_t> dimensionsToReduce
;
1774 for (int64_t dimension
: dimensionsRef
) {
1775 if (dimension
< 0 || dimension
>= inputType
.getRank()) {
1776 return emitOpError()
1777 << "dimensions for reduction should be in the range [0, "
1778 << inputType
.getRank() - 1 << "].";
1780 dimensionsToReduce
.insert(dimension
);
1783 auto inputDims
= inputType
.getShape();
1784 auto initDims
= initType
.getShape();
1786 // Input dimensions that will be left after the reduction.
1787 SmallVector
<int64_t> reducedInputDims
;
1788 for (const auto &en
: llvm::enumerate(inputDims
)) {
1789 if (!dimensionsToReduce
.count(en
.index()))
1790 reducedInputDims
.push_back(en
.value());
1793 if (reducedInputDims
.size() != static_cast<size_t>(initType
.getRank())) {
1794 return emitOpError() << "number of dimensions after reduction "
1795 << reducedInputDims
.size()
1796 << " doesn't match the init rank "
1797 << initType
.getRank();
1800 if (reducedInputDims
!= initDims
)
1801 return emitOpError() << "init dimensions [" << initDims
1802 << "] doesn't match input dimensions after reduction ["
1803 << reducedInputDims
<< "]";
1805 Block
*block
= getBody();
1806 if (block
->getNumArguments() != this->getNumOperands())
1807 return emitOpError()
1808 << "mismatching number of operands and block arguments";
1810 // Check that the first block arguments match the element type of the inputs.
1811 for (auto [input
, bbArg
] : llvm::zip(getInputs(), block
->getArguments())) {
1812 Type inputElementType
=
1813 llvm::cast
<ShapedType
>(input
.getType()).getElementType();
1814 if (inputElementType
!= bbArg
.getType())
1815 return emitOpError()
1816 << "input element type " << inputElementType
1817 << " does not match corresponding block argument type "
1821 // Check that the last block arguments match the element type of the outputs.
1822 for (auto [output
, bbArg
] : llvm::zip(
1823 getDpsInits(), block
->getArguments().take_back(getNumDpsInits()))) {
1824 auto outputElementType
=
1825 llvm::cast
<ShapedType
>(output
.getType()).getElementType();
1826 if (outputElementType
!= bbArg
.getType())
1827 return emitOpError()
1828 << "output element type " << outputElementType
1829 << " does not match corresponding block argument type "
1835 //===----------------------------------------------------------------------===//
1837 //===----------------------------------------------------------------------===//
1839 static void buildIdentityRegion(OpBuilder
&builder
, Location loc
,
1840 Region
®ion
, ValueRange inputs
,
1841 ValueRange outputs
) {
1842 buildGenericRegion(builder
, loc
, region
, inputs
, outputs
,
1843 [](OpBuilder
&b
, Location loc
, ValueRange args
) {
1845 b
.create
<linalg::YieldOp
>(loc
, args
[0]);
1849 void TransposeOp::build(::mlir::OpBuilder
&builder
,
1850 ::mlir::OperationState
&result
, Value input
, Value init
,
1851 DenseI64ArrayAttr permutation
,
1852 ArrayRef
<NamedAttribute
> attributes
) {
1853 result
.addOperands(input
);
1854 result
.addOperands(init
);
1855 result
.addAttribute(getPermutationAttrName(result
.name
), permutation
);
1856 result
.addAttributes(attributes
);
1858 // Add output types for `RankedTensorType` output arguments.
1859 Type initType
= init
.getType();
1860 if (llvm::isa
<RankedTensorType
>(initType
))
1861 result
.addTypes(initType
);
1863 buildIdentityRegion(builder
, result
.location
, *result
.addRegion(), input
,
1867 void TransposeOp::build(::mlir::OpBuilder
&builder
,
1868 ::mlir::OperationState
&result
, Value input
, Value init
,
1869 ArrayRef
<int64_t> permutation
,
1870 ArrayRef
<NamedAttribute
> attributes
) {
1871 build(builder
, result
, input
, init
, builder
.getDenseI64ArrayAttr(permutation
),
1875 ParseResult
TransposeOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
1876 if (failed(parseDstStyleOp(
1877 parser
, result
, [&](OpAsmParser
&parser
, NamedAttrList
&attributes
) {
1878 return parseDenseI64ArrayAttr(parser
, attributes
, "permutation");
1882 OpBuilder
builder(parser
.getContext());
1883 buildIdentityRegion(builder
, result
.location
, *result
.addRegion(),
1884 /*inputs=*/result
.operands
,
1889 void TransposeOp::getAsmResultNames(
1890 function_ref
<void(Value
, StringRef
)> setNameFn
) {
1891 if (!getResults().empty())
1892 setNameFn(getResults().front(), "transposed");
1895 void TransposeOp::print(OpAsmPrinter
&p
) {
1896 printCommonStructuredOpParts(p
, getDpsInputs(), getDpsInits());
1897 printDenseI64ArrayAttr(p
, getPermutationAttrName(), getPermutation());
1898 p
.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1901 LogicalResult
TransposeOp::verify() {
1902 ArrayRef
<int64_t> permutationRef
= getPermutation();
1904 if (!isPermutationVector(permutationRef
))
1905 return emitOpError("permutation is not valid");
1907 auto inputType
= getInput().getType();
1908 auto initType
= getInit().getType();
1910 int64_t rank
= inputType
.getRank();
1912 if (rank
!= initType
.getRank())
1913 return emitOpError() << "input rank " << rank
1914 << " does not match init rank " << initType
.getRank();
1916 if (rank
!= static_cast<int64_t>(permutationRef
.size()))
1917 return emitOpError() << "size of permutation " << permutationRef
.size()
1918 << " does not match the argument rank " << rank
;
1920 auto inputDims
= inputType
.getShape();
1921 auto initDims
= initType
.getShape();
1923 for (int64_t i
= 0; i
< rank
; ++i
) {
1924 int64_t inputDim
= inputDims
[permutationRef
[i
]];
1925 int64_t initDim
= initDims
[i
];
1927 if (inputDim
!= initDim
) {
1928 return emitOpError() << "dim(result, " << i
<< ") = " << initDim
1929 << " doesn't match dim(input, permutation[" << i
1930 << "]) = " << inputDim
;
1937 SmallVector
<utils::IteratorType
> TransposeOp::getIteratorTypesArray() {
1938 int64_t rank
= getInit().getType().getRank();
1939 return SmallVector
<utils::IteratorType
>(rank
, utils::IteratorType::parallel
);
1942 ArrayAttr
TransposeOp::getIndexingMaps() {
1943 Builder
builder(getContext());
1944 int64_t rank
= getInit().getType().getRank();
1945 return builder
.getAffineMapArrayAttr(
1946 {inversePermutation(AffineMap::getPermutationMap(
1947 llvm::to_vector_of
<unsigned>(getPermutation()), getContext())),
1948 builder
.getMultiDimIdentityMap(rank
)});
1951 void TransposeOp::getEffects(
1952 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
1954 getGenericEffectsImpl(effects
, cast
<LinalgOp
>(getOperation()));
1957 Speculation::Speculatability
TransposeOp::getSpeculatability() {
1958 return getGenericSpeculatabilityImpl(cast
<LinalgOp
>(getOperation()));
1961 LogicalResult
TransposeOp::fold(FoldAdaptor adaptor
,
1962 SmallVectorImpl
<OpFoldResult
> &result
) {
1963 // Only the tensor type is supported.
1964 if (!isa
<TensorType
>(getInput().getType()))
1967 // Single dimension transpose.
1968 if (getPermutation().size() == 0) {
1969 result
.push_back(getInput());
1972 // Identity permutation.
1973 if (isIdentityPermutation(getPermutation())) {
1974 result
.push_back(getInput());
1981 /// Fold transpose with transpose.
1982 struct FoldTransposeWithTranspose
: OpRewritePattern
<linalg::TransposeOp
> {
1983 using OpRewritePattern
<linalg::TransposeOp
>::OpRewritePattern
;
1985 LogicalResult
matchAndRewrite(linalg::TransposeOp transposeOp
,
1986 PatternRewriter
&rewriter
) const override
{
1987 auto defTransposeOp
= transposeOp
.getInput().getDefiningOp
<TransposeOp
>();
1988 if (!defTransposeOp
)
1990 ArrayRef
<int64_t> defPerms
= defTransposeOp
.getPermutation();
1991 ArrayRef
<int64_t> perms
= transposeOp
.getPermutation();
1992 SmallVector
<int64_t> foldedPerms
;
1993 foldedPerms
.reserve(perms
.size());
1994 for (int64_t perm
: perms
)
1995 foldedPerms
.push_back(defPerms
[perm
]);
1997 rewriter
.replaceOpWithNewOp
<TransposeOp
>(
1998 transposeOp
, defTransposeOp
.getInput(), transposeOp
.getInit(),
2004 /// This pattern canonicalize transpose by swapping the order of
2005 /// broadcast and transpose:
2006 /// transpose(broadcast(input)) -> broadcast(transpose(input))
2007 struct SwapTransposeWithBroadcast
: OpRewritePattern
<linalg::TransposeOp
> {
2008 using OpRewritePattern
<linalg::TransposeOp
>::OpRewritePattern
;
2010 LogicalResult
matchAndRewrite(linalg::TransposeOp transposeOp
,
2011 PatternRewriter
&rewriter
) const override
{
2012 Value input
= transposeOp
.getInput();
2013 BroadcastOp broadcastOp
= input
.getDefiningOp
<BroadcastOp
>();
2014 if (!input
.hasOneUse() || !broadcastOp
)
2017 ArrayRef
<int64_t> dimensions
= broadcastOp
.getDimensions();
2018 ArrayRef
<int64_t> perms
= transposeOp
.getPermutation();
2020 // Get new perms and new dimensions.
2021 SmallVector
<int64_t> resultPerms
= dropDims(perms
, dimensions
);
2022 SmallVector
<int64_t> invertPerm
= invertPermutationVector(perms
);
2023 SmallVector
<int64_t> resultDimensions
;
2024 unsigned dimensionSize
= dimensions
.size();
2025 for (unsigned i
= 0; i
< dimensionSize
; ++i
)
2026 resultDimensions
.push_back(invertPerm
[dimensions
[i
]]);
2028 // Create transpose result.
2029 Value broadcastInput
= broadcastOp
.getInput();
2030 Location loc
= transposeOp
.getLoc();
2031 MLIRContext
*ctx
= transposeOp
.getContext();
2032 SmallVector
<OpFoldResult
> dims
;
2033 auto broadcastInputTy
=
2034 mlir::cast
<RankedTensorType
>(broadcastInput
.getType());
2035 unsigned inputRank
= broadcastInputTy
.getRank();
2036 for (unsigned i
= 0; i
< inputRank
; ++i
) {
2037 if (broadcastInputTy
.isDynamicDim(i
)) {
2038 dims
.push_back(rewriter
.create
<tensor::DimOp
>(loc
, broadcastInput
, i
)
2041 dims
.push_back(IntegerAttr::get(IndexType::get(ctx
),
2042 broadcastInputTy
.getDimSize(i
)));
2045 SmallVector
<OpFoldResult
> transposeResultShapes
=
2046 applyPermutation(dims
, resultPerms
);
2047 Value transposeInit
= rewriter
.create
<tensor::EmptyOp
>(
2048 transposeOp
.getLoc(), transposeResultShapes
,
2049 broadcastInputTy
.getElementType());
2051 // Create broadcast(transpose(input)).
2052 Value transposeResult
=
2054 .create
<TransposeOp
>(loc
, broadcastOp
.getInput(), transposeInit
,
2057 rewriter
.replaceOpWithNewOp
<BroadcastOp
>(
2058 transposeOp
, transposeResult
, transposeOp
.getInit(), resultDimensions
);
2063 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
2064 MLIRContext
*context
) {
2065 results
.add
<FoldTransposeWithTranspose
, SwapTransposeWithBroadcast
>(context
);
2068 //===----------------------------------------------------------------------===//
2070 //===----------------------------------------------------------------------===//
2072 void BroadcastOp::build(::mlir::OpBuilder
&builder
,
2073 ::mlir::OperationState
&result
, Value input
, Value init
,
2074 DenseI64ArrayAttr dimensions
,
2075 ArrayRef
<NamedAttribute
> attributes
) {
2076 result
.addOperands(input
);
2077 result
.addOperands(init
);
2078 result
.addAttribute(getDimensionsAttrName(result
.name
), dimensions
);
2079 result
.addAttributes(attributes
);
2081 // Add output types for `RankedTensorType` output arguments.
2082 Type initType
= init
.getType();
2083 if (llvm::isa
<RankedTensorType
>(initType
))
2084 result
.addTypes(initType
);
2086 buildIdentityRegion(builder
, result
.location
, *result
.addRegion(), input
,
2090 void BroadcastOp::build(::mlir::OpBuilder
&builder
,
2091 ::mlir::OperationState
&result
, Value input
, Value init
,
2092 ArrayRef
<int64_t> dimensions
,
2093 ArrayRef
<NamedAttribute
> attributes
) {
2094 build(builder
, result
, input
, init
, builder
.getDenseI64ArrayAttr(dimensions
),
2098 ParseResult
BroadcastOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
2099 if (failed(parseDstStyleOp(
2100 parser
, result
, [&](OpAsmParser
&parser
, NamedAttrList
&attributes
) {
2101 return parseDenseI64ArrayAttr(parser
, attributes
, "dimensions");
2105 OpBuilder
builder(parser
.getContext());
2106 buildIdentityRegion(builder
, result
.location
, *result
.addRegion(),
2107 /*inputs=*/result
.operands
,
2112 void BroadcastOp::getAsmResultNames(
2113 function_ref
<void(Value
, StringRef
)> setNameFn
) {
2114 if (!getResults().empty())
2115 setNameFn(getResults().front(), "broadcasted");
2118 void BroadcastOp::print(OpAsmPrinter
&p
) {
2119 printCommonStructuredOpParts(p
, getDpsInputs(), getDpsInits());
2120 printDenseI64ArrayAttr(p
, getDimensionsAttrName(), getDimensions());
2121 p
.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2124 LogicalResult
BroadcastOp::verify() {
2125 ArrayRef
<int64_t> dimensionsRef
= getDimensions();
2127 auto inputType
= getInput().getType();
2128 auto initType
= getInit().getType();
2130 int64_t inputRank
= inputType
.getRank();
2131 int64_t initRank
= initType
.getRank();
2133 auto inputShape
= inputType
.getShape();
2134 auto initShape
= initType
.getShape();
2136 if ((size_t)inputRank
+ dimensionsRef
.size() != (size_t)initRank
)
2137 return emitOpError() << "input rank plus added dimensions does not "
2138 "match init rank. input rank: "
2140 << ", dimensions size: " << dimensionsRef
.size()
2141 << ", init rank: " << initRank
;
2143 for (const auto &[idx
, dim
] : llvm::enumerate(dimensionsRef
)) {
2144 if (dim
< 0 || dim
>= initRank
)
2145 return emitOpError() << "dimension " << idx
2146 << " is out of range. expected range: [0, "
2147 << initRank
- 1 << "], got: " << dim
;
2150 // Mapping from input dims to init dims.
2151 SmallVector
<int64_t> dimMap
;
2152 for (auto dim
: llvm::seq
<int64_t>(0, initRank
)) {
2153 if (!llvm::is_contained(dimensionsRef
, dim
))
2154 dimMap
.push_back(dim
);
2157 for (const auto &[inputDimIdx
, initDimIdx
] : llvm::enumerate(dimMap
)) {
2158 // This dimensions is mapped from the input. Init and input dims should
2160 if (inputShape
[inputDimIdx
] != initShape
[initDimIdx
])
2161 return emitOpError() << "input dim " << inputDimIdx
2162 << " should match init dim " << initDimIdx
2163 << ". input: " << inputShape
[inputDimIdx
]
2164 << ", init: " << initShape
[initDimIdx
];
2170 SmallVector
<utils::IteratorType
> BroadcastOp::getIteratorTypesArray() {
2171 int64_t rank
= getInit().getType().getRank();
2172 return SmallVector
<utils::IteratorType
>(rank
, utils::IteratorType::parallel
);
2175 ArrayAttr
BroadcastOp::getIndexingMaps() {
2176 Builder
builder(getContext());
2177 int64_t rank
= getInit().getType().getRank();
2178 return builder
.getAffineMapArrayAttr(
2179 {builder
.getMultiDimIdentityMap(rank
).dropResults(getDimensions()),
2180 builder
.getMultiDimIdentityMap(rank
)});
2183 void BroadcastOp::getEffects(
2184 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
2186 getGenericEffectsImpl(effects
, cast
<LinalgOp
>(getOperation()));
2189 Speculation::Speculatability
BroadcastOp::getSpeculatability() {
2190 return getGenericSpeculatabilityImpl(cast
<LinalgOp
>(getOperation()));
2193 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
2194 MLIRContext
*context
) {
2195 results
.add
<EraseIdentityLinalgOp
<BroadcastOp
>>(context
);
2198 //===----------------------------------------------------------------------===//
2200 //===----------------------------------------------------------------------===//
2202 void linalg::YieldOp::print(OpAsmPrinter
&p
) {
2203 if (getNumOperands() > 0)
2204 p
<< ' ' << getOperands();
2205 p
.printOptionalAttrDict((*this)->getAttrs());
2206 if (getNumOperands() > 0)
2207 p
<< " : " << getOperandTypes();
2210 ParseResult
YieldOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
2211 SmallVector
<OpAsmParser::UnresolvedOperand
, 2> opInfo
;
2212 SmallVector
<Type
, 2> types
;
2213 SMLoc loc
= parser
.getCurrentLocation();
2214 return failure(parser
.parseOperandList(opInfo
) ||
2215 parser
.parseOptionalAttrDict(result
.attributes
) ||
2216 (!opInfo
.empty() && parser
.parseColonTypeList(types
)) ||
2217 parser
.resolveOperands(opInfo
, types
, loc
, result
.operands
));
2220 // Check the operand number and types must match the element types of the
2221 // LinalgOp interface's shaped operands.
2222 static LogicalResult
verifyYield(linalg::YieldOp op
, LinalgOp linalgOp
) {
2223 if (op
.getNumOperands() != linalgOp
.getNumDpsInits())
2224 return op
.emitOpError("expected number of yield values (")
2225 << op
.getNumOperands()
2226 << ") to match the number of inits / outs operands of the enclosing "
2227 << "LinalgOp (" << linalgOp
.getNumDpsInits() << ")";
2229 for (OpOperand
&opOperand
: op
->getOpOperands()) {
2230 OpOperand
*outputOperand
=
2231 linalgOp
.getDpsInitOperand(opOperand
.getOperandNumber());
2232 Type elementType
= outputOperand
->get().getType();
2233 if (isa
<MemRefType
, RankedTensorType
>(elementType
))
2234 elementType
= getElementTypeOrSelf(outputOperand
->get().getType());
2235 if (opOperand
.get().getType() != elementType
)
2236 return op
.emitOpError("type of yield operand ")
2237 << (opOperand
.getOperandNumber() + 1) << " ("
2238 << opOperand
.get().getType() << ") doesn't match "
2239 << "the element type of the enclosing linalg.generic op ("
2240 << elementType
<< ")";
2245 LogicalResult
linalg::YieldOp::verify() {
2246 auto *parentOp
= (*this)->getParentOp();
2247 if (parentOp
->getNumRegions() != 1 || parentOp
->getRegion(0).empty())
2248 return emitOpError("expected single non-empty parent region");
2250 if (auto linalgOp
= dyn_cast
<LinalgOp
>(parentOp
))
2251 return verifyYield(*this, linalgOp
);
2253 return emitOpError("expected parent op with LinalgOp interface");
2256 //===----------------------------------------------------------------------===//
2258 //===----------------------------------------------------------------------===//
2260 LogicalResult
IndexOp::verify() {
2261 auto linalgOp
= dyn_cast
<LinalgOp
>((*this)->getParentOp());
2263 return emitOpError("expected parent op with LinalgOp interface");
2264 if (linalgOp
.getNumLoops() <= getDim())
2265 return emitOpError("expected dim (")
2266 << getDim() << ") to be lower than the number of loops ("
2267 << linalgOp
.getNumLoops() << ") of the enclosing LinalgOp";
2271 /////// Operations corresponding to library calls defined with Tablegen ////////
2273 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2275 #define GET_OP_CLASSES
2276 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2278 #define GET_OP_CLASSES
2279 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2281 AffineMap
mlir::linalg::extractOrIdentityMap(std::optional
<AffineMap
> maybeMap
,
2283 MLIRContext
*context
) {
2287 return AffineMap::get(context
);
2288 return AffineMap::getMultiDimIdentityMap(rank
, context
);
2291 SmallVector
<AffineExpr
, 4>
2292 mlir::linalg::makeAffineDimExprs(unsigned num
, unsigned &startIdx
,
2293 MLIRContext
*context
) {
2294 SmallVector
<AffineExpr
, 4> res
;
2296 for (unsigned i
= 0; i
< num
; ++i
)
2297 res
.push_back(getAffineDimExpr(startIdx
++, context
));
2301 SmallVector
<AffineExpr
, 4> mlir::linalg::concat(ArrayRef
<AffineExpr
> a
,
2302 ArrayRef
<AffineExpr
> b
) {
2303 auto rangeA
= llvm::make_range(a
.begin(), a
.end());
2304 auto rangeB
= llvm::make_range(b
.begin(), b
.end());
2305 auto concatRanges
= llvm::concat
<const AffineExpr
>(rangeA
, rangeB
);
2306 return llvm::to_vector
<4>(concatRanges
);
2309 static LogicalResult
appendMangledType(llvm::raw_string_ostream
&ss
, Type t
) {
2310 if (auto memref
= llvm::dyn_cast
<MemRefType
>(t
)) {
2312 for (auto size
: memref
.getShape())
2317 if (failed(appendMangledType(ss
, memref
.getElementType())))
2319 if (auto as
= memref
.getMemorySpace()) {
2320 if (auto attr
= llvm::dyn_cast
<IntegerAttr
>(as
))
2321 ss
<< "as" << attr
.getInt();
2327 if (auto vec
= llvm::dyn_cast
<VectorType
>(t
)) {
2330 vec
.getShape(), [&](int64_t i
) { ss
<< i
; }, [&]() { ss
<< "x"; });
2331 if (failed(appendMangledType(ss
, vec
.getElementType())))
2335 if (t
.isSignlessIntOrIndexOrFloat()) {
2342 std::string
mlir::linalg::generateLibraryCallName(Operation
*op
) {
2343 assert(isa
<LinalgOp
>(op
));
2344 std::string
name(op
->getName().getStringRef().str());
2345 std::string fun
= "";
2346 for (NamedAttribute kv
: op
->getAttrs()) {
2347 if (UnaryFnAttr ufa
= llvm::dyn_cast
<UnaryFnAttr
>(kv
.getValue())) {
2348 fun
= stringifyEnum(ufa
.getValue()).str() + "_";
2349 } else if (BinaryFnAttr bfa
= llvm::dyn_cast
<BinaryFnAttr
>(kv
.getValue())) {
2350 fun
= stringifyEnum(bfa
.getValue()).str() + "_";
2354 std::replace(name
.begin(), name
.end(), '.', '_');
2355 llvm::raw_string_ostream
ss(name
);
2357 for (Type t
: op
->getOperandTypes()) {
2358 if (failed(appendMangledType(ss
, t
)))
2359 return std::string();
2366 //===----------------------------------------------------------------------===//
2367 // Canonicalizers and Folders.
2368 //===----------------------------------------------------------------------===//
2371 struct EraseDeadLinalgOp
: public OpInterfaceRewritePattern
<LinalgOp
> {
2372 using OpInterfaceRewritePattern
<LinalgOp
>::OpInterfaceRewritePattern
;
2374 LogicalResult
matchAndRewrite(LinalgOp op
,
2375 PatternRewriter
&rewriter
) const override
{
2376 for (OpOperand
&opOperand
: op
->getOpOperands()) {
2377 // Linalg "inputs" may be either tensor or memref type.
2378 // tensor<0xelt_type> is a convention that may not always mean
2379 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2380 auto mt
= llvm::dyn_cast
<MemRefType
>(opOperand
.get().getType());
2383 if (llvm::is_contained(op
.getShape(&opOperand
), 0)) {
2384 rewriter
.eraseOp(op
);
2392 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2393 /// result that is more static than the linalg op.
2394 struct FoldTensorCastConsumerOp
: public OpRewritePattern
<tensor::CastOp
> {
2395 using OpRewritePattern
<tensor::CastOp
>::OpRewritePattern
;
2397 LogicalResult
matchAndRewrite(tensor::CastOp castOp
,
2398 PatternRewriter
&rewriter
) const override
{
2399 if (!tensor::canFoldIntoProducerOp(castOp
))
2402 auto linalgOp
= castOp
.getSource().getDefiningOp
<LinalgOp
>();
2406 // Cast can be in conditionally reachable region, if which case folding will
2407 // generate invalid code. Only conservatively fold ops in same block for
2409 if (castOp
->getBlock() != linalgOp
->getBlock())
2412 OpBuilder::InsertionGuard
guard(rewriter
);
2413 rewriter
.setInsertionPoint(linalgOp
);
2415 Location loc
= linalgOp
.getLoc();
2416 OpResult resultValue
= llvm::cast
<OpResult
>(castOp
.getSource());
2417 unsigned resultNumber
= resultValue
.getResultNumber();
2419 llvm::cast
<RankedTensorType
>(castOp
->getResult(0).getType());
2420 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2421 // going from a more dynamic shape to a less dynamic shape. If the producer
2422 // for this cast, i.e. producer of the out operand, is also an operation
2423 // that folds with tensor.cast consumer (like this pattern), the cast will
2424 // continue to propagate as far up the stack as it can go.
2425 OpOperand
*outOperand
= linalgOp
.getDpsInitOperand(resultNumber
);
2427 rewriter
.create
<tensor::CastOp
>(loc
, resultType
, outOperand
->get());
2428 SmallVector
<Value
> newOperands
= linalgOp
.getDpsInputs();
2429 SmallVector
<Value
> outputOperands(linalgOp
.getDpsInits().begin(),
2430 linalgOp
.getDpsInits().end());
2431 outputOperands
[resultNumber
] = newOperand
;
2432 newOperands
.append(outputOperands
.begin(), outputOperands
.end());
2434 SmallVector
<Type
> resultTypes(linalgOp
->result_type_begin(),
2435 linalgOp
->result_type_end());
2436 resultTypes
[resultNumber
] = resultType
;
2437 Operation
*newOp
= clone(rewriter
, linalgOp
, resultTypes
, newOperands
);
2439 // Create a tensor.cast operation back to the original type.
2440 Value castBack
= rewriter
.create
<tensor::CastOp
>(
2441 loc
, resultValue
.getType(), newOp
->getResult(resultNumber
));
2443 SmallVector
<Value
> results(newOp
->result_begin(), newOp
->result_end());
2444 results
[resultNumber
] = castBack
;
2445 rewriter
.replaceOp(linalgOp
, results
);
2446 rewriter
.replaceOp(castOp
, newOp
->getResult(resultNumber
));
2451 /// For each of the operand in `operands` this function maps the static sizes of
2452 /// dimensions to their affine dim expressions.
2453 static void populateMap(LinalgOp linalgOp
, MutableArrayRef
<OpOperand
> operands
,
2454 llvm::DenseMap
<AffineExpr
, int64_t> &affineExprToSize
) {
2455 for (OpOperand
&opOperand
: operands
) {
2456 if (linalgOp
.isScalar(&opOperand
))
2458 Value src
= opOperand
.get();
2459 auto sourceType
= llvm::cast
<RankedTensorType
>(src
.getType());
2460 auto sourceMap
= linalgOp
.getMatchingIndexingMap(&opOperand
);
2462 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2463 // `tensor.cast` operation and source of the cast operation has a static
2464 // shape, then assign it to the `sourceShape`.
2465 auto *parentOp
= src
.getDefiningOp();
2466 ArrayRef
<int64_t> sourceShape
= sourceType
.getShape();
2468 if (auto castOp
= dyn_cast
<tensor::CastOp
>(parentOp
)) {
2469 Value castSource
= castOp
.getSource();
2470 auto castSourceType
=
2471 llvm::dyn_cast
<RankedTensorType
>(castSource
.getType());
2472 if (castSourceType
&& castSourceType
.hasStaticShape())
2473 sourceShape
= castSourceType
.getShape();
2477 // If the source shape's dimension has a static shape, map the affine dim
2478 // expression to the known static size.
2479 for (unsigned i
= 0; i
< sourceShape
.size(); i
++) {
2480 if (sourceType
.isDynamicDim(i
))
2482 if (auto affineDimExpr
= dyn_cast
<AffineDimExpr
>(sourceMap
.getResult(i
)))
2483 affineExprToSize
.try_emplace(affineDimExpr
, sourceShape
[i
]);
2488 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2489 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2490 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2491 /// change then `changeNeeded` is false and same operand is added in the
2492 /// `newOperands` list.
2493 static void createNewOperandWithStaticSizes(
2494 Location loc
, PatternRewriter
&rewriter
, OpOperand
*opOperand
,
2495 llvm::DenseMap
<AffineExpr
, int64_t> &affineExprToSize
, LinalgOp linalgOp
,
2496 SmallVector
<Value
> &newOperands
, SmallVector
<Type
> &resultTypes
,
2497 bool &changeNeeded
) {
2498 Value src
= opOperand
->get();
2499 newOperands
.push_back(src
);
2500 if (linalgOp
.isScalar(opOperand
))
2502 auto sourceType
= llvm::cast
<RankedTensorType
>(src
.getType());
2503 Type resultType
= sourceType
;
2504 if (sourceType
.hasStaticShape() && linalgOp
.isDpsInit(opOperand
)) {
2505 resultTypes
.push_back(resultType
);
2508 ArrayRef
<int64_t> sourceShape
= sourceType
.getShape();
2509 AffineMap sourceMap
= linalgOp
.getMatchingIndexingMap(opOperand
);
2510 SmallVector
<int64_t> newShape
;
2511 // If operand is updated with new shape, `newOperandNeeded` will be
2513 bool newOperandNeeded
= false;
2514 for (unsigned i
= 0; i
< sourceShape
.size(); i
++) {
2515 int64_t dimShape
= sourceShape
[i
];
2516 AffineExpr dimExpr
= sourceMap
.getResult(i
);
2517 if (!affineExprToSize
.contains(dimExpr
) || !sourceType
.isDynamicDim(i
)) {
2518 newShape
.push_back(dimShape
);
2521 // Dimension has a dynamic shape and corresponding affine dim
2522 // expression is present in the map. So assign the size for the
2523 // given affine dim expression to the dimension.
2524 newShape
.push_back(affineExprToSize
[dimExpr
]);
2525 newOperandNeeded
= true;
2527 resultType
= RankedTensorType::get(newShape
, sourceType
.getElementType());
2528 if (newOperandNeeded
) {
2529 changeNeeded
= true;
2530 // Get the new operand value given its size and element type by
2532 Value newOperand
= rewriter
.create
<tensor::CastOp
>(loc
, resultType
, src
);
2533 unsigned index
= opOperand
->getOperandNumber();
2534 newOperands
[index
] = newOperand
;
2536 if (linalgOp
.isDpsInit(opOperand
))
2537 resultTypes
.push_back(resultType
);
2540 /// Static shapes for the operands can be inferred if any one of the operands
2541 /// have a static shape. This can be done by referring to the affine dim
2542 /// expressions for the operand.
2543 struct InferStaticShapeOfOperands
: public OpInterfaceRewritePattern
<LinalgOp
> {
2544 using OpInterfaceRewritePattern
<LinalgOp
>::OpInterfaceRewritePattern
;
2546 LogicalResult
matchAndRewrite(LinalgOp linalgOp
,
2547 PatternRewriter
&rewriter
) const override
{
2548 if (!linalgOp
.hasPureTensorSemantics())
2551 // Maps must be projected permutations.
2552 if (llvm::any_of(linalgOp
.getIndexingMapsArray(), [](AffineMap map
) {
2553 return !map
.isProjectedPermutation();
2557 // Maps affine dim expressions to the static size of that dimension.
2558 llvm::DenseMap
<AffineExpr
, int64_t> affineExprToSize
;
2559 Location loc
= linalgOp
.getLoc();
2561 // For each of the affine dim expression, check if the size is known. If
2562 // known add that in the map.
2563 populateMap(linalgOp
, linalgOp
->getOpOperands(), affineExprToSize
);
2565 SmallVector
<Value
> newOperands
;
2566 SmallVector
<Type
> resultTypes
;
2568 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2569 // change in their types.
2570 bool changeNeeded
= false;
2571 newOperands
.reserve(linalgOp
->getNumOperands());
2572 resultTypes
.reserve(linalgOp
.getNumDpsInits());
2574 // Iterate over all the operands and update the static sizes.
2575 for (OpOperand
&opOperand
: linalgOp
->getOpOperands()) {
2576 createNewOperandWithStaticSizes(loc
, rewriter
, &opOperand
,
2577 affineExprToSize
, linalgOp
, newOperands
,
2578 resultTypes
, changeNeeded
);
2581 // If the generic op has all the required static information, no
2582 // canonicalization needed.
2587 Operation
*newOp
= clone(rewriter
, linalgOp
, resultTypes
, newOperands
);
2588 SmallVector
<Value
> replacements
;
2589 replacements
.reserve(newOp
->getNumResults());
2590 for (auto it
: llvm::zip(linalgOp
->getResults(), newOp
->getResults())) {
2591 Value newResult
= std::get
<1>(it
);
2592 Value oldResult
= std::get
<0>(it
);
2593 Type newType
= newResult
.getType();
2594 Type oldType
= oldResult
.getType();
2595 replacements
.push_back(
2596 (newType
!= oldType
)
2597 ? rewriter
.create
<tensor::CastOp
>(loc
, oldType
, newResult
)
2600 rewriter
.replaceOp(linalgOp
, replacements
);
2607 // All named ops canonicalizers and folders are auto-generated in the
2610 //===----------------------------------------------------------------------===//
2612 //===----------------------------------------------------------------------===//
2614 LogicalResult
SoftmaxOp::verify() {
2615 ShapedType inputType
= getInputOperandType();
2616 ShapedType outputType
= getOutputOperandType();
2618 ArrayRef
<int64_t> inputShape
= inputType
.getShape();
2619 ArrayRef
<int64_t> outputShape
= outputType
.getShape();
2620 if (failed(verifyCompatibleShape(inputShape
, outputShape
)))
2621 return emitOpError("incompatible output shape");
2623 int64_t inputRank
= getInputOperandRank();
2624 int64_t dimension
= getDimension();
2625 if ((dimension
< 0) || (dimension
>= inputRank
))
2626 return emitOpError("incorrect dimension specified");
2631 SmallVector
<Range
> SoftmaxOp::getIterationDomain(OpBuilder
&builder
) {
2632 int64_t operandRank
= getInputOperandRank();
2633 SmallVector
<Range
> loopBounds(operandRank
);
2634 Location loc
= getLoc();
2635 Value zero
= builder
.create
<arith::ConstantIndexOp
>(loc
, 0);
2636 Value one
= builder
.create
<arith::ConstantIndexOp
>(loc
, 1);
2637 Value source
= getInput();
2638 for (auto dim
: llvm::seq
<int64_t>(0, operandRank
)) {
2639 loopBounds
[dim
].offset
= zero
;
2640 loopBounds
[dim
].size
= getDimValue(builder
, loc
, source
, dim
);
2641 loopBounds
[dim
].stride
= one
;
2646 SmallVector
<utils::IteratorType
> SoftmaxOp::getLoopIteratorTypes() {
2647 SmallVector
<utils::IteratorType
> iteratorTypes(getInputOperandRank(),
2648 utils::IteratorType::parallel
);
2649 iteratorTypes
[getDimension()] = utils::IteratorType::reduction
;
2650 return iteratorTypes
;
2653 FailureOr
<TilingResult
>
2654 SoftmaxOp::getTiledImplementation(OpBuilder
&builder
,
2655 ArrayRef
<OpFoldResult
> offsets
,
2656 ArrayRef
<OpFoldResult
> sizes
) {
2657 int64_t rank
= getInputOperandRank();
2658 auto oneAttr
= builder
.getI64IntegerAttr(1);
2659 SmallVector
<OpFoldResult
> strides(rank
, oneAttr
);
2660 SmallVector
<Value
> tiledOperands
;
2661 Operation
*inputSlice
=
2662 getSlice(builder
, getLoc(), getInput(), offsets
, sizes
, strides
);
2664 return emitOpError("failed to compute input slice");
2666 tiledOperands
.emplace_back(inputSlice
->getResult(0));
2667 Operation
*outputSlice
=
2668 getSlice(builder
, getLoc(), getOutput(), offsets
, sizes
, strides
);
2670 return emitOpError("failed to compute output slice");
2672 tiledOperands
.emplace_back(outputSlice
->getResult(0));
2674 SmallVector
<Type
, 4> resultTypes
;
2675 if (hasPureTensorSemantics())
2676 resultTypes
.push_back(tiledOperands
[1].getType());
2677 Operation
*tiledOp
=
2678 mlir::clone(builder
, getOperation(), resultTypes
, tiledOperands
);
2680 return TilingResult
{
2682 SmallVector
<Value
>(tiledOp
->getResults()),
2683 llvm::to_vector(ArrayRef
<Operation
*>{inputSlice
, outputSlice
})};
2686 LogicalResult
SoftmaxOp::getResultTilePosition(
2687 OpBuilder
&builder
, unsigned resultNumber
, ArrayRef
<OpFoldResult
> offsets
,
2688 ArrayRef
<OpFoldResult
> sizes
, SmallVector
<OpFoldResult
> &resultOffsets
,
2689 SmallVector
<OpFoldResult
> &resultSizes
) {
2690 if (resultNumber
== 0) {
2691 resultOffsets
.assign(offsets
.begin(), offsets
.end());
2692 resultSizes
.assign(sizes
.begin(), sizes
.end());
2698 // cast(dynamic) -> static.
2699 LogicalResult
SoftmaxOp::fold(FoldAdaptor
, SmallVectorImpl
<OpFoldResult
> &) {
2700 return memref::foldMemRefCast(*this);
2704 SoftmaxOp::reifyResultShapes(OpBuilder
&b
,
2705 ReifiedRankedShapedTypeDims
&reifiedReturnShapes
) {
2706 SmallVector
<OpFoldResult
> shapes
;
2707 Location loc
= getOperation()->getLoc();
2708 IRRewriter
rewriter(b
);
2709 auto inputShapedType
= llvm::cast
<ShapedType
>(getInputOperandType());
2710 auto outputShapedType
= llvm::cast
<ShapedType
>(getOutputOperandType());
2711 for (int64_t dim
: llvm::seq
<int64_t>(0, getOutputOperandRank())) {
2712 if (!outputShapedType
.isDynamicDim(dim
)) {
2713 // Static dim: Return IntegerAttr.
2714 shapes
.push_back(b
.getIndexAttr(inputShapedType
.getDimSize(dim
)));
2716 // Dynamic dim: Return Value.
2717 OpFoldResult ofr
= createOrFoldDimOp(b
, loc
, getInput(), dim
);
2718 shapes
.push_back(getValueOrCreateConstantIndexOp(b
, loc
, ofr
));
2721 reifiedReturnShapes
.emplace_back(std::move(shapes
));
2725 void SoftmaxOp::getEffects(
2726 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
2728 for (auto [index
, operand
] : llvm::enumerate(getDpsInputs())) {
2729 if (!llvm::isa
<MemRefType
>(operand
.getType()))
2731 effects
.emplace_back(MemoryEffects::Read::get(),
2732 &getOperation()->getOpOperand(index
), /*stage=*/0,
2733 /*effectOnFullRegion=*/true,
2734 SideEffects::DefaultResource::get());
2737 for (OpOperand
&operand
: getDpsInitsMutable()) {
2738 if (!llvm::isa
<MemRefType
>(operand
.get().getType()))
2740 effects
.emplace_back(MemoryEffects::Read::get(), &operand
, /*stage=*/0,
2741 /*effectOnFullRegion=*/true,
2742 SideEffects::DefaultResource::get());
2743 effects
.emplace_back(MemoryEffects::Write::get(), &operand
, /*stage=*/0,
2744 /*effectOnFullRegion=*/true,
2745 SideEffects::DefaultResource::get());
2749 // Helper functions for softmax decomposition.
2752 // Helper function to produce the iterator types (reduction or parallel) and
2753 // affine maps for the iterators used in the decomposition of softmax.
2754 // This method creates:
2755 // If allParallel == true:
2756 // - iterator type: {parallel, ..., parallel}
2758 // -- identity with inputRank dimensions.
2759 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2760 // where N == inputRank.
2762 // If allParallel == false:
2763 // - iterator type at dim(i) == parallel for i != \p dim and
2764 // dim(dim) == reduction.
2766 // -- identity with inputRank dimensions.
2767 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2768 // where N == inputRank.
2769 static std::tuple
<SmallVector
<utils::IteratorType
>, SmallVector
<AffineMap
>>
2770 computeIteratorTypesAndIndexingMaps(OpBuilder
&builder
, int64_t inputRank
,
2771 int64_t dim
, bool allParallel
= false) {
2772 SmallVector
<utils::IteratorType
> iteratorTypes(inputRank
,
2773 utils::IteratorType::parallel
);
2775 iteratorTypes
[dim
] = utils::IteratorType::reduction
;
2776 MLIRContext
*ctxt
= builder
.getContext();
2777 auto identityMap
= AffineMap::getMultiDimIdentityMap(inputRank
, ctxt
);
2778 SmallVector
<AffineExpr
, 2> affineExprs
;
2779 for (int i
= 0; i
< inputRank
; i
++) {
2781 affineExprs
.push_back(mlir::getAffineDimExpr(i
, ctxt
));
2784 AffineMap::get(inputRank
, /*symbols=*/0, affineExprs
, ctxt
);
2785 SmallVector
<AffineMap
> indexingMaps
{identityMap
, reductionMap
};
2786 return std::make_tuple(iteratorTypes
, indexingMaps
);
2789 // Helper function to produce a linalg.generic that computes a reduction on
2790 // dimension \p dim with the operation type \p T.
2791 template <typename T
>
2792 static Value
reduce(OpBuilder
&builder
, Location loc
, Value input
, Value output
,
2794 auto inputType
= cast
<ShapedType
>(input
.getType());
2795 ArrayRef
<int64_t> inputShape
= inputType
.getShape();
2796 int64_t inputRank
= inputShape
.size();
2797 auto [iteratorTypes
, indexingMaps
] =
2798 computeIteratorTypesAndIndexingMaps(builder
, inputRank
, dim
);
2799 assert(indexingMaps
.size() == 2 &&
2800 "We should have two maps: 1 for the input, 1 for the output");
2801 assert(indexingMaps
[0].isIdentity() && "input map should be identity");
2803 auto genericOp
= builder
.create
<linalg::GenericOp
>(
2804 loc
, output
.getType(), input
, output
, indexingMaps
, iteratorTypes
,
2805 [&](OpBuilder
&b
, Location loc
, ValueRange args
) {
2806 Value result
= b
.create
<T
>(loc
, args
[0], args
[1]);
2807 b
.create
<linalg::YieldOp
>(loc
, result
);
2809 return genericOp
.getResult(0);
2812 /// Produce a linalg generic that computes the second step of the softmax
2813 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2814 /// on dimension \p dim.
2815 static Value
buildSubAndExpOp(OpBuilder
&builder
, Location loc
, Value input
,
2816 Value max
, Value output
, int64_t dim
) {
2817 auto inputType
= cast
<ShapedType
>(input
.getType());
2818 ArrayRef
<int64_t> inputShape
= inputType
.getShape();
2819 int64_t inputRank
= inputShape
.size();
2820 auto [iteratorTypes
, indexingMaps
] = computeIteratorTypesAndIndexingMaps(
2821 builder
, inputRank
, dim
, /*allParallel=*/true);
2822 assert(indexingMaps
.size() == 2 && "We should have one map for each input");
2823 assert(indexingMaps
[0].isIdentity() && "input map should be identity");
2824 // Add the affine map for the output argument.
2825 indexingMaps
.push_back(indexingMaps
[0]);
2826 auto genericOp
= builder
.create
<linalg::GenericOp
>(
2827 loc
, input
.getType(), ValueRange
{input
, max
}, output
, indexingMaps
,
2828 iteratorTypes
, [&](OpBuilder
&b
, Location loc
, ValueRange args
) {
2829 Value diff
= b
.create
<arith::SubFOp
>(loc
, args
[0], args
[1]);
2830 Value result
= b
.create
<math::ExpOp
>(loc
, diff
);
2831 b
.create
<linalg::YieldOp
>(loc
, result
);
2833 return genericOp
.getResult(0);
2836 /// Produce a linalg generic that computes the final step of the softmax
2838 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2841 static Value
buildDivOp(OpBuilder
&builder
, Location loc
, Value numerator
,
2842 Value denominator
, Value output
, int64_t dim
) {
2843 auto inputType
= cast
<ShapedType
>(numerator
.getType());
2844 ArrayRef
<int64_t> inputShape
= inputType
.getShape();
2845 int64_t inputRank
= inputShape
.size();
2846 auto [iteratorTypes
, indexingMaps
] = computeIteratorTypesAndIndexingMaps(
2847 builder
, inputRank
, dim
, /*allParallel=*/true);
2848 assert(indexingMaps
.size() == 2 &&
2849 "We should have one map for each input (2)");
2850 assert(indexingMaps
[0].isIdentity() && "Numerator map should be identity");
2851 // Add the affine map for the output tensor.
2852 indexingMaps
.push_back(indexingMaps
[0]);
2853 auto genericOp
= builder
.create
<linalg::GenericOp
>(
2854 loc
, numerator
.getType(), ValueRange
{numerator
, denominator
}, output
,
2855 indexingMaps
, iteratorTypes
,
2856 [&](OpBuilder
&b
, Location loc
, ValueRange args
) {
2857 Value result
= b
.create
<arith::DivFOp
>(loc
, args
[0], args
[1]);
2858 b
.create
<linalg::YieldOp
>(loc
, result
);
2860 return genericOp
.getResult(0);
2862 // @} End helper functions for softmax decomposition.
2864 /// Given an N-dimensional tensor x, this method converts
2865 /// softmax(x) to the following sequence of operations:
2867 /// 1. Compute the max of x along dimension d. This results
2868 /// in a N-1 dimensional tensor m.
2869 /// m = max(x, dim = d)
2871 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2872 /// a N dimensional tensor z.
2875 /// 3. Compute the sum of z along dimension d. This results in
2876 /// a N-1 dimensional tensor l.
2877 /// l = sum(z, dim = d)
2879 /// 4. Divide z and l. This gives the N-dimensional softmax.
2882 FailureOr
<SmallVector
<Value
>> SoftmaxOp::decomposeOperation(OpBuilder
&b
) {
2883 OpBuilder::InsertionGuard
guard(b
);
2884 b
.setInsertionPoint(*this);
2885 Location loc
= getLoc();
2886 Value input
= getInput();
2887 ShapedType inputType
= getInputOperandType();
2888 Type elementType
= inputType
.getElementType();
2889 int64_t reductionDim
= getDimension();
2890 SmallVector
<OpFoldResult
> dims
= tensor::getMixedSizes(b
, loc
, input
);
2891 Value output
= getOutput();
2892 dims
.erase(dims
.begin() + reductionDim
);
2893 // Step 1: Compute max along dim.
2894 Value outputReduce
= b
.create
<tensor::EmptyOp
>(loc
, dims
, elementType
);
2895 Value neutralForMaxF
= arith::getIdentityValue(arith::AtomicRMWKind::maximumf
,
2896 elementType
, b
, loc
,
2897 /*useOnlyFiniteValue=*/true);
2898 Value neutralForMaxFInit
=
2899 b
.create
<linalg::FillOp
>(loc
, Value
{neutralForMaxF
}, outputReduce
)
2902 reduce
<arith::MaxNumFOp
>(b
, loc
, input
, neutralForMaxFInit
, reductionDim
);
2904 // Step 2: Subtract max from input and exponentiate.
2905 Value numerator
= buildSubAndExpOp(b
, loc
, input
, max
, output
, reductionDim
);
2907 // Step 3: Compute sum along dim.
2908 Value zero
= arith::getIdentityValue(arith::AtomicRMWKind::addf
, elementType
,
2909 b
, loc
, /*useOnlyFiniteValue=*/true);
2911 b
.create
<linalg::FillOp
>(loc
, Value
{zero
}, outputReduce
).result();
2913 reduce
<arith::AddFOp
>(b
, loc
, numerator
, zeroInit
, reductionDim
);
2915 // Step 4: Compute softmax.
2917 buildDivOp(b
, loc
, numerator
, denominator
, output
, reductionDim
);
2918 return SmallVector
<Value
>{result
};
2921 //===----------------------------------------------------------------------===//
2922 // WinogradFilterTransformOp
2923 //===----------------------------------------------------------------------===//
2925 LogicalResult
WinogradFilterTransformOp::verify() {
2926 auto filterType
= cast
<ShapedType
>(getFilter().getType());
2927 ArrayRef
<int64_t> filterShape
= filterType
.getShape();
2928 int64_t filterH
= filterShape
[getFilterHDim()];
2929 int64_t filterW
= filterShape
[getFilterWDim()];
2933 if (filterH
!= r
&& filterH
!= 1)
2934 return emitOpError("expect filter height either equals to r or 1");
2935 if (filterW
!= r
&& filterW
!= 1)
2936 return emitOpError("expect filter width either equals to r or 1");
2937 if (filterH
== 1 && filterW
== 1)
2938 return emitOpError("expect either filter height or width equals to r");
2940 SmallVector
<int64_t> expectedOutputShape
;
2941 expectedOutputShape
.push_back(filterH
== r
? m
+ r
- 1 : 1);
2942 expectedOutputShape
.push_back(filterW
== r
? m
+ r
- 1 : 1);
2943 expectedOutputShape
.push_back(filterShape
[getFilterCDim()]);
2944 expectedOutputShape
.push_back(filterShape
[getFilterFDim()]);
2946 auto outputType
= cast
<ShapedType
>(getOutput().getType());
2947 ArrayRef
<int64_t> outputShape
= outputType
.getShape();
2948 if (failed(verifyCompatibleShape(expectedOutputShape
, outputShape
))) {
2949 return emitOpError("the output shape is not expected");
2955 WinogradFilterTransformOp::getIterationDomain(OpBuilder
&builder
) {
2956 Location loc
= getLoc();
2957 IntegerAttr zeroAttr
= builder
.getIndexAttr(0);
2958 IntegerAttr oneAttr
= builder
.getIndexAttr(1);
2959 Value filter
= getFilter();
2960 int64_t filterRank
= getFilterOperandRank();
2961 SmallVector
<Range
> loopBounds(filterRank
);
2962 for (unsigned dim
= 0; dim
< filterRank
; ++dim
) {
2963 loopBounds
[dim
].offset
= zeroAttr
;
2964 loopBounds
[dim
].size
= getDimValue(builder
, loc
, filter
, dim
);
2965 loopBounds
[dim
].stride
= oneAttr
;
2970 SmallVector
<utils::IteratorType
>
2971 WinogradFilterTransformOp::getLoopIteratorTypes() {
2972 int64_t filterRank
= getFilterOperandRank();
2973 SmallVector
<utils::IteratorType
> iteratorTypes(filterRank
,
2974 utils::IteratorType::parallel
);
2975 return iteratorTypes
;
2978 LogicalResult
WinogradFilterTransformOp::getResultTilePosition(
2979 OpBuilder
&builder
, unsigned resultNumber
, ArrayRef
<OpFoldResult
> offsets
,
2980 ArrayRef
<OpFoldResult
> sizes
, SmallVector
<OpFoldResult
> &resultOffsets
,
2981 SmallVector
<OpFoldResult
> &resultSizes
) {
2982 IntegerAttr zeroAttr
= builder
.getI64IntegerAttr(0);
2983 ShapedType filterType
= getFilterOperandType();
2984 ArrayRef
<int64_t> filterShape
= filterType
.getShape();
2985 int64_t filterH
= filterShape
[getFilterHDim()];
2986 int64_t filterW
= filterShape
[getFilterWDim()];
2989 int64_t alpha
= m
+ r
- 1;
2990 int64_t alphaH
= filterH
!= 1 ? alpha
: 1;
2991 int64_t alphaW
= filterW
!= 1 ? alpha
: 1;
2992 IntegerAttr alphaHAttr
= builder
.getI64IntegerAttr(alphaH
);
2993 IntegerAttr alphaWAttr
= builder
.getI64IntegerAttr(alphaW
);
2995 resultOffsets
.append(
2996 {zeroAttr
, zeroAttr
, offsets
[getFilterCDim()], offsets
[getFilterFDim()]});
2998 {alphaHAttr
, alphaWAttr
, sizes
[getFilterCDim()], sizes
[getFilterFDim()]});
3003 /// Implement tiling for winograd_filter_transform
3004 /// The input of winograd_filter_transform is (F, KH, KW, C).
3005 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3006 /// Users can specify the tile sizes of F and C.
3007 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3008 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3009 FailureOr
<TilingResult
> WinogradFilterTransformOp::getTiledImplementation(
3010 OpBuilder
&builder
, ArrayRef
<OpFoldResult
> offsets
,
3011 ArrayRef
<OpFoldResult
> sizes
) {
3012 IntegerAttr oneAttr
= builder
.getI64IntegerAttr(1);
3013 IntegerAttr zeroAttr
= builder
.getI64IntegerAttr(0);
3014 ShapedType filterType
= getFilterOperandType();
3015 ArrayRef
<int64_t> filterShape
= filterType
.getShape();
3016 int64_t filterH
= filterShape
[getFilterHDim()];
3017 int64_t filterW
= filterShape
[getFilterWDim()];
3018 IntegerAttr filterHAttr
= builder
.getI64IntegerAttr(filterH
);
3019 IntegerAttr filterWAttr
= builder
.getI64IntegerAttr(filterW
);
3020 SmallVector
<Value
> tiledOperands
;
3021 SmallVector
<OpFoldResult
> sliceOffsets
, sliceSizes
;
3023 sliceOffsets
.append(
3024 {offsets
[getFilterFDim()], zeroAttr
, zeroAttr
, offsets
[getFilterCDim()]});
3025 sliceSizes
.append({sizes
[getFilterFDim()], filterHAttr
, filterWAttr
,
3026 sizes
[getFilterCDim()]});
3027 int64_t filterRank
= getFilterOperandRank();
3028 SmallVector
<OpFoldResult
> filterStrides(filterRank
, oneAttr
);
3029 Location loc
= getLoc();
3030 auto filterSlice
= builder
.create
<tensor::ExtractSliceOp
>(
3031 loc
, getFilter(), sliceOffsets
, sliceSizes
, filterStrides
);
3032 tiledOperands
.emplace_back(filterSlice
);
3034 SmallVector
<OpFoldResult
> resultOffsets
, resultSizes
;
3035 if (failed(getResultTilePosition(builder
, 1, offsets
, sizes
, resultOffsets
,
3039 int64_t outputRank
= getOutputOperandRank();
3040 SmallVector
<OpFoldResult
> outputStrides(outputRank
, oneAttr
);
3041 auto outputSlice
= builder
.create
<tensor::ExtractSliceOp
>(
3042 loc
, getOutput(), resultOffsets
, resultSizes
, outputStrides
);
3043 tiledOperands
.emplace_back(outputSlice
);
3045 SmallVector
<Type
> resultTypes
;
3046 resultTypes
.push_back(tiledOperands
[1].getType());
3047 Operation
*tiledOp
=
3048 mlir::clone(builder
, getOperation(), resultTypes
, tiledOperands
);
3050 return TilingResult
{
3052 SmallVector
<Value
>(tiledOp
->getResults()),
3053 llvm::to_vector(ArrayRef
<Operation
*>{filterSlice
, outputSlice
})};
3056 //===----------------------------------------------------------------------===//
3057 // WinogradInputTransformOp
3058 //===----------------------------------------------------------------------===//
3060 LogicalResult
WinogradInputTransformOp::verify() {
3061 auto inputType
= cast
<ShapedType
>(getInput().getType());
3062 ArrayRef
<int64_t> inputShape
= inputType
.getShape();
3063 int64_t inputH
= inputShape
[getInputHDim()];
3064 int64_t inputW
= inputShape
[getInputWDim()];
3067 int64_t tileSize
= m
+ r
- 1;
3068 bool leftTransform
= inputH
!= 1;
3069 bool rightTransform
= inputW
!= 1;
3071 SmallVector
<int64_t> expectedOutputShape(6, inputH
);
3072 if (ShapedType::isDynamic(inputH
)) {
3073 expectedOutputShape
[getOutputAlphaHDim()] = tileSize
;
3074 expectedOutputShape
[getOutputTileHDim()] = ShapedType::kDynamic
;
3076 expectedOutputShape
[getOutputAlphaHDim()] = leftTransform
? tileSize
: 1;
3077 expectedOutputShape
[getOutputTileHDim()] =
3078 leftTransform
? (inputH
- (r
- 1)) / m
: 1;
3080 if (ShapedType::isDynamic(inputW
)) {
3081 expectedOutputShape
[getOutputAlphaWDim()] = tileSize
;
3082 expectedOutputShape
[getOutputTileWDim()] = ShapedType::kDynamic
;
3084 expectedOutputShape
[getOutputAlphaWDim()] = rightTransform
? tileSize
: 1;
3085 expectedOutputShape
[getOutputTileWDim()] =
3086 rightTransform
? (inputW
- (r
- 1)) / m
: 1;
3088 expectedOutputShape
[getOutputNDim()] = inputShape
[getInputNDim()];
3089 expectedOutputShape
[getOutputCDim()] = inputShape
[getInputCDim()];
3091 auto outputType
= cast
<ShapedType
>(getOutput().getType());
3092 ArrayRef
<int64_t> outputShape
= outputType
.getShape();
3093 if (failed(verifyCompatibleShape(expectedOutputShape
, outputShape
))) {
3094 return emitOpError("the output shape is not expected");
3100 WinogradInputTransformOp::getIterationDomain(OpBuilder
&builder
) {
3101 Location loc
= getLoc();
3102 IntegerAttr zeroAttr
= builder
.getIndexAttr(0);
3103 IntegerAttr oneAttr
= builder
.getIndexAttr(1);
3104 Value output
= getOutput();
3105 int64_t outputRank
= getOutputOperandRank();
3106 SmallVector
<Range
> loopBounds(outputRank
);
3107 for (unsigned dim
= 0; dim
< outputRank
; ++dim
) {
3108 loopBounds
[dim
].offset
= zeroAttr
;
3109 // alphaH, alphaW, tileH, tileW, N, C
3110 loopBounds
[dim
].size
= getDimValue(builder
, loc
, output
, dim
);
3111 loopBounds
[dim
].stride
= oneAttr
;
3116 SmallVector
<utils::IteratorType
>
3117 WinogradInputTransformOp::getLoopIteratorTypes() {
3118 int64_t outputRank
= getOutputOperandRank();
3119 SmallVector
<utils::IteratorType
> iteratorTypes(outputRank
,
3120 utils::IteratorType::parallel
);
3121 return iteratorTypes
;
3124 LogicalResult
WinogradInputTransformOp::getResultTilePosition(
3125 OpBuilder
&builder
, unsigned resultNumber
, ArrayRef
<OpFoldResult
> offsets
,
3126 ArrayRef
<OpFoldResult
> sizes
, SmallVector
<OpFoldResult
> &resultOffsets
,
3127 SmallVector
<OpFoldResult
> &resultSizes
) {
3128 IntegerAttr zeroAttr
= builder
.getI64IntegerAttr(0);
3129 ShapedType inputType
= getInputOperandType();
3130 ArrayRef
<int64_t> inputShape
= inputType
.getShape();
3131 int64_t inputH
= inputShape
[getInputHDim()];
3132 int64_t inputW
= inputShape
[getInputWDim()];
3135 int64_t alpha
= m
+ r
- 1;
3136 int64_t alphaH
= inputH
!= 1 ? alpha
: 1;
3137 int64_t alphaW
= inputW
!= 1 ? alpha
: 1;
3138 IntegerAttr alphaHAttr
= builder
.getI64IntegerAttr(alphaH
);
3139 IntegerAttr alphaWAttr
= builder
.getI64IntegerAttr(alphaW
);
3141 resultOffsets
.append({zeroAttr
, zeroAttr
, offsets
[getOutputTileHDim()],
3142 offsets
[getOutputTileWDim()], offsets
[getOutputNDim()],
3143 offsets
[getOutputCDim()]});
3144 resultSizes
.append({alphaHAttr
, alphaWAttr
, sizes
[getOutputTileHDim()],
3145 sizes
[getOutputTileWDim()], sizes
[getOutputNDim()],
3146 sizes
[getOutputCDim()]});
3151 /// Implement tiling for winograd_input_transform
3152 /// The input of winograd_input_transform is (N, H, W, C).
3153 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3154 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3155 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3156 /// the values for the sizes of tileH, tileW, N, C for one tile.
3157 FailureOr
<TilingResult
>
3158 WinogradInputTransformOp::getTiledImplementation(OpBuilder
&builder
,
3159 ArrayRef
<OpFoldResult
> offsets
,
3160 ArrayRef
<OpFoldResult
> sizes
) {
3161 IntegerAttr oneAttr
= builder
.getI64IntegerAttr(1);
3162 IntegerAttr zeroAttr
= builder
.getI64IntegerAttr(0);
3163 ShapedType inputType
= getInputOperandType();
3164 ArrayRef
<int64_t> inputShape
= inputType
.getShape();
3165 int64_t inputH
= inputShape
[getInputHDim()];
3166 int64_t inputW
= inputShape
[getInputWDim()];
3170 Location loc
= getLoc();
3171 MLIRContext
*context
= builder
.getContext();
3172 auto offsetAffineMap
=
3173 AffineMap::get(1, 0, {builder
.getAffineDimExpr(0) * m
}, context
);
3174 Value mappedOffsetH
= affine::makeComposedAffineApply(
3175 builder
, loc
, offsetAffineMap
, offsets
[getOutputTileHDim()]);
3176 Value mappedOffsetW
= affine::makeComposedAffineApply(
3177 builder
, loc
, offsetAffineMap
, offsets
[getOutputTileWDim()]);
3178 auto sizeAffineMap
= AffineMap::get(
3179 1, 0, {builder
.getAffineDimExpr(0) * m
+ (r
- 1)}, context
);
3180 Value mappedSizeH
= affine::makeComposedAffineApply(
3181 builder
, loc
, sizeAffineMap
, sizes
[getOutputTileHDim()]);
3182 Value mappedSizeW
= affine::makeComposedAffineApply(
3183 builder
, loc
, sizeAffineMap
, sizes
[getOutputTileWDim()]);
3185 SmallVector
<Value
> tiledOperands
;
3186 SmallVector
<OpFoldResult
> sliceOffsets
, sliceSizes
;
3188 OpFoldResult offsetH
=
3189 inputH
!= 1 ? OpFoldResult(mappedOffsetH
) : OpFoldResult(zeroAttr
);
3190 OpFoldResult offsetW
=
3191 inputW
!= 1 ? OpFoldResult(mappedOffsetW
) : OpFoldResult(zeroAttr
);
3192 sliceOffsets
.append(
3193 {offsets
[getOutputNDim()], offsetH
, offsetW
, offsets
[getOutputCDim()]});
3194 OpFoldResult sizeH
=
3195 inputH
!= 1 ? OpFoldResult(mappedSizeH
) : OpFoldResult(oneAttr
);
3196 OpFoldResult sizeW
=
3197 inputW
!= 1 ? OpFoldResult(mappedSizeW
) : OpFoldResult(oneAttr
);
3199 {sizes
[getOutputNDim()], sizeH
, sizeW
, sizes
[getOutputCDim()]});
3200 int64_t inputRank
= getInputOperandRank();
3201 SmallVector
<OpFoldResult
> inputStrides(inputRank
, oneAttr
);
3202 auto inputSlice
= builder
.create
<tensor::ExtractSliceOp
>(
3203 loc
, getInput(), sliceOffsets
, sliceSizes
, inputStrides
);
3204 tiledOperands
.emplace_back(inputSlice
);
3206 SmallVector
<OpFoldResult
> resultOffsets
, resultSizes
;
3207 if (failed(getResultTilePosition(builder
, 1, offsets
, sizes
, resultOffsets
,
3211 int64_t outputRank
= getOutputOperandRank();
3212 SmallVector
<OpFoldResult
> outputStrides(outputRank
, oneAttr
);
3213 auto outputSlice
= builder
.create
<tensor::ExtractSliceOp
>(
3214 loc
, getOutput(), resultOffsets
, resultSizes
, outputStrides
);
3215 tiledOperands
.emplace_back(outputSlice
);
3217 SmallVector
<Type
> resultTypes
;
3218 resultTypes
.push_back(tiledOperands
[1].getType());
3219 Operation
*tiledOp
=
3220 mlir::clone(builder
, getOperation(), resultTypes
, tiledOperands
);
3222 return TilingResult
{
3224 SmallVector
<Value
>(tiledOp
->getResults()),
3225 llvm::to_vector(ArrayRef
<Operation
*>{inputSlice
, outputSlice
})};
3228 //===----------------------------------------------------------------------===//
3229 // WinogradOutputTransformOp
3230 //===----------------------------------------------------------------------===//
3232 LogicalResult
WinogradOutputTransformOp::verify() {
3233 auto valueType
= cast
<ShapedType
>(getValue().getType());
3234 ArrayRef
<int64_t> valueShape
= valueType
.getShape();
3235 int64_t valueH
= valueShape
[getValueAlphaHDim()];
3236 int64_t valueW
= valueShape
[getValueAlphaWDim()];
3237 int64_t valueTileH
= valueShape
[getValueTileHDim()];
3238 int64_t valueTileW
= valueShape
[getValueTileWDim()];
3241 bool leftTransform
= valueH
!= 1;
3242 bool rightTransform
= valueW
!= 1;
3244 int64_t outputRank
= getOutputOperandRank();
3245 SmallVector
<int64_t> expectedOutputShape(outputRank
, valueH
);
3246 if (ShapedType::isDynamic(valueH
) || ShapedType::isDynamic(valueTileH
)) {
3247 expectedOutputShape
[getOutputHDim()] = ShapedType::kDynamic
;
3249 if (valueH
!= (leftTransform
? m
+ r
- 1 : 1))
3250 return emitOpError("expect input height equals to input tile size");
3251 expectedOutputShape
[getOutputHDim()] = (leftTransform
? m
: 1) * valueTileH
;
3253 if (ShapedType::isDynamic(valueW
) || ShapedType::isDynamic(valueTileW
)) {
3254 expectedOutputShape
[getOutputWDim()] = ShapedType::kDynamic
;
3256 if (valueW
!= (rightTransform
? m
+ r
- 1 : 1))
3257 return emitOpError("expect input width equals to input tile size");
3258 expectedOutputShape
[getOutputWDim()] =
3259 (rightTransform
? m
: 1) * valueTileW
;
3261 expectedOutputShape
[getOutputNDim()] = valueShape
[getValueNDim()];
3262 expectedOutputShape
[getOutputFDim()] = valueShape
[getValueFDim()];
3264 auto outputType
= cast
<ShapedType
>(getOutput().getType());
3265 ArrayRef
<int64_t> outputShape
= outputType
.getShape();
3266 if (failed(verifyCompatibleShape(expectedOutputShape
, outputShape
))) {
3267 return emitOpError("the output shape is not expected");
3273 WinogradOutputTransformOp::getIterationDomain(OpBuilder
&builder
) {
3274 Location loc
= getLoc();
3275 IntegerAttr zeroAttr
= builder
.getIndexAttr(0);
3276 IntegerAttr oneAttr
= builder
.getIndexAttr(1);
3277 Value value
= getValue();
3278 int64_t valueRank
= getValueOperandRank();
3279 SmallVector
<Range
> loopBounds(valueRank
);
3280 for (unsigned dim
= 0; dim
< valueRank
; ++dim
) {
3281 loopBounds
[dim
].offset
= zeroAttr
;
3282 // alphaH, alphaW, tileH, tileW, N, F
3283 loopBounds
[dim
].size
= getDimValue(builder
, loc
, value
, dim
);
3284 loopBounds
[dim
].stride
= oneAttr
;
3289 SmallVector
<utils::IteratorType
>
3290 WinogradOutputTransformOp::getLoopIteratorTypes() {
3291 int64_t valueRank
= getValueOperandRank();
3292 SmallVector
<utils::IteratorType
> iteratorTypes(valueRank
,
3293 utils::IteratorType::parallel
);
3294 return iteratorTypes
;
3297 LogicalResult
WinogradOutputTransformOp::getResultTilePosition(
3298 OpBuilder
&builder
, unsigned resultNumber
, ArrayRef
<OpFoldResult
> offsets
,
3299 ArrayRef
<OpFoldResult
> sizes
, SmallVector
<OpFoldResult
> &resultOffsets
,
3300 SmallVector
<OpFoldResult
> &resultSizes
) {
3303 Location loc
= getLoc();
3304 MLIRContext
*context
= builder
.getContext();
3306 AffineMap::get(1, 0, {builder
.getAffineDimExpr(0) * m
}, context
);
3308 Value mappedOffsetH
= affine::makeComposedAffineApply(
3309 builder
, loc
, affineMap
, offsets
[getValueTileHDim()]);
3310 Value mappedOffsetW
= affine::makeComposedAffineApply(
3311 builder
, loc
, affineMap
, offsets
[getValueTileWDim()]);
3312 Value mappedSizeH
= affine::makeComposedAffineApply(
3313 builder
, loc
, affineMap
, sizes
[getValueTileHDim()]);
3314 Value mappedSizeW
= affine::makeComposedAffineApply(
3315 builder
, loc
, affineMap
, sizes
[getValueTileWDim()]);
3317 ShapedType valueType
= getValueOperandType();
3318 ArrayRef
<int64_t> valueShape
= valueType
.getShape();
3319 int64_t valueH
= valueShape
[0];
3320 int64_t valueW
= valueShape
[1];
3321 IntegerAttr oneAttr
= builder
.getI64IntegerAttr(1);
3322 IntegerAttr zeroAttr
= builder
.getI64IntegerAttr(0);
3323 OpFoldResult offsetH
=
3324 valueH
!= 1 ? OpFoldResult(mappedOffsetH
) : OpFoldResult(zeroAttr
);
3325 OpFoldResult offsetW
=
3326 valueW
!= 1 ? OpFoldResult(mappedOffsetW
) : OpFoldResult(zeroAttr
);
3327 OpFoldResult sizeH
=
3328 valueH
!= 1 ? OpFoldResult(mappedSizeH
) : OpFoldResult(oneAttr
);
3329 OpFoldResult sizeW
=
3330 valueW
!= 1 ? OpFoldResult(mappedSizeW
) : OpFoldResult(oneAttr
);
3332 resultOffsets
.append(
3333 {offsets
[getValueNDim()], offsetH
, offsetW
, offsets
[getValueFDim()]});
3335 {sizes
[getValueNDim()], sizeH
, sizeW
, sizes
[getValueFDim()]});
3339 /// Implement tiling for winograd_output_transform
3340 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3341 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3342 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3343 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3344 /// for the sizes of tileH, tileW, N, F for one tile.
3345 FailureOr
<TilingResult
> WinogradOutputTransformOp::getTiledImplementation(
3346 OpBuilder
&builder
, ArrayRef
<OpFoldResult
> offsets
,
3347 ArrayRef
<OpFoldResult
> sizes
) {
3348 IntegerAttr oneAttr
= builder
.getI64IntegerAttr(1);
3349 IntegerAttr zeroAttr
= builder
.getI64IntegerAttr(0);
3350 Location loc
= getLoc();
3351 SmallVector
<Value
> tiledOperands
;
3352 SmallVector
<OpFoldResult
> sliceOffsets
, sliceSizes
;
3354 ShapedType valueType
= getValueOperandType();
3355 ArrayRef
<int64_t> valueShape
= valueType
.getShape();
3356 int64_t alphaH
= valueShape
[getValueAlphaHDim()];
3357 int64_t alphaW
= valueShape
[getValueAlphaWDim()];
3358 IntegerAttr alphaHAttr
= builder
.getI64IntegerAttr(alphaH
);
3359 IntegerAttr alphaWAttr
= builder
.getI64IntegerAttr(alphaW
);
3361 sliceOffsets
.append({zeroAttr
, zeroAttr
, offsets
[getValueTileHDim()],
3362 offsets
[getValueTileWDim()], offsets
[getValueNDim()],
3363 offsets
[getValueFDim()]});
3364 sliceSizes
.append({alphaHAttr
, alphaWAttr
, sizes
[getValueTileHDim()],
3365 sizes
[getValueTileWDim()], sizes
[getValueNDim()],
3366 sizes
[getValueFDim()]});
3367 int64_t valueRank
= getValueOperandRank();
3368 SmallVector
<OpFoldResult
> sliceStrides(valueRank
, oneAttr
);
3369 auto valueSlice
= builder
.create
<tensor::ExtractSliceOp
>(
3370 loc
, getValue(), sliceOffsets
, sliceSizes
, sliceStrides
);
3371 tiledOperands
.emplace_back(valueSlice
);
3373 SmallVector
<OpFoldResult
> resultOffsets
, resultSizes
;
3374 if (failed(getResultTilePosition(builder
, 1, offsets
, sizes
, resultOffsets
,
3378 int64_t outputRank
= getOutputOperandRank();
3379 SmallVector
<OpFoldResult
> strides(outputRank
, oneAttr
);
3380 auto outputSlice
= builder
.create
<tensor::ExtractSliceOp
>(
3381 loc
, getOutput(), resultOffsets
, resultSizes
, strides
);
3382 tiledOperands
.emplace_back(outputSlice
);
3384 SmallVector
<Type
> resultTypes
;
3385 resultTypes
.push_back(tiledOperands
[1].getType());
3386 Operation
*tiledOp
=
3387 mlir::clone(builder
, getOperation(), resultTypes
, tiledOperands
);
3389 return TilingResult
{
3391 SmallVector
<Value
>(tiledOp
->getResults()),
3392 llvm::to_vector(ArrayRef
<Operation
*>{valueSlice
, outputSlice
})};
3395 //===----------------------------------------------------------------------===//
3397 //===----------------------------------------------------------------------===//
3399 void LinalgDialect::getCanonicalizationPatterns(
3400 RewritePatternSet
&results
) const {
3401 results
.add
<EraseDeadLinalgOp
, FoldTensorCastConsumerOp
,
3402 InferStaticShapeOfOperands
>(getContext());
3405 Operation
*LinalgDialect::materializeConstant(OpBuilder
&builder
,
3406 Attribute value
, Type type
,
3408 return arith::ConstantOp::materialize(builder
, value
, type
, loc
);
3411 /// Returns true if the result AffineExpr of the \p explicitMap is same as \p
3413 static bool isValidResultDimExprs(AffineMap explictMap
, AffineMap defaultMap
) {
3414 auto explicitRange
= explictMap
.getResults();
3415 auto defaultRange
= defaultMap
.getResults();
3416 DenseSet
<AffineExpr
> explicitSet(explicitRange
.begin(), explicitRange
.end());
3417 DenseSet
<AffineExpr
> defaultSet(defaultRange
.begin(), defaultRange
.end());
3418 llvm::set_union(explicitSet
, defaultSet
);
3419 return explicitSet
== defaultSet
;
3422 /// Returns true if the \p explictMap is broadcasted with respect to the
3424 static bool isBroadcasted(AffineMap explictMap
, AffineMap defaultMap
) {
3425 return explictMap
.getNumResults() < defaultMap
.getNumResults();
3428 /// Verifies the broadcast and transpose semantic sepecified by the explicit
3429 /// indexing map for the MatmulOp \p op for each operand specified by \p
3431 static LogicalResult
verifyExtendedMatmulSemantic(MatmulOp matmulOp
,
3433 SmallVector
<AffineMap
, 3> opIndexingMaps
= matmulOp
.getIndexingMapsArray();
3434 SmallVector
<AffineMap
, 3> defaultIndexingMaps
=
3435 matmulOp
.getDefaultIndexingMaps(matmulOp
->getContext());
3437 auto opIndexingMap
= opIndexingMaps
[opIndex
];
3438 auto defaultIndexingMap
= defaultIndexingMaps
[opIndex
];
3439 // Check general validity of indexing map results.
3440 if (!isValidResultDimExprs(opIndexingMap
, defaultIndexingMap
))
3441 return matmulOp
->emitOpError()
3442 << "Unexpected dim expression in map result.";
3444 // Check if the requested broadcast is valid.
3445 if (isBroadcasted(opIndexingMap
, defaultIndexingMap
)) {
3446 if (!matmulOp
.isValidLhsRhsBroadcastMap(opIndexingMap
)) {
3447 return matmulOp
->emitOpError()
3448 << "Invalid broadcast requested, should be (d2).";
3458 //===----------------------------------------------------------------------===//
3460 //===----------------------------------------------------------------------===//
3462 /// Returns a list of AffineMap with the typical matmul indexing charactristic.
3463 SmallVector
<AffineMap
> MatmulOp::getDefaultIndexingMaps(MLIRContext
*context
) {
3464 AffineExpr d0
, d1
, d2
;
3465 SmallVector
<AffineMap
> indexingMaps
;
3466 bindDims(context
, d0
, d1
, d2
);
3467 indexingMaps
.push_back(AffineMap::get(3, 0, {d0
, d2
}, context
));
3468 indexingMaps
.push_back(AffineMap::get(3, 0, {d2
, d1
}, context
));
3469 indexingMaps
.push_back(AffineMap::get(3, 0, {d0
, d1
}, context
));
3470 return indexingMaps
;
3473 SmallVector
<utils::IteratorType
> MatmulOp::getIteratorTypesArray() {
3474 return SmallVector
<utils::IteratorType
>{utils::IteratorType::parallel
,
3475 utils::IteratorType::parallel
,
3476 utils::IteratorType::reduction
};
3479 unsigned MatmulOp::getNumRegionArgs() { return 3; }
3481 std::string
MatmulOp::getLibraryCallName() {
3482 return generateLibraryCallName(getOperation());
3485 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3487 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3488 /// the user defined indexing maps are not equal to default map.
3489 bool MatmulOp::hasUserDefinedMaps() {
3490 SmallVector
<AffineMap
, 3> defaultMaps
=
3491 getDefaultIndexingMaps(this->getContext());
3492 SmallVector
<AffineMap
, 3> explicitMaps
= getIndexingMapsArray();
3493 return defaultMaps
!= explicitMaps
;
3496 /// Implements the block region builder for the MatmulOp. This is called by
3497 /// 'fillStructuredOpRegion'.
3498 void MatmulOp::regionBuilder(ImplicitLocOpBuilder
&b
, Block
&block
,
3499 ArrayRef
<NamedAttribute
> attrs
) {
3500 assert(3 > 0 && block
.getNumArguments() == 3 &&
3501 "MatmulOp regionBuilder expects 3 (>=0) args");
3502 RegionBuilderHelper
helper(b
, block
);
3503 SmallVector
<Value
> yields
;
3505 TypeFn castVal
= TypeFn::cast_signed
;
3506 auto castIter
= llvm::find_if(attrs
, [&](const NamedAttribute
&attr
) {
3507 return attr
.getName() == "cast";
3509 if (castIter
!= attrs
.end()) {
3510 if (auto attr
= llvm::dyn_cast
<TypeFnAttr
>(castIter
->getValue()))
3511 castVal
= attr
.getValue();
3514 Value value1
= helper
.buildTypeFn(castVal
, block
.getArgument(2).getType(),
3515 block
.getArgument(0));
3516 Value value2
= helper
.buildTypeFn(castVal
, block
.getArgument(2).getType(),
3517 block
.getArgument(1));
3518 Value value3
= helper
.buildBinaryFn(BinaryFn::mul
, value1
, value2
);
3520 helper
.buildBinaryFn(BinaryFn::add
, block
.getArgument(2), value3
);
3521 yields
.push_back(value4
);
3522 helper
.yieldOutputs(yields
);
3525 /// Returns true if the given broadcast map \p bcastMap is valid for this op.
3526 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap
) {
3527 assert(bcastMap
.getNumResults() == 1 && "Expected single result dim expr.");
3528 AffineExpr exp
= bcastMap
.getResult(0);
3529 // Invalid map if the common dimension of matmul not found.
3530 return exp
.isFunctionOfDim(bcastMap
.getNumDims() - 1);
3533 ParseResult
MatmulOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
3534 SmallVector
<Attribute
, 3> indexingMapsAttr
;
3536 if (succeeded(parser
.parseOptionalKeyword("indexing_maps"))) {
3537 if (parser
.parseEqual())
3540 if (parser
.parseLSquare())
3544 if (parser
.parseAttribute(mapAttr
))
3546 if (!isa
<AffineMapAttr
>(mapAttr
)) {
3547 return parser
.emitError(parser
.getCurrentLocation(),
3548 "expected affine map attribute");
3550 indexingMapsAttr
.push_back(mapAttr
);
3552 if (parser
.parseOptionalComma())
3556 if (parser
.parseRSquare())
3559 // Initialize indexingMaps, if not supplied explicitly.
3560 if (indexingMapsAttr
.empty()) {
3561 indexingMapsAttr
= llvm::map_to_vector(
3562 MatmulOp::getDefaultIndexingMaps(parser
.getContext()),
3563 [](AffineMap map
) -> Attribute
{ return AffineMapAttr::get(map
); });
3565 result
.addAttribute("indexing_maps",
3566 parser
.getBuilder().getArrayAttr(indexingMapsAttr
));
3568 return parseNamedStructuredOp(parser
, result
, MatmulOp::getNumRegionArgs(),
3569 MatmulOp::getRegionBuilder());
3571 void MatmulOp::print(OpAsmPrinter
&p
) {
3572 SmallVector
<StringRef
, 3> elidedAttrs
= {
3573 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3574 printNamedStructuredOp(p
, getOperation(), getInputs(), getOutputs(),
3577 SmallVector
<Attribute
, 3> indexingMaps
= llvm::map_to_vector(
3578 MatmulOp::getDefaultIndexingMaps(getContext()),
3579 [](AffineMap map
) -> Attribute
{ return AffineMapAttr::get(map
); });
3580 if (!llvm::equal(getIndexingMaps(), indexingMaps
)) {
3581 p
<< " indexing_maps = [";
3582 llvm::interleaveComma(getIndexingMaps(), p
,
3583 [&](Attribute attr
) { p
.printAttribute(attr
); });
3588 /// Verify the user defined indexing maps.
3589 LogicalResult
MatmulOp::verify() {
3590 // Verification of pure matmul is handled by verifyStructuredOpInterface().
3591 if (!hasUserDefinedMaps())
3594 for (unsigned opIndex
= 0; opIndex
< 2; opIndex
++) {
3595 if (failed(verifyExtendedMatmulSemantic(*this, opIndex
)))
3601 LogicalResult
MatmulOp::fold(FoldAdaptor
, SmallVectorImpl
<OpFoldResult
> &) {
3602 return memref::foldMemRefCast(*this);
3604 void MatmulOp::getEffects(
3605 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
3607 if (hasPureTensorSemantics())
3609 getGenericEffectsImpl(effects
, cast
<LinalgOp
>(getOperation()));
3612 Speculation::Speculatability
MatmulOp::getSpeculatability() {
3613 return getGenericSpeculatabilityImpl(cast
<LinalgOp
>(getOperation()));
3616 } // namespace linalg