[mlir] Use StringRef::{starts,ends}_with (NFC)
[llvm-project.git] / mlir / lib / Dialect / LLVMIR / IR / LLVMDialect.cpp
blob458bf83eac17f8f34596f3d56ac601c3db6fb84b
1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the LLVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "LLVMInlining.h"
16 #include "TypeDetail.h"
17 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
18 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
19 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/DialectImplementation.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/Interfaces/FunctionImplementation.h"
28 #include "llvm/ADT/SCCIterator.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/AsmParser/Parser.h"
31 #include "llvm/Bitcode/BitcodeReader.h"
32 #include "llvm/Bitcode/BitcodeWriter.h"
33 #include "llvm/IR/Attributes.h"
34 #include "llvm/IR/Function.h"
35 #include "llvm/IR/Type.h"
36 #include "llvm/Support/Error.h"
37 #include "llvm/Support/Mutex.h"
38 #include "llvm/Support/SourceMgr.h"
40 #include <numeric>
41 #include <optional>
43 using namespace mlir;
44 using namespace mlir::LLVM;
45 using mlir::LLVM::cconv::getMaxEnumValForCConv;
46 using mlir::LLVM::linkage::getMaxEnumValForLinkage;
48 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
50 static constexpr const char kElemTypeAttrName[] = "elem_type";
52 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
53 SmallVector<NamedAttribute, 8> filteredAttrs(
54 llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
55 if (attr.getName() == "fastmathFlags") {
56 auto defAttr =
57 FastmathFlagsAttr::get(attr.getValue().getContext(), {});
58 return defAttr != attr.getValue();
60 return true;
61 }));
62 return filteredAttrs;
65 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
66 NamedAttrList &result) {
67 return parser.parseOptionalAttrDict(result);
70 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
71 DictionaryAttr attrs) {
72 auto filteredAttrs = processFMFAttr(attrs.getValue());
73 if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
74 printer.printOptionalAttrDict(
75 filteredAttrs,
76 /*elidedAttrs=*/{iface.getIntegerOverflowAttrName()});
77 else
78 printer.printOptionalAttrDict(filteredAttrs);
81 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
82 /// fully defined llvm.func.
83 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
84 Operation *op,
85 SymbolTableCollection &symbolTable) {
86 StringRef name = symbol.getValue();
87 auto func =
88 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
89 if (!func)
90 return op->emitOpError("'")
91 << name << "' does not reference a valid LLVM function";
92 if (func.isExternal())
93 return op->emitOpError("'") << name << "' does not have a definition";
94 return success();
97 /// Returns a boolean type that has the same shape as `type`. It supports both
98 /// fixed size vectors as well as scalable vectors.
99 static Type getI1SameShape(Type type) {
100 Type i1Type = IntegerType::get(type.getContext(), 1);
101 if (LLVM::isCompatibleVectorType(type))
102 return LLVM::getVectorType(i1Type, LLVM::getVectorNumElements(type));
103 return i1Type;
106 // Parses one of the keywords provided in the list `keywords` and returns the
107 // position of the parsed keyword in the list. If none of the keywords from the
108 // list is parsed, returns -1.
109 static int parseOptionalKeywordAlternative(OpAsmParser &parser,
110 ArrayRef<StringRef> keywords) {
111 for (const auto &en : llvm::enumerate(keywords)) {
112 if (succeeded(parser.parseOptionalKeyword(en.value())))
113 return en.index();
115 return -1;
118 namespace {
119 template <typename Ty>
120 struct EnumTraits {};
122 #define REGISTER_ENUM_TYPE(Ty) \
123 template <> \
124 struct EnumTraits<Ty> { \
125 static StringRef stringify(Ty value) { return stringify##Ty(value); } \
126 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
129 REGISTER_ENUM_TYPE(Linkage);
130 REGISTER_ENUM_TYPE(UnnamedAddr);
131 REGISTER_ENUM_TYPE(CConv);
132 REGISTER_ENUM_TYPE(Visibility);
133 } // namespace
135 /// Parse an enum from the keyword, or default to the provided default value.
136 /// The return type is the enum type by default, unless overridden with the
137 /// second template argument.
138 template <typename EnumTy, typename RetTy = EnumTy>
139 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
140 OperationState &result,
141 EnumTy defaultValue) {
142 SmallVector<StringRef, 10> names;
143 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
144 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
146 int index = parseOptionalKeywordAlternative(parser, names);
147 if (index == -1)
148 return static_cast<RetTy>(defaultValue);
149 return static_cast<RetTy>(index);
152 //===----------------------------------------------------------------------===//
153 // Printing, parsing, folding and builder for LLVM::CmpOp.
154 //===----------------------------------------------------------------------===//
156 void ICmpOp::print(OpAsmPrinter &p) {
157 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)
158 << ", " << getOperand(1);
159 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"});
160 p << " : " << getLhs().getType();
163 void FCmpOp::print(OpAsmPrinter &p) {
164 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0)
165 << ", " << getOperand(1);
166 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"});
167 p << " : " << getLhs().getType();
170 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
171 // attribute-dict? `:` type
172 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
173 // attribute-dict? `:` type
174 template <typename CmpPredicateType>
175 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
176 StringAttr predicateAttr;
177 OpAsmParser::UnresolvedOperand lhs, rhs;
178 Type type;
179 SMLoc predicateLoc, trailingTypeLoc;
180 if (parser.getCurrentLocation(&predicateLoc) ||
181 parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
182 parser.parseOperand(lhs) || parser.parseComma() ||
183 parser.parseOperand(rhs) ||
184 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
185 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
186 parser.resolveOperand(lhs, type, result.operands) ||
187 parser.resolveOperand(rhs, type, result.operands))
188 return failure();
190 // Replace the string attribute `predicate` with an integer attribute.
191 int64_t predicateValue = 0;
192 if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
193 std::optional<ICmpPredicate> predicate =
194 symbolizeICmpPredicate(predicateAttr.getValue());
195 if (!predicate)
196 return parser.emitError(predicateLoc)
197 << "'" << predicateAttr.getValue()
198 << "' is an incorrect value of the 'predicate' attribute";
199 predicateValue = static_cast<int64_t>(*predicate);
200 } else {
201 std::optional<FCmpPredicate> predicate =
202 symbolizeFCmpPredicate(predicateAttr.getValue());
203 if (!predicate)
204 return parser.emitError(predicateLoc)
205 << "'" << predicateAttr.getValue()
206 << "' is an incorrect value of the 'predicate' attribute";
207 predicateValue = static_cast<int64_t>(*predicate);
210 result.attributes.set("predicate",
211 parser.getBuilder().getI64IntegerAttr(predicateValue));
213 // The result type is either i1 or a vector type <? x i1> if the inputs are
214 // vectors.
215 if (!isCompatibleType(type))
216 return parser.emitError(trailingTypeLoc,
217 "expected LLVM dialect-compatible type");
218 result.addTypes(getI1SameShape(type));
219 return success();
222 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) {
223 return parseCmpOp<ICmpPredicate>(parser, result);
226 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
227 return parseCmpOp<FCmpPredicate>(parser, result);
230 /// Returns a scalar or vector boolean attribute of the given type.
231 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
232 auto boolAttr = BoolAttr::get(ctx, value);
233 ShapedType shapedType = dyn_cast<ShapedType>(type);
234 if (!shapedType)
235 return boolAttr;
236 return DenseElementsAttr::get(shapedType, boolAttr);
239 OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
240 if (getPredicate() != ICmpPredicate::eq &&
241 getPredicate() != ICmpPredicate::ne)
242 return {};
244 // cmpi(eq/ne, x, x) -> true/false
245 if (getLhs() == getRhs())
246 return getBoolAttribute(getType(), getContext(),
247 getPredicate() == ICmpPredicate::eq);
249 // cmpi(eq/ne, alloca, null) -> false/true
250 if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<ZeroOp>())
251 return getBoolAttribute(getType(), getContext(),
252 getPredicate() == ICmpPredicate::ne);
254 // cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null)
255 if (getLhs().getDefiningOp<ZeroOp>() && getRhs().getDefiningOp<AllocaOp>()) {
256 Value lhs = getLhs();
257 Value rhs = getRhs();
258 getLhsMutable().assign(rhs);
259 getRhsMutable().assign(lhs);
260 return getResult();
263 return {};
266 //===----------------------------------------------------------------------===//
267 // Printing, parsing and verification for LLVM::AllocaOp.
268 //===----------------------------------------------------------------------===//
270 void AllocaOp::print(OpAsmPrinter &p) {
271 auto funcTy =
272 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()});
274 if (getInalloca())
275 p << " inalloca";
277 p << ' ' << getArraySize() << " x " << getElemType();
278 if (getAlignment() && *getAlignment() != 0)
279 p.printOptionalAttrDict((*this)->getAttrs(),
280 {kElemTypeAttrName, getInallocaAttrName()});
281 else
282 p.printOptionalAttrDict(
283 (*this)->getAttrs(),
284 {getAlignmentAttrName(), kElemTypeAttrName, getInallocaAttrName()});
285 p << " : " << funcTy;
288 // <operation> ::= `llvm.alloca` `inalloca`? ssa-use `x` type
289 // attribute-dict? `:` type `,` type
290 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
291 OpAsmParser::UnresolvedOperand arraySize;
292 Type type, elemType;
293 SMLoc trailingTypeLoc;
295 if (succeeded(parser.parseOptionalKeyword("inalloca")))
296 result.addAttribute(getInallocaAttrName(result.name),
297 UnitAttr::get(parser.getContext()));
299 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
300 parser.parseType(elemType) ||
301 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
302 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
303 return failure();
305 std::optional<NamedAttribute> alignmentAttr =
306 result.attributes.getNamed("alignment");
307 if (alignmentAttr.has_value()) {
308 auto alignmentInt = llvm::dyn_cast<IntegerAttr>(alignmentAttr->getValue());
309 if (!alignmentInt)
310 return parser.emitError(parser.getNameLoc(),
311 "expected integer alignment");
312 if (alignmentInt.getValue().isZero())
313 result.attributes.erase("alignment");
316 // Extract the result type from the trailing function type.
317 auto funcType = llvm::dyn_cast<FunctionType>(type);
318 if (!funcType || funcType.getNumInputs() != 1 ||
319 funcType.getNumResults() != 1)
320 return parser.emitError(
321 trailingTypeLoc,
322 "expected trailing function type with one argument and one result");
324 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
325 return failure();
327 Type resultType = funcType.getResult(0);
328 if (auto ptrResultType = llvm::dyn_cast<LLVMPointerType>(resultType))
329 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType));
331 result.addTypes({funcType.getResult(0)});
332 return success();
335 LogicalResult AllocaOp::verify() {
336 // Only certain target extension types can be used in 'alloca'.
337 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getElemType());
338 targetExtType && !targetExtType.supportsMemOps())
339 return emitOpError()
340 << "this target extension type cannot be used in alloca";
342 return success();
345 Type AllocaOp::getResultPtrElementType() { return getElemType(); }
347 //===----------------------------------------------------------------------===//
348 // LLVM::BrOp
349 //===----------------------------------------------------------------------===//
351 SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
352 assert(index == 0 && "invalid successor index");
353 return SuccessorOperands(getDestOperandsMutable());
356 //===----------------------------------------------------------------------===//
357 // LLVM::CondBrOp
358 //===----------------------------------------------------------------------===//
360 SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
361 assert(index < getNumSuccessors() && "invalid successor index");
362 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
363 : getFalseDestOperandsMutable());
366 void CondBrOp::build(OpBuilder &builder, OperationState &result,
367 Value condition, Block *trueDest, ValueRange trueOperands,
368 Block *falseDest, ValueRange falseOperands,
369 std::optional<std::pair<uint32_t, uint32_t>> weights) {
370 DenseI32ArrayAttr weightsAttr;
371 if (weights)
372 weightsAttr =
373 builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first),
374 static_cast<int32_t>(weights->second)});
376 build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
377 /*loop_annotation=*/{}, trueDest, falseDest);
380 //===----------------------------------------------------------------------===//
381 // LLVM::SwitchOp
382 //===----------------------------------------------------------------------===//
384 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
385 Block *defaultDestination, ValueRange defaultOperands,
386 DenseIntElementsAttr caseValues,
387 BlockRange caseDestinations,
388 ArrayRef<ValueRange> caseOperands,
389 ArrayRef<int32_t> branchWeights) {
390 DenseI32ArrayAttr weightsAttr;
391 if (!branchWeights.empty())
392 weightsAttr = builder.getDenseI32ArrayAttr(branchWeights);
394 build(builder, result, value, defaultOperands, caseOperands, caseValues,
395 weightsAttr, defaultDestination, caseDestinations);
398 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
399 Block *defaultDestination, ValueRange defaultOperands,
400 ArrayRef<APInt> caseValues, BlockRange caseDestinations,
401 ArrayRef<ValueRange> caseOperands,
402 ArrayRef<int32_t> branchWeights) {
403 DenseIntElementsAttr caseValuesAttr;
404 if (!caseValues.empty()) {
405 ShapedType caseValueType = VectorType::get(
406 static_cast<int64_t>(caseValues.size()), value.getType());
407 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
410 build(builder, result, value, defaultDestination, defaultOperands,
411 caseValuesAttr, caseDestinations, caseOperands, branchWeights);
414 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
415 Block *defaultDestination, ValueRange defaultOperands,
416 ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
417 ArrayRef<ValueRange> caseOperands,
418 ArrayRef<int32_t> branchWeights) {
419 DenseIntElementsAttr caseValuesAttr;
420 if (!caseValues.empty()) {
421 ShapedType caseValueType = VectorType::get(
422 static_cast<int64_t>(caseValues.size()), value.getType());
423 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
426 build(builder, result, value, defaultDestination, defaultOperands,
427 caseValuesAttr, caseDestinations, caseOperands, branchWeights);
430 /// <cases> ::= `[` (case (`,` case )* )? `]`
431 /// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
432 static ParseResult parseSwitchOpCases(
433 OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues,
434 SmallVectorImpl<Block *> &caseDestinations,
435 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
436 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
437 if (failed(parser.parseLSquare()))
438 return failure();
439 if (succeeded(parser.parseOptionalRSquare()))
440 return success();
441 SmallVector<APInt> values;
442 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
443 auto parseCase = [&]() {
444 int64_t value = 0;
445 if (failed(parser.parseInteger(value)))
446 return failure();
447 values.push_back(APInt(bitWidth, value));
449 Block *destination;
450 SmallVector<OpAsmParser::UnresolvedOperand> operands;
451 SmallVector<Type> operandTypes;
452 if (parser.parseColon() || parser.parseSuccessor(destination))
453 return failure();
454 if (!parser.parseOptionalLParen()) {
455 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
456 /*allowResultNumber=*/false) ||
457 parser.parseColonTypeList(operandTypes) || parser.parseRParen())
458 return failure();
460 caseDestinations.push_back(destination);
461 caseOperands.emplace_back(operands);
462 caseOperandTypes.emplace_back(operandTypes);
463 return success();
465 if (failed(parser.parseCommaSeparatedList(parseCase)))
466 return failure();
468 ShapedType caseValueType =
469 VectorType::get(static_cast<int64_t>(values.size()), flagType);
470 caseValues = DenseIntElementsAttr::get(caseValueType, values);
471 return parser.parseRSquare();
474 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
475 DenseIntElementsAttr caseValues,
476 SuccessorRange caseDestinations,
477 OperandRangeRange caseOperands,
478 const TypeRangeRange &caseOperandTypes) {
479 p << '[';
480 p.printNewline();
481 if (!caseValues) {
482 p << ']';
483 return;
486 size_t index = 0;
487 llvm::interleave(
488 llvm::zip(caseValues, caseDestinations),
489 [&](auto i) {
490 p << " ";
491 p << std::get<0>(i).getLimitedValue();
492 p << ": ";
493 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
495 [&] {
496 p << ',';
497 p.printNewline();
499 p.printNewline();
500 p << ']';
503 LogicalResult SwitchOp::verify() {
504 if ((!getCaseValues() && !getCaseDestinations().empty()) ||
505 (getCaseValues() &&
506 getCaseValues()->size() !=
507 static_cast<int64_t>(getCaseDestinations().size())))
508 return emitOpError("expects number of case values to match number of "
509 "case destinations");
510 if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
511 return emitError("expects number of branch weights to match number of "
512 "successors: ")
513 << getBranchWeights()->size() << " vs " << getNumSuccessors();
514 if (getCaseValues() &&
515 getValue().getType() != getCaseValues()->getElementType())
516 return emitError("expects case value type to match condition value type");
517 return success();
520 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
521 assert(index < getNumSuccessors() && "invalid successor index");
522 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
523 : getCaseOperandsMutable(index - 1));
526 //===----------------------------------------------------------------------===//
527 // Code for LLVM::GEPOp.
528 //===----------------------------------------------------------------------===//
530 constexpr int32_t GEPOp::kDynamicIndex;
532 GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() {
533 return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(),
534 getDynamicIndices());
537 /// Returns the elemental type of any LLVM-compatible vector type or self.
538 static Type extractVectorElementType(Type type) {
539 if (auto vectorType = llvm::dyn_cast<VectorType>(type))
540 return vectorType.getElementType();
541 if (auto scalableVectorType = llvm::dyn_cast<LLVMScalableVectorType>(type))
542 return scalableVectorType.getElementType();
543 if (auto fixedVectorType = llvm::dyn_cast<LLVMFixedVectorType>(type))
544 return fixedVectorType.getElementType();
545 return type;
548 /// Destructures the 'indices' parameter into 'rawConstantIndices' and
549 /// 'dynamicIndices', encoding the former in the process. In the process,
550 /// dynamic indices which are used to index into a structure type are converted
551 /// to constant indices when possible. To do this, the GEPs element type should
552 /// be passed as first parameter.
553 static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
554 SmallVectorImpl<int32_t> &rawConstantIndices,
555 SmallVectorImpl<Value> &dynamicIndices) {
556 for (const GEPArg &iter : indices) {
557 // If the thing we are currently indexing into is a struct we must turn
558 // any integer constants into constant indices. If this is not possible
559 // we don't do anything here. The verifier will catch it and emit a proper
560 // error. All other canonicalization is done in the fold method.
561 bool requiresConst = !rawConstantIndices.empty() &&
562 currType.isa_and_nonnull<LLVMStructType>();
563 if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
564 APInt intC;
565 if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
566 intC.isSignedIntN(kGEPConstantBitWidth)) {
567 rawConstantIndices.push_back(intC.getSExtValue());
568 } else {
569 rawConstantIndices.push_back(GEPOp::kDynamicIndex);
570 dynamicIndices.push_back(val);
572 } else {
573 rawConstantIndices.push_back(iter.get<GEPConstantIndex>());
576 // Skip for very first iteration of this loop. First index does not index
577 // within the aggregates, but is just a pointer offset.
578 if (rawConstantIndices.size() == 1 || !currType)
579 continue;
581 currType =
582 TypeSwitch<Type, Type>(currType)
583 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
584 LLVMArrayType>([](auto containerType) {
585 return containerType.getElementType();
587 .Case([&](LLVMStructType structType) -> Type {
588 int64_t memberIndex = rawConstantIndices.back();
589 if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
590 structType.getBody().size())
591 return structType.getBody()[memberIndex];
592 return nullptr;
594 .Default(Type(nullptr));
598 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
599 Type elementType, Value basePtr, ArrayRef<GEPArg> indices,
600 bool inbounds, ArrayRef<NamedAttribute> attributes) {
601 SmallVector<int32_t> rawConstantIndices;
602 SmallVector<Value> dynamicIndices;
603 destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices);
605 result.addTypes(resultType);
606 result.addAttributes(attributes);
607 result.addAttribute(getRawConstantIndicesAttrName(result.name),
608 builder.getDenseI32ArrayAttr(rawConstantIndices));
609 if (inbounds) {
610 result.addAttribute(getInboundsAttrName(result.name),
611 builder.getUnitAttr());
613 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType));
614 result.addOperands(basePtr);
615 result.addOperands(dynamicIndices);
618 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
619 Type elementType, Value basePtr, ValueRange indices,
620 bool inbounds, ArrayRef<NamedAttribute> attributes) {
621 build(builder, result, resultType, elementType, basePtr,
622 SmallVector<GEPArg>(indices), inbounds, attributes);
625 static ParseResult
626 parseGEPIndices(OpAsmParser &parser,
627 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices,
628 DenseI32ArrayAttr &rawConstantIndices) {
629 SmallVector<int32_t> constantIndices;
631 auto idxParser = [&]() -> ParseResult {
632 int32_t constantIndex;
633 OptionalParseResult parsedInteger =
634 parser.parseOptionalInteger(constantIndex);
635 if (parsedInteger.has_value()) {
636 if (failed(parsedInteger.value()))
637 return failure();
638 constantIndices.push_back(constantIndex);
639 return success();
642 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
643 return parser.parseOperand(indices.emplace_back());
645 if (parser.parseCommaSeparatedList(idxParser))
646 return failure();
648 rawConstantIndices =
649 DenseI32ArrayAttr::get(parser.getContext(), constantIndices);
650 return success();
653 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
654 OperandRange indices,
655 DenseI32ArrayAttr rawConstantIndices) {
656 llvm::interleaveComma(
657 GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer,
658 [&](PointerUnion<IntegerAttr, Value> cst) {
659 if (Value val = llvm::dyn_cast_if_present<Value>(cst))
660 printer.printOperand(val);
661 else
662 printer << cst.get<IntegerAttr>().getInt();
666 /// For the given `indices`, check if they comply with `baseGEPType`,
667 /// especially check against LLVMStructTypes nested within.
668 static LogicalResult
669 verifyStructIndices(Type baseGEPType, unsigned indexPos,
670 GEPIndicesAdaptor<ValueRange> indices,
671 function_ref<InFlightDiagnostic()> emitOpError) {
672 if (indexPos >= indices.size())
673 // Stop searching
674 return success();
676 return TypeSwitch<Type, LogicalResult>(baseGEPType)
677 .Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult {
678 if (!indices[indexPos].is<IntegerAttr>())
679 return emitOpError() << "expected index " << indexPos
680 << " indexing a struct to be constant";
682 int32_t gepIndex = indices[indexPos].get<IntegerAttr>().getInt();
683 ArrayRef<Type> elementTypes = structType.getBody();
684 if (gepIndex < 0 ||
685 static_cast<size_t>(gepIndex) >= elementTypes.size())
686 return emitOpError() << "index " << indexPos
687 << " indexing a struct is out of bounds";
689 // Instead of recursively going into every children types, we only
690 // dive into the one indexed by gepIndex.
691 return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
692 indices, emitOpError);
694 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
695 LLVMArrayType>([&](auto containerType) -> LogicalResult {
696 return verifyStructIndices(containerType.getElementType(), indexPos + 1,
697 indices, emitOpError);
699 .Default([&](auto otherType) -> LogicalResult {
700 return emitOpError()
701 << "type " << otherType << " cannot be indexed (index #"
702 << indexPos << ")";
706 /// Driver function around `verifyStructIndices`.
707 static LogicalResult
708 verifyStructIndices(Type baseGEPType, GEPIndicesAdaptor<ValueRange> indices,
709 function_ref<InFlightDiagnostic()> emitOpError) {
710 return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices, emitOpError);
713 LogicalResult LLVM::GEPOp::verify() {
714 if (static_cast<size_t>(
715 llvm::count(getRawConstantIndices(), kDynamicIndex)) !=
716 getDynamicIndices().size())
717 return emitOpError("expected as many dynamic indices as specified in '")
718 << getRawConstantIndicesAttrName().getValue() << "'";
720 return verifyStructIndices(getElemType(), getIndices(),
721 [&] { return emitOpError(); });
724 Type GEPOp::getResultPtrElementType() {
725 // Set the initial type currently being used for indexing. This will be
726 // updated as the indices get walked over.
727 Type selectedType = getElemType();
729 // Follow the indexed elements in the gep.
730 auto indices = getIndices();
731 for (GEPIndicesAdaptor<ValueRange>::value_type index :
732 llvm::drop_begin(indices)) {
733 // GEPs can only index into aggregates which can be structs or arrays.
735 // The resulting type if indexing into an array type is always the element
736 // type, regardless of index.
737 if (auto arrayType = dyn_cast<LLVMArrayType>(selectedType)) {
738 selectedType = arrayType.getElementType();
739 continue;
742 // The GEP verifier ensures that any index into structs are static and
743 // that they refer to a field within the struct.
744 selectedType = cast<DestructurableTypeInterface>(selectedType)
745 .getTypeAtIndex(cast<IntegerAttr>(index));
748 // When there are no more indices, the type currently being used for indexing
749 // is the type of the value pointed at by the returned indexed pointer.
750 return selectedType;
753 //===----------------------------------------------------------------------===//
754 // LoadOp
755 //===----------------------------------------------------------------------===//
757 void LoadOp::getEffects(
758 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
759 &effects) {
760 effects.emplace_back(MemoryEffects::Read::get(), getAddr());
761 // Volatile operations can have target-specific read-write effects on
762 // memory besides the one referred to by the pointer operand.
763 // Similarly, atomic operations that are monotonic or stricter cause
764 // synchronization that from a language point-of-view, are arbitrary
765 // read-writes into memory.
766 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
767 getOrdering() != AtomicOrdering::unordered)) {
768 effects.emplace_back(MemoryEffects::Write::get());
769 effects.emplace_back(MemoryEffects::Read::get());
773 /// Returns true if the given type is supported by atomic operations. All
774 /// integer and float types with limited bit width are supported. Additionally,
775 /// depending on the operation pointers may be supported as well.
776 static bool isTypeCompatibleWithAtomicOp(Type type, bool isPointerTypeAllowed) {
777 if (llvm::isa<LLVMPointerType>(type))
778 return isPointerTypeAllowed;
780 std::optional<unsigned> bitWidth;
781 if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
782 if (!isCompatibleFloatingPointType(type))
783 return false;
784 bitWidth = floatType.getWidth();
786 if (auto integerType = llvm::dyn_cast<IntegerType>(type))
787 bitWidth = integerType.getWidth();
788 // The type is neither an integer, float, or pointer type.
789 if (!bitWidth)
790 return false;
791 return *bitWidth == 8 || *bitWidth == 16 || *bitWidth == 32 ||
792 *bitWidth == 64;
795 /// Verifies the attributes and the type of atomic memory access operations.
796 template <typename OpTy>
797 LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType,
798 ArrayRef<AtomicOrdering> unsupportedOrderings) {
799 if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
800 if (!isTypeCompatibleWithAtomicOp(valueType,
801 /*isPointerTypeAllowed=*/true))
802 return memOp.emitOpError("unsupported type ")
803 << valueType << " for atomic access";
804 if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
805 return memOp.emitOpError("unsupported ordering '")
806 << stringifyAtomicOrdering(memOp.getOrdering()) << "'";
807 if (!memOp.getAlignment())
808 return memOp.emitOpError("expected alignment for atomic access");
809 return success();
811 if (memOp.getSyncscope())
812 return memOp.emitOpError(
813 "expected syncscope to be null for non-atomic access");
814 return success();
817 LogicalResult LoadOp::verify() {
818 Type valueType = getResult().getType();
819 return verifyAtomicMemOp(*this, valueType,
820 {AtomicOrdering::release, AtomicOrdering::acq_rel});
823 void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
824 Value addr, unsigned alignment, bool isVolatile,
825 bool isNonTemporal, AtomicOrdering ordering,
826 StringRef syncscope) {
827 build(builder, state, type, addr,
828 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
829 isNonTemporal, ordering,
830 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
831 /*access_groups=*/nullptr,
832 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
833 /*tbaa=*/nullptr);
836 //===----------------------------------------------------------------------===//
837 // StoreOp
838 //===----------------------------------------------------------------------===//
840 void StoreOp::getEffects(
841 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
842 &effects) {
843 effects.emplace_back(MemoryEffects::Write::get(), getAddr());
844 // Volatile operations can have target-specific read-write effects on
845 // memory besides the one referred to by the pointer operand.
846 // Similarly, atomic operations that are monotonic or stricter cause
847 // synchronization that from a language point-of-view, are arbitrary
848 // read-writes into memory.
849 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
850 getOrdering() != AtomicOrdering::unordered)) {
851 effects.emplace_back(MemoryEffects::Write::get());
852 effects.emplace_back(MemoryEffects::Read::get());
856 LogicalResult StoreOp::verify() {
857 Type valueType = getValue().getType();
858 return verifyAtomicMemOp(*this, valueType,
859 {AtomicOrdering::acquire, AtomicOrdering::acq_rel});
862 void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
863 Value addr, unsigned alignment, bool isVolatile,
864 bool isNonTemporal, AtomicOrdering ordering,
865 StringRef syncscope) {
866 build(builder, state, value, addr,
867 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
868 isNonTemporal, ordering,
869 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
870 /*access_groups=*/nullptr,
871 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
874 //===----------------------------------------------------------------------===//
875 // CallOp
876 //===----------------------------------------------------------------------===//
878 /// Gets the MLIR Op-like result types of a LLVMFunctionType.
879 static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
880 SmallVector<Type, 1> results;
881 Type resultType = calleeType.getReturnType();
882 if (!isa<LLVM::LLVMVoidType>(resultType))
883 results.push_back(resultType);
884 return results;
887 /// Constructs a LLVMFunctionType from MLIR `results` and `args`.
888 static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
889 ValueRange args) {
890 Type resultType;
891 if (results.empty())
892 resultType = LLVMVoidType::get(context);
893 else
894 resultType = results.front();
895 return LLVMFunctionType::get(resultType, llvm::to_vector(args.getTypes()),
896 /*isVarArg=*/false);
899 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
900 StringRef callee, ValueRange args) {
901 build(builder, state, results, builder.getStringAttr(callee), args);
904 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
905 StringAttr callee, ValueRange args) {
906 build(builder, state, results, SymbolRefAttr::get(callee), args);
909 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
910 FlatSymbolRefAttr callee, ValueRange args) {
911 build(builder, state, results,
912 TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
913 callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
914 /*CConv=*/nullptr,
915 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
916 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
919 void CallOp::build(OpBuilder &builder, OperationState &state,
920 LLVMFunctionType calleeType, StringRef callee,
921 ValueRange args) {
922 build(builder, state, calleeType, builder.getStringAttr(callee), args);
925 void CallOp::build(OpBuilder &builder, OperationState &state,
926 LLVMFunctionType calleeType, StringAttr callee,
927 ValueRange args) {
928 build(builder, state, calleeType, SymbolRefAttr::get(callee), args);
931 void CallOp::build(OpBuilder &builder, OperationState &state,
932 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
933 ValueRange args) {
934 build(builder, state, getCallOpResultTypes(calleeType),
935 TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
936 /*branch_weights=*/nullptr, /*CConv=*/nullptr,
937 /*access_groups=*/nullptr,
938 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
941 void CallOp::build(OpBuilder &builder, OperationState &state,
942 LLVMFunctionType calleeType, ValueRange args) {
943 build(builder, state, getCallOpResultTypes(calleeType),
944 TypeAttr::get(calleeType), /*callee=*/nullptr, args,
945 /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
946 /*CConv=*/nullptr,
947 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
948 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
951 void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
952 ValueRange args) {
953 auto calleeType = func.getFunctionType();
954 build(builder, state, getCallOpResultTypes(calleeType),
955 TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
956 /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
957 /*CConv=*/nullptr,
958 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
959 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
962 CallInterfaceCallable CallOp::getCallableForCallee() {
963 // Direct call.
964 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
965 return calleeAttr;
966 // Indirect call, callee Value is the first operand.
967 return getOperand(0);
970 void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
971 // Direct call.
972 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
973 auto symRef = callee.get<SymbolRefAttr>();
974 return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
976 // Indirect call, callee Value is the first operand.
977 return setOperand(0, callee.get<Value>());
980 Operation::operand_range CallOp::getArgOperands() {
981 return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
984 MutableOperandRange CallOp::getArgOperandsMutable() {
985 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
986 getCalleeOperands().size());
989 /// Verify that an inlinable callsite of a debug-info-bearing function in a
990 /// debug-info-bearing function has a debug location attached to it. This
991 /// mirrors an LLVM IR verifier.
992 static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
993 if (callee.isExternal())
994 return success();
995 auto parentFunc = callOp->getParentOfType<FunctionOpInterface>();
996 if (!parentFunc)
997 return success();
999 auto hasSubprogram = [](Operation *op) {
1000 return op->getLoc()
1001 ->findInstanceOf<FusedLocWith<LLVM::DISubprogramAttr>>() !=
1002 nullptr;
1004 if (!hasSubprogram(parentFunc) || !hasSubprogram(callee))
1005 return success();
1006 bool containsLoc = !isa<UnknownLoc>(callOp->getLoc());
1007 if (!containsLoc)
1008 return callOp.emitError()
1009 << "inlinable function call in a function with a DISubprogram "
1010 "location must have a debug location";
1011 return success();
1014 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1015 if (getNumResults() > 1)
1016 return emitOpError("must have 0 or 1 result");
1018 // Type for the callee, we'll get it differently depending if it is a direct
1019 // or indirect call.
1020 Type fnType;
1022 bool isIndirect = false;
1024 // If this is an indirect call, the callee attribute is missing.
1025 FlatSymbolRefAttr calleeName = getCalleeAttr();
1026 if (!calleeName) {
1027 isIndirect = true;
1028 if (!getNumOperands())
1029 return emitOpError(
1030 "must have either a `callee` attribute or at least an operand");
1031 auto ptrType = llvm::dyn_cast<LLVMPointerType>(getOperand(0).getType());
1032 if (!ptrType)
1033 return emitOpError("indirect call expects a pointer as callee: ")
1034 << getOperand(0).getType();
1036 return success();
1037 } else {
1038 Operation *callee =
1039 symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr());
1040 if (!callee)
1041 return emitOpError()
1042 << "'" << calleeName.getValue()
1043 << "' does not reference a symbol in the current scope";
1044 auto fn = dyn_cast<LLVMFuncOp>(callee);
1045 if (!fn)
1046 return emitOpError() << "'" << calleeName.getValue()
1047 << "' does not reference a valid LLVM function";
1049 if (failed(verifyCallOpDebugInfo(*this, fn)))
1050 return failure();
1051 fnType = fn.getFunctionType();
1054 LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType);
1055 if (!funcType)
1056 return emitOpError("callee does not have a functional type: ") << fnType;
1058 if (funcType.isVarArg() && !getCalleeType())
1059 return emitOpError() << "missing callee type attribute for vararg call";
1061 // Verify that the operand and result types match the callee.
1063 if (!funcType.isVarArg() &&
1064 funcType.getNumParams() != (getNumOperands() - isIndirect))
1065 return emitOpError() << "incorrect number of operands ("
1066 << (getNumOperands() - isIndirect)
1067 << ") for callee (expecting: "
1068 << funcType.getNumParams() << ")";
1070 if (funcType.getNumParams() > (getNumOperands() - isIndirect))
1071 return emitOpError() << "incorrect number of operands ("
1072 << (getNumOperands() - isIndirect)
1073 << ") for varargs callee (expecting at least: "
1074 << funcType.getNumParams() << ")";
1076 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
1077 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
1078 return emitOpError() << "operand type mismatch for operand " << i << ": "
1079 << getOperand(i + isIndirect).getType()
1080 << " != " << funcType.getParamType(i);
1082 if (getNumResults() == 0 &&
1083 !llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1084 return emitOpError() << "expected function call to produce a value";
1086 if (getNumResults() != 0 &&
1087 llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1088 return emitOpError()
1089 << "calling function with void result must not produce values";
1091 if (getNumResults() > 1)
1092 return emitOpError()
1093 << "expected LLVM function call to produce 0 or 1 result";
1095 if (getNumResults() && getResult().getType() != funcType.getReturnType())
1096 return emitOpError() << "result type mismatch: " << getResult().getType()
1097 << " != " << funcType.getReturnType();
1099 return success();
1102 void CallOp::print(OpAsmPrinter &p) {
1103 auto callee = getCallee();
1104 bool isDirect = callee.has_value();
1106 LLVMFunctionType calleeType;
1107 bool isVarArg = false;
1109 if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
1110 calleeType = *optionalCalleeType;
1111 isVarArg = calleeType.isVarArg();
1114 p << ' ';
1116 // Print calling convention.
1117 if (getCConv() != LLVM::CConv::C)
1118 p << stringifyCConv(getCConv()) << ' ';
1120 // Print the direct callee if present as a function attribute, or an indirect
1121 // callee (first operand) otherwise.
1122 if (isDirect)
1123 p.printSymbolName(callee.value());
1124 else
1125 p << getOperand(0);
1127 auto args = getOperands().drop_front(isDirect ? 0 : 1);
1128 p << '(' << args << ')';
1130 if (isVarArg)
1131 p << " vararg(" << calleeType << ")";
1133 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
1134 {getCConvAttrName(), "callee", "callee_type"});
1136 p << " : ";
1137 if (!isDirect)
1138 p << getOperand(0).getType() << ", ";
1140 // Reconstruct the function MLIR function type from operand and result types.
1141 p.printFunctionalType(args.getTypes(), getResultTypes());
1144 /// Parses the type of a call operation and resolves the operands if the parsing
1145 /// succeeds. Returns failure otherwise.
1146 static ParseResult parseCallTypeAndResolveOperands(
1147 OpAsmParser &parser, OperationState &result, bool isDirect,
1148 ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
1149 SMLoc trailingTypesLoc = parser.getCurrentLocation();
1150 SmallVector<Type> types;
1151 if (parser.parseColonTypeList(types))
1152 return failure();
1154 if (isDirect && types.size() != 1)
1155 return parser.emitError(trailingTypesLoc,
1156 "expected direct call to have 1 trailing type");
1157 if (!isDirect && types.size() != 2)
1158 return parser.emitError(trailingTypesLoc,
1159 "expected indirect call to have 2 trailing types");
1161 auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
1162 if (!funcType)
1163 return parser.emitError(trailingTypesLoc,
1164 "expected trailing function type");
1165 if (funcType.getNumResults() > 1)
1166 return parser.emitError(trailingTypesLoc,
1167 "expected function with 0 or 1 result");
1168 if (funcType.getNumResults() == 1 &&
1169 llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0)))
1170 return parser.emitError(trailingTypesLoc,
1171 "expected a non-void result type");
1173 // The head element of the types list matches the callee type for
1174 // indirect calls, while the types list is emtpy for direct calls.
1175 // Append the function input types to resolve the call operation
1176 // operands.
1177 llvm::append_range(types, funcType.getInputs());
1178 if (parser.resolveOperands(operands, types, parser.getNameLoc(),
1179 result.operands))
1180 return failure();
1181 if (funcType.getNumResults() != 0)
1182 result.addTypes(funcType.getResults());
1184 return success();
1187 /// Parses an optional function pointer operand before the call argument list
1188 /// for indirect calls, or stops parsing at the function identifier otherwise.
1189 static ParseResult parseOptionalCallFuncPtr(
1190 OpAsmParser &parser,
1191 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands) {
1192 OpAsmParser::UnresolvedOperand funcPtrOperand;
1193 OptionalParseResult parseResult = parser.parseOptionalOperand(funcPtrOperand);
1194 if (parseResult.has_value()) {
1195 if (failed(*parseResult))
1196 return *parseResult;
1197 operands.push_back(funcPtrOperand);
1199 return success();
1202 // <operation> ::= `llvm.call` (cconv)? (function-id | ssa-use)
1203 // `(` ssa-use-list `)`
1204 // ( `vararg(` var-arg-func-type `)` )?
1205 // attribute-dict? `:` (type `,`)? function-type
1206 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1207 SymbolRefAttr funcAttr;
1208 TypeAttr calleeType;
1209 SmallVector<OpAsmParser::UnresolvedOperand> operands;
1211 // Default to C Calling Convention if no keyword is provided.
1212 result.addAttribute(
1213 getCConvAttrName(result.name),
1214 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
1215 parser, result, LLVM::CConv::C)));
1217 // Parse a function pointer for indirect calls.
1218 if (parseOptionalCallFuncPtr(parser, operands))
1219 return failure();
1220 bool isDirect = operands.empty();
1222 // Parse a function identifier for direct calls.
1223 if (isDirect)
1224 if (parser.parseAttribute(funcAttr, "callee", result.attributes))
1225 return failure();
1227 // Parse the function arguments.
1228 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren))
1229 return failure();
1231 bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1232 if (isVarArg) {
1233 if (parser.parseLParen().failed() ||
1234 parser.parseAttribute(calleeType, "callee_type", result.attributes)
1235 .failed() ||
1236 parser.parseRParen().failed())
1237 return failure();
1240 if (parser.parseOptionalAttrDict(result.attributes))
1241 return failure();
1243 // Parse the trailing type list and resolve the operands.
1244 return parseCallTypeAndResolveOperands(parser, result, isDirect, operands);
1247 LLVMFunctionType CallOp::getCalleeFunctionType() {
1248 if (getCalleeType())
1249 return *getCalleeType();
1250 else
1251 return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1254 ///===---------------------------------------------------------------------===//
1255 /// LLVM::InvokeOp
1256 ///===---------------------------------------------------------------------===//
1258 void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1259 ValueRange ops, Block *normal, ValueRange normalOps,
1260 Block *unwind, ValueRange unwindOps) {
1261 auto calleeType = func.getFunctionType();
1262 build(builder, state, getCallOpResultTypes(calleeType),
1263 TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
1264 unwindOps, nullptr, nullptr, normal, unwind);
1267 void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
1268 FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
1269 ValueRange normalOps, Block *unwind,
1270 ValueRange unwindOps) {
1271 build(builder, state, tys,
1272 TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
1273 ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
1276 void InvokeOp::build(OpBuilder &builder, OperationState &state,
1277 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1278 ValueRange ops, Block *normal, ValueRange normalOps,
1279 Block *unwind, ValueRange unwindOps) {
1280 build(builder, state, getCallOpResultTypes(calleeType),
1281 TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
1282 nullptr, normal, unwind);
1285 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
1286 assert(index < getNumSuccessors() && "invalid successor index");
1287 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
1288 : getUnwindDestOperandsMutable());
1291 CallInterfaceCallable InvokeOp::getCallableForCallee() {
1292 // Direct call.
1293 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1294 return calleeAttr;
1295 // Indirect call, callee Value is the first operand.
1296 return getOperand(0);
1299 void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1300 // Direct call.
1301 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1302 auto symRef = callee.get<SymbolRefAttr>();
1303 return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
1305 // Indirect call, callee Value is the first operand.
1306 return setOperand(0, callee.get<Value>());
1309 Operation::operand_range InvokeOp::getArgOperands() {
1310 return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
1313 MutableOperandRange InvokeOp::getArgOperandsMutable() {
1314 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1315 getCalleeOperands().size());
1318 LogicalResult InvokeOp::verify() {
1319 if (getNumResults() > 1)
1320 return emitOpError("must have 0 or 1 result");
1322 Block *unwindDest = getUnwindDest();
1323 if (unwindDest->empty())
1324 return emitError("must have at least one operation in unwind destination");
1326 // In unwind destination, first operation must be LandingpadOp
1327 if (!isa<LandingpadOp>(unwindDest->front()))
1328 return emitError("first operation in unwind destination should be a "
1329 "llvm.landingpad operation");
1331 return success();
1334 void InvokeOp::print(OpAsmPrinter &p) {
1335 auto callee = getCallee();
1336 bool isDirect = callee.has_value();
1338 LLVMFunctionType calleeType;
1339 bool isVarArg = false;
1341 if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
1342 calleeType = *optionalCalleeType;
1343 isVarArg = calleeType.isVarArg();
1346 p << ' ';
1348 // Print calling convention.
1349 if (getCConv() != LLVM::CConv::C)
1350 p << stringifyCConv(getCConv()) << ' ';
1352 // Either function name or pointer
1353 if (isDirect)
1354 p.printSymbolName(callee.value());
1355 else
1356 p << getOperand(0);
1358 p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')';
1359 p << " to ";
1360 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands());
1361 p << " unwind ";
1362 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
1364 if (isVarArg)
1365 p << " vararg(" << calleeType << ")";
1367 p.printOptionalAttrDict((*this)->getAttrs(),
1368 {InvokeOp::getOperandSegmentSizeAttr(), "callee",
1369 "callee_type", InvokeOp::getCConvAttrName()});
1371 p << " : ";
1372 if (!isDirect)
1373 p << getOperand(0).getType() << ", ";
1374 p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1),
1375 getResultTypes());
1378 // <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
1379 // `(` ssa-use-list `)`
1380 // `to` bb-id (`[` ssa-use-and-type-list `]`)?
1381 // `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
1382 // ( `vararg(` var-arg-func-type `)` )?
1383 // attribute-dict? `:` (type `,`)? function-type
1384 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
1385 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
1386 SymbolRefAttr funcAttr;
1387 TypeAttr calleeType;
1388 Block *normalDest, *unwindDest;
1389 SmallVector<Value, 4> normalOperands, unwindOperands;
1390 Builder &builder = parser.getBuilder();
1392 // Default to C Calling Convention if no keyword is provided.
1393 result.addAttribute(
1394 getCConvAttrName(result.name),
1395 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
1396 parser, result, LLVM::CConv::C)));
1398 // Parse a function pointer for indirect calls.
1399 if (parseOptionalCallFuncPtr(parser, operands))
1400 return failure();
1401 bool isDirect = operands.empty();
1403 // Parse a function identifier for direct calls.
1404 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
1405 return failure();
1407 // Parse the function arguments.
1408 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
1409 parser.parseKeyword("to") ||
1410 parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
1411 parser.parseKeyword("unwind") ||
1412 parser.parseSuccessorAndUseList(unwindDest, unwindOperands))
1413 return failure();
1415 bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1416 if (isVarArg) {
1417 if (parser.parseLParen().failed() ||
1418 parser.parseAttribute(calleeType, "callee_type", result.attributes)
1419 .failed() ||
1420 parser.parseRParen().failed())
1421 return failure();
1424 if (parser.parseOptionalAttrDict(result.attributes))
1425 return failure();
1427 // Parse the trailing type list and resolve the function operands.
1428 if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
1429 return failure();
1431 result.addSuccessors({normalDest, unwindDest});
1432 result.addOperands(normalOperands);
1433 result.addOperands(unwindOperands);
1435 result.addAttribute(InvokeOp::getOperandSegmentSizeAttr(),
1436 builder.getDenseI32ArrayAttr(
1437 {static_cast<int32_t>(operands.size()),
1438 static_cast<int32_t>(normalOperands.size()),
1439 static_cast<int32_t>(unwindOperands.size())}));
1440 return success();
1443 LLVMFunctionType InvokeOp::getCalleeFunctionType() {
1444 if (getCalleeType())
1445 return *getCalleeType();
1446 else
1447 return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1450 ///===----------------------------------------------------------------------===//
1451 /// Verifying/Printing/Parsing for LLVM::LandingpadOp.
1452 ///===----------------------------------------------------------------------===//
1454 LogicalResult LandingpadOp::verify() {
1455 Value value;
1456 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
1457 if (!func.getPersonality())
1458 return emitError(
1459 "llvm.landingpad needs to be in a function with a personality");
1462 // Consistency of llvm.landingpad result types is checked in
1463 // LLVMFuncOp::verify().
1465 if (!getCleanup() && getOperands().empty())
1466 return emitError("landingpad instruction expects at least one clause or "
1467 "cleanup attribute");
1469 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
1470 value = getOperand(idx);
1471 bool isFilter = llvm::isa<LLVMArrayType>(value.getType());
1472 if (isFilter) {
1473 // FIXME: Verify filter clauses when arrays are appropriately handled
1474 } else {
1475 // catch - global addresses only.
1476 // Bitcast ops should have global addresses as their args.
1477 if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
1478 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
1479 continue;
1480 return emitError("constant clauses expected").attachNote(bcOp.getLoc())
1481 << "global addresses expected as operand to "
1482 "bitcast used in clauses for landingpad";
1484 // ZeroOp and AddressOfOp allowed
1485 if (value.getDefiningOp<ZeroOp>())
1486 continue;
1487 if (value.getDefiningOp<AddressOfOp>())
1488 continue;
1489 return emitError("clause #")
1490 << idx << " is not a known constant - null, addressof, bitcast";
1493 return success();
1496 void LandingpadOp::print(OpAsmPrinter &p) {
1497 p << (getCleanup() ? " cleanup " : " ");
1499 // Clauses
1500 for (auto value : getOperands()) {
1501 // Similar to llvm - if clause is an array type then it is filter
1502 // clause else catch clause
1503 bool isArrayTy = llvm::isa<LLVMArrayType>(value.getType());
1504 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
1505 << value.getType() << ") ";
1508 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"});
1510 p << ": " << getType();
1513 // <operation> ::= `llvm.landingpad` `cleanup`?
1514 // ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
1515 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
1516 // Check for cleanup
1517 if (succeeded(parser.parseOptionalKeyword("cleanup")))
1518 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
1520 // Parse clauses with types
1521 while (succeeded(parser.parseOptionalLParen()) &&
1522 (succeeded(parser.parseOptionalKeyword("filter")) ||
1523 succeeded(parser.parseOptionalKeyword("catch")))) {
1524 OpAsmParser::UnresolvedOperand operand;
1525 Type ty;
1526 if (parser.parseOperand(operand) || parser.parseColon() ||
1527 parser.parseType(ty) ||
1528 parser.resolveOperand(operand, ty, result.operands) ||
1529 parser.parseRParen())
1530 return failure();
1533 Type type;
1534 if (parser.parseColon() || parser.parseType(type))
1535 return failure();
1537 result.addTypes(type);
1538 return success();
1541 //===----------------------------------------------------------------------===//
1542 // ExtractValueOp
1543 //===----------------------------------------------------------------------===//
1545 /// Extract the type at `position` in the LLVM IR aggregate type
1546 /// `containerType`. Each element of `position` is an index into a nested
1547 /// aggregate type. Return the resulting type or emit an error.
1548 static Type getInsertExtractValueElementType(
1549 function_ref<InFlightDiagnostic(StringRef)> emitError, Type containerType,
1550 ArrayRef<int64_t> position) {
1551 Type llvmType = containerType;
1552 if (!isCompatibleType(containerType)) {
1553 emitError("expected LLVM IR Dialect type, got ") << containerType;
1554 return {};
1557 // Infer the element type from the structure type: iteratively step inside the
1558 // type by taking the element type, indexed by the position attribute for
1559 // structures. Check the position index before accessing, it is supposed to
1560 // be in bounds.
1561 for (int64_t idx : position) {
1562 if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(llvmType)) {
1563 if (idx < 0 || static_cast<unsigned>(idx) >= arrayType.getNumElements()) {
1564 emitError("position out of bounds: ") << idx;
1565 return {};
1567 llvmType = arrayType.getElementType();
1568 } else if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) {
1569 if (idx < 0 ||
1570 static_cast<unsigned>(idx) >= structType.getBody().size()) {
1571 emitError("position out of bounds: ") << idx;
1572 return {};
1574 llvmType = structType.getBody()[idx];
1575 } else {
1576 emitError("expected LLVM IR structure/array type, got: ") << llvmType;
1577 return {};
1580 return llvmType;
1583 /// Extract the type at `position` in the wrapped LLVM IR aggregate type
1584 /// `containerType`.
1585 static Type getInsertExtractValueElementType(Type llvmType,
1586 ArrayRef<int64_t> position) {
1587 for (int64_t idx : position) {
1588 if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType))
1589 llvmType = structType.getBody()[idx];
1590 else
1591 llvmType = llvm::cast<LLVMArrayType>(llvmType).getElementType();
1593 return llvmType;
1596 OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
1597 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
1598 OpFoldResult result = {};
1599 while (insertValueOp) {
1600 if (getPosition() == insertValueOp.getPosition())
1601 return insertValueOp.getValue();
1602 unsigned min =
1603 std::min(getPosition().size(), insertValueOp.getPosition().size());
1604 // If one is fully prefix of the other, stop propagating back as it will
1605 // miss dependencies. For instance, %3 should not fold to %f0 in the
1606 // following example:
1607 // ```
1608 // %1 = llvm.insertvalue %f0, %0[0, 0] :
1609 // !llvm.array<4 x !llvm.array<4 x f32>>
1610 // %2 = llvm.insertvalue %arr, %1[0] :
1611 // !llvm.array<4 x !llvm.array<4 x f32>>
1612 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
1613 // ```
1614 if (getPosition().take_front(min) ==
1615 insertValueOp.getPosition().take_front(min))
1616 return result;
1618 // If neither a prefix, nor the exact position, we can extract out of the
1619 // value being inserted into. Moreover, we can try again if that operand
1620 // is itself an insertvalue expression.
1621 getContainerMutable().assign(insertValueOp.getContainer());
1622 result = getResult();
1623 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
1625 return result;
1628 LogicalResult ExtractValueOp::verify() {
1629 auto emitError = [this](StringRef msg) { return emitOpError(msg); };
1630 Type valueType = getInsertExtractValueElementType(
1631 emitError, getContainer().getType(), getPosition());
1632 if (!valueType)
1633 return failure();
1635 if (getRes().getType() != valueType)
1636 return emitOpError() << "Type mismatch: extracting from "
1637 << getContainer().getType() << " should produce "
1638 << valueType << " but this op returns "
1639 << getRes().getType();
1640 return success();
1643 void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
1644 Value container, ArrayRef<int64_t> position) {
1645 build(builder, state,
1646 getInsertExtractValueElementType(container.getType(), position),
1647 container, builder.getAttr<DenseI64ArrayAttr>(position));
1650 //===----------------------------------------------------------------------===//
1651 // InsertValueOp
1652 //===----------------------------------------------------------------------===//
1654 /// Infer the value type from the container type and position.
1655 static ParseResult
1656 parseInsertExtractValueElementType(AsmParser &parser, Type &valueType,
1657 Type containerType,
1658 DenseI64ArrayAttr position) {
1659 valueType = getInsertExtractValueElementType(
1660 [&](StringRef msg) {
1661 return parser.emitError(parser.getCurrentLocation(), msg);
1663 containerType, position.asArrayRef());
1664 return success(!!valueType);
1667 /// Nothing to print for an inferred type.
1668 static void printInsertExtractValueElementType(AsmPrinter &printer,
1669 Operation *op, Type valueType,
1670 Type containerType,
1671 DenseI64ArrayAttr position) {}
1673 LogicalResult InsertValueOp::verify() {
1674 auto emitError = [this](StringRef msg) { return emitOpError(msg); };
1675 Type valueType = getInsertExtractValueElementType(
1676 emitError, getContainer().getType(), getPosition());
1677 if (!valueType)
1678 return failure();
1680 if (getValue().getType() != valueType)
1681 return emitOpError() << "Type mismatch: cannot insert "
1682 << getValue().getType() << " into "
1683 << getContainer().getType();
1685 return success();
1688 //===----------------------------------------------------------------------===//
1689 // ReturnOp
1690 //===----------------------------------------------------------------------===//
1692 LogicalResult ReturnOp::verify() {
1693 auto parent = (*this)->getParentOfType<LLVMFuncOp>();
1694 if (!parent)
1695 return success();
1697 Type expectedType = parent.getFunctionType().getReturnType();
1698 if (llvm::isa<LLVMVoidType>(expectedType)) {
1699 if (!getArg())
1700 return success();
1701 InFlightDiagnostic diag = emitOpError("expected no operands");
1702 diag.attachNote(parent->getLoc()) << "when returning from function";
1703 return diag;
1705 if (!getArg()) {
1706 if (llvm::isa<LLVMVoidType>(expectedType))
1707 return success();
1708 InFlightDiagnostic diag = emitOpError("expected 1 operand");
1709 diag.attachNote(parent->getLoc()) << "when returning from function";
1710 return diag;
1712 if (expectedType != getArg().getType()) {
1713 InFlightDiagnostic diag = emitOpError("mismatching result types");
1714 diag.attachNote(parent->getLoc()) << "when returning from function";
1715 return diag;
1717 return success();
1720 //===----------------------------------------------------------------------===//
1721 // Verifier for LLVM::AddressOfOp.
1722 //===----------------------------------------------------------------------===//
1724 static Operation *parentLLVMModule(Operation *op) {
1725 Operation *module = op->getParentOp();
1726 while (module && !satisfiesLLVMModule(module))
1727 module = module->getParentOp();
1728 assert(module && "unexpected operation outside of a module");
1729 return module;
1732 GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) {
1733 return dyn_cast_or_null<GlobalOp>(
1734 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
1737 LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) {
1738 return dyn_cast_or_null<LLVMFuncOp>(
1739 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
1742 LogicalResult
1743 AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1744 Operation *symbol =
1745 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr());
1747 auto global = dyn_cast_or_null<GlobalOp>(symbol);
1748 auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
1750 if (!global && !function)
1751 return emitOpError(
1752 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
1754 LLVMPointerType type = getType();
1755 if (global && global.getAddrSpace() != type.getAddressSpace())
1756 return emitOpError("pointer address space must match address space of the "
1757 "referenced global");
1759 return success();
1762 //===----------------------------------------------------------------------===//
1763 // Verifier for LLVM::ComdatOp.
1764 //===----------------------------------------------------------------------===//
1766 void ComdatOp::build(OpBuilder &builder, OperationState &result,
1767 StringRef symName) {
1768 result.addAttribute(getSymNameAttrName(result.name),
1769 builder.getStringAttr(symName));
1770 Region *body = result.addRegion();
1771 body->emplaceBlock();
1774 LogicalResult ComdatOp::verifyRegions() {
1775 Region &body = getBody();
1776 for (Operation &op : body.getOps())
1777 if (!isa<ComdatSelectorOp>(op))
1778 return op.emitError(
1779 "only comdat selector symbols can appear in a comdat region");
1781 return success();
1784 //===----------------------------------------------------------------------===//
1785 // Builder, printer and verifier for LLVM::GlobalOp.
1786 //===----------------------------------------------------------------------===//
1788 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
1789 bool isConstant, Linkage linkage, StringRef name,
1790 Attribute value, uint64_t alignment, unsigned addrSpace,
1791 bool dsoLocal, bool threadLocal, SymbolRefAttr comdat,
1792 ArrayRef<NamedAttribute> attrs,
1793 DIGlobalVariableExpressionAttr dbgExpr) {
1794 result.addAttribute(getSymNameAttrName(result.name),
1795 builder.getStringAttr(name));
1796 result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type));
1797 if (isConstant)
1798 result.addAttribute(getConstantAttrName(result.name),
1799 builder.getUnitAttr());
1800 if (value)
1801 result.addAttribute(getValueAttrName(result.name), value);
1802 if (dsoLocal)
1803 result.addAttribute(getDsoLocalAttrName(result.name),
1804 builder.getUnitAttr());
1805 if (threadLocal)
1806 result.addAttribute(getThreadLocal_AttrName(result.name),
1807 builder.getUnitAttr());
1808 if (comdat)
1809 result.addAttribute(getComdatAttrName(result.name), comdat);
1811 // Only add an alignment attribute if the "alignment" input
1812 // is different from 0. The value must also be a power of two, but
1813 // this is tested in GlobalOp::verify, not here.
1814 if (alignment != 0)
1815 result.addAttribute(getAlignmentAttrName(result.name),
1816 builder.getI64IntegerAttr(alignment));
1818 result.addAttribute(getLinkageAttrName(result.name),
1819 LinkageAttr::get(builder.getContext(), linkage));
1820 if (addrSpace != 0)
1821 result.addAttribute(getAddrSpaceAttrName(result.name),
1822 builder.getI32IntegerAttr(addrSpace));
1823 result.attributes.append(attrs.begin(), attrs.end());
1825 if (dbgExpr)
1826 result.addAttribute(getDbgExprAttrName(result.name), dbgExpr);
1828 result.addRegion();
1831 void GlobalOp::print(OpAsmPrinter &p) {
1832 p << ' ' << stringifyLinkage(getLinkage()) << ' ';
1833 StringRef visibility = stringifyVisibility(getVisibility_());
1834 if (!visibility.empty())
1835 p << visibility << ' ';
1836 if (getThreadLocal_())
1837 p << "thread_local ";
1838 if (auto unnamedAddr = getUnnamedAddr()) {
1839 StringRef str = stringifyUnnamedAddr(*unnamedAddr);
1840 if (!str.empty())
1841 p << str << ' ';
1843 if (getConstant())
1844 p << "constant ";
1845 p.printSymbolName(getSymName());
1846 p << '(';
1847 if (auto value = getValueOrNull())
1848 p.printAttribute(value);
1849 p << ')';
1850 if (auto comdat = getComdat())
1851 p << " comdat(" << *comdat << ')';
1853 // Note that the alignment attribute is printed using the
1854 // default syntax here, even though it is an inherent attribute
1855 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
1856 p.printOptionalAttrDict((*this)->getAttrs(),
1857 {SymbolTable::getSymbolAttrName(),
1858 getGlobalTypeAttrName(), getConstantAttrName(),
1859 getValueAttrName(), getLinkageAttrName(),
1860 getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
1861 getVisibility_AttrName(), getComdatAttrName(),
1862 getUnnamedAddrAttrName()});
1864 // Print the trailing type unless it's a string global.
1865 if (llvm::dyn_cast_or_null<StringAttr>(getValueOrNull()))
1866 return;
1867 p << " : " << getType();
1869 Region &initializer = getInitializerRegion();
1870 if (!initializer.empty()) {
1871 p << ' ';
1872 p.printRegion(initializer, /*printEntryBlockArgs=*/false);
1876 static LogicalResult verifyComdat(Operation *op,
1877 std::optional<SymbolRefAttr> attr) {
1878 if (!attr)
1879 return success();
1881 auto *comdatSelector = SymbolTable::lookupNearestSymbolFrom(op, *attr);
1882 if (!isa_and_nonnull<ComdatSelectorOp>(comdatSelector))
1883 return op->emitError() << "expected comdat symbol";
1885 return success();
1888 // operation ::= `llvm.mlir.global` linkage? visibility?
1889 // (`unnamed_addr` | `local_unnamed_addr`)?
1890 // `thread_local`? `constant`? `@` identifier
1891 // `(` attribute? `)` (`comdat(` symbol-ref-id `)`)?
1892 // attribute-list? (`:` type)? region?
1894 // The type can be omitted for string attributes, in which case it will be
1895 // inferred from the value of the string as [strlen(value) x i8].
1896 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
1897 MLIRContext *ctx = parser.getContext();
1898 // Parse optional linkage, default to External.
1899 result.addAttribute(getLinkageAttrName(result.name),
1900 LLVM::LinkageAttr::get(
1901 ctx, parseOptionalLLVMKeyword<Linkage>(
1902 parser, result, LLVM::Linkage::External)));
1904 // Parse optional visibility, default to Default.
1905 result.addAttribute(getVisibility_AttrName(result.name),
1906 parser.getBuilder().getI64IntegerAttr(
1907 parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>(
1908 parser, result, LLVM::Visibility::Default)));
1910 // Parse optional UnnamedAddr, default to None.
1911 result.addAttribute(getUnnamedAddrAttrName(result.name),
1912 parser.getBuilder().getI64IntegerAttr(
1913 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
1914 parser, result, LLVM::UnnamedAddr::None)));
1916 if (succeeded(parser.parseOptionalKeyword("thread_local")))
1917 result.addAttribute(getThreadLocal_AttrName(result.name),
1918 parser.getBuilder().getUnitAttr());
1920 if (succeeded(parser.parseOptionalKeyword("constant")))
1921 result.addAttribute(getConstantAttrName(result.name),
1922 parser.getBuilder().getUnitAttr());
1924 StringAttr name;
1925 if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
1926 result.attributes) ||
1927 parser.parseLParen())
1928 return failure();
1930 Attribute value;
1931 if (parser.parseOptionalRParen()) {
1932 if (parser.parseAttribute(value, getValueAttrName(result.name),
1933 result.attributes) ||
1934 parser.parseRParen())
1935 return failure();
1938 if (succeeded(parser.parseOptionalKeyword("comdat"))) {
1939 SymbolRefAttr comdat;
1940 if (parser.parseLParen() || parser.parseAttribute(comdat) ||
1941 parser.parseRParen())
1942 return failure();
1944 result.addAttribute(getComdatAttrName(result.name), comdat);
1947 SmallVector<Type, 1> types;
1948 if (parser.parseOptionalAttrDict(result.attributes) ||
1949 parser.parseOptionalColonTypeList(types))
1950 return failure();
1952 if (types.size() > 1)
1953 return parser.emitError(parser.getNameLoc(), "expected zero or one type");
1955 Region &initRegion = *result.addRegion();
1956 if (types.empty()) {
1957 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(value)) {
1958 MLIRContext *context = parser.getContext();
1959 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
1960 strAttr.getValue().size());
1961 types.push_back(arrayType);
1962 } else {
1963 return parser.emitError(parser.getNameLoc(),
1964 "type can only be omitted for string globals");
1966 } else {
1967 OptionalParseResult parseResult =
1968 parser.parseOptionalRegion(initRegion, /*arguments=*/{},
1969 /*argTypes=*/{});
1970 if (parseResult.has_value() && failed(*parseResult))
1971 return failure();
1974 result.addAttribute(getGlobalTypeAttrName(result.name),
1975 TypeAttr::get(types[0]));
1976 return success();
1979 static bool isZeroAttribute(Attribute value) {
1980 if (auto intValue = llvm::dyn_cast<IntegerAttr>(value))
1981 return intValue.getValue().isZero();
1982 if (auto fpValue = llvm::dyn_cast<FloatAttr>(value))
1983 return fpValue.getValue().isZero();
1984 if (auto splatValue = llvm::dyn_cast<SplatElementsAttr>(value))
1985 return isZeroAttribute(splatValue.getSplatValue<Attribute>());
1986 if (auto elementsValue = llvm::dyn_cast<ElementsAttr>(value))
1987 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
1988 if (auto arrayValue = llvm::dyn_cast<ArrayAttr>(value))
1989 return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
1990 return false;
1993 LogicalResult GlobalOp::verify() {
1994 bool validType = isCompatibleOuterType(getType())
1995 ? !llvm::isa<LLVMVoidType, LLVMTokenType,
1996 LLVMMetadataType, LLVMLabelType>(getType())
1997 : llvm::isa<PointerElementTypeInterface>(getType());
1998 if (!validType)
1999 return emitOpError(
2000 "expects type to be a valid element type for an LLVM global");
2001 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
2002 return emitOpError("must appear at the module level");
2004 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) {
2005 auto type = llvm::dyn_cast<LLVMArrayType>(getType());
2006 IntegerType elementType =
2007 type ? llvm::dyn_cast<IntegerType>(type.getElementType()) : nullptr;
2008 if (!elementType || elementType.getWidth() != 8 ||
2009 type.getNumElements() != strAttr.getValue().size())
2010 return emitOpError(
2011 "requires an i8 array type of the length equal to that of the string "
2012 "attribute");
2015 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
2016 if (!targetExtType.hasProperty(LLVMTargetExtType::CanBeGlobal))
2017 return emitOpError()
2018 << "this target extension type cannot be used in a global";
2020 if (Attribute value = getValueOrNull())
2021 return emitOpError() << "global with target extension type can only be "
2022 "initialized with zero-initializer";
2025 if (getLinkage() == Linkage::Common) {
2026 if (Attribute value = getValueOrNull()) {
2027 if (!isZeroAttribute(value)) {
2028 return emitOpError()
2029 << "expected zero value for '"
2030 << stringifyLinkage(Linkage::Common) << "' linkage";
2035 if (getLinkage() == Linkage::Appending) {
2036 if (!llvm::isa<LLVMArrayType>(getType())) {
2037 return emitOpError() << "expected array type for '"
2038 << stringifyLinkage(Linkage::Appending)
2039 << "' linkage";
2043 if (failed(verifyComdat(*this, getComdat())))
2044 return failure();
2046 std::optional<uint64_t> alignAttr = getAlignment();
2047 if (alignAttr.has_value()) {
2048 uint64_t value = alignAttr.value();
2049 if (!llvm::isPowerOf2_64(value))
2050 return emitError() << "alignment attribute is not a power of 2";
2053 return success();
2056 LogicalResult GlobalOp::verifyRegions() {
2057 if (Block *b = getInitializerBlock()) {
2058 ReturnOp ret = cast<ReturnOp>(b->getTerminator());
2059 if (ret.operand_type_begin() == ret.operand_type_end())
2060 return emitOpError("initializer region cannot return void");
2061 if (*ret.operand_type_begin() != getType())
2062 return emitOpError("initializer region type ")
2063 << *ret.operand_type_begin() << " does not match global type "
2064 << getType();
2066 for (Operation &op : *b) {
2067 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
2068 if (!iface || !iface.hasNoEffect())
2069 return op.emitError()
2070 << "ops with side effects not allowed in global initializers";
2073 if (getValueOrNull())
2074 return emitOpError("cannot have both initializer value and region");
2077 return success();
2080 //===----------------------------------------------------------------------===//
2081 // LLVM::GlobalCtorsOp
2082 //===----------------------------------------------------------------------===//
2084 LogicalResult
2085 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2086 for (Attribute ctor : getCtors()) {
2087 if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(ctor), *this,
2088 symbolTable)))
2089 return failure();
2091 return success();
2094 LogicalResult GlobalCtorsOp::verify() {
2095 if (getCtors().size() != getPriorities().size())
2096 return emitError(
2097 "mismatch between the number of ctors and the number of priorities");
2098 return success();
2101 //===----------------------------------------------------------------------===//
2102 // LLVM::GlobalDtorsOp
2103 //===----------------------------------------------------------------------===//
2105 LogicalResult
2106 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2107 for (Attribute dtor : getDtors()) {
2108 if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(dtor), *this,
2109 symbolTable)))
2110 return failure();
2112 return success();
2115 LogicalResult GlobalDtorsOp::verify() {
2116 if (getDtors().size() != getPriorities().size())
2117 return emitError(
2118 "mismatch between the number of dtors and the number of priorities");
2119 return success();
2122 //===----------------------------------------------------------------------===//
2123 // ShuffleVectorOp
2124 //===----------------------------------------------------------------------===//
2126 void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2127 Value v2, DenseI32ArrayAttr mask,
2128 ArrayRef<NamedAttribute> attrs) {
2129 auto containerType = v1.getType();
2130 auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType),
2131 mask.size(),
2132 LLVM::isScalableVectorType(containerType));
2133 build(builder, state, vType, v1, v2, mask);
2134 state.addAttributes(attrs);
2137 void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2138 Value v2, ArrayRef<int32_t> mask) {
2139 build(builder, state, v1, v2, builder.getDenseI32ArrayAttr(mask));
2142 /// Build the result type of a shuffle vector operation.
2143 static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
2144 Type &resType, DenseI32ArrayAttr mask) {
2145 if (!LLVM::isCompatibleVectorType(v1Type))
2146 return parser.emitError(parser.getCurrentLocation(),
2147 "expected an LLVM compatible vector type");
2148 resType = LLVM::getVectorType(LLVM::getVectorElementType(v1Type), mask.size(),
2149 LLVM::isScalableVectorType(v1Type));
2150 return success();
2153 /// Nothing to do when the result type is inferred.
2154 static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type,
2155 Type resType, DenseI32ArrayAttr mask) {}
2157 LogicalResult ShuffleVectorOp::verify() {
2158 if (LLVM::isScalableVectorType(getV1().getType()) &&
2159 llvm::any_of(getMask(), [](int32_t v) { return v != 0; }))
2160 return emitOpError("expected a splat operation for scalable vectors");
2161 return success();
2164 //===----------------------------------------------------------------------===//
2165 // Implementations for LLVM::LLVMFuncOp.
2166 //===----------------------------------------------------------------------===//
2168 // Add the entry block to the function.
2169 Block *LLVMFuncOp::addEntryBlock() {
2170 assert(empty() && "function already has an entry block");
2172 auto *entry = new Block;
2173 push_back(entry);
2175 // FIXME: Allow passing in proper locations for the entry arguments.
2176 LLVMFunctionType type = getFunctionType();
2177 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
2178 entry->addArgument(type.getParamType(i), getLoc());
2179 return entry;
2182 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
2183 StringRef name, Type type, LLVM::Linkage linkage,
2184 bool dsoLocal, CConv cconv, SymbolRefAttr comdat,
2185 ArrayRef<NamedAttribute> attrs,
2186 ArrayRef<DictionaryAttr> argAttrs,
2187 std::optional<uint64_t> functionEntryCount) {
2188 result.addRegion();
2189 result.addAttribute(SymbolTable::getSymbolAttrName(),
2190 builder.getStringAttr(name));
2191 result.addAttribute(getFunctionTypeAttrName(result.name),
2192 TypeAttr::get(type));
2193 result.addAttribute(getLinkageAttrName(result.name),
2194 LinkageAttr::get(builder.getContext(), linkage));
2195 result.addAttribute(getCConvAttrName(result.name),
2196 CConvAttr::get(builder.getContext(), cconv));
2197 result.attributes.append(attrs.begin(), attrs.end());
2198 if (dsoLocal)
2199 result.addAttribute(getDsoLocalAttrName(result.name),
2200 builder.getUnitAttr());
2201 if (comdat)
2202 result.addAttribute(getComdatAttrName(result.name), comdat);
2203 if (functionEntryCount)
2204 result.addAttribute(getFunctionEntryCountAttrName(result.name),
2205 builder.getI64IntegerAttr(functionEntryCount.value()));
2206 if (argAttrs.empty())
2207 return;
2209 assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() &&
2210 "expected as many argument attribute lists as arguments");
2211 function_interface_impl::addArgAndResultAttrs(
2212 builder, result, argAttrs, /*resultAttrs=*/std::nullopt,
2213 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2216 // Builds an LLVM function type from the given lists of input and output types.
2217 // Returns a null type if any of the types provided are non-LLVM types, or if
2218 // there is more than one output type.
2219 static Type
2220 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs,
2221 ArrayRef<Type> outputs,
2222 function_interface_impl::VariadicFlag variadicFlag) {
2223 Builder &b = parser.getBuilder();
2224 if (outputs.size() > 1) {
2225 parser.emitError(loc, "failed to construct function type: expected zero or "
2226 "one function result");
2227 return {};
2230 // Convert inputs to LLVM types, exit early on error.
2231 SmallVector<Type, 4> llvmInputs;
2232 for (auto t : inputs) {
2233 if (!isCompatibleType(t)) {
2234 parser.emitError(loc, "failed to construct function type: expected LLVM "
2235 "type for function arguments");
2236 return {};
2238 llvmInputs.push_back(t);
2241 // No output is denoted as "void" in LLVM type system.
2242 Type llvmOutput =
2243 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
2244 if (!isCompatibleType(llvmOutput)) {
2245 parser.emitError(loc, "failed to construct function type: expected LLVM "
2246 "type for function results")
2247 << llvmOutput;
2248 return {};
2250 return LLVMFunctionType::get(llvmOutput, llvmInputs,
2251 variadicFlag.isVariadic());
2254 // Parses an LLVM function.
2256 // operation ::= `llvm.func` linkage? cconv? function-signature
2257 // (`comdat(` symbol-ref-id `)`)?
2258 // function-attributes?
2259 // function-body
2261 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
2262 // Default to external linkage if no keyword is provided.
2263 result.addAttribute(
2264 getLinkageAttrName(result.name),
2265 LinkageAttr::get(parser.getContext(),
2266 parseOptionalLLVMKeyword<Linkage>(
2267 parser, result, LLVM::Linkage::External)));
2269 // Parse optional visibility, default to Default.
2270 result.addAttribute(getVisibility_AttrName(result.name),
2271 parser.getBuilder().getI64IntegerAttr(
2272 parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>(
2273 parser, result, LLVM::Visibility::Default)));
2275 // Parse optional UnnamedAddr, default to None.
2276 result.addAttribute(getUnnamedAddrAttrName(result.name),
2277 parser.getBuilder().getI64IntegerAttr(
2278 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
2279 parser, result, LLVM::UnnamedAddr::None)));
2281 // Default to C Calling Convention if no keyword is provided.
2282 result.addAttribute(
2283 getCConvAttrName(result.name),
2284 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
2285 parser, result, LLVM::CConv::C)));
2287 StringAttr nameAttr;
2288 SmallVector<OpAsmParser::Argument> entryArgs;
2289 SmallVector<DictionaryAttr> resultAttrs;
2290 SmallVector<Type> resultTypes;
2291 bool isVariadic;
2293 auto signatureLocation = parser.getCurrentLocation();
2294 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2295 result.attributes) ||
2296 function_interface_impl::parseFunctionSignature(
2297 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes,
2298 resultAttrs))
2299 return failure();
2301 SmallVector<Type> argTypes;
2302 for (auto &arg : entryArgs)
2303 argTypes.push_back(arg.type);
2304 auto type =
2305 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
2306 function_interface_impl::VariadicFlag(isVariadic));
2307 if (!type)
2308 return failure();
2309 result.addAttribute(getFunctionTypeAttrName(result.name),
2310 TypeAttr::get(type));
2312 if (succeeded(parser.parseOptionalKeyword("vscale_range"))) {
2313 int64_t minRange, maxRange;
2314 if (parser.parseLParen() || parser.parseInteger(minRange) ||
2315 parser.parseComma() || parser.parseInteger(maxRange) ||
2316 parser.parseRParen())
2317 return failure();
2318 auto intTy = IntegerType::get(parser.getContext(), 32);
2319 result.addAttribute(
2320 getVscaleRangeAttrName(result.name),
2321 LLVM::VScaleRangeAttr::get(parser.getContext(),
2322 IntegerAttr::get(intTy, minRange),
2323 IntegerAttr::get(intTy, maxRange)));
2325 // Parse the optional comdat selector.
2326 if (succeeded(parser.parseOptionalKeyword("comdat"))) {
2327 SymbolRefAttr comdat;
2328 if (parser.parseLParen() || parser.parseAttribute(comdat) ||
2329 parser.parseRParen())
2330 return failure();
2332 result.addAttribute(getComdatAttrName(result.name), comdat);
2335 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
2336 return failure();
2337 function_interface_impl::addArgAndResultAttrs(
2338 parser.getBuilder(), result, entryArgs, resultAttrs,
2339 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2341 auto *body = result.addRegion();
2342 OptionalParseResult parseResult =
2343 parser.parseOptionalRegion(*body, entryArgs);
2344 return failure(parseResult.has_value() && failed(*parseResult));
2347 // Print the LLVMFuncOp. Collects argument and result types and passes them to
2348 // helper functions. Drops "void" result since it cannot be parsed back. Skips
2349 // the external linkage since it is the default value.
2350 void LLVMFuncOp::print(OpAsmPrinter &p) {
2351 p << ' ';
2352 if (getLinkage() != LLVM::Linkage::External)
2353 p << stringifyLinkage(getLinkage()) << ' ';
2354 StringRef visibility = stringifyVisibility(getVisibility_());
2355 if (!visibility.empty())
2356 p << visibility << ' ';
2357 if (auto unnamedAddr = getUnnamedAddr()) {
2358 StringRef str = stringifyUnnamedAddr(*unnamedAddr);
2359 if (!str.empty())
2360 p << str << ' ';
2362 if (getCConv() != LLVM::CConv::C)
2363 p << stringifyCConv(getCConv()) << ' ';
2365 p.printSymbolName(getName());
2367 LLVMFunctionType fnType = getFunctionType();
2368 SmallVector<Type, 8> argTypes;
2369 SmallVector<Type, 1> resTypes;
2370 argTypes.reserve(fnType.getNumParams());
2371 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
2372 argTypes.push_back(fnType.getParamType(i));
2374 Type returnType = fnType.getReturnType();
2375 if (!llvm::isa<LLVMVoidType>(returnType))
2376 resTypes.push_back(returnType);
2378 function_interface_impl::printFunctionSignature(p, *this, argTypes,
2379 isVarArg(), resTypes);
2381 // Print vscale range if present
2382 if (std::optional<VScaleRangeAttr> vscale = getVscaleRange())
2383 p << " vscale_range(" << vscale->getMinRange().getInt() << ", "
2384 << vscale->getMaxRange().getInt() << ')';
2386 // Print the optional comdat selector.
2387 if (auto comdat = getComdat())
2388 p << " comdat(" << *comdat << ')';
2390 function_interface_impl::printFunctionAttributes(
2391 p, *this,
2392 {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
2393 getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
2394 getComdatAttrName(), getUnnamedAddrAttrName(),
2395 getVscaleRangeAttrName()});
2397 // Print the body if this is not an external function.
2398 Region &body = getBody();
2399 if (!body.empty()) {
2400 p << ' ';
2401 p.printRegion(body, /*printEntryBlockArgs=*/false,
2402 /*printBlockTerminators=*/true);
2406 // Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2407 // - functions don't have 'common' linkage
2408 // - external functions have 'external' or 'extern_weak' linkage;
2409 // - vararg is (currently) only supported for external functions;
2410 LogicalResult LLVMFuncOp::verify() {
2411 if (getLinkage() == LLVM::Linkage::Common)
2412 return emitOpError() << "functions cannot have '"
2413 << stringifyLinkage(LLVM::Linkage::Common)
2414 << "' linkage";
2416 if (failed(verifyComdat(*this, getComdat())))
2417 return failure();
2419 if (isExternal()) {
2420 if (getLinkage() != LLVM::Linkage::External &&
2421 getLinkage() != LLVM::Linkage::ExternWeak)
2422 return emitOpError() << "external functions must have '"
2423 << stringifyLinkage(LLVM::Linkage::External)
2424 << "' or '"
2425 << stringifyLinkage(LLVM::Linkage::ExternWeak)
2426 << "' linkage";
2427 return success();
2430 Type landingpadResultTy;
2431 StringRef diagnosticMessage;
2432 bool isLandingpadTypeConsistent =
2433 !walk([&](Operation *op) {
2434 const auto checkType = [&](Type type, StringRef errorMessage) {
2435 if (!landingpadResultTy) {
2436 landingpadResultTy = type;
2437 return WalkResult::advance();
2439 if (landingpadResultTy != type) {
2440 diagnosticMessage = errorMessage;
2441 return WalkResult::interrupt();
2443 return WalkResult::advance();
2445 return TypeSwitch<Operation *, WalkResult>(op)
2446 .Case<LandingpadOp>([&](auto landingpad) {
2447 constexpr StringLiteral errorMessage =
2448 "'llvm.landingpad' should have a consistent result type "
2449 "inside a function";
2450 return checkType(landingpad.getType(), errorMessage);
2452 .Case<ResumeOp>([&](auto resume) {
2453 constexpr StringLiteral errorMessage =
2454 "'llvm.resume' should have a consistent input type inside a "
2455 "function";
2456 return checkType(resume.getValue().getType(), errorMessage);
2458 .Default([](auto) { return WalkResult::skip(); });
2459 }).wasInterrupted();
2460 if (!isLandingpadTypeConsistent) {
2461 assert(!diagnosticMessage.empty() &&
2462 "Expecting a non-empty diagnostic message");
2463 return emitError(diagnosticMessage);
2466 return success();
2469 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2470 /// - entry block arguments are of LLVM types.
2471 LogicalResult LLVMFuncOp::verifyRegions() {
2472 if (isExternal())
2473 return success();
2475 unsigned numArguments = getFunctionType().getNumParams();
2476 Block &entryBlock = front();
2477 for (unsigned i = 0; i < numArguments; ++i) {
2478 Type argType = entryBlock.getArgument(i).getType();
2479 if (!isCompatibleType(argType))
2480 return emitOpError("entry block argument #")
2481 << i << " is not of LLVM type";
2484 return success();
2487 Region *LLVMFuncOp::getCallableRegion() {
2488 if (isExternal())
2489 return nullptr;
2490 return &getBody();
2493 //===----------------------------------------------------------------------===//
2494 // ZeroOp.
2495 //===----------------------------------------------------------------------===//
2497 LogicalResult LLVM::ZeroOp::verify() {
2498 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
2499 if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit))
2500 return emitOpError()
2501 << "target extension type does not support zero-initializer";
2503 return success();
2506 //===----------------------------------------------------------------------===//
2507 // ConstantOp.
2508 //===----------------------------------------------------------------------===//
2510 LogicalResult LLVM::ConstantOp::verify() {
2511 if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
2512 auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
2513 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
2514 !arrayType.getElementType().isInteger(8)) {
2515 return emitOpError() << "expected array type of "
2516 << sAttr.getValue().size()
2517 << " i8 elements for the string constant";
2519 return success();
2521 if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) {
2522 if (structType.getBody().size() != 2 ||
2523 structType.getBody()[0] != structType.getBody()[1]) {
2524 return emitError() << "expected struct type with two elements of the "
2525 "same type, the type of a complex constant";
2528 auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue());
2529 if (!arrayAttr || arrayAttr.size() != 2) {
2530 return emitOpError() << "expected array attribute with two elements, "
2531 "representing a complex constant";
2533 auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
2534 auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
2535 if (!re || !im || re.getType() != im.getType()) {
2536 return emitOpError()
2537 << "expected array attribute with two elements of the same type";
2540 Type elementType = structType.getBody()[0];
2541 if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>(
2542 elementType)) {
2543 return emitError()
2544 << "expected struct element types to be floating point type or "
2545 "integer type";
2547 return success();
2549 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
2550 return emitOpError() << "does not support target extension type.";
2552 if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue()))
2553 return emitOpError()
2554 << "only supports integer, float, string or elements attributes";
2555 if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
2556 if (!llvm::isa<IntegerType>(getType()))
2557 return emitOpError() << "expected integer type";
2559 if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
2560 const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
2561 unsigned floatWidth = APFloat::getSizeInBits(sem);
2562 if (auto floatTy = dyn_cast<FloatType>(getType())) {
2563 if (floatTy.getWidth() != floatWidth) {
2564 return emitOpError() << "expected float type of width " << floatWidth;
2567 // See the comment for getLLVMConstant for more details about why 8-bit
2568 // floats can be represented by integers.
2569 if (getType().isa<IntegerType>() && !getType().isInteger(floatWidth)) {
2570 return emitOpError() << "expected integer type of width " << floatWidth;
2573 if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue())) {
2574 if (!getType().isa<VectorType>() && !getType().isa<LLVM::LLVMArrayType>() &&
2575 !getType().isa<LLVM::LLVMFixedVectorType>() &&
2576 !getType().isa<LLVM::LLVMScalableVectorType>())
2577 return emitOpError() << "expected vector or array type";
2579 return success();
2582 bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) {
2583 // The value's type must be the same as the provided type.
2584 auto typedAttr = dyn_cast<TypedAttr>(value);
2585 if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type))
2586 return false;
2587 // The value's type must be an LLVM compatible type.
2588 if (!isCompatibleType(type))
2589 return false;
2590 // TODO: Add support for additional attributes kinds once needed.
2591 return isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
2594 ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value,
2595 Type type, Location loc) {
2596 if (isBuildableWith(value, type))
2597 return builder.create<LLVM::ConstantOp>(loc, cast<TypedAttr>(value));
2598 return nullptr;
2601 // Constant op constant-folds to its value.
2602 OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }
2604 //===----------------------------------------------------------------------===//
2605 // AtomicRMWOp
2606 //===----------------------------------------------------------------------===//
2608 void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,
2609 AtomicBinOp binOp, Value ptr, Value val,
2610 AtomicOrdering ordering, StringRef syncscope,
2611 unsigned alignment, bool isVolatile) {
2612 build(builder, state, val.getType(), binOp, ptr, val, ordering,
2613 !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
2614 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
2615 /*access_groups=*/nullptr,
2616 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
2619 LogicalResult AtomicRMWOp::verify() {
2620 auto valType = getVal().getType();
2621 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
2622 getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) {
2623 if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
2624 return emitOpError("expected LLVM IR floating point type");
2625 } else if (getBinOp() == AtomicBinOp::xchg) {
2626 if (!isTypeCompatibleWithAtomicOp(valType, /*isPointerTypeAllowed=*/true))
2627 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
2628 } else {
2629 auto intType = llvm::dyn_cast<IntegerType>(valType);
2630 unsigned intBitWidth = intType ? intType.getWidth() : 0;
2631 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2632 intBitWidth != 64)
2633 return emitOpError("expected LLVM IR integer type");
2636 if (static_cast<unsigned>(getOrdering()) <
2637 static_cast<unsigned>(AtomicOrdering::monotonic))
2638 return emitOpError() << "expected at least '"
2639 << stringifyAtomicOrdering(AtomicOrdering::monotonic)
2640 << "' ordering";
2642 return success();
2645 //===----------------------------------------------------------------------===//
2646 // AtomicCmpXchgOp
2647 //===----------------------------------------------------------------------===//
2649 /// Returns an LLVM struct type that contains a value type and a boolean type.
2650 static LLVMStructType getValAndBoolStructType(Type valType) {
2651 auto boolType = IntegerType::get(valType.getContext(), 1);
2652 return LLVMStructType::getLiteral(valType.getContext(), {valType, boolType});
2655 void AtomicCmpXchgOp::build(OpBuilder &builder, OperationState &state,
2656 Value ptr, Value cmp, Value val,
2657 AtomicOrdering successOrdering,
2658 AtomicOrdering failureOrdering, StringRef syncscope,
2659 unsigned alignment, bool isWeak, bool isVolatile) {
2660 build(builder, state, getValAndBoolStructType(val.getType()), ptr, cmp, val,
2661 successOrdering, failureOrdering,
2662 !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
2663 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isWeak,
2664 isVolatile, /*access_groups=*/nullptr,
2665 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
2668 LogicalResult AtomicCmpXchgOp::verify() {
2669 auto ptrType = llvm::cast<LLVM::LLVMPointerType>(getPtr().getType());
2670 if (!ptrType)
2671 return emitOpError("expected LLVM IR pointer type for operand #0");
2672 auto valType = getVal().getType();
2673 if (!isTypeCompatibleWithAtomicOp(valType,
2674 /*isPointerTypeAllowed=*/true))
2675 return emitOpError("unexpected LLVM IR type");
2676 if (getSuccessOrdering() < AtomicOrdering::monotonic ||
2677 getFailureOrdering() < AtomicOrdering::monotonic)
2678 return emitOpError("ordering must be at least 'monotonic'");
2679 if (getFailureOrdering() == AtomicOrdering::release ||
2680 getFailureOrdering() == AtomicOrdering::acq_rel)
2681 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
2682 return success();
2685 //===----------------------------------------------------------------------===//
2686 // FenceOp
2687 //===----------------------------------------------------------------------===//
2689 void FenceOp::build(OpBuilder &builder, OperationState &state,
2690 AtomicOrdering ordering, StringRef syncscope) {
2691 build(builder, state, ordering,
2692 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
2695 LogicalResult FenceOp::verify() {
2696 if (getOrdering() == AtomicOrdering::not_atomic ||
2697 getOrdering() == AtomicOrdering::unordered ||
2698 getOrdering() == AtomicOrdering::monotonic)
2699 return emitOpError("can be given only acquire, release, acq_rel, "
2700 "and seq_cst orderings");
2701 return success();
2704 //===----------------------------------------------------------------------===//
2705 // Verifier for extension ops
2706 //===----------------------------------------------------------------------===//
2708 /// Verifies that the given extension operation operates on consistent scalars
2709 /// or vectors, and that the target width is larger than the input width.
2710 template <class ExtOp>
2711 static LogicalResult verifyExtOp(ExtOp op) {
2712 IntegerType inputType, outputType;
2713 if (isCompatibleVectorType(op.getArg().getType())) {
2714 if (!isCompatibleVectorType(op.getResult().getType()))
2715 return op.emitError(
2716 "input type is a vector but output type is an integer");
2717 if (getVectorNumElements(op.getArg().getType()) !=
2718 getVectorNumElements(op.getResult().getType()))
2719 return op.emitError("input and output vectors are of incompatible shape");
2720 // Because this is a CastOp, the element of vectors is guaranteed to be an
2721 // integer.
2722 inputType = cast<IntegerType>(getVectorElementType(op.getArg().getType()));
2723 outputType =
2724 cast<IntegerType>(getVectorElementType(op.getResult().getType()));
2725 } else {
2726 // Because this is a CastOp and arg is not a vector, arg is guaranteed to be
2727 // an integer.
2728 inputType = cast<IntegerType>(op.getArg().getType());
2729 outputType = dyn_cast<IntegerType>(op.getResult().getType());
2730 if (!outputType)
2731 return op.emitError(
2732 "input type is an integer but output type is a vector");
2735 if (outputType.getWidth() <= inputType.getWidth())
2736 return op.emitError("integer width of the output type is smaller or "
2737 "equal to the integer width of the input type");
2738 return success();
2741 //===----------------------------------------------------------------------===//
2742 // ZExtOp
2743 //===----------------------------------------------------------------------===//
2745 LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(*this); }
2747 OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) {
2748 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
2749 if (!arg)
2750 return {};
2752 size_t targetSize = cast<IntegerType>(getType()).getWidth();
2753 return IntegerAttr::get(getType(), arg.getValue().zext(targetSize));
2756 //===----------------------------------------------------------------------===//
2757 // SExtOp
2758 //===----------------------------------------------------------------------===//
2760 LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
2762 //===----------------------------------------------------------------------===//
2763 // Folder and verifier for LLVM::BitcastOp
2764 //===----------------------------------------------------------------------===//
2766 OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
2767 // bitcast(x : T0, T0) -> x
2768 if (getArg().getType() == getType())
2769 return getArg();
2770 // bitcast(bitcast(x : T0, T1), T0) -> x
2771 if (auto prev = getArg().getDefiningOp<BitcastOp>())
2772 if (prev.getArg().getType() == getType())
2773 return prev.getArg();
2774 return {};
2777 LogicalResult LLVM::BitcastOp::verify() {
2778 auto resultType = llvm::dyn_cast<LLVMPointerType>(
2779 extractVectorElementType(getResult().getType()));
2780 auto sourceType = llvm::dyn_cast<LLVMPointerType>(
2781 extractVectorElementType(getArg().getType()));
2783 // If one of the types is a pointer (or vector of pointers), then
2784 // both source and result type have to be pointers.
2785 if (static_cast<bool>(resultType) != static_cast<bool>(sourceType))
2786 return emitOpError("can only cast pointers from and to pointers");
2788 if (!resultType)
2789 return success();
2791 auto isVector = [](Type type) {
2792 return llvm::isa<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>(
2793 type);
2796 // Due to bitcast requiring both operands to be of the same size, it is not
2797 // possible for only one of the two to be a pointer of vectors.
2798 if (isVector(getResult().getType()) && !isVector(getArg().getType()))
2799 return emitOpError("cannot cast pointer to vector of pointers");
2801 if (!isVector(getResult().getType()) && isVector(getArg().getType()))
2802 return emitOpError("cannot cast vector of pointers to pointer");
2804 // Bitcast cannot cast between pointers of different address spaces.
2805 // 'llvm.addrspacecast' must be used for this purpose instead.
2806 if (resultType.getAddressSpace() != sourceType.getAddressSpace())
2807 return emitOpError("cannot cast pointers of different address spaces, "
2808 "use 'llvm.addrspacecast' instead");
2810 return success();
2813 //===----------------------------------------------------------------------===//
2814 // Folder for LLVM::AddrSpaceCastOp
2815 //===----------------------------------------------------------------------===//
2817 OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
2818 // addrcast(x : T0, T0) -> x
2819 if (getArg().getType() == getType())
2820 return getArg();
2821 // addrcast(addrcast(x : T0, T1), T0) -> x
2822 if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>())
2823 if (prev.getArg().getType() == getType())
2824 return prev.getArg();
2825 return {};
2828 //===----------------------------------------------------------------------===//
2829 // Folder for LLVM::GEPOp
2830 //===----------------------------------------------------------------------===//
2832 OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
2833 GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(),
2834 adaptor.getDynamicIndices());
2836 // gep %x:T, 0 -> %x
2837 if (getBase().getType() == getType() && indices.size() == 1)
2838 if (auto integer = llvm::dyn_cast_or_null<IntegerAttr>(indices[0]))
2839 if (integer.getValue().isZero())
2840 return getBase();
2842 // Canonicalize any dynamic indices of constant value to constant indices.
2843 bool changed = false;
2844 SmallVector<GEPArg> gepArgs;
2845 for (auto iter : llvm::enumerate(indices)) {
2846 auto integer = llvm::dyn_cast_or_null<IntegerAttr>(iter.value());
2847 // Constant indices can only be int32_t, so if integer does not fit we
2848 // are forced to keep it dynamic, despite being a constant.
2849 if (!indices.isDynamicIndex(iter.index()) || !integer ||
2850 !integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
2852 PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
2853 if (Value val = llvm::dyn_cast_if_present<Value>(existing))
2854 gepArgs.emplace_back(val);
2855 else
2856 gepArgs.emplace_back(existing.get<IntegerAttr>().getInt());
2858 continue;
2861 changed = true;
2862 gepArgs.emplace_back(integer.getInt());
2864 if (changed) {
2865 SmallVector<int32_t> rawConstantIndices;
2866 SmallVector<Value> dynamicIndices;
2867 destructureIndices(getElemType(), gepArgs, rawConstantIndices,
2868 dynamicIndices);
2870 getDynamicIndicesMutable().assign(dynamicIndices);
2871 setRawConstantIndices(rawConstantIndices);
2872 return Value{*this};
2875 return {};
2878 //===----------------------------------------------------------------------===//
2879 // ShlOp
2880 //===----------------------------------------------------------------------===//
2882 OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) {
2883 auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
2884 if (!rhs)
2885 return {};
2887 if (rhs.getValue().getZExtValue() >=
2888 getLhs().getType().getIntOrFloatBitWidth())
2889 return {}; // TODO: Fold into poison.
2891 auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
2892 if (!lhs)
2893 return {};
2895 return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue()));
2898 //===----------------------------------------------------------------------===//
2899 // OrOp
2900 //===----------------------------------------------------------------------===//
2902 OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
2903 auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
2904 if (!lhs)
2905 return {};
2907 auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
2908 if (!rhs)
2909 return {};
2911 return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue());
2914 //===----------------------------------------------------------------------===//
2915 // CallIntrinsicOp
2916 //===----------------------------------------------------------------------===//
2918 LogicalResult CallIntrinsicOp::verify() {
2919 if (!getIntrin().starts_with("llvm."))
2920 return emitOpError() << "intrinsic name must start with 'llvm.'";
2921 return success();
2924 //===----------------------------------------------------------------------===//
2925 // OpAsmDialectInterface
2926 //===----------------------------------------------------------------------===//
2928 namespace {
2929 struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface {
2930 using OpAsmDialectInterface::OpAsmDialectInterface;
2932 AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
2933 return TypeSwitch<Attribute, AliasResult>(attr)
2934 .Case<AccessGroupAttr, AliasScopeAttr, AliasScopeDomainAttr,
2935 DIBasicTypeAttr, DICompileUnitAttr, DICompositeTypeAttr,
2936 DIDerivedTypeAttr, DIFileAttr, DIGlobalVariableAttr,
2937 DIGlobalVariableExpressionAttr, DILabelAttr, DILexicalBlockAttr,
2938 DILexicalBlockFileAttr, DILocalVariableAttr, DIModuleAttr,
2939 DINamespaceAttr, DINullTypeAttr, DISubprogramAttr,
2940 DISubroutineTypeAttr, LoopAnnotationAttr, LoopVectorizeAttr,
2941 LoopInterleaveAttr, LoopUnrollAttr, LoopUnrollAndJamAttr,
2942 LoopLICMAttr, LoopDistributeAttr, LoopPipelineAttr,
2943 LoopPeeledAttr, LoopUnswitchAttr, TBAARootAttr, TBAATagAttr,
2944 TBAATypeDescriptorAttr>([&](auto attr) {
2945 os << decltype(attr)::getMnemonic();
2946 return AliasResult::OverridableAlias;
2948 .Default([](Attribute) { return AliasResult::NoAlias; });
2951 } // namespace
2953 //===----------------------------------------------------------------------===//
2954 // LinkerOptionsOp
2955 //===----------------------------------------------------------------------===//
2957 LogicalResult LinkerOptionsOp::verify() {
2958 if (mlir::Operation *parentOp = (*this)->getParentOp();
2959 parentOp && !satisfiesLLVMModule(parentOp))
2960 return emitOpError("must appear at the module level");
2961 return success();
2964 //===----------------------------------------------------------------------===//
2965 // LLVMDialect initialization, type parsing, and registration.
2966 //===----------------------------------------------------------------------===//
2968 void LLVMDialect::initialize() {
2969 registerAttributes();
2971 // clang-format off
2972 addTypes<LLVMVoidType,
2973 LLVMPPCFP128Type,
2974 LLVMX86MMXType,
2975 LLVMTokenType,
2976 LLVMLabelType,
2977 LLVMMetadataType,
2978 LLVMStructType>();
2979 // clang-format on
2980 registerTypes();
2982 addOperations<
2983 #define GET_OP_LIST
2984 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
2986 #define GET_OP_LIST
2987 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
2988 >();
2990 // Support unknown operations because not all LLVM operations are registered.
2991 allowUnknownOperations();
2992 // clang-format off
2993 addInterfaces<LLVMOpAsmDialectInterface>();
2994 // clang-format on
2995 detail::addLLVMInlinerInterface(this);
2998 #define GET_OP_CLASSES
2999 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
3001 #define GET_OP_CLASSES
3002 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
3004 LogicalResult LLVMDialect::verifyDataLayoutString(
3005 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
3006 llvm::Expected<llvm::DataLayout> maybeDataLayout =
3007 llvm::DataLayout::parse(descr);
3008 if (maybeDataLayout)
3009 return success();
3011 std::string message;
3012 llvm::raw_string_ostream messageStream(message);
3013 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
3014 reportError("invalid data layout descriptor: " + messageStream.str());
3015 return failure();
3018 /// Verify LLVM dialect attributes.
3019 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
3020 NamedAttribute attr) {
3021 // If the data layout attribute is present, it must use the LLVM data layout
3022 // syntax. Try parsing it and report errors in case of failure. Users of this
3023 // attribute may assume it is well-formed and can pass it to the (asserting)
3024 // llvm::DataLayout constructor.
3025 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
3026 return success();
3027 if (auto stringAttr = llvm::dyn_cast<StringAttr>(attr.getValue()))
3028 return verifyDataLayoutString(
3029 stringAttr.getValue(),
3030 [op](const Twine &message) { op->emitOpError() << message.str(); });
3032 return op->emitOpError() << "expected '"
3033 << LLVM::LLVMDialect::getDataLayoutAttrName()
3034 << "' to be a string attributes";
3037 LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op,
3038 Type paramType,
3039 NamedAttribute paramAttr) {
3040 // LLVM attribute may be attached to a result of operation that has not been
3041 // converted to LLVM dialect yet, so the result may have a type with unknown
3042 // representation in LLVM dialect type space. In this case we cannot verify
3043 // whether the attribute may be
3044 bool verifyValueType = isCompatibleType(paramType);
3045 StringAttr name = paramAttr.getName();
3047 auto checkUnitAttrType = [&]() -> LogicalResult {
3048 if (!llvm::isa<UnitAttr>(paramAttr.getValue()))
3049 return op->emitError() << name << " should be a unit attribute";
3050 return success();
3052 auto checkTypeAttrType = [&]() -> LogicalResult {
3053 if (!llvm::isa<TypeAttr>(paramAttr.getValue()))
3054 return op->emitError() << name << " should be a type attribute";
3055 return success();
3057 auto checkIntegerAttrType = [&]() -> LogicalResult {
3058 if (!llvm::isa<IntegerAttr>(paramAttr.getValue()))
3059 return op->emitError() << name << " should be an integer attribute";
3060 return success();
3062 auto checkPointerType = [&]() -> LogicalResult {
3063 if (!llvm::isa<LLVMPointerType>(paramType))
3064 return op->emitError()
3065 << name << " attribute attached to non-pointer LLVM type";
3066 return success();
3068 auto checkIntegerType = [&]() -> LogicalResult {
3069 if (!llvm::isa<IntegerType>(paramType))
3070 return op->emitError()
3071 << name << " attribute attached to non-integer LLVM type";
3072 return success();
3074 auto checkPointerTypeMatches = [&]() -> LogicalResult {
3075 if (failed(checkPointerType()))
3076 return failure();
3078 return success();
3081 // Check a unit attribute that is attached to a pointer value.
3082 if (name == LLVMDialect::getNoAliasAttrName() ||
3083 name == LLVMDialect::getReadonlyAttrName() ||
3084 name == LLVMDialect::getReadnoneAttrName() ||
3085 name == LLVMDialect::getWriteOnlyAttrName() ||
3086 name == LLVMDialect::getNestAttrName() ||
3087 name == LLVMDialect::getNoCaptureAttrName() ||
3088 name == LLVMDialect::getNoFreeAttrName() ||
3089 name == LLVMDialect::getNonNullAttrName()) {
3090 if (failed(checkUnitAttrType()))
3091 return failure();
3092 if (verifyValueType && failed(checkPointerType()))
3093 return failure();
3094 return success();
3097 // Check a type attribute that is attached to a pointer value.
3098 if (name == LLVMDialect::getStructRetAttrName() ||
3099 name == LLVMDialect::getByValAttrName() ||
3100 name == LLVMDialect::getByRefAttrName() ||
3101 name == LLVMDialect::getInAllocaAttrName() ||
3102 name == LLVMDialect::getPreallocatedAttrName()) {
3103 if (failed(checkTypeAttrType()))
3104 return failure();
3105 if (verifyValueType && failed(checkPointerTypeMatches()))
3106 return failure();
3107 return success();
3110 // Check a unit attribute that is attached to an integer value.
3111 if (name == LLVMDialect::getSExtAttrName() ||
3112 name == LLVMDialect::getZExtAttrName()) {
3113 if (failed(checkUnitAttrType()))
3114 return failure();
3115 if (verifyValueType && failed(checkIntegerType()))
3116 return failure();
3117 return success();
3120 // Check an integer attribute that is attached to a pointer value.
3121 if (name == LLVMDialect::getAlignAttrName() ||
3122 name == LLVMDialect::getDereferenceableAttrName() ||
3123 name == LLVMDialect::getDereferenceableOrNullAttrName() ||
3124 name == LLVMDialect::getStackAlignmentAttrName()) {
3125 if (failed(checkIntegerAttrType()))
3126 return failure();
3127 if (verifyValueType && failed(checkPointerType()))
3128 return failure();
3129 return success();
3132 // Check a unit attribute that can be attached to arbitrary types.
3133 if (name == LLVMDialect::getNoUndefAttrName() ||
3134 name == LLVMDialect::getInRegAttrName() ||
3135 name == LLVMDialect::getReturnedAttrName())
3136 return checkUnitAttrType();
3138 return success();
3141 /// Verify LLVMIR function argument attributes.
3142 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
3143 unsigned regionIdx,
3144 unsigned argIdx,
3145 NamedAttribute argAttr) {
3146 auto funcOp = dyn_cast<FunctionOpInterface>(op);
3147 if (!funcOp)
3148 return success();
3149 Type argType = funcOp.getArgumentTypes()[argIdx];
3151 return verifyParameterAttribute(op, argType, argAttr);
3154 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
3155 unsigned regionIdx,
3156 unsigned resIdx,
3157 NamedAttribute resAttr) {
3158 auto funcOp = dyn_cast<FunctionOpInterface>(op);
3159 if (!funcOp)
3160 return success();
3161 Type resType = funcOp.getResultTypes()[resIdx];
3163 // Check to see if this function has a void return with a result attribute
3164 // to it. It isn't clear what semantics we would assign to that.
3165 if (llvm::isa<LLVMVoidType>(resType))
3166 return op->emitError() << "cannot attach result attributes to functions "
3167 "with a void return";
3169 // Check to see if this attribute is allowed as a result attribute. Only
3170 // explicitly forbidden LLVM attributes will cause an error.
3171 auto name = resAttr.getName();
3172 if (name == LLVMDialect::getAllocAlignAttrName() ||
3173 name == LLVMDialect::getAllocatedPointerAttrName() ||
3174 name == LLVMDialect::getByValAttrName() ||
3175 name == LLVMDialect::getByRefAttrName() ||
3176 name == LLVMDialect::getInAllocaAttrName() ||
3177 name == LLVMDialect::getNestAttrName() ||
3178 name == LLVMDialect::getNoCaptureAttrName() ||
3179 name == LLVMDialect::getNoFreeAttrName() ||
3180 name == LLVMDialect::getPreallocatedAttrName() ||
3181 name == LLVMDialect::getReadnoneAttrName() ||
3182 name == LLVMDialect::getReadonlyAttrName() ||
3183 name == LLVMDialect::getReturnedAttrName() ||
3184 name == LLVMDialect::getStackAlignmentAttrName() ||
3185 name == LLVMDialect::getStructRetAttrName() ||
3186 name == LLVMDialect::getWriteOnlyAttrName())
3187 return op->emitError() << name << " is not a valid result attribute";
3188 return verifyParameterAttribute(op, resType, resAttr);
3191 Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value,
3192 Type type, Location loc) {
3193 return LLVM::ConstantOp::materialize(builder, value, type, loc);
3196 //===----------------------------------------------------------------------===//
3197 // Utility functions.
3198 //===----------------------------------------------------------------------===//
3200 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
3201 StringRef name, StringRef value,
3202 LLVM::Linkage linkage) {
3203 assert(builder.getInsertionBlock() &&
3204 builder.getInsertionBlock()->getParentOp() &&
3205 "expected builder to point to a block constrained in an op");
3206 auto module =
3207 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
3208 assert(module && "builder points to an op outside of a module");
3210 // Create the global at the entry of the module.
3211 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
3212 MLIRContext *ctx = builder.getContext();
3213 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
3214 auto global = moduleBuilder.create<LLVM::GlobalOp>(
3215 loc, type, /*isConstant=*/true, linkage, name,
3216 builder.getStringAttr(value), /*alignment=*/0);
3218 LLVMPointerType ptrType = LLVMPointerType::get(ctx);
3219 // Get the pointer to the first character in the global string.
3220 Value globalPtr =
3221 builder.create<LLVM::AddressOfOp>(loc, ptrType, global.getSymNameAttr());
3222 return builder.create<LLVM::GEPOp>(loc, ptrType, type, globalPtr,
3223 ArrayRef<GEPArg>{0, 0});
3226 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
3227 return op->hasTrait<OpTrait::SymbolTable>() &&
3228 op->hasTrait<OpTrait::IsIsolatedFromAbove>();