1 //===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements patterns to convert SCF dialect to SPIR-V dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/Support/FormatVariadic.h"
24 //===----------------------------------------------------------------------===//
26 //===----------------------------------------------------------------------===//
29 struct ScfToSPIRVContextImpl
{
30 // Map between the spirv region control flow operation (spirv.mlir.loop or
31 // spirv.mlir.selection) to the VariableOp created to store the region
32 // results. The order of the VariableOp matches the order of the results.
33 DenseMap
<Operation
*, SmallVector
<spirv::VariableOp
, 8>> outputVars
;
37 /// We use ScfToSPIRVContext to store information about the lowering of the scf
38 /// region that need to be used later on. When we lower scf.for/scf.if we create
39 /// VariableOp to store the results. We need to keep track of the VariableOp
40 /// created as we need to insert stores into them when lowering Yield. Those
41 /// StoreOp cannot be created earlier as they may use a different type than
43 ScfToSPIRVContext::ScfToSPIRVContext() {
44 impl
= std::make_unique
<::ScfToSPIRVContextImpl
>();
47 ScfToSPIRVContext::~ScfToSPIRVContext() = default;
51 //===----------------------------------------------------------------------===//
53 //===----------------------------------------------------------------------===//
55 /// Replaces SCF op outputs with SPIR-V variable loads.
56 /// We create VariableOp to handle the results value of the control flow region.
57 /// spirv.mlir.loop/spirv.mlir.selection currently don't yield value. Right
58 /// after the loop we load the value from the allocation and use it as the SCF
60 template <typename ScfOp
, typename OpTy
>
61 void replaceSCFOutputValue(ScfOp scfOp
, OpTy newOp
,
62 ConversionPatternRewriter
&rewriter
,
63 ScfToSPIRVContextImpl
*scfToSPIRVContext
,
64 ArrayRef
<Type
> returnTypes
) {
66 Location loc
= scfOp
.getLoc();
67 auto &allocas
= scfToSPIRVContext
->outputVars
[newOp
];
68 // Clearing the allocas is necessary in case a dialect conversion path failed
69 // previously, and this is the second attempt of this conversion.
71 SmallVector
<Value
, 8> resultValue
;
72 for (Type convertedType
: returnTypes
) {
74 spirv::PointerType::get(convertedType
, spirv::StorageClass::Function
);
75 rewriter
.setInsertionPoint(newOp
);
76 auto alloc
= rewriter
.create
<spirv::VariableOp
>(
77 loc
, pointerType
, spirv::StorageClass::Function
,
78 /*initializer=*/nullptr);
79 allocas
.push_back(alloc
);
80 rewriter
.setInsertionPointAfter(newOp
);
81 Value loadResult
= rewriter
.create
<spirv::LoadOp
>(loc
, alloc
);
82 resultValue
.push_back(loadResult
);
84 rewriter
.replaceOp(scfOp
, resultValue
);
87 Region::iterator
getBlockIt(Region
®ion
, unsigned index
) {
88 return std::next(region
.begin(), index
);
91 //===----------------------------------------------------------------------===//
92 // Conversion Patterns
93 //===----------------------------------------------------------------------===//
95 /// Common class for all vector to GPU patterns.
96 template <typename OpTy
>
97 class SCFToSPIRVPattern
: public OpConversionPattern
<OpTy
> {
99 SCFToSPIRVPattern(MLIRContext
*context
, const SPIRVTypeConverter
&converter
,
100 ScfToSPIRVContextImpl
*scfToSPIRVContext
)
101 : OpConversionPattern
<OpTy
>::OpConversionPattern(converter
, context
),
102 scfToSPIRVContext(scfToSPIRVContext
), typeConverter(converter
) {}
105 ScfToSPIRVContextImpl
*scfToSPIRVContext
;
106 // FIXME: We explicitly keep a reference of the type converter here instead of
107 // passing it to OpConversionPattern during construction. This effectively
108 // bypasses the conversion framework's automation on type conversion. This is
109 // needed right now because the conversion framework will unconditionally
110 // legalize all types used by SCF ops upon discovering them, for example, the
111 // types of loop carried values. We use SPIR-V variables for those loop
112 // carried values. Depending on the available capabilities, the SPIR-V
113 // variable can be different, for example, cooperative matrix or normal
114 // variable. We'd like to detach the conversion of the loop carried values
115 // from the SCF ops (which is mainly a region). So we need to "mark" types
116 // used by SCF ops as legal, if to use the conversion framework for type
117 // conversion. There isn't a straightforward way to do that yet, as when
118 // converting types, ops aren't taken into consideration. Therefore, we just
119 // bypass the framework's type conversion for now.
120 const SPIRVTypeConverter
&typeConverter
;
123 //===----------------------------------------------------------------------===//
125 //===----------------------------------------------------------------------===//
127 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
128 struct ForOpConversion final
: SCFToSPIRVPattern
<scf::ForOp
> {
129 using SCFToSPIRVPattern::SCFToSPIRVPattern
;
132 matchAndRewrite(scf::ForOp forOp
, OpAdaptor adaptor
,
133 ConversionPatternRewriter
&rewriter
) const override
{
134 // scf::ForOp can be lowered to the structured control flow represented by
135 // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
136 // latch and the merge block the exit block. The resulting spirv::LoopOp has
137 // a single back edge from the continue to header block, and a single exit
138 // from header to merge.
139 auto loc
= forOp
.getLoc();
140 auto loopOp
= rewriter
.create
<spirv::LoopOp
>(loc
, spirv::LoopControl::None
);
141 loopOp
.addEntryAndMergeBlock(rewriter
);
143 OpBuilder::InsertionGuard
guard(rewriter
);
144 // Create the block for the header.
145 Block
*header
= rewriter
.createBlock(&loopOp
.getBody(),
146 getBlockIt(loopOp
.getBody(), 1));
147 rewriter
.setInsertionPointAfter(loopOp
);
149 // Create the new induction variable to use.
150 Value adapLowerBound
= adaptor
.getLowerBound();
151 BlockArgument newIndVar
=
152 header
->addArgument(adapLowerBound
.getType(), adapLowerBound
.getLoc());
153 for (Value arg
: adaptor
.getInitArgs())
154 header
->addArgument(arg
.getType(), arg
.getLoc());
155 Block
*body
= forOp
.getBody();
157 // Apply signature conversion to the body of the forOp. It has a single
158 // block, with argument which is the induction variable. That has to be
159 // replaced with the new induction variable.
160 TypeConverter::SignatureConversion
signatureConverter(
161 body
->getNumArguments());
162 signatureConverter
.remapInput(0, newIndVar
);
163 for (unsigned i
= 1, e
= body
->getNumArguments(); i
< e
; i
++)
164 signatureConverter
.remapInput(i
, header
->getArgument(i
));
165 body
= rewriter
.applySignatureConversion(&forOp
.getRegion().front(),
168 // Move the blocks from the forOp into the loopOp. This is the body of the
170 rewriter
.inlineRegionBefore(forOp
->getRegion(0), loopOp
.getBody(),
171 getBlockIt(loopOp
.getBody(), 2));
173 SmallVector
<Value
, 8> args(1, adaptor
.getLowerBound());
174 args
.append(adaptor
.getInitArgs().begin(), adaptor
.getInitArgs().end());
175 // Branch into it from the entry.
176 rewriter
.setInsertionPointToEnd(&(loopOp
.getBody().front()));
177 rewriter
.create
<spirv::BranchOp
>(loc
, header
, args
);
179 // Generate the rest of the loop header.
180 rewriter
.setInsertionPointToEnd(header
);
181 auto *mergeBlock
= loopOp
.getMergeBlock();
182 auto cmpOp
= rewriter
.create
<spirv::SLessThanOp
>(
183 loc
, rewriter
.getI1Type(), newIndVar
, adaptor
.getUpperBound());
185 rewriter
.create
<spirv::BranchConditionalOp
>(
186 loc
, cmpOp
, body
, ArrayRef
<Value
>(), mergeBlock
, ArrayRef
<Value
>());
188 // Generate instructions to increment the step of the induction variable and
189 // branch to the header.
190 Block
*continueBlock
= loopOp
.getContinueBlock();
191 rewriter
.setInsertionPointToEnd(continueBlock
);
193 // Add the step to the induction variable and branch to the header.
194 Value updatedIndVar
= rewriter
.create
<spirv::IAddOp
>(
195 loc
, newIndVar
.getType(), newIndVar
, adaptor
.getStep());
196 rewriter
.create
<spirv::BranchOp
>(loc
, header
, updatedIndVar
);
198 // Infer the return types from the init operands. Vector type may get
199 // converted to CooperativeMatrix or to Vector type, to avoid having complex
200 // extra logic to figure out the right type we just infer it from the Init
202 SmallVector
<Type
, 8> initTypes
;
203 for (auto arg
: adaptor
.getInitArgs())
204 initTypes
.push_back(arg
.getType());
205 replaceSCFOutputValue(forOp
, loopOp
, rewriter
, scfToSPIRVContext
,
211 //===----------------------------------------------------------------------===//
213 //===----------------------------------------------------------------------===//
215 /// Pattern to convert a scf::IfOp within kernel functions into
216 /// spirv::SelectionOp.
217 struct IfOpConversion
: SCFToSPIRVPattern
<scf::IfOp
> {
218 using SCFToSPIRVPattern::SCFToSPIRVPattern
;
221 matchAndRewrite(scf::IfOp ifOp
, OpAdaptor adaptor
,
222 ConversionPatternRewriter
&rewriter
) const override
{
223 // When lowering `scf::IfOp` we explicitly create a selection header block
224 // before the control flow diverges and a merge block where control flow
225 // subsequently converges.
226 auto loc
= ifOp
.getLoc();
228 // Create `spirv.selection` operation, selection header block and merge
231 rewriter
.create
<spirv::SelectionOp
>(loc
, spirv::SelectionControl::None
);
232 auto *mergeBlock
= rewriter
.createBlock(&selectionOp
.getBody(),
233 selectionOp
.getBody().end());
234 rewriter
.create
<spirv::MergeOp
>(loc
);
236 OpBuilder::InsertionGuard
guard(rewriter
);
237 auto *selectionHeaderBlock
=
238 rewriter
.createBlock(&selectionOp
.getBody().front());
240 // Inline `then` region before the merge block and branch to it.
241 auto &thenRegion
= ifOp
.getThenRegion();
242 auto *thenBlock
= &thenRegion
.front();
243 rewriter
.setInsertionPointToEnd(&thenRegion
.back());
244 rewriter
.create
<spirv::BranchOp
>(loc
, mergeBlock
);
245 rewriter
.inlineRegionBefore(thenRegion
, mergeBlock
);
247 auto *elseBlock
= mergeBlock
;
248 // If `else` region is not empty, inline that region before the merge block
250 if (!ifOp
.getElseRegion().empty()) {
251 auto &elseRegion
= ifOp
.getElseRegion();
252 elseBlock
= &elseRegion
.front();
253 rewriter
.setInsertionPointToEnd(&elseRegion
.back());
254 rewriter
.create
<spirv::BranchOp
>(loc
, mergeBlock
);
255 rewriter
.inlineRegionBefore(elseRegion
, mergeBlock
);
258 // Create a `spirv.BranchConditional` operation for selection header block.
259 rewriter
.setInsertionPointToEnd(selectionHeaderBlock
);
260 rewriter
.create
<spirv::BranchConditionalOp
>(loc
, adaptor
.getCondition(),
261 thenBlock
, ArrayRef
<Value
>(),
262 elseBlock
, ArrayRef
<Value
>());
264 SmallVector
<Type
, 8> returnTypes
;
265 for (auto result
: ifOp
.getResults()) {
266 auto convertedType
= typeConverter
.convertType(result
.getType());
268 return rewriter
.notifyMatchFailure(
270 llvm::formatv("failed to convert type '{0}'", result
.getType()));
272 returnTypes
.push_back(convertedType
);
274 replaceSCFOutputValue(ifOp
, selectionOp
, rewriter
, scfToSPIRVContext
,
280 //===----------------------------------------------------------------------===//
282 //===----------------------------------------------------------------------===//
284 struct TerminatorOpConversion final
: SCFToSPIRVPattern
<scf::YieldOp
> {
286 using SCFToSPIRVPattern::SCFToSPIRVPattern
;
289 matchAndRewrite(scf::YieldOp terminatorOp
, OpAdaptor adaptor
,
290 ConversionPatternRewriter
&rewriter
) const override
{
291 ValueRange operands
= adaptor
.getOperands();
293 Operation
*parent
= terminatorOp
->getParentOp();
295 // TODO: Implement conversion for the remaining `scf` ops.
296 if (parent
->getDialect()->getNamespace() ==
297 scf::SCFDialect::getDialectNamespace() &&
298 !isa
<scf::IfOp
, scf::ForOp
, scf::WhileOp
>(parent
))
299 return rewriter
.notifyMatchFailure(
301 llvm::formatv("conversion not supported for parent op: '{0}'",
304 // If the region return values, store each value into the associated
305 // VariableOp created during lowering of the parent region.
306 if (!operands
.empty()) {
307 auto &allocas
= scfToSPIRVContext
->outputVars
[parent
];
308 if (allocas
.size() != operands
.size())
311 auto loc
= terminatorOp
.getLoc();
312 for (unsigned i
= 0, e
= operands
.size(); i
< e
; i
++)
313 rewriter
.create
<spirv::StoreOp
>(loc
, allocas
[i
], operands
[i
]);
314 if (isa
<spirv::LoopOp
>(parent
)) {
315 // For loops we also need to update the branch jumping back to the
317 auto br
= cast
<spirv::BranchOp
>(
318 rewriter
.getInsertionBlock()->getTerminator());
319 SmallVector
<Value
, 8> args(br
.getBlockArguments());
320 args
.append(operands
.begin(), operands
.end());
321 rewriter
.setInsertionPoint(br
);
322 rewriter
.create
<spirv::BranchOp
>(terminatorOp
.getLoc(), br
.getTarget(),
324 rewriter
.eraseOp(br
);
327 rewriter
.eraseOp(terminatorOp
);
332 //===----------------------------------------------------------------------===//
334 //===----------------------------------------------------------------------===//
336 struct WhileOpConversion final
: SCFToSPIRVPattern
<scf::WhileOp
> {
337 using SCFToSPIRVPattern::SCFToSPIRVPattern
;
340 matchAndRewrite(scf::WhileOp whileOp
, OpAdaptor adaptor
,
341 ConversionPatternRewriter
&rewriter
) const override
{
342 auto loc
= whileOp
.getLoc();
343 auto loopOp
= rewriter
.create
<spirv::LoopOp
>(loc
, spirv::LoopControl::None
);
344 loopOp
.addEntryAndMergeBlock(rewriter
);
346 Region
&beforeRegion
= whileOp
.getBefore();
347 Region
&afterRegion
= whileOp
.getAfter();
349 if (failed(rewriter
.convertRegionTypes(&beforeRegion
, typeConverter
)) ||
350 failed(rewriter
.convertRegionTypes(&afterRegion
, typeConverter
)))
351 return rewriter
.notifyMatchFailure(whileOp
,
352 "Failed to convert region types");
354 OpBuilder::InsertionGuard
guard(rewriter
);
356 Block
&entryBlock
= *loopOp
.getEntryBlock();
357 Block
&beforeBlock
= beforeRegion
.front();
358 Block
&afterBlock
= afterRegion
.front();
359 Block
&mergeBlock
= *loopOp
.getMergeBlock();
361 auto cond
= cast
<scf::ConditionOp
>(beforeBlock
.getTerminator());
362 SmallVector
<Value
> condArgs
;
363 if (failed(rewriter
.getRemappedValues(cond
.getArgs(), condArgs
)))
366 Value conditionVal
= rewriter
.getRemappedValue(cond
.getCondition());
370 auto yield
= cast
<scf::YieldOp
>(afterBlock
.getTerminator());
371 SmallVector
<Value
> yieldArgs
;
372 if (failed(rewriter
.getRemappedValues(yield
.getResults(), yieldArgs
)))
375 // Move the while before block as the initial loop header block.
376 rewriter
.inlineRegionBefore(beforeRegion
, loopOp
.getBody(),
377 getBlockIt(loopOp
.getBody(), 1));
379 // Move the while after block as the initial loop body block.
380 rewriter
.inlineRegionBefore(afterRegion
, loopOp
.getBody(),
381 getBlockIt(loopOp
.getBody(), 2));
383 // Jump from the loop entry block to the loop header block.
384 rewriter
.setInsertionPointToEnd(&entryBlock
);
385 rewriter
.create
<spirv::BranchOp
>(loc
, &beforeBlock
, adaptor
.getInits());
387 auto condLoc
= cond
.getLoc();
389 SmallVector
<Value
> resultValues(condArgs
.size());
391 // For other SCF ops, the scf.yield op yields the value for the whole SCF
392 // op. So we use the scf.yield op as the anchor to create/load/store SPIR-V
393 // local variables. But for the scf.while op, the scf.yield op yields a
394 // value for the before region, which may not matching the whole op's
395 // result. Instead, the scf.condition op returns values matching the whole
396 // op's results. So we need to create/load/store variables according to
398 for (const auto &it
: llvm::enumerate(condArgs
)) {
399 auto res
= it
.value();
402 spirv::PointerType::get(res
.getType(), spirv::StorageClass::Function
);
404 // Create local variables before the scf.while op.
405 rewriter
.setInsertionPoint(loopOp
);
406 auto alloc
= rewriter
.create
<spirv::VariableOp
>(
407 condLoc
, pointerType
, spirv::StorageClass::Function
,
408 /*initializer=*/nullptr);
410 // Load the final result values after the scf.while op.
411 rewriter
.setInsertionPointAfter(loopOp
);
412 auto loadResult
= rewriter
.create
<spirv::LoadOp
>(condLoc
, alloc
);
413 resultValues
[i
] = loadResult
;
415 // Store the current iteration's result value.
416 rewriter
.setInsertionPointToEnd(&beforeBlock
);
417 rewriter
.create
<spirv::StoreOp
>(condLoc
, alloc
, res
);
420 rewriter
.setInsertionPointToEnd(&beforeBlock
);
421 rewriter
.replaceOpWithNewOp
<spirv::BranchConditionalOp
>(
422 cond
, conditionVal
, &afterBlock
, condArgs
, &mergeBlock
, std::nullopt
);
424 // Convert the scf.yield op to a branch back to the header block.
425 rewriter
.setInsertionPointToEnd(&afterBlock
);
426 rewriter
.replaceOpWithNewOp
<spirv::BranchOp
>(yield
, &beforeBlock
,
429 rewriter
.replaceOp(whileOp
, resultValues
);
435 //===----------------------------------------------------------------------===//
437 //===----------------------------------------------------------------------===//
439 void mlir::populateSCFToSPIRVPatterns(const SPIRVTypeConverter
&typeConverter
,
440 ScfToSPIRVContext
&scfToSPIRVContext
,
441 RewritePatternSet
&patterns
) {
442 patterns
.add
<ForOpConversion
, IfOpConversion
, TerminatorOpConversion
,
443 WhileOpConversion
>(patterns
.getContext(), typeConverter
,
444 scfToSPIRVContext
.getImpl());