1 //===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the folders and canonicalization patterns for SPIR-V ops.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15 #include "mlir/Dialect/CommonFolders.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
23 //===----------------------------------------------------------------------===//
24 // Common utility functions
25 //===----------------------------------------------------------------------===//
27 /// Returns the boolean value under the hood if the given `boolAttr` is a scalar
28 /// or splat vector bool constant.
29 static Optional
<bool> getScalarOrSplatBoolAttr(Attribute boolAttr
) {
33 auto type
= boolAttr
.getType();
34 if (type
.isInteger(1)) {
35 auto attr
= boolAttr
.cast
<BoolAttr
>();
36 return attr
.getValue();
38 if (auto vecType
= type
.cast
<VectorType
>()) {
39 if (vecType
.getElementType().isInteger(1))
40 if (auto attr
= boolAttr
.dyn_cast
<SplatElementsAttr
>())
41 return attr
.getSplatValue
<bool>();
46 // Extracts an element from the given `composite` by following the given
47 // `indices`. Returns a null Attribute if error happens.
48 static Attribute
extractCompositeElement(Attribute composite
,
49 ArrayRef
<unsigned> indices
) {
50 // Check that given composite is a constant.
53 // Return composite itself if we reach the end of the index chain.
57 if (auto vector
= composite
.dyn_cast
<ElementsAttr
>()) {
58 assert(indices
.size() == 1 && "must have exactly one index for a vector");
59 return vector
.getValue({indices
[0]});
62 if (auto array
= composite
.dyn_cast
<ArrayAttr
>()) {
63 assert(!indices
.empty() && "must have at least one index for an array");
64 return extractCompositeElement(array
.getValue()[indices
[0]],
65 indices
.drop_front());
71 //===----------------------------------------------------------------------===//
72 // TableGen'erated canonicalizers
73 //===----------------------------------------------------------------------===//
76 #include "SPIRVCanonicalization.inc"
79 //===----------------------------------------------------------------------===//
81 //===----------------------------------------------------------------------===//
85 /// Combines chained `spirv::AccessChainOp` operations into one
86 /// `spirv::AccessChainOp` operation.
87 struct CombineChainedAccessChain
88 : public OpRewritePattern
<spirv::AccessChainOp
> {
89 using OpRewritePattern
<spirv::AccessChainOp
>::OpRewritePattern
;
91 LogicalResult
matchAndRewrite(spirv::AccessChainOp accessChainOp
,
92 PatternRewriter
&rewriter
) const override
{
93 auto parentAccessChainOp
= dyn_cast_or_null
<spirv::AccessChainOp
>(
94 accessChainOp
.base_ptr().getDefiningOp());
96 if (!parentAccessChainOp
) {
101 SmallVector
<Value
, 4> indices(parentAccessChainOp
.indices());
102 indices
.append(accessChainOp
.indices().begin(),
103 accessChainOp
.indices().end());
105 rewriter
.replaceOpWithNewOp
<spirv::AccessChainOp
>(
106 accessChainOp
, parentAccessChainOp
.base_ptr(), indices
);
111 } // end anonymous namespace
113 void spirv::AccessChainOp::getCanonicalizationPatterns(
114 OwningRewritePatternList
&results
, MLIRContext
*context
) {
115 results
.insert
<CombineChainedAccessChain
>(context
);
118 //===----------------------------------------------------------------------===//
120 //===----------------------------------------------------------------------===//
122 void spirv::BitcastOp::getCanonicalizationPatterns(
123 OwningRewritePatternList
&results
, MLIRContext
*context
) {
124 results
.insert
<ConvertChainedBitcast
>(context
);
127 //===----------------------------------------------------------------------===//
128 // spv.CompositeExtractOp
129 //===----------------------------------------------------------------------===//
131 OpFoldResult
spirv::CompositeExtractOp::fold(ArrayRef
<Attribute
> operands
) {
132 assert(operands
.size() == 1 && "spv.CompositeExtract expects one operand");
134 llvm::to_vector
<8>(llvm::map_range(indices(), [](Attribute attr
) {
135 return static_cast<unsigned>(attr
.cast
<IntegerAttr
>().getInt());
137 return extractCompositeElement(operands
[0], indexVector
);
140 //===----------------------------------------------------------------------===//
142 //===----------------------------------------------------------------------===//
144 OpFoldResult
spirv::ConstantOp::fold(ArrayRef
<Attribute
> operands
) {
145 assert(operands
.empty() && "spv.constant has no operands");
149 //===----------------------------------------------------------------------===//
151 //===----------------------------------------------------------------------===//
153 OpFoldResult
spirv::IAddOp::fold(ArrayRef
<Attribute
> operands
) {
154 assert(operands
.size() == 2 && "spv.IAdd expects two operands");
156 if (matchPattern(operand2(), m_Zero()))
159 // According to the SPIR-V spec:
161 // The resulting value will equal the low-order N bits of the correct result
162 // R, where N is the component width and R is computed with enough precision
163 // to avoid overflow and underflow.
164 return constFoldBinaryOp
<IntegerAttr
>(operands
,
165 [](APInt a
, APInt b
) { return a
+ b
; });
168 //===----------------------------------------------------------------------===//
170 //===----------------------------------------------------------------------===//
172 OpFoldResult
spirv::IMulOp::fold(ArrayRef
<Attribute
> operands
) {
173 assert(operands
.size() == 2 && "spv.IMul expects two operands");
175 if (matchPattern(operand2(), m_Zero()))
178 if (matchPattern(operand2(), m_One()))
181 // According to the SPIR-V spec:
183 // The resulting value will equal the low-order N bits of the correct result
184 // R, where N is the component width and R is computed with enough precision
185 // to avoid overflow and underflow.
186 return constFoldBinaryOp
<IntegerAttr
>(operands
,
187 [](APInt a
, APInt b
) { return a
* b
; });
190 //===----------------------------------------------------------------------===//
192 //===----------------------------------------------------------------------===//
194 OpFoldResult
spirv::ISubOp::fold(ArrayRef
<Attribute
> operands
) {
196 if (operand1() == operand2())
197 return Builder(getContext()).getIntegerAttr(getType(), 0);
199 // According to the SPIR-V spec:
201 // The resulting value will equal the low-order N bits of the correct result
202 // R, where N is the component width and R is computed with enough precision
203 // to avoid overflow and underflow.
204 return constFoldBinaryOp
<IntegerAttr
>(operands
,
205 [](APInt a
, APInt b
) { return a
- b
; });
208 //===----------------------------------------------------------------------===//
210 //===----------------------------------------------------------------------===//
212 OpFoldResult
spirv::LogicalAndOp::fold(ArrayRef
<Attribute
> operands
) {
213 assert(operands
.size() == 2 && "spv.LogicalAnd should take two operands");
215 if (Optional
<bool> rhs
= getScalarOrSplatBoolAttr(operands
.back())) {
220 // x && false = false
222 return operands
.back();
228 //===----------------------------------------------------------------------===//
230 //===----------------------------------------------------------------------===//
232 void spirv::LogicalNotOp::getCanonicalizationPatterns(
233 OwningRewritePatternList
&results
, MLIRContext
*context
) {
234 results
.insert
<ConvertLogicalNotOfIEqual
, ConvertLogicalNotOfINotEqual
,
235 ConvertLogicalNotOfLogicalEqual
,
236 ConvertLogicalNotOfLogicalNotEqual
>(context
);
239 //===----------------------------------------------------------------------===//
241 //===----------------------------------------------------------------------===//
243 OpFoldResult
spirv::LogicalOrOp::fold(ArrayRef
<Attribute
> operands
) {
244 assert(operands
.size() == 2 && "spv.LogicalOr should take two operands");
246 if (auto rhs
= getScalarOrSplatBoolAttr(operands
.back())) {
249 return operands
.back();
259 //===----------------------------------------------------------------------===//
261 //===----------------------------------------------------------------------===//
264 // Blocks from the given `spv.selection` operation must satisfy the following
267 // +-----------------------------------------------+
269 // | spv.BranchConditionalOp %cond, ^case0, ^case1 |
270 // +-----------------------------------------------+
275 // +------------------------+ +------------------------+
276 // | case #0 | | case #1 |
277 // | spv.Store %ptr %value0 | | spv.Store %ptr %value1 |
278 // | spv.Branch ^merge | | spv.Branch ^merge |
279 // +------------------------+ +------------------------+
289 struct ConvertSelectionOpToSelect
290 : public OpRewritePattern
<spirv::SelectionOp
> {
291 using OpRewritePattern
<spirv::SelectionOp
>::OpRewritePattern
;
293 LogicalResult
matchAndRewrite(spirv::SelectionOp selectionOp
,
294 PatternRewriter
&rewriter
) const override
{
295 auto *op
= selectionOp
.getOperation();
296 auto &body
= op
->getRegion(0);
297 // Verifier allows an empty region for `spv.selection`.
302 // Check that region consists of 4 blocks:
303 // header block, `true` block, `false` block and merge block.
304 if (std::distance(body
.begin(), body
.end()) != 4) {
308 auto *headerBlock
= selectionOp
.getHeaderBlock();
309 if (!onlyContainsBranchConditionalOp(headerBlock
)) {
313 auto brConditionalOp
=
314 cast
<spirv::BranchConditionalOp
>(headerBlock
->front());
316 auto *trueBlock
= brConditionalOp
.getSuccessor(0);
317 auto *falseBlock
= brConditionalOp
.getSuccessor(1);
318 auto *mergeBlock
= selectionOp
.getMergeBlock();
320 if (failed(canCanonicalizeSelection(trueBlock
, falseBlock
, mergeBlock
)))
323 auto trueValue
= getSrcValue(trueBlock
);
324 auto falseValue
= getSrcValue(falseBlock
);
325 auto ptrValue
= getDstPtr(trueBlock
);
326 auto storeOpAttributes
=
327 cast
<spirv::StoreOp
>(trueBlock
->front())->getAttrs();
329 auto selectOp
= rewriter
.create
<spirv::SelectOp
>(
330 selectionOp
.getLoc(), trueValue
.getType(), brConditionalOp
.condition(),
331 trueValue
, falseValue
);
332 rewriter
.create
<spirv::StoreOp
>(selectOp
.getLoc(), ptrValue
,
333 selectOp
.getResult(), storeOpAttributes
);
335 // `spv.selection` is not needed anymore.
336 rewriter
.eraseOp(op
);
341 // Checks that given blocks follow the following rules:
342 // 1. Each conditional block consists of two operations, the first operation
343 // is a `spv.Store` and the last operation is a `spv.Branch`.
344 // 2. Each `spv.Store` uses the same pointer and the same memory attributes.
345 // 3. A control flow goes into the given merge block from the given
346 // conditional blocks.
347 LogicalResult
canCanonicalizeSelection(Block
*trueBlock
, Block
*falseBlock
,
348 Block
*mergeBlock
) const;
350 bool onlyContainsBranchConditionalOp(Block
*block
) const {
351 return std::next(block
->begin()) == block
->end() &&
352 isa
<spirv::BranchConditionalOp
>(block
->front());
355 bool isSameAttrList(spirv::StoreOp lhs
, spirv::StoreOp rhs
) const {
356 return lhs
->getAttrDictionary() == rhs
->getAttrDictionary();
359 // Returns a source value for the given block.
360 Value
getSrcValue(Block
*block
) const {
361 auto storeOp
= cast
<spirv::StoreOp
>(block
->front());
362 return storeOp
.value();
365 // Returns a destination value for the given block.
366 Value
getDstPtr(Block
*block
) const {
367 auto storeOp
= cast
<spirv::StoreOp
>(block
->front());
368 return storeOp
.ptr();
372 LogicalResult
ConvertSelectionOpToSelect::canCanonicalizeSelection(
373 Block
*trueBlock
, Block
*falseBlock
, Block
*mergeBlock
) const {
374 // Each block must consists of 2 operations.
375 if ((std::distance(trueBlock
->begin(), trueBlock
->end()) != 2) ||
376 (std::distance(falseBlock
->begin(), falseBlock
->end()) != 2)) {
380 auto trueBrStoreOp
= dyn_cast
<spirv::StoreOp
>(trueBlock
->front());
381 auto trueBrBranchOp
=
382 dyn_cast
<spirv::BranchOp
>(*std::next(trueBlock
->begin()));
383 auto falseBrStoreOp
= dyn_cast
<spirv::StoreOp
>(falseBlock
->front());
384 auto falseBrBranchOp
=
385 dyn_cast
<spirv::BranchOp
>(*std::next(falseBlock
->begin()));
387 if (!trueBrStoreOp
|| !trueBrBranchOp
|| !falseBrStoreOp
||
392 // Checks that given type is valid for `spv.SelectOp`.
393 // According to SPIR-V spec:
394 // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
395 // Starting with version 1.4, Result Type can additionally be a composite type
396 // other than a vector."
397 bool isScalarOrVector
= trueBrStoreOp
.value()
399 .cast
<spirv::SPIRVType
>()
402 // Check that each `spv.Store` uses the same pointer, memory access
403 // attributes and a valid type of the value.
404 if ((trueBrStoreOp
.ptr() != falseBrStoreOp
.ptr()) ||
405 !isSameAttrList(trueBrStoreOp
, falseBrStoreOp
) || !isScalarOrVector
) {
409 if ((trueBrBranchOp
->getSuccessor(0) != mergeBlock
) ||
410 (falseBrBranchOp
->getSuccessor(0) != mergeBlock
)) {
416 } // end anonymous namespace
418 void spirv::SelectionOp::getCanonicalizationPatterns(
419 OwningRewritePatternList
&results
, MLIRContext
*context
) {
420 results
.insert
<ConvertSelectionOpToSelect
>(context
);