1 //===- RewriteInsertsPass.cpp - MLIR conversion pass ----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass to rewrite sequential chains of
10 // `spirv::CompositeInsert` operations into `spirv::CompositeConstruct`
13 //===----------------------------------------------------------------------===//
15 #include "PassDetail.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinOps.h"
25 /// Replaces sequential chains of `spirv::CompositeInsertOp` operation into
26 /// `spirv::CompositeConstructOp` operation if possible.
27 class RewriteInsertsPass
28 : public SPIRVRewriteInsertsPassBase
<RewriteInsertsPass
> {
30 void runOnOperation() override
;
33 /// Collects a sequential insertion chain by the given
34 /// `spirv::CompositeInsertOp` operation, if the given operation is the last
37 collectInsertionChain(spirv::CompositeInsertOp op
,
38 SmallVectorImpl
<spirv::CompositeInsertOp
> &insertions
);
41 } // anonymous namespace
43 void RewriteInsertsPass::runOnOperation() {
44 SmallVector
<SmallVector
<spirv::CompositeInsertOp
, 4>, 4> workList
;
45 getOperation().walk([this, &workList
](spirv::CompositeInsertOp op
) {
46 SmallVector
<spirv::CompositeInsertOp
, 4> insertions
;
47 if (succeeded(collectInsertionChain(op
, insertions
)))
48 workList
.push_back(insertions
);
51 for (const auto &insertions
: workList
) {
52 auto lastCompositeInsertOp
= insertions
.back();
53 auto compositeType
= lastCompositeInsertOp
.getType();
54 auto location
= lastCompositeInsertOp
.getLoc();
56 SmallVector
<Value
, 4> operands
;
57 // Collect inserted objects.
58 for (auto insertionOp
: insertions
)
59 operands
.push_back(insertionOp
.object());
61 OpBuilder
builder(lastCompositeInsertOp
);
62 auto compositeConstructOp
= builder
.create
<spirv::CompositeConstructOp
>(
63 location
, compositeType
, operands
);
65 lastCompositeInsertOp
.replaceAllUsesWith(
66 compositeConstructOp
->getResult(0));
69 for (auto insertOp
: llvm::reverse(insertions
)) {
70 auto *op
= insertOp
.getOperation();
77 LogicalResult
RewriteInsertsPass::collectInsertionChain(
78 spirv::CompositeInsertOp op
,
79 SmallVectorImpl
<spirv::CompositeInsertOp
> &insertions
) {
80 auto indicesArrayAttr
= op
.indices().cast
<ArrayAttr
>();
81 // TODO: handle nested composite object.
82 if (indicesArrayAttr
.size() == 1) {
84 op
.composite().getType().cast
<spirv::CompositeType
>().getNumElements();
86 auto index
= indicesArrayAttr
[0].cast
<IntegerAttr
>().getInt();
87 // Need a last index to collect a sequential chain.
88 if (index
+ 1 != numElements
)
91 insertions
.resize(numElements
);
93 insertions
[index
] = op
;
98 op
= op
.composite().getDefiningOp
<spirv::CompositeInsertOp
>();
103 indicesArrayAttr
= op
.indices().cast
<ArrayAttr
>();
104 if ((indicesArrayAttr
.size() != 1) ||
105 (indicesArrayAttr
[0].cast
<IntegerAttr
>().getInt() != index
))
112 std::unique_ptr
<mlir::OperationPass
<spirv::ModuleOp
>>
113 mlir::spirv::createRewriteInsertsPass() {
114 return std::make_unique
<RewriteInsertsPass
>();