[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / CAPI / IR / BuiltinAttributes.cpp
blob11d1ade552f5a2068dc65838d117b4c12ac41e38
1 //===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir-c/BuiltinAttributes.h"
10 #include "mlir-c/Support.h"
11 #include "mlir/CAPI/AffineMap.h"
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/IntegerSet.h"
14 #include "mlir/CAPI/Support.h"
15 #include "mlir/IR/AsmState.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
20 using namespace mlir;
22 MlirAttribute mlirAttributeGetNull() { return {nullptr}; }
24 //===----------------------------------------------------------------------===//
25 // Location attribute.
26 //===----------------------------------------------------------------------===//
28 bool mlirAttributeIsALocation(MlirAttribute attr) {
29 return llvm::isa<LocationAttr>(unwrap(attr));
32 //===----------------------------------------------------------------------===//
33 // Affine map attribute.
34 //===----------------------------------------------------------------------===//
36 bool mlirAttributeIsAAffineMap(MlirAttribute attr) {
37 return llvm::isa<AffineMapAttr>(unwrap(attr));
40 MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) {
41 return wrap(AffineMapAttr::get(unwrap(map)));
44 MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
45 return wrap(llvm::cast<AffineMapAttr>(unwrap(attr)).getValue());
48 MlirTypeID mlirAffineMapAttrGetTypeID(void) {
49 return wrap(AffineMapAttr::getTypeID());
52 //===----------------------------------------------------------------------===//
53 // Array attribute.
54 //===----------------------------------------------------------------------===//
56 bool mlirAttributeIsAArray(MlirAttribute attr) {
57 return llvm::isa<ArrayAttr>(unwrap(attr));
60 MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
61 MlirAttribute const *elements) {
62 SmallVector<Attribute, 8> attrs;
63 return wrap(
64 ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements),
65 elements, attrs)));
68 intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
69 return static_cast<intptr_t>(llvm::cast<ArrayAttr>(unwrap(attr)).size());
72 MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
73 return wrap(llvm::cast<ArrayAttr>(unwrap(attr)).getValue()[pos]);
76 MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); }
78 //===----------------------------------------------------------------------===//
79 // Dictionary attribute.
80 //===----------------------------------------------------------------------===//
82 bool mlirAttributeIsADictionary(MlirAttribute attr) {
83 return llvm::isa<DictionaryAttr>(unwrap(attr));
86 MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
87 MlirNamedAttribute const *elements) {
88 SmallVector<NamedAttribute, 8> attributes;
89 attributes.reserve(numElements);
90 for (intptr_t i = 0; i < numElements; ++i)
91 attributes.emplace_back(unwrap(elements[i].name),
92 unwrap(elements[i].attribute));
93 return wrap(DictionaryAttr::get(unwrap(ctx), attributes));
96 intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
97 return static_cast<intptr_t>(llvm::cast<DictionaryAttr>(unwrap(attr)).size());
100 MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
101 intptr_t pos) {
102 NamedAttribute attribute =
103 llvm::cast<DictionaryAttr>(unwrap(attr)).getValue()[pos];
104 return {wrap(attribute.getName()), wrap(attribute.getValue())};
107 MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
108 MlirStringRef name) {
109 return wrap(llvm::cast<DictionaryAttr>(unwrap(attr)).get(unwrap(name)));
112 MlirTypeID mlirDictionaryAttrGetTypeID(void) {
113 return wrap(DictionaryAttr::getTypeID());
116 //===----------------------------------------------------------------------===//
117 // Floating point attribute.
118 //===----------------------------------------------------------------------===//
120 bool mlirAttributeIsAFloat(MlirAttribute attr) {
121 return llvm::isa<FloatAttr>(unwrap(attr));
124 MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
125 double value) {
126 return wrap(FloatAttr::get(unwrap(type), value));
129 MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type,
130 double value) {
131 return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value));
134 double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
135 return llvm::cast<FloatAttr>(unwrap(attr)).getValueAsDouble();
138 MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); }
140 //===----------------------------------------------------------------------===//
141 // Integer attribute.
142 //===----------------------------------------------------------------------===//
144 bool mlirAttributeIsAInteger(MlirAttribute attr) {
145 return llvm::isa<IntegerAttr>(unwrap(attr));
148 MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) {
149 return wrap(IntegerAttr::get(unwrap(type), value));
152 int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) {
153 return llvm::cast<IntegerAttr>(unwrap(attr)).getInt();
156 int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) {
157 return llvm::cast<IntegerAttr>(unwrap(attr)).getSInt();
160 uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
161 return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt();
164 MlirTypeID mlirIntegerAttrGetTypeID(void) {
165 return wrap(IntegerAttr::getTypeID());
168 //===----------------------------------------------------------------------===//
169 // Bool attribute.
170 //===----------------------------------------------------------------------===//
172 bool mlirAttributeIsABool(MlirAttribute attr) {
173 return llvm::isa<BoolAttr>(unwrap(attr));
176 MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
177 return wrap(BoolAttr::get(unwrap(ctx), value));
180 bool mlirBoolAttrGetValue(MlirAttribute attr) {
181 return llvm::cast<BoolAttr>(unwrap(attr)).getValue();
184 //===----------------------------------------------------------------------===//
185 // Integer set attribute.
186 //===----------------------------------------------------------------------===//
188 bool mlirAttributeIsAIntegerSet(MlirAttribute attr) {
189 return llvm::isa<IntegerSetAttr>(unwrap(attr));
192 MlirTypeID mlirIntegerSetAttrGetTypeID(void) {
193 return wrap(IntegerSetAttr::getTypeID());
196 MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set) {
197 return wrap(IntegerSetAttr::get(unwrap(set)));
200 MlirIntegerSet mlirIntegerSetAttrGetValue(MlirAttribute attr) {
201 return wrap(llvm::cast<IntegerSetAttr>(unwrap(attr)).getValue());
204 //===----------------------------------------------------------------------===//
205 // Opaque attribute.
206 //===----------------------------------------------------------------------===//
208 bool mlirAttributeIsAOpaque(MlirAttribute attr) {
209 return llvm::isa<OpaqueAttr>(unwrap(attr));
212 MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
213 intptr_t dataLength, const char *data,
214 MlirType type) {
215 return wrap(
216 OpaqueAttr::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)),
217 StringRef(data, dataLength), unwrap(type)));
220 MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
221 return wrap(
222 llvm::cast<OpaqueAttr>(unwrap(attr)).getDialectNamespace().strref());
225 MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
226 return wrap(llvm::cast<OpaqueAttr>(unwrap(attr)).getAttrData());
229 MlirTypeID mlirOpaqueAttrGetTypeID(void) {
230 return wrap(OpaqueAttr::getTypeID());
233 //===----------------------------------------------------------------------===//
234 // String attribute.
235 //===----------------------------------------------------------------------===//
237 bool mlirAttributeIsAString(MlirAttribute attr) {
238 return llvm::isa<StringAttr>(unwrap(attr));
241 MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
242 return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str)));
245 MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
246 return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type)));
249 MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
250 return wrap(llvm::cast<StringAttr>(unwrap(attr)).getValue());
253 MlirTypeID mlirStringAttrGetTypeID(void) {
254 return wrap(StringAttr::getTypeID());
257 //===----------------------------------------------------------------------===//
258 // SymbolRef attribute.
259 //===----------------------------------------------------------------------===//
261 bool mlirAttributeIsASymbolRef(MlirAttribute attr) {
262 return llvm::isa<SymbolRefAttr>(unwrap(attr));
265 MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
266 intptr_t numReferences,
267 MlirAttribute const *references) {
268 SmallVector<FlatSymbolRefAttr, 4> refs;
269 refs.reserve(numReferences);
270 for (intptr_t i = 0; i < numReferences; ++i)
271 refs.push_back(llvm::cast<FlatSymbolRefAttr>(unwrap(references[i])));
272 auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol));
273 return wrap(SymbolRefAttr::get(symbolAttr, refs));
276 MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
277 return wrap(
278 llvm::cast<SymbolRefAttr>(unwrap(attr)).getRootReference().getValue());
281 MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) {
282 return wrap(
283 llvm::cast<SymbolRefAttr>(unwrap(attr)).getLeafReference().getValue());
286 intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) {
287 return static_cast<intptr_t>(
288 llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences().size());
291 MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
292 intptr_t pos) {
293 return wrap(
294 llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences()[pos]);
297 MlirTypeID mlirSymbolRefAttrGetTypeID(void) {
298 return wrap(SymbolRefAttr::getTypeID());
301 MlirAttribute mlirDisctinctAttrCreate(MlirAttribute referencedAttr) {
302 return wrap(mlir::DistinctAttr::create(unwrap(referencedAttr)));
305 //===----------------------------------------------------------------------===//
306 // Flat SymbolRef attribute.
307 //===----------------------------------------------------------------------===//
309 bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
310 return llvm::isa<FlatSymbolRefAttr>(unwrap(attr));
313 MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
314 return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol)));
317 MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
318 return wrap(llvm::cast<FlatSymbolRefAttr>(unwrap(attr)).getValue());
321 //===----------------------------------------------------------------------===//
322 // Type attribute.
323 //===----------------------------------------------------------------------===//
325 bool mlirAttributeIsAType(MlirAttribute attr) {
326 return llvm::isa<TypeAttr>(unwrap(attr));
329 MlirAttribute mlirTypeAttrGet(MlirType type) {
330 return wrap(TypeAttr::get(unwrap(type)));
333 MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
334 return wrap(llvm::cast<TypeAttr>(unwrap(attr)).getValue());
337 MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); }
339 //===----------------------------------------------------------------------===//
340 // Unit attribute.
341 //===----------------------------------------------------------------------===//
343 bool mlirAttributeIsAUnit(MlirAttribute attr) {
344 return llvm::isa<UnitAttr>(unwrap(attr));
347 MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
348 return wrap(UnitAttr::get(unwrap(ctx)));
351 MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); }
353 //===----------------------------------------------------------------------===//
354 // Elements attributes.
355 //===----------------------------------------------------------------------===//
357 bool mlirAttributeIsAElements(MlirAttribute attr) {
358 return llvm::isa<ElementsAttr>(unwrap(attr));
361 MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
362 uint64_t *idxs) {
363 return wrap(llvm::cast<ElementsAttr>(unwrap(attr))
364 .getValues<Attribute>()[llvm::ArrayRef(idxs, rank)]);
367 bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
368 uint64_t *idxs) {
369 return llvm::cast<ElementsAttr>(unwrap(attr))
370 .isValidIndex(llvm::ArrayRef(idxs, rank));
373 int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
374 return llvm::cast<ElementsAttr>(unwrap(attr)).getNumElements();
377 //===----------------------------------------------------------------------===//
378 // Dense array attribute.
379 //===----------------------------------------------------------------------===//
381 MlirTypeID mlirDenseArrayAttrGetTypeID() {
382 return wrap(DenseArrayAttr::getTypeID());
385 //===----------------------------------------------------------------------===//
386 // IsA support.
387 //===----------------------------------------------------------------------===//
389 bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) {
390 return llvm::isa<DenseBoolArrayAttr>(unwrap(attr));
392 bool mlirAttributeIsADenseI8Array(MlirAttribute attr) {
393 return llvm::isa<DenseI8ArrayAttr>(unwrap(attr));
395 bool mlirAttributeIsADenseI16Array(MlirAttribute attr) {
396 return llvm::isa<DenseI16ArrayAttr>(unwrap(attr));
398 bool mlirAttributeIsADenseI32Array(MlirAttribute attr) {
399 return llvm::isa<DenseI32ArrayAttr>(unwrap(attr));
401 bool mlirAttributeIsADenseI64Array(MlirAttribute attr) {
402 return llvm::isa<DenseI64ArrayAttr>(unwrap(attr));
404 bool mlirAttributeIsADenseF32Array(MlirAttribute attr) {
405 return llvm::isa<DenseF32ArrayAttr>(unwrap(attr));
407 bool mlirAttributeIsADenseF64Array(MlirAttribute attr) {
408 return llvm::isa<DenseF64ArrayAttr>(unwrap(attr));
411 //===----------------------------------------------------------------------===//
412 // Constructors.
413 //===----------------------------------------------------------------------===//
415 MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size,
416 int const *values) {
417 SmallVector<bool, 4> elements(values, values + size);
418 return wrap(DenseBoolArrayAttr::get(unwrap(ctx), elements));
420 MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size,
421 int8_t const *values) {
422 return wrap(
423 DenseI8ArrayAttr::get(unwrap(ctx), ArrayRef<int8_t>(values, size)));
425 MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size,
426 int16_t const *values) {
427 return wrap(
428 DenseI16ArrayAttr::get(unwrap(ctx), ArrayRef<int16_t>(values, size)));
430 MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size,
431 int32_t const *values) {
432 return wrap(
433 DenseI32ArrayAttr::get(unwrap(ctx), ArrayRef<int32_t>(values, size)));
435 MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size,
436 int64_t const *values) {
437 return wrap(
438 DenseI64ArrayAttr::get(unwrap(ctx), ArrayRef<int64_t>(values, size)));
440 MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size,
441 float const *values) {
442 return wrap(
443 DenseF32ArrayAttr::get(unwrap(ctx), ArrayRef<float>(values, size)));
445 MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size,
446 double const *values) {
447 return wrap(
448 DenseF64ArrayAttr::get(unwrap(ctx), ArrayRef<double>(values, size)));
451 //===----------------------------------------------------------------------===//
452 // Accessors.
453 //===----------------------------------------------------------------------===//
455 intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
456 return llvm::cast<DenseArrayAttr>(unwrap(attr)).size();
459 //===----------------------------------------------------------------------===//
460 // Indexed accessors.
461 //===----------------------------------------------------------------------===//
463 bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) {
464 return llvm::cast<DenseBoolArrayAttr>(unwrap(attr))[pos];
466 int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) {
467 return llvm::cast<DenseI8ArrayAttr>(unwrap(attr))[pos];
469 int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) {
470 return llvm::cast<DenseI16ArrayAttr>(unwrap(attr))[pos];
472 int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) {
473 return llvm::cast<DenseI32ArrayAttr>(unwrap(attr))[pos];
475 int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
476 return llvm::cast<DenseI64ArrayAttr>(unwrap(attr))[pos];
478 float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) {
479 return llvm::cast<DenseF32ArrayAttr>(unwrap(attr))[pos];
481 double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
482 return llvm::cast<DenseF64ArrayAttr>(unwrap(attr))[pos];
485 //===----------------------------------------------------------------------===//
486 // Dense elements attribute.
487 //===----------------------------------------------------------------------===//
489 //===----------------------------------------------------------------------===//
490 // IsA support.
491 //===----------------------------------------------------------------------===//
493 bool mlirAttributeIsADenseElements(MlirAttribute attr) {
494 return llvm::isa<DenseElementsAttr>(unwrap(attr));
497 bool mlirAttributeIsADenseIntElements(MlirAttribute attr) {
498 return llvm::isa<DenseIntElementsAttr>(unwrap(attr));
501 bool mlirAttributeIsADenseFPElements(MlirAttribute attr) {
502 return llvm::isa<DenseFPElementsAttr>(unwrap(attr));
505 MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) {
506 return wrap(DenseIntOrFPElementsAttr::getTypeID());
509 //===----------------------------------------------------------------------===//
510 // Constructors.
511 //===----------------------------------------------------------------------===//
513 MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
514 intptr_t numElements,
515 MlirAttribute const *elements) {
516 SmallVector<Attribute, 8> attributes;
517 return wrap(
518 DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
519 unwrapList(numElements, elements, attributes)));
522 MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
523 size_t rawBufferSize,
524 const void *rawBuffer) {
525 auto shapedTypeCpp = llvm::cast<ShapedType>(unwrap(shapedType));
526 ArrayRef<char> rawBufferCpp(static_cast<const char *>(rawBuffer),
527 rawBufferSize);
528 bool isSplat = false;
529 if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp,
530 isSplat))
531 return mlirAttributeGetNull();
532 return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp));
535 MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
536 MlirAttribute element) {
537 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
538 unwrap(element)));
540 MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
541 bool element) {
542 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
543 element));
545 MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType,
546 uint8_t element) {
547 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
548 element));
550 MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType,
551 int8_t element) {
552 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
553 element));
555 MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
556 uint32_t element) {
557 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
558 element));
560 MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,
561 int32_t element) {
562 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
563 element));
565 MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,
566 uint64_t element) {
567 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
568 element));
570 MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,
571 int64_t element) {
572 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
573 element));
575 MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,
576 float element) {
577 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
578 element));
580 MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
581 double element) {
582 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
583 element));
586 MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
587 intptr_t numElements,
588 const int *elements) {
589 SmallVector<bool, 8> values(elements, elements + numElements);
590 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
591 values));
594 /// Creates a dense attribute with elements of the type deduced by templates.
595 template <typename T>
596 static MlirAttribute getDenseAttribute(MlirType shapedType,
597 intptr_t numElements,
598 const T *elements) {
599 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
600 llvm::ArrayRef(elements, numElements)));
603 MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType,
604 intptr_t numElements,
605 const uint8_t *elements) {
606 return getDenseAttribute(shapedType, numElements, elements);
608 MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType,
609 intptr_t numElements,
610 const int8_t *elements) {
611 return getDenseAttribute(shapedType, numElements, elements);
613 MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType,
614 intptr_t numElements,
615 const uint16_t *elements) {
616 return getDenseAttribute(shapedType, numElements, elements);
618 MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType,
619 intptr_t numElements,
620 const int16_t *elements) {
621 return getDenseAttribute(shapedType, numElements, elements);
623 MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
624 intptr_t numElements,
625 const uint32_t *elements) {
626 return getDenseAttribute(shapedType, numElements, elements);
628 MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
629 intptr_t numElements,
630 const int32_t *elements) {
631 return getDenseAttribute(shapedType, numElements, elements);
633 MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
634 intptr_t numElements,
635 const uint64_t *elements) {
636 return getDenseAttribute(shapedType, numElements, elements);
638 MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
639 intptr_t numElements,
640 const int64_t *elements) {
641 return getDenseAttribute(shapedType, numElements, elements);
643 MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
644 intptr_t numElements,
645 const float *elements) {
646 return getDenseAttribute(shapedType, numElements, elements);
648 MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
649 intptr_t numElements,
650 const double *elements) {
651 return getDenseAttribute(shapedType, numElements, elements);
653 MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType,
654 intptr_t numElements,
655 const uint16_t *elements) {
656 size_t bufferSize = numElements * 2;
657 const void *buffer = static_cast<const void *>(elements);
658 return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer);
660 MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType,
661 intptr_t numElements,
662 const uint16_t *elements) {
663 size_t bufferSize = numElements * 2;
664 const void *buffer = static_cast<const void *>(elements);
665 return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer);
668 MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
669 intptr_t numElements,
670 MlirStringRef *strs) {
671 SmallVector<StringRef, 8> values;
672 values.reserve(numElements);
673 for (intptr_t i = 0; i < numElements; ++i)
674 values.push_back(unwrap(strs[i]));
676 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
677 values));
680 MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
681 MlirType shapedType) {
682 return wrap(llvm::cast<DenseElementsAttr>(unwrap(attr))
683 .reshape(llvm::cast<ShapedType>(unwrap(shapedType))));
686 //===----------------------------------------------------------------------===//
687 // Splat accessors.
688 //===----------------------------------------------------------------------===//
690 bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
691 return llvm::cast<DenseElementsAttr>(unwrap(attr)).isSplat();
694 MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
695 return wrap(
696 llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<Attribute>());
698 int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
699 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<bool>();
701 int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) {
702 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<int8_t>();
704 uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) {
705 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<uint8_t>();
707 int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) {
708 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<int32_t>();
710 uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) {
711 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<uint32_t>();
713 int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) {
714 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<int64_t>();
716 uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) {
717 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<uint64_t>();
719 float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) {
720 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<float>();
722 double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
723 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<double>();
725 MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
726 return wrap(
727 llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<StringRef>());
730 //===----------------------------------------------------------------------===//
731 // Indexed accessors.
732 //===----------------------------------------------------------------------===//
734 bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
735 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<bool>()[pos];
737 int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
738 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int8_t>()[pos];
740 uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
741 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint8_t>()[pos];
743 int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) {
744 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int16_t>()[pos];
746 uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) {
747 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint16_t>()[pos];
749 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
750 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int32_t>()[pos];
752 uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
753 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint32_t>()[pos];
755 int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
756 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int64_t>()[pos];
758 uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
759 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint64_t>()[pos];
761 float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
762 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<float>()[pos];
764 double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
765 return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<double>()[pos];
767 MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
768 intptr_t pos) {
769 return wrap(
770 llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<StringRef>()[pos]);
773 //===----------------------------------------------------------------------===//
774 // Raw data accessors.
775 //===----------------------------------------------------------------------===//
777 const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
778 return static_cast<const void *>(
779 llvm::cast<DenseElementsAttr>(unwrap(attr)).getRawData().data());
782 //===----------------------------------------------------------------------===//
783 // Resource blob attributes.
784 //===----------------------------------------------------------------------===//
786 bool mlirAttributeIsADenseResourceElements(MlirAttribute attr) {
787 return llvm::isa<DenseResourceElementsAttr>(unwrap(attr));
790 MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet(
791 MlirType shapedType, MlirStringRef name, void *data, size_t dataLength,
792 size_t dataAlignment, bool dataIsMutable,
793 void (*deleter)(void *userData, const void *data, size_t size,
794 size_t align),
795 void *userData) {
796 AsmResourceBlob::DeleterFn cppDeleter = {};
797 if (deleter) {
798 cppDeleter = [deleter, userData](void *data, size_t size, size_t align) {
799 deleter(userData, data, size, align);
802 AsmResourceBlob blob(
803 llvm::ArrayRef(static_cast<const char *>(data), dataLength),
804 dataAlignment, std::move(cppDeleter), dataIsMutable);
805 return wrap(
806 DenseResourceElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
807 unwrap(name), std::move(blob)));
810 template <typename U, typename T>
811 static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name,
812 intptr_t numElements, const T *elements) {
813 return wrap(U::get(llvm::cast<ShapedType>(unwrap(shapedType)), unwrap(name),
814 UnmanagedAsmResourceBlob::allocateInferAlign(
815 llvm::ArrayRef(elements, numElements))));
818 MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet(
819 MlirType shapedType, MlirStringRef name, intptr_t numElements,
820 const int *elements) {
821 return getDenseResource<DenseBoolResourceElementsAttr>(shapedType, name,
822 numElements, elements);
824 MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet(
825 MlirType shapedType, MlirStringRef name, intptr_t numElements,
826 const uint8_t *elements) {
827 return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name,
828 numElements, elements);
830 MlirAttribute mlirUnmanagedDenseUInt16ResourceElementsAttrGet(
831 MlirType shapedType, MlirStringRef name, intptr_t numElements,
832 const uint16_t *elements) {
833 return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name,
834 numElements, elements);
836 MlirAttribute mlirUnmanagedDenseUInt32ResourceElementsAttrGet(
837 MlirType shapedType, MlirStringRef name, intptr_t numElements,
838 const uint32_t *elements) {
839 return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name,
840 numElements, elements);
842 MlirAttribute mlirUnmanagedDenseUInt64ResourceElementsAttrGet(
843 MlirType shapedType, MlirStringRef name, intptr_t numElements,
844 const uint64_t *elements) {
845 return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name,
846 numElements, elements);
848 MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet(
849 MlirType shapedType, MlirStringRef name, intptr_t numElements,
850 const int8_t *elements) {
851 return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name,
852 numElements, elements);
854 MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet(
855 MlirType shapedType, MlirStringRef name, intptr_t numElements,
856 const int16_t *elements) {
857 return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name,
858 numElements, elements);
860 MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet(
861 MlirType shapedType, MlirStringRef name, intptr_t numElements,
862 const int32_t *elements) {
863 return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name,
864 numElements, elements);
866 MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet(
867 MlirType shapedType, MlirStringRef name, intptr_t numElements,
868 const int64_t *elements) {
869 return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name,
870 numElements, elements);
872 MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet(
873 MlirType shapedType, MlirStringRef name, intptr_t numElements,
874 const float *elements) {
875 return getDenseResource<DenseF32ResourceElementsAttr>(shapedType, name,
876 numElements, elements);
878 MlirAttribute mlirUnmanagedDenseDoubleResourceElementsAttrGet(
879 MlirType shapedType, MlirStringRef name, intptr_t numElements,
880 const double *elements) {
881 return getDenseResource<DenseF64ResourceElementsAttr>(shapedType, name,
882 numElements, elements);
884 template <typename U, typename T>
885 static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) {
886 return (*llvm::cast<U>(unwrap(attr)).tryGetAsArrayRef())[pos];
889 bool mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr,
890 intptr_t pos) {
891 return getDenseResourceVal<DenseBoolResourceElementsAttr, uint8_t>(attr, pos);
893 uint8_t mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr,
894 intptr_t pos) {
895 return getDenseResourceVal<DenseUI8ResourceElementsAttr, uint8_t>(attr, pos);
897 uint16_t mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr,
898 intptr_t pos) {
899 return getDenseResourceVal<DenseUI16ResourceElementsAttr, uint16_t>(attr,
900 pos);
902 uint32_t mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr,
903 intptr_t pos) {
904 return getDenseResourceVal<DenseUI32ResourceElementsAttr, uint32_t>(attr,
905 pos);
907 uint64_t mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr,
908 intptr_t pos) {
909 return getDenseResourceVal<DenseUI64ResourceElementsAttr, uint64_t>(attr,
910 pos);
912 int8_t mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr,
913 intptr_t pos) {
914 return getDenseResourceVal<DenseUI8ResourceElementsAttr, int8_t>(attr, pos);
916 int16_t mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr,
917 intptr_t pos) {
918 return getDenseResourceVal<DenseUI16ResourceElementsAttr, int16_t>(attr, pos);
920 int32_t mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr,
921 intptr_t pos) {
922 return getDenseResourceVal<DenseUI32ResourceElementsAttr, int32_t>(attr, pos);
924 int64_t mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr,
925 intptr_t pos) {
926 return getDenseResourceVal<DenseUI64ResourceElementsAttr, int64_t>(attr, pos);
928 float mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr,
929 intptr_t pos) {
930 return getDenseResourceVal<DenseF32ResourceElementsAttr, float>(attr, pos);
932 double mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr,
933 intptr_t pos) {
934 return getDenseResourceVal<DenseF64ResourceElementsAttr, double>(attr, pos);
937 //===----------------------------------------------------------------------===//
938 // Sparse elements attribute.
939 //===----------------------------------------------------------------------===//
941 bool mlirAttributeIsASparseElements(MlirAttribute attr) {
942 return llvm::isa<SparseElementsAttr>(unwrap(attr));
945 MlirAttribute mlirSparseElementsAttribute(MlirType shapedType,
946 MlirAttribute denseIndices,
947 MlirAttribute denseValues) {
948 return wrap(SparseElementsAttr::get(
949 llvm::cast<ShapedType>(unwrap(shapedType)),
950 llvm::cast<DenseElementsAttr>(unwrap(denseIndices)),
951 llvm::cast<DenseElementsAttr>(unwrap(denseValues))));
954 MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) {
955 return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getIndices());
958 MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
959 return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getValues());
962 MlirTypeID mlirSparseElementsAttrGetTypeID(void) {
963 return wrap(SparseElementsAttr::getTypeID());
966 //===----------------------------------------------------------------------===//
967 // Strided layout attribute.
968 //===----------------------------------------------------------------------===//
970 bool mlirAttributeIsAStridedLayout(MlirAttribute attr) {
971 return llvm::isa<StridedLayoutAttr>(unwrap(attr));
974 MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset,
975 intptr_t numStrides,
976 const int64_t *strides) {
977 return wrap(StridedLayoutAttr::get(unwrap(ctx), offset,
978 ArrayRef<int64_t>(strides, numStrides)));
981 int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) {
982 return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getOffset();
985 intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) {
986 return static_cast<intptr_t>(
987 llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides().size());
990 int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) {
991 return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides()[pos];
994 MlirTypeID mlirStridedLayoutAttrGetTypeID(void) {
995 return wrap(StridedLayoutAttr::getTypeID());