1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 #include "mlir-c/Support.h"
12 #include "mlir/AsmParser/AsmParser.h"
13 #include "mlir/Bytecode/BytecodeWriter.h"
14 #include "mlir/CAPI/IR.h"
15 #include "mlir/CAPI/Support.h"
16 #include "mlir/CAPI/Utils.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/Operation.h"
24 #include "mlir/IR/OperationSupport.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/IR/Value.h"
27 #include "mlir/IR/Verifier.h"
28 #include "mlir/IR/Visitors.h"
29 #include "mlir/Interfaces/InferTypeOpInterface.h"
30 #include "mlir/Parser/Parser.h"
31 #include "llvm/ADT/SmallPtrSet.h"
32 #include "llvm/Support/ThreadPool.h"
40 //===----------------------------------------------------------------------===//
42 //===----------------------------------------------------------------------===//
44 MlirContext
mlirContextCreate() {
45 auto *context
= new MLIRContext
;
49 static inline MLIRContext::Threading
toThreadingEnum(bool threadingEnabled
) {
50 return threadingEnabled
? MLIRContext::Threading::ENABLED
51 : MLIRContext::Threading::DISABLED
;
54 MlirContext
mlirContextCreateWithThreading(bool threadingEnabled
) {
55 auto *context
= new MLIRContext(toThreadingEnum(threadingEnabled
));
59 MlirContext
mlirContextCreateWithRegistry(MlirDialectRegistry registry
,
60 bool threadingEnabled
) {
62 new MLIRContext(*unwrap(registry
), toThreadingEnum(threadingEnabled
));
66 bool mlirContextEqual(MlirContext ctx1
, MlirContext ctx2
) {
67 return unwrap(ctx1
) == unwrap(ctx2
);
70 void mlirContextDestroy(MlirContext context
) { delete unwrap(context
); }
72 void mlirContextSetAllowUnregisteredDialects(MlirContext context
, bool allow
) {
73 unwrap(context
)->allowUnregisteredDialects(allow
);
76 bool mlirContextGetAllowUnregisteredDialects(MlirContext context
) {
77 return unwrap(context
)->allowsUnregisteredDialects();
79 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context
) {
80 return static_cast<intptr_t>(unwrap(context
)->getAvailableDialects().size());
83 void mlirContextAppendDialectRegistry(MlirContext ctx
,
84 MlirDialectRegistry registry
) {
85 unwrap(ctx
)->appendDialectRegistry(*unwrap(registry
));
88 // TODO: expose a cheaper way than constructing + sorting a vector only to take
90 intptr_t mlirContextGetNumLoadedDialects(MlirContext context
) {
91 return static_cast<intptr_t>(unwrap(context
)->getLoadedDialects().size());
94 MlirDialect
mlirContextGetOrLoadDialect(MlirContext context
,
96 return wrap(unwrap(context
)->getOrLoadDialect(unwrap(name
)));
99 bool mlirContextIsRegisteredOperation(MlirContext context
, MlirStringRef name
) {
100 return unwrap(context
)->isOperationRegistered(unwrap(name
));
103 void mlirContextEnableMultithreading(MlirContext context
, bool enable
) {
104 return unwrap(context
)->enableMultithreading(enable
);
107 void mlirContextLoadAllAvailableDialects(MlirContext context
) {
108 unwrap(context
)->loadAllAvailableDialects();
111 void mlirContextSetThreadPool(MlirContext context
,
112 MlirLlvmThreadPool threadPool
) {
113 unwrap(context
)->setThreadPool(*unwrap(threadPool
));
116 //===----------------------------------------------------------------------===//
118 //===----------------------------------------------------------------------===//
120 MlirContext
mlirDialectGetContext(MlirDialect dialect
) {
121 return wrap(unwrap(dialect
)->getContext());
124 bool mlirDialectEqual(MlirDialect dialect1
, MlirDialect dialect2
) {
125 return unwrap(dialect1
) == unwrap(dialect2
);
128 MlirStringRef
mlirDialectGetNamespace(MlirDialect dialect
) {
129 return wrap(unwrap(dialect
)->getNamespace());
132 //===----------------------------------------------------------------------===//
133 // DialectRegistry API.
134 //===----------------------------------------------------------------------===//
136 MlirDialectRegistry
mlirDialectRegistryCreate() {
137 return wrap(new DialectRegistry());
140 void mlirDialectRegistryDestroy(MlirDialectRegistry registry
) {
141 delete unwrap(registry
);
144 //===----------------------------------------------------------------------===//
146 //===----------------------------------------------------------------------===//
148 MlirAsmState
mlirAsmStateCreateForOperation(MlirOperation op
,
149 MlirOpPrintingFlags flags
) {
150 return wrap(new AsmState(unwrap(op
), *unwrap(flags
)));
153 static Operation
*findParent(Operation
*op
, bool shouldUseLocalScope
) {
155 // If we are printing local scope, stop at the first operation that is
156 // isolated from above.
157 if (shouldUseLocalScope
&& op
->hasTrait
<OpTrait::IsIsolatedFromAbove
>())
160 // Otherwise, traverse up to the next parent.
161 Operation
*parentOp
= op
->getParentOp();
169 MlirAsmState
mlirAsmStateCreateForValue(MlirValue value
,
170 MlirOpPrintingFlags flags
) {
172 mlir::Value val
= unwrap(value
);
173 if (auto result
= llvm::dyn_cast
<OpResult
>(val
)) {
174 op
= result
.getOwner();
176 op
= llvm::cast
<BlockArgument
>(val
).getOwner()->getParentOp();
178 emitError(val
.getLoc()) << "<<UNKNOWN SSA VALUE>>";
182 op
= findParent(op
, unwrap(flags
)->shouldUseLocalScope());
183 return wrap(new AsmState(op
, *unwrap(flags
)));
186 /// Destroys printing flags created with mlirAsmStateCreate.
187 void mlirAsmStateDestroy(MlirAsmState state
) { delete unwrap(state
); }
189 //===----------------------------------------------------------------------===//
190 // Printing flags API.
191 //===----------------------------------------------------------------------===//
193 MlirOpPrintingFlags
mlirOpPrintingFlagsCreate() {
194 return wrap(new OpPrintingFlags());
197 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags
) {
198 delete unwrap(flags
);
201 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags
,
202 intptr_t largeElementLimit
) {
203 unwrap(flags
)->elideLargeElementsAttrs(largeElementLimit
);
206 void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags
,
207 intptr_t largeResourceLimit
) {
208 unwrap(flags
)->elideLargeResourceString(largeResourceLimit
);
211 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags
, bool enable
,
213 unwrap(flags
)->enableDebugInfo(enable
, /*prettyForm=*/prettyForm
);
216 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags
) {
217 unwrap(flags
)->printGenericOpForm();
220 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags
) {
221 unwrap(flags
)->useLocalScope();
224 void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags
) {
225 unwrap(flags
)->assumeVerified();
228 void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags
) {
229 unwrap(flags
)->skipRegions();
231 //===----------------------------------------------------------------------===//
232 // Bytecode printing flags API.
233 //===----------------------------------------------------------------------===//
235 MlirBytecodeWriterConfig
mlirBytecodeWriterConfigCreate() {
236 return wrap(new BytecodeWriterConfig());
239 void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config
) {
240 delete unwrap(config
);
243 void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags
,
245 unwrap(flags
)->setDesiredBytecodeVersion(version
);
248 //===----------------------------------------------------------------------===//
250 //===----------------------------------------------------------------------===//
252 MlirAttribute
mlirLocationGetAttribute(MlirLocation location
) {
253 return wrap(LocationAttr(unwrap(location
)));
256 MlirLocation
mlirLocationFromAttribute(MlirAttribute attribute
) {
257 return wrap(Location(llvm::cast
<LocationAttr
>(unwrap(attribute
))));
260 MlirLocation
mlirLocationFileLineColGet(MlirContext context
,
261 MlirStringRef filename
, unsigned line
,
263 return wrap(Location(
264 FileLineColLoc::get(unwrap(context
), unwrap(filename
), line
, col
)));
267 MlirLocation
mlirLocationCallSiteGet(MlirLocation callee
, MlirLocation caller
) {
268 return wrap(Location(CallSiteLoc::get(unwrap(callee
), unwrap(caller
))));
271 MlirLocation
mlirLocationFusedGet(MlirContext ctx
, intptr_t nLocations
,
272 MlirLocation
const *locations
,
273 MlirAttribute metadata
) {
274 SmallVector
<Location
, 4> locs
;
275 ArrayRef
<Location
> unwrappedLocs
= unwrapList(nLocations
, locations
, locs
);
276 return wrap(FusedLoc::get(unwrappedLocs
, unwrap(metadata
), unwrap(ctx
)));
279 MlirLocation
mlirLocationNameGet(MlirContext context
, MlirStringRef name
,
280 MlirLocation childLoc
) {
281 if (mlirLocationIsNull(childLoc
))
283 Location(NameLoc::get(StringAttr::get(unwrap(context
), unwrap(name
)))));
284 return wrap(Location(NameLoc::get(
285 StringAttr::get(unwrap(context
), unwrap(name
)), unwrap(childLoc
))));
288 MlirLocation
mlirLocationUnknownGet(MlirContext context
) {
289 return wrap(Location(UnknownLoc::get(unwrap(context
))));
292 bool mlirLocationEqual(MlirLocation l1
, MlirLocation l2
) {
293 return unwrap(l1
) == unwrap(l2
);
296 MlirContext
mlirLocationGetContext(MlirLocation location
) {
297 return wrap(unwrap(location
).getContext());
300 void mlirLocationPrint(MlirLocation location
, MlirStringCallback callback
,
302 detail::CallbackOstream
stream(callback
, userData
);
303 unwrap(location
).print(stream
);
306 //===----------------------------------------------------------------------===//
308 //===----------------------------------------------------------------------===//
310 MlirModule
mlirModuleCreateEmpty(MlirLocation location
) {
311 return wrap(ModuleOp::create(unwrap(location
)));
314 MlirModule
mlirModuleCreateParse(MlirContext context
, MlirStringRef module
) {
315 OwningOpRef
<ModuleOp
> owning
=
316 parseSourceString
<ModuleOp
>(unwrap(module
), unwrap(context
));
318 return MlirModule
{nullptr};
319 return MlirModule
{owning
.release().getOperation()};
322 MlirContext
mlirModuleGetContext(MlirModule module
) {
323 return wrap(unwrap(module
).getContext());
326 MlirBlock
mlirModuleGetBody(MlirModule module
) {
327 return wrap(unwrap(module
).getBody());
330 void mlirModuleDestroy(MlirModule module
) {
331 // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is
333 OwningOpRef
<ModuleOp
>(unwrap(module
));
336 MlirOperation
mlirModuleGetOperation(MlirModule module
) {
337 return wrap(unwrap(module
).getOperation());
340 MlirModule
mlirModuleFromOperation(MlirOperation op
) {
341 return wrap(dyn_cast
<ModuleOp
>(unwrap(op
)));
344 //===----------------------------------------------------------------------===//
345 // Operation state API.
346 //===----------------------------------------------------------------------===//
348 MlirOperationState
mlirOperationStateGet(MlirStringRef name
, MlirLocation loc
) {
349 MlirOperationState state
;
351 state
.location
= loc
;
353 state
.results
= nullptr;
355 state
.operands
= nullptr;
357 state
.regions
= nullptr;
358 state
.nSuccessors
= 0;
359 state
.successors
= nullptr;
360 state
.nAttributes
= 0;
361 state
.attributes
= nullptr;
362 state
.enableResultTypeInference
= false;
366 #define APPEND_ELEMS(type, sizeName, elemName) \
368 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \
369 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \
370 state->sizeName += n;
372 void mlirOperationStateAddResults(MlirOperationState
*state
, intptr_t n
,
373 MlirType
const *results
) {
374 APPEND_ELEMS(MlirType
, nResults
, results
);
377 void mlirOperationStateAddOperands(MlirOperationState
*state
, intptr_t n
,
378 MlirValue
const *operands
) {
379 APPEND_ELEMS(MlirValue
, nOperands
, operands
);
381 void mlirOperationStateAddOwnedRegions(MlirOperationState
*state
, intptr_t n
,
382 MlirRegion
const *regions
) {
383 APPEND_ELEMS(MlirRegion
, nRegions
, regions
);
385 void mlirOperationStateAddSuccessors(MlirOperationState
*state
, intptr_t n
,
386 MlirBlock
const *successors
) {
387 APPEND_ELEMS(MlirBlock
, nSuccessors
, successors
);
389 void mlirOperationStateAddAttributes(MlirOperationState
*state
, intptr_t n
,
390 MlirNamedAttribute
const *attributes
) {
391 APPEND_ELEMS(MlirNamedAttribute
, nAttributes
, attributes
);
394 void mlirOperationStateEnableResultTypeInference(MlirOperationState
*state
) {
395 state
->enableResultTypeInference
= true;
398 //===----------------------------------------------------------------------===//
400 //===----------------------------------------------------------------------===//
402 static LogicalResult
inferOperationTypes(OperationState
&state
) {
403 MLIRContext
*context
= state
.getContext();
404 std::optional
<RegisteredOperationName
> info
= state
.name
.getRegisteredInfo();
406 emitError(state
.location
)
407 << "type inference was requested for the operation " << state
.name
408 << ", but the operation was not registered; ensure that the dialect "
409 "containing the operation is linked into MLIR and registered with "
414 auto *inferInterface
= info
->getInterface
<InferTypeOpInterface
>();
415 if (!inferInterface
) {
416 emitError(state
.location
)
417 << "type inference was requested for the operation " << state
.name
418 << ", but the operation does not support type inference; result "
419 "types must be specified explicitly";
423 DictionaryAttr attributes
= state
.attributes
.getDictionary(context
);
424 OpaqueProperties properties
= state
.getRawProperties();
426 if (!properties
&& info
->getOpPropertyByteSize() > 0 && !attributes
.empty()) {
427 auto prop
= std::make_unique
<char[]>(info
->getOpPropertyByteSize());
428 properties
= OpaqueProperties(prop
.get());
430 auto emitError
= [&]() {
431 return mlir::emitError(state
.location
)
432 << " failed properties conversion while building "
433 << state
.name
.getStringRef() << " with `" << attributes
<< "`: ";
435 if (failed(info
->setOpPropertiesFromAttribute(state
.name
, properties
,
436 attributes
, emitError
)))
439 if (succeeded(inferInterface
->inferReturnTypes(
440 context
, state
.location
, state
.operands
, attributes
, properties
,
441 state
.regions
, state
.types
))) {
444 // Diagnostic emitted by interface.
448 if (succeeded(inferInterface
->inferReturnTypes(
449 context
, state
.location
, state
.operands
, attributes
, properties
,
450 state
.regions
, state
.types
)))
453 // Diagnostic emitted by interface.
457 MlirOperation
mlirOperationCreate(MlirOperationState
*state
) {
459 OperationState
cppState(unwrap(state
->location
), unwrap(state
->name
));
460 SmallVector
<Type
, 4> resultStorage
;
461 SmallVector
<Value
, 8> operandStorage
;
462 SmallVector
<Block
*, 2> successorStorage
;
463 cppState
.addTypes(unwrapList(state
->nResults
, state
->results
, resultStorage
));
464 cppState
.addOperands(
465 unwrapList(state
->nOperands
, state
->operands
, operandStorage
));
466 cppState
.addSuccessors(
467 unwrapList(state
->nSuccessors
, state
->successors
, successorStorage
));
469 cppState
.attributes
.reserve(state
->nAttributes
);
470 for (intptr_t i
= 0; i
< state
->nAttributes
; ++i
)
471 cppState
.addAttribute(unwrap(state
->attributes
[i
].name
),
472 unwrap(state
->attributes
[i
].attribute
));
474 for (intptr_t i
= 0; i
< state
->nRegions
; ++i
)
475 cppState
.addRegion(std::unique_ptr
<Region
>(unwrap(state
->regions
[i
])));
477 free(state
->results
);
478 free(state
->operands
);
479 free(state
->successors
);
480 free(state
->regions
);
481 free(state
->attributes
);
483 // Infer result types.
484 if (state
->enableResultTypeInference
) {
485 assert(cppState
.types
.empty() &&
486 "result type inference enabled and result types provided");
487 if (failed(inferOperationTypes(cppState
)))
491 return wrap(Operation::create(cppState
));
494 MlirOperation
mlirOperationCreateParse(MlirContext context
,
495 MlirStringRef sourceStr
,
496 MlirStringRef sourceName
) {
499 parseSourceString(unwrap(sourceStr
), unwrap(context
), unwrap(sourceName
))
503 MlirOperation
mlirOperationClone(MlirOperation op
) {
504 return wrap(unwrap(op
)->clone());
507 void mlirOperationDestroy(MlirOperation op
) { unwrap(op
)->erase(); }
509 void mlirOperationRemoveFromParent(MlirOperation op
) { unwrap(op
)->remove(); }
511 bool mlirOperationEqual(MlirOperation op
, MlirOperation other
) {
512 return unwrap(op
) == unwrap(other
);
515 MlirContext
mlirOperationGetContext(MlirOperation op
) {
516 return wrap(unwrap(op
)->getContext());
519 MlirLocation
mlirOperationGetLocation(MlirOperation op
) {
520 return wrap(unwrap(op
)->getLoc());
523 MlirTypeID
mlirOperationGetTypeID(MlirOperation op
) {
524 if (auto info
= unwrap(op
)->getRegisteredInfo())
525 return wrap(info
->getTypeID());
529 MlirIdentifier
mlirOperationGetName(MlirOperation op
) {
530 return wrap(unwrap(op
)->getName().getIdentifier());
533 MlirBlock
mlirOperationGetBlock(MlirOperation op
) {
534 return wrap(unwrap(op
)->getBlock());
537 MlirOperation
mlirOperationGetParentOperation(MlirOperation op
) {
538 return wrap(unwrap(op
)->getParentOp());
541 intptr_t mlirOperationGetNumRegions(MlirOperation op
) {
542 return static_cast<intptr_t>(unwrap(op
)->getNumRegions());
545 MlirRegion
mlirOperationGetRegion(MlirOperation op
, intptr_t pos
) {
546 return wrap(&unwrap(op
)->getRegion(static_cast<unsigned>(pos
)));
549 MlirRegion
mlirOperationGetFirstRegion(MlirOperation op
) {
550 Operation
*cppOp
= unwrap(op
);
551 if (cppOp
->getNumRegions() == 0)
552 return wrap(static_cast<Region
*>(nullptr));
553 return wrap(&cppOp
->getRegion(0));
556 MlirRegion
mlirRegionGetNextInOperation(MlirRegion region
) {
557 Region
*cppRegion
= unwrap(region
);
558 Operation
*parent
= cppRegion
->getParentOp();
559 intptr_t next
= cppRegion
->getRegionNumber() + 1;
560 if (parent
->getNumRegions() > next
)
561 return wrap(&parent
->getRegion(next
));
562 return wrap(static_cast<Region
*>(nullptr));
565 MlirOperation
mlirOperationGetNextInBlock(MlirOperation op
) {
566 return wrap(unwrap(op
)->getNextNode());
569 intptr_t mlirOperationGetNumOperands(MlirOperation op
) {
570 return static_cast<intptr_t>(unwrap(op
)->getNumOperands());
573 MlirValue
mlirOperationGetOperand(MlirOperation op
, intptr_t pos
) {
574 return wrap(unwrap(op
)->getOperand(static_cast<unsigned>(pos
)));
577 void mlirOperationSetOperand(MlirOperation op
, intptr_t pos
,
578 MlirValue newValue
) {
579 unwrap(op
)->setOperand(static_cast<unsigned>(pos
), unwrap(newValue
));
582 void mlirOperationSetOperands(MlirOperation op
, intptr_t nOperands
,
583 MlirValue
const *operands
) {
584 SmallVector
<Value
> ops
;
585 unwrap(op
)->setOperands(unwrapList(nOperands
, operands
, ops
));
588 intptr_t mlirOperationGetNumResults(MlirOperation op
) {
589 return static_cast<intptr_t>(unwrap(op
)->getNumResults());
592 MlirValue
mlirOperationGetResult(MlirOperation op
, intptr_t pos
) {
593 return wrap(unwrap(op
)->getResult(static_cast<unsigned>(pos
)));
596 intptr_t mlirOperationGetNumSuccessors(MlirOperation op
) {
597 return static_cast<intptr_t>(unwrap(op
)->getNumSuccessors());
600 MlirBlock
mlirOperationGetSuccessor(MlirOperation op
, intptr_t pos
) {
601 return wrap(unwrap(op
)->getSuccessor(static_cast<unsigned>(pos
)));
604 MLIR_CAPI_EXPORTED
bool
605 mlirOperationHasInherentAttributeByName(MlirOperation op
, MlirStringRef name
) {
606 std::optional
<Attribute
> attr
= unwrap(op
)->getInherentAttr(unwrap(name
));
607 return attr
.has_value();
610 MlirAttribute
mlirOperationGetInherentAttributeByName(MlirOperation op
,
611 MlirStringRef name
) {
612 std::optional
<Attribute
> attr
= unwrap(op
)->getInherentAttr(unwrap(name
));
613 if (attr
.has_value())
618 void mlirOperationSetInherentAttributeByName(MlirOperation op
,
620 MlirAttribute attr
) {
621 unwrap(op
)->setInherentAttr(
622 StringAttr::get(unwrap(op
)->getContext(), unwrap(name
)), unwrap(attr
));
625 intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op
) {
626 return static_cast<intptr_t>(
627 llvm::range_size(unwrap(op
)->getDiscardableAttrs()));
630 MlirNamedAttribute
mlirOperationGetDiscardableAttribute(MlirOperation op
,
632 NamedAttribute attr
=
633 *std::next(unwrap(op
)->getDiscardableAttrs().begin(), pos
);
634 return MlirNamedAttribute
{wrap(attr
.getName()), wrap(attr
.getValue())};
637 MlirAttribute
mlirOperationGetDiscardableAttributeByName(MlirOperation op
,
638 MlirStringRef name
) {
639 return wrap(unwrap(op
)->getDiscardableAttr(unwrap(name
)));
642 void mlirOperationSetDiscardableAttributeByName(MlirOperation op
,
644 MlirAttribute attr
) {
645 unwrap(op
)->setDiscardableAttr(unwrap(name
), unwrap(attr
));
648 bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op
,
649 MlirStringRef name
) {
650 return !!unwrap(op
)->removeDiscardableAttr(unwrap(name
));
653 void mlirOperationSetSuccessor(MlirOperation op
, intptr_t pos
,
655 unwrap(op
)->setSuccessor(unwrap(block
), static_cast<unsigned>(pos
));
658 intptr_t mlirOperationGetNumAttributes(MlirOperation op
) {
659 return static_cast<intptr_t>(unwrap(op
)->getAttrs().size());
662 MlirNamedAttribute
mlirOperationGetAttribute(MlirOperation op
, intptr_t pos
) {
663 NamedAttribute attr
= unwrap(op
)->getAttrs()[pos
];
664 return MlirNamedAttribute
{wrap(attr
.getName()), wrap(attr
.getValue())};
667 MlirAttribute
mlirOperationGetAttributeByName(MlirOperation op
,
668 MlirStringRef name
) {
669 return wrap(unwrap(op
)->getAttr(unwrap(name
)));
672 void mlirOperationSetAttributeByName(MlirOperation op
, MlirStringRef name
,
673 MlirAttribute attr
) {
674 unwrap(op
)->setAttr(unwrap(name
), unwrap(attr
));
677 bool mlirOperationRemoveAttributeByName(MlirOperation op
, MlirStringRef name
) {
678 return !!unwrap(op
)->removeAttr(unwrap(name
));
681 void mlirOperationPrint(MlirOperation op
, MlirStringCallback callback
,
683 detail::CallbackOstream
stream(callback
, userData
);
684 unwrap(op
)->print(stream
);
687 void mlirOperationPrintWithFlags(MlirOperation op
, MlirOpPrintingFlags flags
,
688 MlirStringCallback callback
, void *userData
) {
689 detail::CallbackOstream
stream(callback
, userData
);
690 unwrap(op
)->print(stream
, *unwrap(flags
));
693 void mlirOperationPrintWithState(MlirOperation op
, MlirAsmState state
,
694 MlirStringCallback callback
, void *userData
) {
695 detail::CallbackOstream
stream(callback
, userData
);
697 unwrap(op
)->print(stream
, *unwrap(state
));
698 unwrap(op
)->print(stream
);
701 void mlirOperationWriteBytecode(MlirOperation op
, MlirStringCallback callback
,
703 detail::CallbackOstream
stream(callback
, userData
);
704 // As no desired version is set, no failure can occur.
705 (void)writeBytecodeToFile(unwrap(op
), stream
);
708 MlirLogicalResult
mlirOperationWriteBytecodeWithConfig(
709 MlirOperation op
, MlirBytecodeWriterConfig config
,
710 MlirStringCallback callback
, void *userData
) {
711 detail::CallbackOstream
stream(callback
, userData
);
712 return wrap(writeBytecodeToFile(unwrap(op
), stream
, *unwrap(config
)));
715 void mlirOperationDump(MlirOperation op
) { return unwrap(op
)->dump(); }
717 bool mlirOperationVerify(MlirOperation op
) {
718 return succeeded(verify(unwrap(op
)));
721 void mlirOperationMoveAfter(MlirOperation op
, MlirOperation other
) {
722 return unwrap(op
)->moveAfter(unwrap(other
));
725 void mlirOperationMoveBefore(MlirOperation op
, MlirOperation other
) {
726 return unwrap(op
)->moveBefore(unwrap(other
));
729 static mlir::WalkResult
unwrap(MlirWalkResult result
) {
731 case MlirWalkResultAdvance
:
732 return mlir::WalkResult::advance();
734 case MlirWalkResultInterrupt
:
735 return mlir::WalkResult::interrupt();
737 case MlirWalkResultSkip
:
738 return mlir::WalkResult::skip();
740 llvm_unreachable("unknown result in WalkResult::unwrap");
743 void mlirOperationWalk(MlirOperation op
, MlirOperationWalkCallback callback
,
744 void *userData
, MlirWalkOrder walkOrder
) {
747 case MlirWalkPreOrder
:
748 unwrap(op
)->walk
<mlir::WalkOrder::PreOrder
>(
749 [callback
, userData
](Operation
*op
) {
750 return unwrap(callback(wrap(op
), userData
));
753 case MlirWalkPostOrder
:
754 unwrap(op
)->walk
<mlir::WalkOrder::PostOrder
>(
755 [callback
, userData
](Operation
*op
) {
756 return unwrap(callback(wrap(op
), userData
));
761 //===----------------------------------------------------------------------===//
763 //===----------------------------------------------------------------------===//
765 MlirRegion
mlirRegionCreate() { return wrap(new Region
); }
767 bool mlirRegionEqual(MlirRegion region
, MlirRegion other
) {
768 return unwrap(region
) == unwrap(other
);
771 MlirBlock
mlirRegionGetFirstBlock(MlirRegion region
) {
772 Region
*cppRegion
= unwrap(region
);
773 if (cppRegion
->empty())
774 return wrap(static_cast<Block
*>(nullptr));
775 return wrap(&cppRegion
->front());
778 void mlirRegionAppendOwnedBlock(MlirRegion region
, MlirBlock block
) {
779 unwrap(region
)->push_back(unwrap(block
));
782 void mlirRegionInsertOwnedBlock(MlirRegion region
, intptr_t pos
,
784 auto &blockList
= unwrap(region
)->getBlocks();
785 blockList
.insert(std::next(blockList
.begin(), pos
), unwrap(block
));
788 void mlirRegionInsertOwnedBlockAfter(MlirRegion region
, MlirBlock reference
,
790 Region
*cppRegion
= unwrap(region
);
791 if (mlirBlockIsNull(reference
)) {
792 cppRegion
->getBlocks().insert(cppRegion
->begin(), unwrap(block
));
796 assert(unwrap(reference
)->getParent() == unwrap(region
) &&
797 "expected reference block to belong to the region");
798 cppRegion
->getBlocks().insertAfter(Region::iterator(unwrap(reference
)),
802 void mlirRegionInsertOwnedBlockBefore(MlirRegion region
, MlirBlock reference
,
804 if (mlirBlockIsNull(reference
))
805 return mlirRegionAppendOwnedBlock(region
, block
);
807 assert(unwrap(reference
)->getParent() == unwrap(region
) &&
808 "expected reference block to belong to the region");
809 unwrap(region
)->getBlocks().insert(Region::iterator(unwrap(reference
)),
813 void mlirRegionDestroy(MlirRegion region
) {
814 delete static_cast<Region
*>(region
.ptr
);
817 void mlirRegionTakeBody(MlirRegion target
, MlirRegion source
) {
818 unwrap(target
)->takeBody(*unwrap(source
));
821 //===----------------------------------------------------------------------===//
823 //===----------------------------------------------------------------------===//
825 MlirBlock
mlirBlockCreate(intptr_t nArgs
, MlirType
const *args
,
826 MlirLocation
const *locs
) {
827 Block
*b
= new Block
;
828 for (intptr_t i
= 0; i
< nArgs
; ++i
)
829 b
->addArgument(unwrap(args
[i
]), unwrap(locs
[i
]));
833 bool mlirBlockEqual(MlirBlock block
, MlirBlock other
) {
834 return unwrap(block
) == unwrap(other
);
837 MlirOperation
mlirBlockGetParentOperation(MlirBlock block
) {
838 return wrap(unwrap(block
)->getParentOp());
841 MlirRegion
mlirBlockGetParentRegion(MlirBlock block
) {
842 return wrap(unwrap(block
)->getParent());
845 MlirBlock
mlirBlockGetNextInRegion(MlirBlock block
) {
846 return wrap(unwrap(block
)->getNextNode());
849 MlirOperation
mlirBlockGetFirstOperation(MlirBlock block
) {
850 Block
*cppBlock
= unwrap(block
);
851 if (cppBlock
->empty())
852 return wrap(static_cast<Operation
*>(nullptr));
853 return wrap(&cppBlock
->front());
856 MlirOperation
mlirBlockGetTerminator(MlirBlock block
) {
857 Block
*cppBlock
= unwrap(block
);
858 if (cppBlock
->empty())
859 return wrap(static_cast<Operation
*>(nullptr));
860 Operation
&back
= cppBlock
->back();
861 if (!back
.hasTrait
<OpTrait::IsTerminator
>())
862 return wrap(static_cast<Operation
*>(nullptr));
866 void mlirBlockAppendOwnedOperation(MlirBlock block
, MlirOperation operation
) {
867 unwrap(block
)->push_back(unwrap(operation
));
870 void mlirBlockInsertOwnedOperation(MlirBlock block
, intptr_t pos
,
871 MlirOperation operation
) {
872 auto &opList
= unwrap(block
)->getOperations();
873 opList
.insert(std::next(opList
.begin(), pos
), unwrap(operation
));
876 void mlirBlockInsertOwnedOperationAfter(MlirBlock block
,
877 MlirOperation reference
,
878 MlirOperation operation
) {
879 Block
*cppBlock
= unwrap(block
);
880 if (mlirOperationIsNull(reference
)) {
881 cppBlock
->getOperations().insert(cppBlock
->begin(), unwrap(operation
));
885 assert(unwrap(reference
)->getBlock() == unwrap(block
) &&
886 "expected reference operation to belong to the block");
887 cppBlock
->getOperations().insertAfter(Block::iterator(unwrap(reference
)),
891 void mlirBlockInsertOwnedOperationBefore(MlirBlock block
,
892 MlirOperation reference
,
893 MlirOperation operation
) {
894 if (mlirOperationIsNull(reference
))
895 return mlirBlockAppendOwnedOperation(block
, operation
);
897 assert(unwrap(reference
)->getBlock() == unwrap(block
) &&
898 "expected reference operation to belong to the block");
899 unwrap(block
)->getOperations().insert(Block::iterator(unwrap(reference
)),
903 void mlirBlockDestroy(MlirBlock block
) { delete unwrap(block
); }
905 void mlirBlockDetach(MlirBlock block
) {
906 Block
*b
= unwrap(block
);
907 b
->getParent()->getBlocks().remove(b
);
910 intptr_t mlirBlockGetNumArguments(MlirBlock block
) {
911 return static_cast<intptr_t>(unwrap(block
)->getNumArguments());
914 MlirValue
mlirBlockAddArgument(MlirBlock block
, MlirType type
,
916 return wrap(unwrap(block
)->addArgument(unwrap(type
), unwrap(loc
)));
919 void mlirBlockEraseArgument(MlirBlock block
, unsigned index
) {
920 return unwrap(block
)->eraseArgument(index
);
923 MlirValue
mlirBlockInsertArgument(MlirBlock block
, intptr_t pos
, MlirType type
,
925 return wrap(unwrap(block
)->insertArgument(pos
, unwrap(type
), unwrap(loc
)));
928 MlirValue
mlirBlockGetArgument(MlirBlock block
, intptr_t pos
) {
929 return wrap(unwrap(block
)->getArgument(static_cast<unsigned>(pos
)));
932 void mlirBlockPrint(MlirBlock block
, MlirStringCallback callback
,
934 detail::CallbackOstream
stream(callback
, userData
);
935 unwrap(block
)->print(stream
);
938 //===----------------------------------------------------------------------===//
940 //===----------------------------------------------------------------------===//
942 bool mlirValueEqual(MlirValue value1
, MlirValue value2
) {
943 return unwrap(value1
) == unwrap(value2
);
946 bool mlirValueIsABlockArgument(MlirValue value
) {
947 return llvm::isa
<BlockArgument
>(unwrap(value
));
950 bool mlirValueIsAOpResult(MlirValue value
) {
951 return llvm::isa
<OpResult
>(unwrap(value
));
954 MlirBlock
mlirBlockArgumentGetOwner(MlirValue value
) {
955 return wrap(llvm::cast
<BlockArgument
>(unwrap(value
)).getOwner());
958 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value
) {
959 return static_cast<intptr_t>(
960 llvm::cast
<BlockArgument
>(unwrap(value
)).getArgNumber());
963 void mlirBlockArgumentSetType(MlirValue value
, MlirType type
) {
964 llvm::cast
<BlockArgument
>(unwrap(value
)).setType(unwrap(type
));
967 MlirOperation
mlirOpResultGetOwner(MlirValue value
) {
968 return wrap(llvm::cast
<OpResult
>(unwrap(value
)).getOwner());
971 intptr_t mlirOpResultGetResultNumber(MlirValue value
) {
972 return static_cast<intptr_t>(
973 llvm::cast
<OpResult
>(unwrap(value
)).getResultNumber());
976 MlirType
mlirValueGetType(MlirValue value
) {
977 return wrap(unwrap(value
).getType());
980 void mlirValueSetType(MlirValue value
, MlirType type
) {
981 unwrap(value
).setType(unwrap(type
));
984 void mlirValueDump(MlirValue value
) { unwrap(value
).dump(); }
986 void mlirValuePrint(MlirValue value
, MlirStringCallback callback
,
988 detail::CallbackOstream
stream(callback
, userData
);
989 unwrap(value
).print(stream
);
992 void mlirValuePrintAsOperand(MlirValue value
, MlirAsmState state
,
993 MlirStringCallback callback
, void *userData
) {
994 detail::CallbackOstream
stream(callback
, userData
);
995 Value cppValue
= unwrap(value
);
996 cppValue
.printAsOperand(stream
, *unwrap(state
));
999 MlirOpOperand
mlirValueGetFirstUse(MlirValue value
) {
1000 Value cppValue
= unwrap(value
);
1001 if (cppValue
.use_empty())
1004 OpOperand
*opOperand
= cppValue
.use_begin().getOperand();
1006 return wrap(opOperand
);
1009 void mlirValueReplaceAllUsesOfWith(MlirValue oldValue
, MlirValue newValue
) {
1010 unwrap(oldValue
).replaceAllUsesWith(unwrap(newValue
));
1013 void mlirValueReplaceAllUsesExcept(MlirValue oldValue
, MlirValue newValue
,
1014 intptr_t numExceptions
,
1015 MlirOperation
*exceptions
) {
1016 Value oldValueCpp
= unwrap(oldValue
);
1017 Value newValueCpp
= unwrap(newValue
);
1019 llvm::SmallPtrSet
<mlir::Operation
*, 4> exceptionSet
;
1020 for (intptr_t i
= 0; i
< numExceptions
; ++i
) {
1021 exceptionSet
.insert(unwrap(exceptions
[i
]));
1024 oldValueCpp
.replaceAllUsesExcept(newValueCpp
, exceptionSet
);
1027 //===----------------------------------------------------------------------===//
1029 //===----------------------------------------------------------------------===//
1031 bool mlirOpOperandIsNull(MlirOpOperand opOperand
) { return !opOperand
.ptr
; }
1033 MlirOperation
mlirOpOperandGetOwner(MlirOpOperand opOperand
) {
1034 return wrap(unwrap(opOperand
)->getOwner());
1037 MlirValue
mlirOpOperandGetValue(MlirOpOperand opOperand
) {
1038 return wrap(unwrap(opOperand
)->get());
1041 unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand
) {
1042 return unwrap(opOperand
)->getOperandNumber();
1045 MlirOpOperand
mlirOpOperandGetNextUse(MlirOpOperand opOperand
) {
1046 if (mlirOpOperandIsNull(opOperand
))
1049 OpOperand
*nextOpOperand
= static_cast<OpOperand
*>(
1050 unwrap(opOperand
)->getNextOperandUsingThisValue());
1055 return wrap(nextOpOperand
);
1058 //===----------------------------------------------------------------------===//
1060 //===----------------------------------------------------------------------===//
1062 MlirType
mlirTypeParseGet(MlirContext context
, MlirStringRef type
) {
1063 return wrap(mlir::parseType(unwrap(type
), unwrap(context
)));
1066 MlirContext
mlirTypeGetContext(MlirType type
) {
1067 return wrap(unwrap(type
).getContext());
1070 MlirTypeID
mlirTypeGetTypeID(MlirType type
) {
1071 return wrap(unwrap(type
).getTypeID());
1074 MlirDialect
mlirTypeGetDialect(MlirType type
) {
1075 return wrap(&unwrap(type
).getDialect());
1078 bool mlirTypeEqual(MlirType t1
, MlirType t2
) {
1079 return unwrap(t1
) == unwrap(t2
);
1082 void mlirTypePrint(MlirType type
, MlirStringCallback callback
, void *userData
) {
1083 detail::CallbackOstream
stream(callback
, userData
);
1084 unwrap(type
).print(stream
);
1087 void mlirTypeDump(MlirType type
) { unwrap(type
).dump(); }
1089 //===----------------------------------------------------------------------===//
1091 //===----------------------------------------------------------------------===//
1093 MlirAttribute
mlirAttributeParseGet(MlirContext context
, MlirStringRef attr
) {
1094 return wrap(mlir::parseAttribute(unwrap(attr
), unwrap(context
)));
1097 MlirContext
mlirAttributeGetContext(MlirAttribute attribute
) {
1098 return wrap(unwrap(attribute
).getContext());
1101 MlirType
mlirAttributeGetType(MlirAttribute attribute
) {
1102 Attribute attr
= unwrap(attribute
);
1103 if (auto typedAttr
= llvm::dyn_cast
<TypedAttr
>(attr
))
1104 return wrap(typedAttr
.getType());
1105 return wrap(NoneType::get(attr
.getContext()));
1108 MlirTypeID
mlirAttributeGetTypeID(MlirAttribute attr
) {
1109 return wrap(unwrap(attr
).getTypeID());
1112 MlirDialect
mlirAttributeGetDialect(MlirAttribute attr
) {
1113 return wrap(&unwrap(attr
).getDialect());
1116 bool mlirAttributeEqual(MlirAttribute a1
, MlirAttribute a2
) {
1117 return unwrap(a1
) == unwrap(a2
);
1120 void mlirAttributePrint(MlirAttribute attr
, MlirStringCallback callback
,
1122 detail::CallbackOstream
stream(callback
, userData
);
1123 unwrap(attr
).print(stream
);
1126 void mlirAttributeDump(MlirAttribute attr
) { unwrap(attr
).dump(); }
1128 MlirNamedAttribute
mlirNamedAttributeGet(MlirIdentifier name
,
1129 MlirAttribute attr
) {
1130 return MlirNamedAttribute
{name
, attr
};
1133 //===----------------------------------------------------------------------===//
1135 //===----------------------------------------------------------------------===//
1137 MlirIdentifier
mlirIdentifierGet(MlirContext context
, MlirStringRef str
) {
1138 return wrap(StringAttr::get(unwrap(context
), unwrap(str
)));
1141 MlirContext
mlirIdentifierGetContext(MlirIdentifier ident
) {
1142 return wrap(unwrap(ident
).getContext());
1145 bool mlirIdentifierEqual(MlirIdentifier ident
, MlirIdentifier other
) {
1146 return unwrap(ident
) == unwrap(other
);
1149 MlirStringRef
mlirIdentifierStr(MlirIdentifier ident
) {
1150 return wrap(unwrap(ident
).strref());
1153 //===----------------------------------------------------------------------===//
1154 // Symbol and SymbolTable API.
1155 //===----------------------------------------------------------------------===//
1157 MlirStringRef
mlirSymbolTableGetSymbolAttributeName() {
1158 return wrap(SymbolTable::getSymbolAttrName());
1161 MlirStringRef
mlirSymbolTableGetVisibilityAttributeName() {
1162 return wrap(SymbolTable::getVisibilityAttrName());
1165 MlirSymbolTable
mlirSymbolTableCreate(MlirOperation operation
) {
1166 if (!unwrap(operation
)->hasTrait
<OpTrait::SymbolTable
>())
1167 return wrap(static_cast<SymbolTable
*>(nullptr));
1168 return wrap(new SymbolTable(unwrap(operation
)));
1171 void mlirSymbolTableDestroy(MlirSymbolTable symbolTable
) {
1172 delete unwrap(symbolTable
);
1175 MlirOperation
mlirSymbolTableLookup(MlirSymbolTable symbolTable
,
1176 MlirStringRef name
) {
1177 return wrap(unwrap(symbolTable
)->lookup(StringRef(name
.data
, name
.length
)));
1180 MlirAttribute
mlirSymbolTableInsert(MlirSymbolTable symbolTable
,
1181 MlirOperation operation
) {
1182 return wrap((Attribute
)unwrap(symbolTable
)->insert(unwrap(operation
)));
1185 void mlirSymbolTableErase(MlirSymbolTable symbolTable
,
1186 MlirOperation operation
) {
1187 unwrap(symbolTable
)->erase(unwrap(operation
));
1190 MlirLogicalResult
mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol
,
1191 MlirStringRef newSymbol
,
1192 MlirOperation from
) {
1193 auto *cppFrom
= unwrap(from
);
1194 auto *context
= cppFrom
->getContext();
1195 auto oldSymbolAttr
= StringAttr::get(context
, unwrap(oldSymbol
));
1196 auto newSymbolAttr
= StringAttr::get(context
, unwrap(newSymbol
));
1197 return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr
, newSymbolAttr
,
1201 void mlirSymbolTableWalkSymbolTables(MlirOperation from
, bool allSymUsesVisible
,
1202 void (*callback
)(MlirOperation
, bool,
1205 SymbolTable::walkSymbolTables(unwrap(from
), allSymUsesVisible
,
1206 [&](Operation
*foundOpCpp
, bool isVisible
) {
1207 callback(wrap(foundOpCpp
), isVisible
,