[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / lib / Conversion / OpenMPToLLVM / OpenMPToLLVM.cpp
blob58fd3d565fce50d7cb89cda120962e125d48f1ec
1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
11 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
12 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
13 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
14 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
15 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
16 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17 #include "mlir/Conversion/LLVMCommon/Pattern.h"
18 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
21 #include "mlir/Pass/Pass.h"
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
28 using namespace mlir;
30 namespace {
31 /// A pattern that converts the region arguments in a single-region OpenMP
32 /// operation to the LLVM dialect. The body of the region is not modified and is
33 /// expected to either be processed by the conversion infrastructure or already
34 /// contain ops compatible with LLVM dialect types.
35 template <typename OpType>
36 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
37 using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
39 LogicalResult
40 matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
41 ConversionPatternRewriter &rewriter) const override {
42 auto newOp = rewriter.create<OpType>(
43 curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
44 rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
45 newOp.getRegion().end());
46 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
47 *this->getTypeConverter())))
48 return failure();
50 rewriter.eraseOp(curOp);
51 return success();
55 template <typename T>
56 struct RegionLessOpWithVarOperandsConversion
57 : public ConvertOpToLLVMPattern<T> {
58 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
59 LogicalResult
60 matchAndRewrite(T curOp, typename T::Adaptor adaptor,
61 ConversionPatternRewriter &rewriter) const override {
62 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
63 SmallVector<Type> resTypes;
64 if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
65 return failure();
66 SmallVector<Value> convertedOperands;
67 assert(curOp.getNumVariableOperands() ==
68 curOp.getOperation()->getNumOperands() &&
69 "unexpected non-variable operands");
70 for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
71 Value originalVariableOperand = curOp.getVariableOperand(idx);
72 if (!originalVariableOperand)
73 return failure();
74 if (isa<MemRefType>(originalVariableOperand.getType())) {
75 // TODO: Support memref type in variable operands
76 return rewriter.notifyMatchFailure(curOp,
77 "memref is not supported yet");
79 convertedOperands.emplace_back(adaptor.getOperands()[idx]);
82 rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands,
83 curOp->getAttrs());
84 return success();
88 template <typename T>
89 struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> {
90 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
91 LogicalResult
92 matchAndRewrite(T curOp, typename T::Adaptor adaptor,
93 ConversionPatternRewriter &rewriter) const override {
94 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
95 SmallVector<Type> resTypes;
96 if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
97 return failure();
98 SmallVector<Value> convertedOperands;
99 assert(curOp.getNumVariableOperands() ==
100 curOp.getOperation()->getNumOperands() &&
101 "unexpected non-variable operands");
102 for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
103 Value originalVariableOperand = curOp.getVariableOperand(idx);
104 if (!originalVariableOperand)
105 return failure();
106 if (isa<MemRefType>(originalVariableOperand.getType())) {
107 // TODO: Support memref type in variable operands
108 return rewriter.notifyMatchFailure(curOp,
109 "memref is not supported yet");
111 convertedOperands.emplace_back(adaptor.getOperands()[idx]);
113 auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
114 curOp->getAttrs());
115 rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
116 newOp.getRegion().end());
117 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
118 *this->getTypeConverter())))
119 return failure();
121 rewriter.eraseOp(curOp);
122 return success();
126 template <typename T>
127 struct RegionLessOpConversion : public ConvertOpToLLVMPattern<T> {
128 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
129 LogicalResult
130 matchAndRewrite(T curOp, typename T::Adaptor adaptor,
131 ConversionPatternRewriter &rewriter) const override {
132 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
133 SmallVector<Type> resTypes;
134 if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
135 return failure();
137 rewriter.replaceOpWithNewOp<T>(curOp, resTypes, adaptor.getOperands(),
138 curOp->getAttrs());
139 return success();
143 struct AtomicReadOpConversion
144 : public ConvertOpToLLVMPattern<omp::AtomicReadOp> {
145 using ConvertOpToLLVMPattern<omp::AtomicReadOp>::ConvertOpToLLVMPattern;
146 LogicalResult
147 matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor,
148 ConversionPatternRewriter &rewriter) const override {
149 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
150 Type curElementType = curOp.getElementType();
151 auto newOp = rewriter.create<omp::AtomicReadOp>(
152 curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
153 TypeAttr typeAttr = TypeAttr::get(converter->convertType(curElementType));
154 newOp.setElementTypeAttr(typeAttr);
155 rewriter.eraseOp(curOp);
156 return success();
160 struct MapInfoOpConversion : public ConvertOpToLLVMPattern<omp::MapInfoOp> {
161 using ConvertOpToLLVMPattern<omp::MapInfoOp>::ConvertOpToLLVMPattern;
162 LogicalResult
163 matchAndRewrite(omp::MapInfoOp curOp, OpAdaptor adaptor,
164 ConversionPatternRewriter &rewriter) const override {
165 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
167 SmallVector<Type> resTypes;
168 if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
169 return failure();
171 // Copy attributes of the curOp except for the typeAttr which should
172 // be converted
173 SmallVector<NamedAttribute> newAttrs;
174 for (NamedAttribute attr : curOp->getAttrs()) {
175 if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
176 Type newAttr = converter->convertType(typeAttr.getValue());
177 newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
178 } else {
179 newAttrs.push_back(attr);
183 rewriter.replaceOpWithNewOp<omp::MapInfoOp>(
184 curOp, resTypes, adaptor.getOperands(), newAttrs);
185 return success();
189 template <typename OpType>
190 struct MultiRegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
191 using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
193 void forwardOpAttrs(OpType curOp, OpType newOp) const {}
195 LogicalResult
196 matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
197 ConversionPatternRewriter &rewriter) const override {
198 auto newOp = rewriter.create<OpType>(
199 curOp.getLoc(), TypeRange(), curOp.getSymNameAttr(),
200 TypeAttr::get(this->getTypeConverter()->convertType(
201 curOp.getTypeAttr().getValue())));
202 forwardOpAttrs(curOp, newOp);
204 for (unsigned idx = 0; idx < curOp.getNumRegions(); idx++) {
205 rewriter.inlineRegionBefore(curOp.getRegion(idx), newOp.getRegion(idx),
206 newOp.getRegion(idx).end());
207 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(idx),
208 *this->getTypeConverter())))
209 return failure();
212 rewriter.eraseOp(curOp);
213 return success();
217 template <>
218 void MultiRegionOpConversion<omp::PrivateClauseOp>::forwardOpAttrs(
219 omp::PrivateClauseOp curOp, omp::PrivateClauseOp newOp) const {
220 newOp.setDataSharingType(curOp.getDataSharingType());
222 } // namespace
224 void mlir::configureOpenMPToLLVMConversionLegality(
225 ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
226 target.addDynamicallyLegalOp<
227 omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
228 omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
229 omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp,
230 omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp,
231 omp::YieldOp>([&](Operation *op) {
232 return typeConverter.isLegal(op->getOperandTypes()) &&
233 typeConverter.isLegal(op->getResultTypes());
235 target.addDynamicallyLegalOp<
236 omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareReductionOp,
237 omp::DistributeOp, omp::LoopNestOp, omp::LoopOp, omp::MasterOp,
238 omp::OrderedRegionOp, omp::ParallelOp, omp::PrivateClauseOp,
239 omp::SectionOp, omp::SectionsOp, omp::SimdOp, omp::SingleOp,
240 omp::TargetDataOp, omp::TargetOp, omp::TaskgroupOp, omp::TaskloopOp,
241 omp::TaskOp, omp::TeamsOp, omp::WsloopOp>([&](Operation *op) {
242 return std::all_of(op->getRegions().begin(), op->getRegions().end(),
243 [&](Region &region) {
244 return typeConverter.isLegal(&region);
245 }) &&
246 typeConverter.isLegal(op->getOperandTypes()) &&
247 typeConverter.isLegal(op->getResultTypes());
251 void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
252 RewritePatternSet &patterns) {
253 // This type is allowed when converting OpenMP to LLVM Dialect, it carries
254 // bounds information for map clauses and the operation and type are
255 // discarded on lowering to LLVM-IR from the OpenMP dialect.
256 converter.addConversion(
257 [&](omp::MapBoundsType type) -> Type { return type; });
259 patterns.add<
260 AtomicReadOpConversion, MapInfoOpConversion,
261 MultiRegionOpConversion<omp::DeclareReductionOp>,
262 MultiRegionOpConversion<omp::PrivateClauseOp>,
263 RegionLessOpConversion<omp::CancellationPointOp>,
264 RegionLessOpConversion<omp::CancelOp>,
265 RegionLessOpConversion<omp::CriticalDeclareOp>,
266 RegionLessOpConversion<omp::OrderedOp>,
267 RegionLessOpConversion<omp::TargetEnterDataOp>,
268 RegionLessOpConversion<omp::TargetExitDataOp>,
269 RegionLessOpConversion<omp::TargetUpdateOp>,
270 RegionLessOpConversion<omp::YieldOp>,
271 RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
272 RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
273 RegionLessOpWithVarOperandsConversion<omp::MapBoundsOp>,
274 RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>,
275 RegionOpConversion<omp::AtomicCaptureOp>,
276 RegionOpConversion<omp::CriticalOp>,
277 RegionOpConversion<omp::DistributeOp>,
278 RegionOpConversion<omp::LoopNestOp>, RegionOpConversion<omp::LoopOp>,
279 RegionOpConversion<omp::MaskedOp>, RegionOpConversion<omp::MasterOp>,
280 RegionOpConversion<omp::OrderedRegionOp>,
281 RegionOpConversion<omp::ParallelOp>, RegionOpConversion<omp::SectionOp>,
282 RegionOpConversion<omp::SectionsOp>, RegionOpConversion<omp::SimdOp>,
283 RegionOpConversion<omp::SingleOp>, RegionOpConversion<omp::TargetDataOp>,
284 RegionOpConversion<omp::TargetOp>, RegionOpConversion<omp::TaskgroupOp>,
285 RegionOpConversion<omp::TaskloopOp>, RegionOpConversion<omp::TaskOp>,
286 RegionOpConversion<omp::TeamsOp>, RegionOpConversion<omp::WsloopOp>,
287 RegionOpWithVarOperandsConversion<omp::AtomicUpdateOp>>(converter);
290 namespace {
291 struct ConvertOpenMPToLLVMPass
292 : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
293 using Base::Base;
295 void runOnOperation() override;
297 } // namespace
299 void ConvertOpenMPToLLVMPass::runOnOperation() {
300 auto module = getOperation();
302 // Convert to OpenMP operations with LLVM IR dialect
303 RewritePatternSet patterns(&getContext());
304 LLVMTypeConverter converter(&getContext());
305 arith::populateArithToLLVMConversionPatterns(converter, patterns);
306 cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
307 populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
308 populateFuncToLLVMConversionPatterns(converter, patterns);
309 populateOpenMPToLLVMConversionPatterns(converter, patterns);
311 LLVMConversionTarget target(getContext());
312 target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp,
313 omp::TaskyieldOp, omp::TerminatorOp>();
314 configureOpenMPToLLVMConversionLegality(target, converter);
315 if (failed(applyPartialConversion(module, target, std::move(patterns))))
316 signalPassFailure();
319 //===----------------------------------------------------------------------===//
320 // ConvertToLLVMPatternInterface implementation
321 //===----------------------------------------------------------------------===//
322 namespace {
323 /// Implement the interface to convert OpenMP to LLVM.
324 struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
325 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
326 void loadDependentDialects(MLIRContext *context) const final {
327 context->loadDialect<LLVM::LLVMDialect>();
330 /// Hook for derived dialect interface to provide conversion patterns
331 /// and mark dialect legal for the conversion target.
332 void populateConvertToLLVMConversionPatterns(
333 ConversionTarget &target, LLVMTypeConverter &typeConverter,
334 RewritePatternSet &patterns) const final {
335 configureOpenMPToLLVMConversionLegality(target, typeConverter);
336 populateOpenMPToLLVMConversionPatterns(typeConverter, patterns);
339 } // namespace
341 void mlir::registerConvertOpenMPToLLVMInterface(DialectRegistry &registry) {
342 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
343 dialect->addInterfaces<OpenMPToLLVMDialectInterface>();