Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Conversion / GPUCommon / GPUOpsLowering.cpp
blobb3c3fd4956d0bb0cd91ff19a1102b6ed01d2dfae
1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "GPUOpsLowering.h"
11 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "llvm/ADT/SmallVectorExtras.h"
17 #include "llvm/ADT/StringSet.h"
18 #include "llvm/Support/FormatVariadic.h"
20 using namespace mlir;
22 LogicalResult
23 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
24 ConversionPatternRewriter &rewriter) const {
25 Location loc = gpuFuncOp.getLoc();
27 SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
28 if (encodeWorkgroupAttributionsAsArguments) {
29 // Append an `llvm.ptr` argument to the function signature to encode
30 // workgroup attributions.
32 ArrayRef<BlockArgument> workgroupAttributions =
33 gpuFuncOp.getWorkgroupAttributions();
34 size_t numAttributions = workgroupAttributions.size();
36 // Insert all arguments at the end.
37 unsigned index = gpuFuncOp.getNumArguments();
38 SmallVector<unsigned> argIndices(numAttributions, index);
40 // New arguments will simply be `llvm.ptr` with the correct address space
41 Type workgroupPtrType =
42 rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace);
43 SmallVector<Type> argTypes(numAttributions, workgroupPtrType);
45 // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
46 std::array attrs{
47 rewriter.getNamedAttr(LLVM::LLVMDialect::getNoAliasAttrName(),
48 rewriter.getUnitAttr()),
49 rewriter.getNamedAttr(
50 getDialect().getWorkgroupAttributionAttrHelper().getName(),
51 rewriter.getUnitAttr()),
53 SmallVector<DictionaryAttr> argAttrs;
54 for (BlockArgument attribution : workgroupAttributions) {
55 auto attributionType = cast<MemRefType>(attribution.getType());
56 IntegerAttr numElements =
57 rewriter.getI64IntegerAttr(attributionType.getNumElements());
58 Type llvmElementType =
59 getTypeConverter()->convertType(attributionType.getElementType());
60 if (!llvmElementType)
61 return failure();
62 TypeAttr type = TypeAttr::get(llvmElementType);
63 attrs.back().setValue(
64 rewriter.getAttr<LLVM::WorkgroupAttributionAttr>(numElements, type));
65 argAttrs.push_back(rewriter.getDictionaryAttr(attrs));
68 // Location match function location
69 SmallVector<Location> argLocs(numAttributions, gpuFuncOp.getLoc());
71 // Perform signature modification
72 rewriter.modifyOpInPlace(
73 gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() {
74 static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments(
75 argIndices, argTypes, argAttrs, argLocs);
76 });
77 } else {
78 workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
79 for (auto [idx, attribution] :
80 llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
81 auto type = dyn_cast<MemRefType>(attribution.getType());
82 assert(type && type.hasStaticShape() && "unexpected type in attribution");
84 uint64_t numElements = type.getNumElements();
86 auto elementType =
87 cast<Type>(typeConverter->convertType(type.getElementType()));
88 auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
89 std::string name =
90 std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx));
91 uint64_t alignment = 0;
92 if (auto alignAttr = dyn_cast_or_null<IntegerAttr>(
93 gpuFuncOp.getWorkgroupAttributionAttr(
94 idx, LLVM::LLVMDialect::getAlignAttrName())))
95 alignment = alignAttr.getInt();
96 auto globalOp = rewriter.create<LLVM::GlobalOp>(
97 gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
98 LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
99 workgroupAddrSpace);
100 workgroupBuffers.push_back(globalOp);
104 // Remap proper input types.
105 TypeConverter::SignatureConversion signatureConversion(
106 gpuFuncOp.front().getNumArguments());
108 Type funcType = getTypeConverter()->convertFunctionSignature(
109 gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
110 getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
111 if (!funcType) {
112 return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
113 diag << "failed to convert function signature type for: "
114 << gpuFuncOp.getFunctionType();
118 // Create the new function operation. Only copy those attributes that are
119 // not specific to function modeling.
120 SmallVector<NamedAttribute, 4> attributes;
121 ArrayAttr argAttrs;
122 for (const auto &attr : gpuFuncOp->getAttrs()) {
123 if (attr.getName() == SymbolTable::getSymbolAttrName() ||
124 attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
125 attr.getName() ==
126 gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() ||
127 attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() ||
128 attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() ||
129 attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() ||
130 attr.getName() == gpuFuncOp.getKnownGridSizeAttrName())
131 continue;
132 if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) {
133 argAttrs = gpuFuncOp.getArgAttrsAttr();
134 continue;
136 attributes.push_back(attr);
139 DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr();
140 DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
141 // Ensure we don't lose information if the function is lowered before its
142 // surrounding context.
143 auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
144 if (knownBlockSize)
145 attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
146 knownBlockSize);
147 if (knownGridSize)
148 attributes.emplace_back(gpuDialect->getKnownGridSizeAttrHelper().getName(),
149 knownGridSize);
151 // Add a dialect specific kernel attribute in addition to GPU kernel
152 // attribute. The former is necessary for further translation while the
153 // latter is expected by gpu.launch_func.
154 if (gpuFuncOp.isKernel()) {
155 if (kernelAttributeName)
156 attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
157 // Set the dialect-specific block size attribute if there is one.
158 if (kernelBlockSizeAttributeName && knownBlockSize) {
159 attributes.emplace_back(kernelBlockSizeAttributeName, knownBlockSize);
162 LLVM::CConv callingConvention = gpuFuncOp.isKernel()
163 ? kernelCallingConvention
164 : nonKernelCallingConvention;
165 auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
166 gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
167 LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention,
168 /*comdat=*/nullptr, attributes);
171 // Insert operations that correspond to converted workgroup and private
172 // memory attributions to the body of the function. This must operate on
173 // the original function, before the body region is inlined in the new
174 // function to maintain the relation between block arguments and the
175 // parent operation that assigns their semantics.
176 OpBuilder::InsertionGuard guard(rewriter);
178 // Rewrite workgroup memory attributions to addresses of global buffers.
179 rewriter.setInsertionPointToStart(&gpuFuncOp.front());
180 unsigned numProperArguments = gpuFuncOp.getNumArguments();
182 if (encodeWorkgroupAttributionsAsArguments) {
183 // Build a MemRefDescriptor with each of the arguments added above.
185 unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions();
186 assert(numProperArguments >= numAttributions &&
187 "Expecting attributions to be encoded as arguments already");
189 // Arguments encoding workgroup attributions will be in positions
190 // [numProperArguments, numProperArguments+numAttributions)
191 ArrayRef<BlockArgument> attributionArguments =
192 gpuFuncOp.getArguments().slice(numProperArguments - numAttributions,
193 numAttributions);
194 for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal(
195 gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) {
196 auto [attribution, arg] = vals;
197 auto type = cast<MemRefType>(attribution.getType());
199 // Arguments are of llvm.ptr type and attributions are of memref type:
200 // we need to wrap them in memref descriptors.
201 Value descr = MemRefDescriptor::fromStaticShape(
202 rewriter, loc, *getTypeConverter(), type, arg);
204 // And remap the arguments
205 signatureConversion.remapInput(numProperArguments + idx, descr);
207 } else {
208 for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
209 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
210 global.getAddrSpace());
211 Value address = rewriter.create<LLVM::AddressOfOp>(
212 loc, ptrType, global.getSymNameAttr());
213 Value memory =
214 rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(),
215 address, ArrayRef<LLVM::GEPArg>{0, 0});
217 // Build a memref descriptor pointing to the buffer to plug with the
218 // existing memref infrastructure. This may use more registers than
219 // otherwise necessary given that memref sizes are fixed, but we can try
220 // and canonicalize that away later.
221 Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
222 auto type = cast<MemRefType>(attribution.getType());
223 auto descr = MemRefDescriptor::fromStaticShape(
224 rewriter, loc, *getTypeConverter(), type, memory);
225 signatureConversion.remapInput(numProperArguments + idx, descr);
229 // Rewrite private memory attributions to alloca'ed buffers.
230 unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
231 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
232 for (const auto [idx, attribution] :
233 llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
234 auto type = cast<MemRefType>(attribution.getType());
235 assert(type && type.hasStaticShape() && "unexpected type in attribution");
237 // Explicitly drop memory space when lowering private memory
238 // attributions since NVVM models it as `alloca`s in the default
239 // memory space and does not support `alloca`s with addrspace(5).
240 Type elementType = typeConverter->convertType(type.getElementType());
241 auto ptrType =
242 LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
243 Value numElements = rewriter.create<LLVM::ConstantOp>(
244 gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
245 uint64_t alignment = 0;
246 if (auto alignAttr =
247 dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
248 idx, LLVM::LLVMDialect::getAlignAttrName())))
249 alignment = alignAttr.getInt();
250 Value allocated = rewriter.create<LLVM::AllocaOp>(
251 gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
252 auto descr = MemRefDescriptor::fromStaticShape(
253 rewriter, loc, *getTypeConverter(), type, allocated);
254 signatureConversion.remapInput(
255 numProperArguments + numWorkgroupAttributions + idx, descr);
259 // Move the region to the new function, update the entry block signature.
260 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
261 llvmFuncOp.end());
262 if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
263 &signatureConversion)))
264 return failure();
266 // Get memref type from function arguments and set the noalias to
267 // pointer arguments.
268 for (const auto [idx, argTy] :
269 llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
270 auto remapping = signatureConversion.getInputMapping(idx);
271 NamedAttrList argAttr =
272 argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList();
273 auto copyAttribute = [&](StringRef attrName) {
274 Attribute attr = argAttr.erase(attrName);
275 if (!attr)
276 return;
277 for (size_t i = 0, e = remapping->size; i < e; ++i)
278 llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
280 auto copyPointerAttribute = [&](StringRef attrName) {
281 Attribute attr = argAttr.erase(attrName);
283 if (!attr)
284 return;
285 if (remapping->size > 1 &&
286 attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
287 emitWarning(llvmFuncOp.getLoc(),
288 "Cannot copy noalias with non-bare pointers.\n");
289 return;
291 for (size_t i = 0, e = remapping->size; i < e; ++i) {
292 if (isa<LLVM::LLVMPointerType>(
293 llvmFuncOp.getArgument(remapping->inputNo + i).getType())) {
294 llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
299 if (argAttr.empty())
300 continue;
302 copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
303 copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
304 copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
305 bool lowersToPointer = false;
306 for (size_t i = 0, e = remapping->size; i < e; ++i) {
307 lowersToPointer |= isa<LLVM::LLVMPointerType>(
308 llvmFuncOp.getArgument(remapping->inputNo + i).getType());
311 if (lowersToPointer) {
312 copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
313 copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
314 copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
315 copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
316 copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
317 copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
318 copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
319 copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
320 copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
321 copyPointerAttribute(
322 LLVM::LLVMDialect::getDereferenceableOrNullAttrName());
323 copyPointerAttribute(
324 LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr());
327 rewriter.eraseOp(gpuFuncOp);
328 return success();
331 static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
332 const char formatStringPrefix[] = "printfFormat_";
333 // Get a unique global name.
334 unsigned stringNumber = 0;
335 SmallString<16> stringConstName;
336 do {
337 stringConstName.clear();
338 (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
339 } while (moduleOp.lookupSymbol(stringConstName));
340 return stringConstName;
343 /// Create an global that contains the given format string. If a global with
344 /// the same format string exists already in the module, return that global.
345 static LLVM::GlobalOp getOrCreateFormatStringConstant(
346 OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
347 StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
348 llvm::SmallString<20> formatString(str);
349 formatString.push_back('\0'); // Null terminate for C
350 auto globalType =
351 LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
352 StringAttr attr = b.getStringAttr(formatString);
354 // Try to find existing global.
355 for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
356 if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
357 globalOp.getValueAttr() == attr &&
358 globalOp.getAlignment().value_or(0) == alignment &&
359 globalOp.getAddrSpace() == addrSpace)
360 return globalOp;
362 // Not found: create new global.
363 OpBuilder::InsertionGuard guard(b);
364 b.setInsertionPointToStart(moduleOp.getBody());
365 SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
366 return b.create<LLVM::GlobalOp>(loc, globalType,
367 /*isConstant=*/true, LLVM::Linkage::Internal,
368 name, attr, alignment, addrSpace);
371 template <typename T>
372 static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
373 ConversionPatternRewriter &rewriter,
374 StringRef name,
375 LLVM::LLVMFunctionType type) {
376 LLVM::LLVMFuncOp ret;
377 if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
378 ConversionPatternRewriter::InsertionGuard guard(rewriter);
379 rewriter.setInsertionPointToStart(moduleOp.getBody());
380 ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
381 LLVM::Linkage::External);
383 return ret;
386 LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
387 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
388 ConversionPatternRewriter &rewriter) const {
389 Location loc = gpuPrintfOp->getLoc();
391 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
392 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
393 mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
394 mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
395 // Note: this is the GPUModule op, not the ModuleOp that surrounds it
396 // This ensures that global constants and declarations are placed within
397 // the device code, not the host code
398 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
400 auto ocklBegin =
401 getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
402 LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
403 LLVM::LLVMFuncOp ocklAppendArgs;
404 if (!adaptor.getArgs().empty()) {
405 ocklAppendArgs = getOrDefineFunction(
406 moduleOp, loc, rewriter, "__ockl_printf_append_args",
407 LLVM::LLVMFunctionType::get(
408 llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
409 llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
411 auto ocklAppendStringN = getOrDefineFunction(
412 moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
413 LLVM::LLVMFunctionType::get(
414 llvmI64,
415 {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
417 /// Start the printf hostcall
418 Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
419 auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
420 Value printfDesc = printfBeginCall.getResult();
422 // Create the global op or find an existing one.
423 LLVM::GlobalOp global = getOrCreateFormatStringConstant(
424 rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
426 // Get a pointer to the format string's first element and pass it to printf()
427 Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
428 loc,
429 LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
430 global.getSymNameAttr());
431 Value stringStart =
432 rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
433 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
434 Value stringLen = rewriter.create<LLVM::ConstantOp>(
435 loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
437 Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
438 Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
440 auto appendFormatCall = rewriter.create<LLVM::CallOp>(
441 loc, ocklAppendStringN,
442 ValueRange{printfDesc, stringStart, stringLen,
443 adaptor.getArgs().empty() ? oneI32 : zeroI32});
444 printfDesc = appendFormatCall.getResult();
446 // __ockl_printf_append_args takes 7 values per append call
447 constexpr size_t argsPerAppend = 7;
448 size_t nArgs = adaptor.getArgs().size();
449 for (size_t group = 0; group < nArgs; group += argsPerAppend) {
450 size_t bound = std::min(group + argsPerAppend, nArgs);
451 size_t numArgsThisCall = bound - group;
453 SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
454 arguments.push_back(printfDesc);
455 arguments.push_back(
456 rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
457 for (size_t i = group; i < bound; ++i) {
458 Value arg = adaptor.getArgs()[i];
459 if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
460 if (!floatType.isF64())
461 arg = rewriter.create<LLVM::FPExtOp>(
462 loc, typeConverter->convertType(rewriter.getF64Type()), arg);
463 arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
465 if (arg.getType().getIntOrFloatBitWidth() != 64)
466 arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
468 arguments.push_back(arg);
470 // Pad out to 7 arguments since the hostcall always needs 7
471 for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
472 arguments.push_back(zeroI64);
475 auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
476 arguments.push_back(isLast);
477 auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
478 printfDesc = call.getResult();
480 rewriter.eraseOp(gpuPrintfOp);
481 return success();
484 LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
485 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
486 ConversionPatternRewriter &rewriter) const {
487 Location loc = gpuPrintfOp->getLoc();
489 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
490 mlir::Type ptrType =
491 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
493 // Note: this is the GPUModule op, not the ModuleOp that surrounds it
494 // This ensures that global constants and declarations are placed within
495 // the device code, not the host code
496 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
498 auto printfType =
499 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
500 /*isVarArg=*/true);
501 LLVM::LLVMFuncOp printfDecl =
502 getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
504 // Create the global op or find an existing one.
505 LLVM::GlobalOp global = getOrCreateFormatStringConstant(
506 rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
507 addressSpace);
509 // Get a pointer to the format string's first element
510 Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
511 loc,
512 LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
513 global.getSymNameAttr());
514 Value stringStart =
515 rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
516 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
518 // Construct arguments and function call
519 auto argsRange = adaptor.getArgs();
520 SmallVector<Value, 4> printfArgs;
521 printfArgs.reserve(argsRange.size() + 1);
522 printfArgs.push_back(stringStart);
523 printfArgs.append(argsRange.begin(), argsRange.end());
525 rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
526 rewriter.eraseOp(gpuPrintfOp);
527 return success();
530 LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
531 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
532 ConversionPatternRewriter &rewriter) const {
533 Location loc = gpuPrintfOp->getLoc();
535 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
536 mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
538 // Note: this is the GPUModule op, not the ModuleOp that surrounds it
539 // This ensures that global constants and declarations are placed within
540 // the device code, not the host code
541 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
543 auto vprintfType =
544 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
545 LLVM::LLVMFuncOp vprintfDecl =
546 getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
548 // Create the global op or find an existing one.
549 LLVM::GlobalOp global = getOrCreateFormatStringConstant(
550 rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
552 // Get a pointer to the format string's first element
553 Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
554 Value stringStart =
555 rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
556 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
557 SmallVector<Type> types;
558 SmallVector<Value> args;
559 // Promote and pack the arguments into a stack allocation.
560 for (Value arg : adaptor.getArgs()) {
561 Type type = arg.getType();
562 Value promotedArg = arg;
563 assert(type.isIntOrFloat());
564 if (isa<FloatType>(type)) {
565 type = rewriter.getF64Type();
566 promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
568 types.push_back(type);
569 args.push_back(promotedArg);
571 Type structType =
572 LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
573 Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
574 rewriter.getIndexAttr(1));
575 Value tempAlloc =
576 rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
577 /*alignment=*/0);
578 for (auto [index, arg] : llvm::enumerate(args)) {
579 Value ptr = rewriter.create<LLVM::GEPOp>(
580 loc, ptrType, structType, tempAlloc,
581 ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)});
582 rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
584 std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
586 rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
587 rewriter.eraseOp(gpuPrintfOp);
588 return success();
591 /// Unrolls op if it's operating on vectors.
592 LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
593 ConversionPatternRewriter &rewriter,
594 const LLVMTypeConverter &converter) {
595 TypeRange operandTypes(operands);
596 if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
597 return rewriter.notifyMatchFailure(op, "expected vector operand");
599 if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
600 return rewriter.notifyMatchFailure(op, "expected no region/successor");
601 if (op->getNumResults() != 1)
602 return rewriter.notifyMatchFailure(op, "expected single result");
603 VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
604 if (!vectorType)
605 return rewriter.notifyMatchFailure(op, "expected vector result");
607 Location loc = op->getLoc();
608 Value result = rewriter.create<LLVM::UndefOp>(loc, vectorType);
609 Type indexType = converter.convertType(rewriter.getIndexType());
610 StringAttr name = op->getName().getIdentifier();
611 Type elementType = vectorType.getElementType();
613 for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
614 Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
615 auto extractElement = [&](Value operand) -> Value {
616 if (!isa<VectorType>(operand.getType()))
617 return operand;
618 return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
620 auto scalarOperands = llvm::map_to_vector(operands, extractElement);
621 Operation *scalarOp =
622 rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
623 result = rewriter.create<LLVM::InsertElementOp>(
624 loc, result, scalarOp->getResult(0), index);
627 rewriter.replaceOp(op, result);
628 return success();
631 static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
632 return IntegerAttr::get(IntegerType::get(ctx, 64), space);
635 /// Generates a symbol with 0-sized array type for dynamic shared memory usage,
636 /// or uses existing symbol.
637 LLVM::GlobalOp getDynamicSharedMemorySymbol(
638 ConversionPatternRewriter &rewriter, gpu::GPUModuleOp moduleOp,
639 gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter,
640 MemRefType memrefType, unsigned alignmentBit) {
641 uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
643 FailureOr<unsigned> addressSpace =
644 typeConverter->getMemRefAddressSpace(memrefType);
645 if (failed(addressSpace)) {
646 op->emitError() << "conversion of memref memory space "
647 << memrefType.getMemorySpace()
648 << " to integer address space "
649 "failed. Consider adding memory space conversions.";
652 // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
653 // LLVM::GlobalOp is suitable for shared memory, return it.
654 llvm::StringSet<> existingGlobalNames;
655 for (auto globalOp : moduleOp.getBody()->getOps<LLVM::GlobalOp>()) {
656 existingGlobalNames.insert(globalOp.getSymName());
657 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
658 if (globalOp.getAddrSpace() == addressSpace.value() &&
659 arrayType.getNumElements() == 0 &&
660 globalOp.getAlignment().value_or(0) == alignmentByte) {
661 return globalOp;
666 // Step 2. Find a unique symbol name
667 unsigned uniquingCounter = 0;
668 SmallString<128> symName = SymbolTable::generateSymbolName<128>(
669 "__dynamic_shmem_",
670 [&](StringRef candidate) {
671 return existingGlobalNames.contains(candidate);
673 uniquingCounter);
675 // Step 3. Generate a global op
676 OpBuilder::InsertionGuard guard(rewriter);
677 rewriter.setInsertionPointToStart(moduleOp.getBody());
679 auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
680 typeConverter->convertType(memrefType.getElementType()), 0);
682 return rewriter.create<LLVM::GlobalOp>(
683 op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
684 LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
685 addressSpace.value());
688 LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
689 gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
690 ConversionPatternRewriter &rewriter) const {
691 Location loc = op.getLoc();
692 MemRefType memrefType = op.getResultMemref().getType();
693 Type elementType = typeConverter->convertType(memrefType.getElementType());
695 // Step 1: Generate a memref<0xi8> type
696 MemRefLayoutAttrInterface layout = {};
697 auto memrefType0sz =
698 MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
700 // Step 2: Generate a global symbol or existing for the dynamic shared
701 // memory with memref<0xi8> type
702 auto moduleOp = op->getParentOfType<gpu::GPUModuleOp>();
703 LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
704 rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit);
706 // Step 3. Get address of the global symbol
707 OpBuilder::InsertionGuard guard(rewriter);
708 rewriter.setInsertionPoint(op);
709 auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
710 Type baseType = basePtr->getResultTypes().front();
712 // Step 4. Generate GEP using offsets
713 SmallVector<LLVM::GEPArg> gepArgs = {0};
714 Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
715 basePtr, gepArgs);
716 // Step 5. Create a memref descriptor
717 SmallVector<Value> shape, strides;
718 Value sizeBytes;
719 getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
720 sizeBytes);
721 auto memRefDescriptor = this->createMemRefDescriptor(
722 loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
724 // Step 5. Replace the op with memref descriptor
725 rewriter.replaceOp(op, {memRefDescriptor});
726 return success();
729 LogicalResult GPUReturnOpLowering::matchAndRewrite(
730 gpu::ReturnOp op, OpAdaptor adaptor,
731 ConversionPatternRewriter &rewriter) const {
732 Location loc = op.getLoc();
733 unsigned numArguments = op.getNumOperands();
734 SmallVector<Value, 4> updatedOperands;
736 bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv;
737 if (useBarePtrCallConv) {
738 // For the bare-ptr calling convention, extract the aligned pointer to
739 // be returned from the memref descriptor.
740 for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
741 Type oldTy = std::get<0>(it).getType();
742 Value newOperand = std::get<1>(it);
743 if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
744 cast<BaseMemRefType>(oldTy))) {
745 MemRefDescriptor memrefDesc(newOperand);
746 newOperand = memrefDesc.allocatedPtr(rewriter, loc);
747 } else if (isa<UnrankedMemRefType>(oldTy)) {
748 // Unranked memref is not supported in the bare pointer calling
749 // convention.
750 return failure();
752 updatedOperands.push_back(newOperand);
754 } else {
755 updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
756 (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
757 updatedOperands,
758 /*toDynamic=*/true);
761 // If ReturnOp has 0 or 1 operand, create it and return immediately.
762 if (numArguments <= 1) {
763 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
764 op, TypeRange(), updatedOperands, op->getAttrs());
765 return success();
768 // Otherwise, we need to pack the arguments into an LLVM struct type before
769 // returning.
770 auto packedType = getTypeConverter()->packFunctionResults(
771 op.getOperandTypes(), useBarePtrCallConv);
772 if (!packedType) {
773 return rewriter.notifyMatchFailure(op, "could not convert result types");
776 Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
777 for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
778 packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
780 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
781 op->getAttrs());
782 return success();
785 void mlir::populateGpuMemorySpaceAttributeConversions(
786 TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
787 typeConverter.addTypeAttributeConversion(
788 [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) {
789 gpu::AddressSpace memorySpace = memorySpaceAttr.getValue();
790 unsigned addressSpace = mapping(memorySpace);
791 return wrapNumericMemorySpace(memorySpaceAttr.getContext(),
792 addressSpace);