1 //===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir-c/BuiltinAttributes.h"
10 #include "mlir-c/Support.h"
11 #include "mlir/CAPI/AffineMap.h"
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/IntegerSet.h"
14 #include "mlir/CAPI/Support.h"
15 #include "mlir/IR/AsmState.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
22 MlirAttribute
mlirAttributeGetNull() { return {nullptr}; }
24 //===----------------------------------------------------------------------===//
25 // Location attribute.
26 //===----------------------------------------------------------------------===//
28 bool mlirAttributeIsALocation(MlirAttribute attr
) {
29 return llvm::isa
<LocationAttr
>(unwrap(attr
));
32 //===----------------------------------------------------------------------===//
33 // Affine map attribute.
34 //===----------------------------------------------------------------------===//
36 bool mlirAttributeIsAAffineMap(MlirAttribute attr
) {
37 return llvm::isa
<AffineMapAttr
>(unwrap(attr
));
40 MlirAttribute
mlirAffineMapAttrGet(MlirAffineMap map
) {
41 return wrap(AffineMapAttr::get(unwrap(map
)));
44 MlirAffineMap
mlirAffineMapAttrGetValue(MlirAttribute attr
) {
45 return wrap(llvm::cast
<AffineMapAttr
>(unwrap(attr
)).getValue());
48 MlirTypeID
mlirAffineMapAttrGetTypeID(void) {
49 return wrap(AffineMapAttr::getTypeID());
52 //===----------------------------------------------------------------------===//
54 //===----------------------------------------------------------------------===//
56 bool mlirAttributeIsAArray(MlirAttribute attr
) {
57 return llvm::isa
<ArrayAttr
>(unwrap(attr
));
60 MlirAttribute
mlirArrayAttrGet(MlirContext ctx
, intptr_t numElements
,
61 MlirAttribute
const *elements
) {
62 SmallVector
<Attribute
, 8> attrs
;
64 ArrayAttr::get(unwrap(ctx
), unwrapList(static_cast<size_t>(numElements
),
68 intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr
) {
69 return static_cast<intptr_t>(llvm::cast
<ArrayAttr
>(unwrap(attr
)).size());
72 MlirAttribute
mlirArrayAttrGetElement(MlirAttribute attr
, intptr_t pos
) {
73 return wrap(llvm::cast
<ArrayAttr
>(unwrap(attr
)).getValue()[pos
]);
76 MlirTypeID
mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); }
78 //===----------------------------------------------------------------------===//
79 // Dictionary attribute.
80 //===----------------------------------------------------------------------===//
82 bool mlirAttributeIsADictionary(MlirAttribute attr
) {
83 return llvm::isa
<DictionaryAttr
>(unwrap(attr
));
86 MlirAttribute
mlirDictionaryAttrGet(MlirContext ctx
, intptr_t numElements
,
87 MlirNamedAttribute
const *elements
) {
88 SmallVector
<NamedAttribute
, 8> attributes
;
89 attributes
.reserve(numElements
);
90 for (intptr_t i
= 0; i
< numElements
; ++i
)
91 attributes
.emplace_back(unwrap(elements
[i
].name
),
92 unwrap(elements
[i
].attribute
));
93 return wrap(DictionaryAttr::get(unwrap(ctx
), attributes
));
96 intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr
) {
97 return static_cast<intptr_t>(llvm::cast
<DictionaryAttr
>(unwrap(attr
)).size());
100 MlirNamedAttribute
mlirDictionaryAttrGetElement(MlirAttribute attr
,
102 NamedAttribute attribute
=
103 llvm::cast
<DictionaryAttr
>(unwrap(attr
)).getValue()[pos
];
104 return {wrap(attribute
.getName()), wrap(attribute
.getValue())};
107 MlirAttribute
mlirDictionaryAttrGetElementByName(MlirAttribute attr
,
108 MlirStringRef name
) {
109 return wrap(llvm::cast
<DictionaryAttr
>(unwrap(attr
)).get(unwrap(name
)));
112 MlirTypeID
mlirDictionaryAttrGetTypeID(void) {
113 return wrap(DictionaryAttr::getTypeID());
116 //===----------------------------------------------------------------------===//
117 // Floating point attribute.
118 //===----------------------------------------------------------------------===//
120 bool mlirAttributeIsAFloat(MlirAttribute attr
) {
121 return llvm::isa
<FloatAttr
>(unwrap(attr
));
124 MlirAttribute
mlirFloatAttrDoubleGet(MlirContext ctx
, MlirType type
,
126 return wrap(FloatAttr::get(unwrap(type
), value
));
129 MlirAttribute
mlirFloatAttrDoubleGetChecked(MlirLocation loc
, MlirType type
,
131 return wrap(FloatAttr::getChecked(unwrap(loc
), unwrap(type
), value
));
134 double mlirFloatAttrGetValueDouble(MlirAttribute attr
) {
135 return llvm::cast
<FloatAttr
>(unwrap(attr
)).getValueAsDouble();
138 MlirTypeID
mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); }
140 //===----------------------------------------------------------------------===//
141 // Integer attribute.
142 //===----------------------------------------------------------------------===//
144 bool mlirAttributeIsAInteger(MlirAttribute attr
) {
145 return llvm::isa
<IntegerAttr
>(unwrap(attr
));
148 MlirAttribute
mlirIntegerAttrGet(MlirType type
, int64_t value
) {
149 return wrap(IntegerAttr::get(unwrap(type
), value
));
152 int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr
) {
153 return llvm::cast
<IntegerAttr
>(unwrap(attr
)).getInt();
156 int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr
) {
157 return llvm::cast
<IntegerAttr
>(unwrap(attr
)).getSInt();
160 uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr
) {
161 return llvm::cast
<IntegerAttr
>(unwrap(attr
)).getUInt();
164 MlirTypeID
mlirIntegerAttrGetTypeID(void) {
165 return wrap(IntegerAttr::getTypeID());
168 //===----------------------------------------------------------------------===//
170 //===----------------------------------------------------------------------===//
172 bool mlirAttributeIsABool(MlirAttribute attr
) {
173 return llvm::isa
<BoolAttr
>(unwrap(attr
));
176 MlirAttribute
mlirBoolAttrGet(MlirContext ctx
, int value
) {
177 return wrap(BoolAttr::get(unwrap(ctx
), value
));
180 bool mlirBoolAttrGetValue(MlirAttribute attr
) {
181 return llvm::cast
<BoolAttr
>(unwrap(attr
)).getValue();
184 //===----------------------------------------------------------------------===//
185 // Integer set attribute.
186 //===----------------------------------------------------------------------===//
188 bool mlirAttributeIsAIntegerSet(MlirAttribute attr
) {
189 return llvm::isa
<IntegerSetAttr
>(unwrap(attr
));
192 MlirTypeID
mlirIntegerSetAttrGetTypeID(void) {
193 return wrap(IntegerSetAttr::getTypeID());
196 MlirAttribute
mlirIntegerSetAttrGet(MlirIntegerSet set
) {
197 return wrap(IntegerSetAttr::get(unwrap(set
)));
200 MlirIntegerSet
mlirIntegerSetAttrGetValue(MlirAttribute attr
) {
201 return wrap(llvm::cast
<IntegerSetAttr
>(unwrap(attr
)).getValue());
204 //===----------------------------------------------------------------------===//
206 //===----------------------------------------------------------------------===//
208 bool mlirAttributeIsAOpaque(MlirAttribute attr
) {
209 return llvm::isa
<OpaqueAttr
>(unwrap(attr
));
212 MlirAttribute
mlirOpaqueAttrGet(MlirContext ctx
, MlirStringRef dialectNamespace
,
213 intptr_t dataLength
, const char *data
,
216 OpaqueAttr::get(StringAttr::get(unwrap(ctx
), unwrap(dialectNamespace
)),
217 StringRef(data
, dataLength
), unwrap(type
)));
220 MlirStringRef
mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr
) {
222 llvm::cast
<OpaqueAttr
>(unwrap(attr
)).getDialectNamespace().strref());
225 MlirStringRef
mlirOpaqueAttrGetData(MlirAttribute attr
) {
226 return wrap(llvm::cast
<OpaqueAttr
>(unwrap(attr
)).getAttrData());
229 MlirTypeID
mlirOpaqueAttrGetTypeID(void) {
230 return wrap(OpaqueAttr::getTypeID());
233 //===----------------------------------------------------------------------===//
235 //===----------------------------------------------------------------------===//
237 bool mlirAttributeIsAString(MlirAttribute attr
) {
238 return llvm::isa
<StringAttr
>(unwrap(attr
));
241 MlirAttribute
mlirStringAttrGet(MlirContext ctx
, MlirStringRef str
) {
242 return wrap((Attribute
)StringAttr::get(unwrap(ctx
), unwrap(str
)));
245 MlirAttribute
mlirStringAttrTypedGet(MlirType type
, MlirStringRef str
) {
246 return wrap((Attribute
)StringAttr::get(unwrap(str
), unwrap(type
)));
249 MlirStringRef
mlirStringAttrGetValue(MlirAttribute attr
) {
250 return wrap(llvm::cast
<StringAttr
>(unwrap(attr
)).getValue());
253 MlirTypeID
mlirStringAttrGetTypeID(void) {
254 return wrap(StringAttr::getTypeID());
257 //===----------------------------------------------------------------------===//
258 // SymbolRef attribute.
259 //===----------------------------------------------------------------------===//
261 bool mlirAttributeIsASymbolRef(MlirAttribute attr
) {
262 return llvm::isa
<SymbolRefAttr
>(unwrap(attr
));
265 MlirAttribute
mlirSymbolRefAttrGet(MlirContext ctx
, MlirStringRef symbol
,
266 intptr_t numReferences
,
267 MlirAttribute
const *references
) {
268 SmallVector
<FlatSymbolRefAttr
, 4> refs
;
269 refs
.reserve(numReferences
);
270 for (intptr_t i
= 0; i
< numReferences
; ++i
)
271 refs
.push_back(llvm::cast
<FlatSymbolRefAttr
>(unwrap(references
[i
])));
272 auto symbolAttr
= StringAttr::get(unwrap(ctx
), unwrap(symbol
));
273 return wrap(SymbolRefAttr::get(symbolAttr
, refs
));
276 MlirStringRef
mlirSymbolRefAttrGetRootReference(MlirAttribute attr
) {
278 llvm::cast
<SymbolRefAttr
>(unwrap(attr
)).getRootReference().getValue());
281 MlirStringRef
mlirSymbolRefAttrGetLeafReference(MlirAttribute attr
) {
283 llvm::cast
<SymbolRefAttr
>(unwrap(attr
)).getLeafReference().getValue());
286 intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr
) {
287 return static_cast<intptr_t>(
288 llvm::cast
<SymbolRefAttr
>(unwrap(attr
)).getNestedReferences().size());
291 MlirAttribute
mlirSymbolRefAttrGetNestedReference(MlirAttribute attr
,
294 llvm::cast
<SymbolRefAttr
>(unwrap(attr
)).getNestedReferences()[pos
]);
297 MlirTypeID
mlirSymbolRefAttrGetTypeID(void) {
298 return wrap(SymbolRefAttr::getTypeID());
301 MlirAttribute
mlirDisctinctAttrCreate(MlirAttribute referencedAttr
) {
302 return wrap(mlir::DistinctAttr::create(unwrap(referencedAttr
)));
305 //===----------------------------------------------------------------------===//
306 // Flat SymbolRef attribute.
307 //===----------------------------------------------------------------------===//
309 bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr
) {
310 return llvm::isa
<FlatSymbolRefAttr
>(unwrap(attr
));
313 MlirAttribute
mlirFlatSymbolRefAttrGet(MlirContext ctx
, MlirStringRef symbol
) {
314 return wrap(FlatSymbolRefAttr::get(unwrap(ctx
), unwrap(symbol
)));
317 MlirStringRef
mlirFlatSymbolRefAttrGetValue(MlirAttribute attr
) {
318 return wrap(llvm::cast
<FlatSymbolRefAttr
>(unwrap(attr
)).getValue());
321 //===----------------------------------------------------------------------===//
323 //===----------------------------------------------------------------------===//
325 bool mlirAttributeIsAType(MlirAttribute attr
) {
326 return llvm::isa
<TypeAttr
>(unwrap(attr
));
329 MlirAttribute
mlirTypeAttrGet(MlirType type
) {
330 return wrap(TypeAttr::get(unwrap(type
)));
333 MlirType
mlirTypeAttrGetValue(MlirAttribute attr
) {
334 return wrap(llvm::cast
<TypeAttr
>(unwrap(attr
)).getValue());
337 MlirTypeID
mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); }
339 //===----------------------------------------------------------------------===//
341 //===----------------------------------------------------------------------===//
343 bool mlirAttributeIsAUnit(MlirAttribute attr
) {
344 return llvm::isa
<UnitAttr
>(unwrap(attr
));
347 MlirAttribute
mlirUnitAttrGet(MlirContext ctx
) {
348 return wrap(UnitAttr::get(unwrap(ctx
)));
351 MlirTypeID
mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); }
353 //===----------------------------------------------------------------------===//
354 // Elements attributes.
355 //===----------------------------------------------------------------------===//
357 bool mlirAttributeIsAElements(MlirAttribute attr
) {
358 return llvm::isa
<ElementsAttr
>(unwrap(attr
));
361 MlirAttribute
mlirElementsAttrGetValue(MlirAttribute attr
, intptr_t rank
,
363 return wrap(llvm::cast
<ElementsAttr
>(unwrap(attr
))
364 .getValues
<Attribute
>()[llvm::ArrayRef(idxs
, rank
)]);
367 bool mlirElementsAttrIsValidIndex(MlirAttribute attr
, intptr_t rank
,
369 return llvm::cast
<ElementsAttr
>(unwrap(attr
))
370 .isValidIndex(llvm::ArrayRef(idxs
, rank
));
373 int64_t mlirElementsAttrGetNumElements(MlirAttribute attr
) {
374 return llvm::cast
<ElementsAttr
>(unwrap(attr
)).getNumElements();
377 //===----------------------------------------------------------------------===//
378 // Dense array attribute.
379 //===----------------------------------------------------------------------===//
381 MlirTypeID
mlirDenseArrayAttrGetTypeID() {
382 return wrap(DenseArrayAttr::getTypeID());
385 //===----------------------------------------------------------------------===//
387 //===----------------------------------------------------------------------===//
389 bool mlirAttributeIsADenseBoolArray(MlirAttribute attr
) {
390 return llvm::isa
<DenseBoolArrayAttr
>(unwrap(attr
));
392 bool mlirAttributeIsADenseI8Array(MlirAttribute attr
) {
393 return llvm::isa
<DenseI8ArrayAttr
>(unwrap(attr
));
395 bool mlirAttributeIsADenseI16Array(MlirAttribute attr
) {
396 return llvm::isa
<DenseI16ArrayAttr
>(unwrap(attr
));
398 bool mlirAttributeIsADenseI32Array(MlirAttribute attr
) {
399 return llvm::isa
<DenseI32ArrayAttr
>(unwrap(attr
));
401 bool mlirAttributeIsADenseI64Array(MlirAttribute attr
) {
402 return llvm::isa
<DenseI64ArrayAttr
>(unwrap(attr
));
404 bool mlirAttributeIsADenseF32Array(MlirAttribute attr
) {
405 return llvm::isa
<DenseF32ArrayAttr
>(unwrap(attr
));
407 bool mlirAttributeIsADenseF64Array(MlirAttribute attr
) {
408 return llvm::isa
<DenseF64ArrayAttr
>(unwrap(attr
));
411 //===----------------------------------------------------------------------===//
413 //===----------------------------------------------------------------------===//
415 MlirAttribute
mlirDenseBoolArrayGet(MlirContext ctx
, intptr_t size
,
417 SmallVector
<bool, 4> elements(values
, values
+ size
);
418 return wrap(DenseBoolArrayAttr::get(unwrap(ctx
), elements
));
420 MlirAttribute
mlirDenseI8ArrayGet(MlirContext ctx
, intptr_t size
,
421 int8_t const *values
) {
423 DenseI8ArrayAttr::get(unwrap(ctx
), ArrayRef
<int8_t>(values
, size
)));
425 MlirAttribute
mlirDenseI16ArrayGet(MlirContext ctx
, intptr_t size
,
426 int16_t const *values
) {
428 DenseI16ArrayAttr::get(unwrap(ctx
), ArrayRef
<int16_t>(values
, size
)));
430 MlirAttribute
mlirDenseI32ArrayGet(MlirContext ctx
, intptr_t size
,
431 int32_t const *values
) {
433 DenseI32ArrayAttr::get(unwrap(ctx
), ArrayRef
<int32_t>(values
, size
)));
435 MlirAttribute
mlirDenseI64ArrayGet(MlirContext ctx
, intptr_t size
,
436 int64_t const *values
) {
438 DenseI64ArrayAttr::get(unwrap(ctx
), ArrayRef
<int64_t>(values
, size
)));
440 MlirAttribute
mlirDenseF32ArrayGet(MlirContext ctx
, intptr_t size
,
441 float const *values
) {
443 DenseF32ArrayAttr::get(unwrap(ctx
), ArrayRef
<float>(values
, size
)));
445 MlirAttribute
mlirDenseF64ArrayGet(MlirContext ctx
, intptr_t size
,
446 double const *values
) {
448 DenseF64ArrayAttr::get(unwrap(ctx
), ArrayRef
<double>(values
, size
)));
451 //===----------------------------------------------------------------------===//
453 //===----------------------------------------------------------------------===//
455 intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr
) {
456 return llvm::cast
<DenseArrayAttr
>(unwrap(attr
)).size();
459 //===----------------------------------------------------------------------===//
460 // Indexed accessors.
461 //===----------------------------------------------------------------------===//
463 bool mlirDenseBoolArrayGetElement(MlirAttribute attr
, intptr_t pos
) {
464 return llvm::cast
<DenseBoolArrayAttr
>(unwrap(attr
))[pos
];
466 int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr
, intptr_t pos
) {
467 return llvm::cast
<DenseI8ArrayAttr
>(unwrap(attr
))[pos
];
469 int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr
, intptr_t pos
) {
470 return llvm::cast
<DenseI16ArrayAttr
>(unwrap(attr
))[pos
];
472 int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr
, intptr_t pos
) {
473 return llvm::cast
<DenseI32ArrayAttr
>(unwrap(attr
))[pos
];
475 int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr
, intptr_t pos
) {
476 return llvm::cast
<DenseI64ArrayAttr
>(unwrap(attr
))[pos
];
478 float mlirDenseF32ArrayGetElement(MlirAttribute attr
, intptr_t pos
) {
479 return llvm::cast
<DenseF32ArrayAttr
>(unwrap(attr
))[pos
];
481 double mlirDenseF64ArrayGetElement(MlirAttribute attr
, intptr_t pos
) {
482 return llvm::cast
<DenseF64ArrayAttr
>(unwrap(attr
))[pos
];
485 //===----------------------------------------------------------------------===//
486 // Dense elements attribute.
487 //===----------------------------------------------------------------------===//
489 //===----------------------------------------------------------------------===//
491 //===----------------------------------------------------------------------===//
493 bool mlirAttributeIsADenseElements(MlirAttribute attr
) {
494 return llvm::isa
<DenseElementsAttr
>(unwrap(attr
));
497 bool mlirAttributeIsADenseIntElements(MlirAttribute attr
) {
498 return llvm::isa
<DenseIntElementsAttr
>(unwrap(attr
));
501 bool mlirAttributeIsADenseFPElements(MlirAttribute attr
) {
502 return llvm::isa
<DenseFPElementsAttr
>(unwrap(attr
));
505 MlirTypeID
mlirDenseIntOrFPElementsAttrGetTypeID(void) {
506 return wrap(DenseIntOrFPElementsAttr::getTypeID());
509 //===----------------------------------------------------------------------===//
511 //===----------------------------------------------------------------------===//
513 MlirAttribute
mlirDenseElementsAttrGet(MlirType shapedType
,
514 intptr_t numElements
,
515 MlirAttribute
const *elements
) {
516 SmallVector
<Attribute
, 8> attributes
;
518 DenseElementsAttr::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)),
519 unwrapList(numElements
, elements
, attributes
)));
522 MlirAttribute
mlirDenseElementsAttrRawBufferGet(MlirType shapedType
,
523 size_t rawBufferSize
,
524 const void *rawBuffer
) {
525 auto shapedTypeCpp
= llvm::cast
<ShapedType
>(unwrap(shapedType
));
526 ArrayRef
<char> rawBufferCpp(static_cast<const char *>(rawBuffer
),
528 bool isSplat
= false;
529 if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp
, rawBufferCpp
,
531 return mlirAttributeGetNull();
532 return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp
, rawBufferCpp
));
535 MlirAttribute
mlirDenseElementsAttrSplatGet(MlirType shapedType
,
536 MlirAttribute element
) {
537 return wrap(DenseElementsAttr::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)),
540 MlirAttribute
mlirDenseElementsAttrBoolSplatGet(MlirType shapedType
,
542 return wrap(DenseElementsAttr::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)),
545 MlirAttribute
mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType
,
547 return wrap(DenseElementsAttr::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)),
550 MlirAttribute
mlirDenseElementsAttrInt8SplatGet(MlirType shapedType
,
552 return wrap(DenseElementsAttr::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)),
555 MlirAttribute
mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType
,
557 return wrap(DenseElementsAttr::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)),
560 MlirAttribute
mlirDenseElementsAttrInt32SplatGet(MlirType shapedType
,
562 return wrap(DenseElementsAttr::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)),
565 MlirAttribute
mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType
,
567 return wrap(DenseElementsAttr::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)),
570 MlirAttribute
mlirDenseElementsAttrInt64SplatGet(MlirType shapedType
,
572 return wrap(DenseElementsAttr::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)),
575 MlirAttribute
mlirDenseElementsAttrFloatSplatGet(MlirType shapedType
,
577 return wrap(DenseElementsAttr::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)),
580 MlirAttribute
mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType
,
582 return wrap(DenseElementsAttr::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)),
586 MlirAttribute
mlirDenseElementsAttrBoolGet(MlirType shapedType
,
587 intptr_t numElements
,
588 const int *elements
) {
589 SmallVector
<bool, 8> values(elements
, elements
+ numElements
);
590 return wrap(DenseElementsAttr::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)),
594 /// Creates a dense attribute with elements of the type deduced by templates.
595 template <typename T
>
596 static MlirAttribute
getDenseAttribute(MlirType shapedType
,
597 intptr_t numElements
,
599 return wrap(DenseElementsAttr::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)),
600 llvm::ArrayRef(elements
, numElements
)));
603 MlirAttribute
mlirDenseElementsAttrUInt8Get(MlirType shapedType
,
604 intptr_t numElements
,
605 const uint8_t *elements
) {
606 return getDenseAttribute(shapedType
, numElements
, elements
);
608 MlirAttribute
mlirDenseElementsAttrInt8Get(MlirType shapedType
,
609 intptr_t numElements
,
610 const int8_t *elements
) {
611 return getDenseAttribute(shapedType
, numElements
, elements
);
613 MlirAttribute
mlirDenseElementsAttrUInt16Get(MlirType shapedType
,
614 intptr_t numElements
,
615 const uint16_t *elements
) {
616 return getDenseAttribute(shapedType
, numElements
, elements
);
618 MlirAttribute
mlirDenseElementsAttrInt16Get(MlirType shapedType
,
619 intptr_t numElements
,
620 const int16_t *elements
) {
621 return getDenseAttribute(shapedType
, numElements
, elements
);
623 MlirAttribute
mlirDenseElementsAttrUInt32Get(MlirType shapedType
,
624 intptr_t numElements
,
625 const uint32_t *elements
) {
626 return getDenseAttribute(shapedType
, numElements
, elements
);
628 MlirAttribute
mlirDenseElementsAttrInt32Get(MlirType shapedType
,
629 intptr_t numElements
,
630 const int32_t *elements
) {
631 return getDenseAttribute(shapedType
, numElements
, elements
);
633 MlirAttribute
mlirDenseElementsAttrUInt64Get(MlirType shapedType
,
634 intptr_t numElements
,
635 const uint64_t *elements
) {
636 return getDenseAttribute(shapedType
, numElements
, elements
);
638 MlirAttribute
mlirDenseElementsAttrInt64Get(MlirType shapedType
,
639 intptr_t numElements
,
640 const int64_t *elements
) {
641 return getDenseAttribute(shapedType
, numElements
, elements
);
643 MlirAttribute
mlirDenseElementsAttrFloatGet(MlirType shapedType
,
644 intptr_t numElements
,
645 const float *elements
) {
646 return getDenseAttribute(shapedType
, numElements
, elements
);
648 MlirAttribute
mlirDenseElementsAttrDoubleGet(MlirType shapedType
,
649 intptr_t numElements
,
650 const double *elements
) {
651 return getDenseAttribute(shapedType
, numElements
, elements
);
653 MlirAttribute
mlirDenseElementsAttrBFloat16Get(MlirType shapedType
,
654 intptr_t numElements
,
655 const uint16_t *elements
) {
656 size_t bufferSize
= numElements
* 2;
657 const void *buffer
= static_cast<const void *>(elements
);
658 return mlirDenseElementsAttrRawBufferGet(shapedType
, bufferSize
, buffer
);
660 MlirAttribute
mlirDenseElementsAttrFloat16Get(MlirType shapedType
,
661 intptr_t numElements
,
662 const uint16_t *elements
) {
663 size_t bufferSize
= numElements
* 2;
664 const void *buffer
= static_cast<const void *>(elements
);
665 return mlirDenseElementsAttrRawBufferGet(shapedType
, bufferSize
, buffer
);
668 MlirAttribute
mlirDenseElementsAttrStringGet(MlirType shapedType
,
669 intptr_t numElements
,
670 MlirStringRef
*strs
) {
671 SmallVector
<StringRef
, 8> values
;
672 values
.reserve(numElements
);
673 for (intptr_t i
= 0; i
< numElements
; ++i
)
674 values
.push_back(unwrap(strs
[i
]));
676 return wrap(DenseElementsAttr::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)),
680 MlirAttribute
mlirDenseElementsAttrReshapeGet(MlirAttribute attr
,
681 MlirType shapedType
) {
682 return wrap(llvm::cast
<DenseElementsAttr
>(unwrap(attr
))
683 .reshape(llvm::cast
<ShapedType
>(unwrap(shapedType
))));
686 //===----------------------------------------------------------------------===//
688 //===----------------------------------------------------------------------===//
690 bool mlirDenseElementsAttrIsSplat(MlirAttribute attr
) {
691 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).isSplat();
694 MlirAttribute
mlirDenseElementsAttrGetSplatValue(MlirAttribute attr
) {
696 llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getSplatValue
<Attribute
>());
698 int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr
) {
699 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getSplatValue
<bool>();
701 int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr
) {
702 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getSplatValue
<int8_t>();
704 uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr
) {
705 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getSplatValue
<uint8_t>();
707 int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr
) {
708 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getSplatValue
<int32_t>();
710 uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr
) {
711 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getSplatValue
<uint32_t>();
713 int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr
) {
714 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getSplatValue
<int64_t>();
716 uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr
) {
717 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getSplatValue
<uint64_t>();
719 float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr
) {
720 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getSplatValue
<float>();
722 double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr
) {
723 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getSplatValue
<double>();
725 MlirStringRef
mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr
) {
727 llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getSplatValue
<StringRef
>());
730 //===----------------------------------------------------------------------===//
731 // Indexed accessors.
732 //===----------------------------------------------------------------------===//
734 bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr
, intptr_t pos
) {
735 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getValues
<bool>()[pos
];
737 int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr
, intptr_t pos
) {
738 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getValues
<int8_t>()[pos
];
740 uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr
, intptr_t pos
) {
741 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getValues
<uint8_t>()[pos
];
743 int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr
, intptr_t pos
) {
744 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getValues
<int16_t>()[pos
];
746 uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr
, intptr_t pos
) {
747 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getValues
<uint16_t>()[pos
];
749 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr
, intptr_t pos
) {
750 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getValues
<int32_t>()[pos
];
752 uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr
, intptr_t pos
) {
753 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getValues
<uint32_t>()[pos
];
755 int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr
, intptr_t pos
) {
756 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getValues
<int64_t>()[pos
];
758 uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr
, intptr_t pos
) {
759 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getValues
<uint64_t>()[pos
];
761 float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr
, intptr_t pos
) {
762 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getValues
<float>()[pos
];
764 double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr
, intptr_t pos
) {
765 return llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getValues
<double>()[pos
];
767 MlirStringRef
mlirDenseElementsAttrGetStringValue(MlirAttribute attr
,
770 llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getValues
<StringRef
>()[pos
]);
773 //===----------------------------------------------------------------------===//
774 // Raw data accessors.
775 //===----------------------------------------------------------------------===//
777 const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr
) {
778 return static_cast<const void *>(
779 llvm::cast
<DenseElementsAttr
>(unwrap(attr
)).getRawData().data());
782 //===----------------------------------------------------------------------===//
783 // Resource blob attributes.
784 //===----------------------------------------------------------------------===//
786 bool mlirAttributeIsADenseResourceElements(MlirAttribute attr
) {
787 return llvm::isa
<DenseResourceElementsAttr
>(unwrap(attr
));
790 MlirAttribute
mlirUnmanagedDenseResourceElementsAttrGet(
791 MlirType shapedType
, MlirStringRef name
, void *data
, size_t dataLength
,
792 size_t dataAlignment
, bool dataIsMutable
,
793 void (*deleter
)(void *userData
, const void *data
, size_t size
,
796 AsmResourceBlob::DeleterFn cppDeleter
= {};
798 cppDeleter
= [deleter
, userData
](void *data
, size_t size
, size_t align
) {
799 deleter(userData
, data
, size
, align
);
802 AsmResourceBlob
blob(
803 llvm::ArrayRef(static_cast<const char *>(data
), dataLength
),
804 dataAlignment
, std::move(cppDeleter
), dataIsMutable
);
806 DenseResourceElementsAttr::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)),
807 unwrap(name
), std::move(blob
)));
810 template <typename U
, typename T
>
811 static MlirAttribute
getDenseResource(MlirType shapedType
, MlirStringRef name
,
812 intptr_t numElements
, const T
*elements
) {
813 return wrap(U::get(llvm::cast
<ShapedType
>(unwrap(shapedType
)), unwrap(name
),
814 UnmanagedAsmResourceBlob::allocateInferAlign(
815 llvm::ArrayRef(elements
, numElements
))));
818 MlirAttribute
mlirUnmanagedDenseBoolResourceElementsAttrGet(
819 MlirType shapedType
, MlirStringRef name
, intptr_t numElements
,
820 const int *elements
) {
821 return getDenseResource
<DenseBoolResourceElementsAttr
>(shapedType
, name
,
822 numElements
, elements
);
824 MlirAttribute
mlirUnmanagedDenseUInt8ResourceElementsAttrGet(
825 MlirType shapedType
, MlirStringRef name
, intptr_t numElements
,
826 const uint8_t *elements
) {
827 return getDenseResource
<DenseUI8ResourceElementsAttr
>(shapedType
, name
,
828 numElements
, elements
);
830 MlirAttribute
mlirUnmanagedDenseUInt16ResourceElementsAttrGet(
831 MlirType shapedType
, MlirStringRef name
, intptr_t numElements
,
832 const uint16_t *elements
) {
833 return getDenseResource
<DenseUI16ResourceElementsAttr
>(shapedType
, name
,
834 numElements
, elements
);
836 MlirAttribute
mlirUnmanagedDenseUInt32ResourceElementsAttrGet(
837 MlirType shapedType
, MlirStringRef name
, intptr_t numElements
,
838 const uint32_t *elements
) {
839 return getDenseResource
<DenseUI32ResourceElementsAttr
>(shapedType
, name
,
840 numElements
, elements
);
842 MlirAttribute
mlirUnmanagedDenseUInt64ResourceElementsAttrGet(
843 MlirType shapedType
, MlirStringRef name
, intptr_t numElements
,
844 const uint64_t *elements
) {
845 return getDenseResource
<DenseUI64ResourceElementsAttr
>(shapedType
, name
,
846 numElements
, elements
);
848 MlirAttribute
mlirUnmanagedDenseInt8ResourceElementsAttrGet(
849 MlirType shapedType
, MlirStringRef name
, intptr_t numElements
,
850 const int8_t *elements
) {
851 return getDenseResource
<DenseUI8ResourceElementsAttr
>(shapedType
, name
,
852 numElements
, elements
);
854 MlirAttribute
mlirUnmanagedDenseInt16ResourceElementsAttrGet(
855 MlirType shapedType
, MlirStringRef name
, intptr_t numElements
,
856 const int16_t *elements
) {
857 return getDenseResource
<DenseUI16ResourceElementsAttr
>(shapedType
, name
,
858 numElements
, elements
);
860 MlirAttribute
mlirUnmanagedDenseInt32ResourceElementsAttrGet(
861 MlirType shapedType
, MlirStringRef name
, intptr_t numElements
,
862 const int32_t *elements
) {
863 return getDenseResource
<DenseUI32ResourceElementsAttr
>(shapedType
, name
,
864 numElements
, elements
);
866 MlirAttribute
mlirUnmanagedDenseInt64ResourceElementsAttrGet(
867 MlirType shapedType
, MlirStringRef name
, intptr_t numElements
,
868 const int64_t *elements
) {
869 return getDenseResource
<DenseUI64ResourceElementsAttr
>(shapedType
, name
,
870 numElements
, elements
);
872 MlirAttribute
mlirUnmanagedDenseFloatResourceElementsAttrGet(
873 MlirType shapedType
, MlirStringRef name
, intptr_t numElements
,
874 const float *elements
) {
875 return getDenseResource
<DenseF32ResourceElementsAttr
>(shapedType
, name
,
876 numElements
, elements
);
878 MlirAttribute
mlirUnmanagedDenseDoubleResourceElementsAttrGet(
879 MlirType shapedType
, MlirStringRef name
, intptr_t numElements
,
880 const double *elements
) {
881 return getDenseResource
<DenseF64ResourceElementsAttr
>(shapedType
, name
,
882 numElements
, elements
);
884 template <typename U
, typename T
>
885 static T
getDenseResourceVal(MlirAttribute attr
, intptr_t pos
) {
886 return (*llvm::cast
<U
>(unwrap(attr
)).tryGetAsArrayRef())[pos
];
889 bool mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr
,
891 return getDenseResourceVal
<DenseBoolResourceElementsAttr
, uint8_t>(attr
, pos
);
893 uint8_t mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr
,
895 return getDenseResourceVal
<DenseUI8ResourceElementsAttr
, uint8_t>(attr
, pos
);
897 uint16_t mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr
,
899 return getDenseResourceVal
<DenseUI16ResourceElementsAttr
, uint16_t>(attr
,
902 uint32_t mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr
,
904 return getDenseResourceVal
<DenseUI32ResourceElementsAttr
, uint32_t>(attr
,
907 uint64_t mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr
,
909 return getDenseResourceVal
<DenseUI64ResourceElementsAttr
, uint64_t>(attr
,
912 int8_t mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr
,
914 return getDenseResourceVal
<DenseUI8ResourceElementsAttr
, int8_t>(attr
, pos
);
916 int16_t mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr
,
918 return getDenseResourceVal
<DenseUI16ResourceElementsAttr
, int16_t>(attr
, pos
);
920 int32_t mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr
,
922 return getDenseResourceVal
<DenseUI32ResourceElementsAttr
, int32_t>(attr
, pos
);
924 int64_t mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr
,
926 return getDenseResourceVal
<DenseUI64ResourceElementsAttr
, int64_t>(attr
, pos
);
928 float mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr
,
930 return getDenseResourceVal
<DenseF32ResourceElementsAttr
, float>(attr
, pos
);
932 double mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr
,
934 return getDenseResourceVal
<DenseF64ResourceElementsAttr
, double>(attr
, pos
);
937 //===----------------------------------------------------------------------===//
938 // Sparse elements attribute.
939 //===----------------------------------------------------------------------===//
941 bool mlirAttributeIsASparseElements(MlirAttribute attr
) {
942 return llvm::isa
<SparseElementsAttr
>(unwrap(attr
));
945 MlirAttribute
mlirSparseElementsAttribute(MlirType shapedType
,
946 MlirAttribute denseIndices
,
947 MlirAttribute denseValues
) {
948 return wrap(SparseElementsAttr::get(
949 llvm::cast
<ShapedType
>(unwrap(shapedType
)),
950 llvm::cast
<DenseElementsAttr
>(unwrap(denseIndices
)),
951 llvm::cast
<DenseElementsAttr
>(unwrap(denseValues
))));
954 MlirAttribute
mlirSparseElementsAttrGetIndices(MlirAttribute attr
) {
955 return wrap(llvm::cast
<SparseElementsAttr
>(unwrap(attr
)).getIndices());
958 MlirAttribute
mlirSparseElementsAttrGetValues(MlirAttribute attr
) {
959 return wrap(llvm::cast
<SparseElementsAttr
>(unwrap(attr
)).getValues());
962 MlirTypeID
mlirSparseElementsAttrGetTypeID(void) {
963 return wrap(SparseElementsAttr::getTypeID());
966 //===----------------------------------------------------------------------===//
967 // Strided layout attribute.
968 //===----------------------------------------------------------------------===//
970 bool mlirAttributeIsAStridedLayout(MlirAttribute attr
) {
971 return llvm::isa
<StridedLayoutAttr
>(unwrap(attr
));
974 MlirAttribute
mlirStridedLayoutAttrGet(MlirContext ctx
, int64_t offset
,
976 const int64_t *strides
) {
977 return wrap(StridedLayoutAttr::get(unwrap(ctx
), offset
,
978 ArrayRef
<int64_t>(strides
, numStrides
)));
981 int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr
) {
982 return llvm::cast
<StridedLayoutAttr
>(unwrap(attr
)).getOffset();
985 intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr
) {
986 return static_cast<intptr_t>(
987 llvm::cast
<StridedLayoutAttr
>(unwrap(attr
)).getStrides().size());
990 int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr
, intptr_t pos
) {
991 return llvm::cast
<StridedLayoutAttr
>(unwrap(attr
)).getStrides()[pos
];
994 MlirTypeID
mlirStridedLayoutAttrGetTypeID(void) {
995 return wrap(StridedLayoutAttr::getTypeID());