1 //===- TestLoopMapping.cpp --- Parametric loop mapping pass ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass to parametrically map scf.for loops to virtual
10 // processing element dimensions.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/LoopUtils.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/Pass/Pass.h"
21 using namespace mlir::affine
;
24 struct TestLoopMappingPass
25 : public PassWrapper
<TestLoopMappingPass
, OperationPass
<>> {
26 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopMappingPass
)
28 StringRef
getArgument() const final
{
29 return "test-mapping-to-processing-elements";
31 StringRef
getDescription() const final
{
32 return "test mapping a single loop on a virtual processor grid";
34 explicit TestLoopMappingPass() = default;
36 void getDependentDialects(DialectRegistry
®istry
) const override
{
37 registry
.insert
<affine::AffineDialect
, scf::SCFDialect
>();
40 void runOnOperation() override
{
41 // SSA values for the transformation are created out of thin air by
42 // unregistered "new_processor_id_and_range" operations. This is enough to
43 // emulate mapping conditions.
44 SmallVector
<Value
, 8> processorIds
, numProcessors
;
45 getOperation()->walk([&processorIds
, &numProcessors
](Operation
*op
) {
46 if (op
->getName().getStringRef() != "new_processor_id_and_range")
48 processorIds
.push_back(op
->getResult(0));
49 numProcessors
.push_back(op
->getResult(1));
52 getOperation()->walk([&processorIds
, &numProcessors
](scf::ForOp op
) {
53 // Ignore nested loops.
54 if (op
->getParentRegion()->getParentOfType
<scf::ForOp
>())
56 mapLoopToProcessorIds(op
, processorIds
, numProcessors
);
64 void registerTestLoopMappingPass() { PassRegistration
<TestLoopMappingPass
>(); }