1 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2 # See https://llvm.org/LICENSE.txt for license information.
3 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 """Experimental MLIR-PyTACO with sparse tensor support.
7 See http://tensor-compiler.org/ for TACO tensor compiler.
9 This module implements the Python classes for PyTACO index notation. These
10 include classes for data types, tensor dimension formats (aka mode formats),
11 tensor dimension orderings (aka mode ordering), tensor storage formats, and
14 The PyTACO API doesn't follow the naming conversion required by the style guide
15 for this module. As such, we first implement the supporting classes and routines
16 following the style guide, and then define the type aliases and constants to
17 support the PyTACO API in the pytaco_api module.
20 from typing
import Any
, Callable
, Dict
, Iterable
, List
, Optional
, Set
, Tuple
, Union
32 # Import MLIR related modules.
33 from mlir
import all_passes_registration
# Register MLIR compiler passes.
34 from mlir
import execution_engine
36 from mlir
import runtime
37 from mlir
.dialects
import arith
38 from mlir
.dialects
import builtin
39 from mlir
.dialects
import linalg
40 from mlir
.dialects
import std
41 from mlir
.dialects
import sparse_tensor
42 from mlir
.dialects
.linalg
.opdsl
import lang
43 from mlir
.passmanager
import PassManager
45 from . import mlir_pytaco_utils
as utils
47 # TACO naming prefixes.
48 _TACO_INDEX_PREFIX
= "i"
49 _TACO_TENSOR_PREFIX
= "A"
51 # Bitwidths for pointers and indices.
52 _POINTER_BIT_WIDTH
= 0
54 # The name for the environment variable that provides the full path for the
56 _SUPPORTLIB_ENV_VAR
= "SUPPORTLIB"
57 # The default supporting library if the environment variable is not provided.
58 _DEFAULT_SUPPORTLIB
= "libmlir_c_runner_utils.so"
59 # The JIT compiler optimization level.
61 # The entry point to the JIT compiled program.
64 # Type aliases for type annotation.
65 _BinaryOp
= Callable
[[Any
, Any
], Any
]
66 _ExprVisitor
= Callable
[..., None]
67 _ExprInfoDict
= Dict
["IndexExpr", "_ExprInfo"]
68 _LogicalOp
= Callable
[[bool, bool], bool]
69 _ModeFormatOp
= Callable
[["ModeFormat", "ModeFormat"], "ModeFormat"]
70 _SubtreeLeafChecker
= Optional
[Callable
[..., bool]]
73 class Type(enum
.Enum
):
74 """The data types supported by TACO.
76 We use numpy data types to implement the enum data types.
81 # numpy _ctype_from_dtype_scalar can't handle np.float16 yet.
86 # All floating point type enums.
87 _FLOAT_TYPES
= (Type
.FLOAT32
, Type
.FLOAT64
)
88 # All integral type enums.
89 _INT_TYPES
= (Type
.INT16
, Type
.INT32
, Type
.INT64
)
90 # Type alias for any numpy type used to implement the runtime support for the
92 _AnyRuntimeType
= Union
[np
.int16
, np
.int32
, np
.int64
, np
.float32
, np
.float64
]
95 @dataclasses.dataclass(frozen
=True)
97 """The data type class.
99 We support the TACO API dtype class with an alias of this class.
101 The following methods are defined by the TACO API:
102 is_float: Returns whether the data type represents a floating point value.
103 is_int: Returns whether the data type represents an integral value.
106 kind: A Type enum representing the data type.
107 value: The numpy data type for the TACO data type.
109 kind
: Type
= Type
.FLOAT64
111 def is_float(self
) -> bool:
112 """Returns whether the data type represents a floating point value."""
113 return self
.kind
in _FLOAT_TYPES
115 def is_int(self
) -> bool:
116 """Returns whether the data type represents an integral value."""
117 return self
.kind
in _INT_TYPES
120 def value(self
) -> _AnyRuntimeType
:
121 """Returns the numpy dtype for the data type."""
122 return self
.kind
.value
125 def _mlir_type_from_taco_type(dtype
: DType
) -> ir
.Type
:
126 """Returns the MLIR type corresponding to the given TACO type."""
128 Type
.INT16
: ir
.IntegerType
.get_signless(16),
129 Type
.INT32
: ir
.IntegerType
.get_signless(32),
130 Type
.INT64
: ir
.IntegerType
.get_signless(64),
131 Type
.FLOAT32
: ir
.F32Type
.get(),
132 Type
.FLOAT64
: ir
.F64Type
.get()
134 return dtype_to_irtype
[dtype
.kind
]
137 def _compile_mlir(module
: ir
.Module
) -> ir
.Module
:
138 """Compiles an MLIR module and returns the compiled module."""
139 # TODO: Replace this with a pipeline implemented for
140 # https://github.com/llvm/llvm-project/issues/51751.
143 f
"sparse-tensor-conversion,"
144 f
"builtin.func(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),"
145 f
"convert-scf-to-std,"
147 f
"tensor-constant-bufferize,"
148 f
"builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),"
149 f
"convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}},"
151 f
"convert-memref-to-llvm,"
152 f
"convert-std-to-llvm,"
153 f
"reconcile-unrealized-casts")
154 PassManager
.parse(pipeline
).run(module
)
158 @functools.lru_cache()
159 def _get_support_lib_name() -> str:
160 """Returns the string for the supporting C shared library."""
161 return os
.getenv(_SUPPORTLIB_ENV_VAR
, _DEFAULT_SUPPORTLIB
)
164 def _ctype_pointer_from_array(array
: np
.ndarray
) -> ctypes
.pointer
:
165 """Returns the ctype pointer for the given numpy array."""
166 return ctypes
.pointer(
167 ctypes
.pointer(runtime
.get_ranked_memref_descriptor(array
)))
170 class ModeFormat(enum
.Enum
):
171 """The tensor dimension storage format class.
173 We support the TACO API mode_format class with an alias of this class.
175 In TACO, a tensor dimension is called a mode and the storage format for a
176 tensor dimension is called a mode format.
178 DENSE
= sparse_tensor
.DimLevelType
.dense
179 COMPRESSED
= sparse_tensor
.DimLevelType
.compressed
182 def _mode_format_operation(a
: ModeFormat
, b
: ModeFormat
,
183 op
: _LogicalOp
) -> ModeFormat
:
184 """Implements the given operator on ModeFormat."""
185 return (ModeFormat
.COMPRESSED
186 if op(a
== ModeFormat
.COMPRESSED
, b
== ModeFormat
.COMPRESSED
) else
190 def _mode_format_estimator(op
: _BinaryOp
) -> _ModeFormatOp
:
191 """Produces a ModeFormat operator for the given binary operator.
193 The ModeFormat operator is used as a heuristic to derive the destination
194 dimension sparsity from the source dimension sparsity. In particular, if the
195 binary operator produces a disjunction of the zero values from its source
196 operands, such as the MUL operator, we return a ModeFormat operator that
197 uses operator.or_. That is, we estimate that a dimension for the MUL
198 operation result to be sparse if either of its source operands is sparse.
200 On the other hand, if the binary operator produces a conjunction of the
201 zero values from its source operands, such as the ADD operator, we return
202 a ModeFormat operator that uses operator.and_. In this case, we estimate
203 that a dimension for the ADD operation result to be sparse if both of its
204 source operands are sparse.
207 op: A _BinaryOp object representing a supporting operator on tensors.
210 A ModeFormatOp for estimating the destination dimension sparsity from
211 the source dimension sparsity.
213 conjunction
= functools
.partial(_mode_format_operation
, op
=operator
.and_
)
214 disjunction
= functools
.partial(_mode_format_operation
, op
=operator
.or_
)
215 return conjunction
if op(0, 1) != 0 else disjunction
218 def _all_instance_of(collection
: Iterable
, cls
: Any
) -> bool:
219 """Returns true if all elements of the iterable is an instance of cls."""
220 return all(isinstance(e
, cls
) for e
in collection
)
223 def _identity_ordering(rank
: int) -> List
[int]:
224 """Returns the identity ordering for tensor of given rank."""
225 return list(range(rank
))
228 @dataclasses.dataclass(frozen
=True)
230 """The tensor dimension ordering class.
232 We support the TACO API mode_ordering class with an alias of this class.
235 ordering: A list of integers representing the ordering of the tensor
240 def __post_init__(self
) -> None:
241 """Verifies the value in ordering.
244 ValueError: If ordering is not a list of integers.
246 if (not isinstance(self
.ordering
, list) or
247 not _all_instance_of(self
.ordering
, int)):
248 raise ValueError("Ordering must be a list of integers: "
250 # Check that ordering is a permutation of the dimension numbers.
251 if sorted(self
.ordering
) != _identity_ordering(self
.rank()):
252 raise ValueError(f
"Invalid ordering: {self.ordering} != "
253 f
"permutation{_identity_ordering(self.rank())}.")
255 def rank(self
) -> int:
256 """Returns the number of dimensions represented by the ordering."""
257 return len(self
.ordering
)
260 @dataclasses.dataclass(frozen
=True)
261 class ModeFormatPack
:
262 """The tensor dimension format class.
264 We support the TACO API mode_format_pack class with an alias of this class.
266 The storage format of a tensor contains one mode_format for each tensor
270 formats: A list of ModeFormat representing the storage format for each of
271 the tensor dimension.
273 formats
: List
[ModeFormat
]
275 def __post_init__(self
) -> None:
276 """Verifies the value in formats.
279 ValueError: If formats is not a list of ModeFormats.
281 if (not isinstance(self
.formats
, list) or
282 not _all_instance_of(self
.formats
, ModeFormat
)):
283 raise ValueError("Formats must be a list of ModeFormat: "
286 def rank(self
) -> int:
287 """Returns the number of dimensions represented by the format pack."""
288 return len(self
.formats
)
291 @dataclasses.dataclass
293 """The tensor format class defined by the TACO API.
296 format_pack: A ModeFormatPack representing the storage format for the tensor
298 ordering: A ModeOrdering representing the tensor dimension ordering in the
301 format_pack
: ModeFormatPack
302 ordering
: Optional
[ModeOrdering
] = None
304 def __post_init__(self
) -> None:
305 """Verifies and fixes up the values in format_pack and ordering.
307 Verifies and fixes up the values in format_pack and ordering to supports the
308 initializer syntax defined by the TACO API. If format_pack is a list of
309 ModeFormat, replaces it with ModeFormatPack constructed from the list. If
310 ordering is not provided, set ordering to the natural ordering for the rank
311 corresponding to format_pack.
314 ValueError: If format_pack is not an instance of ModeFormatPack or if
315 ordering is not an instance of ModeOrdering.
317 if isinstance(self
.format_pack
, list):
318 if not _all_instance_of(self
.format_pack
, ModeFormat
):
319 raise ValueError(f
"Expected a list of ModeFormat: {self.format_pack}")
320 self
.format_pack
= ModeFormatPack(self
.format_pack
)
321 if not isinstance(self
.format_pack
, ModeFormatPack
):
322 raise ValueError(f
"Expected ModeFormatpack: {self.format_pack}")
324 if self
.ordering
is None:
325 self
.ordering
= ModeOrdering(list(range(self
.rank())))
326 if not isinstance(self
.ordering
, ModeOrdering
):
327 raise ValueError(f
"Expected ModeOrdering: {self.ordering}")
329 if self
.format_pack
.rank() != self
.ordering
.rank():
330 raise ValueError("Inconsistent ModeFormatPack and ModeOrdering: "
331 f
"len({self.format_pack}) != "
332 f
"len({self.ordering})")
334 def is_dense(self
) -> bool:
335 """Returns true if all the Tensor dimensions have a dense format."""
336 return all([f
== ModeFormat
.DENSE
for f
in self
.format_pack
.formats
])
338 def rank(self
) -> int:
339 """Returns the number of dimensions represented by the format."""
340 return self
.format_pack
.rank()
342 def mlir_tensor_attr(self
) -> Optional
[sparse_tensor
.EncodingAttr
]:
343 """Constructs the MLIR attributes for the tensor format."""
348 range(self
.rank()) if
349 (self
.ordering
is None) else self
.ordering
.ordering
)
350 mlir_storage_format
= [f
.value
for f
in self
.format_pack
.formats
]
351 return sparse_tensor
.EncodingAttr
.get(mlir_storage_format
,
352 ir
.AffineMap
.get_permutation(order
),
353 _POINTER_BIT_WIDTH
, _INDEX_BIT_WIDTH
)
356 def _make_format(formats
: List
[ModeFormat
],
357 ordering
: Optional
[List
[int]] = None) -> Format
:
358 """Constructs a format from a list of ModeFormat and an optional ordering.
361 formats: A list of ModeFormat, one for each dimension of a tensor.
362 ordering: An optional list of integer, for the ordering of the tensor
363 dimensions. When an ordering is not given, the identity ordering is used.
366 A tensor format object.
369 ValueError: If formats is not a list of ModeFormat or the length of formats
370 is not consistent with the len of ordering.
372 ordering
= ordering
or _identity_ordering(len(formats
))
373 return Format(ModeFormatPack(formats
), ModeOrdering(ordering
))
376 class _AtomicCounter
:
377 """An atomic counter."""
381 self
._counter
_lock
= threading
.Lock()
383 def increment(self
) -> int:
384 """Increments the counter by one and returns the old value."""
385 old_value
= self
._counter
386 with self
._counter
_lock
:
387 self
._counter
= self
._counter
+ 1
392 """The tensor index class.
394 We support the TACO API index_var class with an alias of this class.
396 An IndexVar object represents an index variable in tensor index notation.
399 name: A unique string name of the IndexVar.
401 _counter
= _AtomicCounter()
404 id = self
._counter
.increment()
405 self
._name
= f
"{_TACO_INDEX_PREFIX}{id}"
407 def __repr__(self
) -> str:
408 return f
"IndexVar(name={repr(self._name)})"
411 def name(self
) -> str:
412 """Returns the name of the IndexVar."""
416 def get_index_vars(n
: int) -> List
[IndexVar
]:
417 """Returns a list of n IndexVar.
419 This routine is defined by the TACO API.
422 n: An interger representing the number of IndexVar to get.
428 ValueError: if n is not a positive integer.
430 if not isinstance(n
, int) or n
<= 0:
431 raise ValueError(f
"Expected an integer: {n}.")
432 # If lock contention ever becomes an issue, we could implement a bulk getter
433 # that returns a range by only claiming the lock once.
434 return [IndexVar() for i
in range(n
)]
437 def _mlir_symbols_from_index_vars(
438 index_vars
: Tuple
[IndexVar
, ...]) -> Tuple
[lang
.SymbolDef
, ...]:
439 """Returns a tuple of MLIR symbols for the given tuple of index_var."""
440 return tuple(getattr(lang
.S
, i
.name
) for i
in index_vars
)
443 def _mlir_dimensions_from_index_vars(
444 index_vars
: Tuple
[IndexVar
, ...]) -> Tuple
[lang
.DimDef
, ...]:
445 """Returns a tuple of MLIR dimensions for the given tuple of index_var."""
446 return tuple(getattr(lang
.D
, i
.name
) for i
in index_vars
)
449 def _mlir_tensor_type(
450 dtype
: DType
, shape
: Tuple
[int, ...],
451 attr
: Optional
[sparse_tensor
.EncodingAttr
]) -> ir
.RankedTensorType
:
452 """Returns an MLIR tensor type.
455 dtype: An DType object for the element data type of the tensor.
456 shape: A tuple of integer for the shape of the tensor.
457 attr: An optional MLIR sparse tensor attribute, only provided if the tensor
461 An MLIR ranked tensor type.
463 ir_type
= _mlir_type_from_taco_type(dtype
)
464 return ir
.RankedTensorType
.get(shape
, ir_type
, attr
)
467 def _verify_and_normalize_indices(indices
) -> Tuple
[IndexVar
, ...]:
468 """Verifies and normalizes the indices for a tensor access.
471 indices: The index expression used to access a tensor, which could be any
472 Python object from user inputs.
478 ValueError: If indices is not an IndexVar or a tuple of IndexVar.
480 if isinstance(indices
, IndexVar
):
482 elif isinstance(indices
, tuple) and _all_instance_of(indices
, IndexVar
):
485 raise ValueError(f
"Expected IndexVars: {indices}")
488 @dataclasses.dataclass(frozen
=True)
490 """Information for generating a structured op in the linalg dialect.
492 This information is associated with an expression node that serves as the
493 root for an expression subtree implemented with a structured op.
496 dst_indices: A tuple of IndexVar, representing the result dimensions of the
497 structured op. This is used to construct the temporary variable for the
498 tensor to hold the structured op result.
499 dst_dims: A tuple of int, representing the result shape of the structured
501 dst_dtype: A DType representing the data type of the structured op result.
502 dst_name: A string representing the name of the structured op result.
503 dst_format: A Format object representing the destination tensor format.
505 dst_indices
: Tuple
[IndexVar
, ...]
506 dst_dims
: Tuple
[int, ...]
511 def __post_init__(self
) -> None:
512 """Verifies the integrity of the attribute values."""
513 assert len(self
.dst_indices
) == len(self
.dst_dims
)
514 assert self
.dst_format
is not None
516 def emit_tensor_init(self
) -> ir
.RankedTensorType
:
517 """Returns an initialization for the destination tensor."""
518 if self
.dst_format
.is_dense():
519 # Initialize the dense tensor.
520 ir_type
= _mlir_type_from_taco_type(self
.dst_dtype
)
521 tensor
= linalg
.InitTensorOp(self
.dst_dims
, ir_type
).result
522 zero
= arith
.ConstantOp(ir_type
, 0.0)
523 return linalg
.FillOp(output
=tensor
, value
=zero
).results
[0]
525 # Initialize the sparse tensor.
526 mlir_type
= _mlir_tensor_type(self
.dst_dtype
, self
.dst_dims
,
527 self
.dst_format
.mlir_tensor_attr())
528 index_type
= ir
.IndexType
.get()
529 dims
= [arith
.ConstantOp(index_type
, d
).result
for d
in mlir_type
.shape
]
530 return sparse_tensor
.InitOp(mlir_type
, dims
)
534 """Information to describe how a tensor expression is implemented.
536 Currently, we only record the temporary tensors introduced for splitting the
543 def __repr__(self
) -> str:
544 return f
"_Stats({repr(self._temps)})"
546 def add_element(self
, structop
: _StructOpInfo
):
547 """Adds a temporary tensor."""
548 self
._temps
.append(structop
)
550 def get_total(self
) -> int:
551 """Gets the total number of temporary tensors."""
552 return len(self
._temps
)
554 def _get_element(self
, idx
: int) -> _StructOpInfo
:
555 """Gets the ith temporary tensor."""
556 assert idx
< self
.get_total()
557 return self
._temps
[idx
]
559 def get_dimensions(self
, idx
: int) -> Tuple
[int]:
560 """Gets the dimensions for the ith temporary tensor."""
561 return self
._get
_element
(idx
).dst_dims
563 def get_formats(self
, idx
: int) -> Tuple
[ModeFormat
]:
564 """Gets the ModeFormats for the ith temporary tensor."""
565 return tuple(self
._get
_element
(idx
).dst_format
.format_pack
.formats
)
571 We support the TACO API tensor class with an alias of this class.
573 This class is part of the TACO API with the following methods:
574 insert: Inserts a value to the given coordinate in the tensor.
575 to_array: Returns a numpy ndarray for the tensor.
577 TACO API also defines the following arrtibutes for the class:
578 dtype: A dtype object representing the data type of the tensor.
579 format: A format object representing the storage format of the tensor.
580 name: A string object representing the name of the tensor.
581 order: An integral rank of the tensor.
582 shape: A list of integers representing the shape of the tensor.
584 We currently ignore the tensor dimension ordering for dense tensor.
586 _counter
= _AtomicCounter()
588 def _get_unique_name(self
) -> str:
589 """Returns a unique name for creating a new Tensor."""
590 return f
"{_TACO_TENSOR_PREFIX}{self._counter.increment()}"
592 def _init_format(self
, fmt
: Union
[ModeFormat
, List
[ModeFormat
],
594 """Process the fmt argument for the Tensor constructor.
597 fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
598 this argument is a ModeFormat, uses this ModeFormat for all the tensor
599 dimensions. If this argument is a list of ModeFormat, the len of the
600 list should equal to the rank of the tensor. If this argument is a
601 format, uses it for the format of the tensor.
604 ValueError: If fmt is not one of the expected type or is inconsistent
605 with the rank of the tensor. This is because fmt could be an users
608 if isinstance(fmt
, ModeFormat
):
609 self
._format
= _make_format([fmt
] * self
.order
)
610 elif isinstance(fmt
, list):
611 if len(fmt
) == self
.order
and isinstance(fmt
[0], ModeFormat
):
612 self
._format
= _make_format(fmt
)
614 raise ValueError("Inconsistent shape and format: "
615 f
"{self._shape}, {fmt}.")
616 elif isinstance(fmt
, Format
):
617 if fmt
.rank() != self
.order
:
618 raise ValueError("Inconsistent shape and format: "
619 f
"{self._shape}, {fmt}.")
623 raise ValueError(f
"Invalid format argument: {fmt}.")
626 value_or_shape
: Optional
[Union
[List
[int], Tuple
[int, ...], float,
628 fmt
: Optional
[Union
[ModeFormat
, List
[ModeFormat
],
630 dtype
: Optional
[DType
] = None,
631 name
: Optional
[str] = None):
632 """The tensor constructor interface defined by TACO API.
635 value_or_shape: This argument is optional and can be int, float,
636 List[int], or Tuple[int, ...]. If this argument is an int or float,
637 creates a scalar tensor and initializes it with the value. If this
638 argument is a list or tuple of int, uses it as the shape to create a
640 fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
641 this argument is a ModeFormat, uses this ModeFormat for all the tensor
642 dimensions. If this argument is a list of ModeFormat, the len of the
643 list should equal to the rank of the tensor. If this argument is a
644 format, uses it for the format of the tensor.
645 dtype: An object of dtype, representing the data type of the tensor.
646 name: A string name of the tensor. If a name is not given, creates a
647 unique name for the tensor.
650 ValueError: If there is any inconsistency among the input arguments.
652 # Take care of the argument default values.
653 fmt
= fmt
or ModeFormat
.COMPRESSED
654 dtype
= dtype
or DType(Type
.FLOAT64
)
655 self
._name
= name
or self
._get
_unique
_name
()
658 # We currently use _coords and _values to host the sparse tensor value with
659 # COO format, and _dense_storage to host the dense tensor value. We haven't
660 # implement the conversion between the two storages yet. This will be
661 # improved in a follow up CL.
664 self
._dense
_storage
= None
665 self
._stats
= _Stats()
666 if value_or_shape
is None or isinstance(value_or_shape
, int) or isinstance(
667 value_or_shape
, float):
668 # Create a scalar tensor and ignore the fmt parameter.
670 self
._format
= _make_format([], [])
671 if value_or_shape
is not None:
672 self
._dense
_storage
= np
.array(value_or_shape
, dtype
=self
._dtype
.value
)
673 elif (isinstance(value_or_shape
, tuple) or isinstance(
674 value_or_shape
, list)) and _all_instance_of(value_or_shape
, int):
675 # Create a tensor with the specified shape and format.
676 self
._shape
= list(value_or_shape
)
677 self
._init
_format
(fmt
)
679 raise ValueError("Invalid first argument. "
680 "Must be a tuple or list for a shape or a single value"
681 f
"if initializing a scalar tensor: {value_or_shape}.")
683 def __repr__(self
) -> str:
684 value_str
= (f
"{repr(self._dense_storage)})" if self
.is_dense() else
685 f
"{repr(self._coords)} {repr(self._values)})")
686 return (f
"Tensor(_name={repr(self._name)} "
687 f
"_dtype={repr(self._dtype)} : ") + value_str
689 def insert(self
, coords
: List
[int], val
: Union
[float, int]) -> None:
690 """Inserts a value to the given coordinate.
693 coords: A list of integer coordinates. The length of the list must be the
694 same as the rank of the tensor.
695 val: A value being inserted. It is either an integral or a floating point
696 value. This value will be converted to the data type of the tensor.
699 ValueError: When there is any problem in the parameters.
701 if not isinstance(coords
, list):
702 raise ValueError(f
"Non list coordinate detected: {coords}.")
703 if not _all_instance_of(coords
, int):
704 raise ValueError(f
"Non integer coordinate detected: {coords}.")
705 if (len(coords
) != self
.order
or
706 any([c
< 0 or c
>= self
._shape
[i
] for i
, c
in enumerate(coords
)])):
707 raise ValueError("Invalid coordinate for rank: "
708 f
"{self.order}, {coords}.")
710 if not isinstance(val
, int) and not isinstance(val
, float):
711 raise ValueError(f
"Value is neither int nor float: {val}.")
713 self
._coords
.append(tuple(coords
))
714 self
._values
.append(self
._dtype
.value(val
))
716 def is_dense(self
) -> bool:
717 """Returns true if all the Tensor dimensions have a dense format."""
718 return self
._format
.is_dense()
720 def to_array(self
) -> np
.ndarray
:
721 """Returns the numpy array for the Tensor.
723 This is currenly only implemented for dense Tensor.
725 if not self
.is_dense():
726 raise ValueError("Conversion from non-dense Tensor "
727 "to numpy array not supported yet.")
728 return self
._dense
_storage
731 def from_array(array
: np
.ndarray
) -> "Tensor":
732 """Returns a dense tensor with the value copied from the input array.
734 We currently only support the conversion of float64 numpy arrays to Tensor.
737 array: The numpy array that provides the data type, shape and value for
744 ValueError if the data type of the numpy array is not float64.
746 if array
.dtype
!= np
.float64
:
747 raise ValueError(f
"Expected float64 value type: {array.dtype}.")
748 tensor
= Tensor(array
.shape
, ModeFormat
.DENSE
)
749 tensor
._dense
_storage
= np
.copy(array
)
754 coordinates
: List
[Tuple
[int, ...]],
755 values
: List
[_AnyRuntimeType
],
759 """Converts coordinates and values to a sparse tensor representation.
762 coordinates: A list of coordinates with non-zero values.
763 values: The non-zero values.
764 fmt: The tensor storage format.
765 dtype: The tensor element data type.
768 A tensor with the given non-zero values and storage format. The shape of
769 the tensor has the minimum size for each dimension to make the given
772 assert (isinstance(coordinates
, List
) and
773 _all_instance_of(coordinates
, Tuple
))
774 assert (isinstance(values
, List
) and _all_instance_of(values
, dtype
.value
))
775 assert isinstance(fmt
, Format
)
778 assert all(len(c
) == rank
and _all_instance_of(c
, int) for c
in coordinates
)
780 # Find the maximum coordinate value for each dimension.
781 max_coordinate
= list(map(max, zip(*coordinates
)))
782 # The size of each dimension is one more that such a maximum coordinate
784 shape
= [c
+ 1 for c
in max_coordinate
]
785 tensor
= Tensor(shape
, fmt
)
786 tensor
._coords
= coordinates
787 tensor
._values
= values
792 def dtype(self
) -> DType
:
793 """Returns the data type for the Tensor."""
797 def format(self
) -> Format
:
798 """Returns the storage format for the Tensor."""
802 def name(self
) -> str:
803 """Returns the name for the Tensor."""
807 def order(self
) -> int:
808 """Returns the rank of the Tensor."""
809 return len(self
._shape
)
812 def shape(self
) -> List
[int]:
813 """Returns the shape of the Tensor."""
816 def __getitem__(self
, key
) -> "Access":
817 """Verifies and processes a tensor access.
819 In the tensor index notation, a tensor access T[i, j] is represented as
820 retrieving a value with key (i, j) from the tensor object T in Python. This
821 routine verifies the key for the tensor access and returns a tensor access
825 key: The key used to access the tensor, which could be any Python object
829 The corresponding tensor access object.
832 ValueError: If key is not an IndexVar or a tuple of IndexVar.
834 indices
= _verify_and_normalize_indices(key
)
835 return Access(self
, indices
)
837 def __setitem__(self
, key
, value
) -> None:
838 """Verifies and processes a tensor assignment.
840 In the tensor index notation, a tensor assignment "T[i, j] = ..." is
841 represented as setting a value for a tensor object T via key (i, j) in
842 Python. This routine verifies the key, evaluates the value, and assigns the
845 We only support assignment of dense tensor currently.
848 key: The key used to access the tensor, which could be any Python object
850 value: The value assigned to the tensor, which could be any Python object
854 ValueError: If tensor is not a dense tensor, or the key is not an IndexVar
855 or a tuple of IndexVar, or the length of the indices is not the same as
856 the rank of the tensor.
858 indices
= _verify_and_normalize_indices(key
)
859 if len(indices
) != self
.order
:
860 raise ValueError("Mismatch between indices and tensor rank: "
861 f
"len({indices}) != {self.order}.")
863 result
= value
.evaluate(self
, indices
)
865 assert isinstance(result
, np
.ndarray
)
866 self
._dense
_storage
= result
868 assert _all_instance_of(result
, np
.ndarray
) and len(result
) == 2
869 assert (result
[0].ndim
, result
[1].ndim
) == (1, 2)
870 (self
._values
, self
._coords
) = result
872 def mlir_tensor_type(self
) -> ir
.RankedTensorType
:
873 """Returns the MLIR type for the tensor."""
874 return _mlir_tensor_type(self
._dtype
, tuple(self
._shape
),
875 self
._format
.mlir_tensor_attr())
877 def dense_dst_ctype_pointer(self
) -> ctypes
.pointer
:
878 """Returns the ctypes pointer for the pointer to an MemRefDescriptor.
880 For a dense tensor output, the MLIR compiler allocates the storage for
881 the tensor. This routine returns the pointer to an MLIR MemRefDescriptor for
882 receiving the tensor.
884 assert self
.is_dense()
885 mem_ref_desc
= runtime
.make_nd_memref_descriptor(
886 self
.order
, np
.ctypeslib
.as_ctypes_type(self
.dtype
.value
))()
887 return ctypes
.pointer(ctypes
.pointer(mem_ref_desc
))
889 def ctype_pointer(self
) -> ctypes
.pointer
:
890 """Returns the ctypes pointer for the pointer to the input tensor."""
892 if self
._dense
_storage
is None:
893 self
._dense
_storage
= np
.zeros(self
._shape
, self
._dtype
.value
)
894 return _ctype_pointer_from_array(self
._dense
_storage
)
896 shape
= np
.array(self
._shape
, np
.int64
)
897 indices
= np
.array(self
._coords
, np
.int64
)
898 values
= np
.array(self
._values
, self
._dtype
.value
)
899 ptr
= utils
.coo_tensor_to_sparse_tensor(_get_support_lib_name(), shape
,
901 return ctypes
.pointer(ctypes
.cast(ptr
, ctypes
.c_void_p
))
903 def get_coordinates_and_values(
904 self
) -> Tuple
[List
[Tuple
[int, ...]], List
[_AnyRuntimeType
]]:
905 """Returns the coordinates and values for the non-zero elements."""
906 if not self
.is_dense():
907 return (self
._coords
, self
._values
)
909 # Coordinates for non-zero elements, grouped by dimensions.
910 coords_by_dims
= self
._dense
_storage
.nonzero()
911 # Coordinates for non-zero elements, grouped by elements.
912 coords
= np
.transpose(coords_by_dims
)
913 values
= self
._dense
_storage
[coords_by_dims
]
914 return (coords
, values
)
916 def _record_stats(self
, structop
: "_StructOpInfo"):
917 """Collects information for temporary tensors."""
918 # Exclude user specified destination tensors.
919 if structop
.dst_name
== self
.name
:
922 self
._stats
.add_element(structop
)
925 def _emit_operand(op_def
: lang
.LinalgOpDef
, indices
: Tuple
[IndexVar
, ...],
926 name
: str, kind
: lang
.OperandKind
) -> lang
.OperandDef
:
927 """Emits an operand for a tensor access in the current linalg operation.
930 op_def: A LinalgOpDef representing the current linalg dialect operation.
931 indices: A tuple of IndexVar used to access the tensor.
932 name: A unique string name of the tensor.
933 kind: An OperandKind for the operand.
936 An OperandDef representing the operand.
938 dim_sym
= _mlir_symbols_from_index_vars(indices
)
939 opnd
= lang
.OperandDef(kind
, lang
.T
, dim_sym
)
940 op_def
.add_operand(name
, opnd
)
944 @dataclasses.dataclass(frozen
=True)
946 """Information for an operand dimension.
949 dim: An integer for the size of the dimension.
950 mode_format: A ModeFormat for the dimension sparsity.
953 mode_format
: ModeFormat
956 @dataclasses.dataclass()
958 """Expression information for validation and code generation.
961 src_indices: A tuple of IndexVar for the indices used by the tensors in the
963 dim_infos: A tuple of _DimInfo, representing the dimension information
964 corresponding to the src_indices.
965 reduce_indices: A set of IndexVar for the indices reduced by the expression.
966 acc_reduce_indices: An accumulated set of IndexVar for the indices reduced
967 by the expression and its children.
968 structop_info: Information to support the code generation for a structured
969 op in the linalg dialect, if the corresponding expression node is the root
970 of a subtree for a structured op.
971 mlir_value: The MLIR value generated for the structured op.
973 src_indices
: Tuple
[IndexVar
, ...]
974 dim_infos
: Tuple
[_DimInfo
, ...]
975 reduce_indices
: Optional
[Set
[IndexVar
]] = None
976 acc_reduce_indices
: Optional
[Set
[IndexVar
]] = None
977 structop_info
: Optional
[_StructOpInfo
] = None
978 mlir_value
: Optional
[ir
.Value
] = None
980 def __post_init__(self
) -> None:
981 """Verifies and fix up attribute values.
983 Verifies the consistency of the attributes and modifies the default values
984 to support convenient initializer syntax.
986 assert len(self
.src_indices
) == len(self
.dim_infos
)
987 self
.reduce_indices
= self
.reduce_indices
or set()
988 self
.acc_reduce_indices
= self
.acc_reduce_indices
or set()
991 class IndexExpr(abc
.ABC
):
992 """The index notation base class.
994 We support the TACO API index_expression class with an alias of this class.
997 def _verify_operand_and_build_expr(self
, rhs
, op
: _BinaryOp
) -> "_BinaryExpr":
998 """Verifies the RHS operand and returns a binary expression.
1001 rhs: The RHS of the binary operation, which could be any Python object
1003 op: A _BinaryOp object representing the binary operator.
1006 ValueError: If rhs is not an IndexExpr.
1008 if not isinstance(rhs
, IndexExpr
):
1009 raise ValueError(f
"Expected IndexExpr: {rhs}")
1010 return _BinaryExpr(op
, self
, rhs
)
1012 def __add__(self
, rhs
) -> "_BinaryExpr":
1013 """Defines the operator +.
1016 rhs: The value being added, which could be any Python object from user
1020 A _BinaryExpr object representing the operation.
1023 ValueError: If rhs is not an IndexExpr.
1025 return self
._verify
_operand
_and
_build
_expr
(rhs
, operator
.add
)
1027 def __mul__(self
, rhs
) -> "_BinaryExpr":
1028 """Defines the operator *.
1031 rhs: The value being multiplied, which could be any Python object from
1035 A _BinaryExpr object representing the operation.
1038 ValueError: If rhs is not an IndexExpr.
1040 return self
._verify
_operand
_and
_build
_expr
(rhs
, operator
.mul
)
1042 def __sub__(self
, rhs
) -> "_BinaryExpr":
1043 """Defines the operator -.
1046 rhs: The value being subtracted, which could be any Python object from
1050 A _BinaryExpr object representing the operation.
1053 ValueError: If rhs is not an IndexExpr.
1055 return self
._verify
_operand
_and
_build
_expr
(rhs
, operator
.sub
)
1062 leaf_checker
: _SubtreeLeafChecker
= None) -> None:
1063 """A post-order visitor.
1066 func: A callable applied to each node in the expression tree.
1067 args: The variable-length arguments passed to the callable. These
1068 arguments are grouped as an iterable and will be unpacked before passing
1069 to the callable. This is to enable the keyword argument only syntax
1070 after this argument.
1071 leaf_checker: A callable object to identify nodes that should be treated
1072 as leaf nodes to support partial tree visiting.
1077 def _emit_expression(
1079 expr_to_opnd
: Dict
["IndexExpr", lang
.OperandDef
],
1080 expr_to_info
: _ExprInfoDict
,
1081 ) -> lang
.ScalarExpression
:
1082 """Emits MLIR for the expression tree.
1085 expr_to_opnd: A dictionary for looking up structured op input operands for
1086 the input nodes of the structured op.
1087 expr_to_info: A dictionary for looking up code generation information for
1091 A linalg dialect ScalarExpression for the expression.
1096 def dtype(self
) -> DType
:
1097 """Returns the data type for the result of the expression."""
1100 def _emit_structured_op(self
, expr_to_info
: _ExprInfoDict
) -> None:
1101 """Emits a structured op in the linalg dialect for the expression tree.
1103 We define a DefineOpcallable in the domain specific language for the linalg
1104 dialect and execute the callable to generate the structured op. Self is the
1105 root of the expression tree for the structured op.
1108 expr_to_info: A dictionary for looking up code generation information for
1111 op_info
= expr_to_info
[self
].structop_info
1112 op_name
= op_info
.dst_name
1113 op_def
= lang
.LinalgOpDef(name
=op_name
)
1114 op_callable
= lang
.DefinedOpCallable(op_name
, op_def
)
1116 # Collect the input expression nodes for the structured op.
1119 _gather_structured_op_input
,
1120 (self
, expr_to_info
, expr_inputs
),
1121 leaf_checker
=_is_structured_op_leaf
,
1124 # Create a linalg structured op operand for each input expression node and
1125 # build a dictionary for looking up the information.
1126 expr_to_input_opnd
= {
1127 e
: _emit_structured_op_input(e
, expr_to_info
, op_def
)
1128 for e
in expr_inputs
1131 # Emit the expression tree, which produces the value assigned to the
1132 # destination tensor.
1133 value
= self
._emit
_expression
(expr_to_input_opnd
, expr_to_info
)
1134 # Emit the structured op representation for the destination tensor.
1135 dst_opnd
= _emit_operand(op_def
, op_info
.dst_indices
, op_info
.dst_name
,
1136 lang
.OperandKind
.OutputTensor
)
1137 dst_dim_syms
= _mlir_dimensions_from_index_vars(op_info
.dst_indices
)
1138 dst_use
= lang
.TensorUse(dst_opnd
, dst_dim_syms
)
1140 expr_info
= expr_to_info
[self
]
1141 # If the structured op reduces some indices, explicitly represent the
1142 # reduction. This is done by generating a ReduceFn for the dimensions being
1143 # reduced in the linalg dialect and calling the function with the value
1144 # being reduced. We only support add reduction currently.
1145 if expr_info
.reduce_indices
:
1146 reduce_dims
= _mlir_dimensions_from_index_vars(expr_info
.reduce_indices
)
1147 value
= lang
.ReduceFn
.add
[reduce_dims
](value
)
1149 # Emit the assignment as a comprehension in the linalg dialect.
1150 comp
= lang
.Comprehension((dst_use
, value
))
1151 op_def
.comprehensions
.append(comp
)
1153 # The structured op in the linalg dialect requires an explicit
1154 # initialization for the destination tensor. Emit MLIR to initialize the
1155 # destination tensor.
1156 init
= op_info
.emit_tensor_init()
1158 # Collect MLIR values for the linalg input operands, with the assumption
1159 # that dictionary preserves the insertion order.
1161 expr_to_info
[expr
].mlir_value
1162 for expr
, opnd
in expr_to_input_opnd
.items()
1164 # Execute the DefineOpcallable object for the linalg dialect operation to
1165 # emit MLIR for the linalg structured op.
1166 expr_info
.mlir_value
= op_callable(*args
, outs
=[init
])
1168 def _identify_structured_ops(
1170 expr_to_info
: _ExprInfoDict
,
1172 dst_indices
: Tuple
[IndexVar
, ...],
1173 ) -> List
["IndexExpr"]:
1174 """Returns expression nodes for the roots of the identified structured ops.
1176 A structured op in the linalg dialect only supports reduction performed on
1177 the whole expression. If the expression tree contains reduction that are
1178 performed on part of the expression tree, the expression tree needs to be
1179 implemented with multiple structured ops. This routine identifies all the
1180 expression nodes that contain reduction as the root of structured ops in the
1184 expr_to_info: A dictionary for looking up code generation information for
1186 dst: A destination Tensor that accepts the value of the expression tree.
1187 dst_indices: The indices used by the destination index expression.
1190 An ordered list of IndexExpr for the root expressions of the structured
1191 ops, where child expressions go before parent expressions that use their
1194 reduce_indices
= tuple(
1195 set(expr_to_info
[self
].src_indices
) - set(dst_indices
))
1196 for reduce_index
in reduce_indices
:
1197 _mark_structured_op_root(self
, reduce_index
, expr_to_info
)
1199 self
._visit
(_accumulate_reduce_indices
, (expr_to_info
,))
1201 self
._visit
(_gather_structured_op
, (expr_to_info
, structop_roots
))
1203 # Handle the root of the top level expression.
1204 if not structop_roots
or structop_roots
[-1] != self
:
1205 # The top level expression is not a reduction. Add the top level
1206 # expression as a structured op root.
1207 structop_roots
.append(self
)
1209 # Use user specified information for the destination tensor to build an
1210 # _StructOpInfo for the top level expression.
1211 expr_to_info
[self
].structop_info
= _StructOpInfo(dst_indices
,
1213 self
.dtype(), dst
.name
,
1216 return structop_roots
1218 def _validate_and_collect_expr_info(
1221 dst_indices
: Tuple
[IndexVar
, ...],
1223 """Propagates expression information for validation.
1225 Propagates the indices used by child expression nodes to parent expression
1226 nodes. Also collects and validates the sizes for the dimensions
1227 corresponding to the indices.
1230 dst: A destination Tensor that accepts the value of the expression tree.
1231 dst_indices: The indices used by the destination index expression.
1234 ValueError if there is any inconsistency in indices or dimensional
1238 A dictionary of (IndexExpr, _ExprInfo).
1241 # Validate the expression tree and construct expression information.
1242 self
._visit
(_validate_and_collect_expr_info
, (expr_to_info
,))
1244 # Validate the destination dimension information.
1245 info
= expr_to_info
[self
]
1246 index_to_dim_info
= {i
: d
for i
, d
in zip(info
.src_indices
, info
.dim_infos
)}
1247 for i
, d
, in zip(dst_indices
, dst
.shape
):
1248 if i
not in index_to_dim_info
:
1249 raise ValueError("Destination IndexVar not used in the "
1250 f
"source expression: {i}")
1252 if d
!= index_to_dim_info
[i
].dim
:
1253 raise ValueError(f
"Inconsistent destination dimension for {i}: "
1254 f
"{d} vs {index_to_dim_info[i].dim}")
1258 def _emit_assignment(
1262 dst_indices
: Tuple
[IndexVar
, ...],
1263 expr_to_info
: _ExprInfoDict
,
1264 input_accesses
: List
["Access"],
1266 """Emits an MLIR function for assigning the expression to a tensor."""
1267 input_types
= [a
.tensor
.mlir_tensor_type() for a
in input_accesses
]
1269 # Build the kernel for the operations.
1270 with ir
.InsertionPoint(module
.body
):
1272 @builtin.FuncOp
.from_py_func(*input_types
, name
=_ENTRY_NAME
)
1273 def linalg_funcop(*args
):
1274 # Set up the mapping from the Access nodes to their MLIR values.
1275 for e
, mlir
in zip(input_accesses
, args
):
1276 expr_to_info
[e
].mlir_value
= mlir
1278 # Emit structured ops in the linalg dialect to implement the assignment.
1279 for structop_root
in self
._identify
_structured
_ops
(
1280 expr_to_info
, dst
, dst_indices
):
1281 structop_root
._emit
_structured
_op
(expr_to_info
)
1282 dst
._record
_stats
(expr_to_info
[structop_root
].structop_info
)
1284 # The function returns the MLIR value of the root expression.
1285 return expr_to_info
[self
].mlir_value
1287 linalg_funcop
.func_op
.attributes
[
1288 "llvm.emit_c_interface"] = ir
.UnitAttr
.get()
1293 dst_indices
: Tuple
[IndexVar
, ...],
1294 ) -> Union
[np
.ndarray
, Tuple
[np
.ndarray
, np
.ndarray
]]:
1295 """Evaluates tensor assignment dst[dst_indices] = expression.
1298 dst: The destination tensor.
1299 dst_indices: The tuple of IndexVar used to access the destination tensor.
1302 The result of the dense tensor represented in numpy ndarray or the sparse
1303 tensor represented by two numpy ndarray for its non-zero values and
1307 ValueError: If the expression is not proper or not supported.
1309 expr_to_info
= self
._validate
_and
_collect
_expr
_info
(dst
, dst_indices
)
1311 # Compute a list of input accesses.
1313 self
._visit
(_gather_input_accesses_index_vars
, (input_accesses
,))
1315 support_lib
= _get_support_lib_name()
1316 # Build and compile the module to produce the execution engine.
1317 with ir
.Context(), ir
.Location
.unknown():
1318 module
= ir
.Module
.create()
1319 self
._emit
_assignment
(module
, dst
, dst_indices
, expr_to_info
,
1321 compiled_module
= _compile_mlir(module
)
1323 # We currently rely on an environment to pass in the full path of a
1324 # supporting library for the execution engine.
1325 engine
= execution_engine
.ExecutionEngine(
1326 compiled_module
, opt_level
=_OPT_LEVEL
, shared_libs
=[support_lib
])
1328 # Gather the pointers for the input buffers.
1329 input_pointers
= [a
.tensor
.ctype_pointer() for a
in input_accesses
]
1331 # The pointer to receive dense output is the first argument to the
1333 arg_pointers
= [dst
.dense_dst_ctype_pointer()] + input_pointers
1335 # The pointer to receive sparse output is the last argument to the
1336 # execution engine. The pointer to receive a sparse tensor output is a
1337 # pointer to pointer of char.
1338 arg_pointers
= input_pointers
+ [
1339 ctypes
.pointer(ctypes
.pointer(ctypes
.c_char(0)))
1342 # Invoke the execution engine to run the module and return the result.
1343 engine
.invoke(_ENTRY_NAME
, *arg_pointers
)
1346 return runtime
.ranked_memref_to_numpy(arg_pointers
[0][0])
1348 # Check and return the sparse tensor output.
1349 rank
, nse
, shape
, values
, indices
= utils
.sparse_tensor_to_coo_tensor(
1351 ctypes
.cast(arg_pointers
[-1][0], ctypes
.c_void_p
),
1354 assert (np
.equal(rank
, dst
.order
)
1355 and np
.array_equal(shape
, np
.array(dst
.shape
)) and
1356 np
.equal(values
.ndim
, 1) and np
.equal(values
.shape
[0], nse
) and
1357 np
.equal(indices
.ndim
, 2) and np
.equal(indices
.shape
[0], nse
) and
1358 np
.equal(indices
.shape
[1], rank
))
1359 return (values
, indices
)
1362 @dataclasses.dataclass(frozen
=True)
1363 class Access(IndexExpr
):
1364 """The tensor access class.
1366 We support the TACO API access class with an alias of this class.
1369 tensor: A Tensor being accessed.
1370 indices: A tuple of IndexVar, representing the indices used to access the
1374 indices
: Tuple
[IndexVar
, ...]
1376 def __post_init__(self
) -> None:
1377 """Verifies the tensor and indices for a tensor access.
1380 ValueError: If indices is not a list of IndexVar or the len of indices
1381 doesn't equal to the rank of the tensor.
1383 if (not isinstance(self
.indices
, tuple) or
1384 not _all_instance_of(self
.indices
, IndexVar
)):
1385 raise ValueError(f
"Indices contain non IndexVar: {str(self.indices)}.")
1386 if self
.tensor
.order
!= len(self
.indices
):
1387 raise ValueError("Invalid indices for rank: "
1388 f
"str{self.tensor.order} != len({str(self.indices)}).")
1390 def _emit_expression(
1392 expr_to_opnd
: Dict
[IndexExpr
, lang
.OperandDef
],
1393 expr_to_info
: _ExprInfoDict
,
1394 ) -> lang
.ScalarExpression
:
1395 """Emits a linalg dialect TensorUse expression for the tensor access."""
1396 assert self
in expr_to_opnd
1397 dims
= _mlir_dimensions_from_index_vars(self
.indices
)
1398 return lang
.TensorUse(expr_to_opnd
[self
], dims
)
1404 leaf_checker
: _SubtreeLeafChecker
= None) -> None:
1406 assert leaf_checker(self
, *args
)
1409 def dtype(self
) -> DType
:
1410 return self
.tensor
.dtype
1413 def _gather_input_accesses_index_vars(
1415 input_accesses
: List
[Access
],
1417 """Collects Access nodes."""
1418 if isinstance(expr
, Access
) and expr
not in input_accesses
:
1419 input_accesses
.append(expr
)
1422 def _op_to_callable(op
: _BinaryOp
) -> lang
.ArithFnType
:
1423 """Returns the linalg dialect function object for the given operation."""
1425 operator
.add
: lang
.ArithFn
.add
,
1426 operator
.sub
: lang
.ArithFn
.sub
,
1427 operator
.mul
: lang
.ArithFn
.mul
,
1429 return op_to_callable
[op
]
1432 @dataclasses.dataclass(frozen
=True)
1433 class _BinaryExpr(IndexExpr
):
1434 """The representation for a binary operation.
1437 op: A _BinaryOp representing the binary operation.
1438 a: An IndexExpr representing the first operand of the operation.
1439 b: An IndexExpr representing the second operand of the operation.
1445 def __post_init__(self
) -> None:
1446 """Verifies that the operands being added are IndexExpr."""
1447 assert isinstance(self
.a
, IndexExpr
) and isinstance(self
.b
, IndexExpr
)
1449 def _emit_expression(
1451 expr_to_opnd
: Dict
[IndexExpr
, lang
.OperandDef
],
1452 expr_to_info
: _ExprInfoDict
,
1453 ) -> lang
.ScalarExpression
:
1454 """Emits the expression tree and returns the expression."""
1455 # The current expression node is an internal node of the structured op.
1456 if self
not in expr_to_opnd
:
1457 a
= self
.a
._emit
_expression
(expr_to_opnd
, expr_to_info
)
1458 b
= self
.b
._emit
_expression
(expr_to_opnd
, expr_to_info
)
1459 return _op_to_callable(self
.op
)(a
, b
)
1461 # The current expression is a leaf node of the structured op. That is, it is
1462 # a temporary tensor generated by its child structured op.
1463 op_info
= expr_to_info
[self
].structop_info
1464 assert op_info
is not None
1465 dims
= _mlir_dimensions_from_index_vars(op_info
.dst_indices
)
1466 return lang
.TensorUse(expr_to_opnd
[self
], dims
)
1472 leaf_checker
: _SubtreeLeafChecker
= None) -> None:
1473 """A post-order visitor."""
1474 if leaf_checker
is None or not leaf_checker(self
, *args
):
1475 self
.a
._visit
(func
, args
, leaf_checker
=leaf_checker
)
1476 self
.b
._visit
(func
, args
, leaf_checker
=leaf_checker
)
1479 def dtype(self
) -> DType
:
1480 """Returns the data type of the binary operation."""
1481 return self
.a
.dtype()
1484 def _validate_and_collect_dim_info(
1485 index_to_dim_info
: Dict
[IndexVar
, _DimInfo
],
1486 indices
: Tuple
[IndexVar
, ...],
1487 dim_infos
: Tuple
[_DimInfo
, ...],
1490 """Validates and collects the dimension information for an index notation.
1492 Validates (indices, dim_infos) against the information collected from other
1493 source operands and is represented by index_to_dim_info. In particular, we
1494 ensure that each IndexVar corresponds to only one dimension size. We also
1495 aggregate the new information represented in (indices, dim_infos) to
1499 index_to_dim: A dictionary of (IndexVar, _DimInfo) collected from the
1501 indices: The IndexVars to be validated.
1502 dim_infos: The dimension information for the IndexVars to be validated.
1503 expr: The binary expression where (indices, dim_infos) is used.
1506 ValueError if there is any problem in the IndexVars or dimensional values.
1508 assert len(indices
) == len(dim_infos
)
1509 for i
, d
in zip(indices
, dim_infos
):
1510 if i
not in index_to_dim_info
:
1511 index_to_dim_info
[i
] = d
1513 if d
.dim
!= index_to_dim_info
[i
].dim
:
1514 raise ValueError(f
"Inconsistent source dimension for {i}: "
1515 f
"{d.dim} vs {index_to_dim_info[i].dim}")
1516 mode_format
= _mode_format_estimator(expr
.op
)(
1517 index_to_dim_info
[i
].mode_format
, d
.mode_format
)
1518 index_to_dim_info
[i
] = _DimInfo(d
.dim
, mode_format
)
1521 def _validate_and_collect_expr_info(
1523 expr_to_info
: _ExprInfoDict
,
1525 """Validates dimension information and constructs _ExprInfo.
1527 Validates that dimensional values for the same IndexVar are the same. Collects
1528 a list of IndexVar used by the expression and their corresponding dimensional
1529 values. Constructs an _ExprInfo object to record the information for the
1532 This routine is passed to the post-order visitor as an _ExprVisitor object.
1535 expr: The IndexExpr being validated.
1536 expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
1537 expression information.
1540 ValueError if there is any problem in the IndexVars or dimensional values.
1542 # Objects of class Access can be shared by different expressions. Avoid
1543 # processing Access objects multiple times by skipping the processing if expr
1544 # is already in the dictionary.
1545 if expr
in expr_to_info
:
1548 if isinstance(expr
, Access
):
1549 src_indices
= expr
.indices
1550 src_dims
= tuple(expr
.tensor
.shape
)
1551 mode_formats
= tuple(expr
.tensor
.format
.format_pack
.formats
)
1552 assert len(src_dims
) == len(mode_formats
)
1553 dim_infos
= tuple([_DimInfo(d
, m
) for d
, m
in zip(src_dims
, mode_formats
)])
1555 assert isinstance(expr
, _BinaryExpr
)
1556 a_info
= expr_to_info
[expr
.a
]
1557 index_to_dim_info
= {
1558 i
: d
for i
, d
in zip(a_info
.src_indices
, a_info
.dim_infos
)
1560 b_info
= expr_to_info
[expr
.b
]
1561 _validate_and_collect_dim_info(index_to_dim_info
, b_info
.src_indices
,
1562 b_info
.dim_infos
, expr
)
1563 # Here we rely on the fact that dictionaries keep the insertion order for
1565 src_indices
= tuple(index_to_dim_info
.keys())
1566 dim_infos
= tuple(index_to_dim_info
.values())
1568 expr_to_info
[expr
] = _ExprInfo(src_indices
, dim_infos
)
1571 def _mark_structured_op_root(
1573 reduce_index
: IndexVar
,
1574 expr_to_info
: _ExprInfoDict
,
1576 """Identifies the root expression for a structured op in the linalg dialect.
1578 An linalg structured op can only perform reduction on the whole expression.
1579 For a TACO tensor algebra expression, the reduction on an IndexVar is done at
1580 the smallest expression that contains all the uses of the IndexVar. If such an
1581 expression is only part of the whole expression, we need to split this
1582 sub-expression tree out from its parent and implement the sub-expression as a
1585 This routine identifies the root expression node for performing a reduction on
1586 the given IndexVar. If the reduction of the given IndexVar should be performed
1587 on expression X, then the IndexVar is added to expr_to_info[X].reduce_indices
1590 expr: The root IndexExpr for the tensor algebra expression.
1591 reduce_index: The IndexVar which we want to find out the proper expression
1592 to perform a reduction.
1593 expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
1595 assert (isinstance(expr
, _BinaryExpr
))
1596 a_info
= expr_to_info
[expr
.a
]
1597 b_info
= expr_to_info
[expr
.b
]
1598 expr_info
= expr_to_info
[expr
]
1600 if reduce_index
in a_info
.src_indices
and reduce_index
in b_info
.src_indices
:
1601 expr_info
.reduce_indices
.add(reduce_index
)
1604 if reduce_index
in a_info
.src_indices
:
1605 _mark_structured_op_root(expr
.a
, reduce_index
, expr_to_info
)
1606 elif reduce_index
in b_info
.src_indices
:
1607 _mark_structured_op_root(expr
.b
, reduce_index
, expr_to_info
)
1609 assert False, "Unreachable path"
1612 def _accumulate_reduce_indices(
1614 expr_to_info
: _ExprInfoDict
,
1616 """Propagates reduction indices from child expressions to parent expressions.
1618 This routine is passed to the post-order visitor as an _ExprVisitor object.
1621 expr: The IndexExpr being visited.
1622 expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
1623 expression information.
1625 assert expr
in expr_to_info
1626 expr_info
= expr_to_info
[expr
]
1628 if isinstance(expr
, _BinaryExpr
):
1629 a_info
= expr_to_info
[expr
.a
]
1630 b_info
= expr_to_info
[expr
.b
]
1631 expr_info
.acc_reduce_indices
= (
1632 a_info
.acc_reduce_indices | b_info
.acc_reduce_indices
1633 | expr_info
.reduce_indices
)
1635 assert isinstance(expr
, Access
)
1638 def _gather_structured_op(
1640 expr_to_info
: _ExprInfoDict
,
1641 structop_roots
: List
[IndexExpr
],
1643 """Adds structured op root expression information to structop_roots.
1645 This routine is passed to the post-order visitor as an _ExprVisitor object.
1648 expr: The IndexExpr being visited.
1649 expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
1650 structop_roots: The resulting list of IndexExpr that are the roots for
1651 linalg structured ops.
1653 if not expr_to_info
[expr
].reduce_indices
:
1656 # If the expression is the root for reducing some indices, collect the indices
1657 # and dimensions for the reduction result.
1661 for i
, d
in zip(expr_to_info
[expr
].src_indices
, expr_to_info
[expr
].dim_infos
):
1662 if i
not in expr_to_info
[expr
].acc_reduce_indices
:
1663 dst_indices
.append(i
)
1664 dst_dims
.append(d
.dim
)
1665 mode_fmts
.append(d
.mode_format
)
1667 # Add the information to the dictionary.
1668 op_info
= _StructOpInfo(
1672 f
"temp{len(structop_roots)}",
1673 _make_format(mode_fmts
),
1675 expr_to_info
[expr
].structop_info
= op_info
1677 # Add the expression to the list of structured op roots.
1678 structop_roots
.append(expr
)
1681 def _is_structured_op_leaf(
1684 expr_to_info
: _ExprInfoDict
,
1687 """Returns true iff the expression is a leaf node for a structured op.
1689 The root of a structured op is a leaf of its parent structured op that uses
1690 its result. An expression node is a leaf node for the current structured op if
1691 it is an Access node or the root for a structured op that is not the current
1694 This routine is passed to the post-order visitor as a _SubtreeLeafChecker
1695 object. Because the post-order visitor pass the same parameters to both
1696 _SubtreeLeafChecker and _ExprVisitor, this routine may received unused
1700 expr: The IndexExpr being visited.
1701 root: The root of the current structured op.
1702 expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
1705 True if the current IndexExpr is a leaf for the current structured op.
1707 return (expr
!= root
and
1708 expr_to_info
[expr
].structop_info
is not None) or isinstance(
1712 def _gather_structured_op_input(
1715 expr_to_info
: _ExprInfoDict
,
1716 structop_inputs
: List
[IndexExpr
],
1718 """Adds the IndexExpr to structop_inputs if it is an input.
1720 If the current IndexExpr is an input for the current structured op, adds it to
1721 structop_inputs. The current IndexExpr is an input if it is an Access node or
1722 if it is the root for a structured op that is not the current structured op.
1724 This routine is passed to the post-order visitor as an _ExprVisitor object.
1727 expr: The IndexExpr being visited.
1728 root: The root of the current structured op.
1729 expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
1730 structop_inputs: The resulting list of IndexExpr that provide input to the
1731 current structured op.
1733 if (expr
!= root
and expr
not in structop_inputs
) and (
1734 isinstance(expr
, Access
) or
1735 (expr
in expr_to_info
and expr_to_info
[expr
].structop_info
)):
1736 structop_inputs
.append(expr
)
1739 def _emit_structured_op_input(
1741 expr_to_info
: _ExprInfoDict
,
1742 op_def
: lang
.LinalgOpDef
,
1743 ) -> lang
.OperandDef
:
1744 """Emits OperandDef in the linalg dialect for the input IndexExpr.
1747 expr: The input IndexExpr for the current structured op.
1748 expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
1749 op_def: The linalg operation for the current structured op.
1752 An OperandDef in the linalg dialect for the input IndexExpr.
1754 op_info
= expr_to_info
[expr
].structop_info
1756 # The input is a temporary tensor produced by another structured op.
1757 indices
= op_info
.dst_indices
1758 name
= op_info
.dst_name
1760 # The input is a user provided tensor.
1761 assert isinstance(expr
, Access
)
1762 indices
= expr
.indices
1763 name
= expr
.tensor
.name
1765 dim_sym
= _mlir_symbols_from_index_vars(indices
)
1766 opnd
= lang
.OperandDef(lang
.OperandKind
.InputTensor
, lang
.T
, dim_sym
)
1767 op_def
.add_operand(name
, opnd
)