1 //===- LoopLikeSCFOpsTest.cpp - SCF LoopLikeOpInterface Tests -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/SCF/IR/SCF.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/MLIRContext.h"
13 #include "mlir/IR/OwningOpRef.h"
14 #include "gtest/gtest.h"
17 using namespace mlir::scf
;
19 //===----------------------------------------------------------------------===//
21 //===----------------------------------------------------------------------===//
23 class SCFLoopLikeTest
: public ::testing::Test
{
25 SCFLoopLikeTest() : b(&context
), loc(UnknownLoc::get(&context
)) {
26 context
.loadDialect
<arith::ArithDialect
, scf::SCFDialect
>();
29 void checkUnidimensional(LoopLikeOpInterface loopLikeOp
) {
30 std::optional
<OpFoldResult
> maybeLb
= loopLikeOp
.getSingleLowerBound();
31 EXPECT_TRUE(maybeLb
.has_value());
32 std::optional
<OpFoldResult
> maybeUb
= loopLikeOp
.getSingleUpperBound();
33 EXPECT_TRUE(maybeUb
.has_value());
34 std::optional
<OpFoldResult
> maybeStep
= loopLikeOp
.getSingleStep();
35 EXPECT_TRUE(maybeStep
.has_value());
36 std::optional
<OpFoldResult
> maybeIndVar
=
37 loopLikeOp
.getSingleInductionVar();
38 EXPECT_TRUE(maybeIndVar
.has_value());
41 void checkMultidimensional(LoopLikeOpInterface loopLikeOp
) {
42 std::optional
<OpFoldResult
> maybeLb
= loopLikeOp
.getSingleLowerBound();
43 EXPECT_FALSE(maybeLb
.has_value());
44 std::optional
<OpFoldResult
> maybeUb
= loopLikeOp
.getSingleUpperBound();
45 EXPECT_FALSE(maybeUb
.has_value());
46 std::optional
<OpFoldResult
> maybeStep
= loopLikeOp
.getSingleStep();
47 EXPECT_FALSE(maybeStep
.has_value());
48 std::optional
<OpFoldResult
> maybeIndVar
=
49 loopLikeOp
.getSingleInductionVar();
50 EXPECT_FALSE(maybeIndVar
.has_value());
58 TEST_F(SCFLoopLikeTest
, queryUnidimensionalLooplikes
) {
59 OwningOpRef
<arith::ConstantIndexOp
> lb
=
60 b
.create
<arith::ConstantIndexOp
>(loc
, 0);
61 OwningOpRef
<arith::ConstantIndexOp
> ub
=
62 b
.create
<arith::ConstantIndexOp
>(loc
, 10);
63 OwningOpRef
<arith::ConstantIndexOp
> step
=
64 b
.create
<arith::ConstantIndexOp
>(loc
, 2);
66 OwningOpRef
<scf::ForOp
> forOp
=
67 b
.create
<scf::ForOp
>(loc
, lb
.get(), ub
.get(), step
.get());
68 checkUnidimensional(forOp
.get());
70 OwningOpRef
<scf::ForallOp
> forallOp
= b
.create
<scf::ForallOp
>(
71 loc
, ArrayRef
<OpFoldResult
>(lb
->getResult()),
72 ArrayRef
<OpFoldResult
>(ub
->getResult()),
73 ArrayRef
<OpFoldResult
>(step
->getResult()), ValueRange(), std::nullopt
);
74 checkUnidimensional(forallOp
.get());
76 OwningOpRef
<scf::ParallelOp
> parallelOp
= b
.create
<scf::ParallelOp
>(
77 loc
, ValueRange(lb
->getResult()), ValueRange(ub
->getResult()),
78 ValueRange(step
->getResult()), ValueRange());
79 checkUnidimensional(parallelOp
.get());
82 TEST_F(SCFLoopLikeTest
, queryMultidimensionalLooplikes
) {
83 OwningOpRef
<arith::ConstantIndexOp
> lb
=
84 b
.create
<arith::ConstantIndexOp
>(loc
, 0);
85 OwningOpRef
<arith::ConstantIndexOp
> ub
=
86 b
.create
<arith::ConstantIndexOp
>(loc
, 10);
87 OwningOpRef
<arith::ConstantIndexOp
> step
=
88 b
.create
<arith::ConstantIndexOp
>(loc
, 2);
90 OwningOpRef
<scf::ForallOp
> forallOp
= b
.create
<scf::ForallOp
>(
91 loc
, ArrayRef
<OpFoldResult
>({lb
->getResult(), lb
->getResult()}),
92 ArrayRef
<OpFoldResult
>({ub
->getResult(), ub
->getResult()}),
93 ArrayRef
<OpFoldResult
>({step
->getResult(), step
->getResult()}),
94 ValueRange(), std::nullopt
);
95 checkMultidimensional(forallOp
.get());
97 OwningOpRef
<scf::ParallelOp
> parallelOp
= b
.create
<scf::ParallelOp
>(
98 loc
, ValueRange({lb
->getResult(), lb
->getResult()}),
99 ValueRange({ub
->getResult(), ub
->getResult()}),
100 ValueRange({step
->getResult(), step
->getResult()}), ValueRange());
101 checkMultidimensional(parallelOp
.get());