[mlir][spirv] NFC: Shuffle code around to better follow convention
[llvm-project.git] / mlir / lib / Dialect / SPIRV / IR / SPIRVCanonicalization.cpp
blob7f268ca929646a5bd5042f2b969b880dfa29a341
1 //===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the folders and canonicalization patterns for SPIR-V ops.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15 #include "mlir/Dialect/CommonFolders.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
21 using namespace mlir;
23 //===----------------------------------------------------------------------===//
24 // Common utility functions
25 //===----------------------------------------------------------------------===//
27 /// Returns the boolean value under the hood if the given `boolAttr` is a scalar
28 /// or splat vector bool constant.
29 static Optional<bool> getScalarOrSplatBoolAttr(Attribute boolAttr) {
30 if (!boolAttr)
31 return llvm::None;
33 auto type = boolAttr.getType();
34 if (type.isInteger(1)) {
35 auto attr = boolAttr.cast<BoolAttr>();
36 return attr.getValue();
38 if (auto vecType = type.cast<VectorType>()) {
39 if (vecType.getElementType().isInteger(1))
40 if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
41 return attr.getSplatValue<bool>();
43 return llvm::None;
46 // Extracts an element from the given `composite` by following the given
47 // `indices`. Returns a null Attribute if error happens.
48 static Attribute extractCompositeElement(Attribute composite,
49 ArrayRef<unsigned> indices) {
50 // Check that given composite is a constant.
51 if (!composite)
52 return {};
53 // Return composite itself if we reach the end of the index chain.
54 if (indices.empty())
55 return composite;
57 if (auto vector = composite.dyn_cast<ElementsAttr>()) {
58 assert(indices.size() == 1 && "must have exactly one index for a vector");
59 return vector.getValue({indices[0]});
62 if (auto array = composite.dyn_cast<ArrayAttr>()) {
63 assert(!indices.empty() && "must have at least one index for an array");
64 return extractCompositeElement(array.getValue()[indices[0]],
65 indices.drop_front());
68 return {};
71 //===----------------------------------------------------------------------===//
72 // TableGen'erated canonicalizers
73 //===----------------------------------------------------------------------===//
75 namespace {
76 #include "SPIRVCanonicalization.inc"
79 //===----------------------------------------------------------------------===//
80 // spv.AccessChainOp
81 //===----------------------------------------------------------------------===//
83 namespace {
85 /// Combines chained `spirv::AccessChainOp` operations into one
86 /// `spirv::AccessChainOp` operation.
87 struct CombineChainedAccessChain
88 : public OpRewritePattern<spirv::AccessChainOp> {
89 using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
91 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
92 PatternRewriter &rewriter) const override {
93 auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
94 accessChainOp.base_ptr().getDefiningOp());
96 if (!parentAccessChainOp) {
97 return failure();
100 // Combine indices.
101 SmallVector<Value, 4> indices(parentAccessChainOp.indices());
102 indices.append(accessChainOp.indices().begin(),
103 accessChainOp.indices().end());
105 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
106 accessChainOp, parentAccessChainOp.base_ptr(), indices);
108 return success();
111 } // end anonymous namespace
113 void spirv::AccessChainOp::getCanonicalizationPatterns(
114 OwningRewritePatternList &results, MLIRContext *context) {
115 results.insert<CombineChainedAccessChain>(context);
118 //===----------------------------------------------------------------------===//
119 // spv.BitcastOp
120 //===----------------------------------------------------------------------===//
122 void spirv::BitcastOp::getCanonicalizationPatterns(
123 OwningRewritePatternList &results, MLIRContext *context) {
124 results.insert<ConvertChainedBitcast>(context);
127 //===----------------------------------------------------------------------===//
128 // spv.CompositeExtractOp
129 //===----------------------------------------------------------------------===//
131 OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
132 assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
133 auto indexVector =
134 llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) {
135 return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
136 }));
137 return extractCompositeElement(operands[0], indexVector);
140 //===----------------------------------------------------------------------===//
141 // spv.constant
142 //===----------------------------------------------------------------------===//
144 OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
145 assert(operands.empty() && "spv.constant has no operands");
146 return value();
149 //===----------------------------------------------------------------------===//
150 // spv.IAdd
151 //===----------------------------------------------------------------------===//
153 OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
154 assert(operands.size() == 2 && "spv.IAdd expects two operands");
155 // x + 0 = x
156 if (matchPattern(operand2(), m_Zero()))
157 return operand1();
159 // According to the SPIR-V spec:
161 // The resulting value will equal the low-order N bits of the correct result
162 // R, where N is the component width and R is computed with enough precision
163 // to avoid overflow and underflow.
164 return constFoldBinaryOp<IntegerAttr>(operands,
165 [](APInt a, APInt b) { return a + b; });
168 //===----------------------------------------------------------------------===//
169 // spv.IMul
170 //===----------------------------------------------------------------------===//
172 OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
173 assert(operands.size() == 2 && "spv.IMul expects two operands");
174 // x * 0 == 0
175 if (matchPattern(operand2(), m_Zero()))
176 return operand2();
177 // x * 1 = x
178 if (matchPattern(operand2(), m_One()))
179 return operand1();
181 // According to the SPIR-V spec:
183 // The resulting value will equal the low-order N bits of the correct result
184 // R, where N is the component width and R is computed with enough precision
185 // to avoid overflow and underflow.
186 return constFoldBinaryOp<IntegerAttr>(operands,
187 [](APInt a, APInt b) { return a * b; });
190 //===----------------------------------------------------------------------===//
191 // spv.ISub
192 //===----------------------------------------------------------------------===//
194 OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
195 // x - x = 0
196 if (operand1() == operand2())
197 return Builder(getContext()).getIntegerAttr(getType(), 0);
199 // According to the SPIR-V spec:
201 // The resulting value will equal the low-order N bits of the correct result
202 // R, where N is the component width and R is computed with enough precision
203 // to avoid overflow and underflow.
204 return constFoldBinaryOp<IntegerAttr>(operands,
205 [](APInt a, APInt b) { return a - b; });
208 //===----------------------------------------------------------------------===//
209 // spv.LogicalAnd
210 //===----------------------------------------------------------------------===//
212 OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
213 assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
215 if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
216 // x && true = x
217 if (rhs.getValue())
218 return operand1();
220 // x && false = false
221 if (!rhs.getValue())
222 return operands.back();
225 return Attribute();
228 //===----------------------------------------------------------------------===//
229 // spv.LogicalNot
230 //===----------------------------------------------------------------------===//
232 void spirv::LogicalNotOp::getCanonicalizationPatterns(
233 OwningRewritePatternList &results, MLIRContext *context) {
234 results.insert<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
235 ConvertLogicalNotOfLogicalEqual,
236 ConvertLogicalNotOfLogicalNotEqual>(context);
239 //===----------------------------------------------------------------------===//
240 // spv.LogicalOr
241 //===----------------------------------------------------------------------===//
243 OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
244 assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
246 if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) {
247 if (rhs.getValue())
248 // x || true = true
249 return operands.back();
251 // x || false = x
252 if (!rhs.getValue())
253 return operand1();
256 return Attribute();
259 //===----------------------------------------------------------------------===//
260 // spv.selection
261 //===----------------------------------------------------------------------===//
263 namespace {
264 // Blocks from the given `spv.selection` operation must satisfy the following
265 // layout:
267 // +-----------------------------------------------+
268 // | header block |
269 // | spv.BranchConditionalOp %cond, ^case0, ^case1 |
270 // +-----------------------------------------------+
271 // / \
272 // ...
275 // +------------------------+ +------------------------+
276 // | case #0 | | case #1 |
277 // | spv.Store %ptr %value0 | | spv.Store %ptr %value1 |
278 // | spv.Branch ^merge | | spv.Branch ^merge |
279 // +------------------------+ +------------------------+
282 // ...
283 // \ /
284 // v
285 // +-------------+
286 // | merge block |
287 // +-------------+
289 struct ConvertSelectionOpToSelect
290 : public OpRewritePattern<spirv::SelectionOp> {
291 using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
293 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
294 PatternRewriter &rewriter) const override {
295 auto *op = selectionOp.getOperation();
296 auto &body = op->getRegion(0);
297 // Verifier allows an empty region for `spv.selection`.
298 if (body.empty()) {
299 return failure();
302 // Check that region consists of 4 blocks:
303 // header block, `true` block, `false` block and merge block.
304 if (std::distance(body.begin(), body.end()) != 4) {
305 return failure();
308 auto *headerBlock = selectionOp.getHeaderBlock();
309 if (!onlyContainsBranchConditionalOp(headerBlock)) {
310 return failure();
313 auto brConditionalOp =
314 cast<spirv::BranchConditionalOp>(headerBlock->front());
316 auto *trueBlock = brConditionalOp.getSuccessor(0);
317 auto *falseBlock = brConditionalOp.getSuccessor(1);
318 auto *mergeBlock = selectionOp.getMergeBlock();
320 if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
321 return failure();
323 auto trueValue = getSrcValue(trueBlock);
324 auto falseValue = getSrcValue(falseBlock);
325 auto ptrValue = getDstPtr(trueBlock);
326 auto storeOpAttributes =
327 cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
329 auto selectOp = rewriter.create<spirv::SelectOp>(
330 selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
331 trueValue, falseValue);
332 rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
333 selectOp.getResult(), storeOpAttributes);
335 // `spv.selection` is not needed anymore.
336 rewriter.eraseOp(op);
337 return success();
340 private:
341 // Checks that given blocks follow the following rules:
342 // 1. Each conditional block consists of two operations, the first operation
343 // is a `spv.Store` and the last operation is a `spv.Branch`.
344 // 2. Each `spv.Store` uses the same pointer and the same memory attributes.
345 // 3. A control flow goes into the given merge block from the given
346 // conditional blocks.
347 LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
348 Block *mergeBlock) const;
350 bool onlyContainsBranchConditionalOp(Block *block) const {
351 return std::next(block->begin()) == block->end() &&
352 isa<spirv::BranchConditionalOp>(block->front());
355 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
356 return lhs->getAttrDictionary() == rhs->getAttrDictionary();
359 // Returns a source value for the given block.
360 Value getSrcValue(Block *block) const {
361 auto storeOp = cast<spirv::StoreOp>(block->front());
362 return storeOp.value();
365 // Returns a destination value for the given block.
366 Value getDstPtr(Block *block) const {
367 auto storeOp = cast<spirv::StoreOp>(block->front());
368 return storeOp.ptr();
372 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
373 Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
374 // Each block must consists of 2 operations.
375 if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
376 (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
377 return failure();
380 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
381 auto trueBrBranchOp =
382 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
383 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
384 auto falseBrBranchOp =
385 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
387 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
388 !falseBrBranchOp) {
389 return failure();
392 // Checks that given type is valid for `spv.SelectOp`.
393 // According to SPIR-V spec:
394 // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
395 // Starting with version 1.4, Result Type can additionally be a composite type
396 // other than a vector."
397 bool isScalarOrVector = trueBrStoreOp.value()
398 .getType()
399 .cast<spirv::SPIRVType>()
400 .isScalarOrVector();
402 // Check that each `spv.Store` uses the same pointer, memory access
403 // attributes and a valid type of the value.
404 if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
405 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
406 return failure();
409 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
410 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
411 return failure();
414 return success();
416 } // end anonymous namespace
418 void spirv::SelectionOp::getCanonicalizationPatterns(
419 OwningRewritePatternList &results, MLIRContext *context) {
420 results.insert<ConvertSelectionOpToSelect>(context);