[TableGen] Fix validateOperandClass for non Phyical Reg (#118146)
[llvm-project.git] / mlir / lib / Dialect / Linalg / IR / LinalgOps.cpp
blobd9840e3923c4f713bbe23564828c36823daf7bf5
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/AsmParser/AsmParser.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Arith/Utils/Utils.h"
19 #include "mlir/Dialect/Complex/IR/Complex.h"
20 #include "mlir/Dialect/Math/IR/Math.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Utils/IndexingUtils.h"
26 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
27 #include "mlir/Dialect/Utils/StaticValueUtils.h"
28 #include "mlir/IR/AffineExprVisitor.h"
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Attributes.h"
31 #include "mlir/IR/BuiltinAttributes.h"
32 #include "mlir/IR/BuiltinTypeInterfaces.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OpImplementation.h"
35 #include "mlir/IR/OperationSupport.h"
36 #include "mlir/IR/PatternMatch.h"
37 #include "mlir/Interfaces/InferTypeOpInterface.h"
38 #include "mlir/Interfaces/SideEffectInterfaces.h"
40 #include "llvm/ADT/DenseMap.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SetOperations.h"
43 #include "llvm/ADT/SmallSet.h"
44 #include "llvm/ADT/SmallVector.h"
45 #include "llvm/ADT/StringSet.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include "llvm/Support/LogicalResult.h"
49 #include "llvm/Support/MathExtras.h"
50 #include "llvm/Support/raw_ostream.h"
51 #include <cassert>
52 #include <optional>
54 using namespace mlir;
55 using namespace mlir::linalg;
57 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
58 static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
59 int64_t dim) {
60 auto type = cast<ShapedType>(v.getType());
61 if (!type.isDynamicDim(dim))
62 return builder.getIndexAttr(type.getDimSize(dim));
64 return getAsOpFoldResult(
65 TypeSwitch<Type, Value>(v.getType())
66 .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
67 return builder.create<tensor::DimOp>(loc, v, dim);
69 .Case<MemRefType>([&](MemRefType t) -> Value {
70 return builder.create<memref::DimOp>(loc, v, dim);
71 }));
74 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
75 /// `source`.
76 static Operation *getSlice(OpBuilder &b, Location loc, Value source,
77 ArrayRef<OpFoldResult> offsets,
78 ArrayRef<OpFoldResult> sizes,
79 ArrayRef<OpFoldResult> strides) {
80 return TypeSwitch<Type, Operation *>(source.getType())
81 .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
82 return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
83 strides);
85 .Case<MemRefType>([&](MemRefType type) -> Operation * {
86 return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
87 strides);
89 .Default([&](Type t) -> Operation * { return nullptr; });
92 //===----------------------------------------------------------------------===//
93 // Helper functions
94 //===----------------------------------------------------------------------===//
96 Value linalg::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
97 int64_t dim) {
98 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
99 return b.createOrFold<memref::DimOp>(loc, source, dim);
100 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
101 return b.createOrFold<tensor::DimOp>(loc, source, dim);
102 llvm_unreachable("Expected MemRefType or TensorType");
105 OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source,
106 int64_t dim) {
107 auto shapedType = llvm::cast<ShapedType>(source.getType());
108 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
109 return createOrFoldDimOp(b, loc, source, dim);
110 return b.getIndexAttr(shapedType.getDimSize(dim));
113 //===----------------------------------------------------------------------===//
114 // Support for named Linalg ops defined in ods-gen.
115 //===----------------------------------------------------------------------===//
117 using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
118 ArrayRef<NamedAttribute>)>;
120 /// Fills the region of a structured operation using the provided
121 /// `regionBuilder`. The method is used by both named structured ops created by
122 /// ods-gen and by manually defined C++ ops. It is called by both builders and
123 /// parsers and creates a block with arguments corresponding to the elemental
124 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
125 /// ShapedType.
126 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
127 TypeRange inputTypes, TypeRange outputTypes,
128 ArrayRef<NamedAttribute> attrs,
129 RegionBuilderFn regionBuilder) {
130 assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
132 SmallVector<Type, 8> argTypes;
133 SmallVector<Location, 8> argLocs;
134 for (auto containers : {inputTypes, outputTypes}) {
135 for (auto t : containers) {
136 argTypes.push_back(
137 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
139 // TODO: Pass in a proper location here.
140 argLocs.push_back(opBuilder.getUnknownLoc());
144 // RAII.
145 OpBuilder::InsertionGuard guard(opBuilder);
146 Block *body =
147 opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
149 opBuilder.setInsertionPointToStart(body);
150 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
151 regionBuilder(b, *body, attrs);
153 // indexing_maps is an auto-generated method.
155 // iterator_types is an auto-generated method.
158 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
159 /// The result types are derived automatically if `resultTensorTypes` is none.
160 /// The body of the operation is filled using `regionBuilder`. All ods-gen
161 /// created structured operations use the method to implement their builders.
162 static void buildStructuredOp(OpBuilder &b, OperationState &state,
163 std::optional<TypeRange> resultTensorTypes,
164 ValueRange inputs, ValueRange outputs,
165 ArrayRef<NamedAttribute> attributes,
166 RegionBuilderFn regionBuilder) {
167 // Derive the result types if needed.
168 SmallVector<Type> derivedResultTypes =
169 resultTensorTypes.value_or(TypeRange());
170 if (!resultTensorTypes)
171 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
172 llvm::IsaPred<RankedTensorType>);
174 state.addOperands(inputs);
175 state.addOperands(outputs);
176 state.addTypes(derivedResultTypes);
178 state.addAttributes(attributes);
179 state.addAttribute(
180 "operandSegmentSizes",
181 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
182 static_cast<int32_t>(outputs.size())}));
184 // Create and fill the region of the structured operation.
185 Region &region = *state.addRegion();
186 fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
187 state.attributes.getAttrs(), regionBuilder);
190 static void buildMatmulOp(OpBuilder &b, OperationState &state,
191 std::optional<TypeRange> resultTensorTypes,
192 ValueRange inputs, ValueRange outputs,
193 ArrayRef<NamedAttribute> attributes,
194 RegionBuilderFn regionBuilder,
195 ArrayRef<AffineMap> indexingMaps) {
196 // Initialize indexingMaps attribute, for MatmulOp.
197 SmallVector<Attribute, 3> indexingMapsAttrVal;
198 indexingMapsAttrVal = llvm::map_to_vector(
199 MatmulOp::getDefaultIndexingMaps(b.getContext()),
200 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
201 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
202 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
203 attributes, regionBuilder);
206 /// Common parsing used for both named structured ops created by ods-gen and by
207 /// manually defined C++ ops. Does not handle regions.
208 static ParseResult
209 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
210 SmallVectorImpl<Type> &inputTypes,
211 SmallVectorImpl<Type> &outputTypes,
212 bool addOperandSegmentSizes = true) {
213 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
214 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
215 outputsOperands;
217 if (succeeded(parser.parseOptionalLess())) {
218 if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
219 return failure();
221 attrsLoc = parser.getCurrentLocation();
222 if (parser.parseOptionalAttrDict(result.attributes))
223 return failure();
225 if (succeeded(parser.parseOptionalKeyword("ins"))) {
226 if (parser.parseLParen())
227 return failure();
229 inputsOperandsLoc = parser.getCurrentLocation();
230 if (parser.parseOperandList(inputsOperands) ||
231 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
232 return failure();
235 if (succeeded(parser.parseOptionalKeyword("outs"))) {
236 outputsOperandsLoc = parser.getCurrentLocation();
237 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
238 parser.parseColonTypeList(outputTypes) || parser.parseRParen())
239 return failure();
242 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
243 result.operands) ||
244 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
245 result.operands))
246 return failure();
248 if (addOperandSegmentSizes) {
249 // This is a bit complex because we're trying to be backward compatible with
250 // operation syntax that mix the inherent attributes and the discardable
251 // ones in the same dictionary. If the properties are used, we append the
252 // operandSegmentSizes there directly. Otherwise we append it to the
253 // discardable attributes dictionary where it is handled by the generic
254 // Operation::create(...) method.
255 if (result.propertiesAttr) {
256 NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
257 attrs.append("operandSegmentSizes",
258 parser.getBuilder().getDenseI32ArrayAttr(
259 {static_cast<int32_t>(inputsOperands.size()),
260 static_cast<int32_t>(outputsOperands.size())}));
261 result.propertiesAttr = attrs.getDictionary(parser.getContext());
262 } else {
263 result.addAttribute("operandSegmentSizes",
264 parser.getBuilder().getDenseI32ArrayAttr(
265 {static_cast<int32_t>(inputsOperands.size()),
266 static_cast<int32_t>(outputsOperands.size())}));
269 if (!result.propertiesAttr) {
270 std::optional<RegisteredOperationName> info =
271 result.name.getRegisteredInfo();
272 if (info) {
273 if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
274 return parser.emitError(attrsLoc)
275 << "'" << result.name.getStringRef() << "' op ";
276 })))
277 return failure();
280 return success();
283 static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
284 ValueRange outputs) {
285 if (!inputs.empty())
286 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
287 if (!outputs.empty())
288 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
291 //===----------------------------------------------------------------------===//
292 // Specific parsing and printing for named structured ops created by ods-gen.
293 //===----------------------------------------------------------------------===//
295 static ParseResult parseNamedStructuredOpRegion(
296 OpAsmParser &parser, Region &region, unsigned numRegionArgs,
297 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
298 RegionBuilderFn regionBuilder) {
299 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
300 return parser.emitError(
301 parser.getCurrentLocation(),
302 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
303 "region expects {0} args, got {1}",
304 numRegionArgs, inputTypes.size() + outputTypes.size()));
307 OpBuilder opBuilder(parser.getContext());
308 fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
309 regionBuilder);
310 return success();
313 static ParseResult
314 parseNamedStructuredOpResults(OpAsmParser &parser,
315 SmallVectorImpl<Type> &resultTypes) {
316 if (parser.parseOptionalArrowTypeList(resultTypes))
317 return failure();
318 return success();
321 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
322 OperationState &result,
323 unsigned numRegionArgs,
324 RegionBuilderFn regionBuilder) {
325 // TODO: Enable when ods-gen supports captures.
326 SmallVector<Type, 1> inputTypes, outputTypes;
327 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
328 return failure();
330 // Parse optional attributes.
331 if (parser.parseOptionalAttrDict(result.attributes))
332 return failure();
334 // TODO: consider merging results parsing into region parsing.
335 // Need to wait for declarative assembly resolution to decide.
336 SmallVector<Type, 1> outputTensorsTypes;
337 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
338 return failure();
339 result.addTypes(outputTensorsTypes);
341 std::unique_ptr<Region> region = std::make_unique<Region>();
342 if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
343 outputTypes, result.attributes.getAttrs(),
344 regionBuilder))
345 return failure();
346 result.addRegion(std::move(region));
348 return success();
351 static void printNamedStructuredOpResults(OpAsmPrinter &p,
352 TypeRange resultTypes) {
353 if (resultTypes.empty())
354 return;
355 p.printOptionalArrowTypeList(resultTypes);
358 static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
359 ValueRange inputs, ValueRange outputs,
360 ArrayRef<StringRef> elidedAttrs = {}) {
361 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
363 // Printing is shared with generic ops, except for the region and
364 // attributes.
365 printCommonStructuredOpParts(p, inputs, outputs);
367 // Results printing.
368 printNamedStructuredOpResults(p, op->getResultTypes());
370 // Region is elided.
373 //===----------------------------------------------------------------------===//
374 // Region builder helper.
375 // TODO: Move this to a utility library.
376 // The public methods on this class are referenced directly from generated code.
377 // Helper build the unary, binary, and type conversion functions defined by the
378 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
379 // class.
381 // Implementations of the math functions must be polymorphic over numeric types,
382 // internally performing necessary casts. If the function application makes no
383 // sense, then the only recourse is to assert and return nullptr. This can be
384 // extended later if it becomes possible to fail construction of the region. The
385 // invariant should be enforced at a higher level.
387 // TODO: These helpers are currently type polymorphic over the class of integer
388 // and floating point types, but they will not internally cast within bit
389 // widths of a class (mixed precision such as i8->i32) or across classes
390 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
391 // to be handled with care and work is being considered to extend the op
392 // language to make such cases explicit. In the mean-time, violating this will
393 // fail verification, which is deemed acceptable.
394 //===----------------------------------------------------------------------===//
396 namespace {
398 class RegionBuilderHelper {
399 public:
400 RegionBuilderHelper(OpBuilder &builder, Block &block)
401 : builder(builder), block(block) {}
403 // Build the unary functions defined by OpDSL.
404 Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
405 if (!isFloatingPoint(arg))
406 llvm_unreachable("unsupported non numeric type");
407 OpBuilder::InsertionGuard g(builder);
408 builder.setInsertionPointToEnd(&block);
409 switch (unaryFn) {
410 case UnaryFn::exp:
411 return builder.create<math::ExpOp>(arg.getLoc(), arg);
412 case UnaryFn::log:
413 return builder.create<math::LogOp>(arg.getLoc(), arg);
414 case UnaryFn::abs:
415 return builder.create<math::AbsFOp>(arg.getLoc(), arg);
416 case UnaryFn::ceil:
417 return builder.create<math::CeilOp>(arg.getLoc(), arg);
418 case UnaryFn::floor:
419 return builder.create<math::FloorOp>(arg.getLoc(), arg);
420 case UnaryFn::negf:
421 return builder.create<arith::NegFOp>(arg.getLoc(), arg);
422 case UnaryFn::reciprocal: {
423 Attribute oneAttr = builder.getOneAttr(arg.getType());
424 auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
425 ::cast<TypedAttr>(oneAttr));
426 return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
428 case UnaryFn::round:
429 return builder.create<math::RoundOp>(arg.getLoc(), arg);
430 case UnaryFn::sqrt:
431 return builder.create<math::SqrtOp>(arg.getLoc(), arg);
432 case UnaryFn::rsqrt:
433 return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
434 case UnaryFn::square:
435 return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
436 case UnaryFn::tanh:
437 return builder.create<math::TanhOp>(arg.getLoc(), arg);
438 case UnaryFn::erf:
439 return builder.create<math::ErfOp>(arg.getLoc(), arg);
441 llvm_unreachable("unsupported unary function");
444 // Build the binary functions defined by OpDSL.
445 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
446 bool allComplex = isComplex(arg0) && isComplex(arg1);
447 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
448 bool allInteger = isInteger(arg0) && isInteger(arg1);
449 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
450 arg1.getType().getIntOrFloatBitWidth() == 1;
451 if (!allComplex && !allFloatingPoint && !allInteger)
452 llvm_unreachable("unsupported non numeric type");
453 OpBuilder::InsertionGuard g(builder);
454 builder.setInsertionPointToEnd(&block);
455 switch (binaryFn) {
456 case BinaryFn::add:
457 if (allComplex)
458 return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
459 if (allFloatingPoint)
460 return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
461 if (allBool)
462 return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
463 return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
464 case BinaryFn::sub:
465 if (allComplex)
466 return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
467 if (allFloatingPoint)
468 return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
469 if (allBool)
470 llvm_unreachable("unsupported operation: sub with bools");
471 return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
472 case BinaryFn::mul:
473 if (allComplex)
474 return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
475 if (allFloatingPoint)
476 return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
477 if (allBool)
478 return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
479 return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
480 case BinaryFn::div:
481 if (allComplex)
482 return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
483 if (allFloatingPoint)
484 return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
485 if (allBool)
486 llvm_unreachable("unsupported operation: div with bools");
487 return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
488 case BinaryFn::div_unsigned:
489 if (!allInteger || allBool)
490 llvm_unreachable("unsupported operation: unsigned div not on uint");
491 return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
492 case BinaryFn::max_signed:
493 assert(!allComplex);
494 if (allFloatingPoint)
495 return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
496 return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
497 case BinaryFn::min_signed:
498 assert(!allComplex);
499 if (allFloatingPoint)
500 return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
501 return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
502 case BinaryFn::max_unsigned:
503 assert(!allComplex);
504 if (allFloatingPoint)
505 return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
506 return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
507 case BinaryFn::min_unsigned:
508 assert(!allComplex);
509 if (allFloatingPoint)
510 return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
511 return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
512 case BinaryFn::powf:
513 assert(allFloatingPoint);
514 return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
516 llvm_unreachable("unsupported binary function");
519 // Build the ternary functions defined by OpDSL.
520 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
521 Value arg2) {
522 bool headBool =
523 isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
524 bool tailFloatingPoint =
525 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
526 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
527 OpBuilder::InsertionGuard g(builder);
528 builder.setInsertionPointToEnd(&block);
529 switch (ternaryFn) {
530 case TernaryFn::select:
531 if (!headBool && !(tailFloatingPoint || tailInteger))
532 llvm_unreachable("unsupported non numeric type");
533 return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
535 llvm_unreachable("unsupported ternary function");
538 // Build the type functions defined by OpDSL.
539 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
540 switch (typeFn) {
541 case TypeFn::cast_signed:
542 return cast(toType, operand, false);
543 case TypeFn::cast_unsigned:
544 return cast(toType, operand, true);
546 llvm_unreachable("unsupported type conversion function");
549 void yieldOutputs(ValueRange values) {
550 OpBuilder::InsertionGuard g(builder);
551 builder.setInsertionPointToEnd(&block);
552 Location loc = builder.getUnknownLoc();
553 builder.create<YieldOp>(loc, values);
556 Value constant(const std::string &value) {
557 OpBuilder::InsertionGuard g(builder);
558 builder.setInsertionPointToEnd(&block);
559 Location loc = builder.getUnknownLoc();
560 Attribute valueAttr = parseAttribute(value, builder.getContext());
561 return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
564 Value index(int64_t dim) {
565 OpBuilder::InsertionGuard g(builder);
566 builder.setInsertionPointToEnd(&block);
567 return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
570 Type getIntegerType(unsigned width) {
571 return IntegerType::get(builder.getContext(), width);
574 Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
575 Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
577 private:
578 // Generates operations to cast the given operand to a specified type.
579 // If the cast cannot be performed, a warning will be issued and the
580 // operand returned as-is (which will presumably yield a verification
581 // issue downstream).
582 Value cast(Type toType, Value operand, bool isUnsignedCast) {
583 OpBuilder::InsertionGuard g(builder);
584 builder.setInsertionPointToEnd(&block);
585 auto loc = operand.getLoc();
586 return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
589 bool isComplex(Value value) {
590 return llvm::isa<ComplexType>(value.getType());
592 bool isFloatingPoint(Value value) {
593 return llvm::isa<FloatType>(value.getType());
595 bool isInteger(Value value) {
596 return llvm::isa<IntegerType>(value.getType());
599 OpBuilder &builder;
600 Block &block;
603 } // namespace
605 //===----------------------------------------------------------------------===//
606 // CopyOp
607 //===----------------------------------------------------------------------===//
609 namespace {
611 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
612 using OpRewritePattern<CopyOp>::OpRewritePattern;
613 LogicalResult matchAndRewrite(CopyOp copyOp,
614 PatternRewriter &rewriter) const override {
615 if (copyOp.getInputs() != copyOp.getOutputs())
616 return rewriter.notifyMatchFailure(copyOp, "not a self copy");
617 if (copyOp.hasPureBufferSemantics())
618 rewriter.eraseOp(copyOp);
619 else
620 rewriter.replaceOp(copyOp, copyOp.getInputs());
622 return success();
626 } // namespace
628 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
629 MLIRContext *context) {
630 results.add<EraseSelfCopy>(context);
633 //===----------------------------------------------------------------------===//
634 // FillOp
635 //===----------------------------------------------------------------------===//
637 namespace {
639 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
641 /// For such op chains, we can create new linalg.fill ops with the result
642 /// type of the tensor.expand/collapse_shape op.
643 template <typename TensorReshapeOp>
644 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
645 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
646 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
647 PatternRewriter &rewriter) const override {
648 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
649 if (!oldFill)
650 return failure();
652 Location loc = oldFill.getLoc();
653 TensorReshapeOp newInit;
654 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
656 newInit = rewriter.create<TensorReshapeOp>(
657 loc, reshapeOp.getResultType(), oldFill.output(),
658 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
659 reshapeOp.getStaticOutputShape());
660 } else {
661 newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
662 oldFill.output(),
663 reshapeOp.getReassociation());
665 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
666 ValueRange{newInit});
667 return success();
671 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
672 /// filling value are the same.
673 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
674 using OpRewritePattern::OpRewritePattern;
676 LogicalResult matchAndRewrite(tensor::PadOp padOp,
677 PatternRewriter &rewriter) const override {
678 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
679 if (!fillOp)
680 return failure();
682 // We can only fold if the padding value is the same as the original
683 // filling value.
684 Value padValue = padOp.getConstantPaddingValue();
685 if (!padValue || fillOp.value() != padValue)
686 return failure();
688 ReifiedRankedShapedTypeDims reifiedShape;
689 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
690 return rewriter.notifyMatchFailure(
691 padOp, "failed to reify tensor.pad op result shape");
693 auto emptyTensor = rewriter.create<tensor::EmptyOp>(
694 padOp.getLoc(), reifiedShape.front(),
695 padOp.getResultType().getElementType());
696 Value replacement =
697 rewriter
698 .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
699 ValueRange{emptyTensor})
700 .getResult(0);
701 if (replacement.getType() != padOp.getResultType()) {
702 replacement = rewriter.create<tensor::CastOp>(
703 fillOp.getLoc(), padOp.getResultType(), replacement);
705 rewriter.replaceOp(padOp, replacement);
706 return success();
710 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
711 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
712 /// filling value are the same.
713 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
714 using OpRewritePattern::OpRewritePattern;
716 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
717 PatternRewriter &rewriter) const override {
718 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
719 if (!srcPadOp)
720 return failure();
722 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
723 return failure();
725 // Walk back the tensor.insert_slice chain and find the first destination
726 // value at the start of the chain.
727 Value firstDest = insertOp.getDest();
728 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
729 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
730 return failure();
732 // Make sure the range of values accessed are disjoint. Without this, we
733 // cannot fold tensor.pad away.
734 bool disjoint = false;
735 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
736 // If the dimension has dynamic offset/size, we cannot guarantee
737 // disjoint. So just skip it.
738 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
739 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
740 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
741 continue;
743 // Get the range start and end, inclusively for both.
744 int64_t prevStart = prevOp.getStaticOffset(i);
745 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
746 prevOp.getStaticStride(i);
747 int64_t nextStart = insertOp.getStaticOffset(i);
748 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
749 insertOp.getStaticStride(i);
750 if (prevEnd < nextStart || nextEnd < prevStart) {
751 disjoint = true;
752 break;
756 if (!disjoint)
757 break;
758 firstDest = prevOp.getDest();
761 // Check whether the first destination is a fill op. For overlapped cases,
762 // this also cannot be true.
763 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
764 if (!dstFillOp)
765 return failure();
767 // We can only fold if the padding value is the same as the original
768 // filling value.
769 Value padValue = srcPadOp.getConstantPaddingValue();
770 if (!padValue || dstFillOp.value() != padValue)
771 return failure();
773 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
774 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
776 Location loc = insertOp.getLoc();
777 MLIRContext *context = getContext();
779 AffineExpr sym0, sym1;
780 bindSymbols(context, sym0, sym1);
781 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
783 // Calculate the new offsets for the insert. It should be the old offsets
784 // plus low padding sizes.
785 SmallVector<OpFoldResult, 4> newOffsets;
786 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
787 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
788 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
791 RankedTensorType srcPadType = srcPadOp.getSourceType();
792 SmallVector<OpFoldResult, 4> newSizes;
793 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
794 if (srcPadType.isDynamicDim(i)) {
795 newSizes.push_back(
796 rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
797 .getResult());
798 } else {
799 newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
803 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
804 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
805 newSizes, insertOp.getMixedStrides());
806 return success();
810 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
811 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
812 public:
813 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
815 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
816 PatternRewriter &rewriter) const override {
817 // See if tensor input of tensor.extract op is the result of a linalg.fill
818 // op.
819 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
820 if (!fillOp)
821 return failure();
823 // Get scalar input operand of linalg.fill op.
824 Value extractedScalar = fillOp.getInputs()[0];
826 // Replace tensor.extract op with scalar value used to fill the tensor.
827 rewriter.replaceOp(extractOp, extractedScalar);
828 return success();
832 /// Folds pack(fill) into a single fill op if
833 /// 1. The pack op does not have padding value, or
834 /// 2. The filled value and padding value are the same.
835 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
836 tensor::PackOp packOp) {
837 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
838 if (!fillOp)
839 return failure();
841 if (auto paddingValue = packOp.getPaddingValue())
842 if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
843 return failure();
845 Value packOpDest = packOp.getDest();
846 if (!packOpDest.hasOneUse())
847 return failure();
849 return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
850 packOp.getDest());
853 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
854 struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
855 public:
856 FoldFillWithPack(MLIRContext *context)
857 : OpRewritePattern<tensor::PackOp>(context) {}
859 LogicalResult matchAndRewrite(tensor::PackOp packOp,
860 PatternRewriter &rewriter) const override {
861 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
862 if (failed(fillOp))
863 return failure();
864 rewriter.replaceOp(packOp, fillOp.value().result());
865 return success();
869 /// Fold fill with copy.
870 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
871 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
873 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
874 PatternRewriter &rewriter) const override {
875 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
876 rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
877 fillOp.getInputs(),
878 copyOp.getOutputs());
879 return success();
881 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
882 rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
883 fillOp.getOutputs());
884 return success();
886 return failure();
890 /// Fold fill with transpose.
891 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
892 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
894 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
895 PatternRewriter &rewriter) const override {
896 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
897 rewriter.replaceOpWithNewOp<FillOp>(
898 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
899 transposeOp.getDpsInitOperand(0)->get());
900 return success();
902 return failure();
906 /// Fold a concat with all elements being fills of the same value
907 /// into a fill of the concat result shape.
908 struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
909 using OpRewritePattern::OpRewritePattern;
911 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
912 PatternRewriter &rewriter) const override {
913 auto concatOperands = concatOp.getInputs();
914 if (concatOperands.empty()) {
915 return failure();
918 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
919 if (!firstFillOp) {
920 return failure();
922 // Prefetch the fill value.
923 OpFoldResult firstFillVal =
924 getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
925 // Collect all the outs values for the fill operations.
926 SmallVector<Value> allOuts;
927 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
929 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
930 auto fillOp = v.getDefiningOp<linalg::FillOp>();
931 if (!fillOp) {
932 return false;
935 OpFoldResult fillVal =
936 getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
937 if (fillVal != firstFillVal)
938 return false;
940 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
941 return true;
943 if (!llvm::all_of(concatOperands.drop_front(),
944 isDefinedByCompatibleFillOp)) {
945 return rewriter.notifyMatchFailure(
946 concatOp, "not all operands are defined by a compatible fill op");
949 Value outsConcat = rewriter.create<tensor::ConcatOp>(
950 concatOp.getLoc(), concatOp.getDim(), allOuts);
951 rewriter.replaceOpWithNewOp<linalg::FillOp>(
952 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
953 return success();
957 } // namespace
959 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
960 MLIRContext *context) {
961 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
962 FoldFillWithPack, FoldFillWithPad,
963 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
964 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
965 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
968 //===----------------------------------------------------------------------===//
969 // GenericOp
970 //===----------------------------------------------------------------------===//
972 static void buildGenericRegion(
973 OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
974 ValueRange outputs,
975 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
976 SmallVector<Type, 4> blockArgTypes;
977 SmallVector<Location, 4> blockArgLocs;
978 for (ValueRange container : {inputs, outputs}) {
979 for (Value v : container) {
980 Type t = v.getType();
981 blockArgTypes.push_back(
982 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
983 blockArgLocs.push_back(v.getLoc());
987 OpBuilder::InsertionGuard guard(builder);
988 Block *bodyBlock =
989 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
990 bodyBuild(builder, loc, bodyBlock->getArguments());
993 void GenericOp::getAsmBlockArgumentNames(Region &region,
994 OpAsmSetValueNameFn setNameFn) {
995 for (Value v : getRegionInputArgs())
996 setNameFn(v, "in");
997 for (Value v : getRegionOutputArgs())
998 setNameFn(v, "out");
1001 void GenericOp::build(
1002 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1003 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1004 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1005 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1006 ArrayRef<NamedAttribute> attributes) {
1007 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1008 iteratorTypes, doc, libraryCall);
1009 result.addAttributes(attributes);
1010 if (bodyBuild)
1011 buildGenericRegion(builder, result.location, *result.regions.front(),
1012 inputs, outputs, bodyBuild);
1015 void GenericOp::build(
1016 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1017 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1018 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1019 StringRef libraryCall,
1020 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1021 ArrayRef<NamedAttribute> attributes) {
1022 build(builder, result, resultTensorTypes, inputs, outputs,
1023 builder.getAffineMapArrayAttr(indexingMaps),
1024 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1025 iteratorTypes,
1026 [&](utils::IteratorType iter) -> mlir::Attribute {
1027 return IteratorTypeAttr::get(builder.getContext(), iter);
1028 }))),
1029 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1030 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1031 bodyBuild, attributes);
1034 void GenericOp::build(
1035 OpBuilder &builder, OperationState &result, ValueRange inputs,
1036 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1037 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1038 StringRef libraryCall,
1039 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1040 ArrayRef<NamedAttribute> attributes) {
1041 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1042 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1045 void GenericOp::build(
1046 OpBuilder &builder, OperationState &result, ValueRange inputs,
1047 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1048 ArrayRef<utils::IteratorType> iteratorTypes,
1049 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1050 ArrayRef<NamedAttribute> attributes) {
1051 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1052 /*doc=*/"",
1053 /*libraryCall=*/"", bodyBuild, attributes);
1056 void GenericOp::build(
1057 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1058 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1059 ArrayRef<utils::IteratorType> iteratorTypes,
1060 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1061 ArrayRef<NamedAttribute> attributes) {
1062 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1063 iteratorTypes,
1064 /*doc=*/"",
1065 /*libraryCall=*/"", bodyBuild, attributes);
1068 void GenericOp::print(OpAsmPrinter &p) {
1069 p << " ";
1071 // Print extra attributes.
1072 auto genericAttrNames = linalgTraitAttrNames();
1074 llvm::StringSet<> genericAttrNamesSet;
1075 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
1076 SmallVector<NamedAttribute, 8> genericAttrs;
1077 for (auto attr : (*this)->getAttrs()) {
1078 if (attr.getName() == getIteratorTypesAttrName()) {
1079 auto iteratorTypes =
1080 llvm::cast<ArrayAttr>(attr.getValue())
1081 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1082 // Convert IteratorType enums into the string representation. This is
1083 // needed, because tests still use the old format when 'iterator_types'
1084 // attribute is represented as an array of strings.
1085 // TODO: Remove this conversion once tests are fixed.
1086 SmallVector<Attribute> iteratorTypeNames =
1087 llvm::to_vector(llvm::map_range(
1088 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1089 return StringAttr::get(getContext(), stringifyIteratorType(t));
1090 }));
1092 genericAttrs.emplace_back(
1093 getIteratorTypesAttrName(),
1094 ArrayAttr::get(getContext(), iteratorTypeNames));
1095 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1096 genericAttrs.push_back(attr);
1099 if (!genericAttrs.empty()) {
1100 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1101 p << genericDictAttr;
1104 // Printing is shared with named ops, except for the region and attributes
1105 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1107 genericAttrNames.push_back("operandSegmentSizes");
1108 genericAttrNamesSet.insert(genericAttrNames.back());
1110 bool hasExtraAttrs = false;
1111 for (NamedAttribute n : (*this)->getAttrs()) {
1112 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1113 break;
1115 if (hasExtraAttrs) {
1116 p << " attrs = ";
1117 p.printOptionalAttrDict((*this)->getAttrs(),
1118 /*elidedAttrs=*/genericAttrNames);
1121 // Print region.
1122 if (!getRegion().empty()) {
1123 p << ' ';
1124 p.printRegion(getRegion());
1127 // Print results.
1128 printNamedStructuredOpResults(p, getResultTensors().getTypes());
1131 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1132 DictionaryAttr dictAttr;
1133 // Parse the core linalg traits that must check into a dictAttr.
1134 // The name is unimportant as we will overwrite result.attributes.
1135 // The core linalg traits must contain the information necessary to pass the
1136 // verifier.
1137 llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1138 if (parser.parseAttribute(dictAttr, "_", result.attributes))
1139 return failure();
1140 result.attributes.assign(dictAttr.getValue().begin(),
1141 dictAttr.getValue().end());
1143 // Convert array of string into an array of IteratorType enums. This is
1144 // needed, because tests still use the old format when 'iterator_types'
1145 // attribute is represented as an array of strings.
1146 // TODO: Remove this conversion once tests are fixed.
1147 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1148 result.attributes.get(getIteratorTypesAttrName(result.name)));
1149 if (!iteratorTypes) {
1150 return parser.emitError(attributeLocation)
1151 << "expected " << getIteratorTypesAttrName(result.name)
1152 << " array attribute";
1155 SmallVector<Attribute> iteratorTypeAttrs;
1157 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1158 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1159 if (!maybeIteratorType.has_value())
1160 return parser.emitError(parser.getCurrentLocation())
1161 << "unexpected iterator_type (" << s << ")";
1163 iteratorTypeAttrs.push_back(
1164 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1166 result.attributes.set(getIteratorTypesAttrName(result.name),
1167 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1169 // Parsing is shared with named ops, except for the region.
1170 SmallVector<Type, 1> inputTypes, outputTypes;
1171 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1172 return failure();
1174 // Optional attributes may be added.
1175 if (succeeded(parser.parseOptionalKeyword("attrs")))
1176 if (failed(parser.parseEqual()) ||
1177 failed(parser.parseOptionalAttrDict(result.attributes)))
1178 return failure();
1180 std::unique_ptr<Region> region = std::make_unique<Region>();
1181 if (parser.parseRegion(*region, {}))
1182 return failure();
1183 result.addRegion(std::move(region));
1185 // Generic ops may specify that a subset of its outputs are tensors. Such
1186 // outputs are specified in the result type.
1187 // TODO: may need to move output parsing before region parsing.
1188 // Need to wait for declarative assembly resolution to decide.
1189 SmallVector<Type, 1> outputTensorsTypes;
1190 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1191 return failure();
1192 result.addTypes(outputTensorsTypes);
1194 return success();
1197 static void getGenericEffectsImpl(
1198 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1199 &effects,
1200 LinalgOp linalgOp) {
1201 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1202 if (!llvm::isa<MemRefType>(operand.getType()))
1203 continue;
1204 effects.emplace_back(
1205 MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1206 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1209 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1210 if (!llvm::isa<MemRefType>(operand.get().getType()))
1211 continue;
1212 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1213 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1214 /*effectOnFullRegion=*/true,
1215 SideEffects::DefaultResource::get());
1217 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1218 /*effectOnFullRegion=*/true,
1219 SideEffects::DefaultResource::get());
1223 void GenericOp::getEffects(
1224 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1225 &effects) {
1226 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1229 static Speculation::Speculatability
1230 getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1231 // Operands with value semantics are speculatable, while operands with memory
1232 // semantics are not.
1233 if (!linalgOp.hasPureTensorSemantics())
1234 return Speculation::NotSpeculatable;
1235 // The body of the op can still have speculation in its region.
1236 return Speculation::RecursivelySpeculatable;
1239 Speculation::Speculatability GenericOp::getSpeculatability() {
1240 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1243 LogicalResult GenericOp::verify() { return success(); }
1245 namespace {
1247 /// Remove any linalg operation (on tensors) that are just copying
1248 /// the values from inputs to the results. Requirements are
1249 /// 1) All iterator types are parallel
1250 /// 2) The body contains just a yield operation with the yielded values being
1251 /// the arguments corresponding to the operands.
1252 template <typename OpTy>
1253 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1254 using OpRewritePattern<OpTy>::OpRewritePattern;
1256 LogicalResult matchAndRewrite(OpTy linalgOp,
1257 PatternRewriter &rewriter) const override {
1258 // All indexing maps must be equal. It follows that they are permutations.
1259 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1260 return failure();
1262 // Check that the body of the linalg operation is just a linalg.yield
1263 // operation.
1264 Block &body = linalgOp->getRegion(0).front();
1265 if (!llvm::hasSingleElement(body))
1266 return failure();
1267 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1268 if (!yieldOp)
1269 return failure();
1271 // In the buffer case, we need to check exact buffer equality.
1272 if (linalgOp.hasPureBufferSemantics()) {
1273 if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1274 linalgOp.getDpsInputOperand(0)->get() ==
1275 linalgOp.getDpsInitOperand(0)->get()) {
1276 rewriter.eraseOp(linalgOp);
1277 return success();
1279 return failure();
1282 // Mixed semantics is not supported yet.
1283 if (!linalgOp.hasPureTensorSemantics())
1284 return failure();
1286 // Get the argument number of the returned values. That is the operand
1287 // number to use for replacing uses of this operation.
1288 SmallVector<Value> returnedArgs;
1289 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1290 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1291 if (!yieldArg || yieldArg.getOwner() != &body)
1292 return failure();
1293 unsigned argumentNumber = yieldArg.getArgNumber();
1294 Value returnedArg = linalgOp->getOperand(argumentNumber);
1295 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1296 // The input can have a different type than the result, e.g. a dynamic
1297 // input dimension can be turned into a static output dimension.
1298 Type returnType = returnedArg.getType();
1299 if (returnType != resultType) {
1300 // Distinguish between sparse conversion or dense tensor casting.
1301 // TODO: unify the two ops?
1302 if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1303 sparse_tensor::getSparseTensorEncoding(resultType))
1304 returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1305 linalgOp.getLoc(), resultType, returnedArg);
1306 else {
1307 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1308 resultType))
1309 return failure();
1310 returnedArg = rewriter.create<tensor::CastOp>(
1311 linalgOp.getLoc(), resultType, returnedArg);
1314 returnedArgs.push_back(returnedArg);
1317 if (returnedArgs.size() != linalgOp->getNumResults())
1318 return failure();
1319 rewriter.replaceOp(linalgOp, returnedArgs);
1320 return success();
1324 } // namespace
1326 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1327 MLIRContext *context) {
1328 results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1331 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1332 return memref::foldMemRefCast(*this);
1335 //===----------------------------------------------------------------------===//
1336 // MapOp
1337 //===----------------------------------------------------------------------===//
1339 static ParseResult parseDstStyleOp(
1340 OpAsmParser &parser, OperationState &result,
1341 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1342 nullptr) {
1343 // Parse `ins` and `outs`.
1344 SmallVector<Type, 4> inputTypes, outputTypes;
1345 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1346 /*addOperandSegmentSizes=*/false))
1347 return failure();
1349 // Add result types.
1350 for (Type outputType : outputTypes) {
1351 if (llvm::isa<RankedTensorType>(outputType))
1352 result.addTypes(outputType);
1355 // Parse required attributes.
1356 if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1357 return failure();
1359 // Parse optional attributes.
1360 if (parser.parseOptionalAttrDict(result.attributes))
1361 return failure();
1362 return success();
1365 void MapOp::getAsmBlockArgumentNames(Region &region,
1366 OpAsmSetValueNameFn setNameFn) {
1367 for (Value v : getRegionInputArgs())
1368 setNameFn(v, "in");
1371 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1372 if (!getResults().empty())
1373 setNameFn(getResults().front(), "mapped");
1376 void MapOp::build(
1377 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1378 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1379 ArrayRef<NamedAttribute> attributes) {
1380 build(builder, result, TypeRange{}, inputs, init);
1381 result.addAttributes(attributes);
1383 // Add output types for `RankedTensorType` output arguments.
1384 Type initType = init.getType();
1385 if (llvm::isa<RankedTensorType>(initType))
1386 result.addTypes(initType);
1388 if (bodyBuild)
1389 buildGenericRegion(builder, result.location, *result.regions.front(),
1390 inputs, /*outputs=*/{}, bodyBuild);
1393 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1394 const OperationName &payloadOpName,
1395 const NamedAttrList &payloadOpAttrs,
1396 ArrayRef<Value> operands,
1397 bool initFirst = false) {
1398 OpBuilder b(parser.getContext());
1399 Region *body = result.addRegion();
1400 Block &block = body->emplaceBlock();
1401 b.setInsertionPointToStart(&block);
1402 SmallVector<Value> bbArgs;
1403 for (auto &operand : operands) {
1404 block.addArgument(
1405 llvm::cast<ShapedType>(operand.getType()).getElementType(),
1406 b.getUnknownLoc());
1408 SmallVector<Value> payloadOpOperands;
1409 // If initFirst flag is enabled, we consider init as the first position of
1410 // payload operands.
1411 if (initFirst) {
1412 payloadOpOperands.push_back(block.getArguments().back());
1413 for (const auto &arg : block.getArguments().drop_back())
1414 payloadOpOperands.push_back(arg);
1415 } else {
1416 payloadOpOperands = {block.getArguments().begin(),
1417 block.getArguments().end()};
1420 Operation *payloadOp = b.create(
1421 result.location, b.getStringAttr(payloadOpName.getStringRef()),
1422 payloadOpOperands,
1423 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1424 .getElementType()},
1425 payloadOpAttrs);
1426 b.create<YieldOp>(result.location, payloadOp->getResults());
1429 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1430 std::optional<OperationName> payloadOpName;
1431 NamedAttrList payloadOpAttrs;
1432 if (succeeded(parser.parseOptionalLBrace())) {
1433 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1434 if (failed(operationName))
1435 return failure();
1436 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1437 return failure();
1438 payloadOpName = operationName.value();
1439 if (parser.parseRBrace())
1440 return failure();
1443 if (parseDstStyleOp(parser, result))
1444 return failure();
1446 if (payloadOpName.has_value()) {
1447 if (!result.operands.empty())
1448 addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1449 payloadOpAttrs,
1450 ArrayRef(result.operands).drop_back());
1451 else
1452 result.addRegion();
1453 } else {
1454 SmallVector<OpAsmParser::Argument> regionArgs;
1455 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1456 /*allowType=*/true, /*allowAttrs=*/true)) {
1457 return failure();
1459 Region *body = result.addRegion();
1460 if (parser.parseRegion(*body, regionArgs))
1461 return failure();
1463 return success();
1466 // Retrieve the operation from the body, if it is the only one (except
1467 // yield) and if it gets the same amount of arguments as the body does.
1468 // If initFirst flag is enabled, we check that init takes the first position in
1469 // operands of payload.
1470 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1471 if (body->getOperations().size() != 2)
1472 return nullptr;
1473 Operation &payload = body->getOperations().front();
1474 assert(isa<YieldOp>(body->getOperations().back()));
1476 if (payload.getNumOperands() == 0 ||
1477 payload.getNumOperands() != body->getNumArguments())
1478 return nullptr;
1479 if (initFirst) {
1480 // check init
1481 if (payload.getOperands().back() != body->getArgument(0))
1482 return nullptr;
1483 // check rest
1484 for (const auto &[operand, bbArg] :
1485 llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1486 if (bbArg != operand)
1487 return nullptr;
1489 } else {
1490 for (const auto &[operand, bbArg] :
1491 llvm::zip(payload.getOperands(), body->getArguments())) {
1492 if (bbArg != operand)
1493 return nullptr;
1496 return &payload;
1499 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1500 SmallVector<StringRef> elidedAttrs;
1501 p << " { " << payloadOp->getName().getStringRef();
1502 for (const auto &attr : payloadOp->getAttrs()) {
1503 if (auto fastAttr = dyn_cast<arith::FastMathFlagsAttr>(attr.getValue())) {
1504 if (fastAttr.getValue() == arith::FastMathFlags::none) {
1505 elidedAttrs.push_back(attr.getName());
1508 if (auto denormAttr = dyn_cast<arith::DenormalModeAttr>(attr.getValue())) {
1509 if (denormAttr.getValue() == arith::DenormalMode::ieee) {
1510 elidedAttrs.push_back(attr.getName());
1514 p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1515 p << " }";
1518 void MapOp::print(OpAsmPrinter &p) {
1519 Block *mapper = getBody();
1520 Operation *payloadOp = findPayloadOp(mapper);
1521 if (payloadOp) {
1522 printShortForm(p, payloadOp);
1525 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1526 p.printOptionalAttrDict((*this)->getAttrs());
1528 if (!payloadOp) {
1529 // Print region if the payload op was not detected.
1530 p.increaseIndent();
1531 p.printNewline();
1532 p << "(";
1533 llvm::interleaveComma(mapper->getArguments(), p,
1534 [&](auto arg) { p.printRegionArgument(arg); });
1535 p << ") ";
1537 p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1538 p.decreaseIndent();
1542 LogicalResult MapOp::verify() {
1543 auto *bodyBlock = getBody();
1544 auto blockArgs = bodyBlock->getArguments();
1546 // Checks if the number of `inputs` match the arity of the `mapper` region.
1547 if (getInputs().size() != blockArgs.size())
1548 return emitOpError() << "expects number of operands to match the arity of "
1549 "mapper, but got: "
1550 << getInputs().size() << " and " << blockArgs.size();
1552 // The parameters of mapper should all match the element type of inputs.
1553 for (const auto &[bbArgType, inputArg] :
1554 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1555 auto inputElemType =
1556 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1557 if (bbArgType != inputElemType) {
1558 return emitOpError() << "expected element type of input " << inputElemType
1559 << " to match bbArg type " << bbArgType;
1563 // The shape of each input must match the shape of the output.
1564 auto outputShape = getInit().getType().getShape();
1565 for (Type inputArgType : TypeRange{getInputs()}) {
1566 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1567 if (inputElemShape != outputShape) {
1568 return emitOpError() << "expected shape of input (" << inputElemShape
1569 << ") to match shape of output (" << outputShape
1570 << ")";
1574 return success();
1577 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1578 int64_t rank = getInit().getType().getRank();
1579 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1582 ArrayAttr MapOp::getIndexingMaps() {
1583 Builder builder(getContext());
1584 int64_t rank = getInit().getType().getRank();
1585 int64_t numIndexingMaps = getOperands().size();
1586 return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
1587 numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1590 void MapOp::getEffects(
1591 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1592 &effects) {
1593 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1596 Speculation::Speculatability MapOp::getSpeculatability() {
1597 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1600 //===----------------------------------------------------------------------===//
1601 // ReduceOp
1602 //===----------------------------------------------------------------------===//
1604 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1605 OpAsmSetValueNameFn setNameFn) {
1606 for (Value v : getRegionInputArgs())
1607 setNameFn(v, "in");
1608 for (Value v : getRegionOutputArgs())
1609 setNameFn(v, "init");
1612 void ReduceOp::getAsmResultNames(
1613 function_ref<void(Value, StringRef)> setNameFn) {
1614 if (!getResults().empty())
1615 setNameFn(getResults().front(), "reduced");
1618 void ReduceOp::build(
1619 OpBuilder &builder, OperationState &result, ValueRange inputs,
1620 ValueRange inits, ArrayRef<int64_t> dimensions,
1621 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1622 ArrayRef<NamedAttribute> attributes) {
1623 build(builder, result, TypeRange{}, inputs, inits, dimensions);
1624 result.addAttributes(attributes);
1626 // Add output types for `RankedTensorType` output arguments.
1627 for (Value init : inits) {
1628 Type initType = init.getType();
1629 if (llvm::isa<RankedTensorType>(initType))
1630 result.addTypes(initType);
1633 if (bodyBuild)
1634 buildGenericRegion(builder, result.location, *result.regions.front(),
1635 inputs, inits, bodyBuild);
1638 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1639 int64_t inputRank =
1640 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1641 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1642 utils::IteratorType::parallel);
1643 for (int64_t reductionDim : getDimensions())
1644 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1645 return iteratorTypes;
1648 ArrayAttr ReduceOp::getIndexingMaps() {
1649 int64_t inputRank =
1650 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1651 SmallVector<AffineMap> affineMaps(
1652 getNumDpsInputs(),
1653 AffineMap::getMultiDimIdentityMap(inputRank, getContext()));
1654 AffineMap resultMap =
1655 AffineMap::getMultiDimIdentityMap(inputRank, getContext())
1656 .dropResults(getDimensions());
1657 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1658 affineMaps.push_back(resultMap);
1659 return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1662 void ReduceOp::getEffects(
1663 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1664 &effects) {
1665 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1668 Speculation::Speculatability ReduceOp::getSpeculatability() {
1669 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1672 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1673 NamedAttrList &attributes,
1674 StringRef attributeName) {
1675 if (parser.parseKeyword(attributeName) || parser.parseEqual())
1676 return failure();
1678 attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1679 return success();
1682 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1683 std::optional<OperationName> payloadOpName;
1684 NamedAttrList payloadOpAttrs;
1685 if (succeeded(parser.parseOptionalLBrace())) {
1686 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1687 if (failed(operationName))
1688 return failure();
1689 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1690 return failure();
1691 payloadOpName = operationName.value();
1692 if (parser.parseRBrace())
1693 return failure();
1696 if (parseDstStyleOp(
1697 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1698 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1700 return failure();
1702 if (payloadOpName.has_value()) {
1703 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1704 ArrayRef(result.operands), /*initFirst=*/true);
1705 } else {
1706 SmallVector<OpAsmParser::Argument> regionArgs;
1707 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1708 /*allowType=*/true, /*allowAttrs=*/true)) {
1709 return failure();
1712 Region *body = result.addRegion();
1713 if (parser.parseRegion(*body, regionArgs))
1714 return failure();
1717 return success();
1720 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1721 ArrayRef<int64_t> attributeValue) {
1722 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1725 void ReduceOp::print(OpAsmPrinter &p) {
1726 Block *mapper = getBody();
1727 Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1728 if (payloadOp) {
1729 printShortForm(p, payloadOp);
1732 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1733 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1734 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1735 if (!payloadOp) {
1736 // Print region if the payload op was not detected.
1737 p.increaseIndent();
1738 p.printNewline();
1739 p << "(";
1740 llvm::interleaveComma(mapper->getArguments(), p,
1741 [&](auto arg) { p.printRegionArgument(arg); });
1742 p << ") ";
1744 p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1745 p.decreaseIndent();
1749 LogicalResult ReduceOp::verify() {
1750 ArrayRef<int64_t> dimensionsRef = getDimensions();
1752 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1753 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1754 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1755 return emitOpError() << "expects all inputs to have the same shapes. "
1756 "Shape at input-index "
1757 << i
1758 << " is not equal to the shape at input-index 0.";
1761 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1762 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1763 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1764 return emitOpError() << "expects all outputs to have the same shapes. "
1765 "Shape at output-index "
1766 << i
1767 << " is not equal to the shape at output-index 0.";
1770 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1771 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1773 DenseSet<int64_t> dimensionsToReduce;
1774 for (int64_t dimension : dimensionsRef) {
1775 if (dimension < 0 || dimension >= inputType.getRank()) {
1776 return emitOpError()
1777 << "dimensions for reduction should be in the range [0, "
1778 << inputType.getRank() - 1 << "].";
1780 dimensionsToReduce.insert(dimension);
1783 auto inputDims = inputType.getShape();
1784 auto initDims = initType.getShape();
1786 // Input dimensions that will be left after the reduction.
1787 SmallVector<int64_t> reducedInputDims;
1788 for (const auto &en : llvm::enumerate(inputDims)) {
1789 if (!dimensionsToReduce.count(en.index()))
1790 reducedInputDims.push_back(en.value());
1793 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1794 return emitOpError() << "number of dimensions after reduction "
1795 << reducedInputDims.size()
1796 << " doesn't match the init rank "
1797 << initType.getRank();
1800 if (reducedInputDims != initDims)
1801 return emitOpError() << "init dimensions [" << initDims
1802 << "] doesn't match input dimensions after reduction ["
1803 << reducedInputDims << "]";
1805 Block *block = getBody();
1806 if (block->getNumArguments() != this->getNumOperands())
1807 return emitOpError()
1808 << "mismatching number of operands and block arguments";
1810 // Check that the first block arguments match the element type of the inputs.
1811 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1812 Type inputElementType =
1813 llvm::cast<ShapedType>(input.getType()).getElementType();
1814 if (inputElementType != bbArg.getType())
1815 return emitOpError()
1816 << "input element type " << inputElementType
1817 << " does not match corresponding block argument type "
1818 << bbArg.getType();
1821 // Check that the last block arguments match the element type of the outputs.
1822 for (auto [output, bbArg] : llvm::zip(
1823 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1824 auto outputElementType =
1825 llvm::cast<ShapedType>(output.getType()).getElementType();
1826 if (outputElementType != bbArg.getType())
1827 return emitOpError()
1828 << "output element type " << outputElementType
1829 << " does not match corresponding block argument type "
1830 << bbArg.getType();
1832 return success();
1835 //===----------------------------------------------------------------------===//
1836 // TransposeOp
1837 //===----------------------------------------------------------------------===//
1839 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1840 Region &region, ValueRange inputs,
1841 ValueRange outputs) {
1842 buildGenericRegion(builder, loc, region, inputs, outputs,
1843 [](OpBuilder &b, Location loc, ValueRange args) {
1844 if (!args.empty())
1845 b.create<linalg::YieldOp>(loc, args[0]);
1849 void TransposeOp::build(::mlir::OpBuilder &builder,
1850 ::mlir::OperationState &result, Value input, Value init,
1851 DenseI64ArrayAttr permutation,
1852 ArrayRef<NamedAttribute> attributes) {
1853 result.addOperands(input);
1854 result.addOperands(init);
1855 result.addAttribute(getPermutationAttrName(result.name), permutation);
1856 result.addAttributes(attributes);
1858 // Add output types for `RankedTensorType` output arguments.
1859 Type initType = init.getType();
1860 if (llvm::isa<RankedTensorType>(initType))
1861 result.addTypes(initType);
1863 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1864 init);
1867 void TransposeOp::build(::mlir::OpBuilder &builder,
1868 ::mlir::OperationState &result, Value input, Value init,
1869 ArrayRef<int64_t> permutation,
1870 ArrayRef<NamedAttribute> attributes) {
1871 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1872 attributes);
1875 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1876 if (failed(parseDstStyleOp(
1877 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1878 return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1879 })))
1880 return failure();
1882 OpBuilder builder(parser.getContext());
1883 buildIdentityRegion(builder, result.location, *result.addRegion(),
1884 /*inputs=*/result.operands,
1885 /*outputs=*/{});
1886 return success();
1889 void TransposeOp::getAsmResultNames(
1890 function_ref<void(Value, StringRef)> setNameFn) {
1891 if (!getResults().empty())
1892 setNameFn(getResults().front(), "transposed");
1895 void TransposeOp::print(OpAsmPrinter &p) {
1896 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1897 printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1898 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1901 LogicalResult TransposeOp::verify() {
1902 ArrayRef<int64_t> permutationRef = getPermutation();
1904 if (!isPermutationVector(permutationRef))
1905 return emitOpError("permutation is not valid");
1907 auto inputType = getInput().getType();
1908 auto initType = getInit().getType();
1910 int64_t rank = inputType.getRank();
1912 if (rank != initType.getRank())
1913 return emitOpError() << "input rank " << rank
1914 << " does not match init rank " << initType.getRank();
1916 if (rank != static_cast<int64_t>(permutationRef.size()))
1917 return emitOpError() << "size of permutation " << permutationRef.size()
1918 << " does not match the argument rank " << rank;
1920 auto inputDims = inputType.getShape();
1921 auto initDims = initType.getShape();
1923 for (int64_t i = 0; i < rank; ++i) {
1924 int64_t inputDim = inputDims[permutationRef[i]];
1925 int64_t initDim = initDims[i];
1927 if (inputDim != initDim) {
1928 return emitOpError() << "dim(result, " << i << ") = " << initDim
1929 << " doesn't match dim(input, permutation[" << i
1930 << "]) = " << inputDim;
1934 return success();
1937 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1938 int64_t rank = getInit().getType().getRank();
1939 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1942 ArrayAttr TransposeOp::getIndexingMaps() {
1943 Builder builder(getContext());
1944 int64_t rank = getInit().getType().getRank();
1945 return builder.getAffineMapArrayAttr(
1946 {inversePermutation(AffineMap::getPermutationMap(
1947 llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1948 builder.getMultiDimIdentityMap(rank)});
1951 void TransposeOp::getEffects(
1952 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1953 &effects) {
1954 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1957 Speculation::Speculatability TransposeOp::getSpeculatability() {
1958 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1961 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1962 SmallVectorImpl<OpFoldResult> &result) {
1963 // Only the tensor type is supported.
1964 if (!isa<TensorType>(getInput().getType()))
1965 return failure();
1967 // Single dimension transpose.
1968 if (getPermutation().size() == 0) {
1969 result.push_back(getInput());
1970 return success();
1972 // Identity permutation.
1973 if (isIdentityPermutation(getPermutation())) {
1974 result.push_back(getInput());
1975 return success();
1978 return failure();
1981 /// Fold transpose with transpose.
1982 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1983 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1985 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1986 PatternRewriter &rewriter) const override {
1987 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
1988 if (!defTransposeOp)
1989 return failure();
1990 ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
1991 ArrayRef<int64_t> perms = transposeOp.getPermutation();
1992 SmallVector<int64_t> foldedPerms;
1993 foldedPerms.reserve(perms.size());
1994 for (int64_t perm : perms)
1995 foldedPerms.push_back(defPerms[perm]);
1997 rewriter.replaceOpWithNewOp<TransposeOp>(
1998 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
1999 foldedPerms);
2000 return success();
2004 /// This pattern canonicalize transpose by swapping the order of
2005 /// broadcast and transpose:
2006 /// transpose(broadcast(input)) -> broadcast(transpose(input))
2007 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2008 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2010 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2011 PatternRewriter &rewriter) const override {
2012 Value input = transposeOp.getInput();
2013 BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2014 if (!input.hasOneUse() || !broadcastOp)
2015 return failure();
2017 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2018 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2020 // Get new perms and new dimensions.
2021 SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2022 SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
2023 SmallVector<int64_t> resultDimensions;
2024 unsigned dimensionSize = dimensions.size();
2025 for (unsigned i = 0; i < dimensionSize; ++i)
2026 resultDimensions.push_back(invertPerm[dimensions[i]]);
2028 // Create transpose result.
2029 Value broadcastInput = broadcastOp.getInput();
2030 Location loc = transposeOp.getLoc();
2031 MLIRContext *ctx = transposeOp.getContext();
2032 SmallVector<OpFoldResult> dims;
2033 auto broadcastInputTy =
2034 mlir::cast<RankedTensorType>(broadcastInput.getType());
2035 unsigned inputRank = broadcastInputTy.getRank();
2036 for (unsigned i = 0; i < inputRank; ++i) {
2037 if (broadcastInputTy.isDynamicDim(i)) {
2038 dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
2039 ->getResult(0));
2040 } else {
2041 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2042 broadcastInputTy.getDimSize(i)));
2045 SmallVector<OpFoldResult> transposeResultShapes =
2046 applyPermutation(dims, resultPerms);
2047 Value transposeInit = rewriter.create<tensor::EmptyOp>(
2048 transposeOp.getLoc(), transposeResultShapes,
2049 broadcastInputTy.getElementType());
2051 // Create broadcast(transpose(input)).
2052 Value transposeResult =
2053 rewriter
2054 .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2055 resultPerms)
2056 ->getResult(0);
2057 rewriter.replaceOpWithNewOp<BroadcastOp>(
2058 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2059 return success();
2063 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2064 MLIRContext *context) {
2065 results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2068 //===----------------------------------------------------------------------===//
2069 // BroadcastOp
2070 //===----------------------------------------------------------------------===//
2072 void BroadcastOp::build(::mlir::OpBuilder &builder,
2073 ::mlir::OperationState &result, Value input, Value init,
2074 DenseI64ArrayAttr dimensions,
2075 ArrayRef<NamedAttribute> attributes) {
2076 result.addOperands(input);
2077 result.addOperands(init);
2078 result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2079 result.addAttributes(attributes);
2081 // Add output types for `RankedTensorType` output arguments.
2082 Type initType = init.getType();
2083 if (llvm::isa<RankedTensorType>(initType))
2084 result.addTypes(initType);
2086 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2087 init);
2090 void BroadcastOp::build(::mlir::OpBuilder &builder,
2091 ::mlir::OperationState &result, Value input, Value init,
2092 ArrayRef<int64_t> dimensions,
2093 ArrayRef<NamedAttribute> attributes) {
2094 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2095 attributes);
2098 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2099 if (failed(parseDstStyleOp(
2100 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2101 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2102 })))
2103 return failure();
2105 OpBuilder builder(parser.getContext());
2106 buildIdentityRegion(builder, result.location, *result.addRegion(),
2107 /*inputs=*/result.operands,
2108 /*outputs=*/{});
2109 return success();
2112 void BroadcastOp::getAsmResultNames(
2113 function_ref<void(Value, StringRef)> setNameFn) {
2114 if (!getResults().empty())
2115 setNameFn(getResults().front(), "broadcasted");
2118 void BroadcastOp::print(OpAsmPrinter &p) {
2119 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2120 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2121 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2124 LogicalResult BroadcastOp::verify() {
2125 ArrayRef<int64_t> dimensionsRef = getDimensions();
2127 auto inputType = getInput().getType();
2128 auto initType = getInit().getType();
2130 int64_t inputRank = inputType.getRank();
2131 int64_t initRank = initType.getRank();
2133 auto inputShape = inputType.getShape();
2134 auto initShape = initType.getShape();
2136 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2137 return emitOpError() << "input rank plus added dimensions does not "
2138 "match init rank. input rank: "
2139 << inputRank
2140 << ", dimensions size: " << dimensionsRef.size()
2141 << ", init rank: " << initRank;
2143 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2144 if (dim < 0 || dim >= initRank)
2145 return emitOpError() << "dimension " << idx
2146 << " is out of range. expected range: [0, "
2147 << initRank - 1 << "], got: " << dim;
2150 // Mapping from input dims to init dims.
2151 SmallVector<int64_t> dimMap;
2152 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2153 if (!llvm::is_contained(dimensionsRef, dim))
2154 dimMap.push_back(dim);
2157 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2158 // This dimensions is mapped from the input. Init and input dims should
2159 // match.
2160 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2161 return emitOpError() << "input dim " << inputDimIdx
2162 << " should match init dim " << initDimIdx
2163 << ". input: " << inputShape[inputDimIdx]
2164 << ", init: " << initShape[initDimIdx];
2167 return success();
2170 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2171 int64_t rank = getInit().getType().getRank();
2172 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2175 ArrayAttr BroadcastOp::getIndexingMaps() {
2176 Builder builder(getContext());
2177 int64_t rank = getInit().getType().getRank();
2178 return builder.getAffineMapArrayAttr(
2179 {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2180 builder.getMultiDimIdentityMap(rank)});
2183 void BroadcastOp::getEffects(
2184 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2185 &effects) {
2186 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2189 Speculation::Speculatability BroadcastOp::getSpeculatability() {
2190 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2193 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2194 MLIRContext *context) {
2195 results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2198 //===----------------------------------------------------------------------===//
2199 // YieldOp
2200 //===----------------------------------------------------------------------===//
2202 void linalg::YieldOp::print(OpAsmPrinter &p) {
2203 if (getNumOperands() > 0)
2204 p << ' ' << getOperands();
2205 p.printOptionalAttrDict((*this)->getAttrs());
2206 if (getNumOperands() > 0)
2207 p << " : " << getOperandTypes();
2210 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2211 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2212 SmallVector<Type, 2> types;
2213 SMLoc loc = parser.getCurrentLocation();
2214 return failure(parser.parseOperandList(opInfo) ||
2215 parser.parseOptionalAttrDict(result.attributes) ||
2216 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2217 parser.resolveOperands(opInfo, types, loc, result.operands));
2220 // Check the operand number and types must match the element types of the
2221 // LinalgOp interface's shaped operands.
2222 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2223 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2224 return op.emitOpError("expected number of yield values (")
2225 << op.getNumOperands()
2226 << ") to match the number of inits / outs operands of the enclosing "
2227 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2229 for (OpOperand &opOperand : op->getOpOperands()) {
2230 OpOperand *outputOperand =
2231 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2232 Type elementType = outputOperand->get().getType();
2233 if (isa<MemRefType, RankedTensorType>(elementType))
2234 elementType = getElementTypeOrSelf(outputOperand->get().getType());
2235 if (opOperand.get().getType() != elementType)
2236 return op.emitOpError("type of yield operand ")
2237 << (opOperand.getOperandNumber() + 1) << " ("
2238 << opOperand.get().getType() << ") doesn't match "
2239 << "the element type of the enclosing linalg.generic op ("
2240 << elementType << ")";
2242 return success();
2245 LogicalResult linalg::YieldOp::verify() {
2246 auto *parentOp = (*this)->getParentOp();
2247 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2248 return emitOpError("expected single non-empty parent region");
2250 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2251 return verifyYield(*this, linalgOp);
2253 return emitOpError("expected parent op with LinalgOp interface");
2256 //===----------------------------------------------------------------------===//
2257 // IndexOp
2258 //===----------------------------------------------------------------------===//
2260 LogicalResult IndexOp::verify() {
2261 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2262 if (!linalgOp)
2263 return emitOpError("expected parent op with LinalgOp interface");
2264 if (linalgOp.getNumLoops() <= getDim())
2265 return emitOpError("expected dim (")
2266 << getDim() << ") to be lower than the number of loops ("
2267 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2268 return success();
2271 /////// Operations corresponding to library calls defined with Tablegen ////////
2273 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2275 #define GET_OP_CLASSES
2276 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2278 #define GET_OP_CLASSES
2279 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2281 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2282 unsigned rank,
2283 MLIRContext *context) {
2284 if (maybeMap)
2285 return *maybeMap;
2286 if (rank == 0)
2287 return AffineMap::get(context);
2288 return AffineMap::getMultiDimIdentityMap(rank, context);
2291 SmallVector<AffineExpr, 4>
2292 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2293 MLIRContext *context) {
2294 SmallVector<AffineExpr, 4> res;
2295 res.reserve(num);
2296 for (unsigned i = 0; i < num; ++i)
2297 res.push_back(getAffineDimExpr(startIdx++, context));
2298 return res;
2301 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
2302 ArrayRef<AffineExpr> b) {
2303 auto rangeA = llvm::make_range(a.begin(), a.end());
2304 auto rangeB = llvm::make_range(b.begin(), b.end());
2305 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2306 return llvm::to_vector<4>(concatRanges);
2309 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2310 if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2311 ss << "view";
2312 for (auto size : memref.getShape())
2313 if (size < 0)
2314 ss << "sx";
2315 else
2316 ss << size << "x";
2317 if (failed(appendMangledType(ss, memref.getElementType())))
2318 return failure();
2319 if (auto as = memref.getMemorySpace()) {
2320 if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2321 ss << "as" << attr.getInt();
2322 else
2323 return failure();
2325 return success();
2327 if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2328 ss << "vector";
2329 llvm::interleave(
2330 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2331 if (failed(appendMangledType(ss, vec.getElementType())))
2332 return failure();
2333 return success();
2335 if (t.isSignlessIntOrIndexOrFloat()) {
2336 ss << t;
2337 return success();
2339 return failure();
2342 std::string mlir::linalg::generateLibraryCallName(Operation *op) {
2343 assert(isa<LinalgOp>(op));
2344 std::string name(op->getName().getStringRef().str());
2345 std::string fun = "";
2346 for (NamedAttribute kv : op->getAttrs()) {
2347 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2348 fun = stringifyEnum(ufa.getValue()).str() + "_";
2349 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2350 fun = stringifyEnum(bfa.getValue()).str() + "_";
2353 name.reserve(128);
2354 std::replace(name.begin(), name.end(), '.', '_');
2355 llvm::raw_string_ostream ss(name);
2356 ss << "_" << fun;
2357 for (Type t : op->getOperandTypes()) {
2358 if (failed(appendMangledType(ss, t)))
2359 return std::string();
2360 ss << "_";
2362 name.pop_back();
2363 return name;
2366 //===----------------------------------------------------------------------===//
2367 // Canonicalizers and Folders.
2368 //===----------------------------------------------------------------------===//
2370 namespace {
2371 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2372 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2374 LogicalResult matchAndRewrite(LinalgOp op,
2375 PatternRewriter &rewriter) const override {
2376 for (OpOperand &opOperand : op->getOpOperands()) {
2377 // Linalg "inputs" may be either tensor or memref type.
2378 // tensor<0xelt_type> is a convention that may not always mean
2379 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2380 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2381 if (!mt)
2382 continue;
2383 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2384 rewriter.eraseOp(op);
2385 return success();
2388 return failure();
2392 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2393 /// result that is more static than the linalg op.
2394 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2395 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2397 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2398 PatternRewriter &rewriter) const override {
2399 if (!tensor::canFoldIntoProducerOp(castOp))
2400 return failure();
2402 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2403 if (!linalgOp)
2404 return failure();
2406 // Cast can be in conditionally reachable region, if which case folding will
2407 // generate invalid code. Only conservatively fold ops in same block for
2408 // now.
2409 if (castOp->getBlock() != linalgOp->getBlock())
2410 return failure();
2412 OpBuilder::InsertionGuard guard(rewriter);
2413 rewriter.setInsertionPoint(linalgOp);
2415 Location loc = linalgOp.getLoc();
2416 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2417 unsigned resultNumber = resultValue.getResultNumber();
2418 auto resultType =
2419 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2420 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2421 // going from a more dynamic shape to a less dynamic shape. If the producer
2422 // for this cast, i.e. producer of the out operand, is also an operation
2423 // that folds with tensor.cast consumer (like this pattern), the cast will
2424 // continue to propagate as far up the stack as it can go.
2425 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2426 Value newOperand =
2427 rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2428 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2429 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2430 linalgOp.getDpsInits().end());
2431 outputOperands[resultNumber] = newOperand;
2432 newOperands.append(outputOperands.begin(), outputOperands.end());
2434 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2435 linalgOp->result_type_end());
2436 resultTypes[resultNumber] = resultType;
2437 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2439 // Create a tensor.cast operation back to the original type.
2440 Value castBack = rewriter.create<tensor::CastOp>(
2441 loc, resultValue.getType(), newOp->getResult(resultNumber));
2443 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2444 results[resultNumber] = castBack;
2445 rewriter.replaceOp(linalgOp, results);
2446 rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2447 return success();
2451 /// For each of the operand in `operands` this function maps the static sizes of
2452 /// dimensions to their affine dim expressions.
2453 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2454 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2455 for (OpOperand &opOperand : operands) {
2456 if (linalgOp.isScalar(&opOperand))
2457 continue;
2458 Value src = opOperand.get();
2459 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2460 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2462 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2463 // `tensor.cast` operation and source of the cast operation has a static
2464 // shape, then assign it to the `sourceShape`.
2465 auto *parentOp = src.getDefiningOp();
2466 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2467 if (parentOp) {
2468 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2469 Value castSource = castOp.getSource();
2470 auto castSourceType =
2471 llvm::dyn_cast<RankedTensorType>(castSource.getType());
2472 if (castSourceType && castSourceType.hasStaticShape())
2473 sourceShape = castSourceType.getShape();
2477 // If the source shape's dimension has a static shape, map the affine dim
2478 // expression to the known static size.
2479 for (unsigned i = 0; i < sourceShape.size(); i++) {
2480 if (sourceType.isDynamicDim(i))
2481 continue;
2482 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2483 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2488 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2489 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2490 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2491 /// change then `changeNeeded` is false and same operand is added in the
2492 /// `newOperands` list.
2493 static void createNewOperandWithStaticSizes(
2494 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2495 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2496 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2497 bool &changeNeeded) {
2498 Value src = opOperand->get();
2499 newOperands.push_back(src);
2500 if (linalgOp.isScalar(opOperand))
2501 return;
2502 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2503 Type resultType = sourceType;
2504 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2505 resultTypes.push_back(resultType);
2506 return;
2508 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2509 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2510 SmallVector<int64_t> newShape;
2511 // If operand is updated with new shape, `newOperandNeeded` will be
2512 // true.
2513 bool newOperandNeeded = false;
2514 for (unsigned i = 0; i < sourceShape.size(); i++) {
2515 int64_t dimShape = sourceShape[i];
2516 AffineExpr dimExpr = sourceMap.getResult(i);
2517 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2518 newShape.push_back(dimShape);
2519 continue;
2521 // Dimension has a dynamic shape and corresponding affine dim
2522 // expression is present in the map. So assign the size for the
2523 // given affine dim expression to the dimension.
2524 newShape.push_back(affineExprToSize[dimExpr]);
2525 newOperandNeeded = true;
2527 resultType = RankedTensorType::get(newShape, sourceType.getElementType());
2528 if (newOperandNeeded) {
2529 changeNeeded = true;
2530 // Get the new operand value given its size and element type by
2531 // casting it.
2532 Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2533 unsigned index = opOperand->getOperandNumber();
2534 newOperands[index] = newOperand;
2536 if (linalgOp.isDpsInit(opOperand))
2537 resultTypes.push_back(resultType);
2540 /// Static shapes for the operands can be inferred if any one of the operands
2541 /// have a static shape. This can be done by referring to the affine dim
2542 /// expressions for the operand.
2543 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2544 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2546 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2547 PatternRewriter &rewriter) const override {
2548 if (!linalgOp.hasPureTensorSemantics())
2549 return failure();
2551 // Maps must be projected permutations.
2552 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2553 return !map.isProjectedPermutation();
2555 return failure();
2557 // Maps affine dim expressions to the static size of that dimension.
2558 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2559 Location loc = linalgOp.getLoc();
2561 // For each of the affine dim expression, check if the size is known. If
2562 // known add that in the map.
2563 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2565 SmallVector<Value> newOperands;
2566 SmallVector<Type> resultTypes;
2568 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2569 // change in their types.
2570 bool changeNeeded = false;
2571 newOperands.reserve(linalgOp->getNumOperands());
2572 resultTypes.reserve(linalgOp.getNumDpsInits());
2574 // Iterate over all the operands and update the static sizes.
2575 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2576 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2577 affineExprToSize, linalgOp, newOperands,
2578 resultTypes, changeNeeded);
2581 // If the generic op has all the required static information, no
2582 // canonicalization needed.
2583 if (!changeNeeded)
2584 return failure();
2586 // Clone op.
2587 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2588 SmallVector<Value> replacements;
2589 replacements.reserve(newOp->getNumResults());
2590 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2591 Value newResult = std::get<1>(it);
2592 Value oldResult = std::get<0>(it);
2593 Type newType = newResult.getType();
2594 Type oldType = oldResult.getType();
2595 replacements.push_back(
2596 (newType != oldType)
2597 ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2598 : newResult);
2600 rewriter.replaceOp(linalgOp, replacements);
2601 return success();
2605 } // namespace
2607 // All named ops canonicalizers and folders are auto-generated in the
2608 // .cpp.inc.
2610 //===----------------------------------------------------------------------===//
2611 // SoftmaxOp
2612 //===----------------------------------------------------------------------===//
2614 LogicalResult SoftmaxOp::verify() {
2615 ShapedType inputType = getInputOperandType();
2616 ShapedType outputType = getOutputOperandType();
2618 ArrayRef<int64_t> inputShape = inputType.getShape();
2619 ArrayRef<int64_t> outputShape = outputType.getShape();
2620 if (failed(verifyCompatibleShape(inputShape, outputShape)))
2621 return emitOpError("incompatible output shape");
2623 int64_t inputRank = getInputOperandRank();
2624 int64_t dimension = getDimension();
2625 if ((dimension < 0) || (dimension >= inputRank))
2626 return emitOpError("incorrect dimension specified");
2628 return success();
2631 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2632 int64_t operandRank = getInputOperandRank();
2633 SmallVector<Range> loopBounds(operandRank);
2634 Location loc = getLoc();
2635 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2636 Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2637 Value source = getInput();
2638 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2639 loopBounds[dim].offset = zero;
2640 loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2641 loopBounds[dim].stride = one;
2643 return loopBounds;
2646 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2647 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2648 utils::IteratorType::parallel);
2649 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2650 return iteratorTypes;
2653 FailureOr<TilingResult>
2654 SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2655 ArrayRef<OpFoldResult> offsets,
2656 ArrayRef<OpFoldResult> sizes) {
2657 int64_t rank = getInputOperandRank();
2658 auto oneAttr = builder.getI64IntegerAttr(1);
2659 SmallVector<OpFoldResult> strides(rank, oneAttr);
2660 SmallVector<Value> tiledOperands;
2661 Operation *inputSlice =
2662 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2663 if (!inputSlice) {
2664 return emitOpError("failed to compute input slice");
2666 tiledOperands.emplace_back(inputSlice->getResult(0));
2667 Operation *outputSlice =
2668 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2669 if (!outputSlice) {
2670 return emitOpError("failed to compute output slice");
2672 tiledOperands.emplace_back(outputSlice->getResult(0));
2674 SmallVector<Type, 4> resultTypes;
2675 if (hasPureTensorSemantics())
2676 resultTypes.push_back(tiledOperands[1].getType());
2677 Operation *tiledOp =
2678 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2680 return TilingResult{
2681 {tiledOp},
2682 SmallVector<Value>(tiledOp->getResults()),
2683 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2686 LogicalResult SoftmaxOp::getResultTilePosition(
2687 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2688 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2689 SmallVector<OpFoldResult> &resultSizes) {
2690 if (resultNumber == 0) {
2691 resultOffsets.assign(offsets.begin(), offsets.end());
2692 resultSizes.assign(sizes.begin(), sizes.end());
2693 return success();
2695 return failure();
2698 // cast(dynamic) -> static.
2699 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2700 return memref::foldMemRefCast(*this);
2703 LogicalResult
2704 SoftmaxOp::reifyResultShapes(OpBuilder &b,
2705 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2706 SmallVector<OpFoldResult> shapes;
2707 Location loc = getOperation()->getLoc();
2708 IRRewriter rewriter(b);
2709 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2710 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2711 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2712 if (!outputShapedType.isDynamicDim(dim)) {
2713 // Static dim: Return IntegerAttr.
2714 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2715 } else {
2716 // Dynamic dim: Return Value.
2717 OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2718 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2721 reifiedReturnShapes.emplace_back(std::move(shapes));
2722 return success();
2725 void SoftmaxOp::getEffects(
2726 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2727 &effects) {
2728 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2729 if (!llvm::isa<MemRefType>(operand.getType()))
2730 continue;
2731 effects.emplace_back(MemoryEffects::Read::get(),
2732 &getOperation()->getOpOperand(index), /*stage=*/0,
2733 /*effectOnFullRegion=*/true,
2734 SideEffects::DefaultResource::get());
2737 for (OpOperand &operand : getDpsInitsMutable()) {
2738 if (!llvm::isa<MemRefType>(operand.get().getType()))
2739 continue;
2740 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2741 /*effectOnFullRegion=*/true,
2742 SideEffects::DefaultResource::get());
2743 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2744 /*effectOnFullRegion=*/true,
2745 SideEffects::DefaultResource::get());
2749 // Helper functions for softmax decomposition.
2750 // @{
2752 // Helper function to produce the iterator types (reduction or parallel) and
2753 // affine maps for the iterators used in the decomposition of softmax.
2754 // This method creates:
2755 // If allParallel == true:
2756 // - iterator type: {parallel, ..., parallel}
2757 // - affine maps:
2758 // -- identity with inputRank dimensions.
2759 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2760 // where N == inputRank.
2762 // If allParallel == false:
2763 // - iterator type at dim(i) == parallel for i != \p dim and
2764 // dim(dim) == reduction.
2765 // - affine map:
2766 // -- identity with inputRank dimensions.
2767 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2768 // where N == inputRank.
2769 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2770 computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
2771 int64_t dim, bool allParallel = false) {
2772 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2773 utils::IteratorType::parallel);
2774 if (!allParallel)
2775 iteratorTypes[dim] = utils::IteratorType::reduction;
2776 MLIRContext *ctxt = builder.getContext();
2777 auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2778 SmallVector<AffineExpr, 2> affineExprs;
2779 for (int i = 0; i < inputRank; i++) {
2780 if (i != dim)
2781 affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2783 auto reductionMap =
2784 AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2785 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2786 return std::make_tuple(iteratorTypes, indexingMaps);
2789 // Helper function to produce a linalg.generic that computes a reduction on
2790 // dimension \p dim with the operation type \p T.
2791 template <typename T>
2792 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2793 int64_t dim) {
2794 auto inputType = cast<ShapedType>(input.getType());
2795 ArrayRef<int64_t> inputShape = inputType.getShape();
2796 int64_t inputRank = inputShape.size();
2797 auto [iteratorTypes, indexingMaps] =
2798 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2799 assert(indexingMaps.size() == 2 &&
2800 "We should have two maps: 1 for the input, 1 for the output");
2801 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2803 auto genericOp = builder.create<linalg::GenericOp>(
2804 loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2805 [&](OpBuilder &b, Location loc, ValueRange args) {
2806 Value result = b.create<T>(loc, args[0], args[1]);
2807 b.create<linalg::YieldOp>(loc, result);
2809 return genericOp.getResult(0);
2812 /// Produce a linalg generic that computes the second step of the softmax
2813 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2814 /// on dimension \p dim.
2815 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2816 Value max, Value output, int64_t dim) {
2817 auto inputType = cast<ShapedType>(input.getType());
2818 ArrayRef<int64_t> inputShape = inputType.getShape();
2819 int64_t inputRank = inputShape.size();
2820 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2821 builder, inputRank, dim, /*allParallel=*/true);
2822 assert(indexingMaps.size() == 2 && "We should have one map for each input");
2823 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2824 // Add the affine map for the output argument.
2825 indexingMaps.push_back(indexingMaps[0]);
2826 auto genericOp = builder.create<linalg::GenericOp>(
2827 loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2828 iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2829 Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2830 Value result = b.create<math::ExpOp>(loc, diff);
2831 b.create<linalg::YieldOp>(loc, result);
2833 return genericOp.getResult(0);
2836 /// Produce a linalg generic that computes the final step of the softmax
2837 /// decomposition.
2838 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2839 /// yield n / d
2840 /// }
2841 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2842 Value denominator, Value output, int64_t dim) {
2843 auto inputType = cast<ShapedType>(numerator.getType());
2844 ArrayRef<int64_t> inputShape = inputType.getShape();
2845 int64_t inputRank = inputShape.size();
2846 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2847 builder, inputRank, dim, /*allParallel=*/true);
2848 assert(indexingMaps.size() == 2 &&
2849 "We should have one map for each input (2)");
2850 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2851 // Add the affine map for the output tensor.
2852 indexingMaps.push_back(indexingMaps[0]);
2853 auto genericOp = builder.create<linalg::GenericOp>(
2854 loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2855 indexingMaps, iteratorTypes,
2856 [&](OpBuilder &b, Location loc, ValueRange args) {
2857 Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2858 b.create<linalg::YieldOp>(loc, result);
2860 return genericOp.getResult(0);
2862 // @} End helper functions for softmax decomposition.
2864 /// Given an N-dimensional tensor x, this method converts
2865 /// softmax(x) to the following sequence of operations:
2867 /// 1. Compute the max of x along dimension d. This results
2868 /// in a N-1 dimensional tensor m.
2869 /// m = max(x, dim = d)
2871 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2872 /// a N dimensional tensor z.
2873 /// z = exp(x - m)
2875 /// 3. Compute the sum of z along dimension d. This results in
2876 /// a N-1 dimensional tensor l.
2877 /// l = sum(z, dim = d)
2879 /// 4. Divide z and l. This gives the N-dimensional softmax.
2880 /// softmax = z / l
2882 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2883 OpBuilder::InsertionGuard guard(b);
2884 b.setInsertionPoint(*this);
2885 Location loc = getLoc();
2886 Value input = getInput();
2887 ShapedType inputType = getInputOperandType();
2888 Type elementType = inputType.getElementType();
2889 int64_t reductionDim = getDimension();
2890 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2891 Value output = getOutput();
2892 dims.erase(dims.begin() + reductionDim);
2893 // Step 1: Compute max along dim.
2894 Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2895 Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
2896 elementType, b, loc,
2897 /*useOnlyFiniteValue=*/true);
2898 Value neutralForMaxFInit =
2899 b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2900 .result();
2901 Value max =
2902 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2904 // Step 2: Subtract max from input and exponentiate.
2905 Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2907 // Step 3: Compute sum along dim.
2908 Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2909 b, loc, /*useOnlyFiniteValue=*/true);
2910 Value zeroInit =
2911 b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2912 Value denominator =
2913 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2915 // Step 4: Compute softmax.
2916 Value result =
2917 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2918 return SmallVector<Value>{result};
2921 //===----------------------------------------------------------------------===//
2922 // WinogradFilterTransformOp
2923 //===----------------------------------------------------------------------===//
2925 LogicalResult WinogradFilterTransformOp::verify() {
2926 auto filterType = cast<ShapedType>(getFilter().getType());
2927 ArrayRef<int64_t> filterShape = filterType.getShape();
2928 int64_t filterH = filterShape[getFilterHDim()];
2929 int64_t filterW = filterShape[getFilterWDim()];
2930 int64_t r = getR();
2931 int64_t m = getM();
2933 if (filterH != r && filterH != 1)
2934 return emitOpError("expect filter height either equals to r or 1");
2935 if (filterW != r && filterW != 1)
2936 return emitOpError("expect filter width either equals to r or 1");
2937 if (filterH == 1 && filterW == 1)
2938 return emitOpError("expect either filter height or width equals to r");
2940 SmallVector<int64_t> expectedOutputShape;
2941 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2942 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2943 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2944 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2946 auto outputType = cast<ShapedType>(getOutput().getType());
2947 ArrayRef<int64_t> outputShape = outputType.getShape();
2948 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2949 return emitOpError("the output shape is not expected");
2951 return success();
2954 SmallVector<Range>
2955 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
2956 Location loc = getLoc();
2957 IntegerAttr zeroAttr = builder.getIndexAttr(0);
2958 IntegerAttr oneAttr = builder.getIndexAttr(1);
2959 Value filter = getFilter();
2960 int64_t filterRank = getFilterOperandRank();
2961 SmallVector<Range> loopBounds(filterRank);
2962 for (unsigned dim = 0; dim < filterRank; ++dim) {
2963 loopBounds[dim].offset = zeroAttr;
2964 loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
2965 loopBounds[dim].stride = oneAttr;
2967 return loopBounds;
2970 SmallVector<utils::IteratorType>
2971 WinogradFilterTransformOp::getLoopIteratorTypes() {
2972 int64_t filterRank = getFilterOperandRank();
2973 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
2974 utils::IteratorType::parallel);
2975 return iteratorTypes;
2978 LogicalResult WinogradFilterTransformOp::getResultTilePosition(
2979 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2980 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2981 SmallVector<OpFoldResult> &resultSizes) {
2982 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
2983 ShapedType filterType = getFilterOperandType();
2984 ArrayRef<int64_t> filterShape = filterType.getShape();
2985 int64_t filterH = filterShape[getFilterHDim()];
2986 int64_t filterW = filterShape[getFilterWDim()];
2987 int64_t m = getM();
2988 int64_t r = getR();
2989 int64_t alpha = m + r - 1;
2990 int64_t alphaH = filterH != 1 ? alpha : 1;
2991 int64_t alphaW = filterW != 1 ? alpha : 1;
2992 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
2993 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
2995 resultOffsets.append(
2996 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
2997 resultSizes.append(
2998 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3000 return success();
3003 /// Implement tiling for winograd_filter_transform
3004 /// The input of winograd_filter_transform is (F, KH, KW, C).
3005 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3006 /// Users can specify the tile sizes of F and C.
3007 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3008 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3009 FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3010 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3011 ArrayRef<OpFoldResult> sizes) {
3012 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3013 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3014 ShapedType filterType = getFilterOperandType();
3015 ArrayRef<int64_t> filterShape = filterType.getShape();
3016 int64_t filterH = filterShape[getFilterHDim()];
3017 int64_t filterW = filterShape[getFilterWDim()];
3018 IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3019 IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3020 SmallVector<Value> tiledOperands;
3021 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3023 sliceOffsets.append(
3024 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3025 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3026 sizes[getFilterCDim()]});
3027 int64_t filterRank = getFilterOperandRank();
3028 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3029 Location loc = getLoc();
3030 auto filterSlice = builder.create<tensor::ExtractSliceOp>(
3031 loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3032 tiledOperands.emplace_back(filterSlice);
3034 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3035 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3036 resultSizes)))
3037 return failure();
3039 int64_t outputRank = getOutputOperandRank();
3040 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3041 auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3042 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3043 tiledOperands.emplace_back(outputSlice);
3045 SmallVector<Type> resultTypes;
3046 resultTypes.push_back(tiledOperands[1].getType());
3047 Operation *tiledOp =
3048 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3050 return TilingResult{
3051 {tiledOp},
3052 SmallVector<Value>(tiledOp->getResults()),
3053 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3056 //===----------------------------------------------------------------------===//
3057 // WinogradInputTransformOp
3058 //===----------------------------------------------------------------------===//
3060 LogicalResult WinogradInputTransformOp::verify() {
3061 auto inputType = cast<ShapedType>(getInput().getType());
3062 ArrayRef<int64_t> inputShape = inputType.getShape();
3063 int64_t inputH = inputShape[getInputHDim()];
3064 int64_t inputW = inputShape[getInputWDim()];
3065 int m = getM();
3066 int r = getR();
3067 int64_t tileSize = m + r - 1;
3068 bool leftTransform = inputH != 1;
3069 bool rightTransform = inputW != 1;
3071 SmallVector<int64_t> expectedOutputShape(6, inputH);
3072 if (ShapedType::isDynamic(inputH)) {
3073 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3074 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3075 } else {
3076 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3077 expectedOutputShape[getOutputTileHDim()] =
3078 leftTransform ? (inputH - (r - 1)) / m : 1;
3080 if (ShapedType::isDynamic(inputW)) {
3081 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3082 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3083 } else {
3084 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3085 expectedOutputShape[getOutputTileWDim()] =
3086 rightTransform ? (inputW - (r - 1)) / m : 1;
3088 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3089 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3091 auto outputType = cast<ShapedType>(getOutput().getType());
3092 ArrayRef<int64_t> outputShape = outputType.getShape();
3093 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3094 return emitOpError("the output shape is not expected");
3096 return success();
3099 SmallVector<Range>
3100 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3101 Location loc = getLoc();
3102 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3103 IntegerAttr oneAttr = builder.getIndexAttr(1);
3104 Value output = getOutput();
3105 int64_t outputRank = getOutputOperandRank();
3106 SmallVector<Range> loopBounds(outputRank);
3107 for (unsigned dim = 0; dim < outputRank; ++dim) {
3108 loopBounds[dim].offset = zeroAttr;
3109 // alphaH, alphaW, tileH, tileW, N, C
3110 loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3111 loopBounds[dim].stride = oneAttr;
3113 return loopBounds;
3116 SmallVector<utils::IteratorType>
3117 WinogradInputTransformOp::getLoopIteratorTypes() {
3118 int64_t outputRank = getOutputOperandRank();
3119 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3120 utils::IteratorType::parallel);
3121 return iteratorTypes;
3124 LogicalResult WinogradInputTransformOp::getResultTilePosition(
3125 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3126 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3127 SmallVector<OpFoldResult> &resultSizes) {
3128 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3129 ShapedType inputType = getInputOperandType();
3130 ArrayRef<int64_t> inputShape = inputType.getShape();
3131 int64_t inputH = inputShape[getInputHDim()];
3132 int64_t inputW = inputShape[getInputWDim()];
3133 int64_t m = getM();
3134 int64_t r = getR();
3135 int64_t alpha = m + r - 1;
3136 int64_t alphaH = inputH != 1 ? alpha : 1;
3137 int64_t alphaW = inputW != 1 ? alpha : 1;
3138 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3139 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3141 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3142 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3143 offsets[getOutputCDim()]});
3144 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3145 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3146 sizes[getOutputCDim()]});
3148 return success();
3151 /// Implement tiling for winograd_input_transform
3152 /// The input of winograd_input_transform is (N, H, W, C).
3153 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3154 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3155 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3156 /// the values for the sizes of tileH, tileW, N, C for one tile.
3157 FailureOr<TilingResult>
3158 WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3159 ArrayRef<OpFoldResult> offsets,
3160 ArrayRef<OpFoldResult> sizes) {
3161 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3162 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3163 ShapedType inputType = getInputOperandType();
3164 ArrayRef<int64_t> inputShape = inputType.getShape();
3165 int64_t inputH = inputShape[getInputHDim()];
3166 int64_t inputW = inputShape[getInputWDim()];
3167 int64_t m = getM();
3168 int64_t r = getR();
3170 Location loc = getLoc();
3171 MLIRContext *context = builder.getContext();
3172 auto offsetAffineMap =
3173 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3174 Value mappedOffsetH = affine::makeComposedAffineApply(
3175 builder, loc, offsetAffineMap, offsets[getOutputTileHDim()]);
3176 Value mappedOffsetW = affine::makeComposedAffineApply(
3177 builder, loc, offsetAffineMap, offsets[getOutputTileWDim()]);
3178 auto sizeAffineMap = AffineMap::get(
3179 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3180 Value mappedSizeH = affine::makeComposedAffineApply(
3181 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3182 Value mappedSizeW = affine::makeComposedAffineApply(
3183 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3185 SmallVector<Value> tiledOperands;
3186 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3188 OpFoldResult offsetH =
3189 inputH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
3190 OpFoldResult offsetW =
3191 inputW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
3192 sliceOffsets.append(
3193 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3194 OpFoldResult sizeH =
3195 inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3196 OpFoldResult sizeW =
3197 inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3198 sliceSizes.append(
3199 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3200 int64_t inputRank = getInputOperandRank();
3201 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3202 auto inputSlice = builder.create<tensor::ExtractSliceOp>(
3203 loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3204 tiledOperands.emplace_back(inputSlice);
3206 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3207 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3208 resultSizes)))
3209 return failure();
3211 int64_t outputRank = getOutputOperandRank();
3212 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3213 auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3214 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3215 tiledOperands.emplace_back(outputSlice);
3217 SmallVector<Type> resultTypes;
3218 resultTypes.push_back(tiledOperands[1].getType());
3219 Operation *tiledOp =
3220 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3222 return TilingResult{
3223 {tiledOp},
3224 SmallVector<Value>(tiledOp->getResults()),
3225 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3228 //===----------------------------------------------------------------------===//
3229 // WinogradOutputTransformOp
3230 //===----------------------------------------------------------------------===//
3232 LogicalResult WinogradOutputTransformOp::verify() {
3233 auto valueType = cast<ShapedType>(getValue().getType());
3234 ArrayRef<int64_t> valueShape = valueType.getShape();
3235 int64_t valueH = valueShape[getValueAlphaHDim()];
3236 int64_t valueW = valueShape[getValueAlphaWDim()];
3237 int64_t valueTileH = valueShape[getValueTileHDim()];
3238 int64_t valueTileW = valueShape[getValueTileWDim()];
3239 int m = getM();
3240 int r = getR();
3241 bool leftTransform = valueH != 1;
3242 bool rightTransform = valueW != 1;
3244 int64_t outputRank = getOutputOperandRank();
3245 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3246 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3247 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3248 } else {
3249 if (valueH != (leftTransform ? m + r - 1 : 1))
3250 return emitOpError("expect input height equals to input tile size");
3251 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3253 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3254 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3255 } else {
3256 if (valueW != (rightTransform ? m + r - 1 : 1))
3257 return emitOpError("expect input width equals to input tile size");
3258 expectedOutputShape[getOutputWDim()] =
3259 (rightTransform ? m : 1) * valueTileW;
3261 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3262 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3264 auto outputType = cast<ShapedType>(getOutput().getType());
3265 ArrayRef<int64_t> outputShape = outputType.getShape();
3266 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3267 return emitOpError("the output shape is not expected");
3269 return success();
3272 SmallVector<Range>
3273 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3274 Location loc = getLoc();
3275 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3276 IntegerAttr oneAttr = builder.getIndexAttr(1);
3277 Value value = getValue();
3278 int64_t valueRank = getValueOperandRank();
3279 SmallVector<Range> loopBounds(valueRank);
3280 for (unsigned dim = 0; dim < valueRank; ++dim) {
3281 loopBounds[dim].offset = zeroAttr;
3282 // alphaH, alphaW, tileH, tileW, N, F
3283 loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3284 loopBounds[dim].stride = oneAttr;
3286 return loopBounds;
3289 SmallVector<utils::IteratorType>
3290 WinogradOutputTransformOp::getLoopIteratorTypes() {
3291 int64_t valueRank = getValueOperandRank();
3292 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3293 utils::IteratorType::parallel);
3294 return iteratorTypes;
3297 LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3298 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3299 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3300 SmallVector<OpFoldResult> &resultSizes) {
3301 int64_t m = getM();
3303 Location loc = getLoc();
3304 MLIRContext *context = builder.getContext();
3305 auto affineMap =
3306 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3308 Value mappedOffsetH = affine::makeComposedAffineApply(
3309 builder, loc, affineMap, offsets[getValueTileHDim()]);
3310 Value mappedOffsetW = affine::makeComposedAffineApply(
3311 builder, loc, affineMap, offsets[getValueTileWDim()]);
3312 Value mappedSizeH = affine::makeComposedAffineApply(
3313 builder, loc, affineMap, sizes[getValueTileHDim()]);
3314 Value mappedSizeW = affine::makeComposedAffineApply(
3315 builder, loc, affineMap, sizes[getValueTileWDim()]);
3317 ShapedType valueType = getValueOperandType();
3318 ArrayRef<int64_t> valueShape = valueType.getShape();
3319 int64_t valueH = valueShape[0];
3320 int64_t valueW = valueShape[1];
3321 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3322 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3323 OpFoldResult offsetH =
3324 valueH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
3325 OpFoldResult offsetW =
3326 valueW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
3327 OpFoldResult sizeH =
3328 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3329 OpFoldResult sizeW =
3330 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3332 resultOffsets.append(
3333 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3334 resultSizes.append(
3335 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3336 return success();
3339 /// Implement tiling for winograd_output_transform
3340 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3341 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3342 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3343 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3344 /// for the sizes of tileH, tileW, N, F for one tile.
3345 FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3346 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3347 ArrayRef<OpFoldResult> sizes) {
3348 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3349 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3350 Location loc = getLoc();
3351 SmallVector<Value> tiledOperands;
3352 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3354 ShapedType valueType = getValueOperandType();
3355 ArrayRef<int64_t> valueShape = valueType.getShape();
3356 int64_t alphaH = valueShape[getValueAlphaHDim()];
3357 int64_t alphaW = valueShape[getValueAlphaWDim()];
3358 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3359 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3361 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3362 offsets[getValueTileWDim()], offsets[getValueNDim()],
3363 offsets[getValueFDim()]});
3364 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3365 sizes[getValueTileWDim()], sizes[getValueNDim()],
3366 sizes[getValueFDim()]});
3367 int64_t valueRank = getValueOperandRank();
3368 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3369 auto valueSlice = builder.create<tensor::ExtractSliceOp>(
3370 loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3371 tiledOperands.emplace_back(valueSlice);
3373 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3374 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3375 resultSizes)))
3376 return failure();
3378 int64_t outputRank = getOutputOperandRank();
3379 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3380 auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3381 loc, getOutput(), resultOffsets, resultSizes, strides);
3382 tiledOperands.emplace_back(outputSlice);
3384 SmallVector<Type> resultTypes;
3385 resultTypes.push_back(tiledOperands[1].getType());
3386 Operation *tiledOp =
3387 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3389 return TilingResult{
3390 {tiledOp},
3391 SmallVector<Value>(tiledOp->getResults()),
3392 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3395 //===----------------------------------------------------------------------===//
3396 // LinalgDialect
3397 //===----------------------------------------------------------------------===//
3399 void LinalgDialect::getCanonicalizationPatterns(
3400 RewritePatternSet &results) const {
3401 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
3402 InferStaticShapeOfOperands>(getContext());
3405 Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
3406 Attribute value, Type type,
3407 Location loc) {
3408 return arith::ConstantOp::materialize(builder, value, type, loc);
3411 /// Returns true if the result AffineExpr of the \p explicitMap is same as \p
3412 /// defaultMap.
3413 static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
3414 auto explicitRange = explictMap.getResults();
3415 auto defaultRange = defaultMap.getResults();
3416 DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3417 DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3418 llvm::set_union(explicitSet, defaultSet);
3419 return explicitSet == defaultSet;
3422 /// Returns true if the \p explictMap is broadcasted with respect to the
3423 /// \p defaultMap.
3424 static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3425 return explictMap.getNumResults() < defaultMap.getNumResults();
3428 /// Verifies the broadcast and transpose semantic sepecified by the explicit
3429 /// indexing map for the MatmulOp \p op for each operand specified by \p
3430 /// opIndex.
3431 static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3432 unsigned opIndex) {
3433 SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3434 SmallVector<AffineMap, 3> defaultIndexingMaps =
3435 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3437 auto opIndexingMap = opIndexingMaps[opIndex];
3438 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3439 // Check general validity of indexing map results.
3440 if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
3441 return matmulOp->emitOpError()
3442 << "Unexpected dim expression in map result.";
3444 // Check if the requested broadcast is valid.
3445 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3446 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3447 return matmulOp->emitOpError()
3448 << "Invalid broadcast requested, should be (d2).";
3450 return success();
3452 return success();
3455 namespace mlir {
3456 namespace linalg {
3458 //===----------------------------------------------------------------------===//
3459 // MatMulOp
3460 //===----------------------------------------------------------------------===//
3462 /// Returns a list of AffineMap with the typical matmul indexing charactristic.
3463 SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3464 AffineExpr d0, d1, d2;
3465 SmallVector<AffineMap> indexingMaps;
3466 bindDims(context, d0, d1, d2);
3467 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3468 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3469 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3470 return indexingMaps;
3473 SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3474 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3475 utils::IteratorType::parallel,
3476 utils::IteratorType::reduction};
3479 unsigned MatmulOp::getNumRegionArgs() { return 3; }
3481 std::string MatmulOp::getLibraryCallName() {
3482 return generateLibraryCallName(getOperation());
3485 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3487 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3488 /// the user defined indexing maps are not equal to default map.
3489 bool MatmulOp::hasUserDefinedMaps() {
3490 SmallVector<AffineMap, 3> defaultMaps =
3491 getDefaultIndexingMaps(this->getContext());
3492 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3493 return defaultMaps != explicitMaps;
3496 /// Implements the block region builder for the MatmulOp. This is called by
3497 /// 'fillStructuredOpRegion'.
3498 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3499 ArrayRef<NamedAttribute> attrs) {
3500 assert(3 > 0 && block.getNumArguments() == 3 &&
3501 "MatmulOp regionBuilder expects 3 (>=0) args");
3502 RegionBuilderHelper helper(b, block);
3503 SmallVector<Value> yields;
3505 TypeFn castVal = TypeFn::cast_signed;
3506 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3507 return attr.getName() == "cast";
3509 if (castIter != attrs.end()) {
3510 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3511 castVal = attr.getValue();
3514 Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3515 block.getArgument(0));
3516 Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3517 block.getArgument(1));
3518 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3519 Value value4 =
3520 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
3521 yields.push_back(value4);
3522 helper.yieldOutputs(yields);
3525 /// Returns true if the given broadcast map \p bcastMap is valid for this op.
3526 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3527 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3528 AffineExpr exp = bcastMap.getResult(0);
3529 // Invalid map if the common dimension of matmul not found.
3530 return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
3533 ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3534 SmallVector<Attribute, 3> indexingMapsAttr;
3535 Attribute mapAttr;
3536 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
3537 if (parser.parseEqual())
3538 return failure();
3540 if (parser.parseLSquare())
3541 return failure();
3543 do {
3544 if (parser.parseAttribute(mapAttr))
3545 return failure();
3546 if (!isa<AffineMapAttr>(mapAttr)) {
3547 return parser.emitError(parser.getCurrentLocation(),
3548 "expected affine map attribute");
3550 indexingMapsAttr.push_back(mapAttr);
3552 if (parser.parseOptionalComma())
3553 break;
3554 } while (true);
3556 if (parser.parseRSquare())
3557 return failure();
3559 // Initialize indexingMaps, if not supplied explicitly.
3560 if (indexingMapsAttr.empty()) {
3561 indexingMapsAttr = llvm::map_to_vector(
3562 MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3563 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3565 result.addAttribute("indexing_maps",
3566 parser.getBuilder().getArrayAttr(indexingMapsAttr));
3568 return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3569 MatmulOp::getRegionBuilder());
3571 void MatmulOp::print(OpAsmPrinter &p) {
3572 SmallVector<StringRef, 3> elidedAttrs = {
3573 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3574 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3575 elidedAttrs);
3577 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
3578 MatmulOp::getDefaultIndexingMaps(getContext()),
3579 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3580 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
3581 p << " indexing_maps = [";
3582 llvm::interleaveComma(getIndexingMaps(), p,
3583 [&](Attribute attr) { p.printAttribute(attr); });
3584 p << "]";
3588 /// Verify the user defined indexing maps.
3589 LogicalResult MatmulOp::verify() {
3590 // Verification of pure matmul is handled by verifyStructuredOpInterface().
3591 if (!hasUserDefinedMaps())
3592 return success();
3594 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3595 if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3596 return failure();
3598 return success();
3601 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3602 return memref::foldMemRefCast(*this);
3604 void MatmulOp::getEffects(
3605 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3606 &effects) {
3607 if (hasPureTensorSemantics())
3608 return;
3609 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3612 Speculation::Speculatability MatmulOp::getSpeculatability() {
3613 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3616 } // namespace linalg
3617 } // namespace mlir