Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / NVGPU / TestNVGPUTransforms.cpp
blob8ca29257b812033669a431e12f99ee913286815f
1 //===- TestNVGPUTransforms.cpp - Test NVGPU transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include <type_traits>
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Support/LLVM.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 using namespace mlir;
28 using namespace mlir::nvgpu;
30 namespace {
32 struct TestMmaSyncF32ToTF32Patterns
33 : public PassWrapper<TestMmaSyncF32ToTF32Patterns,
34 OperationPass<func::FuncOp>> {
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMmaSyncF32ToTF32Patterns)
37 StringRef getArgument() const final {
38 return "test-nvgpu-mmasync-f32-to-tf32-patterns";
40 StringRef getDescription() const final {
41 return "Test patterns to convert mma.sync on f32 with tf32 precision";
43 TestMmaSyncF32ToTF32Patterns() = default;
44 TestMmaSyncF32ToTF32Patterns(const TestMmaSyncF32ToTF32Patterns &pass)
45 : PassWrapper(pass) {}
47 Option<std::string> precision{
48 *this, "precision",
49 llvm::cl::desc(
50 "Target nvgpu.mma.sync on f32 input with tf32 or tf32x3 precision"),
51 llvm::cl::init("tf32")};
53 MmaSyncF32Lowering tf32Precision =
54 llvm::StringSwitch<MmaSyncF32Lowering>(precision)
55 .Case("tf32", MmaSyncF32Lowering::TF32)
56 .Case("tf32x3", MmaSyncF32Lowering::TF32x3)
57 .Default(MmaSyncF32Lowering::Unkown);
59 void runOnOperation() override {
60 RewritePatternSet patterns(&getContext());
62 populateMmaSyncF32ToTF32Patterns(patterns, tf32Precision);
63 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
67 } // namespace
69 namespace mlir {
70 namespace test {
71 void registerTestNVGPULowerings() {
72 PassRegistration<TestMmaSyncF32ToTF32Patterns>();
75 } // namespace test
76 } // namespace mlir