Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Dialect / Linalg / IR / LinalgOps.cpp
blob26d9d2b091750cc2c2a5f2d6e9f32280bf5bd2b8
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/AsmParser/AsmParser.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Arith/Utils/Utils.h"
19 #include "mlir/Dialect/Complex/IR/Complex.h"
20 #include "mlir/Dialect/Math/IR/Math.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Utils/IndexingUtils.h"
26 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
27 #include "mlir/Dialect/Utils/StaticValueUtils.h"
28 #include "mlir/IR/AffineExprVisitor.h"
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Attributes.h"
31 #include "mlir/IR/BuiltinAttributes.h"
32 #include "mlir/IR/BuiltinTypeInterfaces.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OpImplementation.h"
35 #include "mlir/IR/OperationSupport.h"
36 #include "mlir/IR/PatternMatch.h"
37 #include "mlir/Interfaces/InferTypeOpInterface.h"
38 #include "mlir/Interfaces/SideEffectInterfaces.h"
40 #include "llvm/ADT/DenseMap.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SetOperations.h"
43 #include "llvm/ADT/SmallSet.h"
44 #include "llvm/ADT/SmallVector.h"
45 #include "llvm/ADT/StringSet.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include "llvm/Support/LogicalResult.h"
49 #include "llvm/Support/MathExtras.h"
50 #include "llvm/Support/raw_ostream.h"
51 #include <cassert>
52 #include <optional>
54 using namespace mlir;
55 using namespace mlir::linalg;
57 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
58 static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
59 int64_t dim) {
60 auto type = cast<ShapedType>(v.getType());
61 if (!type.isDynamicDim(dim))
62 return builder.getIndexAttr(type.getDimSize(dim));
64 return getAsOpFoldResult(
65 TypeSwitch<Type, Value>(v.getType())
66 .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
67 return builder.create<tensor::DimOp>(loc, v, dim);
69 .Case<MemRefType>([&](MemRefType t) -> Value {
70 return builder.create<memref::DimOp>(loc, v, dim);
71 }));
74 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
75 /// `source`.
76 static Operation *getSlice(OpBuilder &b, Location loc, Value source,
77 ArrayRef<OpFoldResult> offsets,
78 ArrayRef<OpFoldResult> sizes,
79 ArrayRef<OpFoldResult> strides) {
80 return TypeSwitch<Type, Operation *>(source.getType())
81 .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
82 return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
83 strides);
85 .Case<MemRefType>([&](MemRefType type) -> Operation * {
86 return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
87 strides);
89 .Default([&](Type t) -> Operation * { return nullptr; });
92 //===----------------------------------------------------------------------===//
93 // Helper functions
94 //===----------------------------------------------------------------------===//
96 Value linalg::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
97 int64_t dim) {
98 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
99 return b.createOrFold<memref::DimOp>(loc, source, dim);
100 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
101 return b.createOrFold<tensor::DimOp>(loc, source, dim);
102 llvm_unreachable("Expected MemRefType or TensorType");
105 OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source,
106 int64_t dim) {
107 auto shapedType = llvm::cast<ShapedType>(source.getType());
108 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
109 return createOrFoldDimOp(b, loc, source, dim);
110 return b.getIndexAttr(shapedType.getDimSize(dim));
113 //===----------------------------------------------------------------------===//
114 // Support for named Linalg ops defined in ods-gen.
115 //===----------------------------------------------------------------------===//
117 using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
118 ArrayRef<NamedAttribute>)>;
120 /// Fills the region of a structured operation using the provided
121 /// `regionBuilder`. The method is used by both named structured ops created by
122 /// ods-gen and by manually defined C++ ops. It is called by both builders and
123 /// parsers and creates a block with arguments corresponding to the elemental
124 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
125 /// ShapedType.
126 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
127 TypeRange inputTypes, TypeRange outputTypes,
128 ArrayRef<NamedAttribute> attrs,
129 RegionBuilderFn regionBuilder) {
130 assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
132 SmallVector<Type, 8> argTypes;
133 SmallVector<Location, 8> argLocs;
134 for (auto containers : {inputTypes, outputTypes}) {
135 for (auto t : containers) {
136 argTypes.push_back(
137 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
139 // TODO: Pass in a proper location here.
140 argLocs.push_back(opBuilder.getUnknownLoc());
144 // RAII.
145 OpBuilder::InsertionGuard guard(opBuilder);
146 Block *body =
147 opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
149 opBuilder.setInsertionPointToStart(body);
150 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
151 regionBuilder(b, *body, attrs);
153 // indexing_maps is an auto-generated method.
155 // iterator_types is an auto-generated method.
158 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
159 /// The result types are derived automatically if `resultTensorTypes` is none.
160 /// The body of the operation is filled using `regionBuilder`. All ods-gen
161 /// created structured operations use the method to implement their builders.
162 static void buildStructuredOp(OpBuilder &b, OperationState &state,
163 std::optional<TypeRange> resultTensorTypes,
164 ValueRange inputs, ValueRange outputs,
165 ArrayRef<NamedAttribute> attributes,
166 RegionBuilderFn regionBuilder) {
167 // Derive the result types if needed.
168 SmallVector<Type> derivedResultTypes =
169 resultTensorTypes.value_or(TypeRange());
170 if (!resultTensorTypes)
171 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
172 llvm::IsaPred<RankedTensorType>);
174 state.addOperands(inputs);
175 state.addOperands(outputs);
176 state.addTypes(derivedResultTypes);
178 state.addAttributes(attributes);
179 state.addAttribute(
180 "operandSegmentSizes",
181 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
182 static_cast<int32_t>(outputs.size())}));
184 // Create and fill the region of the structured operation.
185 Region &region = *state.addRegion();
186 fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
187 state.attributes.getAttrs(), regionBuilder);
190 static void buildMatmulOp(OpBuilder &b, OperationState &state,
191 std::optional<TypeRange> resultTensorTypes,
192 ValueRange inputs, ValueRange outputs,
193 ArrayRef<NamedAttribute> attributes,
194 RegionBuilderFn regionBuilder,
195 ArrayRef<AffineMap> indexingMaps) {
196 // Initialize indexingMaps attribute, for MatmulOp.
197 SmallVector<Attribute, 3> indexingMapsAttrVal;
198 indexingMapsAttrVal = llvm::map_to_vector(
199 MatmulOp::getDefaultIndexingMaps(b.getContext()),
200 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
201 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
202 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
203 attributes, regionBuilder);
206 /// Common parsing used for both named structured ops created by ods-gen and by
207 /// manually defined C++ ops. Does not handle regions.
208 static ParseResult
209 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
210 SmallVectorImpl<Type> &inputTypes,
211 SmallVectorImpl<Type> &outputTypes,
212 bool addOperandSegmentSizes = true) {
213 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
214 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
215 outputsOperands;
217 if (succeeded(parser.parseOptionalLess())) {
218 if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
219 return failure();
221 attrsLoc = parser.getCurrentLocation();
222 if (parser.parseOptionalAttrDict(result.attributes))
223 return failure();
225 if (succeeded(parser.parseOptionalKeyword("ins"))) {
226 if (parser.parseLParen())
227 return failure();
229 inputsOperandsLoc = parser.getCurrentLocation();
230 if (parser.parseOperandList(inputsOperands) ||
231 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
232 return failure();
235 if (succeeded(parser.parseOptionalKeyword("outs"))) {
236 outputsOperandsLoc = parser.getCurrentLocation();
237 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
238 parser.parseColonTypeList(outputTypes) || parser.parseRParen())
239 return failure();
242 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
243 result.operands) ||
244 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
245 result.operands))
246 return failure();
248 if (addOperandSegmentSizes) {
249 // This is a bit complex because we're trying to be backward compatible with
250 // operation syntax that mix the inherent attributes and the discardable
251 // ones in the same dictionary. If the properties are used, we append the
252 // operandSegmentSizes there directly. Otherwise we append it to the
253 // discardable attributes dictionary where it is handled by the generic
254 // Operation::create(...) method.
255 if (result.propertiesAttr) {
256 NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
257 attrs.append("operandSegmentSizes",
258 parser.getBuilder().getDenseI32ArrayAttr(
259 {static_cast<int32_t>(inputsOperands.size()),
260 static_cast<int32_t>(outputsOperands.size())}));
261 result.propertiesAttr = attrs.getDictionary(parser.getContext());
262 } else {
263 result.addAttribute("operandSegmentSizes",
264 parser.getBuilder().getDenseI32ArrayAttr(
265 {static_cast<int32_t>(inputsOperands.size()),
266 static_cast<int32_t>(outputsOperands.size())}));
269 if (!result.propertiesAttr) {
270 std::optional<RegisteredOperationName> info =
271 result.name.getRegisteredInfo();
272 if (info) {
273 if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
274 return parser.emitError(attrsLoc)
275 << "'" << result.name.getStringRef() << "' op ";
276 })))
277 return failure();
280 return success();
283 static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
284 ValueRange outputs) {
285 if (!inputs.empty())
286 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
287 if (!outputs.empty())
288 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
291 //===----------------------------------------------------------------------===//
292 // Specific parsing and printing for named structured ops created by ods-gen.
293 //===----------------------------------------------------------------------===//
295 static ParseResult parseNamedStructuredOpRegion(
296 OpAsmParser &parser, Region &region, unsigned numRegionArgs,
297 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
298 RegionBuilderFn regionBuilder) {
299 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
300 return parser.emitError(
301 parser.getCurrentLocation(),
302 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
303 "region expects {0} args, got {1}",
304 numRegionArgs, inputTypes.size() + outputTypes.size()));
307 OpBuilder opBuilder(parser.getContext());
308 fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
309 regionBuilder);
310 return success();
313 static ParseResult
314 parseNamedStructuredOpResults(OpAsmParser &parser,
315 SmallVectorImpl<Type> &resultTypes) {
316 if (parser.parseOptionalArrowTypeList(resultTypes))
317 return failure();
318 return success();
321 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
322 OperationState &result,
323 unsigned numRegionArgs,
324 RegionBuilderFn regionBuilder) {
325 // TODO: Enable when ods-gen supports captures.
326 SmallVector<Type, 1> inputTypes, outputTypes;
327 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
328 return failure();
330 // Parse optional attributes.
331 if (parser.parseOptionalAttrDict(result.attributes))
332 return failure();
334 // TODO: consider merging results parsing into region parsing.
335 // Need to wait for declarative assembly resolution to decide.
336 SmallVector<Type, 1> outputTensorsTypes;
337 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
338 return failure();
339 result.addTypes(outputTensorsTypes);
341 std::unique_ptr<Region> region = std::make_unique<Region>();
342 if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
343 outputTypes, result.attributes.getAttrs(),
344 regionBuilder))
345 return failure();
346 result.addRegion(std::move(region));
348 return success();
351 static void printNamedStructuredOpResults(OpAsmPrinter &p,
352 TypeRange resultTypes) {
353 if (resultTypes.empty())
354 return;
355 p.printOptionalArrowTypeList(resultTypes);
358 static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
359 ValueRange inputs, ValueRange outputs,
360 ArrayRef<StringRef> elidedAttrs = {}) {
361 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
363 // Printing is shared with generic ops, except for the region and
364 // attributes.
365 printCommonStructuredOpParts(p, inputs, outputs);
367 // Results printing.
368 printNamedStructuredOpResults(p, op->getResultTypes());
370 // Region is elided.
373 //===----------------------------------------------------------------------===//
374 // Region builder helper.
375 // TODO: Move this to a utility library.
376 // The public methods on this class are referenced directly from generated code.
377 // Helper build the unary, binary, and type conversion functions defined by the
378 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
379 // class.
381 // Implementations of the math functions must be polymorphic over numeric types,
382 // internally performing necessary casts. If the function application makes no
383 // sense, then the only recourse is to assert and return nullptr. This can be
384 // extended later if it becomes possible to fail construction of the region. The
385 // invariant should be enforced at a higher level.
387 // TODO: These helpers are currently type polymorphic over the class of integer
388 // and floating point types, but they will not internally cast within bit
389 // widths of a class (mixed precision such as i8->i32) or across classes
390 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
391 // to be handled with care and work is being considered to extend the op
392 // language to make such cases explicit. In the mean-time, violating this will
393 // fail verification, which is deemed acceptable.
394 //===----------------------------------------------------------------------===//
396 namespace {
398 class RegionBuilderHelper {
399 public:
400 RegionBuilderHelper(OpBuilder &builder, Block &block)
401 : builder(builder), block(block) {}
403 // Build the unary functions defined by OpDSL.
404 Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
405 if (!isFloatingPoint(arg))
406 llvm_unreachable("unsupported non numeric type");
407 OpBuilder::InsertionGuard g(builder);
408 builder.setInsertionPointToEnd(&block);
409 switch (unaryFn) {
410 case UnaryFn::exp:
411 return builder.create<math::ExpOp>(arg.getLoc(), arg);
412 case UnaryFn::log:
413 return builder.create<math::LogOp>(arg.getLoc(), arg);
414 case UnaryFn::abs:
415 return builder.create<math::AbsFOp>(arg.getLoc(), arg);
416 case UnaryFn::ceil:
417 return builder.create<math::CeilOp>(arg.getLoc(), arg);
418 case UnaryFn::floor:
419 return builder.create<math::FloorOp>(arg.getLoc(), arg);
420 case UnaryFn::negf:
421 return builder.create<arith::NegFOp>(arg.getLoc(), arg);
422 case UnaryFn::reciprocal: {
423 Attribute oneAttr = builder.getOneAttr(arg.getType());
424 auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
425 ::cast<TypedAttr>(oneAttr));
426 return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
428 case UnaryFn::round:
429 return builder.create<math::RoundOp>(arg.getLoc(), arg);
430 case UnaryFn::sqrt:
431 return builder.create<math::SqrtOp>(arg.getLoc(), arg);
432 case UnaryFn::rsqrt:
433 return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
434 case UnaryFn::square:
435 return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
436 case UnaryFn::tanh:
437 return builder.create<math::TanhOp>(arg.getLoc(), arg);
438 case UnaryFn::erf:
439 return builder.create<math::ErfOp>(arg.getLoc(), arg);
441 llvm_unreachable("unsupported unary function");
444 // Build the binary functions defined by OpDSL.
445 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
446 bool allComplex = isComplex(arg0) && isComplex(arg1);
447 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
448 bool allInteger = isInteger(arg0) && isInteger(arg1);
449 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
450 arg1.getType().getIntOrFloatBitWidth() == 1;
451 if (!allComplex && !allFloatingPoint && !allInteger)
452 llvm_unreachable("unsupported non numeric type");
453 OpBuilder::InsertionGuard g(builder);
454 builder.setInsertionPointToEnd(&block);
455 switch (binaryFn) {
456 case BinaryFn::add:
457 if (allComplex)
458 return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
459 if (allFloatingPoint)
460 return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
461 if (allBool)
462 return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
463 return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
464 case BinaryFn::sub:
465 if (allComplex)
466 return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
467 if (allFloatingPoint)
468 return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
469 if (allBool)
470 llvm_unreachable("unsupported operation: sub with bools");
471 return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
472 case BinaryFn::mul:
473 if (allComplex)
474 return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
475 if (allFloatingPoint)
476 return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
477 if (allBool)
478 return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
479 return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
480 case BinaryFn::div:
481 if (allComplex)
482 return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
483 if (allFloatingPoint)
484 return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
485 if (allBool)
486 llvm_unreachable("unsupported operation: div with bools");
487 return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
488 case BinaryFn::div_unsigned:
489 if (!allInteger || allBool)
490 llvm_unreachable("unsupported operation: unsigned div not on uint");
491 return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
492 case BinaryFn::max_signed:
493 assert(!allComplex);
494 if (allFloatingPoint)
495 return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
496 return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
497 case BinaryFn::min_signed:
498 assert(!allComplex);
499 if (allFloatingPoint)
500 return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
501 return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
502 case BinaryFn::max_unsigned:
503 assert(!allComplex);
504 if (allFloatingPoint)
505 return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
506 return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
507 case BinaryFn::min_unsigned:
508 assert(!allComplex);
509 if (allFloatingPoint)
510 return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
511 return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
512 case BinaryFn::powf:
513 assert(allFloatingPoint);
514 return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
516 llvm_unreachable("unsupported binary function");
519 // Build the ternary functions defined by OpDSL.
520 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
521 Value arg2) {
522 bool headBool =
523 isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
524 bool tailFloatingPoint =
525 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
526 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
527 OpBuilder::InsertionGuard g(builder);
528 builder.setInsertionPointToEnd(&block);
529 switch (ternaryFn) {
530 case TernaryFn::select:
531 if (!headBool && !(tailFloatingPoint || tailInteger))
532 llvm_unreachable("unsupported non numeric type");
533 return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
535 llvm_unreachable("unsupported ternary function");
538 // Build the type functions defined by OpDSL.
539 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
540 switch (typeFn) {
541 case TypeFn::cast_signed:
542 return cast(toType, operand, false);
543 case TypeFn::cast_unsigned:
544 return cast(toType, operand, true);
546 llvm_unreachable("unsupported type conversion function");
549 void yieldOutputs(ValueRange values) {
550 OpBuilder::InsertionGuard g(builder);
551 builder.setInsertionPointToEnd(&block);
552 Location loc = builder.getUnknownLoc();
553 builder.create<YieldOp>(loc, values);
556 Value constant(const std::string &value) {
557 OpBuilder::InsertionGuard g(builder);
558 builder.setInsertionPointToEnd(&block);
559 Location loc = builder.getUnknownLoc();
560 Attribute valueAttr = parseAttribute(value, builder.getContext());
561 return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
564 Value index(int64_t dim) {
565 OpBuilder::InsertionGuard g(builder);
566 builder.setInsertionPointToEnd(&block);
567 return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
570 Type getIntegerType(unsigned width) {
571 return IntegerType::get(builder.getContext(), width);
574 Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
575 Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
577 private:
578 // Generates operations to cast the given operand to a specified type.
579 // If the cast cannot be performed, a warning will be issued and the
580 // operand returned as-is (which will presumably yield a verification
581 // issue downstream).
582 Value cast(Type toType, Value operand, bool isUnsignedCast) {
583 OpBuilder::InsertionGuard g(builder);
584 builder.setInsertionPointToEnd(&block);
585 auto loc = operand.getLoc();
586 return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
589 bool isComplex(Value value) {
590 return llvm::isa<ComplexType>(value.getType());
592 bool isFloatingPoint(Value value) {
593 return llvm::isa<FloatType>(value.getType());
595 bool isInteger(Value value) {
596 return llvm::isa<IntegerType>(value.getType());
599 OpBuilder &builder;
600 Block &block;
603 } // namespace
605 //===----------------------------------------------------------------------===//
606 // CopyOp
607 //===----------------------------------------------------------------------===//
609 namespace {
611 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
612 using OpRewritePattern<CopyOp>::OpRewritePattern;
613 LogicalResult matchAndRewrite(CopyOp copyOp,
614 PatternRewriter &rewriter) const override {
615 if (copyOp.getInputs() != copyOp.getOutputs())
616 return rewriter.notifyMatchFailure(copyOp, "not a self copy");
617 if (copyOp.hasPureBufferSemantics())
618 rewriter.eraseOp(copyOp);
619 else
620 rewriter.replaceOp(copyOp, copyOp.getInputs());
622 return success();
626 } // namespace
628 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
629 MLIRContext *context) {
630 results.add<EraseSelfCopy>(context);
633 //===----------------------------------------------------------------------===//
634 // FillOp
635 //===----------------------------------------------------------------------===//
637 namespace {
639 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
641 /// For such op chains, we can create new linalg.fill ops with the result
642 /// type of the tensor.expand/collapse_shape op.
643 template <typename TensorReshapeOp>
644 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
645 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
646 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
647 PatternRewriter &rewriter) const override {
648 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
649 if (!oldFill)
650 return failure();
652 Location loc = oldFill.getLoc();
653 TensorReshapeOp newInit;
654 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
656 newInit = rewriter.create<TensorReshapeOp>(
657 loc, reshapeOp.getResultType(), oldFill.output(),
658 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
659 reshapeOp.getStaticOutputShape());
660 } else {
661 newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
662 oldFill.output(),
663 reshapeOp.getReassociation());
665 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
666 ValueRange{newInit});
667 return success();
671 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
672 /// filling value are the same.
673 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
674 using OpRewritePattern::OpRewritePattern;
676 LogicalResult matchAndRewrite(tensor::PadOp padOp,
677 PatternRewriter &rewriter) const override {
678 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
679 if (!fillOp)
680 return failure();
682 // We can only fold if the padding value is the same as the original
683 // filling value.
684 Value padValue = padOp.getConstantPaddingValue();
685 if (!padValue || fillOp.value() != padValue)
686 return failure();
688 ReifiedRankedShapedTypeDims reifiedShape;
689 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
690 return rewriter.notifyMatchFailure(
691 padOp, "failed to reify tensor.pad op result shape");
693 auto emptyTensor = rewriter.create<tensor::EmptyOp>(
694 padOp.getLoc(), reifiedShape.front(),
695 padOp.getResultType().getElementType());
696 Value replacement =
697 rewriter
698 .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
699 ValueRange{emptyTensor})
700 .getResult(0);
701 if (replacement.getType() != padOp.getResultType()) {
702 replacement = rewriter.create<tensor::CastOp>(
703 fillOp.getLoc(), padOp.getResultType(), replacement);
705 rewriter.replaceOp(padOp, replacement);
706 return success();
710 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
711 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
712 /// filling value are the same.
713 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
714 using OpRewritePattern::OpRewritePattern;
716 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
717 PatternRewriter &rewriter) const override {
718 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
719 if (!srcPadOp)
720 return failure();
722 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
723 return failure();
725 // Walk back the tensor.insert_slice chain and find the first destination
726 // value at the start of the chain.
727 Value firstDest = insertOp.getDest();
728 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
729 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
730 return failure();
732 // Make sure the range of values accessed are disjoint. Without this, we
733 // cannot fold tensor.pad away.
734 bool disjoint = false;
735 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
736 // If the dimension has dynamic offset/size, we cannot guarantee
737 // disjoint. So just skip it.
738 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
739 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
740 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
741 continue;
743 // Get the range start and end, inclusively for both.
744 int64_t prevStart = prevOp.getStaticOffset(i);
745 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
746 prevOp.getStaticStride(i);
747 int64_t nextStart = insertOp.getStaticOffset(i);
748 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
749 insertOp.getStaticStride(i);
750 if (prevEnd < nextStart || nextEnd < prevStart) {
751 disjoint = true;
752 break;
756 if (!disjoint)
757 break;
758 firstDest = prevOp.getDest();
761 // Check whether the first destination is a fill op. For overlapped cases,
762 // this also cannot be true.
763 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
764 if (!dstFillOp)
765 return failure();
767 // We can only fold if the padding value is the same as the original
768 // filling value.
769 Value padValue = srcPadOp.getConstantPaddingValue();
770 if (!padValue || dstFillOp.value() != padValue)
771 return failure();
773 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
774 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
776 Location loc = insertOp.getLoc();
777 MLIRContext *context = getContext();
779 AffineExpr sym0, sym1;
780 bindSymbols(context, sym0, sym1);
781 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
783 // Calculate the new offsets for the insert. It should be the old offsets
784 // plus low padding sizes.
785 SmallVector<OpFoldResult, 4> newOffsets;
786 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
787 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
788 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
791 RankedTensorType srcPadType = srcPadOp.getSourceType();
792 SmallVector<OpFoldResult, 4> newSizes;
793 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
794 if (srcPadType.isDynamicDim(i)) {
795 newSizes.push_back(
796 rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
797 .getResult());
798 } else {
799 newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
803 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
804 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
805 newSizes, insertOp.getMixedStrides());
806 return success();
810 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
811 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
812 public:
813 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
815 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
816 PatternRewriter &rewriter) const override {
817 // See if tensor input of tensor.extract op is the result of a linalg.fill
818 // op.
819 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
820 if (!fillOp)
821 return failure();
823 // Get scalar input operand of linalg.fill op.
824 Value extractedScalar = fillOp.getInputs()[0];
826 // Replace tensor.extract op with scalar value used to fill the tensor.
827 rewriter.replaceOp(extractOp, extractedScalar);
828 return success();
832 /// Folds pack(fill) into a single fill op if
833 /// 1. The pack op does not have padding value, or
834 /// 2. The filled value and padding value are the same.
835 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
836 tensor::PackOp packOp) {
837 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
838 if (!fillOp)
839 return failure();
841 if (auto paddingValue = packOp.getPaddingValue())
842 if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
843 return failure();
845 Value packOpDest = packOp.getDest();
846 if (!packOpDest.hasOneUse())
847 return failure();
849 return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
850 packOp.getDest());
853 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
854 struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
855 public:
856 FoldFillWithPack(MLIRContext *context)
857 : OpRewritePattern<tensor::PackOp>(context) {}
859 LogicalResult matchAndRewrite(tensor::PackOp packOp,
860 PatternRewriter &rewriter) const override {
861 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
862 if (failed(fillOp))
863 return failure();
864 rewriter.replaceOp(packOp, fillOp.value().result());
865 return success();
869 /// Fold fill with copy.
870 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
871 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
873 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
874 PatternRewriter &rewriter) const override {
875 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
876 rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
877 fillOp.getInputs(),
878 copyOp.getOutputs());
879 return success();
881 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
882 rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
883 fillOp.getOutputs());
884 return success();
886 return failure();
890 /// Fold fill with transpose.
891 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
892 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
894 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
895 PatternRewriter &rewriter) const override {
896 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
897 rewriter.replaceOpWithNewOp<FillOp>(
898 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
899 transposeOp.getDpsInitOperand(0)->get());
900 return success();
902 return failure();
906 /// Fold a concat with all elements being fills of the same value
907 /// into a fill of the concat result shape.
908 struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
909 using OpRewritePattern::OpRewritePattern;
911 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
912 PatternRewriter &rewriter) const override {
913 auto concatOperands = concatOp.getInputs();
914 if (concatOperands.empty()) {
915 return failure();
918 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
919 if (!firstFillOp) {
920 return failure();
922 // Prefetch the fill value.
923 OpFoldResult firstFillVal =
924 getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
925 // Collect all the outs values for the fill operations.
926 SmallVector<Value> allOuts;
927 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
929 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
930 auto fillOp = v.getDefiningOp<linalg::FillOp>();
931 if (!fillOp) {
932 return false;
935 OpFoldResult fillVal =
936 getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
937 if (fillVal != firstFillVal)
938 return false;
940 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
941 return true;
943 if (!llvm::all_of(concatOperands.drop_front(),
944 isDefinedByCompatibleFillOp)) {
945 return rewriter.notifyMatchFailure(
946 concatOp, "not all operands are defined by a compatible fill op");
949 Value outsConcat = rewriter.create<tensor::ConcatOp>(
950 concatOp.getLoc(), concatOp.getDim(), allOuts);
951 rewriter.replaceOpWithNewOp<linalg::FillOp>(
952 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
953 return success();
957 } // namespace
959 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
960 MLIRContext *context) {
961 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
962 FoldFillWithPack, FoldFillWithPad,
963 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
964 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
965 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
968 //===----------------------------------------------------------------------===//
969 // GenericOp
970 //===----------------------------------------------------------------------===//
972 static void buildGenericRegion(
973 OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
974 ValueRange outputs,
975 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
976 SmallVector<Type, 4> blockArgTypes;
977 SmallVector<Location, 4> blockArgLocs;
978 for (ValueRange container : {inputs, outputs}) {
979 for (Value v : container) {
980 Type t = v.getType();
981 blockArgTypes.push_back(
982 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
983 blockArgLocs.push_back(v.getLoc());
987 OpBuilder::InsertionGuard guard(builder);
988 Block *bodyBlock =
989 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
990 bodyBuild(builder, loc, bodyBlock->getArguments());
993 void GenericOp::getAsmBlockArgumentNames(Region &region,
994 OpAsmSetValueNameFn setNameFn) {
995 for (Value v : getRegionInputArgs())
996 setNameFn(v, "in");
997 for (Value v : getRegionOutputArgs())
998 setNameFn(v, "out");
1001 void GenericOp::build(
1002 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1003 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1004 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1005 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1006 ArrayRef<NamedAttribute> attributes) {
1007 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1008 iteratorTypes, doc, libraryCall);
1009 result.addAttributes(attributes);
1010 if (bodyBuild)
1011 buildGenericRegion(builder, result.location, *result.regions.front(),
1012 inputs, outputs, bodyBuild);
1015 void GenericOp::build(
1016 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1017 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1018 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1019 StringRef libraryCall,
1020 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1021 ArrayRef<NamedAttribute> attributes) {
1022 build(builder, result, resultTensorTypes, inputs, outputs,
1023 builder.getAffineMapArrayAttr(indexingMaps),
1024 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1025 iteratorTypes,
1026 [&](utils::IteratorType iter) -> mlir::Attribute {
1027 return IteratorTypeAttr::get(builder.getContext(), iter);
1028 }))),
1029 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1030 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1031 bodyBuild, attributes);
1034 void GenericOp::build(
1035 OpBuilder &builder, OperationState &result, ValueRange inputs,
1036 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1037 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1038 StringRef libraryCall,
1039 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1040 ArrayRef<NamedAttribute> attributes) {
1041 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1042 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1045 void GenericOp::build(
1046 OpBuilder &builder, OperationState &result, ValueRange inputs,
1047 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1048 ArrayRef<utils::IteratorType> iteratorTypes,
1049 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1050 ArrayRef<NamedAttribute> attributes) {
1051 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1052 /*doc=*/"",
1053 /*libraryCall=*/"", bodyBuild, attributes);
1056 void GenericOp::build(
1057 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1058 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1059 ArrayRef<utils::IteratorType> iteratorTypes,
1060 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1061 ArrayRef<NamedAttribute> attributes) {
1062 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1063 iteratorTypes,
1064 /*doc=*/"",
1065 /*libraryCall=*/"", bodyBuild, attributes);
1068 void GenericOp::print(OpAsmPrinter &p) {
1069 p << " ";
1071 // Print extra attributes.
1072 auto genericAttrNames = linalgTraitAttrNames();
1074 llvm::StringSet<> genericAttrNamesSet;
1075 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
1076 SmallVector<NamedAttribute, 8> genericAttrs;
1077 for (auto attr : (*this)->getAttrs()) {
1078 if (attr.getName() == getIteratorTypesAttrName()) {
1079 auto iteratorTypes =
1080 llvm::cast<ArrayAttr>(attr.getValue())
1081 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1082 // Convert IteratorType enums into the string representation. This is
1083 // needed, because tests still use the old format when 'iterator_types'
1084 // attribute is represented as an array of strings.
1085 // TODO: Remove this conversion once tests are fixed.
1086 SmallVector<Attribute> iteratorTypeNames =
1087 llvm::to_vector(llvm::map_range(
1088 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1089 return StringAttr::get(getContext(), stringifyIteratorType(t));
1090 }));
1092 genericAttrs.emplace_back(
1093 getIteratorTypesAttrName(),
1094 ArrayAttr::get(getContext(), iteratorTypeNames));
1095 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1096 genericAttrs.push_back(attr);
1099 if (!genericAttrs.empty()) {
1100 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1101 p << genericDictAttr;
1104 // Printing is shared with named ops, except for the region and attributes
1105 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1107 genericAttrNames.push_back("operandSegmentSizes");
1108 genericAttrNamesSet.insert(genericAttrNames.back());
1110 bool hasExtraAttrs = false;
1111 for (NamedAttribute n : (*this)->getAttrs()) {
1112 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1113 break;
1115 if (hasExtraAttrs) {
1116 p << " attrs = ";
1117 p.printOptionalAttrDict((*this)->getAttrs(),
1118 /*elidedAttrs=*/genericAttrNames);
1121 // Print region.
1122 if (!getRegion().empty()) {
1123 p << ' ';
1124 p.printRegion(getRegion());
1127 // Print results.
1128 printNamedStructuredOpResults(p, getResultTensors().getTypes());
1131 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1132 DictionaryAttr dictAttr;
1133 // Parse the core linalg traits that must check into a dictAttr.
1134 // The name is unimportant as we will overwrite result.attributes.
1135 // The core linalg traits must contain the information necessary to pass the
1136 // verifier.
1137 llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1138 if (parser.parseAttribute(dictAttr, "_", result.attributes))
1139 return failure();
1140 result.attributes.assign(dictAttr.getValue().begin(),
1141 dictAttr.getValue().end());
1143 // Convert array of string into an array of IteratorType enums. This is
1144 // needed, because tests still use the old format when 'iterator_types'
1145 // attribute is represented as an array of strings.
1146 // TODO: Remove this conversion once tests are fixed.
1147 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1148 result.attributes.get(getIteratorTypesAttrName(result.name)));
1149 if (!iteratorTypes) {
1150 return parser.emitError(attributeLocation)
1151 << "expected " << getIteratorTypesAttrName(result.name)
1152 << " array attribute";
1155 SmallVector<Attribute> iteratorTypeAttrs;
1157 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1158 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1159 if (!maybeIteratorType.has_value())
1160 return parser.emitError(parser.getCurrentLocation())
1161 << "unexpected iterator_type (" << s << ")";
1163 iteratorTypeAttrs.push_back(
1164 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1166 result.attributes.set(getIteratorTypesAttrName(result.name),
1167 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1169 // Parsing is shared with named ops, except for the region.
1170 SmallVector<Type, 1> inputTypes, outputTypes;
1171 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1172 return failure();
1174 // Optional attributes may be added.
1175 if (succeeded(parser.parseOptionalKeyword("attrs")))
1176 if (failed(parser.parseEqual()) ||
1177 failed(parser.parseOptionalAttrDict(result.attributes)))
1178 return failure();
1180 std::unique_ptr<Region> region = std::make_unique<Region>();
1181 if (parser.parseRegion(*region, {}))
1182 return failure();
1183 result.addRegion(std::move(region));
1185 // Generic ops may specify that a subset of its outputs are tensors. Such
1186 // outputs are specified in the result type.
1187 // TODO: may need to move output parsing before region parsing.
1188 // Need to wait for declarative assembly resolution to decide.
1189 SmallVector<Type, 1> outputTensorsTypes;
1190 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1191 return failure();
1192 result.addTypes(outputTensorsTypes);
1194 return success();
1197 static void getGenericEffectsImpl(
1198 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1199 &effects,
1200 LinalgOp linalgOp) {
1201 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1202 if (!llvm::isa<MemRefType>(operand.getType()))
1203 continue;
1204 effects.emplace_back(
1205 MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1206 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1209 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1210 if (!llvm::isa<MemRefType>(operand.get().getType()))
1211 continue;
1212 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1213 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1214 /*effectOnFullRegion=*/true,
1215 SideEffects::DefaultResource::get());
1217 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1218 /*effectOnFullRegion=*/true,
1219 SideEffects::DefaultResource::get());
1223 void GenericOp::getEffects(
1224 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1225 &effects) {
1226 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1229 static Speculation::Speculatability
1230 getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1231 // Operands with value semantics are speculatable, while operands with memory
1232 // semantics are not.
1233 if (!linalgOp.hasPureTensorSemantics())
1234 return Speculation::NotSpeculatable;
1235 // The body of the op can still have speculation in its region.
1236 return Speculation::RecursivelySpeculatable;
1239 Speculation::Speculatability GenericOp::getSpeculatability() {
1240 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1243 LogicalResult GenericOp::verify() { return success(); }
1245 namespace {
1247 /// Remove any linalg operation (on tensors) that are just copying
1248 /// the values from inputs to the results. Requirements are
1249 /// 1) All iterator types are parallel
1250 /// 2) The body contains just a yield operation with the yielded values being
1251 /// the arguments corresponding to the operands.
1252 template <typename OpTy>
1253 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1254 using OpRewritePattern<OpTy>::OpRewritePattern;
1256 LogicalResult matchAndRewrite(OpTy linalgOp,
1257 PatternRewriter &rewriter) const override {
1258 // All indexing maps must be equal. It follows that they are permutations.
1259 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1260 return failure();
1262 // Check that the body of the linalg operation is just a linalg.yield
1263 // operation.
1264 Block &body = linalgOp->getRegion(0).front();
1265 if (!llvm::hasSingleElement(body))
1266 return failure();
1267 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1268 if (!yieldOp)
1269 return failure();
1271 // In the buffer case, we need to check exact buffer equality.
1272 if (linalgOp.hasPureBufferSemantics()) {
1273 if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1274 linalgOp.getDpsInputOperand(0)->get() ==
1275 linalgOp.getDpsInitOperand(0)->get()) {
1276 rewriter.eraseOp(linalgOp);
1277 return success();
1279 return failure();
1282 // Mixed semantics is not supported yet.
1283 if (!linalgOp.hasPureTensorSemantics())
1284 return failure();
1286 // Get the argument number of the returned values. That is the operand
1287 // number to use for replacing uses of this operation.
1288 SmallVector<Value> returnedArgs;
1289 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1290 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1291 if (!yieldArg || yieldArg.getOwner() != &body)
1292 return failure();
1293 unsigned argumentNumber = yieldArg.getArgNumber();
1294 Value returnedArg = linalgOp->getOperand(argumentNumber);
1295 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1296 // The input can have a different type than the result, e.g. a dynamic
1297 // input dimension can be turned into a static output dimension.
1298 Type returnType = returnedArg.getType();
1299 if (returnType != resultType) {
1300 // Distinguish between sparse conversion or dense tensor casting.
1301 // TODO: unify the two ops?
1302 if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1303 sparse_tensor::getSparseTensorEncoding(resultType))
1304 returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1305 linalgOp.getLoc(), resultType, returnedArg);
1306 else {
1307 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1308 resultType))
1309 return failure();
1310 returnedArg = rewriter.create<tensor::CastOp>(
1311 linalgOp.getLoc(), resultType, returnedArg);
1314 returnedArgs.push_back(returnedArg);
1317 if (returnedArgs.size() != linalgOp->getNumResults())
1318 return failure();
1319 rewriter.replaceOp(linalgOp, returnedArgs);
1320 return success();
1324 } // namespace
1326 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1327 MLIRContext *context) {
1328 results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1331 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1332 return memref::foldMemRefCast(*this);
1335 //===----------------------------------------------------------------------===//
1336 // MapOp
1337 //===----------------------------------------------------------------------===//
1339 static ParseResult parseDstStyleOp(
1340 OpAsmParser &parser, OperationState &result,
1341 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1342 nullptr) {
1343 // Parse `ins` and `outs`.
1344 SmallVector<Type, 4> inputTypes, outputTypes;
1345 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1346 /*addOperandSegmentSizes=*/false))
1347 return failure();
1349 // Add result types.
1350 for (Type outputType : outputTypes) {
1351 if (llvm::isa<RankedTensorType>(outputType))
1352 result.addTypes(outputType);
1355 // Parse required attributes.
1356 if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1357 return failure();
1359 // Parse optional attributes.
1360 if (parser.parseOptionalAttrDict(result.attributes))
1361 return failure();
1362 return success();
1365 void MapOp::getAsmBlockArgumentNames(Region &region,
1366 OpAsmSetValueNameFn setNameFn) {
1367 for (Value v : getRegionInputArgs())
1368 setNameFn(v, "in");
1371 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1372 if (!getResults().empty())
1373 setNameFn(getResults().front(), "mapped");
1376 void MapOp::build(
1377 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1378 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1379 ArrayRef<NamedAttribute> attributes) {
1380 build(builder, result, TypeRange{}, inputs, init);
1381 result.addAttributes(attributes);
1383 // Add output types for `RankedTensorType` output arguments.
1384 Type initType = init.getType();
1385 if (llvm::isa<RankedTensorType>(initType))
1386 result.addTypes(initType);
1388 if (bodyBuild)
1389 buildGenericRegion(builder, result.location, *result.regions.front(),
1390 inputs, /*outputs=*/{}, bodyBuild);
1393 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1394 const OperationName &payloadOpName,
1395 const NamedAttrList &payloadOpAttrs,
1396 ArrayRef<Value> operands,
1397 bool initFirst = false) {
1398 OpBuilder b(parser.getContext());
1399 Region *body = result.addRegion();
1400 Block &block = body->emplaceBlock();
1401 b.setInsertionPointToStart(&block);
1402 SmallVector<Value> bbArgs;
1403 for (auto &operand : operands) {
1404 block.addArgument(
1405 llvm::cast<ShapedType>(operand.getType()).getElementType(),
1406 b.getUnknownLoc());
1408 SmallVector<Value> payloadOpOperands;
1409 // If initFirst flag is enabled, we consider init as the first position of
1410 // payload operands.
1411 if (initFirst) {
1412 payloadOpOperands.push_back(block.getArguments().back());
1413 for (const auto &arg : block.getArguments().drop_back())
1414 payloadOpOperands.push_back(arg);
1415 } else {
1416 payloadOpOperands = {block.getArguments().begin(),
1417 block.getArguments().end()};
1420 Operation *payloadOp = b.create(
1421 result.location, b.getStringAttr(payloadOpName.getStringRef()),
1422 payloadOpOperands,
1423 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1424 .getElementType()},
1425 payloadOpAttrs);
1426 b.create<YieldOp>(result.location, payloadOp->getResults());
1429 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1430 std::optional<OperationName> payloadOpName;
1431 NamedAttrList payloadOpAttrs;
1432 if (succeeded(parser.parseOptionalLBrace())) {
1433 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1434 if (failed(operationName))
1435 return failure();
1436 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1437 return failure();
1438 payloadOpName = operationName.value();
1439 if (parser.parseRBrace())
1440 return failure();
1443 if (parseDstStyleOp(parser, result))
1444 return failure();
1446 if (payloadOpName.has_value()) {
1447 if (!result.operands.empty())
1448 addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1449 payloadOpAttrs,
1450 ArrayRef(result.operands).drop_back());
1451 else
1452 result.addRegion();
1453 } else {
1454 SmallVector<OpAsmParser::Argument> regionArgs;
1455 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1456 /*allowType=*/true, /*allowAttrs=*/true)) {
1457 return failure();
1459 Region *body = result.addRegion();
1460 if (parser.parseRegion(*body, regionArgs))
1461 return failure();
1463 return success();
1466 // Retrieve the operation from the body, if it is the only one (except
1467 // yield) and if it gets the same amount of arguments as the body does.
1468 // If initFirst flag is enabled, we check that init takes the first position in
1469 // operands of payload.
1470 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1471 if (body->getOperations().size() != 2)
1472 return nullptr;
1473 Operation &payload = body->getOperations().front();
1474 assert(isa<YieldOp>(body->getOperations().back()));
1476 if (payload.getNumOperands() == 0 ||
1477 payload.getNumOperands() != body->getNumArguments())
1478 return nullptr;
1479 if (initFirst) {
1480 // check init
1481 if (payload.getOperands().back() != body->getArgument(0))
1482 return nullptr;
1483 // check rest
1484 for (const auto &[operand, bbArg] :
1485 llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1486 if (bbArg != operand)
1487 return nullptr;
1489 } else {
1490 for (const auto &[operand, bbArg] :
1491 llvm::zip(payload.getOperands(), body->getArguments())) {
1492 if (bbArg != operand)
1493 return nullptr;
1496 return &payload;
1499 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1500 SmallVector<StringRef> elidedAttrs;
1501 std::string attrToElide;
1502 p << " { " << payloadOp->getName().getStringRef();
1503 for (const auto &attr : payloadOp->getAttrs()) {
1504 auto fastAttr =
1505 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1506 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1507 attrToElide = attr.getName().str();
1508 elidedAttrs.push_back(attrToElide);
1509 break;
1512 p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1513 p << " }";
1516 void MapOp::print(OpAsmPrinter &p) {
1517 Block *mapper = getBody();
1518 Operation *payloadOp = findPayloadOp(mapper);
1519 if (payloadOp) {
1520 printShortForm(p, payloadOp);
1523 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1524 p.printOptionalAttrDict((*this)->getAttrs());
1526 if (!payloadOp) {
1527 // Print region if the payload op was not detected.
1528 p.increaseIndent();
1529 p.printNewline();
1530 p << "(";
1531 llvm::interleaveComma(mapper->getArguments(), p,
1532 [&](auto arg) { p.printRegionArgument(arg); });
1533 p << ") ";
1535 p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1536 p.decreaseIndent();
1540 LogicalResult MapOp::verify() {
1541 auto *bodyBlock = getBody();
1542 auto blockArgs = bodyBlock->getArguments();
1544 // Checks if the number of `inputs` match the arity of the `mapper` region.
1545 if (getInputs().size() != blockArgs.size())
1546 return emitOpError() << "expects number of operands to match the arity of "
1547 "mapper, but got: "
1548 << getInputs().size() << " and " << blockArgs.size();
1550 // The parameters of mapper should all match the element type of inputs.
1551 for (const auto &[bbArgType, inputArg] :
1552 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1553 auto inputElemType =
1554 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1555 if (bbArgType != inputElemType) {
1556 return emitOpError() << "expected element type of input " << inputElemType
1557 << " to match bbArg type " << bbArgType;
1561 // The shape of each input must match the shape of the output.
1562 auto outputShape = getInit().getType().getShape();
1563 for (Type inputArgType : TypeRange{getInputs()}) {
1564 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1565 if (inputElemShape != outputShape) {
1566 return emitOpError() << "expected shape of input (" << inputElemShape
1567 << ") to match shape of output (" << outputShape
1568 << ")";
1572 return success();
1575 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1576 int64_t rank = getInit().getType().getRank();
1577 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1580 ArrayAttr MapOp::getIndexingMaps() {
1581 Builder builder(getContext());
1582 int64_t rank = getInit().getType().getRank();
1583 int64_t numIndexingMaps = getOperands().size();
1584 return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
1585 numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1588 void MapOp::getEffects(
1589 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1590 &effects) {
1591 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1594 Speculation::Speculatability MapOp::getSpeculatability() {
1595 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1598 //===----------------------------------------------------------------------===//
1599 // ReduceOp
1600 //===----------------------------------------------------------------------===//
1602 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1603 OpAsmSetValueNameFn setNameFn) {
1604 for (Value v : getRegionInputArgs())
1605 setNameFn(v, "in");
1606 for (Value v : getRegionOutputArgs())
1607 setNameFn(v, "init");
1610 void ReduceOp::getAsmResultNames(
1611 function_ref<void(Value, StringRef)> setNameFn) {
1612 if (!getResults().empty())
1613 setNameFn(getResults().front(), "reduced");
1616 void ReduceOp::build(
1617 OpBuilder &builder, OperationState &result, ValueRange inputs,
1618 ValueRange inits, ArrayRef<int64_t> dimensions,
1619 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1620 ArrayRef<NamedAttribute> attributes) {
1621 build(builder, result, TypeRange{}, inputs, inits, dimensions);
1622 result.addAttributes(attributes);
1624 // Add output types for `RankedTensorType` output arguments.
1625 for (Value init : inits) {
1626 Type initType = init.getType();
1627 if (llvm::isa<RankedTensorType>(initType))
1628 result.addTypes(initType);
1631 if (bodyBuild)
1632 buildGenericRegion(builder, result.location, *result.regions.front(),
1633 inputs, inits, bodyBuild);
1636 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1637 int64_t inputRank =
1638 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1639 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1640 utils::IteratorType::parallel);
1641 for (int64_t reductionDim : getDimensions())
1642 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1643 return iteratorTypes;
1646 ArrayAttr ReduceOp::getIndexingMaps() {
1647 int64_t inputRank =
1648 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1649 SmallVector<AffineMap> affineMaps(
1650 getNumDpsInputs(),
1651 AffineMap::getMultiDimIdentityMap(inputRank, getContext()));
1652 AffineMap resultMap =
1653 AffineMap::getMultiDimIdentityMap(inputRank, getContext())
1654 .dropResults(getDimensions());
1655 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1656 affineMaps.push_back(resultMap);
1657 return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1660 void ReduceOp::getEffects(
1661 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1662 &effects) {
1663 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1666 Speculation::Speculatability ReduceOp::getSpeculatability() {
1667 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1670 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1671 NamedAttrList &attributes,
1672 StringRef attributeName) {
1673 if (parser.parseKeyword(attributeName) || parser.parseEqual())
1674 return failure();
1676 attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1677 return success();
1680 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1681 std::optional<OperationName> payloadOpName;
1682 NamedAttrList payloadOpAttrs;
1683 if (succeeded(parser.parseOptionalLBrace())) {
1684 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1685 if (failed(operationName))
1686 return failure();
1687 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1688 return failure();
1689 payloadOpName = operationName.value();
1690 if (parser.parseRBrace())
1691 return failure();
1694 if (parseDstStyleOp(
1695 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1696 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1698 return failure();
1700 if (payloadOpName.has_value()) {
1701 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1702 ArrayRef(result.operands), /*initFirst=*/true);
1703 } else {
1704 SmallVector<OpAsmParser::Argument> regionArgs;
1705 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1706 /*allowType=*/true, /*allowAttrs=*/true)) {
1707 return failure();
1710 Region *body = result.addRegion();
1711 if (parser.parseRegion(*body, regionArgs))
1712 return failure();
1715 return success();
1718 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1719 ArrayRef<int64_t> attributeValue) {
1720 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1723 void ReduceOp::print(OpAsmPrinter &p) {
1724 Block *mapper = getBody();
1725 Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1726 if (payloadOp) {
1727 printShortForm(p, payloadOp);
1730 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1731 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1732 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1733 if (!payloadOp) {
1734 // Print region if the payload op was not detected.
1735 p.increaseIndent();
1736 p.printNewline();
1737 p << "(";
1738 llvm::interleaveComma(mapper->getArguments(), p,
1739 [&](auto arg) { p.printRegionArgument(arg); });
1740 p << ") ";
1742 p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1743 p.decreaseIndent();
1747 LogicalResult ReduceOp::verify() {
1748 ArrayRef<int64_t> dimensionsRef = getDimensions();
1750 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1751 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1752 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1753 return emitOpError() << "expects all inputs to have the same shapes. "
1754 "Shape at input-index "
1755 << i
1756 << " is not equal to the shape at input-index 0.";
1759 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1760 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1761 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1762 return emitOpError() << "expects all outputs to have the same shapes. "
1763 "Shape at output-index "
1764 << i
1765 << " is not equal to the shape at output-index 0.";
1768 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1769 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1771 DenseSet<int64_t> dimensionsToReduce;
1772 for (int64_t dimension : dimensionsRef) {
1773 if (dimension < 0 || dimension >= inputType.getRank()) {
1774 return emitOpError()
1775 << "dimensions for reduction should be in the range [0, "
1776 << inputType.getRank() - 1 << "].";
1778 dimensionsToReduce.insert(dimension);
1781 auto inputDims = inputType.getShape();
1782 auto initDims = initType.getShape();
1784 // Input dimensions that will be left after the reduction.
1785 SmallVector<int64_t> reducedInputDims;
1786 for (const auto &en : llvm::enumerate(inputDims)) {
1787 if (!dimensionsToReduce.count(en.index()))
1788 reducedInputDims.push_back(en.value());
1791 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1792 return emitOpError() << "number of dimensions after reduction "
1793 << reducedInputDims.size()
1794 << " doesn't match the init rank "
1795 << initType.getRank();
1798 if (reducedInputDims != initDims)
1799 return emitOpError() << "init dimensions [" << initDims
1800 << "] doesn't match input dimensions after reduction ["
1801 << reducedInputDims << "]";
1803 Block *block = getBody();
1804 if (block->getNumArguments() != this->getNumOperands())
1805 return emitOpError()
1806 << "mismatching number of operands and block arguments";
1808 // Check that the first block arguments match the element type of the inputs.
1809 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1810 Type inputElementType =
1811 llvm::cast<ShapedType>(input.getType()).getElementType();
1812 if (inputElementType != bbArg.getType())
1813 return emitOpError()
1814 << "input element type " << inputElementType
1815 << " does not match corresponding block argument type "
1816 << bbArg.getType();
1819 // Check that the last block arguments match the element type of the outputs.
1820 for (auto [output, bbArg] : llvm::zip(
1821 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1822 auto outputElementType =
1823 llvm::cast<ShapedType>(output.getType()).getElementType();
1824 if (outputElementType != bbArg.getType())
1825 return emitOpError()
1826 << "output element type " << outputElementType
1827 << " does not match corresponding block argument type "
1828 << bbArg.getType();
1830 return success();
1833 //===----------------------------------------------------------------------===//
1834 // TransposeOp
1835 //===----------------------------------------------------------------------===//
1837 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1838 Region &region, ValueRange inputs,
1839 ValueRange outputs) {
1840 buildGenericRegion(builder, loc, region, inputs, outputs,
1841 [](OpBuilder &b, Location loc, ValueRange args) {
1842 if (!args.empty())
1843 b.create<linalg::YieldOp>(loc, args[0]);
1847 void TransposeOp::build(::mlir::OpBuilder &builder,
1848 ::mlir::OperationState &result, Value input, Value init,
1849 DenseI64ArrayAttr permutation,
1850 ArrayRef<NamedAttribute> attributes) {
1851 result.addOperands(input);
1852 result.addOperands(init);
1853 result.addAttribute(getPermutationAttrName(result.name), permutation);
1854 result.addAttributes(attributes);
1856 // Add output types for `RankedTensorType` output arguments.
1857 Type initType = init.getType();
1858 if (llvm::isa<RankedTensorType>(initType))
1859 result.addTypes(initType);
1861 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1862 init);
1865 void TransposeOp::build(::mlir::OpBuilder &builder,
1866 ::mlir::OperationState &result, Value input, Value init,
1867 ArrayRef<int64_t> permutation,
1868 ArrayRef<NamedAttribute> attributes) {
1869 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1870 attributes);
1873 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1874 if (failed(parseDstStyleOp(
1875 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1876 return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1877 })))
1878 return failure();
1880 OpBuilder builder(parser.getContext());
1881 buildIdentityRegion(builder, result.location, *result.addRegion(),
1882 /*inputs=*/result.operands,
1883 /*outputs=*/{});
1884 return success();
1887 void TransposeOp::getAsmResultNames(
1888 function_ref<void(Value, StringRef)> setNameFn) {
1889 if (!getResults().empty())
1890 setNameFn(getResults().front(), "transposed");
1893 void TransposeOp::print(OpAsmPrinter &p) {
1894 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1895 printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1896 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1899 LogicalResult TransposeOp::verify() {
1900 ArrayRef<int64_t> permutationRef = getPermutation();
1902 if (!isPermutationVector(permutationRef))
1903 return emitOpError("permutation is not valid");
1905 auto inputType = getInput().getType();
1906 auto initType = getInit().getType();
1908 int64_t rank = inputType.getRank();
1910 if (rank != initType.getRank())
1911 return emitOpError() << "input rank " << rank
1912 << " does not match init rank " << initType.getRank();
1914 if (rank != static_cast<int64_t>(permutationRef.size()))
1915 return emitOpError() << "size of permutation " << permutationRef.size()
1916 << " does not match the argument rank " << rank;
1918 auto inputDims = inputType.getShape();
1919 auto initDims = initType.getShape();
1921 for (int64_t i = 0; i < rank; ++i) {
1922 int64_t inputDim = inputDims[permutationRef[i]];
1923 int64_t initDim = initDims[i];
1925 if (inputDim != initDim) {
1926 return emitOpError() << "dim(result, " << i << ") = " << initDim
1927 << " doesn't match dim(input, permutation[" << i
1928 << "]) = " << inputDim;
1932 return success();
1935 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1936 int64_t rank = getInit().getType().getRank();
1937 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1940 ArrayAttr TransposeOp::getIndexingMaps() {
1941 Builder builder(getContext());
1942 int64_t rank = getInit().getType().getRank();
1943 return builder.getAffineMapArrayAttr(
1944 {inversePermutation(AffineMap::getPermutationMap(
1945 llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1946 builder.getMultiDimIdentityMap(rank)});
1949 void TransposeOp::getEffects(
1950 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1951 &effects) {
1952 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1955 Speculation::Speculatability TransposeOp::getSpeculatability() {
1956 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1959 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1960 SmallVectorImpl<OpFoldResult> &result) {
1961 // Only the tensor type is supported.
1962 if (!isa<TensorType>(getInput().getType()))
1963 return failure();
1965 // Single dimension transpose.
1966 if (getPermutation().size() == 0) {
1967 result.push_back(getInput());
1968 return success();
1970 // Identity permutation.
1971 if (isIdentityPermutation(getPermutation())) {
1972 result.push_back(getInput());
1973 return success();
1976 return failure();
1979 /// Fold transpose with transpose.
1980 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1981 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1983 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1984 PatternRewriter &rewriter) const override {
1985 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
1986 if (!defTransposeOp)
1987 return failure();
1988 ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
1989 ArrayRef<int64_t> perms = transposeOp.getPermutation();
1990 SmallVector<int64_t> foldedPerms;
1991 foldedPerms.reserve(perms.size());
1992 for (int64_t perm : perms)
1993 foldedPerms.push_back(defPerms[perm]);
1995 rewriter.replaceOpWithNewOp<TransposeOp>(
1996 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
1997 foldedPerms);
1998 return success();
2002 /// This pattern canonicalize transpose by swapping the order of
2003 /// broadcast and transpose:
2004 /// transpose(broadcast(input)) -> broadcast(transpose(input))
2005 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2006 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2008 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2009 PatternRewriter &rewriter) const override {
2010 Value input = transposeOp.getInput();
2011 BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2012 if (!input.hasOneUse() || !broadcastOp)
2013 return failure();
2015 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2016 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2018 // Get new perms and new dimensions.
2019 SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2020 SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
2021 SmallVector<int64_t> resultDimensions;
2022 unsigned dimensionSize = dimensions.size();
2023 for (unsigned i = 0; i < dimensionSize; ++i)
2024 resultDimensions.push_back(invertPerm[dimensions[i]]);
2026 // Create transpose result.
2027 Value broadcastInput = broadcastOp.getInput();
2028 Location loc = transposeOp.getLoc();
2029 MLIRContext *ctx = transposeOp.getContext();
2030 SmallVector<OpFoldResult> dims;
2031 auto broadcastInputTy =
2032 mlir::cast<RankedTensorType>(broadcastInput.getType());
2033 unsigned inputRank = broadcastInputTy.getRank();
2034 for (unsigned i = 0; i < inputRank; ++i) {
2035 if (broadcastInputTy.isDynamicDim(i)) {
2036 dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
2037 ->getResult(0));
2038 } else {
2039 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2040 broadcastInputTy.getDimSize(i)));
2043 SmallVector<OpFoldResult> transposeResultShapes =
2044 applyPermutation(dims, resultPerms);
2045 Value transposeInit = rewriter.create<tensor::EmptyOp>(
2046 transposeOp.getLoc(), transposeResultShapes,
2047 broadcastInputTy.getElementType());
2049 // Create broadcast(transpose(input)).
2050 Value transposeResult =
2051 rewriter
2052 .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2053 resultPerms)
2054 ->getResult(0);
2055 rewriter.replaceOpWithNewOp<BroadcastOp>(
2056 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2057 return success();
2061 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2062 MLIRContext *context) {
2063 results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2066 //===----------------------------------------------------------------------===//
2067 // BroadcastOp
2068 //===----------------------------------------------------------------------===//
2070 void BroadcastOp::build(::mlir::OpBuilder &builder,
2071 ::mlir::OperationState &result, Value input, Value init,
2072 DenseI64ArrayAttr dimensions,
2073 ArrayRef<NamedAttribute> attributes) {
2074 result.addOperands(input);
2075 result.addOperands(init);
2076 result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2077 result.addAttributes(attributes);
2079 // Add output types for `RankedTensorType` output arguments.
2080 Type initType = init.getType();
2081 if (llvm::isa<RankedTensorType>(initType))
2082 result.addTypes(initType);
2084 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2085 init);
2088 void BroadcastOp::build(::mlir::OpBuilder &builder,
2089 ::mlir::OperationState &result, Value input, Value init,
2090 ArrayRef<int64_t> dimensions,
2091 ArrayRef<NamedAttribute> attributes) {
2092 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2093 attributes);
2096 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2097 if (failed(parseDstStyleOp(
2098 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2099 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2100 })))
2101 return failure();
2103 OpBuilder builder(parser.getContext());
2104 buildIdentityRegion(builder, result.location, *result.addRegion(),
2105 /*inputs=*/result.operands,
2106 /*outputs=*/{});
2107 return success();
2110 void BroadcastOp::getAsmResultNames(
2111 function_ref<void(Value, StringRef)> setNameFn) {
2112 if (!getResults().empty())
2113 setNameFn(getResults().front(), "broadcasted");
2116 void BroadcastOp::print(OpAsmPrinter &p) {
2117 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2118 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2119 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2122 LogicalResult BroadcastOp::verify() {
2123 ArrayRef<int64_t> dimensionsRef = getDimensions();
2125 auto inputType = getInput().getType();
2126 auto initType = getInit().getType();
2128 int64_t inputRank = inputType.getRank();
2129 int64_t initRank = initType.getRank();
2131 auto inputShape = inputType.getShape();
2132 auto initShape = initType.getShape();
2134 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2135 return emitOpError() << "input rank plus added dimensions does not "
2136 "match init rank. input rank: "
2137 << inputRank
2138 << ", dimensions size: " << dimensionsRef.size()
2139 << ", init rank: " << initRank;
2141 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2142 if (dim < 0 || dim >= initRank)
2143 return emitOpError() << "dimension " << idx
2144 << " is out of range. expected range: [0, "
2145 << initRank - 1 << "], got: " << dim;
2148 // Mapping from input dims to init dims.
2149 SmallVector<int64_t> dimMap;
2150 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2151 if (!llvm::is_contained(dimensionsRef, dim))
2152 dimMap.push_back(dim);
2155 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2156 // This dimensions is mapped from the input. Init and input dims should
2157 // match.
2158 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2159 return emitOpError() << "input dim " << inputDimIdx
2160 << " should match init dim " << initDimIdx
2161 << ". input: " << inputShape[inputDimIdx]
2162 << ", init: " << initShape[initDimIdx];
2165 return success();
2168 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2169 int64_t rank = getInit().getType().getRank();
2170 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2173 ArrayAttr BroadcastOp::getIndexingMaps() {
2174 Builder builder(getContext());
2175 int64_t rank = getInit().getType().getRank();
2176 return builder.getAffineMapArrayAttr(
2177 {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2178 builder.getMultiDimIdentityMap(rank)});
2181 void BroadcastOp::getEffects(
2182 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2183 &effects) {
2184 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2187 Speculation::Speculatability BroadcastOp::getSpeculatability() {
2188 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2191 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2192 MLIRContext *context) {
2193 results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2196 //===----------------------------------------------------------------------===//
2197 // YieldOp
2198 //===----------------------------------------------------------------------===//
2200 void linalg::YieldOp::print(OpAsmPrinter &p) {
2201 if (getNumOperands() > 0)
2202 p << ' ' << getOperands();
2203 p.printOptionalAttrDict((*this)->getAttrs());
2204 if (getNumOperands() > 0)
2205 p << " : " << getOperandTypes();
2208 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2209 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2210 SmallVector<Type, 2> types;
2211 SMLoc loc = parser.getCurrentLocation();
2212 return failure(parser.parseOperandList(opInfo) ||
2213 parser.parseOptionalAttrDict(result.attributes) ||
2214 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2215 parser.resolveOperands(opInfo, types, loc, result.operands));
2218 // Check the operand number and types must match the element types of the
2219 // LinalgOp interface's shaped operands.
2220 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2221 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2222 return op.emitOpError("expected number of yield values (")
2223 << op.getNumOperands()
2224 << ") to match the number of inits / outs operands of the enclosing "
2225 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2227 for (OpOperand &opOperand : op->getOpOperands()) {
2228 OpOperand *outputOperand =
2229 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2230 Type elementType = outputOperand->get().getType();
2231 if (isa<MemRefType, RankedTensorType>(elementType))
2232 elementType = getElementTypeOrSelf(outputOperand->get().getType());
2233 if (opOperand.get().getType() != elementType)
2234 return op.emitOpError("type of yield operand ")
2235 << (opOperand.getOperandNumber() + 1) << " ("
2236 << opOperand.get().getType() << ") doesn't match "
2237 << "the element type of the enclosing linalg.generic op ("
2238 << elementType << ")";
2240 return success();
2243 LogicalResult linalg::YieldOp::verify() {
2244 auto *parentOp = (*this)->getParentOp();
2245 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2246 return emitOpError("expected single non-empty parent region");
2248 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2249 return verifyYield(*this, linalgOp);
2251 return emitOpError("expected parent op with LinalgOp interface");
2254 //===----------------------------------------------------------------------===//
2255 // IndexOp
2256 //===----------------------------------------------------------------------===//
2258 LogicalResult IndexOp::verify() {
2259 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2260 if (!linalgOp)
2261 return emitOpError("expected parent op with LinalgOp interface");
2262 if (linalgOp.getNumLoops() <= getDim())
2263 return emitOpError("expected dim (")
2264 << getDim() << ") to be lower than the number of loops ("
2265 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2266 return success();
2269 /////// Operations corresponding to library calls defined with Tablegen ////////
2271 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2273 #define GET_OP_CLASSES
2274 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2276 #define GET_OP_CLASSES
2277 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2279 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2280 unsigned rank,
2281 MLIRContext *context) {
2282 if (maybeMap)
2283 return *maybeMap;
2284 if (rank == 0)
2285 return AffineMap::get(context);
2286 return AffineMap::getMultiDimIdentityMap(rank, context);
2289 SmallVector<AffineExpr, 4>
2290 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2291 MLIRContext *context) {
2292 SmallVector<AffineExpr, 4> res;
2293 res.reserve(num);
2294 for (unsigned i = 0; i < num; ++i)
2295 res.push_back(getAffineDimExpr(startIdx++, context));
2296 return res;
2299 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
2300 ArrayRef<AffineExpr> b) {
2301 auto rangeA = llvm::make_range(a.begin(), a.end());
2302 auto rangeB = llvm::make_range(b.begin(), b.end());
2303 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2304 return llvm::to_vector<4>(concatRanges);
2307 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2308 if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2309 ss << "view";
2310 for (auto size : memref.getShape())
2311 if (size < 0)
2312 ss << "sx";
2313 else
2314 ss << size << "x";
2315 if (failed(appendMangledType(ss, memref.getElementType())))
2316 return failure();
2317 if (auto as = memref.getMemorySpace()) {
2318 if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2319 ss << "as" << attr.getInt();
2320 else
2321 return failure();
2323 return success();
2325 if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2326 ss << "vector";
2327 llvm::interleave(
2328 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2329 if (failed(appendMangledType(ss, vec.getElementType())))
2330 return failure();
2331 return success();
2333 if (t.isSignlessIntOrIndexOrFloat()) {
2334 ss << t;
2335 return success();
2337 return failure();
2340 std::string mlir::linalg::generateLibraryCallName(Operation *op) {
2341 assert(isa<LinalgOp>(op));
2342 std::string name(op->getName().getStringRef().str());
2343 std::string fun = "";
2344 for (NamedAttribute kv : op->getAttrs()) {
2345 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2346 fun = stringifyEnum(ufa.getValue()).str() + "_";
2347 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2348 fun = stringifyEnum(bfa.getValue()).str() + "_";
2351 name.reserve(128);
2352 std::replace(name.begin(), name.end(), '.', '_');
2353 llvm::raw_string_ostream ss(name);
2354 ss << "_" << fun;
2355 for (Type t : op->getOperandTypes()) {
2356 if (failed(appendMangledType(ss, t)))
2357 return std::string();
2358 ss << "_";
2360 name.pop_back();
2361 return name;
2364 //===----------------------------------------------------------------------===//
2365 // Canonicalizers and Folders.
2366 //===----------------------------------------------------------------------===//
2368 namespace {
2369 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2370 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2372 LogicalResult matchAndRewrite(LinalgOp op,
2373 PatternRewriter &rewriter) const override {
2374 for (OpOperand &opOperand : op->getOpOperands()) {
2375 // Linalg "inputs" may be either tensor or memref type.
2376 // tensor<0xelt_type> is a convention that may not always mean
2377 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2378 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2379 if (!mt)
2380 continue;
2381 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2382 rewriter.eraseOp(op);
2383 return success();
2386 return failure();
2390 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2391 /// result that is more static than the linalg op.
2392 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2393 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2395 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2396 PatternRewriter &rewriter) const override {
2397 if (!tensor::canFoldIntoProducerOp(castOp))
2398 return failure();
2400 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2401 if (!linalgOp)
2402 return failure();
2404 // Cast can be in conditionally reachable region, if which case folding will
2405 // generate invalid code. Only conservatively fold ops in same block for
2406 // now.
2407 if (castOp->getBlock() != linalgOp->getBlock())
2408 return failure();
2410 OpBuilder::InsertionGuard guard(rewriter);
2411 rewriter.setInsertionPoint(linalgOp);
2413 Location loc = linalgOp.getLoc();
2414 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2415 unsigned resultNumber = resultValue.getResultNumber();
2416 auto resultType =
2417 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2418 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2419 // going from a more dynamic shape to a less dynamic shape. If the producer
2420 // for this cast, i.e. producer of the out operand, is also an operation
2421 // that folds with tensor.cast consumer (like this pattern), the cast will
2422 // continue to propagate as far up the stack as it can go.
2423 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2424 Value newOperand =
2425 rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2426 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2427 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2428 linalgOp.getDpsInits().end());
2429 outputOperands[resultNumber] = newOperand;
2430 newOperands.append(outputOperands.begin(), outputOperands.end());
2432 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2433 linalgOp->result_type_end());
2434 resultTypes[resultNumber] = resultType;
2435 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2437 // Create a tensor.cast operation back to the original type.
2438 Value castBack = rewriter.create<tensor::CastOp>(
2439 loc, resultValue.getType(), newOp->getResult(resultNumber));
2441 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2442 results[resultNumber] = castBack;
2443 rewriter.replaceOp(linalgOp, results);
2444 rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2445 return success();
2449 /// For each of the operand in `operands` this function maps the static sizes of
2450 /// dimensions to their affine dim expressions.
2451 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2452 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2453 for (OpOperand &opOperand : operands) {
2454 if (linalgOp.isScalar(&opOperand))
2455 continue;
2456 Value src = opOperand.get();
2457 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2458 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2460 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2461 // `tensor.cast` operation and source of the cast operation has a static
2462 // shape, then assign it to the `sourceShape`.
2463 auto *parentOp = src.getDefiningOp();
2464 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2465 if (parentOp) {
2466 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2467 Value castSource = castOp.getSource();
2468 auto castSourceType =
2469 llvm::dyn_cast<RankedTensorType>(castSource.getType());
2470 if (castSourceType && castSourceType.hasStaticShape())
2471 sourceShape = castSourceType.getShape();
2475 // If the source shape's dimension has a static shape, map the affine dim
2476 // expression to the known static size.
2477 for (unsigned i = 0; i < sourceShape.size(); i++) {
2478 if (sourceType.isDynamicDim(i))
2479 continue;
2480 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2481 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2486 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2487 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2488 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2489 /// change then `changeNeeded` is false and same operand is added in the
2490 /// `newOperands` list.
2491 static void createNewOperandWithStaticSizes(
2492 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2493 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2494 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2495 bool &changeNeeded) {
2496 Value src = opOperand->get();
2497 newOperands.push_back(src);
2498 if (linalgOp.isScalar(opOperand))
2499 return;
2500 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2501 Type resultType = sourceType;
2502 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2503 resultTypes.push_back(resultType);
2504 return;
2506 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2507 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2508 SmallVector<int64_t> newShape;
2509 // If operand is updated with new shape, `newOperandNeeded` will be
2510 // true.
2511 bool newOperandNeeded = false;
2512 for (unsigned i = 0; i < sourceShape.size(); i++) {
2513 int64_t dimShape = sourceShape[i];
2514 AffineExpr dimExpr = sourceMap.getResult(i);
2515 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2516 newShape.push_back(dimShape);
2517 continue;
2519 // Dimension has a dynamic shape and corresponding affine dim
2520 // expression is present in the map. So assign the size for the
2521 // given affine dim expression to the dimension.
2522 newShape.push_back(affineExprToSize[dimExpr]);
2523 newOperandNeeded = true;
2525 resultType = RankedTensorType::get(newShape, sourceType.getElementType());
2526 if (newOperandNeeded) {
2527 changeNeeded = true;
2528 // Get the new operand value given its size and element type by
2529 // casting it.
2530 Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2531 unsigned index = opOperand->getOperandNumber();
2532 newOperands[index] = newOperand;
2534 if (linalgOp.isDpsInit(opOperand))
2535 resultTypes.push_back(resultType);
2538 /// Static shapes for the operands can be inferred if any one of the operands
2539 /// have a static shape. This can be done by referring to the affine dim
2540 /// expressions for the operand.
2541 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2542 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2544 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2545 PatternRewriter &rewriter) const override {
2546 if (!linalgOp.hasPureTensorSemantics())
2547 return failure();
2549 // Maps must be projected permutations.
2550 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2551 return !map.isProjectedPermutation();
2553 return failure();
2555 // Maps affine dim expressions to the static size of that dimension.
2556 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2557 Location loc = linalgOp.getLoc();
2559 // For each of the affine dim expression, check if the size is known. If
2560 // known add that in the map.
2561 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2563 SmallVector<Value> newOperands;
2564 SmallVector<Type> resultTypes;
2566 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2567 // change in their types.
2568 bool changeNeeded = false;
2569 newOperands.reserve(linalgOp->getNumOperands());
2570 resultTypes.reserve(linalgOp.getNumDpsInits());
2572 // Iterate over all the operands and update the static sizes.
2573 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2574 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2575 affineExprToSize, linalgOp, newOperands,
2576 resultTypes, changeNeeded);
2579 // If the generic op has all the required static information, no
2580 // canonicalization needed.
2581 if (!changeNeeded)
2582 return failure();
2584 // Clone op.
2585 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2586 SmallVector<Value> replacements;
2587 replacements.reserve(newOp->getNumResults());
2588 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2589 Value newResult = std::get<1>(it);
2590 Value oldResult = std::get<0>(it);
2591 Type newType = newResult.getType();
2592 Type oldType = oldResult.getType();
2593 replacements.push_back(
2594 (newType != oldType)
2595 ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2596 : newResult);
2598 rewriter.replaceOp(linalgOp, replacements);
2599 return success();
2603 } // namespace
2605 // All named ops canonicalizers and folders are auto-generated in the
2606 // .cpp.inc.
2608 //===----------------------------------------------------------------------===//
2609 // SoftmaxOp
2610 //===----------------------------------------------------------------------===//
2612 LogicalResult SoftmaxOp::verify() {
2613 ShapedType inputType = getInputOperandType();
2614 ShapedType outputType = getOutputOperandType();
2616 ArrayRef<int64_t> inputShape = inputType.getShape();
2617 ArrayRef<int64_t> outputShape = outputType.getShape();
2618 if (failed(verifyCompatibleShape(inputShape, outputShape)))
2619 return emitOpError("incompatible output shape");
2621 int64_t inputRank = getInputOperandRank();
2622 int64_t dimension = getDimension();
2623 if ((dimension < 0) || (dimension >= inputRank))
2624 return emitOpError("incorrect dimension specified");
2626 return success();
2629 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2630 int64_t operandRank = getInputOperandRank();
2631 SmallVector<Range> loopBounds(operandRank);
2632 Location loc = getLoc();
2633 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2634 Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2635 Value source = getInput();
2636 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2637 loopBounds[dim].offset = zero;
2638 loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2639 loopBounds[dim].stride = one;
2641 return loopBounds;
2644 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2645 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2646 utils::IteratorType::parallel);
2647 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2648 return iteratorTypes;
2651 FailureOr<TilingResult>
2652 SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2653 ArrayRef<OpFoldResult> offsets,
2654 ArrayRef<OpFoldResult> sizes) {
2655 int64_t rank = getInputOperandRank();
2656 auto oneAttr = builder.getI64IntegerAttr(1);
2657 SmallVector<OpFoldResult> strides(rank, oneAttr);
2658 SmallVector<Value> tiledOperands;
2659 Operation *inputSlice =
2660 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2661 if (!inputSlice) {
2662 return emitOpError("failed to compute input slice");
2664 tiledOperands.emplace_back(inputSlice->getResult(0));
2665 Operation *outputSlice =
2666 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2667 if (!outputSlice) {
2668 return emitOpError("failed to compute output slice");
2670 tiledOperands.emplace_back(outputSlice->getResult(0));
2672 SmallVector<Type, 4> resultTypes;
2673 if (hasPureTensorSemantics())
2674 resultTypes.push_back(tiledOperands[1].getType());
2675 Operation *tiledOp =
2676 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2678 return TilingResult{
2679 {tiledOp},
2680 SmallVector<Value>(tiledOp->getResults()),
2681 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2684 LogicalResult SoftmaxOp::getResultTilePosition(
2685 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2686 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2687 SmallVector<OpFoldResult> &resultSizes) {
2688 if (resultNumber == 0) {
2689 resultOffsets.assign(offsets.begin(), offsets.end());
2690 resultSizes.assign(sizes.begin(), sizes.end());
2691 return success();
2693 return failure();
2696 // cast(dynamic) -> static.
2697 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2698 return memref::foldMemRefCast(*this);
2701 LogicalResult
2702 SoftmaxOp::reifyResultShapes(OpBuilder &b,
2703 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2704 SmallVector<OpFoldResult> shapes;
2705 Location loc = getOperation()->getLoc();
2706 IRRewriter rewriter(b);
2707 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2708 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2709 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2710 if (!outputShapedType.isDynamicDim(dim)) {
2711 // Static dim: Return IntegerAttr.
2712 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2713 } else {
2714 // Dynamic dim: Return Value.
2715 OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2716 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2719 reifiedReturnShapes.emplace_back(std::move(shapes));
2720 return success();
2723 void SoftmaxOp::getEffects(
2724 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2725 &effects) {
2726 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2727 if (!llvm::isa<MemRefType>(operand.getType()))
2728 continue;
2729 effects.emplace_back(MemoryEffects::Read::get(),
2730 &getOperation()->getOpOperand(index), /*stage=*/0,
2731 /*effectOnFullRegion=*/true,
2732 SideEffects::DefaultResource::get());
2735 for (OpOperand &operand : getDpsInitsMutable()) {
2736 if (!llvm::isa<MemRefType>(operand.get().getType()))
2737 continue;
2738 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2739 /*effectOnFullRegion=*/true,
2740 SideEffects::DefaultResource::get());
2741 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2742 /*effectOnFullRegion=*/true,
2743 SideEffects::DefaultResource::get());
2747 // Helper functions for softmax decomposition.
2748 // @{
2750 // Helper function to produce the iterator types (reduction or parallel) and
2751 // affine maps for the iterators used in the decomposition of softmax.
2752 // This method creates:
2753 // If allParallel == true:
2754 // - iterator type: {parallel, ..., parallel}
2755 // - affine maps:
2756 // -- identity with inputRank dimensions.
2757 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2758 // where N == inputRank.
2760 // If allParallel == false:
2761 // - iterator type at dim(i) == parallel for i != \p dim and
2762 // dim(dim) == reduction.
2763 // - affine map:
2764 // -- identity with inputRank dimensions.
2765 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2766 // where N == inputRank.
2767 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2768 computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
2769 int64_t dim, bool allParallel = false) {
2770 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2771 utils::IteratorType::parallel);
2772 if (!allParallel)
2773 iteratorTypes[dim] = utils::IteratorType::reduction;
2774 MLIRContext *ctxt = builder.getContext();
2775 auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2776 SmallVector<AffineExpr, 2> affineExprs;
2777 for (int i = 0; i < inputRank; i++) {
2778 if (i != dim)
2779 affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2781 auto reductionMap =
2782 AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2783 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2784 return std::make_tuple(iteratorTypes, indexingMaps);
2787 // Helper function to produce a linalg.generic that computes a reduction on
2788 // dimension \p dim with the operation type \p T.
2789 template <typename T>
2790 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2791 int64_t dim) {
2792 auto inputType = cast<ShapedType>(input.getType());
2793 ArrayRef<int64_t> inputShape = inputType.getShape();
2794 int64_t inputRank = inputShape.size();
2795 auto [iteratorTypes, indexingMaps] =
2796 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2797 assert(indexingMaps.size() == 2 &&
2798 "We should have two maps: 1 for the input, 1 for the output");
2799 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2801 auto genericOp = builder.create<linalg::GenericOp>(
2802 loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2803 [&](OpBuilder &b, Location loc, ValueRange args) {
2804 Value result = b.create<T>(loc, args[0], args[1]);
2805 b.create<linalg::YieldOp>(loc, result);
2807 return genericOp.getResult(0);
2810 /// Produce a linalg generic that computes the second step of the softmax
2811 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2812 /// on dimension \p dim.
2813 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2814 Value max, Value output, int64_t dim) {
2815 auto inputType = cast<ShapedType>(input.getType());
2816 ArrayRef<int64_t> inputShape = inputType.getShape();
2817 int64_t inputRank = inputShape.size();
2818 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2819 builder, inputRank, dim, /*allParallel=*/true);
2820 assert(indexingMaps.size() == 2 && "We should have one map for each input");
2821 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2822 // Add the affine map for the output argument.
2823 indexingMaps.push_back(indexingMaps[0]);
2824 auto genericOp = builder.create<linalg::GenericOp>(
2825 loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2826 iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2827 Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2828 Value result = b.create<math::ExpOp>(loc, diff);
2829 b.create<linalg::YieldOp>(loc, result);
2831 return genericOp.getResult(0);
2834 /// Produce a linalg generic that computes the final step of the softmax
2835 /// decomposition.
2836 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2837 /// yield n / d
2838 /// }
2839 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2840 Value denominator, Value output, int64_t dim) {
2841 auto inputType = cast<ShapedType>(numerator.getType());
2842 ArrayRef<int64_t> inputShape = inputType.getShape();
2843 int64_t inputRank = inputShape.size();
2844 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2845 builder, inputRank, dim, /*allParallel=*/true);
2846 assert(indexingMaps.size() == 2 &&
2847 "We should have one map for each input (2)");
2848 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2849 // Add the affine map for the output tensor.
2850 indexingMaps.push_back(indexingMaps[0]);
2851 auto genericOp = builder.create<linalg::GenericOp>(
2852 loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2853 indexingMaps, iteratorTypes,
2854 [&](OpBuilder &b, Location loc, ValueRange args) {
2855 Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2856 b.create<linalg::YieldOp>(loc, result);
2858 return genericOp.getResult(0);
2860 // @} End helper functions for softmax decomposition.
2862 /// Given an N-dimensional tensor x, this method converts
2863 /// softmax(x) to the following sequence of operations:
2865 /// 1. Compute the max of x along dimension d. This results
2866 /// in a N-1 dimensional tensor m.
2867 /// m = max(x, dim = d)
2869 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2870 /// a N dimensional tensor z.
2871 /// z = exp(x - m)
2873 /// 3. Compute the sum of z along dimension d. This results in
2874 /// a N-1 dimensional tensor l.
2875 /// l = sum(z, dim = d)
2877 /// 4. Divide z and l. This gives the N-dimensional softmax.
2878 /// softmax = z / l
2880 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2881 OpBuilder::InsertionGuard guard(b);
2882 b.setInsertionPoint(*this);
2883 Location loc = getLoc();
2884 Value input = getInput();
2885 ShapedType inputType = getInputOperandType();
2886 Type elementType = inputType.getElementType();
2887 int64_t reductionDim = getDimension();
2888 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2889 Value output = getOutput();
2890 dims.erase(dims.begin() + reductionDim);
2891 // Step 1: Compute max along dim.
2892 Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2893 Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
2894 elementType, b, loc,
2895 /*useOnlyFiniteValue=*/true);
2896 Value neutralForMaxFInit =
2897 b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2898 .result();
2899 Value max =
2900 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2902 // Step 2: Subtract max from input and exponentiate.
2903 Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2905 // Step 3: Compute sum along dim.
2906 Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2907 b, loc, /*useOnlyFiniteValue=*/true);
2908 Value zeroInit =
2909 b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2910 Value denominator =
2911 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2913 // Step 4: Compute softmax.
2914 Value result =
2915 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2916 return SmallVector<Value>{result};
2919 //===----------------------------------------------------------------------===//
2920 // WinogradFilterTransformOp
2921 //===----------------------------------------------------------------------===//
2923 LogicalResult WinogradFilterTransformOp::verify() {
2924 auto filterType = cast<ShapedType>(getFilter().getType());
2925 ArrayRef<int64_t> filterShape = filterType.getShape();
2926 int64_t filterH = filterShape[getFilterHDim()];
2927 int64_t filterW = filterShape[getFilterWDim()];
2928 int64_t r = getR();
2929 int64_t m = getM();
2931 if (filterH != r && filterH != 1)
2932 return emitOpError("expect filter height either equals to r or 1");
2933 if (filterW != r && filterW != 1)
2934 return emitOpError("expect filter width either equals to r or 1");
2935 if (filterH == 1 && filterW == 1)
2936 return emitOpError("expect either filter height or width equals to r");
2938 SmallVector<int64_t> expectedOutputShape;
2939 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2940 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2941 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2942 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2944 auto outputType = cast<ShapedType>(getOutput().getType());
2945 ArrayRef<int64_t> outputShape = outputType.getShape();
2946 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2947 return emitOpError("the output shape is not expected");
2949 return success();
2952 SmallVector<Range>
2953 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
2954 Location loc = getLoc();
2955 IntegerAttr zeroAttr = builder.getIndexAttr(0);
2956 IntegerAttr oneAttr = builder.getIndexAttr(1);
2957 Value filter = getFilter();
2958 int64_t filterRank = getFilterOperandRank();
2959 SmallVector<Range> loopBounds(filterRank);
2960 for (unsigned dim = 0; dim < filterRank; ++dim) {
2961 loopBounds[dim].offset = zeroAttr;
2962 loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
2963 loopBounds[dim].stride = oneAttr;
2965 return loopBounds;
2968 SmallVector<utils::IteratorType>
2969 WinogradFilterTransformOp::getLoopIteratorTypes() {
2970 int64_t filterRank = getFilterOperandRank();
2971 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
2972 utils::IteratorType::parallel);
2973 return iteratorTypes;
2976 LogicalResult WinogradFilterTransformOp::getResultTilePosition(
2977 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2978 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2979 SmallVector<OpFoldResult> &resultSizes) {
2980 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
2981 ShapedType filterType = getFilterOperandType();
2982 ArrayRef<int64_t> filterShape = filterType.getShape();
2983 int64_t filterH = filterShape[getFilterHDim()];
2984 int64_t filterW = filterShape[getFilterWDim()];
2985 int64_t m = getM();
2986 int64_t r = getR();
2987 int64_t alpha = m + r - 1;
2988 int64_t alphaH = filterH != 1 ? alpha : 1;
2989 int64_t alphaW = filterW != 1 ? alpha : 1;
2990 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
2991 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
2993 resultOffsets.append(
2994 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
2995 resultSizes.append(
2996 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
2998 return success();
3001 /// Implement tiling for winograd_filter_transform
3002 /// The input of winograd_filter_transform is (F, KH, KW, C).
3003 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3004 /// Users can specify the tile sizes of F and C.
3005 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3006 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3007 FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3008 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3009 ArrayRef<OpFoldResult> sizes) {
3010 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3011 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3012 ShapedType filterType = getFilterOperandType();
3013 ArrayRef<int64_t> filterShape = filterType.getShape();
3014 int64_t filterH = filterShape[getFilterHDim()];
3015 int64_t filterW = filterShape[getFilterWDim()];
3016 IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3017 IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3018 SmallVector<Value> tiledOperands;
3019 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3021 sliceOffsets.append(
3022 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3023 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3024 sizes[getFilterCDim()]});
3025 int64_t filterRank = getFilterOperandRank();
3026 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3027 Location loc = getLoc();
3028 auto filterSlice = builder.create<tensor::ExtractSliceOp>(
3029 loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3030 tiledOperands.emplace_back(filterSlice);
3032 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3033 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3034 resultSizes)))
3035 return failure();
3037 int64_t outputRank = getOutputOperandRank();
3038 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3039 auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3040 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3041 tiledOperands.emplace_back(outputSlice);
3043 SmallVector<Type> resultTypes;
3044 resultTypes.push_back(tiledOperands[1].getType());
3045 Operation *tiledOp =
3046 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3048 return TilingResult{
3049 {tiledOp},
3050 SmallVector<Value>(tiledOp->getResults()),
3051 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3054 //===----------------------------------------------------------------------===//
3055 // WinogradInputTransformOp
3056 //===----------------------------------------------------------------------===//
3058 LogicalResult WinogradInputTransformOp::verify() {
3059 auto inputType = cast<ShapedType>(getInput().getType());
3060 ArrayRef<int64_t> inputShape = inputType.getShape();
3061 int64_t inputH = inputShape[getInputHDim()];
3062 int64_t inputW = inputShape[getInputWDim()];
3063 int m = getM();
3064 int r = getR();
3065 int64_t tileSize = m + r - 1;
3066 bool leftTransform = inputH != 1;
3067 bool rightTransform = inputW != 1;
3069 SmallVector<int64_t> expectedOutputShape(6, inputH);
3070 if (ShapedType::isDynamic(inputH)) {
3071 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3072 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3073 } else {
3074 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3075 expectedOutputShape[getOutputTileHDim()] =
3076 leftTransform ? (inputH - (r - 1)) / m : 1;
3078 if (ShapedType::isDynamic(inputW)) {
3079 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3080 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3081 } else {
3082 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3083 expectedOutputShape[getOutputTileWDim()] =
3084 rightTransform ? (inputW - (r - 1)) / m : 1;
3086 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3087 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3089 auto outputType = cast<ShapedType>(getOutput().getType());
3090 ArrayRef<int64_t> outputShape = outputType.getShape();
3091 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3092 return emitOpError("the output shape is not expected");
3094 return success();
3097 SmallVector<Range>
3098 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3099 Location loc = getLoc();
3100 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3101 IntegerAttr oneAttr = builder.getIndexAttr(1);
3102 Value output = getOutput();
3103 int64_t outputRank = getOutputOperandRank();
3104 SmallVector<Range> loopBounds(outputRank);
3105 for (unsigned dim = 0; dim < outputRank; ++dim) {
3106 loopBounds[dim].offset = zeroAttr;
3107 // alphaH, alphaW, tileH, tileW, N, C
3108 loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3109 loopBounds[dim].stride = oneAttr;
3111 return loopBounds;
3114 SmallVector<utils::IteratorType>
3115 WinogradInputTransformOp::getLoopIteratorTypes() {
3116 int64_t outputRank = getOutputOperandRank();
3117 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3118 utils::IteratorType::parallel);
3119 return iteratorTypes;
3122 LogicalResult WinogradInputTransformOp::getResultTilePosition(
3123 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3124 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3125 SmallVector<OpFoldResult> &resultSizes) {
3126 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3127 ShapedType inputType = getInputOperandType();
3128 ArrayRef<int64_t> inputShape = inputType.getShape();
3129 int64_t inputH = inputShape[getInputHDim()];
3130 int64_t inputW = inputShape[getInputWDim()];
3131 int64_t m = getM();
3132 int64_t r = getR();
3133 int64_t alpha = m + r - 1;
3134 int64_t alphaH = inputH != 1 ? alpha : 1;
3135 int64_t alphaW = inputW != 1 ? alpha : 1;
3136 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3137 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3139 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3140 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3141 offsets[getOutputCDim()]});
3142 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3143 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3144 sizes[getOutputCDim()]});
3146 return success();
3149 /// Implement tiling for winograd_input_transform
3150 /// The input of winograd_input_transform is (N, H, W, C).
3151 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3152 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3153 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3154 /// the values for the sizes of tileH, tileW, N, C for one tile.
3155 FailureOr<TilingResult>
3156 WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3157 ArrayRef<OpFoldResult> offsets,
3158 ArrayRef<OpFoldResult> sizes) {
3159 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3160 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3161 ShapedType inputType = getInputOperandType();
3162 ArrayRef<int64_t> inputShape = inputType.getShape();
3163 int64_t inputH = inputShape[getInputHDim()];
3164 int64_t inputW = inputShape[getInputWDim()];
3165 int64_t m = getM();
3166 int64_t r = getR();
3168 Location loc = getLoc();
3169 MLIRContext *context = builder.getContext();
3170 auto offsetAffineMap =
3171 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3172 Value mappedOffsetH = affine::makeComposedAffineApply(
3173 builder, loc, offsetAffineMap, offsets[getOutputTileHDim()]);
3174 Value mappedOffsetW = affine::makeComposedAffineApply(
3175 builder, loc, offsetAffineMap, offsets[getOutputTileWDim()]);
3176 auto sizeAffineMap = AffineMap::get(
3177 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3178 Value mappedSizeH = affine::makeComposedAffineApply(
3179 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3180 Value mappedSizeW = affine::makeComposedAffineApply(
3181 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3183 SmallVector<Value> tiledOperands;
3184 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3186 OpFoldResult offsetH =
3187 inputH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
3188 OpFoldResult offsetW =
3189 inputW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
3190 sliceOffsets.append(
3191 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3192 OpFoldResult sizeH =
3193 inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3194 OpFoldResult sizeW =
3195 inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3196 sliceSizes.append(
3197 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3198 int64_t inputRank = getInputOperandRank();
3199 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3200 auto inputSlice = builder.create<tensor::ExtractSliceOp>(
3201 loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3202 tiledOperands.emplace_back(inputSlice);
3204 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3205 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3206 resultSizes)))
3207 return failure();
3209 int64_t outputRank = getOutputOperandRank();
3210 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3211 auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3212 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3213 tiledOperands.emplace_back(outputSlice);
3215 SmallVector<Type> resultTypes;
3216 resultTypes.push_back(tiledOperands[1].getType());
3217 Operation *tiledOp =
3218 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3220 return TilingResult{
3221 {tiledOp},
3222 SmallVector<Value>(tiledOp->getResults()),
3223 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3226 //===----------------------------------------------------------------------===//
3227 // WinogradOutputTransformOp
3228 //===----------------------------------------------------------------------===//
3230 LogicalResult WinogradOutputTransformOp::verify() {
3231 auto valueType = cast<ShapedType>(getValue().getType());
3232 ArrayRef<int64_t> valueShape = valueType.getShape();
3233 int64_t valueH = valueShape[getValueAlphaHDim()];
3234 int64_t valueW = valueShape[getValueAlphaWDim()];
3235 int64_t valueTileH = valueShape[getValueTileHDim()];
3236 int64_t valueTileW = valueShape[getValueTileWDim()];
3237 int m = getM();
3238 int r = getR();
3239 bool leftTransform = valueH != 1;
3240 bool rightTransform = valueW != 1;
3242 int64_t outputRank = getOutputOperandRank();
3243 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3244 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3245 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3246 } else {
3247 if (valueH != (leftTransform ? m + r - 1 : 1))
3248 return emitOpError("expect input height equals to input tile size");
3249 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3251 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3252 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3253 } else {
3254 if (valueW != (rightTransform ? m + r - 1 : 1))
3255 return emitOpError("expect input width equals to input tile size");
3256 expectedOutputShape[getOutputWDim()] =
3257 (rightTransform ? m : 1) * valueTileW;
3259 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3260 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3262 auto outputType = cast<ShapedType>(getOutput().getType());
3263 ArrayRef<int64_t> outputShape = outputType.getShape();
3264 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3265 return emitOpError("the output shape is not expected");
3267 return success();
3270 SmallVector<Range>
3271 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3272 Location loc = getLoc();
3273 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3274 IntegerAttr oneAttr = builder.getIndexAttr(1);
3275 Value value = getValue();
3276 int64_t valueRank = getValueOperandRank();
3277 SmallVector<Range> loopBounds(valueRank);
3278 for (unsigned dim = 0; dim < valueRank; ++dim) {
3279 loopBounds[dim].offset = zeroAttr;
3280 // alphaH, alphaW, tileH, tileW, N, F
3281 loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3282 loopBounds[dim].stride = oneAttr;
3284 return loopBounds;
3287 SmallVector<utils::IteratorType>
3288 WinogradOutputTransformOp::getLoopIteratorTypes() {
3289 int64_t valueRank = getValueOperandRank();
3290 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3291 utils::IteratorType::parallel);
3292 return iteratorTypes;
3295 LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3296 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3297 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3298 SmallVector<OpFoldResult> &resultSizes) {
3299 int64_t m = getM();
3301 Location loc = getLoc();
3302 MLIRContext *context = builder.getContext();
3303 auto affineMap =
3304 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3306 Value mappedOffsetH = affine::makeComposedAffineApply(
3307 builder, loc, affineMap, offsets[getValueTileHDim()]);
3308 Value mappedOffsetW = affine::makeComposedAffineApply(
3309 builder, loc, affineMap, offsets[getValueTileWDim()]);
3310 Value mappedSizeH = affine::makeComposedAffineApply(
3311 builder, loc, affineMap, sizes[getValueTileHDim()]);
3312 Value mappedSizeW = affine::makeComposedAffineApply(
3313 builder, loc, affineMap, sizes[getValueTileWDim()]);
3315 ShapedType valueType = getValueOperandType();
3316 ArrayRef<int64_t> valueShape = valueType.getShape();
3317 int64_t valueH = valueShape[0];
3318 int64_t valueW = valueShape[1];
3319 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3320 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3321 OpFoldResult offsetH =
3322 valueH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
3323 OpFoldResult offsetW =
3324 valueW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
3325 OpFoldResult sizeH =
3326 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3327 OpFoldResult sizeW =
3328 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3330 resultOffsets.append(
3331 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3332 resultSizes.append(
3333 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3334 return success();
3337 /// Implement tiling for winograd_output_transform
3338 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3339 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3340 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3341 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3342 /// for the sizes of tileH, tileW, N, F for one tile.
3343 FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3344 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3345 ArrayRef<OpFoldResult> sizes) {
3346 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3347 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3348 Location loc = getLoc();
3349 SmallVector<Value> tiledOperands;
3350 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3352 ShapedType valueType = getValueOperandType();
3353 ArrayRef<int64_t> valueShape = valueType.getShape();
3354 int64_t alphaH = valueShape[getValueAlphaHDim()];
3355 int64_t alphaW = valueShape[getValueAlphaWDim()];
3356 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3357 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3359 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3360 offsets[getValueTileWDim()], offsets[getValueNDim()],
3361 offsets[getValueFDim()]});
3362 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3363 sizes[getValueTileWDim()], sizes[getValueNDim()],
3364 sizes[getValueFDim()]});
3365 int64_t valueRank = getValueOperandRank();
3366 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3367 auto valueSlice = builder.create<tensor::ExtractSliceOp>(
3368 loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3369 tiledOperands.emplace_back(valueSlice);
3371 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3372 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3373 resultSizes)))
3374 return failure();
3376 int64_t outputRank = getOutputOperandRank();
3377 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3378 auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3379 loc, getOutput(), resultOffsets, resultSizes, strides);
3380 tiledOperands.emplace_back(outputSlice);
3382 SmallVector<Type> resultTypes;
3383 resultTypes.push_back(tiledOperands[1].getType());
3384 Operation *tiledOp =
3385 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3387 return TilingResult{
3388 {tiledOp},
3389 SmallVector<Value>(tiledOp->getResults()),
3390 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3393 //===----------------------------------------------------------------------===//
3394 // LinalgDialect
3395 //===----------------------------------------------------------------------===//
3397 void LinalgDialect::getCanonicalizationPatterns(
3398 RewritePatternSet &results) const {
3399 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
3400 InferStaticShapeOfOperands>(getContext());
3403 Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
3404 Attribute value, Type type,
3405 Location loc) {
3406 return arith::ConstantOp::materialize(builder, value, type, loc);
3409 /// Returns true if the result AffineExpr of the \p explicitMap is same as \p
3410 /// defaultMap.
3411 static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
3412 auto explicitRange = explictMap.getResults();
3413 auto defaultRange = defaultMap.getResults();
3414 DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3415 DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3416 llvm::set_union(explicitSet, defaultSet);
3417 return explicitSet == defaultSet;
3420 /// Returns true if the \p explictMap is broadcasted with respect to the
3421 /// \p defaultMap.
3422 static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3423 return explictMap.getNumResults() < defaultMap.getNumResults();
3426 /// Verifies the broadcast and transpose semantic sepecified by the explicit
3427 /// indexing map for the MatmulOp \p op for each operand specified by \p
3428 /// opIndex.
3429 static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3430 unsigned opIndex) {
3431 SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3432 SmallVector<AffineMap, 3> defaultIndexingMaps =
3433 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3435 auto opIndexingMap = opIndexingMaps[opIndex];
3436 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3437 // Check general validity of indexing map results.
3438 if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
3439 return matmulOp->emitOpError()
3440 << "Unexpected dim expression in map result.";
3442 // Check if the requested broadcast is valid.
3443 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3444 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3445 return matmulOp->emitOpError()
3446 << "Invalid broadcast requested, should be (d2).";
3448 return success();
3450 return success();
3453 namespace mlir {
3454 namespace linalg {
3456 //===----------------------------------------------------------------------===//
3457 // MatMulOp
3458 //===----------------------------------------------------------------------===//
3460 /// Returns a list of AffineMap with the typical matmul indexing charactristic.
3461 SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3462 AffineExpr d0, d1, d2;
3463 SmallVector<AffineMap> indexingMaps;
3464 bindDims(context, d0, d1, d2);
3465 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3466 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3467 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3468 return indexingMaps;
3471 SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3472 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3473 utils::IteratorType::parallel,
3474 utils::IteratorType::reduction};
3477 unsigned MatmulOp::getNumRegionArgs() { return 3; }
3479 std::string MatmulOp::getLibraryCallName() {
3480 return generateLibraryCallName(getOperation());
3483 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3485 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3486 /// the user defined indexing maps are not equal to default map.
3487 bool MatmulOp::hasUserDefinedMaps() {
3488 SmallVector<AffineMap, 3> defaultMaps =
3489 getDefaultIndexingMaps(this->getContext());
3490 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3491 return defaultMaps != explicitMaps;
3494 /// Implements the block region builder for the MatmulOp. This is called by
3495 /// 'fillStructuredOpRegion'.
3496 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3497 ArrayRef<NamedAttribute> attrs) {
3498 assert(3 > 0 && block.getNumArguments() == 3 &&
3499 "MatmulOp regionBuilder expects 3 (>=0) args");
3500 RegionBuilderHelper helper(b, block);
3501 SmallVector<Value> yields;
3503 TypeFn castVal = TypeFn::cast_signed;
3504 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3505 return attr.getName() == "cast";
3507 if (castIter != attrs.end()) {
3508 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3509 castVal = attr.getValue();
3512 Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3513 block.getArgument(0));
3514 Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3515 block.getArgument(1));
3516 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3517 Value value4 =
3518 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
3519 yields.push_back(value4);
3520 helper.yieldOutputs(yields);
3523 /// Returns true if the given broadcast map \p bcastMap is valid for this op.
3524 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3525 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3526 AffineExpr exp = bcastMap.getResult(0);
3527 // Invalid map if the common dimension of matmul not found.
3528 return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
3531 ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3532 SmallVector<Attribute, 3> indexingMapsAttr;
3533 Attribute mapAttr;
3534 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
3535 if (parser.parseEqual())
3536 return failure();
3538 if (parser.parseLSquare())
3539 return failure();
3541 do {
3542 if (parser.parseAttribute(mapAttr))
3543 return failure();
3544 if (!isa<AffineMapAttr>(mapAttr)) {
3545 return parser.emitError(parser.getCurrentLocation(),
3546 "expected affine map attribute");
3548 indexingMapsAttr.push_back(mapAttr);
3550 if (parser.parseOptionalComma())
3551 break;
3552 } while (true);
3554 if (parser.parseRSquare())
3555 return failure();
3557 // Initialize indexingMaps, if not supplied explicitly.
3558 if (indexingMapsAttr.empty()) {
3559 indexingMapsAttr = llvm::map_to_vector(
3560 MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3561 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3563 result.addAttribute("indexing_maps",
3564 parser.getBuilder().getArrayAttr(indexingMapsAttr));
3566 return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3567 MatmulOp::getRegionBuilder());
3569 void MatmulOp::print(OpAsmPrinter &p) {
3570 SmallVector<StringRef, 3> elidedAttrs = {
3571 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3572 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3573 elidedAttrs);
3575 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
3576 MatmulOp::getDefaultIndexingMaps(getContext()),
3577 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3578 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
3579 p << " indexing_maps = [";
3580 llvm::interleaveComma(getIndexingMaps(), p,
3581 [&](Attribute attr) { p.printAttribute(attr); });
3582 p << "]";
3586 /// Verify the user defined indexing maps.
3587 LogicalResult MatmulOp::verify() {
3588 // Verification of pure matmul is handled by verifyStructuredOpInterface().
3589 if (!hasUserDefinedMaps())
3590 return success();
3592 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3593 if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3594 return failure();
3596 return success();
3599 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3600 return memref::foldMemRefCast(*this);
3602 void MatmulOp::getEffects(
3603 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3604 &effects) {
3605 if (hasPureTensorSemantics())
3606 return;
3607 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3610 Speculation::Speculatability MatmulOp::getSpeculatability() {
3611 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3614 } // namespace linalg
3615 } // namespace mlir