1 //===- NVVMToLLVM.cpp - NVVM to LLVM dialect conversion -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a translation NVVM ops which is not supported in LLVM
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
16 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17 #include "mlir/Conversion/LLVMCommon/Pattern.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
21 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Support/LLVM.h"
28 #include "llvm/Support/raw_ostream.h"
30 #define DEBUG_TYPE "nvvm-to-llvm"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32 #define DBGSNL() (llvm::dbgs() << "\n")
35 #define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS
36 #include "mlir/Conversion/Passes.h.inc"
45 : public OpInterfaceRewritePattern
<BasicPtxBuilderInterface
> {
46 using OpInterfaceRewritePattern
<
47 BasicPtxBuilderInterface
>::OpInterfaceRewritePattern
;
49 PtxLowering(MLIRContext
*context
, PatternBenefit benefit
= 2)
50 : OpInterfaceRewritePattern(context
, benefit
) {}
52 LogicalResult
matchAndRewrite(BasicPtxBuilderInterface op
,
53 PatternRewriter
&rewriter
) const override
{
54 if (op
.hasIntrinsic()) {
55 LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op
<< "\n");
59 SmallVector
<std::pair
<Value
, PTXRegisterMod
>> asmValues
;
60 LLVM_DEBUG(DBGS() << op
.getPtx() << "\n");
61 PtxBuilder
generator(op
, rewriter
);
63 op
.getAsmValues(rewriter
, asmValues
);
64 for (auto &[asmValue
, modifier
] : asmValues
) {
65 LLVM_DEBUG(DBGSNL() << asmValue
<< "\t Modifier : " << &modifier
);
66 generator
.insertValue(asmValue
, modifier
);
69 generator
.buildAndReplaceOp();
74 struct ConvertNVVMToLLVMPass
75 : public impl::ConvertNVVMToLLVMPassBase
<ConvertNVVMToLLVMPass
> {
78 void getDependentDialects(DialectRegistry
®istry
) const override
{
79 registry
.insert
<LLVM::LLVMDialect
, NVVM::NVVMDialect
>();
82 void runOnOperation() override
{
83 ConversionTarget
target(getContext());
84 target
.addLegalDialect
<::mlir::LLVM::LLVMDialect
>();
85 RewritePatternSet
pattern(&getContext());
86 mlir::populateNVVMToLLVMConversionPatterns(pattern
);
88 applyPartialConversion(getOperation(), target
, std::move(pattern
))))
93 /// Implement the interface to convert NVVM to LLVM.
94 struct NVVMToLLVMDialectInterface
: public ConvertToLLVMPatternInterface
{
95 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface
;
96 void loadDependentDialects(MLIRContext
*context
) const final
{
97 context
->loadDialect
<NVVMDialect
>();
100 /// Hook for derived dialect interface to provide conversion patterns
101 /// and mark dialect legal for the conversion target.
102 void populateConvertToLLVMConversionPatterns(
103 ConversionTarget
&target
, LLVMTypeConverter
&typeConverter
,
104 RewritePatternSet
&patterns
) const final
{
105 populateNVVMToLLVMConversionPatterns(patterns
);
111 void mlir::populateNVVMToLLVMConversionPatterns(RewritePatternSet
&patterns
) {
112 patterns
.add
<PtxLowering
>(patterns
.getContext());
115 void mlir::registerConvertNVVMToLLVMInterface(DialectRegistry
®istry
) {
116 registry
.addExtension(+[](MLIRContext
*ctx
, NVVMDialect
*dialect
) {
117 dialect
->addInterfaces
<NVVMToLLVMDialectInterface
>();