1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "GPUOpsLowering.h"
11 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "llvm/ADT/SmallVectorExtras.h"
17 #include "llvm/ADT/StringSet.h"
18 #include "llvm/Support/FormatVariadic.h"
23 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp
, OpAdaptor adaptor
,
24 ConversionPatternRewriter
&rewriter
) const {
25 Location loc
= gpuFuncOp
.getLoc();
27 SmallVector
<LLVM::GlobalOp
, 3> workgroupBuffers
;
28 if (encodeWorkgroupAttributionsAsArguments
) {
29 // Append an `llvm.ptr` argument to the function signature to encode
30 // workgroup attributions.
32 ArrayRef
<BlockArgument
> workgroupAttributions
=
33 gpuFuncOp
.getWorkgroupAttributions();
34 size_t numAttributions
= workgroupAttributions
.size();
36 // Insert all arguments at the end.
37 unsigned index
= gpuFuncOp
.getNumArguments();
38 SmallVector
<unsigned> argIndices(numAttributions
, index
);
40 // New arguments will simply be `llvm.ptr` with the correct address space
41 Type workgroupPtrType
=
42 rewriter
.getType
<LLVM::LLVMPointerType
>(workgroupAddrSpace
);
43 SmallVector
<Type
> argTypes(numAttributions
, workgroupPtrType
);
45 // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
47 rewriter
.getNamedAttr(LLVM::LLVMDialect::getNoAliasAttrName(),
48 rewriter
.getUnitAttr()),
49 rewriter
.getNamedAttr(
50 getDialect().getWorkgroupAttributionAttrHelper().getName(),
51 rewriter
.getUnitAttr()),
53 SmallVector
<DictionaryAttr
> argAttrs
;
54 for (BlockArgument attribution
: workgroupAttributions
) {
55 auto attributionType
= cast
<MemRefType
>(attribution
.getType());
56 IntegerAttr numElements
=
57 rewriter
.getI64IntegerAttr(attributionType
.getNumElements());
58 Type llvmElementType
=
59 getTypeConverter()->convertType(attributionType
.getElementType());
62 TypeAttr type
= TypeAttr::get(llvmElementType
);
63 attrs
.back().setValue(
64 rewriter
.getAttr
<LLVM::WorkgroupAttributionAttr
>(numElements
, type
));
65 argAttrs
.push_back(rewriter
.getDictionaryAttr(attrs
));
68 // Location match function location
69 SmallVector
<Location
> argLocs(numAttributions
, gpuFuncOp
.getLoc());
71 // Perform signature modification
72 rewriter
.modifyOpInPlace(
73 gpuFuncOp
, [gpuFuncOp
, &argIndices
, &argTypes
, &argAttrs
, &argLocs
]() {
74 static_cast<FunctionOpInterface
>(gpuFuncOp
).insertArguments(
75 argIndices
, argTypes
, argAttrs
, argLocs
);
78 workgroupBuffers
.reserve(gpuFuncOp
.getNumWorkgroupAttributions());
79 for (auto [idx
, attribution
] :
80 llvm::enumerate(gpuFuncOp
.getWorkgroupAttributions())) {
81 auto type
= dyn_cast
<MemRefType
>(attribution
.getType());
82 assert(type
&& type
.hasStaticShape() && "unexpected type in attribution");
84 uint64_t numElements
= type
.getNumElements();
87 cast
<Type
>(typeConverter
->convertType(type
.getElementType()));
88 auto arrayType
= LLVM::LLVMArrayType::get(elementType
, numElements
);
90 std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp
.getName(), idx
));
91 uint64_t alignment
= 0;
92 if (auto alignAttr
= dyn_cast_or_null
<IntegerAttr
>(
93 gpuFuncOp
.getWorkgroupAttributionAttr(
94 idx
, LLVM::LLVMDialect::getAlignAttrName())))
95 alignment
= alignAttr
.getInt();
96 auto globalOp
= rewriter
.create
<LLVM::GlobalOp
>(
97 gpuFuncOp
.getLoc(), arrayType
, /*isConstant=*/false,
98 LLVM::Linkage::Internal
, name
, /*value=*/Attribute(), alignment
,
100 workgroupBuffers
.push_back(globalOp
);
104 // Remap proper input types.
105 TypeConverter::SignatureConversion
signatureConversion(
106 gpuFuncOp
.front().getNumArguments());
108 Type funcType
= getTypeConverter()->convertFunctionSignature(
109 gpuFuncOp
.getFunctionType(), /*isVariadic=*/false,
110 getTypeConverter()->getOptions().useBarePtrCallConv
, signatureConversion
);
112 return rewriter
.notifyMatchFailure(gpuFuncOp
, [&](Diagnostic
&diag
) {
113 diag
<< "failed to convert function signature type for: "
114 << gpuFuncOp
.getFunctionType();
118 // Create the new function operation. Only copy those attributes that are
119 // not specific to function modeling.
120 SmallVector
<NamedAttribute
, 4> attributes
;
122 for (const auto &attr
: gpuFuncOp
->getAttrs()) {
123 if (attr
.getName() == SymbolTable::getSymbolAttrName() ||
124 attr
.getName() == gpuFuncOp
.getFunctionTypeAttrName() ||
126 gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() ||
127 attr
.getName() == gpuFuncOp
.getWorkgroupAttribAttrsAttrName() ||
128 attr
.getName() == gpuFuncOp
.getPrivateAttribAttrsAttrName() ||
129 attr
.getName() == gpuFuncOp
.getKnownBlockSizeAttrName() ||
130 attr
.getName() == gpuFuncOp
.getKnownGridSizeAttrName())
132 if (attr
.getName() == gpuFuncOp
.getArgAttrsAttrName()) {
133 argAttrs
= gpuFuncOp
.getArgAttrsAttr();
136 attributes
.push_back(attr
);
139 DenseI32ArrayAttr knownBlockSize
= gpuFuncOp
.getKnownBlockSizeAttr();
140 DenseI32ArrayAttr knownGridSize
= gpuFuncOp
.getKnownGridSizeAttr();
141 // Ensure we don't lose information if the function is lowered before its
142 // surrounding context.
143 auto *gpuDialect
= cast
<gpu::GPUDialect
>(gpuFuncOp
->getDialect());
145 attributes
.emplace_back(gpuDialect
->getKnownBlockSizeAttrHelper().getName(),
148 attributes
.emplace_back(gpuDialect
->getKnownGridSizeAttrHelper().getName(),
151 // Add a dialect specific kernel attribute in addition to GPU kernel
152 // attribute. The former is necessary for further translation while the
153 // latter is expected by gpu.launch_func.
154 if (gpuFuncOp
.isKernel()) {
155 if (kernelAttributeName
)
156 attributes
.emplace_back(kernelAttributeName
, rewriter
.getUnitAttr());
157 // Set the dialect-specific block size attribute if there is one.
158 if (kernelBlockSizeAttributeName
&& knownBlockSize
) {
159 attributes
.emplace_back(kernelBlockSizeAttributeName
, knownBlockSize
);
162 LLVM::CConv callingConvention
= gpuFuncOp
.isKernel()
163 ? kernelCallingConvention
164 : nonKernelCallingConvention
;
165 auto llvmFuncOp
= rewriter
.create
<LLVM::LLVMFuncOp
>(
166 gpuFuncOp
.getLoc(), gpuFuncOp
.getName(), funcType
,
167 LLVM::Linkage::External
, /*dsoLocal=*/false, callingConvention
,
168 /*comdat=*/nullptr, attributes
);
171 // Insert operations that correspond to converted workgroup and private
172 // memory attributions to the body of the function. This must operate on
173 // the original function, before the body region is inlined in the new
174 // function to maintain the relation between block arguments and the
175 // parent operation that assigns their semantics.
176 OpBuilder::InsertionGuard
guard(rewriter
);
178 // Rewrite workgroup memory attributions to addresses of global buffers.
179 rewriter
.setInsertionPointToStart(&gpuFuncOp
.front());
180 unsigned numProperArguments
= gpuFuncOp
.getNumArguments();
182 if (encodeWorkgroupAttributionsAsArguments
) {
183 // Build a MemRefDescriptor with each of the arguments added above.
185 unsigned numAttributions
= gpuFuncOp
.getNumWorkgroupAttributions();
186 assert(numProperArguments
>= numAttributions
&&
187 "Expecting attributions to be encoded as arguments already");
189 // Arguments encoding workgroup attributions will be in positions
190 // [numProperArguments, numProperArguments+numAttributions)
191 ArrayRef
<BlockArgument
> attributionArguments
=
192 gpuFuncOp
.getArguments().slice(numProperArguments
- numAttributions
,
194 for (auto [idx
, vals
] : llvm::enumerate(llvm::zip_equal(
195 gpuFuncOp
.getWorkgroupAttributions(), attributionArguments
))) {
196 auto [attribution
, arg
] = vals
;
197 auto type
= cast
<MemRefType
>(attribution
.getType());
199 // Arguments are of llvm.ptr type and attributions are of memref type:
200 // we need to wrap them in memref descriptors.
201 Value descr
= MemRefDescriptor::fromStaticShape(
202 rewriter
, loc
, *getTypeConverter(), type
, arg
);
204 // And remap the arguments
205 signatureConversion
.remapInput(numProperArguments
+ idx
, descr
);
208 for (const auto [idx
, global
] : llvm::enumerate(workgroupBuffers
)) {
209 auto ptrType
= LLVM::LLVMPointerType::get(rewriter
.getContext(),
210 global
.getAddrSpace());
211 Value address
= rewriter
.create
<LLVM::AddressOfOp
>(
212 loc
, ptrType
, global
.getSymNameAttr());
214 rewriter
.create
<LLVM::GEPOp
>(loc
, ptrType
, global
.getType(),
215 address
, ArrayRef
<LLVM::GEPArg
>{0, 0});
217 // Build a memref descriptor pointing to the buffer to plug with the
218 // existing memref infrastructure. This may use more registers than
219 // otherwise necessary given that memref sizes are fixed, but we can try
220 // and canonicalize that away later.
221 Value attribution
= gpuFuncOp
.getWorkgroupAttributions()[idx
];
222 auto type
= cast
<MemRefType
>(attribution
.getType());
223 auto descr
= MemRefDescriptor::fromStaticShape(
224 rewriter
, loc
, *getTypeConverter(), type
, memory
);
225 signatureConversion
.remapInput(numProperArguments
+ idx
, descr
);
229 // Rewrite private memory attributions to alloca'ed buffers.
230 unsigned numWorkgroupAttributions
= gpuFuncOp
.getNumWorkgroupAttributions();
231 auto int64Ty
= IntegerType::get(rewriter
.getContext(), 64);
232 for (const auto [idx
, attribution
] :
233 llvm::enumerate(gpuFuncOp
.getPrivateAttributions())) {
234 auto type
= cast
<MemRefType
>(attribution
.getType());
235 assert(type
&& type
.hasStaticShape() && "unexpected type in attribution");
237 // Explicitly drop memory space when lowering private memory
238 // attributions since NVVM models it as `alloca`s in the default
239 // memory space and does not support `alloca`s with addrspace(5).
240 Type elementType
= typeConverter
->convertType(type
.getElementType());
242 LLVM::LLVMPointerType::get(rewriter
.getContext(), allocaAddrSpace
);
243 Value numElements
= rewriter
.create
<LLVM::ConstantOp
>(
244 gpuFuncOp
.getLoc(), int64Ty
, type
.getNumElements());
245 uint64_t alignment
= 0;
247 dyn_cast_or_null
<IntegerAttr
>(gpuFuncOp
.getPrivateAttributionAttr(
248 idx
, LLVM::LLVMDialect::getAlignAttrName())))
249 alignment
= alignAttr
.getInt();
250 Value allocated
= rewriter
.create
<LLVM::AllocaOp
>(
251 gpuFuncOp
.getLoc(), ptrType
, elementType
, numElements
, alignment
);
252 auto descr
= MemRefDescriptor::fromStaticShape(
253 rewriter
, loc
, *getTypeConverter(), type
, allocated
);
254 signatureConversion
.remapInput(
255 numProperArguments
+ numWorkgroupAttributions
+ idx
, descr
);
259 // Move the region to the new function, update the entry block signature.
260 rewriter
.inlineRegionBefore(gpuFuncOp
.getBody(), llvmFuncOp
.getBody(),
262 if (failed(rewriter
.convertRegionTypes(&llvmFuncOp
.getBody(), *typeConverter
,
263 &signatureConversion
)))
266 // Get memref type from function arguments and set the noalias to
267 // pointer arguments.
268 for (const auto [idx
, argTy
] :
269 llvm::enumerate(gpuFuncOp
.getArgumentTypes())) {
270 auto remapping
= signatureConversion
.getInputMapping(idx
);
271 NamedAttrList argAttr
=
272 argAttrs
? cast
<DictionaryAttr
>(argAttrs
[idx
]) : NamedAttrList();
273 auto copyAttribute
= [&](StringRef attrName
) {
274 Attribute attr
= argAttr
.erase(attrName
);
277 for (size_t i
= 0, e
= remapping
->size
; i
< e
; ++i
)
278 llvmFuncOp
.setArgAttr(remapping
->inputNo
+ i
, attrName
, attr
);
280 auto copyPointerAttribute
= [&](StringRef attrName
) {
281 Attribute attr
= argAttr
.erase(attrName
);
285 if (remapping
->size
> 1 &&
286 attrName
== LLVM::LLVMDialect::getNoAliasAttrName()) {
287 emitWarning(llvmFuncOp
.getLoc(),
288 "Cannot copy noalias with non-bare pointers.\n");
291 for (size_t i
= 0, e
= remapping
->size
; i
< e
; ++i
) {
292 if (isa
<LLVM::LLVMPointerType
>(
293 llvmFuncOp
.getArgument(remapping
->inputNo
+ i
).getType())) {
294 llvmFuncOp
.setArgAttr(remapping
->inputNo
+ i
, attrName
, attr
);
302 copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
303 copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
304 copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
305 bool lowersToPointer
= false;
306 for (size_t i
= 0, e
= remapping
->size
; i
< e
; ++i
) {
307 lowersToPointer
|= isa
<LLVM::LLVMPointerType
>(
308 llvmFuncOp
.getArgument(remapping
->inputNo
+ i
).getType());
311 if (lowersToPointer
) {
312 copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
313 copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
314 copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
315 copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
316 copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
317 copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
318 copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
319 copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
320 copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
321 copyPointerAttribute(
322 LLVM::LLVMDialect::getDereferenceableOrNullAttrName());
323 copyPointerAttribute(
324 LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr());
327 rewriter
.eraseOp(gpuFuncOp
);
331 static SmallString
<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp
) {
332 const char formatStringPrefix
[] = "printfFormat_";
333 // Get a unique global name.
334 unsigned stringNumber
= 0;
335 SmallString
<16> stringConstName
;
337 stringConstName
.clear();
338 (formatStringPrefix
+ Twine(stringNumber
++)).toStringRef(stringConstName
);
339 } while (moduleOp
.lookupSymbol(stringConstName
));
340 return stringConstName
;
343 /// Create an global that contains the given format string. If a global with
344 /// the same format string exists already in the module, return that global.
345 static LLVM::GlobalOp
getOrCreateFormatStringConstant(
346 OpBuilder
&b
, Location loc
, gpu::GPUModuleOp moduleOp
, Type llvmI8
,
347 StringRef str
, uint64_t alignment
= 0, unsigned addrSpace
= 0) {
348 llvm::SmallString
<20> formatString(str
);
349 formatString
.push_back('\0'); // Null terminate for C
351 LLVM::LLVMArrayType::get(llvmI8
, formatString
.size_in_bytes());
352 StringAttr attr
= b
.getStringAttr(formatString
);
354 // Try to find existing global.
355 for (auto globalOp
: moduleOp
.getOps
<LLVM::GlobalOp
>())
356 if (globalOp
.getGlobalType() == globalType
&& globalOp
.getConstant() &&
357 globalOp
.getValueAttr() == attr
&&
358 globalOp
.getAlignment().value_or(0) == alignment
&&
359 globalOp
.getAddrSpace() == addrSpace
)
362 // Not found: create new global.
363 OpBuilder::InsertionGuard
guard(b
);
364 b
.setInsertionPointToStart(moduleOp
.getBody());
365 SmallString
<16> name
= getUniqueFormatGlobalName(moduleOp
);
366 return b
.create
<LLVM::GlobalOp
>(loc
, globalType
,
367 /*isConstant=*/true, LLVM::Linkage::Internal
,
368 name
, attr
, alignment
, addrSpace
);
371 template <typename T
>
372 static LLVM::LLVMFuncOp
getOrDefineFunction(T
&moduleOp
, const Location loc
,
373 ConversionPatternRewriter
&rewriter
,
375 LLVM::LLVMFunctionType type
) {
376 LLVM::LLVMFuncOp ret
;
377 if (!(ret
= moduleOp
.template lookupSymbol
<LLVM::LLVMFuncOp
>(name
))) {
378 ConversionPatternRewriter::InsertionGuard
guard(rewriter
);
379 rewriter
.setInsertionPointToStart(moduleOp
.getBody());
380 ret
= rewriter
.create
<LLVM::LLVMFuncOp
>(loc
, name
, type
,
381 LLVM::Linkage::External
);
386 LogicalResult
GPUPrintfOpToHIPLowering::matchAndRewrite(
387 gpu::PrintfOp gpuPrintfOp
, gpu::PrintfOpAdaptor adaptor
,
388 ConversionPatternRewriter
&rewriter
) const {
389 Location loc
= gpuPrintfOp
->getLoc();
391 mlir::Type llvmI8
= typeConverter
->convertType(rewriter
.getI8Type());
392 auto ptrType
= LLVM::LLVMPointerType::get(rewriter
.getContext());
393 mlir::Type llvmI32
= typeConverter
->convertType(rewriter
.getI32Type());
394 mlir::Type llvmI64
= typeConverter
->convertType(rewriter
.getI64Type());
395 // Note: this is the GPUModule op, not the ModuleOp that surrounds it
396 // This ensures that global constants and declarations are placed within
397 // the device code, not the host code
398 auto moduleOp
= gpuPrintfOp
->getParentOfType
<gpu::GPUModuleOp
>();
401 getOrDefineFunction(moduleOp
, loc
, rewriter
, "__ockl_printf_begin",
402 LLVM::LLVMFunctionType::get(llvmI64
, {llvmI64
}));
403 LLVM::LLVMFuncOp ocklAppendArgs
;
404 if (!adaptor
.getArgs().empty()) {
405 ocklAppendArgs
= getOrDefineFunction(
406 moduleOp
, loc
, rewriter
, "__ockl_printf_append_args",
407 LLVM::LLVMFunctionType::get(
408 llvmI64
, {llvmI64
, /*numArgs*/ llvmI32
, llvmI64
, llvmI64
, llvmI64
,
409 llvmI64
, llvmI64
, llvmI64
, llvmI64
, /*isLast*/ llvmI32
}));
411 auto ocklAppendStringN
= getOrDefineFunction(
412 moduleOp
, loc
, rewriter
, "__ockl_printf_append_string_n",
413 LLVM::LLVMFunctionType::get(
415 {llvmI64
, ptrType
, /*length (bytes)*/ llvmI64
, /*isLast*/ llvmI32
}));
417 /// Start the printf hostcall
418 Value zeroI64
= rewriter
.create
<LLVM::ConstantOp
>(loc
, llvmI64
, 0);
419 auto printfBeginCall
= rewriter
.create
<LLVM::CallOp
>(loc
, ocklBegin
, zeroI64
);
420 Value printfDesc
= printfBeginCall
.getResult();
422 // Create the global op or find an existing one.
423 LLVM::GlobalOp global
= getOrCreateFormatStringConstant(
424 rewriter
, loc
, moduleOp
, llvmI8
, adaptor
.getFormat());
426 // Get a pointer to the format string's first element and pass it to printf()
427 Value globalPtr
= rewriter
.create
<LLVM::AddressOfOp
>(
429 LLVM::LLVMPointerType::get(rewriter
.getContext(), global
.getAddrSpace()),
430 global
.getSymNameAttr());
432 rewriter
.create
<LLVM::GEPOp
>(loc
, ptrType
, global
.getGlobalType(),
433 globalPtr
, ArrayRef
<LLVM::GEPArg
>{0, 0});
434 Value stringLen
= rewriter
.create
<LLVM::ConstantOp
>(
435 loc
, llvmI64
, cast
<StringAttr
>(global
.getValueAttr()).size());
437 Value oneI32
= rewriter
.create
<LLVM::ConstantOp
>(loc
, llvmI32
, 1);
438 Value zeroI32
= rewriter
.create
<LLVM::ConstantOp
>(loc
, llvmI32
, 0);
440 auto appendFormatCall
= rewriter
.create
<LLVM::CallOp
>(
441 loc
, ocklAppendStringN
,
442 ValueRange
{printfDesc
, stringStart
, stringLen
,
443 adaptor
.getArgs().empty() ? oneI32
: zeroI32
});
444 printfDesc
= appendFormatCall
.getResult();
446 // __ockl_printf_append_args takes 7 values per append call
447 constexpr size_t argsPerAppend
= 7;
448 size_t nArgs
= adaptor
.getArgs().size();
449 for (size_t group
= 0; group
< nArgs
; group
+= argsPerAppend
) {
450 size_t bound
= std::min(group
+ argsPerAppend
, nArgs
);
451 size_t numArgsThisCall
= bound
- group
;
453 SmallVector
<mlir::Value
, 2 + argsPerAppend
+ 1> arguments
;
454 arguments
.push_back(printfDesc
);
456 rewriter
.create
<LLVM::ConstantOp
>(loc
, llvmI32
, numArgsThisCall
));
457 for (size_t i
= group
; i
< bound
; ++i
) {
458 Value arg
= adaptor
.getArgs()[i
];
459 if (auto floatType
= dyn_cast
<FloatType
>(arg
.getType())) {
460 if (!floatType
.isF64())
461 arg
= rewriter
.create
<LLVM::FPExtOp
>(
462 loc
, typeConverter
->convertType(rewriter
.getF64Type()), arg
);
463 arg
= rewriter
.create
<LLVM::BitcastOp
>(loc
, llvmI64
, arg
);
465 if (arg
.getType().getIntOrFloatBitWidth() != 64)
466 arg
= rewriter
.create
<LLVM::ZExtOp
>(loc
, llvmI64
, arg
);
468 arguments
.push_back(arg
);
470 // Pad out to 7 arguments since the hostcall always needs 7
471 for (size_t extra
= numArgsThisCall
; extra
< argsPerAppend
; ++extra
) {
472 arguments
.push_back(zeroI64
);
475 auto isLast
= (bound
== nArgs
) ? oneI32
: zeroI32
;
476 arguments
.push_back(isLast
);
477 auto call
= rewriter
.create
<LLVM::CallOp
>(loc
, ocklAppendArgs
, arguments
);
478 printfDesc
= call
.getResult();
480 rewriter
.eraseOp(gpuPrintfOp
);
484 LogicalResult
GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
485 gpu::PrintfOp gpuPrintfOp
, gpu::PrintfOpAdaptor adaptor
,
486 ConversionPatternRewriter
&rewriter
) const {
487 Location loc
= gpuPrintfOp
->getLoc();
489 mlir::Type llvmI8
= typeConverter
->convertType(rewriter
.getIntegerType(8));
491 LLVM::LLVMPointerType::get(rewriter
.getContext(), addressSpace
);
493 // Note: this is the GPUModule op, not the ModuleOp that surrounds it
494 // This ensures that global constants and declarations are placed within
495 // the device code, not the host code
496 auto moduleOp
= gpuPrintfOp
->getParentOfType
<gpu::GPUModuleOp
>();
499 LLVM::LLVMFunctionType::get(rewriter
.getI32Type(), {ptrType
},
501 LLVM::LLVMFuncOp printfDecl
=
502 getOrDefineFunction(moduleOp
, loc
, rewriter
, "printf", printfType
);
504 // Create the global op or find an existing one.
505 LLVM::GlobalOp global
= getOrCreateFormatStringConstant(
506 rewriter
, loc
, moduleOp
, llvmI8
, adaptor
.getFormat(), /*alignment=*/0,
509 // Get a pointer to the format string's first element
510 Value globalPtr
= rewriter
.create
<LLVM::AddressOfOp
>(
512 LLVM::LLVMPointerType::get(rewriter
.getContext(), global
.getAddrSpace()),
513 global
.getSymNameAttr());
515 rewriter
.create
<LLVM::GEPOp
>(loc
, ptrType
, global
.getGlobalType(),
516 globalPtr
, ArrayRef
<LLVM::GEPArg
>{0, 0});
518 // Construct arguments and function call
519 auto argsRange
= adaptor
.getArgs();
520 SmallVector
<Value
, 4> printfArgs
;
521 printfArgs
.reserve(argsRange
.size() + 1);
522 printfArgs
.push_back(stringStart
);
523 printfArgs
.append(argsRange
.begin(), argsRange
.end());
525 rewriter
.create
<LLVM::CallOp
>(loc
, printfDecl
, printfArgs
);
526 rewriter
.eraseOp(gpuPrintfOp
);
530 LogicalResult
GPUPrintfOpToVPrintfLowering::matchAndRewrite(
531 gpu::PrintfOp gpuPrintfOp
, gpu::PrintfOpAdaptor adaptor
,
532 ConversionPatternRewriter
&rewriter
) const {
533 Location loc
= gpuPrintfOp
->getLoc();
535 mlir::Type llvmI8
= typeConverter
->convertType(rewriter
.getIntegerType(8));
536 mlir::Type ptrType
= LLVM::LLVMPointerType::get(rewriter
.getContext());
538 // Note: this is the GPUModule op, not the ModuleOp that surrounds it
539 // This ensures that global constants and declarations are placed within
540 // the device code, not the host code
541 auto moduleOp
= gpuPrintfOp
->getParentOfType
<gpu::GPUModuleOp
>();
544 LLVM::LLVMFunctionType::get(rewriter
.getI32Type(), {ptrType
, ptrType
});
545 LLVM::LLVMFuncOp vprintfDecl
=
546 getOrDefineFunction(moduleOp
, loc
, rewriter
, "vprintf", vprintfType
);
548 // Create the global op or find an existing one.
549 LLVM::GlobalOp global
= getOrCreateFormatStringConstant(
550 rewriter
, loc
, moduleOp
, llvmI8
, adaptor
.getFormat());
552 // Get a pointer to the format string's first element
553 Value globalPtr
= rewriter
.create
<LLVM::AddressOfOp
>(loc
, global
);
555 rewriter
.create
<LLVM::GEPOp
>(loc
, ptrType
, global
.getGlobalType(),
556 globalPtr
, ArrayRef
<LLVM::GEPArg
>{0, 0});
557 SmallVector
<Type
> types
;
558 SmallVector
<Value
> args
;
559 // Promote and pack the arguments into a stack allocation.
560 for (Value arg
: adaptor
.getArgs()) {
561 Type type
= arg
.getType();
562 Value promotedArg
= arg
;
563 assert(type
.isIntOrFloat());
564 if (isa
<FloatType
>(type
)) {
565 type
= rewriter
.getF64Type();
566 promotedArg
= rewriter
.create
<LLVM::FPExtOp
>(loc
, type
, arg
);
568 types
.push_back(type
);
569 args
.push_back(promotedArg
);
572 LLVM::LLVMStructType::getLiteral(gpuPrintfOp
.getContext(), types
);
573 Value one
= rewriter
.create
<LLVM::ConstantOp
>(loc
, rewriter
.getI64Type(),
574 rewriter
.getIndexAttr(1));
576 rewriter
.create
<LLVM::AllocaOp
>(loc
, ptrType
, structType
, one
,
578 for (auto [index
, arg
] : llvm::enumerate(args
)) {
579 Value ptr
= rewriter
.create
<LLVM::GEPOp
>(
580 loc
, ptrType
, structType
, tempAlloc
,
581 ArrayRef
<LLVM::GEPArg
>{0, static_cast<int32_t>(index
)});
582 rewriter
.create
<LLVM::StoreOp
>(loc
, arg
, ptr
);
584 std::array
<Value
, 2> printfArgs
= {stringStart
, tempAlloc
};
586 rewriter
.create
<LLVM::CallOp
>(loc
, vprintfDecl
, printfArgs
);
587 rewriter
.eraseOp(gpuPrintfOp
);
591 /// Unrolls op if it's operating on vectors.
592 LogicalResult
impl::scalarizeVectorOp(Operation
*op
, ValueRange operands
,
593 ConversionPatternRewriter
&rewriter
,
594 const LLVMTypeConverter
&converter
) {
595 TypeRange
operandTypes(operands
);
596 if (llvm::none_of(operandTypes
, llvm::IsaPred
<VectorType
>)) {
597 return rewriter
.notifyMatchFailure(op
, "expected vector operand");
599 if (op
->getNumRegions() != 0 || op
->getNumSuccessors() != 0)
600 return rewriter
.notifyMatchFailure(op
, "expected no region/successor");
601 if (op
->getNumResults() != 1)
602 return rewriter
.notifyMatchFailure(op
, "expected single result");
603 VectorType vectorType
= dyn_cast
<VectorType
>(op
->getResult(0).getType());
605 return rewriter
.notifyMatchFailure(op
, "expected vector result");
607 Location loc
= op
->getLoc();
608 Value result
= rewriter
.create
<LLVM::UndefOp
>(loc
, vectorType
);
609 Type indexType
= converter
.convertType(rewriter
.getIndexType());
610 StringAttr name
= op
->getName().getIdentifier();
611 Type elementType
= vectorType
.getElementType();
613 for (int64_t i
= 0; i
< vectorType
.getNumElements(); ++i
) {
614 Value index
= rewriter
.create
<LLVM::ConstantOp
>(loc
, indexType
, i
);
615 auto extractElement
= [&](Value operand
) -> Value
{
616 if (!isa
<VectorType
>(operand
.getType()))
618 return rewriter
.create
<LLVM::ExtractElementOp
>(loc
, operand
, index
);
620 auto scalarOperands
= llvm::map_to_vector(operands
, extractElement
);
621 Operation
*scalarOp
=
622 rewriter
.create(loc
, name
, scalarOperands
, elementType
, op
->getAttrs());
623 result
= rewriter
.create
<LLVM::InsertElementOp
>(
624 loc
, result
, scalarOp
->getResult(0), index
);
627 rewriter
.replaceOp(op
, result
);
631 static IntegerAttr
wrapNumericMemorySpace(MLIRContext
*ctx
, unsigned space
) {
632 return IntegerAttr::get(IntegerType::get(ctx
, 64), space
);
635 /// Generates a symbol with 0-sized array type for dynamic shared memory usage,
636 /// or uses existing symbol.
637 LLVM::GlobalOp
getDynamicSharedMemorySymbol(
638 ConversionPatternRewriter
&rewriter
, gpu::GPUModuleOp moduleOp
,
639 gpu::DynamicSharedMemoryOp op
, const LLVMTypeConverter
*typeConverter
,
640 MemRefType memrefType
, unsigned alignmentBit
) {
641 uint64_t alignmentByte
= alignmentBit
/ memrefType
.getElementTypeBitWidth();
643 FailureOr
<unsigned> addressSpace
=
644 typeConverter
->getMemRefAddressSpace(memrefType
);
645 if (failed(addressSpace
)) {
646 op
->emitError() << "conversion of memref memory space "
647 << memrefType
.getMemorySpace()
648 << " to integer address space "
649 "failed. Consider adding memory space conversions.";
652 // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
653 // LLVM::GlobalOp is suitable for shared memory, return it.
654 llvm::StringSet
<> existingGlobalNames
;
655 for (auto globalOp
: moduleOp
.getBody()->getOps
<LLVM::GlobalOp
>()) {
656 existingGlobalNames
.insert(globalOp
.getSymName());
657 if (auto arrayType
= dyn_cast
<LLVM::LLVMArrayType
>(globalOp
.getType())) {
658 if (globalOp
.getAddrSpace() == addressSpace
.value() &&
659 arrayType
.getNumElements() == 0 &&
660 globalOp
.getAlignment().value_or(0) == alignmentByte
) {
666 // Step 2. Find a unique symbol name
667 unsigned uniquingCounter
= 0;
668 SmallString
<128> symName
= SymbolTable::generateSymbolName
<128>(
670 [&](StringRef candidate
) {
671 return existingGlobalNames
.contains(candidate
);
675 // Step 3. Generate a global op
676 OpBuilder::InsertionGuard
guard(rewriter
);
677 rewriter
.setInsertionPointToStart(moduleOp
.getBody());
679 auto zeroSizedArrayType
= LLVM::LLVMArrayType::get(
680 typeConverter
->convertType(memrefType
.getElementType()), 0);
682 return rewriter
.create
<LLVM::GlobalOp
>(
683 op
->getLoc(), zeroSizedArrayType
, /*isConstant=*/false,
684 LLVM::Linkage::Internal
, symName
, /*value=*/Attribute(), alignmentByte
,
685 addressSpace
.value());
688 LogicalResult
GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
689 gpu::DynamicSharedMemoryOp op
, OpAdaptor adaptor
,
690 ConversionPatternRewriter
&rewriter
) const {
691 Location loc
= op
.getLoc();
692 MemRefType memrefType
= op
.getResultMemref().getType();
693 Type elementType
= typeConverter
->convertType(memrefType
.getElementType());
695 // Step 1: Generate a memref<0xi8> type
696 MemRefLayoutAttrInterface layout
= {};
698 MemRefType::get({0}, elementType
, layout
, memrefType
.getMemorySpace());
700 // Step 2: Generate a global symbol or existing for the dynamic shared
701 // memory with memref<0xi8> type
702 auto moduleOp
= op
->getParentOfType
<gpu::GPUModuleOp
>();
703 LLVM::GlobalOp shmemOp
= getDynamicSharedMemorySymbol(
704 rewriter
, moduleOp
, op
, getTypeConverter(), memrefType0sz
, alignmentBit
);
706 // Step 3. Get address of the global symbol
707 OpBuilder::InsertionGuard
guard(rewriter
);
708 rewriter
.setInsertionPoint(op
);
709 auto basePtr
= rewriter
.create
<LLVM::AddressOfOp
>(loc
, shmemOp
);
710 Type baseType
= basePtr
->getResultTypes().front();
712 // Step 4. Generate GEP using offsets
713 SmallVector
<LLVM::GEPArg
> gepArgs
= {0};
714 Value shmemPtr
= rewriter
.create
<LLVM::GEPOp
>(loc
, baseType
, elementType
,
716 // Step 5. Create a memref descriptor
717 SmallVector
<Value
> shape
, strides
;
719 getMemRefDescriptorSizes(loc
, memrefType0sz
, {}, rewriter
, shape
, strides
,
721 auto memRefDescriptor
= this->createMemRefDescriptor(
722 loc
, memrefType0sz
, shmemPtr
, shmemPtr
, shape
, strides
, rewriter
);
724 // Step 5. Replace the op with memref descriptor
725 rewriter
.replaceOp(op
, {memRefDescriptor
});
729 LogicalResult
GPUReturnOpLowering::matchAndRewrite(
730 gpu::ReturnOp op
, OpAdaptor adaptor
,
731 ConversionPatternRewriter
&rewriter
) const {
732 Location loc
= op
.getLoc();
733 unsigned numArguments
= op
.getNumOperands();
734 SmallVector
<Value
, 4> updatedOperands
;
736 bool useBarePtrCallConv
= getTypeConverter()->getOptions().useBarePtrCallConv
;
737 if (useBarePtrCallConv
) {
738 // For the bare-ptr calling convention, extract the aligned pointer to
739 // be returned from the memref descriptor.
740 for (auto it
: llvm::zip(op
->getOperands(), adaptor
.getOperands())) {
741 Type oldTy
= std::get
<0>(it
).getType();
742 Value newOperand
= std::get
<1>(it
);
743 if (isa
<MemRefType
>(oldTy
) && getTypeConverter()->canConvertToBarePtr(
744 cast
<BaseMemRefType
>(oldTy
))) {
745 MemRefDescriptor
memrefDesc(newOperand
);
746 newOperand
= memrefDesc
.allocatedPtr(rewriter
, loc
);
747 } else if (isa
<UnrankedMemRefType
>(oldTy
)) {
748 // Unranked memref is not supported in the bare pointer calling
752 updatedOperands
.push_back(newOperand
);
755 updatedOperands
= llvm::to_vector
<4>(adaptor
.getOperands());
756 (void)copyUnrankedDescriptors(rewriter
, loc
, op
.getOperands().getTypes(),
761 // If ReturnOp has 0 or 1 operand, create it and return immediately.
762 if (numArguments
<= 1) {
763 rewriter
.replaceOpWithNewOp
<LLVM::ReturnOp
>(
764 op
, TypeRange(), updatedOperands
, op
->getAttrs());
768 // Otherwise, we need to pack the arguments into an LLVM struct type before
770 auto packedType
= getTypeConverter()->packFunctionResults(
771 op
.getOperandTypes(), useBarePtrCallConv
);
773 return rewriter
.notifyMatchFailure(op
, "could not convert result types");
776 Value packed
= rewriter
.create
<LLVM::UndefOp
>(loc
, packedType
);
777 for (auto [idx
, operand
] : llvm::enumerate(updatedOperands
)) {
778 packed
= rewriter
.create
<LLVM::InsertValueOp
>(loc
, packed
, operand
, idx
);
780 rewriter
.replaceOpWithNewOp
<LLVM::ReturnOp
>(op
, TypeRange(), packed
,
785 void mlir::populateGpuMemorySpaceAttributeConversions(
786 TypeConverter
&typeConverter
, const MemorySpaceMapping
&mapping
) {
787 typeConverter
.addTypeAttributeConversion(
788 [mapping
](BaseMemRefType type
, gpu::AddressSpaceAttr memorySpaceAttr
) {
789 gpu::AddressSpace memorySpace
= memorySpaceAttr
.getValue();
790 unsigned addressSpace
= mapping(memorySpace
);
791 return wrapNumericMemorySpace(memorySpaceAttr
.getContext(),