1 //===- TestNVGPUTransforms.cpp - Test NVGPU transforms and lowerings ----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Support/LLVM.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 using namespace mlir::nvgpu
;
32 struct TestMmaSyncF32ToTF32Patterns
33 : public PassWrapper
<TestMmaSyncF32ToTF32Patterns
,
34 OperationPass
<func::FuncOp
>> {
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMmaSyncF32ToTF32Patterns
)
37 StringRef
getArgument() const final
{
38 return "test-nvgpu-mmasync-f32-to-tf32-patterns";
40 StringRef
getDescription() const final
{
41 return "Test patterns to convert mma.sync on f32 with tf32 precision";
43 TestMmaSyncF32ToTF32Patterns() = default;
44 TestMmaSyncF32ToTF32Patterns(const TestMmaSyncF32ToTF32Patterns
&pass
)
45 : PassWrapper(pass
) {}
47 Option
<std::string
> precision
{
50 "Target nvgpu.mma.sync on f32 input with tf32 or tf32x3 precision"),
51 llvm::cl::init("tf32")};
53 MmaSyncF32Lowering tf32Precision
=
54 llvm::StringSwitch
<MmaSyncF32Lowering
>(precision
)
55 .Case("tf32", MmaSyncF32Lowering::TF32
)
56 .Case("tf32x3", MmaSyncF32Lowering::TF32x3
)
57 .Default(MmaSyncF32Lowering::Unkown
);
59 void runOnOperation() override
{
60 RewritePatternSet
patterns(&getContext());
62 populateMmaSyncF32ToTF32Patterns(patterns
, tf32Precision
);
63 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
71 void registerTestNVGPULowerings() {
72 PassRegistration
<TestMmaSyncF32ToTF32Patterns
>();