1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
11 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
12 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
13 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
14 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
15 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
16 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17 #include "mlir/Conversion/LLVMCommon/Pattern.h"
18 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
21 #include "mlir/Pass/Pass.h"
24 #define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
31 /// A pattern that converts the region arguments in a single-region OpenMP
32 /// operation to the LLVM dialect. The body of the region is not modified and is
33 /// expected to either be processed by the conversion infrastructure or already
34 /// contain ops compatible with LLVM dialect types.
35 template <typename OpType
>
36 struct RegionOpConversion
: public ConvertOpToLLVMPattern
<OpType
> {
37 using ConvertOpToLLVMPattern
<OpType
>::ConvertOpToLLVMPattern
;
40 matchAndRewrite(OpType curOp
, typename
OpType::Adaptor adaptor
,
41 ConversionPatternRewriter
&rewriter
) const override
{
42 auto newOp
= rewriter
.create
<OpType
>(
43 curOp
.getLoc(), TypeRange(), adaptor
.getOperands(), curOp
->getAttrs());
44 rewriter
.inlineRegionBefore(curOp
.getRegion(), newOp
.getRegion(),
45 newOp
.getRegion().end());
46 if (failed(rewriter
.convertRegionTypes(&newOp
.getRegion(),
47 *this->getTypeConverter())))
50 rewriter
.eraseOp(curOp
);
56 struct RegionLessOpWithVarOperandsConversion
57 : public ConvertOpToLLVMPattern
<T
> {
58 using ConvertOpToLLVMPattern
<T
>::ConvertOpToLLVMPattern
;
60 matchAndRewrite(T curOp
, typename
T::Adaptor adaptor
,
61 ConversionPatternRewriter
&rewriter
) const override
{
62 const TypeConverter
*converter
= ConvertToLLVMPattern::getTypeConverter();
63 SmallVector
<Type
> resTypes
;
64 if (failed(converter
->convertTypes(curOp
->getResultTypes(), resTypes
)))
66 SmallVector
<Value
> convertedOperands
;
67 assert(curOp
.getNumVariableOperands() ==
68 curOp
.getOperation()->getNumOperands() &&
69 "unexpected non-variable operands");
70 for (unsigned idx
= 0; idx
< curOp
.getNumVariableOperands(); ++idx
) {
71 Value originalVariableOperand
= curOp
.getVariableOperand(idx
);
72 if (!originalVariableOperand
)
74 if (isa
<MemRefType
>(originalVariableOperand
.getType())) {
75 // TODO: Support memref type in variable operands
76 return rewriter
.notifyMatchFailure(curOp
,
77 "memref is not supported yet");
79 convertedOperands
.emplace_back(adaptor
.getOperands()[idx
]);
82 rewriter
.replaceOpWithNewOp
<T
>(curOp
, resTypes
, convertedOperands
,
89 struct RegionOpWithVarOperandsConversion
: public ConvertOpToLLVMPattern
<T
> {
90 using ConvertOpToLLVMPattern
<T
>::ConvertOpToLLVMPattern
;
92 matchAndRewrite(T curOp
, typename
T::Adaptor adaptor
,
93 ConversionPatternRewriter
&rewriter
) const override
{
94 const TypeConverter
*converter
= ConvertToLLVMPattern::getTypeConverter();
95 SmallVector
<Type
> resTypes
;
96 if (failed(converter
->convertTypes(curOp
->getResultTypes(), resTypes
)))
98 SmallVector
<Value
> convertedOperands
;
99 assert(curOp
.getNumVariableOperands() ==
100 curOp
.getOperation()->getNumOperands() &&
101 "unexpected non-variable operands");
102 for (unsigned idx
= 0; idx
< curOp
.getNumVariableOperands(); ++idx
) {
103 Value originalVariableOperand
= curOp
.getVariableOperand(idx
);
104 if (!originalVariableOperand
)
106 if (isa
<MemRefType
>(originalVariableOperand
.getType())) {
107 // TODO: Support memref type in variable operands
108 return rewriter
.notifyMatchFailure(curOp
,
109 "memref is not supported yet");
111 convertedOperands
.emplace_back(adaptor
.getOperands()[idx
]);
113 auto newOp
= rewriter
.create
<T
>(curOp
.getLoc(), resTypes
, convertedOperands
,
115 rewriter
.inlineRegionBefore(curOp
.getRegion(), newOp
.getRegion(),
116 newOp
.getRegion().end());
117 if (failed(rewriter
.convertRegionTypes(&newOp
.getRegion(),
118 *this->getTypeConverter())))
121 rewriter
.eraseOp(curOp
);
126 template <typename T
>
127 struct RegionLessOpConversion
: public ConvertOpToLLVMPattern
<T
> {
128 using ConvertOpToLLVMPattern
<T
>::ConvertOpToLLVMPattern
;
130 matchAndRewrite(T curOp
, typename
T::Adaptor adaptor
,
131 ConversionPatternRewriter
&rewriter
) const override
{
132 const TypeConverter
*converter
= ConvertToLLVMPattern::getTypeConverter();
133 SmallVector
<Type
> resTypes
;
134 if (failed(converter
->convertTypes(curOp
->getResultTypes(), resTypes
)))
137 rewriter
.replaceOpWithNewOp
<T
>(curOp
, resTypes
, adaptor
.getOperands(),
143 struct AtomicReadOpConversion
144 : public ConvertOpToLLVMPattern
<omp::AtomicReadOp
> {
145 using ConvertOpToLLVMPattern
<omp::AtomicReadOp
>::ConvertOpToLLVMPattern
;
147 matchAndRewrite(omp::AtomicReadOp curOp
, OpAdaptor adaptor
,
148 ConversionPatternRewriter
&rewriter
) const override
{
149 const TypeConverter
*converter
= ConvertToLLVMPattern::getTypeConverter();
150 Type curElementType
= curOp
.getElementType();
151 auto newOp
= rewriter
.create
<omp::AtomicReadOp
>(
152 curOp
.getLoc(), TypeRange(), adaptor
.getOperands(), curOp
->getAttrs());
153 TypeAttr typeAttr
= TypeAttr::get(converter
->convertType(curElementType
));
154 newOp
.setElementTypeAttr(typeAttr
);
155 rewriter
.eraseOp(curOp
);
160 struct MapInfoOpConversion
: public ConvertOpToLLVMPattern
<omp::MapInfoOp
> {
161 using ConvertOpToLLVMPattern
<omp::MapInfoOp
>::ConvertOpToLLVMPattern
;
163 matchAndRewrite(omp::MapInfoOp curOp
, OpAdaptor adaptor
,
164 ConversionPatternRewriter
&rewriter
) const override
{
165 const TypeConverter
*converter
= ConvertToLLVMPattern::getTypeConverter();
167 SmallVector
<Type
> resTypes
;
168 if (failed(converter
->convertTypes(curOp
->getResultTypes(), resTypes
)))
171 // Copy attributes of the curOp except for the typeAttr which should
173 SmallVector
<NamedAttribute
> newAttrs
;
174 for (NamedAttribute attr
: curOp
->getAttrs()) {
175 if (auto typeAttr
= dyn_cast
<TypeAttr
>(attr
.getValue())) {
176 Type newAttr
= converter
->convertType(typeAttr
.getValue());
177 newAttrs
.emplace_back(attr
.getName(), TypeAttr::get(newAttr
));
179 newAttrs
.push_back(attr
);
183 rewriter
.replaceOpWithNewOp
<omp::MapInfoOp
>(
184 curOp
, resTypes
, adaptor
.getOperands(), newAttrs
);
189 template <typename OpType
>
190 struct MultiRegionOpConversion
: public ConvertOpToLLVMPattern
<OpType
> {
191 using ConvertOpToLLVMPattern
<OpType
>::ConvertOpToLLVMPattern
;
193 void forwardOpAttrs(OpType curOp
, OpType newOp
) const {}
196 matchAndRewrite(OpType curOp
, typename
OpType::Adaptor adaptor
,
197 ConversionPatternRewriter
&rewriter
) const override
{
198 auto newOp
= rewriter
.create
<OpType
>(
199 curOp
.getLoc(), TypeRange(), curOp
.getSymNameAttr(),
200 TypeAttr::get(this->getTypeConverter()->convertType(
201 curOp
.getTypeAttr().getValue())));
202 forwardOpAttrs(curOp
, newOp
);
204 for (unsigned idx
= 0; idx
< curOp
.getNumRegions(); idx
++) {
205 rewriter
.inlineRegionBefore(curOp
.getRegion(idx
), newOp
.getRegion(idx
),
206 newOp
.getRegion(idx
).end());
207 if (failed(rewriter
.convertRegionTypes(&newOp
.getRegion(idx
),
208 *this->getTypeConverter())))
212 rewriter
.eraseOp(curOp
);
218 void MultiRegionOpConversion
<omp::PrivateClauseOp
>::forwardOpAttrs(
219 omp::PrivateClauseOp curOp
, omp::PrivateClauseOp newOp
) const {
220 newOp
.setDataSharingType(curOp
.getDataSharingType());
224 void mlir::configureOpenMPToLLVMConversionLegality(
225 ConversionTarget
&target
, const LLVMTypeConverter
&typeConverter
) {
226 target
.addDynamicallyLegalOp
<
227 omp::AtomicReadOp
, omp::AtomicWriteOp
, omp::CancellationPointOp
,
228 omp::CancelOp
, omp::CriticalDeclareOp
, omp::FlushOp
, omp::MapBoundsOp
,
229 omp::MapInfoOp
, omp::OrderedOp
, omp::TargetEnterDataOp
,
230 omp::TargetExitDataOp
, omp::TargetUpdateOp
, omp::ThreadprivateOp
,
231 omp::YieldOp
>([&](Operation
*op
) {
232 return typeConverter
.isLegal(op
->getOperandTypes()) &&
233 typeConverter
.isLegal(op
->getResultTypes());
235 target
.addDynamicallyLegalOp
<
236 omp::AtomicUpdateOp
, omp::CriticalOp
, omp::DeclareReductionOp
,
237 omp::DistributeOp
, omp::LoopNestOp
, omp::LoopOp
, omp::MasterOp
,
238 omp::OrderedRegionOp
, omp::ParallelOp
, omp::PrivateClauseOp
,
239 omp::SectionOp
, omp::SectionsOp
, omp::SimdOp
, omp::SingleOp
,
240 omp::TargetDataOp
, omp::TargetOp
, omp::TaskgroupOp
, omp::TaskloopOp
,
241 omp::TaskOp
, omp::TeamsOp
, omp::WsloopOp
>([&](Operation
*op
) {
242 return std::all_of(op
->getRegions().begin(), op
->getRegions().end(),
243 [&](Region
®ion
) {
244 return typeConverter
.isLegal(®ion
);
246 typeConverter
.isLegal(op
->getOperandTypes()) &&
247 typeConverter
.isLegal(op
->getResultTypes());
251 void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter
&converter
,
252 RewritePatternSet
&patterns
) {
253 // This type is allowed when converting OpenMP to LLVM Dialect, it carries
254 // bounds information for map clauses and the operation and type are
255 // discarded on lowering to LLVM-IR from the OpenMP dialect.
256 converter
.addConversion(
257 [&](omp::MapBoundsType type
) -> Type
{ return type
; });
260 AtomicReadOpConversion
, MapInfoOpConversion
,
261 MultiRegionOpConversion
<omp::DeclareReductionOp
>,
262 MultiRegionOpConversion
<omp::PrivateClauseOp
>,
263 RegionLessOpConversion
<omp::CancellationPointOp
>,
264 RegionLessOpConversion
<omp::CancelOp
>,
265 RegionLessOpConversion
<omp::CriticalDeclareOp
>,
266 RegionLessOpConversion
<omp::OrderedOp
>,
267 RegionLessOpConversion
<omp::TargetEnterDataOp
>,
268 RegionLessOpConversion
<omp::TargetExitDataOp
>,
269 RegionLessOpConversion
<omp::TargetUpdateOp
>,
270 RegionLessOpConversion
<omp::YieldOp
>,
271 RegionLessOpWithVarOperandsConversion
<omp::AtomicWriteOp
>,
272 RegionLessOpWithVarOperandsConversion
<omp::FlushOp
>,
273 RegionLessOpWithVarOperandsConversion
<omp::MapBoundsOp
>,
274 RegionLessOpWithVarOperandsConversion
<omp::ThreadprivateOp
>,
275 RegionOpConversion
<omp::AtomicCaptureOp
>,
276 RegionOpConversion
<omp::CriticalOp
>,
277 RegionOpConversion
<omp::DistributeOp
>,
278 RegionOpConversion
<omp::LoopNestOp
>, RegionOpConversion
<omp::LoopOp
>,
279 RegionOpConversion
<omp::MaskedOp
>, RegionOpConversion
<omp::MasterOp
>,
280 RegionOpConversion
<omp::OrderedRegionOp
>,
281 RegionOpConversion
<omp::ParallelOp
>, RegionOpConversion
<omp::SectionOp
>,
282 RegionOpConversion
<omp::SectionsOp
>, RegionOpConversion
<omp::SimdOp
>,
283 RegionOpConversion
<omp::SingleOp
>, RegionOpConversion
<omp::TargetDataOp
>,
284 RegionOpConversion
<omp::TargetOp
>, RegionOpConversion
<omp::TaskgroupOp
>,
285 RegionOpConversion
<omp::TaskloopOp
>, RegionOpConversion
<omp::TaskOp
>,
286 RegionOpConversion
<omp::TeamsOp
>, RegionOpConversion
<omp::WsloopOp
>,
287 RegionOpWithVarOperandsConversion
<omp::AtomicUpdateOp
>>(converter
);
291 struct ConvertOpenMPToLLVMPass
292 : public impl::ConvertOpenMPToLLVMPassBase
<ConvertOpenMPToLLVMPass
> {
295 void runOnOperation() override
;
299 void ConvertOpenMPToLLVMPass::runOnOperation() {
300 auto module
= getOperation();
302 // Convert to OpenMP operations with LLVM IR dialect
303 RewritePatternSet
patterns(&getContext());
304 LLVMTypeConverter
converter(&getContext());
305 arith::populateArithToLLVMConversionPatterns(converter
, patterns
);
306 cf::populateControlFlowToLLVMConversionPatterns(converter
, patterns
);
307 populateFinalizeMemRefToLLVMConversionPatterns(converter
, patterns
);
308 populateFuncToLLVMConversionPatterns(converter
, patterns
);
309 populateOpenMPToLLVMConversionPatterns(converter
, patterns
);
311 LLVMConversionTarget
target(getContext());
312 target
.addLegalOp
<omp::BarrierOp
, omp::FlushOp
, omp::TaskwaitOp
,
313 omp::TaskyieldOp
, omp::TerminatorOp
>();
314 configureOpenMPToLLVMConversionLegality(target
, converter
);
315 if (failed(applyPartialConversion(module
, target
, std::move(patterns
))))
319 //===----------------------------------------------------------------------===//
320 // ConvertToLLVMPatternInterface implementation
321 //===----------------------------------------------------------------------===//
323 /// Implement the interface to convert OpenMP to LLVM.
324 struct OpenMPToLLVMDialectInterface
: public ConvertToLLVMPatternInterface
{
325 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface
;
326 void loadDependentDialects(MLIRContext
*context
) const final
{
327 context
->loadDialect
<LLVM::LLVMDialect
>();
330 /// Hook for derived dialect interface to provide conversion patterns
331 /// and mark dialect legal for the conversion target.
332 void populateConvertToLLVMConversionPatterns(
333 ConversionTarget
&target
, LLVMTypeConverter
&typeConverter
,
334 RewritePatternSet
&patterns
) const final
{
335 configureOpenMPToLLVMConversionLegality(target
, typeConverter
);
336 populateOpenMPToLLVMConversionPatterns(typeConverter
, patterns
);
341 void mlir::registerConvertOpenMPToLLVMInterface(DialectRegistry
®istry
) {
342 registry
.addExtension(+[](MLIRContext
*ctx
, omp::OpenMPDialect
*dialect
) {
343 dialect
->addInterfaces
<OpenMPToLLVMDialectInterface
>();