[Clang][SME2] Fix PSEL builtin predicates (#77097)
[llvm-project.git] / mlir / unittests / Dialect / SCF / LoopLikeSCFOpsTest.cpp
blob6bc0fd6113b9bbf00c633d799719879024e39ca6
1 //===- LoopLikeSCFOpsTest.cpp - SCF LoopLikeOpInterface Tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/SCF/IR/SCF.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/MLIRContext.h"
13 #include "mlir/IR/OwningOpRef.h"
14 #include "gtest/gtest.h"
16 using namespace mlir;
17 using namespace mlir::scf;
19 //===----------------------------------------------------------------------===//
20 // Test Fixture
21 //===----------------------------------------------------------------------===//
23 class SCFLoopLikeTest : public ::testing::Test {
24 protected:
25 SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
26 context.loadDialect<arith::ArithDialect, scf::SCFDialect>();
29 void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
30 std::optional<OpFoldResult> maybeLb = loopLikeOp.getSingleLowerBound();
31 EXPECT_TRUE(maybeLb.has_value());
32 std::optional<OpFoldResult> maybeUb = loopLikeOp.getSingleUpperBound();
33 EXPECT_TRUE(maybeUb.has_value());
34 std::optional<OpFoldResult> maybeStep = loopLikeOp.getSingleStep();
35 EXPECT_TRUE(maybeStep.has_value());
36 std::optional<OpFoldResult> maybeIndVar =
37 loopLikeOp.getSingleInductionVar();
38 EXPECT_TRUE(maybeIndVar.has_value());
41 void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
42 std::optional<OpFoldResult> maybeLb = loopLikeOp.getSingleLowerBound();
43 EXPECT_FALSE(maybeLb.has_value());
44 std::optional<OpFoldResult> maybeUb = loopLikeOp.getSingleUpperBound();
45 EXPECT_FALSE(maybeUb.has_value());
46 std::optional<OpFoldResult> maybeStep = loopLikeOp.getSingleStep();
47 EXPECT_FALSE(maybeStep.has_value());
48 std::optional<OpFoldResult> maybeIndVar =
49 loopLikeOp.getSingleInductionVar();
50 EXPECT_FALSE(maybeIndVar.has_value());
53 MLIRContext context;
54 OpBuilder b;
55 Location loc;
58 TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) {
59 OwningOpRef<arith::ConstantIndexOp> lb =
60 b.create<arith::ConstantIndexOp>(loc, 0);
61 OwningOpRef<arith::ConstantIndexOp> ub =
62 b.create<arith::ConstantIndexOp>(loc, 10);
63 OwningOpRef<arith::ConstantIndexOp> step =
64 b.create<arith::ConstantIndexOp>(loc, 2);
66 OwningOpRef<scf::ForOp> forOp =
67 b.create<scf::ForOp>(loc, lb.get(), ub.get(), step.get());
68 checkUnidimensional(forOp.get());
70 OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>(
71 loc, ArrayRef<OpFoldResult>(lb->getResult()),
72 ArrayRef<OpFoldResult>(ub->getResult()),
73 ArrayRef<OpFoldResult>(step->getResult()), ValueRange(), std::nullopt);
74 checkUnidimensional(forallOp.get());
76 OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>(
77 loc, ValueRange(lb->getResult()), ValueRange(ub->getResult()),
78 ValueRange(step->getResult()), ValueRange());
79 checkUnidimensional(parallelOp.get());
82 TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
83 OwningOpRef<arith::ConstantIndexOp> lb =
84 b.create<arith::ConstantIndexOp>(loc, 0);
85 OwningOpRef<arith::ConstantIndexOp> ub =
86 b.create<arith::ConstantIndexOp>(loc, 10);
87 OwningOpRef<arith::ConstantIndexOp> step =
88 b.create<arith::ConstantIndexOp>(loc, 2);
90 OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>(
91 loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}),
92 ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
93 ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
94 ValueRange(), std::nullopt);
95 checkMultidimensional(forallOp.get());
97 OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>(
98 loc, ValueRange({lb->getResult(), lb->getResult()}),
99 ValueRange({ub->getResult(), ub->getResult()}),
100 ValueRange({step->getResult(), step->getResult()}), ValueRange());
101 checkMultidimensional(parallelOp.get());