1 //===- RewriteInsertsPass.cpp - MLIR conversion pass ----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass to rewrite sequential chains of
10 // `spirv::CompositeInsert` operations into `spirv::CompositeConstruct`
13 //===----------------------------------------------------------------------===//
15 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinOps.h"
23 #define GEN_PASS_DEF_SPIRVREWRITEINSERTSPASS
24 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
32 /// Replaces sequential chains of `spirv::CompositeInsertOp` operation into
33 /// `spirv::CompositeConstructOp` operation if possible.
34 class RewriteInsertsPass
35 : public spirv::impl::SPIRVRewriteInsertsPassBase
<RewriteInsertsPass
> {
37 void runOnOperation() override
;
40 /// Collects a sequential insertion chain by the given
41 /// `spirv::CompositeInsertOp` operation, if the given operation is the last
44 collectInsertionChain(spirv::CompositeInsertOp op
,
45 SmallVectorImpl
<spirv::CompositeInsertOp
> &insertions
);
50 void RewriteInsertsPass::runOnOperation() {
51 SmallVector
<SmallVector
<spirv::CompositeInsertOp
, 4>, 4> workList
;
52 getOperation().walk([this, &workList
](spirv::CompositeInsertOp op
) {
53 SmallVector
<spirv::CompositeInsertOp
, 4> insertions
;
54 if (succeeded(collectInsertionChain(op
, insertions
)))
55 workList
.push_back(insertions
);
58 for (const auto &insertions
: workList
) {
59 auto lastCompositeInsertOp
= insertions
.back();
60 auto compositeType
= lastCompositeInsertOp
.getType();
61 auto location
= lastCompositeInsertOp
.getLoc();
63 SmallVector
<Value
, 4> operands
;
64 // Collect inserted objects.
65 for (auto insertionOp
: insertions
)
66 operands
.push_back(insertionOp
.getObject());
68 OpBuilder
builder(lastCompositeInsertOp
);
69 auto compositeConstructOp
= builder
.create
<spirv::CompositeConstructOp
>(
70 location
, compositeType
, operands
);
72 lastCompositeInsertOp
.replaceAllUsesWith(
73 compositeConstructOp
->getResult(0));
76 for (auto insertOp
: llvm::reverse(insertions
)) {
77 auto *op
= insertOp
.getOperation();
84 LogicalResult
RewriteInsertsPass::collectInsertionChain(
85 spirv::CompositeInsertOp op
,
86 SmallVectorImpl
<spirv::CompositeInsertOp
> &insertions
) {
87 auto indicesArrayAttr
= cast
<ArrayAttr
>(op
.getIndices());
88 // TODO: handle nested composite object.
89 if (indicesArrayAttr
.size() == 1) {
90 auto numElements
= cast
<spirv::CompositeType
>(op
.getComposite().getType())
93 auto index
= cast
<IntegerAttr
>(indicesArrayAttr
[0]).getInt();
94 // Need a last index to collect a sequential chain.
95 if (index
+ 1 != numElements
)
98 insertions
.resize(numElements
);
100 insertions
[index
] = op
;
105 op
= op
.getComposite().getDefiningOp
<spirv::CompositeInsertOp
>();
110 indicesArrayAttr
= cast
<ArrayAttr
>(op
.getIndices());
111 if ((indicesArrayAttr
.size() != 1) ||
112 (cast
<IntegerAttr
>(indicesArrayAttr
[0]).getInt() != index
))