[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / CAPI / IR / IR.cpp
blob24dc8854048532ccd3539a2c480239d1e0a9256e
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir-c/IR.h"
10 #include "mlir-c/Support.h"
12 #include "mlir/AsmParser/AsmParser.h"
13 #include "mlir/Bytecode/BytecodeWriter.h"
14 #include "mlir/CAPI/IR.h"
15 #include "mlir/CAPI/Support.h"
16 #include "mlir/CAPI/Utils.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/Operation.h"
24 #include "mlir/IR/OperationSupport.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/IR/Value.h"
27 #include "mlir/IR/Verifier.h"
28 #include "mlir/IR/Visitors.h"
29 #include "mlir/Interfaces/InferTypeOpInterface.h"
30 #include "mlir/Parser/Parser.h"
31 #include "llvm/ADT/SmallPtrSet.h"
32 #include "llvm/Support/ThreadPool.h"
34 #include <cstddef>
35 #include <memory>
36 #include <optional>
38 using namespace mlir;
40 //===----------------------------------------------------------------------===//
41 // Context API.
42 //===----------------------------------------------------------------------===//
44 MlirContext mlirContextCreate() {
45 auto *context = new MLIRContext;
46 return wrap(context);
49 static inline MLIRContext::Threading toThreadingEnum(bool threadingEnabled) {
50 return threadingEnabled ? MLIRContext::Threading::ENABLED
51 : MLIRContext::Threading::DISABLED;
54 MlirContext mlirContextCreateWithThreading(bool threadingEnabled) {
55 auto *context = new MLIRContext(toThreadingEnum(threadingEnabled));
56 return wrap(context);
59 MlirContext mlirContextCreateWithRegistry(MlirDialectRegistry registry,
60 bool threadingEnabled) {
61 auto *context =
62 new MLIRContext(*unwrap(registry), toThreadingEnum(threadingEnabled));
63 return wrap(context);
66 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
67 return unwrap(ctx1) == unwrap(ctx2);
70 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
72 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) {
73 unwrap(context)->allowUnregisteredDialects(allow);
76 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) {
77 return unwrap(context)->allowsUnregisteredDialects();
79 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
80 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
83 void mlirContextAppendDialectRegistry(MlirContext ctx,
84 MlirDialectRegistry registry) {
85 unwrap(ctx)->appendDialectRegistry(*unwrap(registry));
88 // TODO: expose a cheaper way than constructing + sorting a vector only to take
89 // its size.
90 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
91 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
94 MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
95 MlirStringRef name) {
96 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
99 bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
100 return unwrap(context)->isOperationRegistered(unwrap(name));
103 void mlirContextEnableMultithreading(MlirContext context, bool enable) {
104 return unwrap(context)->enableMultithreading(enable);
107 void mlirContextLoadAllAvailableDialects(MlirContext context) {
108 unwrap(context)->loadAllAvailableDialects();
111 void mlirContextSetThreadPool(MlirContext context,
112 MlirLlvmThreadPool threadPool) {
113 unwrap(context)->setThreadPool(*unwrap(threadPool));
116 //===----------------------------------------------------------------------===//
117 // Dialect API.
118 //===----------------------------------------------------------------------===//
120 MlirContext mlirDialectGetContext(MlirDialect dialect) {
121 return wrap(unwrap(dialect)->getContext());
124 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
125 return unwrap(dialect1) == unwrap(dialect2);
128 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
129 return wrap(unwrap(dialect)->getNamespace());
132 //===----------------------------------------------------------------------===//
133 // DialectRegistry API.
134 //===----------------------------------------------------------------------===//
136 MlirDialectRegistry mlirDialectRegistryCreate() {
137 return wrap(new DialectRegistry());
140 void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
141 delete unwrap(registry);
144 //===----------------------------------------------------------------------===//
145 // AsmState API.
146 //===----------------------------------------------------------------------===//
148 MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op,
149 MlirOpPrintingFlags flags) {
150 return wrap(new AsmState(unwrap(op), *unwrap(flags)));
153 static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
154 do {
155 // If we are printing local scope, stop at the first operation that is
156 // isolated from above.
157 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
158 break;
160 // Otherwise, traverse up to the next parent.
161 Operation *parentOp = op->getParentOp();
162 if (!parentOp)
163 break;
164 op = parentOp;
165 } while (true);
166 return op;
169 MlirAsmState mlirAsmStateCreateForValue(MlirValue value,
170 MlirOpPrintingFlags flags) {
171 Operation *op;
172 mlir::Value val = unwrap(value);
173 if (auto result = llvm::dyn_cast<OpResult>(val)) {
174 op = result.getOwner();
175 } else {
176 op = llvm::cast<BlockArgument>(val).getOwner()->getParentOp();
177 if (!op) {
178 emitError(val.getLoc()) << "<<UNKNOWN SSA VALUE>>";
179 return {nullptr};
182 op = findParent(op, unwrap(flags)->shouldUseLocalScope());
183 return wrap(new AsmState(op, *unwrap(flags)));
186 /// Destroys printing flags created with mlirAsmStateCreate.
187 void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(state); }
189 //===----------------------------------------------------------------------===//
190 // Printing flags API.
191 //===----------------------------------------------------------------------===//
193 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
194 return wrap(new OpPrintingFlags());
197 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
198 delete unwrap(flags);
201 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
202 intptr_t largeElementLimit) {
203 unwrap(flags)->elideLargeElementsAttrs(largeElementLimit);
206 void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags,
207 intptr_t largeResourceLimit) {
208 unwrap(flags)->elideLargeResourceString(largeResourceLimit);
211 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable,
212 bool prettyForm) {
213 unwrap(flags)->enableDebugInfo(enable, /*prettyForm=*/prettyForm);
216 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
217 unwrap(flags)->printGenericOpForm();
220 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
221 unwrap(flags)->useLocalScope();
224 void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) {
225 unwrap(flags)->assumeVerified();
228 void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags) {
229 unwrap(flags)->skipRegions();
231 //===----------------------------------------------------------------------===//
232 // Bytecode printing flags API.
233 //===----------------------------------------------------------------------===//
235 MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate() {
236 return wrap(new BytecodeWriterConfig());
239 void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config) {
240 delete unwrap(config);
243 void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags,
244 int64_t version) {
245 unwrap(flags)->setDesiredBytecodeVersion(version);
248 //===----------------------------------------------------------------------===//
249 // Location API.
250 //===----------------------------------------------------------------------===//
252 MlirAttribute mlirLocationGetAttribute(MlirLocation location) {
253 return wrap(LocationAttr(unwrap(location)));
256 MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) {
257 return wrap(Location(llvm::cast<LocationAttr>(unwrap(attribute))));
260 MlirLocation mlirLocationFileLineColGet(MlirContext context,
261 MlirStringRef filename, unsigned line,
262 unsigned col) {
263 return wrap(Location(
264 FileLineColLoc::get(unwrap(context), unwrap(filename), line, col)));
267 MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
268 return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller))));
271 MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
272 MlirLocation const *locations,
273 MlirAttribute metadata) {
274 SmallVector<Location, 4> locs;
275 ArrayRef<Location> unwrappedLocs = unwrapList(nLocations, locations, locs);
276 return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx)));
279 MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
280 MlirLocation childLoc) {
281 if (mlirLocationIsNull(childLoc))
282 return wrap(
283 Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name)))));
284 return wrap(Location(NameLoc::get(
285 StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc))));
288 MlirLocation mlirLocationUnknownGet(MlirContext context) {
289 return wrap(Location(UnknownLoc::get(unwrap(context))));
292 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
293 return unwrap(l1) == unwrap(l2);
296 MlirContext mlirLocationGetContext(MlirLocation location) {
297 return wrap(unwrap(location).getContext());
300 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
301 void *userData) {
302 detail::CallbackOstream stream(callback, userData);
303 unwrap(location).print(stream);
306 //===----------------------------------------------------------------------===//
307 // Module API.
308 //===----------------------------------------------------------------------===//
310 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
311 return wrap(ModuleOp::create(unwrap(location)));
314 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
315 OwningOpRef<ModuleOp> owning =
316 parseSourceString<ModuleOp>(unwrap(module), unwrap(context));
317 if (!owning)
318 return MlirModule{nullptr};
319 return MlirModule{owning.release().getOperation()};
322 MlirContext mlirModuleGetContext(MlirModule module) {
323 return wrap(unwrap(module).getContext());
326 MlirBlock mlirModuleGetBody(MlirModule module) {
327 return wrap(unwrap(module).getBody());
330 void mlirModuleDestroy(MlirModule module) {
331 // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is
332 // called.
333 OwningOpRef<ModuleOp>(unwrap(module));
336 MlirOperation mlirModuleGetOperation(MlirModule module) {
337 return wrap(unwrap(module).getOperation());
340 MlirModule mlirModuleFromOperation(MlirOperation op) {
341 return wrap(dyn_cast<ModuleOp>(unwrap(op)));
344 //===----------------------------------------------------------------------===//
345 // Operation state API.
346 //===----------------------------------------------------------------------===//
348 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
349 MlirOperationState state;
350 state.name = name;
351 state.location = loc;
352 state.nResults = 0;
353 state.results = nullptr;
354 state.nOperands = 0;
355 state.operands = nullptr;
356 state.nRegions = 0;
357 state.regions = nullptr;
358 state.nSuccessors = 0;
359 state.successors = nullptr;
360 state.nAttributes = 0;
361 state.attributes = nullptr;
362 state.enableResultTypeInference = false;
363 return state;
366 #define APPEND_ELEMS(type, sizeName, elemName) \
367 state->elemName = \
368 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \
369 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \
370 state->sizeName += n;
372 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
373 MlirType const *results) {
374 APPEND_ELEMS(MlirType, nResults, results);
377 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
378 MlirValue const *operands) {
379 APPEND_ELEMS(MlirValue, nOperands, operands);
381 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
382 MlirRegion const *regions) {
383 APPEND_ELEMS(MlirRegion, nRegions, regions);
385 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
386 MlirBlock const *successors) {
387 APPEND_ELEMS(MlirBlock, nSuccessors, successors);
389 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
390 MlirNamedAttribute const *attributes) {
391 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
394 void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) {
395 state->enableResultTypeInference = true;
398 //===----------------------------------------------------------------------===//
399 // Operation API.
400 //===----------------------------------------------------------------------===//
402 static LogicalResult inferOperationTypes(OperationState &state) {
403 MLIRContext *context = state.getContext();
404 std::optional<RegisteredOperationName> info = state.name.getRegisteredInfo();
405 if (!info) {
406 emitError(state.location)
407 << "type inference was requested for the operation " << state.name
408 << ", but the operation was not registered; ensure that the dialect "
409 "containing the operation is linked into MLIR and registered with "
410 "the context";
411 return failure();
414 auto *inferInterface = info->getInterface<InferTypeOpInterface>();
415 if (!inferInterface) {
416 emitError(state.location)
417 << "type inference was requested for the operation " << state.name
418 << ", but the operation does not support type inference; result "
419 "types must be specified explicitly";
420 return failure();
423 DictionaryAttr attributes = state.attributes.getDictionary(context);
424 OpaqueProperties properties = state.getRawProperties();
426 if (!properties && info->getOpPropertyByteSize() > 0 && !attributes.empty()) {
427 auto prop = std::make_unique<char[]>(info->getOpPropertyByteSize());
428 properties = OpaqueProperties(prop.get());
429 if (properties) {
430 auto emitError = [&]() {
431 return mlir::emitError(state.location)
432 << " failed properties conversion while building "
433 << state.name.getStringRef() << " with `" << attributes << "`: ";
435 if (failed(info->setOpPropertiesFromAttribute(state.name, properties,
436 attributes, emitError)))
437 return failure();
439 if (succeeded(inferInterface->inferReturnTypes(
440 context, state.location, state.operands, attributes, properties,
441 state.regions, state.types))) {
442 return success();
444 // Diagnostic emitted by interface.
445 return failure();
448 if (succeeded(inferInterface->inferReturnTypes(
449 context, state.location, state.operands, attributes, properties,
450 state.regions, state.types)))
451 return success();
453 // Diagnostic emitted by interface.
454 return failure();
457 MlirOperation mlirOperationCreate(MlirOperationState *state) {
458 assert(state);
459 OperationState cppState(unwrap(state->location), unwrap(state->name));
460 SmallVector<Type, 4> resultStorage;
461 SmallVector<Value, 8> operandStorage;
462 SmallVector<Block *, 2> successorStorage;
463 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
464 cppState.addOperands(
465 unwrapList(state->nOperands, state->operands, operandStorage));
466 cppState.addSuccessors(
467 unwrapList(state->nSuccessors, state->successors, successorStorage));
469 cppState.attributes.reserve(state->nAttributes);
470 for (intptr_t i = 0; i < state->nAttributes; ++i)
471 cppState.addAttribute(unwrap(state->attributes[i].name),
472 unwrap(state->attributes[i].attribute));
474 for (intptr_t i = 0; i < state->nRegions; ++i)
475 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
477 free(state->results);
478 free(state->operands);
479 free(state->successors);
480 free(state->regions);
481 free(state->attributes);
483 // Infer result types.
484 if (state->enableResultTypeInference) {
485 assert(cppState.types.empty() &&
486 "result type inference enabled and result types provided");
487 if (failed(inferOperationTypes(cppState)))
488 return {nullptr};
491 return wrap(Operation::create(cppState));
494 MlirOperation mlirOperationCreateParse(MlirContext context,
495 MlirStringRef sourceStr,
496 MlirStringRef sourceName) {
498 return wrap(
499 parseSourceString(unwrap(sourceStr), unwrap(context), unwrap(sourceName))
500 .release());
503 MlirOperation mlirOperationClone(MlirOperation op) {
504 return wrap(unwrap(op)->clone());
507 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
509 void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); }
511 bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
512 return unwrap(op) == unwrap(other);
515 MlirContext mlirOperationGetContext(MlirOperation op) {
516 return wrap(unwrap(op)->getContext());
519 MlirLocation mlirOperationGetLocation(MlirOperation op) {
520 return wrap(unwrap(op)->getLoc());
523 MlirTypeID mlirOperationGetTypeID(MlirOperation op) {
524 if (auto info = unwrap(op)->getRegisteredInfo())
525 return wrap(info->getTypeID());
526 return {nullptr};
529 MlirIdentifier mlirOperationGetName(MlirOperation op) {
530 return wrap(unwrap(op)->getName().getIdentifier());
533 MlirBlock mlirOperationGetBlock(MlirOperation op) {
534 return wrap(unwrap(op)->getBlock());
537 MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
538 return wrap(unwrap(op)->getParentOp());
541 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
542 return static_cast<intptr_t>(unwrap(op)->getNumRegions());
545 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
546 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
549 MlirRegion mlirOperationGetFirstRegion(MlirOperation op) {
550 Operation *cppOp = unwrap(op);
551 if (cppOp->getNumRegions() == 0)
552 return wrap(static_cast<Region *>(nullptr));
553 return wrap(&cppOp->getRegion(0));
556 MlirRegion mlirRegionGetNextInOperation(MlirRegion region) {
557 Region *cppRegion = unwrap(region);
558 Operation *parent = cppRegion->getParentOp();
559 intptr_t next = cppRegion->getRegionNumber() + 1;
560 if (parent->getNumRegions() > next)
561 return wrap(&parent->getRegion(next));
562 return wrap(static_cast<Region *>(nullptr));
565 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
566 return wrap(unwrap(op)->getNextNode());
569 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
570 return static_cast<intptr_t>(unwrap(op)->getNumOperands());
573 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
574 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
577 void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
578 MlirValue newValue) {
579 unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue));
582 void mlirOperationSetOperands(MlirOperation op, intptr_t nOperands,
583 MlirValue const *operands) {
584 SmallVector<Value> ops;
585 unwrap(op)->setOperands(unwrapList(nOperands, operands, ops));
588 intptr_t mlirOperationGetNumResults(MlirOperation op) {
589 return static_cast<intptr_t>(unwrap(op)->getNumResults());
592 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
593 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
596 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
597 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
600 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
601 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
604 MLIR_CAPI_EXPORTED bool
605 mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) {
606 std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name));
607 return attr.has_value();
610 MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op,
611 MlirStringRef name) {
612 std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name));
613 if (attr.has_value())
614 return wrap(*attr);
615 return {};
618 void mlirOperationSetInherentAttributeByName(MlirOperation op,
619 MlirStringRef name,
620 MlirAttribute attr) {
621 unwrap(op)->setInherentAttr(
622 StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr));
625 intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) {
626 return static_cast<intptr_t>(
627 llvm::range_size(unwrap(op)->getDiscardableAttrs()));
630 MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op,
631 intptr_t pos) {
632 NamedAttribute attr =
633 *std::next(unwrap(op)->getDiscardableAttrs().begin(), pos);
634 return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
637 MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op,
638 MlirStringRef name) {
639 return wrap(unwrap(op)->getDiscardableAttr(unwrap(name)));
642 void mlirOperationSetDiscardableAttributeByName(MlirOperation op,
643 MlirStringRef name,
644 MlirAttribute attr) {
645 unwrap(op)->setDiscardableAttr(unwrap(name), unwrap(attr));
648 bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
649 MlirStringRef name) {
650 return !!unwrap(op)->removeDiscardableAttr(unwrap(name));
653 void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos,
654 MlirBlock block) {
655 unwrap(op)->setSuccessor(unwrap(block), static_cast<unsigned>(pos));
658 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
659 return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
662 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
663 NamedAttribute attr = unwrap(op)->getAttrs()[pos];
664 return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
667 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
668 MlirStringRef name) {
669 return wrap(unwrap(op)->getAttr(unwrap(name)));
672 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name,
673 MlirAttribute attr) {
674 unwrap(op)->setAttr(unwrap(name), unwrap(attr));
677 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) {
678 return !!unwrap(op)->removeAttr(unwrap(name));
681 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
682 void *userData) {
683 detail::CallbackOstream stream(callback, userData);
684 unwrap(op)->print(stream);
687 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
688 MlirStringCallback callback, void *userData) {
689 detail::CallbackOstream stream(callback, userData);
690 unwrap(op)->print(stream, *unwrap(flags));
693 void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state,
694 MlirStringCallback callback, void *userData) {
695 detail::CallbackOstream stream(callback, userData);
696 if (state.ptr)
697 unwrap(op)->print(stream, *unwrap(state));
698 unwrap(op)->print(stream);
701 void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback,
702 void *userData) {
703 detail::CallbackOstream stream(callback, userData);
704 // As no desired version is set, no failure can occur.
705 (void)writeBytecodeToFile(unwrap(op), stream);
708 MlirLogicalResult mlirOperationWriteBytecodeWithConfig(
709 MlirOperation op, MlirBytecodeWriterConfig config,
710 MlirStringCallback callback, void *userData) {
711 detail::CallbackOstream stream(callback, userData);
712 return wrap(writeBytecodeToFile(unwrap(op), stream, *unwrap(config)));
715 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
717 bool mlirOperationVerify(MlirOperation op) {
718 return succeeded(verify(unwrap(op)));
721 void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) {
722 return unwrap(op)->moveAfter(unwrap(other));
725 void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
726 return unwrap(op)->moveBefore(unwrap(other));
729 static mlir::WalkResult unwrap(MlirWalkResult result) {
730 switch (result) {
731 case MlirWalkResultAdvance:
732 return mlir::WalkResult::advance();
734 case MlirWalkResultInterrupt:
735 return mlir::WalkResult::interrupt();
737 case MlirWalkResultSkip:
738 return mlir::WalkResult::skip();
740 llvm_unreachable("unknown result in WalkResult::unwrap");
743 void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
744 void *userData, MlirWalkOrder walkOrder) {
745 switch (walkOrder) {
747 case MlirWalkPreOrder:
748 unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
749 [callback, userData](Operation *op) {
750 return unwrap(callback(wrap(op), userData));
752 break;
753 case MlirWalkPostOrder:
754 unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
755 [callback, userData](Operation *op) {
756 return unwrap(callback(wrap(op), userData));
761 //===----------------------------------------------------------------------===//
762 // Region API.
763 //===----------------------------------------------------------------------===//
765 MlirRegion mlirRegionCreate() { return wrap(new Region); }
767 bool mlirRegionEqual(MlirRegion region, MlirRegion other) {
768 return unwrap(region) == unwrap(other);
771 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
772 Region *cppRegion = unwrap(region);
773 if (cppRegion->empty())
774 return wrap(static_cast<Block *>(nullptr));
775 return wrap(&cppRegion->front());
778 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
779 unwrap(region)->push_back(unwrap(block));
782 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
783 MlirBlock block) {
784 auto &blockList = unwrap(region)->getBlocks();
785 blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
788 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
789 MlirBlock block) {
790 Region *cppRegion = unwrap(region);
791 if (mlirBlockIsNull(reference)) {
792 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
793 return;
796 assert(unwrap(reference)->getParent() == unwrap(region) &&
797 "expected reference block to belong to the region");
798 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
799 unwrap(block));
802 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
803 MlirBlock block) {
804 if (mlirBlockIsNull(reference))
805 return mlirRegionAppendOwnedBlock(region, block);
807 assert(unwrap(reference)->getParent() == unwrap(region) &&
808 "expected reference block to belong to the region");
809 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
810 unwrap(block));
813 void mlirRegionDestroy(MlirRegion region) {
814 delete static_cast<Region *>(region.ptr);
817 void mlirRegionTakeBody(MlirRegion target, MlirRegion source) {
818 unwrap(target)->takeBody(*unwrap(source));
821 //===----------------------------------------------------------------------===//
822 // Block API.
823 //===----------------------------------------------------------------------===//
825 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args,
826 MlirLocation const *locs) {
827 Block *b = new Block;
828 for (intptr_t i = 0; i < nArgs; ++i)
829 b->addArgument(unwrap(args[i]), unwrap(locs[i]));
830 return wrap(b);
833 bool mlirBlockEqual(MlirBlock block, MlirBlock other) {
834 return unwrap(block) == unwrap(other);
837 MlirOperation mlirBlockGetParentOperation(MlirBlock block) {
838 return wrap(unwrap(block)->getParentOp());
841 MlirRegion mlirBlockGetParentRegion(MlirBlock block) {
842 return wrap(unwrap(block)->getParent());
845 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
846 return wrap(unwrap(block)->getNextNode());
849 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
850 Block *cppBlock = unwrap(block);
851 if (cppBlock->empty())
852 return wrap(static_cast<Operation *>(nullptr));
853 return wrap(&cppBlock->front());
856 MlirOperation mlirBlockGetTerminator(MlirBlock block) {
857 Block *cppBlock = unwrap(block);
858 if (cppBlock->empty())
859 return wrap(static_cast<Operation *>(nullptr));
860 Operation &back = cppBlock->back();
861 if (!back.hasTrait<OpTrait::IsTerminator>())
862 return wrap(static_cast<Operation *>(nullptr));
863 return wrap(&back);
866 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
867 unwrap(block)->push_back(unwrap(operation));
870 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
871 MlirOperation operation) {
872 auto &opList = unwrap(block)->getOperations();
873 opList.insert(std::next(opList.begin(), pos), unwrap(operation));
876 void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
877 MlirOperation reference,
878 MlirOperation operation) {
879 Block *cppBlock = unwrap(block);
880 if (mlirOperationIsNull(reference)) {
881 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
882 return;
885 assert(unwrap(reference)->getBlock() == unwrap(block) &&
886 "expected reference operation to belong to the block");
887 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
888 unwrap(operation));
891 void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
892 MlirOperation reference,
893 MlirOperation operation) {
894 if (mlirOperationIsNull(reference))
895 return mlirBlockAppendOwnedOperation(block, operation);
897 assert(unwrap(reference)->getBlock() == unwrap(block) &&
898 "expected reference operation to belong to the block");
899 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
900 unwrap(operation));
903 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
905 void mlirBlockDetach(MlirBlock block) {
906 Block *b = unwrap(block);
907 b->getParent()->getBlocks().remove(b);
910 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
911 return static_cast<intptr_t>(unwrap(block)->getNumArguments());
914 MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type,
915 MlirLocation loc) {
916 return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc)));
919 void mlirBlockEraseArgument(MlirBlock block, unsigned index) {
920 return unwrap(block)->eraseArgument(index);
923 MlirValue mlirBlockInsertArgument(MlirBlock block, intptr_t pos, MlirType type,
924 MlirLocation loc) {
925 return wrap(unwrap(block)->insertArgument(pos, unwrap(type), unwrap(loc)));
928 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
929 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
932 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
933 void *userData) {
934 detail::CallbackOstream stream(callback, userData);
935 unwrap(block)->print(stream);
938 //===----------------------------------------------------------------------===//
939 // Value API.
940 //===----------------------------------------------------------------------===//
942 bool mlirValueEqual(MlirValue value1, MlirValue value2) {
943 return unwrap(value1) == unwrap(value2);
946 bool mlirValueIsABlockArgument(MlirValue value) {
947 return llvm::isa<BlockArgument>(unwrap(value));
950 bool mlirValueIsAOpResult(MlirValue value) {
951 return llvm::isa<OpResult>(unwrap(value));
954 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
955 return wrap(llvm::cast<BlockArgument>(unwrap(value)).getOwner());
958 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
959 return static_cast<intptr_t>(
960 llvm::cast<BlockArgument>(unwrap(value)).getArgNumber());
963 void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
964 llvm::cast<BlockArgument>(unwrap(value)).setType(unwrap(type));
967 MlirOperation mlirOpResultGetOwner(MlirValue value) {
968 return wrap(llvm::cast<OpResult>(unwrap(value)).getOwner());
971 intptr_t mlirOpResultGetResultNumber(MlirValue value) {
972 return static_cast<intptr_t>(
973 llvm::cast<OpResult>(unwrap(value)).getResultNumber());
976 MlirType mlirValueGetType(MlirValue value) {
977 return wrap(unwrap(value).getType());
980 void mlirValueSetType(MlirValue value, MlirType type) {
981 unwrap(value).setType(unwrap(type));
984 void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
986 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
987 void *userData) {
988 detail::CallbackOstream stream(callback, userData);
989 unwrap(value).print(stream);
992 void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state,
993 MlirStringCallback callback, void *userData) {
994 detail::CallbackOstream stream(callback, userData);
995 Value cppValue = unwrap(value);
996 cppValue.printAsOperand(stream, *unwrap(state));
999 MlirOpOperand mlirValueGetFirstUse(MlirValue value) {
1000 Value cppValue = unwrap(value);
1001 if (cppValue.use_empty())
1002 return {};
1004 OpOperand *opOperand = cppValue.use_begin().getOperand();
1006 return wrap(opOperand);
1009 void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) {
1010 unwrap(oldValue).replaceAllUsesWith(unwrap(newValue));
1013 void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue,
1014 intptr_t numExceptions,
1015 MlirOperation *exceptions) {
1016 Value oldValueCpp = unwrap(oldValue);
1017 Value newValueCpp = unwrap(newValue);
1019 llvm::SmallPtrSet<mlir::Operation *, 4> exceptionSet;
1020 for (intptr_t i = 0; i < numExceptions; ++i) {
1021 exceptionSet.insert(unwrap(exceptions[i]));
1024 oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet);
1027 //===----------------------------------------------------------------------===//
1028 // OpOperand API.
1029 //===----------------------------------------------------------------------===//
1031 bool mlirOpOperandIsNull(MlirOpOperand opOperand) { return !opOperand.ptr; }
1033 MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand) {
1034 return wrap(unwrap(opOperand)->getOwner());
1037 MlirValue mlirOpOperandGetValue(MlirOpOperand opOperand) {
1038 return wrap(unwrap(opOperand)->get());
1041 unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand) {
1042 return unwrap(opOperand)->getOperandNumber();
1045 MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand) {
1046 if (mlirOpOperandIsNull(opOperand))
1047 return {};
1049 OpOperand *nextOpOperand = static_cast<OpOperand *>(
1050 unwrap(opOperand)->getNextOperandUsingThisValue());
1052 if (!nextOpOperand)
1053 return {};
1055 return wrap(nextOpOperand);
1058 //===----------------------------------------------------------------------===//
1059 // Type API.
1060 //===----------------------------------------------------------------------===//
1062 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) {
1063 return wrap(mlir::parseType(unwrap(type), unwrap(context)));
1066 MlirContext mlirTypeGetContext(MlirType type) {
1067 return wrap(unwrap(type).getContext());
1070 MlirTypeID mlirTypeGetTypeID(MlirType type) {
1071 return wrap(unwrap(type).getTypeID());
1074 MlirDialect mlirTypeGetDialect(MlirType type) {
1075 return wrap(&unwrap(type).getDialect());
1078 bool mlirTypeEqual(MlirType t1, MlirType t2) {
1079 return unwrap(t1) == unwrap(t2);
1082 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
1083 detail::CallbackOstream stream(callback, userData);
1084 unwrap(type).print(stream);
1087 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
1089 //===----------------------------------------------------------------------===//
1090 // Attribute API.
1091 //===----------------------------------------------------------------------===//
1093 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) {
1094 return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context)));
1097 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
1098 return wrap(unwrap(attribute).getContext());
1101 MlirType mlirAttributeGetType(MlirAttribute attribute) {
1102 Attribute attr = unwrap(attribute);
1103 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr))
1104 return wrap(typedAttr.getType());
1105 return wrap(NoneType::get(attr.getContext()));
1108 MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
1109 return wrap(unwrap(attr).getTypeID());
1112 MlirDialect mlirAttributeGetDialect(MlirAttribute attr) {
1113 return wrap(&unwrap(attr).getDialect());
1116 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
1117 return unwrap(a1) == unwrap(a2);
1120 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
1121 void *userData) {
1122 detail::CallbackOstream stream(callback, userData);
1123 unwrap(attr).print(stream);
1126 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
1128 MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name,
1129 MlirAttribute attr) {
1130 return MlirNamedAttribute{name, attr};
1133 //===----------------------------------------------------------------------===//
1134 // Identifier API.
1135 //===----------------------------------------------------------------------===//
1137 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) {
1138 return wrap(StringAttr::get(unwrap(context), unwrap(str)));
1141 MlirContext mlirIdentifierGetContext(MlirIdentifier ident) {
1142 return wrap(unwrap(ident).getContext());
1145 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
1146 return unwrap(ident) == unwrap(other);
1149 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
1150 return wrap(unwrap(ident).strref());
1153 //===----------------------------------------------------------------------===//
1154 // Symbol and SymbolTable API.
1155 //===----------------------------------------------------------------------===//
1157 MlirStringRef mlirSymbolTableGetSymbolAttributeName() {
1158 return wrap(SymbolTable::getSymbolAttrName());
1161 MlirStringRef mlirSymbolTableGetVisibilityAttributeName() {
1162 return wrap(SymbolTable::getVisibilityAttrName());
1165 MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) {
1166 if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>())
1167 return wrap(static_cast<SymbolTable *>(nullptr));
1168 return wrap(new SymbolTable(unwrap(operation)));
1171 void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) {
1172 delete unwrap(symbolTable);
1175 MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,
1176 MlirStringRef name) {
1177 return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length)));
1180 MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
1181 MlirOperation operation) {
1182 return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation)));
1185 void mlirSymbolTableErase(MlirSymbolTable symbolTable,
1186 MlirOperation operation) {
1187 unwrap(symbolTable)->erase(unwrap(operation));
1190 MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,
1191 MlirStringRef newSymbol,
1192 MlirOperation from) {
1193 auto *cppFrom = unwrap(from);
1194 auto *context = cppFrom->getContext();
1195 auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol));
1196 auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol));
1197 return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr,
1198 unwrap(from)));
1201 void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible,
1202 void (*callback)(MlirOperation, bool,
1203 void *userData),
1204 void *userData) {
1205 SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible,
1206 [&](Operation *foundOpCpp, bool isVisible) {
1207 callback(wrap(foundOpCpp), isVisible,
1208 userData);