1 //===- TestMatchReduction.cpp - Test the match reduction utility ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains a test pass for the match reduction utility.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Interfaces/FunctionInterfaces.h"
15 #include "mlir/Pass/Pass.h"
21 void printReductionResult(Operation
*redRegionOp
, unsigned numOutput
,
23 ArrayRef
<Operation
*> combinerOps
) {
25 redRegionOp
->emitRemark("Reduction found in output #") << numOutput
<< "!";
26 redRegionOp
->emitRemark("Reduced Value: ") << reducedValue
;
27 for (Operation
*combOp
: combinerOps
)
28 redRegionOp
->emitRemark("Combiner Op: ") << *combOp
;
33 redRegionOp
->emitRemark("Reduction NOT found in output #")
37 struct TestMatchReductionPass
38 : public PassWrapper
<TestMatchReductionPass
,
39 InterfacePass
<FunctionOpInterface
>> {
40 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMatchReductionPass
)
42 StringRef
getArgument() const final
{ return "test-match-reduction"; }
43 StringRef
getDescription() const final
{
44 return "Test the match reduction utility.";
47 void runOnOperation() override
{
48 FunctionOpInterface func
= getOperation();
49 func
->emitRemark("Testing function");
51 func
.walk
<WalkOrder::PreOrder
>([](Operation
*op
) {
52 if (isa
<FunctionOpInterface
>(op
))
55 // Limit testing to ops with only one region.
56 if (op
->getNumRegions() != 1)
59 Region
®ion
= op
->getRegion(0);
60 if (!region
.hasOneBlock())
63 // We expect all the tested region ops to have 1 input by default. The
64 // remaining arguments are assumed to be outputs/reductions and there must
66 // TODO: Extend it to support more generic cases.
67 Block
®ionEntry
= region
.front();
68 auto args
= regionEntry
.getArguments();
72 auto outputs
= args
.drop_front();
73 for (int i
= 0, size
= outputs
.size(); i
< size
; ++i
) {
74 SmallVector
<Operation
*, 4> combinerOps
;
75 Value reducedValue
= matchReduction(outputs
, i
, combinerOps
);
76 printReductionResult(op
, i
, reducedValue
, combinerOps
);
86 void registerTestMatchReductionPass() {
87 PassRegistration
<TestMatchReductionPass
>();