1 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2 # See https://llvm.org/LICENSE.txt for license information.
3 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 """Supports the PyTACO API with the MLIR-PyTACO implementation.
7 See http://tensor-compiler.org/ for TACO tensor compiler.
9 This module exports the MLIR-PyTACO implementation through the language defined
10 by PyTACO. In particular, it defines the function and type aliases and constants
11 needed for the PyTACO API to support the execution of PyTACO programs using the
12 MLIR-PyTACO implementation.
15 from . import mlir_pytaco
16 from . import mlir_pytaco_io
18 # Functions defined by PyTACO API.
19 get_index_vars
= mlir_pytaco
.get_index_vars
20 from_array
= mlir_pytaco
.Tensor
.from_array
21 read
= mlir_pytaco_io
.read
22 write
= mlir_pytaco_io
.write
24 # Classes defined by PyTACO API.
25 dtype
= mlir_pytaco
.DType
26 mode_format
= mlir_pytaco
.ModeFormat
27 mode_ordering
= mlir_pytaco
.ModeOrdering
28 mode_format_pack
= mlir_pytaco
.ModeFormatPack
29 format
= mlir_pytaco
.Format
30 index_var
= mlir_pytaco
.IndexVar
31 tensor
= mlir_pytaco
.Tensor
32 index_expression
= mlir_pytaco
.IndexExpr
33 access
= mlir_pytaco
.Access
35 # Data type constants defined by PyTACO API.
36 int16
= mlir_pytaco
.DType(mlir_pytaco
.Type
.INT16
)
37 int32
= mlir_pytaco
.DType(mlir_pytaco
.Type
.INT32
)
38 int64
= mlir_pytaco
.DType(mlir_pytaco
.Type
.INT64
)
39 float32
= mlir_pytaco
.DType(mlir_pytaco
.Type
.FLOAT32
)
40 float64
= mlir_pytaco
.DType(mlir_pytaco
.Type
.FLOAT64
)
42 # Storage format constants defined by the PyTACO API. In PyTACO, each storage
43 # format constant has two aliasing names.
44 compressed
= mlir_pytaco
.ModeFormat
.COMPRESSED
45 Compressed
= mlir_pytaco
.ModeFormat
.COMPRESSED
46 dense
= mlir_pytaco
.ModeFormat
.DENSE
47 Dense
= mlir_pytaco
.ModeFormat
.DENSE