[mlir][spirv] NFC: Shuffle code around to better follow convention
[llvm-project.git] / mlir / lib / Dialect / SPIRV / Transforms / RewriteInsertsPass.cpp
blobb89312fb8ae79da4256512866c4507d578e48766
1 //===- RewriteInsertsPass.cpp - MLIR conversion pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to rewrite sequential chains of
10 // `spirv::CompositeInsert` operations into `spirv::CompositeConstruct`
11 // operations.
13 //===----------------------------------------------------------------------===//
15 #include "PassDetail.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinOps.h"
21 using namespace mlir;
23 namespace {
25 /// Replaces sequential chains of `spirv::CompositeInsertOp` operation into
26 /// `spirv::CompositeConstructOp` operation if possible.
27 class RewriteInsertsPass
28 : public SPIRVRewriteInsertsPassBase<RewriteInsertsPass> {
29 public:
30 void runOnOperation() override;
32 private:
33 /// Collects a sequential insertion chain by the given
34 /// `spirv::CompositeInsertOp` operation, if the given operation is the last
35 /// in the chain.
36 LogicalResult
37 collectInsertionChain(spirv::CompositeInsertOp op,
38 SmallVectorImpl<spirv::CompositeInsertOp> &insertions);
41 } // anonymous namespace
43 void RewriteInsertsPass::runOnOperation() {
44 SmallVector<SmallVector<spirv::CompositeInsertOp, 4>, 4> workList;
45 getOperation().walk([this, &workList](spirv::CompositeInsertOp op) {
46 SmallVector<spirv::CompositeInsertOp, 4> insertions;
47 if (succeeded(collectInsertionChain(op, insertions)))
48 workList.push_back(insertions);
49 });
51 for (const auto &insertions : workList) {
52 auto lastCompositeInsertOp = insertions.back();
53 auto compositeType = lastCompositeInsertOp.getType();
54 auto location = lastCompositeInsertOp.getLoc();
56 SmallVector<Value, 4> operands;
57 // Collect inserted objects.
58 for (auto insertionOp : insertions)
59 operands.push_back(insertionOp.object());
61 OpBuilder builder(lastCompositeInsertOp);
62 auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
63 location, compositeType, operands);
65 lastCompositeInsertOp.replaceAllUsesWith(
66 compositeConstructOp->getResult(0));
68 // Erase ops.
69 for (auto insertOp : llvm::reverse(insertions)) {
70 auto *op = insertOp.getOperation();
71 if (op->use_empty())
72 insertOp.erase();
77 LogicalResult RewriteInsertsPass::collectInsertionChain(
78 spirv::CompositeInsertOp op,
79 SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
80 auto indicesArrayAttr = op.indices().cast<ArrayAttr>();
81 // TODO: handle nested composite object.
82 if (indicesArrayAttr.size() == 1) {
83 auto numElements =
84 op.composite().getType().cast<spirv::CompositeType>().getNumElements();
86 auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt();
87 // Need a last index to collect a sequential chain.
88 if (index + 1 != numElements)
89 return failure();
91 insertions.resize(numElements);
92 while (true) {
93 insertions[index] = op;
95 if (index == 0)
96 return success();
98 op = op.composite().getDefiningOp<spirv::CompositeInsertOp>();
99 if (!op)
100 return failure();
102 --index;
103 indicesArrayAttr = op.indices().cast<ArrayAttr>();
104 if ((indicesArrayAttr.size() != 1) ||
105 (indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index))
106 return failure();
109 return failure();
112 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
113 mlir::spirv::createRewriteInsertsPass() {
114 return std::make_unique<RewriteInsertsPass>();