[AMDGPU] Make v8i16/v8f16 legal
[llvm-project.git] / mlir / test / Integration / Dialect / SparseTensor / taco / tools / mlir_pytaco.py
blobf64d34037eabd37af524e1d8231981b304909380
1 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2 # See https://llvm.org/LICENSE.txt for license information.
3 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 """Experimental MLIR-PyTACO with sparse tensor support.
7 See http://tensor-compiler.org/ for TACO tensor compiler.
9 This module implements the Python classes for PyTACO index notation. These
10 include classes for data types, tensor dimension formats (aka mode formats),
11 tensor dimension orderings (aka mode ordering), tensor storage formats, and
12 tensors.
14 The PyTACO API doesn't follow the naming conversion required by the style guide
15 for this module. As such, we first implement the supporting classes and routines
16 following the style guide, and then define the type aliases and constants to
17 support the PyTACO API in the pytaco_api module.
18 """
20 from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
22 import abc
23 import ctypes
24 import dataclasses
25 import enum
26 import numpy as np
27 import functools
28 import operator
29 import os
30 import threading
32 # Import MLIR related modules.
33 from mlir import all_passes_registration # Register MLIR compiler passes.
34 from mlir import execution_engine
35 from mlir import ir
36 from mlir import runtime
37 from mlir.dialects import arith
38 from mlir.dialects import builtin
39 from mlir.dialects import linalg
40 from mlir.dialects import std
41 from mlir.dialects import sparse_tensor
42 from mlir.dialects.linalg.opdsl import lang
43 from mlir.passmanager import PassManager
45 from . import mlir_pytaco_utils as utils
47 # TACO naming prefixes.
48 _TACO_INDEX_PREFIX = "i"
49 _TACO_TENSOR_PREFIX = "A"
51 # Bitwidths for pointers and indices.
52 _POINTER_BIT_WIDTH = 0
53 _INDEX_BIT_WIDTH = 0
54 # The name for the environment variable that provides the full path for the
55 # supporting library.
56 _SUPPORTLIB_ENV_VAR = "SUPPORTLIB"
57 # The default supporting library if the environment variable is not provided.
58 _DEFAULT_SUPPORTLIB = "libmlir_c_runner_utils.so"
59 # The JIT compiler optimization level.
60 _OPT_LEVEL = 2
61 # The entry point to the JIT compiled program.
62 _ENTRY_NAME = "main"
64 # Type aliases for type annotation.
65 _BinaryOp = Callable[[Any, Any], Any]
66 _ExprVisitor = Callable[..., None]
67 _ExprInfoDict = Dict["IndexExpr", "_ExprInfo"]
68 _LogicalOp = Callable[[bool, bool], bool]
69 _ModeFormatOp = Callable[["ModeFormat", "ModeFormat"], "ModeFormat"]
70 _SubtreeLeafChecker = Optional[Callable[..., bool]]
73 class Type(enum.Enum):
74 """The data types supported by TACO.
76 We use numpy data types to implement the enum data types.
77 """
78 INT16 = np.int16
79 INT32 = np.int32
80 INT64 = np.int64
81 # numpy _ctype_from_dtype_scalar can't handle np.float16 yet.
82 FLOAT32 = np.float32
83 FLOAT64 = np.float64
86 # All floating point type enums.
87 _FLOAT_TYPES = (Type.FLOAT32, Type.FLOAT64)
88 # All integral type enums.
89 _INT_TYPES = (Type.INT16, Type.INT32, Type.INT64)
90 # Type alias for any numpy type used to implement the runtime support for the
91 # enum data types.
92 _AnyRuntimeType = Union[np.int16, np.int32, np.int64, np.float32, np.float64]
95 @dataclasses.dataclass(frozen=True)
96 class DType:
97 """The data type class.
99 We support the TACO API dtype class with an alias of this class.
101 The following methods are defined by the TACO API:
102 is_float: Returns whether the data type represents a floating point value.
103 is_int: Returns whether the data type represents an integral value.
105 Attributes:
106 kind: A Type enum representing the data type.
107 value: The numpy data type for the TACO data type.
109 kind: Type = Type.FLOAT64
111 def is_float(self) -> bool:
112 """Returns whether the data type represents a floating point value."""
113 return self.kind in _FLOAT_TYPES
115 def is_int(self) -> bool:
116 """Returns whether the data type represents an integral value."""
117 return self.kind in _INT_TYPES
119 @property
120 def value(self) -> _AnyRuntimeType:
121 """Returns the numpy dtype for the data type."""
122 return self.kind.value
125 def _mlir_type_from_taco_type(dtype: DType) -> ir.Type:
126 """Returns the MLIR type corresponding to the given TACO type."""
127 dtype_to_irtype = {
128 Type.INT16: ir.IntegerType.get_signless(16),
129 Type.INT32: ir.IntegerType.get_signless(32),
130 Type.INT64: ir.IntegerType.get_signless(64),
131 Type.FLOAT32: ir.F32Type.get(),
132 Type.FLOAT64: ir.F64Type.get()
134 return dtype_to_irtype[dtype.kind]
137 def _compile_mlir(module: ir.Module) -> ir.Module:
138 """Compiles an MLIR module and returns the compiled module."""
139 # TODO: Replace this with a pipeline implemented for
140 # https://github.com/llvm/llvm-project/issues/51751.
141 pipeline = (
142 f"sparsification,"
143 f"sparse-tensor-conversion,"
144 f"builtin.func(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),"
145 f"convert-scf-to-std,"
146 f"func-bufferize,"
147 f"tensor-constant-bufferize,"
148 f"builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),"
149 f"convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}},"
150 f"lower-affine,"
151 f"convert-memref-to-llvm,"
152 f"convert-std-to-llvm,"
153 f"reconcile-unrealized-casts")
154 PassManager.parse(pipeline).run(module)
155 return module
158 @functools.lru_cache()
159 def _get_support_lib_name() -> str:
160 """Returns the string for the supporting C shared library."""
161 return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
164 def _ctype_pointer_from_array(array: np.ndarray) -> ctypes.pointer:
165 """Returns the ctype pointer for the given numpy array."""
166 return ctypes.pointer(
167 ctypes.pointer(runtime.get_ranked_memref_descriptor(array)))
170 class ModeFormat(enum.Enum):
171 """The tensor dimension storage format class.
173 We support the TACO API mode_format class with an alias of this class.
175 In TACO, a tensor dimension is called a mode and the storage format for a
176 tensor dimension is called a mode format.
178 DENSE = sparse_tensor.DimLevelType.dense
179 COMPRESSED = sparse_tensor.DimLevelType.compressed
182 def _mode_format_operation(a: ModeFormat, b: ModeFormat,
183 op: _LogicalOp) -> ModeFormat:
184 """Implements the given operator on ModeFormat."""
185 return (ModeFormat.COMPRESSED
186 if op(a == ModeFormat.COMPRESSED, b == ModeFormat.COMPRESSED) else
187 ModeFormat.DENSE)
190 def _mode_format_estimator(op: _BinaryOp) -> _ModeFormatOp:
191 """Produces a ModeFormat operator for the given binary operator.
193 The ModeFormat operator is used as a heuristic to derive the destination
194 dimension sparsity from the source dimension sparsity. In particular, if the
195 binary operator produces a disjunction of the zero values from its source
196 operands, such as the MUL operator, we return a ModeFormat operator that
197 uses operator.or_. That is, we estimate that a dimension for the MUL
198 operation result to be sparse if either of its source operands is sparse.
200 On the other hand, if the binary operator produces a conjunction of the
201 zero values from its source operands, such as the ADD operator, we return
202 a ModeFormat operator that uses operator.and_. In this case, we estimate
203 that a dimension for the ADD operation result to be sparse if both of its
204 source operands are sparse.
206 Args:
207 op: A _BinaryOp object representing a supporting operator on tensors.
209 Returns:
210 A ModeFormatOp for estimating the destination dimension sparsity from
211 the source dimension sparsity.
213 conjunction = functools.partial(_mode_format_operation, op=operator.and_)
214 disjunction = functools.partial(_mode_format_operation, op=operator.or_)
215 return conjunction if op(0, 1) != 0 else disjunction
218 def _all_instance_of(collection: Iterable, cls: Any) -> bool:
219 """Returns true if all elements of the iterable is an instance of cls."""
220 return all(isinstance(e, cls) for e in collection)
223 def _identity_ordering(rank: int) -> List[int]:
224 """Returns the identity ordering for tensor of given rank."""
225 return list(range(rank))
228 @dataclasses.dataclass(frozen=True)
229 class ModeOrdering:
230 """The tensor dimension ordering class.
232 We support the TACO API mode_ordering class with an alias of this class.
234 Attributes:
235 ordering: A list of integers representing the ordering of the tensor
236 dimensions.
238 ordering: List[int]
240 def __post_init__(self) -> None:
241 """Verifies the value in ordering.
243 Raises:
244 ValueError: If ordering is not a list of integers.
246 if (not isinstance(self.ordering, list) or
247 not _all_instance_of(self.ordering, int)):
248 raise ValueError("Ordering must be a list of integers: "
249 f"{self.ordering}")
250 # Check that ordering is a permutation of the dimension numbers.
251 if sorted(self.ordering) != _identity_ordering(self.rank()):
252 raise ValueError(f"Invalid ordering: {self.ordering} != "
253 f"permutation{_identity_ordering(self.rank())}.")
255 def rank(self) -> int:
256 """Returns the number of dimensions represented by the ordering."""
257 return len(self.ordering)
260 @dataclasses.dataclass(frozen=True)
261 class ModeFormatPack:
262 """The tensor dimension format class.
264 We support the TACO API mode_format_pack class with an alias of this class.
266 The storage format of a tensor contains one mode_format for each tensor
267 dimension.
269 Attributes:
270 formats: A list of ModeFormat representing the storage format for each of
271 the tensor dimension.
273 formats: List[ModeFormat]
275 def __post_init__(self) -> None:
276 """Verifies the value in formats.
278 Raises:
279 ValueError: If formats is not a list of ModeFormats.
281 if (not isinstance(self.formats, list) or
282 not _all_instance_of(self.formats, ModeFormat)):
283 raise ValueError("Formats must be a list of ModeFormat: "
284 f"{self.formats}")
286 def rank(self) -> int:
287 """Returns the number of dimensions represented by the format pack."""
288 return len(self.formats)
291 @dataclasses.dataclass
292 class Format:
293 """The tensor format class defined by the TACO API.
295 Attributes:
296 format_pack: A ModeFormatPack representing the storage format for the tensor
297 dimensions.
298 ordering: A ModeOrdering representing the tensor dimension ordering in the
299 storage.
301 format_pack: ModeFormatPack
302 ordering: Optional[ModeOrdering] = None
304 def __post_init__(self) -> None:
305 """Verifies and fixes up the values in format_pack and ordering.
307 Verifies and fixes up the values in format_pack and ordering to supports the
308 initializer syntax defined by the TACO API. If format_pack is a list of
309 ModeFormat, replaces it with ModeFormatPack constructed from the list. If
310 ordering is not provided, set ordering to the natural ordering for the rank
311 corresponding to format_pack.
313 Raises:
314 ValueError: If format_pack is not an instance of ModeFormatPack or if
315 ordering is not an instance of ModeOrdering.
317 if isinstance(self.format_pack, list):
318 if not _all_instance_of(self.format_pack, ModeFormat):
319 raise ValueError(f"Expected a list of ModeFormat: {self.format_pack}")
320 self.format_pack = ModeFormatPack(self.format_pack)
321 if not isinstance(self.format_pack, ModeFormatPack):
322 raise ValueError(f"Expected ModeFormatpack: {self.format_pack}")
324 if self.ordering is None:
325 self.ordering = ModeOrdering(list(range(self.rank())))
326 if not isinstance(self.ordering, ModeOrdering):
327 raise ValueError(f"Expected ModeOrdering: {self.ordering}")
329 if self.format_pack.rank() != self.ordering.rank():
330 raise ValueError("Inconsistent ModeFormatPack and ModeOrdering: "
331 f"len({self.format_pack}) != "
332 f"len({self.ordering})")
334 def is_dense(self) -> bool:
335 """Returns true if all the Tensor dimensions have a dense format."""
336 return all([f == ModeFormat.DENSE for f in self.format_pack.formats])
338 def rank(self) -> int:
339 """Returns the number of dimensions represented by the format."""
340 return self.format_pack.rank()
342 def mlir_tensor_attr(self) -> Optional[sparse_tensor.EncodingAttr]:
343 """Constructs the MLIR attributes for the tensor format."""
344 if self.is_dense():
345 return None
347 order = (
348 range(self.rank()) if
349 (self.ordering is None) else self.ordering.ordering)
350 mlir_storage_format = [f.value for f in self.format_pack.formats]
351 return sparse_tensor.EncodingAttr.get(mlir_storage_format,
352 ir.AffineMap.get_permutation(order),
353 _POINTER_BIT_WIDTH, _INDEX_BIT_WIDTH)
356 def _make_format(formats: List[ModeFormat],
357 ordering: Optional[List[int]] = None) -> Format:
358 """Constructs a format from a list of ModeFormat and an optional ordering.
360 Args:
361 formats: A list of ModeFormat, one for each dimension of a tensor.
362 ordering: An optional list of integer, for the ordering of the tensor
363 dimensions. When an ordering is not given, the identity ordering is used.
365 Returns:
366 A tensor format object.
368 Raises:
369 ValueError: If formats is not a list of ModeFormat or the length of formats
370 is not consistent with the len of ordering.
372 ordering = ordering or _identity_ordering(len(formats))
373 return Format(ModeFormatPack(formats), ModeOrdering(ordering))
376 class _AtomicCounter:
377 """An atomic counter."""
379 def __init__(self):
380 self._counter = 0
381 self._counter_lock = threading.Lock()
383 def increment(self) -> int:
384 """Increments the counter by one and returns the old value."""
385 old_value = self._counter
386 with self._counter_lock:
387 self._counter = self._counter + 1
388 return old_value
391 class IndexVar:
392 """The tensor index class.
394 We support the TACO API index_var class with an alias of this class.
396 An IndexVar object represents an index variable in tensor index notation.
398 Attributes:
399 name: A unique string name of the IndexVar.
401 _counter = _AtomicCounter()
403 def __init__(self):
404 id = self._counter.increment()
405 self._name = f"{_TACO_INDEX_PREFIX}{id}"
407 def __repr__(self) -> str:
408 return f"IndexVar(name={repr(self._name)})"
410 @property
411 def name(self) -> str:
412 """Returns the name of the IndexVar."""
413 return self._name
416 def get_index_vars(n: int) -> List[IndexVar]:
417 """Returns a list of n IndexVar.
419 This routine is defined by the TACO API.
421 Args:
422 n: An interger representing the number of IndexVar to get.
424 Returns:
425 A list of IndexVar.
427 Raises:
428 ValueError: if n is not a positive integer.
430 if not isinstance(n, int) or n <= 0:
431 raise ValueError(f"Expected an integer: {n}.")
432 # If lock contention ever becomes an issue, we could implement a bulk getter
433 # that returns a range by only claiming the lock once.
434 return [IndexVar() for i in range(n)]
437 def _mlir_symbols_from_index_vars(
438 index_vars: Tuple[IndexVar, ...]) -> Tuple[lang.SymbolDef, ...]:
439 """Returns a tuple of MLIR symbols for the given tuple of index_var."""
440 return tuple(getattr(lang.S, i.name) for i in index_vars)
443 def _mlir_dimensions_from_index_vars(
444 index_vars: Tuple[IndexVar, ...]) -> Tuple[lang.DimDef, ...]:
445 """Returns a tuple of MLIR dimensions for the given tuple of index_var."""
446 return tuple(getattr(lang.D, i.name) for i in index_vars)
449 def _mlir_tensor_type(
450 dtype: DType, shape: Tuple[int, ...],
451 attr: Optional[sparse_tensor.EncodingAttr]) -> ir.RankedTensorType:
452 """Returns an MLIR tensor type.
454 Args:
455 dtype: An DType object for the element data type of the tensor.
456 shape: A tuple of integer for the shape of the tensor.
457 attr: An optional MLIR sparse tensor attribute, only provided if the tensor
458 is a sparse tensor.
460 Returns:
461 An MLIR ranked tensor type.
463 ir_type = _mlir_type_from_taco_type(dtype)
464 return ir.RankedTensorType.get(shape, ir_type, attr)
467 def _verify_and_normalize_indices(indices) -> Tuple[IndexVar, ...]:
468 """Verifies and normalizes the indices for a tensor access.
470 Args:
471 indices: The index expression used to access a tensor, which could be any
472 Python object from user inputs.
474 Returns:
475 A tuple of IndexVar.
477 Raises:
478 ValueError: If indices is not an IndexVar or a tuple of IndexVar.
480 if isinstance(indices, IndexVar):
481 return (indices,)
482 elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar):
483 return indices
485 raise ValueError(f"Expected IndexVars: {indices}")
488 @dataclasses.dataclass(frozen=True)
489 class _StructOpInfo:
490 """Information for generating a structured op in the linalg dialect.
492 This information is associated with an expression node that serves as the
493 root for an expression subtree implemented with a structured op.
495 Attributes:
496 dst_indices: A tuple of IndexVar, representing the result dimensions of the
497 structured op. This is used to construct the temporary variable for the
498 tensor to hold the structured op result.
499 dst_dims: A tuple of int, representing the result shape of the structured
501 dst_dtype: A DType representing the data type of the structured op result.
502 dst_name: A string representing the name of the structured op result.
503 dst_format: A Format object representing the destination tensor format.
505 dst_indices: Tuple[IndexVar, ...]
506 dst_dims: Tuple[int, ...]
507 dst_dtype: DType
508 dst_name: str
509 dst_format: Format
511 def __post_init__(self) -> None:
512 """Verifies the integrity of the attribute values."""
513 assert len(self.dst_indices) == len(self.dst_dims)
514 assert self.dst_format is not None
516 def emit_tensor_init(self) -> ir.RankedTensorType:
517 """Returns an initialization for the destination tensor."""
518 if self.dst_format.is_dense():
519 # Initialize the dense tensor.
520 ir_type = _mlir_type_from_taco_type(self.dst_dtype)
521 tensor = linalg.InitTensorOp(self.dst_dims, ir_type).result
522 zero = arith.ConstantOp(ir_type, 0.0)
523 return linalg.FillOp(output=tensor, value=zero).results[0]
525 # Initialize the sparse tensor.
526 mlir_type = _mlir_tensor_type(self.dst_dtype, self.dst_dims,
527 self.dst_format.mlir_tensor_attr())
528 index_type = ir.IndexType.get()
529 dims = [arith.ConstantOp(index_type, d).result for d in mlir_type.shape]
530 return sparse_tensor.InitOp(mlir_type, dims)
533 class _Stats:
534 """Information to describe how a tensor expression is implemented.
536 Currently, we only record the temporary tensors introduced for splitting the
537 original expression.
540 def __init__(self):
541 self._temps = []
543 def __repr__(self) -> str:
544 return f"_Stats({repr(self._temps)})"
546 def add_element(self, structop: _StructOpInfo):
547 """Adds a temporary tensor."""
548 self._temps.append(structop)
550 def get_total(self) -> int:
551 """Gets the total number of temporary tensors."""
552 return len(self._temps)
554 def _get_element(self, idx: int) -> _StructOpInfo:
555 """Gets the ith temporary tensor."""
556 assert idx < self.get_total()
557 return self._temps[idx]
559 def get_dimensions(self, idx: int) -> Tuple[int]:
560 """Gets the dimensions for the ith temporary tensor."""
561 return self._get_element(idx).dst_dims
563 def get_formats(self, idx: int) -> Tuple[ModeFormat]:
564 """Gets the ModeFormats for the ith temporary tensor."""
565 return tuple(self._get_element(idx).dst_format.format_pack.formats)
568 class Tensor:
569 """The tensor class.
571 We support the TACO API tensor class with an alias of this class.
573 This class is part of the TACO API with the following methods:
574 insert: Inserts a value to the given coordinate in the tensor.
575 to_array: Returns a numpy ndarray for the tensor.
577 TACO API also defines the following arrtibutes for the class:
578 dtype: A dtype object representing the data type of the tensor.
579 format: A format object representing the storage format of the tensor.
580 name: A string object representing the name of the tensor.
581 order: An integral rank of the tensor.
582 shape: A list of integers representing the shape of the tensor.
584 We currently ignore the tensor dimension ordering for dense tensor.
586 _counter = _AtomicCounter()
588 def _get_unique_name(self) -> str:
589 """Returns a unique name for creating a new Tensor."""
590 return f"{_TACO_TENSOR_PREFIX}{self._counter.increment()}"
592 def _init_format(self, fmt: Union[ModeFormat, List[ModeFormat],
593 Format]) -> None:
594 """Process the fmt argument for the Tensor constructor.
596 Args:
597 fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
598 this argument is a ModeFormat, uses this ModeFormat for all the tensor
599 dimensions. If this argument is a list of ModeFormat, the len of the
600 list should equal to the rank of the tensor. If this argument is a
601 format, uses it for the format of the tensor.
603 Raises:
604 ValueError: If fmt is not one of the expected type or is inconsistent
605 with the rank of the tensor. This is because fmt could be an users
606 input.
608 if isinstance(fmt, ModeFormat):
609 self._format = _make_format([fmt] * self.order)
610 elif isinstance(fmt, list):
611 if len(fmt) == self.order and isinstance(fmt[0], ModeFormat):
612 self._format = _make_format(fmt)
613 else:
614 raise ValueError("Inconsistent shape and format: "
615 f"{self._shape}, {fmt}.")
616 elif isinstance(fmt, Format):
617 if fmt.rank() != self.order:
618 raise ValueError("Inconsistent shape and format: "
619 f"{self._shape}, {fmt}.")
620 else:
621 self._format = fmt
622 else:
623 raise ValueError(f"Invalid format argument: {fmt}.")
625 def __init__(self,
626 value_or_shape: Optional[Union[List[int], Tuple[int, ...], float,
627 int]] = None,
628 fmt: Optional[Union[ModeFormat, List[ModeFormat],
629 Format]] = None,
630 dtype: Optional[DType] = None,
631 name: Optional[str] = None):
632 """The tensor constructor interface defined by TACO API.
634 Args:
635 value_or_shape: This argument is optional and can be int, float,
636 List[int], or Tuple[int, ...]. If this argument is an int or float,
637 creates a scalar tensor and initializes it with the value. If this
638 argument is a list or tuple of int, uses it as the shape to create a
639 tensor.
640 fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
641 this argument is a ModeFormat, uses this ModeFormat for all the tensor
642 dimensions. If this argument is a list of ModeFormat, the len of the
643 list should equal to the rank of the tensor. If this argument is a
644 format, uses it for the format of the tensor.
645 dtype: An object of dtype, representing the data type of the tensor.
646 name: A string name of the tensor. If a name is not given, creates a
647 unique name for the tensor.
649 Raises:
650 ValueError: If there is any inconsistency among the input arguments.
652 # Take care of the argument default values.
653 fmt = fmt or ModeFormat.COMPRESSED
654 dtype = dtype or DType(Type.FLOAT64)
655 self._name = name or self._get_unique_name()
657 self._dtype = dtype
658 # We currently use _coords and _values to host the sparse tensor value with
659 # COO format, and _dense_storage to host the dense tensor value. We haven't
660 # implement the conversion between the two storages yet. This will be
661 # improved in a follow up CL.
662 self._coords = []
663 self._values = []
664 self._dense_storage = None
665 self._stats = _Stats()
666 if value_or_shape is None or isinstance(value_or_shape, int) or isinstance(
667 value_or_shape, float):
668 # Create a scalar tensor and ignore the fmt parameter.
669 self._shape = []
670 self._format = _make_format([], [])
671 if value_or_shape is not None:
672 self._dense_storage = np.array(value_or_shape, dtype=self._dtype.value)
673 elif (isinstance(value_or_shape, tuple) or isinstance(
674 value_or_shape, list)) and _all_instance_of(value_or_shape, int):
675 # Create a tensor with the specified shape and format.
676 self._shape = list(value_or_shape)
677 self._init_format(fmt)
678 else:
679 raise ValueError("Invalid first argument. "
680 "Must be a tuple or list for a shape or a single value"
681 f"if initializing a scalar tensor: {value_or_shape}.")
683 def __repr__(self) -> str:
684 value_str = (f"{repr(self._dense_storage)})" if self.is_dense() else
685 f"{repr(self._coords)} {repr(self._values)})")
686 return (f"Tensor(_name={repr(self._name)} "
687 f"_dtype={repr(self._dtype)} : ") + value_str
689 def insert(self, coords: List[int], val: Union[float, int]) -> None:
690 """Inserts a value to the given coordinate.
692 Args:
693 coords: A list of integer coordinates. The length of the list must be the
694 same as the rank of the tensor.
695 val: A value being inserted. It is either an integral or a floating point
696 value. This value will be converted to the data type of the tensor.
698 Raises:
699 ValueError: When there is any problem in the parameters.
701 if not isinstance(coords, list):
702 raise ValueError(f"Non list coordinate detected: {coords}.")
703 if not _all_instance_of(coords, int):
704 raise ValueError(f"Non integer coordinate detected: {coords}.")
705 if (len(coords) != self.order or
706 any([c < 0 or c >= self._shape[i] for i, c in enumerate(coords)])):
707 raise ValueError("Invalid coordinate for rank: "
708 f"{self.order}, {coords}.")
710 if not isinstance(val, int) and not isinstance(val, float):
711 raise ValueError(f"Value is neither int nor float: {val}.")
713 self._coords.append(tuple(coords))
714 self._values.append(self._dtype.value(val))
716 def is_dense(self) -> bool:
717 """Returns true if all the Tensor dimensions have a dense format."""
718 return self._format.is_dense()
720 def to_array(self) -> np.ndarray:
721 """Returns the numpy array for the Tensor.
723 This is currenly only implemented for dense Tensor.
725 if not self.is_dense():
726 raise ValueError("Conversion from non-dense Tensor "
727 "to numpy array not supported yet.")
728 return self._dense_storage
730 @staticmethod
731 def from_array(array: np.ndarray) -> "Tensor":
732 """Returns a dense tensor with the value copied from the input array.
734 We currently only support the conversion of float64 numpy arrays to Tensor.
736 Args:
737 array: The numpy array that provides the data type, shape and value for
738 the tensor.
740 Returns:
741 A Tensor object.
743 Raises:
744 ValueError if the data type of the numpy array is not float64.
746 if array.dtype != np.float64:
747 raise ValueError(f"Expected float64 value type: {array.dtype}.")
748 tensor = Tensor(array.shape, ModeFormat.DENSE)
749 tensor._dense_storage = np.copy(array)
750 return tensor
752 @staticmethod
753 def from_coo(
754 coordinates: List[Tuple[int, ...]],
755 values: List[_AnyRuntimeType],
756 fmt: Format,
757 dtype: DType,
758 ) -> "Tensor":
759 """Converts coordinates and values to a sparse tensor representation.
761 Args:
762 coordinates: A list of coordinates with non-zero values.
763 values: The non-zero values.
764 fmt: The tensor storage format.
765 dtype: The tensor element data type.
767 Returns:
768 A tensor with the given non-zero values and storage format. The shape of
769 the tensor has the minimum size for each dimension to make the given
770 coordinates valid.
772 assert (isinstance(coordinates, List) and
773 _all_instance_of(coordinates, Tuple))
774 assert (isinstance(values, List) and _all_instance_of(values, dtype.value))
775 assert isinstance(fmt, Format)
777 rank = fmt.rank()
778 assert all(len(c) == rank and _all_instance_of(c, int) for c in coordinates)
780 # Find the maximum coordinate value for each dimension.
781 max_coordinate = list(map(max, zip(*coordinates)))
782 # The size of each dimension is one more that such a maximum coordinate
783 # value.
784 shape = [c + 1 for c in max_coordinate]
785 tensor = Tensor(shape, fmt)
786 tensor._coords = coordinates
787 tensor._values = values
789 return tensor
791 @property
792 def dtype(self) -> DType:
793 """Returns the data type for the Tensor."""
794 return self._dtype
796 @property
797 def format(self) -> Format:
798 """Returns the storage format for the Tensor."""
799 return self._format
801 @property
802 def name(self) -> str:
803 """Returns the name for the Tensor."""
804 return self._name
806 @property
807 def order(self) -> int:
808 """Returns the rank of the Tensor."""
809 return len(self._shape)
811 @property
812 def shape(self) -> List[int]:
813 """Returns the shape of the Tensor."""
814 return self._shape
816 def __getitem__(self, key) -> "Access":
817 """Verifies and processes a tensor access.
819 In the tensor index notation, a tensor access T[i, j] is represented as
820 retrieving a value with key (i, j) from the tensor object T in Python. This
821 routine verifies the key for the tensor access and returns a tensor access
822 object.
824 Args:
825 key: The key used to access the tensor, which could be any Python object
826 from user inputs.
828 Returns:
829 The corresponding tensor access object.
831 Raises:
832 ValueError: If key is not an IndexVar or a tuple of IndexVar.
834 indices = _verify_and_normalize_indices(key)
835 return Access(self, indices)
837 def __setitem__(self, key, value) -> None:
838 """Verifies and processes a tensor assignment.
840 In the tensor index notation, a tensor assignment "T[i, j] = ..." is
841 represented as setting a value for a tensor object T via key (i, j) in
842 Python. This routine verifies the key, evaluates the value, and assigns the
843 value to the tensor.
845 We only support assignment of dense tensor currently.
847 Args:
848 key: The key used to access the tensor, which could be any Python object
849 from user inputs.
850 value: The value assigned to the tensor, which could be any Python object
851 from user inputs.
853 Raises:
854 ValueError: If tensor is not a dense tensor, or the key is not an IndexVar
855 or a tuple of IndexVar, or the length of the indices is not the same as
856 the rank of the tensor.
858 indices = _verify_and_normalize_indices(key)
859 if len(indices) != self.order:
860 raise ValueError("Mismatch between indices and tensor rank: "
861 f"len({indices}) != {self.order}.")
863 result = value.evaluate(self, indices)
864 if self.is_dense():
865 assert isinstance(result, np.ndarray)
866 self._dense_storage = result
867 else:
868 assert _all_instance_of(result, np.ndarray) and len(result) == 2
869 assert (result[0].ndim, result[1].ndim) == (1, 2)
870 (self._values, self._coords) = result
872 def mlir_tensor_type(self) -> ir.RankedTensorType:
873 """Returns the MLIR type for the tensor."""
874 return _mlir_tensor_type(self._dtype, tuple(self._shape),
875 self._format.mlir_tensor_attr())
877 def dense_dst_ctype_pointer(self) -> ctypes.pointer:
878 """Returns the ctypes pointer for the pointer to an MemRefDescriptor.
880 For a dense tensor output, the MLIR compiler allocates the storage for
881 the tensor. This routine returns the pointer to an MLIR MemRefDescriptor for
882 receiving the tensor.
884 assert self.is_dense()
885 mem_ref_desc = runtime.make_nd_memref_descriptor(
886 self.order, np.ctypeslib.as_ctypes_type(self.dtype.value))()
887 return ctypes.pointer(ctypes.pointer(mem_ref_desc))
889 def ctype_pointer(self) -> ctypes.pointer:
890 """Returns the ctypes pointer for the pointer to the input tensor."""
891 if self.is_dense():
892 if self._dense_storage is None:
893 self._dense_storage = np.zeros(self._shape, self._dtype.value)
894 return _ctype_pointer_from_array(self._dense_storage)
896 shape = np.array(self._shape, np.int64)
897 indices = np.array(self._coords, np.int64)
898 values = np.array(self._values, self._dtype.value)
899 ptr = utils.coo_tensor_to_sparse_tensor(_get_support_lib_name(), shape,
900 values, indices)
901 return ctypes.pointer(ctypes.cast(ptr, ctypes.c_void_p))
903 def get_coordinates_and_values(
904 self) -> Tuple[List[Tuple[int, ...]], List[_AnyRuntimeType]]:
905 """Returns the coordinates and values for the non-zero elements."""
906 if not self.is_dense():
907 return (self._coords, self._values)
909 # Coordinates for non-zero elements, grouped by dimensions.
910 coords_by_dims = self._dense_storage.nonzero()
911 # Coordinates for non-zero elements, grouped by elements.
912 coords = np.transpose(coords_by_dims)
913 values = self._dense_storage[coords_by_dims]
914 return (coords, values)
916 def _record_stats(self, structop: "_StructOpInfo"):
917 """Collects information for temporary tensors."""
918 # Exclude user specified destination tensors.
919 if structop.dst_name == self.name:
920 return
922 self._stats.add_element(structop)
925 def _emit_operand(op_def: lang.LinalgOpDef, indices: Tuple[IndexVar, ...],
926 name: str, kind: lang.OperandKind) -> lang.OperandDef:
927 """Emits an operand for a tensor access in the current linalg operation.
929 Args:
930 op_def: A LinalgOpDef representing the current linalg dialect operation.
931 indices: A tuple of IndexVar used to access the tensor.
932 name: A unique string name of the tensor.
933 kind: An OperandKind for the operand.
935 Returns:
936 An OperandDef representing the operand.
938 dim_sym = _mlir_symbols_from_index_vars(indices)
939 opnd = lang.OperandDef(kind, lang.T, dim_sym)
940 op_def.add_operand(name, opnd)
941 return opnd
944 @dataclasses.dataclass(frozen=True)
945 class _DimInfo:
946 """Information for an operand dimension.
948 Attributes:
949 dim: An integer for the size of the dimension.
950 mode_format: A ModeFormat for the dimension sparsity.
952 dim: int
953 mode_format: ModeFormat
956 @dataclasses.dataclass()
957 class _ExprInfo:
958 """Expression information for validation and code generation.
960 Attributes:
961 src_indices: A tuple of IndexVar for the indices used by the tensors in the
962 expression tree.
963 dim_infos: A tuple of _DimInfo, representing the dimension information
964 corresponding to the src_indices.
965 reduce_indices: A set of IndexVar for the indices reduced by the expression.
966 acc_reduce_indices: An accumulated set of IndexVar for the indices reduced
967 by the expression and its children.
968 structop_info: Information to support the code generation for a structured
969 op in the linalg dialect, if the corresponding expression node is the root
970 of a subtree for a structured op.
971 mlir_value: The MLIR value generated for the structured op.
973 src_indices: Tuple[IndexVar, ...]
974 dim_infos: Tuple[_DimInfo, ...]
975 reduce_indices: Optional[Set[IndexVar]] = None
976 acc_reduce_indices: Optional[Set[IndexVar]] = None
977 structop_info: Optional[_StructOpInfo] = None
978 mlir_value: Optional[ir.Value] = None
980 def __post_init__(self) -> None:
981 """Verifies and fix up attribute values.
983 Verifies the consistency of the attributes and modifies the default values
984 to support convenient initializer syntax.
986 assert len(self.src_indices) == len(self.dim_infos)
987 self.reduce_indices = self.reduce_indices or set()
988 self.acc_reduce_indices = self.acc_reduce_indices or set()
991 class IndexExpr(abc.ABC):
992 """The index notation base class.
994 We support the TACO API index_expression class with an alias of this class.
997 def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr":
998 """Verifies the RHS operand and returns a binary expression.
1000 Args:
1001 rhs: The RHS of the binary operation, which could be any Python object
1002 from user inputs.
1003 op: A _BinaryOp object representing the binary operator.
1005 Raises:
1006 ValueError: If rhs is not an IndexExpr.
1008 if not isinstance(rhs, IndexExpr):
1009 raise ValueError(f"Expected IndexExpr: {rhs}")
1010 return _BinaryExpr(op, self, rhs)
1012 def __add__(self, rhs) -> "_BinaryExpr":
1013 """Defines the operator +.
1015 Args:
1016 rhs: The value being added, which could be any Python object from user
1017 inputs.
1019 Returns:
1020 A _BinaryExpr object representing the operation.
1022 Raises:
1023 ValueError: If rhs is not an IndexExpr.
1025 return self._verify_operand_and_build_expr(rhs, operator.add)
1027 def __mul__(self, rhs) -> "_BinaryExpr":
1028 """Defines the operator *.
1030 Args:
1031 rhs: The value being multiplied, which could be any Python object from
1032 user inputs.
1034 Returns:
1035 A _BinaryExpr object representing the operation.
1037 Raises:
1038 ValueError: If rhs is not an IndexExpr.
1040 return self._verify_operand_and_build_expr(rhs, operator.mul)
1042 def __sub__(self, rhs) -> "_BinaryExpr":
1043 """Defines the operator -.
1045 Args:
1046 rhs: The value being subtracted, which could be any Python object from
1047 user inputs.
1049 Returns:
1050 A _BinaryExpr object representing the operation.
1052 Raises:
1053 ValueError: If rhs is not an IndexExpr.
1055 return self._verify_operand_and_build_expr(rhs, operator.sub)
1057 @abc.abstractmethod
1058 def _visit(self,
1059 func: _ExprVisitor,
1060 args,
1062 leaf_checker: _SubtreeLeafChecker = None) -> None:
1063 """A post-order visitor.
1065 Args:
1066 func: A callable applied to each node in the expression tree.
1067 args: The variable-length arguments passed to the callable. These
1068 arguments are grouped as an iterable and will be unpacked before passing
1069 to the callable. This is to enable the keyword argument only syntax
1070 after this argument.
1071 leaf_checker: A callable object to identify nodes that should be treated
1072 as leaf nodes to support partial tree visiting.
1074 pass
1076 @abc.abstractmethod
1077 def _emit_expression(
1078 self,
1079 expr_to_opnd: Dict["IndexExpr", lang.OperandDef],
1080 expr_to_info: _ExprInfoDict,
1081 ) -> lang.ScalarExpression:
1082 """Emits MLIR for the expression tree.
1084 Args:
1085 expr_to_opnd: A dictionary for looking up structured op input operands for
1086 the input nodes of the structured op.
1087 expr_to_info: A dictionary for looking up code generation information for
1088 expressions.
1090 Returns:
1091 A linalg dialect ScalarExpression for the expression.
1093 pass
1095 @abc.abstractmethod
1096 def dtype(self) -> DType:
1097 """Returns the data type for the result of the expression."""
1098 pass
1100 def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None:
1101 """Emits a structured op in the linalg dialect for the expression tree.
1103 We define a DefineOpcallable in the domain specific language for the linalg
1104 dialect and execute the callable to generate the structured op. Self is the
1105 root of the expression tree for the structured op.
1107 Args:
1108 expr_to_info: A dictionary for looking up code generation information for
1109 expressions.
1111 op_info = expr_to_info[self].structop_info
1112 op_name = op_info.dst_name
1113 op_def = lang.LinalgOpDef(name=op_name)
1114 op_callable = lang.DefinedOpCallable(op_name, op_def)
1116 # Collect the input expression nodes for the structured op.
1117 expr_inputs = []
1118 self._visit(
1119 _gather_structured_op_input,
1120 (self, expr_to_info, expr_inputs),
1121 leaf_checker=_is_structured_op_leaf,
1124 # Create a linalg structured op operand for each input expression node and
1125 # build a dictionary for looking up the information.
1126 expr_to_input_opnd = {
1127 e: _emit_structured_op_input(e, expr_to_info, op_def)
1128 for e in expr_inputs
1131 # Emit the expression tree, which produces the value assigned to the
1132 # destination tensor.
1133 value = self._emit_expression(expr_to_input_opnd, expr_to_info)
1134 # Emit the structured op representation for the destination tensor.
1135 dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name,
1136 lang.OperandKind.OutputTensor)
1137 dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
1138 dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
1140 expr_info = expr_to_info[self]
1141 # If the structured op reduces some indices, explicitly represent the
1142 # reduction. This is done by generating a ReduceFn for the dimensions being
1143 # reduced in the linalg dialect and calling the function with the value
1144 # being reduced. We only support add reduction currently.
1145 if expr_info.reduce_indices:
1146 reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices)
1147 value = lang.ReduceFn.add[reduce_dims](value)
1149 # Emit the assignment as a comprehension in the linalg dialect.
1150 comp = lang.Comprehension((dst_use, value))
1151 op_def.comprehensions.append(comp)
1153 # The structured op in the linalg dialect requires an explicit
1154 # initialization for the destination tensor. Emit MLIR to initialize the
1155 # destination tensor.
1156 init = op_info.emit_tensor_init()
1158 # Collect MLIR values for the linalg input operands, with the assumption
1159 # that dictionary preserves the insertion order.
1160 args = [
1161 expr_to_info[expr].mlir_value
1162 for expr, opnd in expr_to_input_opnd.items()
1164 # Execute the DefineOpcallable object for the linalg dialect operation to
1165 # emit MLIR for the linalg structured op.
1166 expr_info.mlir_value = op_callable(*args, outs=[init])
1168 def _identify_structured_ops(
1169 self,
1170 expr_to_info: _ExprInfoDict,
1171 dst: Tensor,
1172 dst_indices: Tuple[IndexVar, ...],
1173 ) -> List["IndexExpr"]:
1174 """Returns expression nodes for the roots of the identified structured ops.
1176 A structured op in the linalg dialect only supports reduction performed on
1177 the whole expression. If the expression tree contains reduction that are
1178 performed on part of the expression tree, the expression tree needs to be
1179 implemented with multiple structured ops. This routine identifies all the
1180 expression nodes that contain reduction as the root of structured ops in the
1181 linalg dialect.
1183 Args:
1184 expr_to_info: A dictionary for looking up code generation information for
1185 expressions.
1186 dst: A destination Tensor that accepts the value of the expression tree.
1187 dst_indices: The indices used by the destination index expression.
1189 Returns:
1190 An ordered list of IndexExpr for the root expressions of the structured
1191 ops, where child expressions go before parent expressions that use their
1192 results.
1194 reduce_indices = tuple(
1195 set(expr_to_info[self].src_indices) - set(dst_indices))
1196 for reduce_index in reduce_indices:
1197 _mark_structured_op_root(self, reduce_index, expr_to_info)
1199 self._visit(_accumulate_reduce_indices, (expr_to_info,))
1200 structop_roots = []
1201 self._visit(_gather_structured_op, (expr_to_info, structop_roots))
1203 # Handle the root of the top level expression.
1204 if not structop_roots or structop_roots[-1] != self:
1205 # The top level expression is not a reduction. Add the top level
1206 # expression as a structured op root.
1207 structop_roots.append(self)
1209 # Use user specified information for the destination tensor to build an
1210 # _StructOpInfo for the top level expression.
1211 expr_to_info[self].structop_info = _StructOpInfo(dst_indices,
1212 tuple(dst.shape),
1213 self.dtype(), dst.name,
1214 dst.format)
1216 return structop_roots
1218 def _validate_and_collect_expr_info(
1219 self,
1220 dst: Tensor,
1221 dst_indices: Tuple[IndexVar, ...],
1222 ) -> _ExprInfoDict:
1223 """Propagates expression information for validation.
1225 Propagates the indices used by child expression nodes to parent expression
1226 nodes. Also collects and validates the sizes for the dimensions
1227 corresponding to the indices.
1229 Args:
1230 dst: A destination Tensor that accepts the value of the expression tree.
1231 dst_indices: The indices used by the destination index expression.
1233 Raises:
1234 ValueError if there is any inconsistency in indices or dimensional
1235 values.
1237 Returns:
1238 A dictionary of (IndexExpr, _ExprInfo).
1240 expr_to_info = {}
1241 # Validate the expression tree and construct expression information.
1242 self._visit(_validate_and_collect_expr_info, (expr_to_info,))
1244 # Validate the destination dimension information.
1245 info = expr_to_info[self]
1246 index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)}
1247 for i, d, in zip(dst_indices, dst.shape):
1248 if i not in index_to_dim_info:
1249 raise ValueError("Destination IndexVar not used in the "
1250 f"source expression: {i}")
1251 else:
1252 if d != index_to_dim_info[i].dim:
1253 raise ValueError(f"Inconsistent destination dimension for {i}: "
1254 f"{d} vs {index_to_dim_info[i].dim}")
1256 return expr_to_info
1258 def _emit_assignment(
1259 self,
1260 module: ir.Module,
1261 dst: Tensor,
1262 dst_indices: Tuple[IndexVar, ...],
1263 expr_to_info: _ExprInfoDict,
1264 input_accesses: List["Access"],
1265 ) -> None:
1266 """Emits an MLIR function for assigning the expression to a tensor."""
1267 input_types = [a.tensor.mlir_tensor_type() for a in input_accesses]
1269 # Build the kernel for the operations.
1270 with ir.InsertionPoint(module.body):
1272 @builtin.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME)
1273 def linalg_funcop(*args):
1274 # Set up the mapping from the Access nodes to their MLIR values.
1275 for e, mlir in zip(input_accesses, args):
1276 expr_to_info[e].mlir_value = mlir
1278 # Emit structured ops in the linalg dialect to implement the assignment.
1279 for structop_root in self._identify_structured_ops(
1280 expr_to_info, dst, dst_indices):
1281 structop_root._emit_structured_op(expr_to_info)
1282 dst._record_stats(expr_to_info[structop_root].structop_info)
1284 # The function returns the MLIR value of the root expression.
1285 return expr_to_info[self].mlir_value
1287 linalg_funcop.func_op.attributes[
1288 "llvm.emit_c_interface"] = ir.UnitAttr.get()
1290 def evaluate(
1291 self,
1292 dst: Tensor,
1293 dst_indices: Tuple[IndexVar, ...],
1294 ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
1295 """Evaluates tensor assignment dst[dst_indices] = expression.
1297 Args:
1298 dst: The destination tensor.
1299 dst_indices: The tuple of IndexVar used to access the destination tensor.
1301 Returns:
1302 The result of the dense tensor represented in numpy ndarray or the sparse
1303 tensor represented by two numpy ndarray for its non-zero values and
1304 indices.
1306 Raises:
1307 ValueError: If the expression is not proper or not supported.
1309 expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
1311 # Compute a list of input accesses.
1312 input_accesses = []
1313 self._visit(_gather_input_accesses_index_vars, (input_accesses,))
1315 support_lib = _get_support_lib_name()
1316 # Build and compile the module to produce the execution engine.
1317 with ir.Context(), ir.Location.unknown():
1318 module = ir.Module.create()
1319 self._emit_assignment(module, dst, dst_indices, expr_to_info,
1320 input_accesses)
1321 compiled_module = _compile_mlir(module)
1323 # We currently rely on an environment to pass in the full path of a
1324 # supporting library for the execution engine.
1325 engine = execution_engine.ExecutionEngine(
1326 compiled_module, opt_level=_OPT_LEVEL, shared_libs=[support_lib])
1328 # Gather the pointers for the input buffers.
1329 input_pointers = [a.tensor.ctype_pointer() for a in input_accesses]
1330 if dst.is_dense():
1331 # The pointer to receive dense output is the first argument to the
1332 # execution engine.
1333 arg_pointers = [dst.dense_dst_ctype_pointer()] + input_pointers
1334 else:
1335 # The pointer to receive sparse output is the last argument to the
1336 # execution engine. The pointer to receive a sparse tensor output is a
1337 # pointer to pointer of char.
1338 arg_pointers = input_pointers + [
1339 ctypes.pointer(ctypes.pointer(ctypes.c_char(0)))
1342 # Invoke the execution engine to run the module and return the result.
1343 engine.invoke(_ENTRY_NAME, *arg_pointers)
1345 if dst.is_dense():
1346 return runtime.ranked_memref_to_numpy(arg_pointers[0][0])
1348 # Check and return the sparse tensor output.
1349 rank, nse, shape, values, indices = utils.sparse_tensor_to_coo_tensor(
1350 support_lib,
1351 ctypes.cast(arg_pointers[-1][0], ctypes.c_void_p),
1352 np.float64,
1354 assert (np.equal(rank, dst.order)
1355 and np.array_equal(shape, np.array(dst.shape)) and
1356 np.equal(values.ndim, 1) and np.equal(values.shape[0], nse) and
1357 np.equal(indices.ndim, 2) and np.equal(indices.shape[0], nse) and
1358 np.equal(indices.shape[1], rank))
1359 return (values, indices)
1362 @dataclasses.dataclass(frozen=True)
1363 class Access(IndexExpr):
1364 """The tensor access class.
1366 We support the TACO API access class with an alias of this class.
1368 Attributes:
1369 tensor: A Tensor being accessed.
1370 indices: A tuple of IndexVar, representing the indices used to access the
1371 Tensor.
1373 tensor: Tensor
1374 indices: Tuple[IndexVar, ...]
1376 def __post_init__(self) -> None:
1377 """Verifies the tensor and indices for a tensor access.
1379 Raises:
1380 ValueError: If indices is not a list of IndexVar or the len of indices
1381 doesn't equal to the rank of the tensor.
1383 if (not isinstance(self.indices, tuple) or
1384 not _all_instance_of(self.indices, IndexVar)):
1385 raise ValueError(f"Indices contain non IndexVar: {str(self.indices)}.")
1386 if self.tensor.order != len(self.indices):
1387 raise ValueError("Invalid indices for rank: "
1388 f"str{self.tensor.order} != len({str(self.indices)}).")
1390 def _emit_expression(
1391 self,
1392 expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
1393 expr_to_info: _ExprInfoDict,
1394 ) -> lang.ScalarExpression:
1395 """Emits a linalg dialect TensorUse expression for the tensor access."""
1396 assert self in expr_to_opnd
1397 dims = _mlir_dimensions_from_index_vars(self.indices)
1398 return lang.TensorUse(expr_to_opnd[self], dims)
1400 def _visit(self,
1401 func: _ExprVisitor,
1402 args,
1404 leaf_checker: _SubtreeLeafChecker = None) -> None:
1405 if leaf_checker:
1406 assert leaf_checker(self, *args)
1407 func(self, *args)
1409 def dtype(self) -> DType:
1410 return self.tensor.dtype
1413 def _gather_input_accesses_index_vars(
1414 expr: IndexExpr,
1415 input_accesses: List[Access],
1416 ) -> None:
1417 """Collects Access nodes."""
1418 if isinstance(expr, Access) and expr not in input_accesses:
1419 input_accesses.append(expr)
1422 def _op_to_callable(op: _BinaryOp) -> lang.ArithFnType:
1423 """Returns the linalg dialect function object for the given operation."""
1424 op_to_callable = {
1425 operator.add: lang.ArithFn.add,
1426 operator.sub: lang.ArithFn.sub,
1427 operator.mul: lang.ArithFn.mul,
1429 return op_to_callable[op]
1432 @dataclasses.dataclass(frozen=True)
1433 class _BinaryExpr(IndexExpr):
1434 """The representation for a binary operation.
1436 Attributes:
1437 op: A _BinaryOp representing the binary operation.
1438 a: An IndexExpr representing the first operand of the operation.
1439 b: An IndexExpr representing the second operand of the operation.
1441 op: _BinaryOp
1442 a: IndexExpr
1443 b: IndexExpr
1445 def __post_init__(self) -> None:
1446 """Verifies that the operands being added are IndexExpr."""
1447 assert isinstance(self.a, IndexExpr) and isinstance(self.b, IndexExpr)
1449 def _emit_expression(
1450 self,
1451 expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
1452 expr_to_info: _ExprInfoDict,
1453 ) -> lang.ScalarExpression:
1454 """Emits the expression tree and returns the expression."""
1455 # The current expression node is an internal node of the structured op.
1456 if self not in expr_to_opnd:
1457 a = self.a._emit_expression(expr_to_opnd, expr_to_info)
1458 b = self.b._emit_expression(expr_to_opnd, expr_to_info)
1459 return _op_to_callable(self.op)(a, b)
1461 # The current expression is a leaf node of the structured op. That is, it is
1462 # a temporary tensor generated by its child structured op.
1463 op_info = expr_to_info[self].structop_info
1464 assert op_info is not None
1465 dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
1466 return lang.TensorUse(expr_to_opnd[self], dims)
1468 def _visit(self,
1469 func: _ExprVisitor,
1470 args,
1472 leaf_checker: _SubtreeLeafChecker = None) -> None:
1473 """A post-order visitor."""
1474 if leaf_checker is None or not leaf_checker(self, *args):
1475 self.a._visit(func, args, leaf_checker=leaf_checker)
1476 self.b._visit(func, args, leaf_checker=leaf_checker)
1477 func(self, *args)
1479 def dtype(self) -> DType:
1480 """Returns the data type of the binary operation."""
1481 return self.a.dtype()
1484 def _validate_and_collect_dim_info(
1485 index_to_dim_info: Dict[IndexVar, _DimInfo],
1486 indices: Tuple[IndexVar, ...],
1487 dim_infos: Tuple[_DimInfo, ...],
1488 expr: _BinaryExpr,
1489 ) -> None:
1490 """Validates and collects the dimension information for an index notation.
1492 Validates (indices, dim_infos) against the information collected from other
1493 source operands and is represented by index_to_dim_info. In particular, we
1494 ensure that each IndexVar corresponds to only one dimension size. We also
1495 aggregate the new information represented in (indices, dim_infos) to
1496 index_to_dim_info.
1498 Args:
1499 index_to_dim: A dictionary of (IndexVar, _DimInfo) collected from the
1500 previous operands.
1501 indices: The IndexVars to be validated.
1502 dim_infos: The dimension information for the IndexVars to be validated.
1503 expr: The binary expression where (indices, dim_infos) is used.
1505 Raises:
1506 ValueError if there is any problem in the IndexVars or dimensional values.
1508 assert len(indices) == len(dim_infos)
1509 for i, d in zip(indices, dim_infos):
1510 if i not in index_to_dim_info:
1511 index_to_dim_info[i] = d
1512 else:
1513 if d.dim != index_to_dim_info[i].dim:
1514 raise ValueError(f"Inconsistent source dimension for {i}: "
1515 f"{d.dim} vs {index_to_dim_info[i].dim}")
1516 mode_format = _mode_format_estimator(expr.op)(
1517 index_to_dim_info[i].mode_format, d.mode_format)
1518 index_to_dim_info[i] = _DimInfo(d.dim, mode_format)
1521 def _validate_and_collect_expr_info(
1522 expr: IndexExpr,
1523 expr_to_info: _ExprInfoDict,
1524 ) -> None:
1525 """Validates dimension information and constructs _ExprInfo.
1527 Validates that dimensional values for the same IndexVar are the same. Collects
1528 a list of IndexVar used by the expression and their corresponding dimensional
1529 values. Constructs an _ExprInfo object to record the information for the
1530 IndexExpr.
1532 This routine is passed to the post-order visitor as an _ExprVisitor object.
1534 Args:
1535 expr: The IndexExpr being validated.
1536 expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
1537 expression information.
1539 Raises:
1540 ValueError if there is any problem in the IndexVars or dimensional values.
1542 # Objects of class Access can be shared by different expressions. Avoid
1543 # processing Access objects multiple times by skipping the processing if expr
1544 # is already in the dictionary.
1545 if expr in expr_to_info:
1546 return
1548 if isinstance(expr, Access):
1549 src_indices = expr.indices
1550 src_dims = tuple(expr.tensor.shape)
1551 mode_formats = tuple(expr.tensor.format.format_pack.formats)
1552 assert len(src_dims) == len(mode_formats)
1553 dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)])
1554 else:
1555 assert isinstance(expr, _BinaryExpr)
1556 a_info = expr_to_info[expr.a]
1557 index_to_dim_info = {
1558 i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)
1560 b_info = expr_to_info[expr.b]
1561 _validate_and_collect_dim_info(index_to_dim_info, b_info.src_indices,
1562 b_info.dim_infos, expr)
1563 # Here we rely on the fact that dictionaries keep the insertion order for
1564 # keys and values.
1565 src_indices = tuple(index_to_dim_info.keys())
1566 dim_infos = tuple(index_to_dim_info.values())
1568 expr_to_info[expr] = _ExprInfo(src_indices, dim_infos)
1571 def _mark_structured_op_root(
1572 expr: IndexExpr,
1573 reduce_index: IndexVar,
1574 expr_to_info: _ExprInfoDict,
1575 ) -> None:
1576 """Identifies the root expression for a structured op in the linalg dialect.
1578 An linalg structured op can only perform reduction on the whole expression.
1579 For a TACO tensor algebra expression, the reduction on an IndexVar is done at
1580 the smallest expression that contains all the uses of the IndexVar. If such an
1581 expression is only part of the whole expression, we need to split this
1582 sub-expression tree out from its parent and implement the sub-expression as a
1583 structured op.
1585 This routine identifies the root expression node for performing a reduction on
1586 the given IndexVar. If the reduction of the given IndexVar should be performed
1587 on expression X, then the IndexVar is added to expr_to_info[X].reduce_indices
1589 Args:
1590 expr: The root IndexExpr for the tensor algebra expression.
1591 reduce_index: The IndexVar which we want to find out the proper expression
1592 to perform a reduction.
1593 expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
1595 assert (isinstance(expr, _BinaryExpr))
1596 a_info = expr_to_info[expr.a]
1597 b_info = expr_to_info[expr.b]
1598 expr_info = expr_to_info[expr]
1600 if reduce_index in a_info.src_indices and reduce_index in b_info.src_indices:
1601 expr_info.reduce_indices.add(reduce_index)
1602 return
1604 if reduce_index in a_info.src_indices:
1605 _mark_structured_op_root(expr.a, reduce_index, expr_to_info)
1606 elif reduce_index in b_info.src_indices:
1607 _mark_structured_op_root(expr.b, reduce_index, expr_to_info)
1608 else:
1609 assert False, "Unreachable path"
1612 def _accumulate_reduce_indices(
1613 expr: IndexExpr,
1614 expr_to_info: _ExprInfoDict,
1615 ) -> None:
1616 """Propagates reduction indices from child expressions to parent expressions.
1618 This routine is passed to the post-order visitor as an _ExprVisitor object.
1620 Args:
1621 expr: The IndexExpr being visited.
1622 expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
1623 expression information.
1625 assert expr in expr_to_info
1626 expr_info = expr_to_info[expr]
1628 if isinstance(expr, _BinaryExpr):
1629 a_info = expr_to_info[expr.a]
1630 b_info = expr_to_info[expr.b]
1631 expr_info.acc_reduce_indices = (
1632 a_info.acc_reduce_indices | b_info.acc_reduce_indices
1633 | expr_info.reduce_indices)
1634 else:
1635 assert isinstance(expr, Access)
1638 def _gather_structured_op(
1639 expr: IndexExpr,
1640 expr_to_info: _ExprInfoDict,
1641 structop_roots: List[IndexExpr],
1642 ) -> None:
1643 """Adds structured op root expression information to structop_roots.
1645 This routine is passed to the post-order visitor as an _ExprVisitor object.
1647 Args:
1648 expr: The IndexExpr being visited.
1649 expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
1650 structop_roots: The resulting list of IndexExpr that are the roots for
1651 linalg structured ops.
1653 if not expr_to_info[expr].reduce_indices:
1654 return
1656 # If the expression is the root for reducing some indices, collect the indices
1657 # and dimensions for the reduction result.
1658 dst_indices = []
1659 dst_dims = []
1660 mode_fmts = []
1661 for i, d in zip(expr_to_info[expr].src_indices, expr_to_info[expr].dim_infos):
1662 if i not in expr_to_info[expr].acc_reduce_indices:
1663 dst_indices.append(i)
1664 dst_dims.append(d.dim)
1665 mode_fmts.append(d.mode_format)
1667 # Add the information to the dictionary.
1668 op_info = _StructOpInfo(
1669 tuple(dst_indices),
1670 tuple(dst_dims),
1671 expr.dtype(),
1672 f"temp{len(structop_roots)}",
1673 _make_format(mode_fmts),
1675 expr_to_info[expr].structop_info = op_info
1677 # Add the expression to the list of structured op roots.
1678 structop_roots.append(expr)
1681 def _is_structured_op_leaf(
1682 expr: IndexExpr,
1683 root: IndexExpr,
1684 expr_to_info: _ExprInfoDict,
1685 *unused_args,
1686 ) -> bool:
1687 """Returns true iff the expression is a leaf node for a structured op.
1689 The root of a structured op is a leaf of its parent structured op that uses
1690 its result. An expression node is a leaf node for the current structured op if
1691 it is an Access node or the root for a structured op that is not the current
1692 structured op.
1694 This routine is passed to the post-order visitor as a _SubtreeLeafChecker
1695 object. Because the post-order visitor pass the same parameters to both
1696 _SubtreeLeafChecker and _ExprVisitor, this routine may received unused
1697 parameters.
1699 Args:
1700 expr: The IndexExpr being visited.
1701 root: The root of the current structured op.
1702 expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
1704 Returns:
1705 True if the current IndexExpr is a leaf for the current structured op.
1707 return (expr != root and
1708 expr_to_info[expr].structop_info is not None) or isinstance(
1709 expr, Access)
1712 def _gather_structured_op_input(
1713 expr: IndexExpr,
1714 root: IndexExpr,
1715 expr_to_info: _ExprInfoDict,
1716 structop_inputs: List[IndexExpr],
1717 ) -> None:
1718 """Adds the IndexExpr to structop_inputs if it is an input.
1720 If the current IndexExpr is an input for the current structured op, adds it to
1721 structop_inputs. The current IndexExpr is an input if it is an Access node or
1722 if it is the root for a structured op that is not the current structured op.
1724 This routine is passed to the post-order visitor as an _ExprVisitor object.
1726 Args:
1727 expr: The IndexExpr being visited.
1728 root: The root of the current structured op.
1729 expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
1730 structop_inputs: The resulting list of IndexExpr that provide input to the
1731 current structured op.
1733 if (expr != root and expr not in structop_inputs) and (
1734 isinstance(expr, Access) or
1735 (expr in expr_to_info and expr_to_info[expr].structop_info)):
1736 structop_inputs.append(expr)
1739 def _emit_structured_op_input(
1740 expr: IndexExpr,
1741 expr_to_info: _ExprInfoDict,
1742 op_def: lang.LinalgOpDef,
1743 ) -> lang.OperandDef:
1744 """Emits OperandDef in the linalg dialect for the input IndexExpr.
1746 Args:
1747 expr: The input IndexExpr for the current structured op.
1748 expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
1749 op_def: The linalg operation for the current structured op.
1751 Returns:
1752 An OperandDef in the linalg dialect for the input IndexExpr.
1754 op_info = expr_to_info[expr].structop_info
1755 if op_info:
1756 # The input is a temporary tensor produced by another structured op.
1757 indices = op_info.dst_indices
1758 name = op_info.dst_name
1759 else:
1760 # The input is a user provided tensor.
1761 assert isinstance(expr, Access)
1762 indices = expr.indices
1763 name = expr.tensor.name
1765 dim_sym = _mlir_symbols_from_index_vars(indices)
1766 opnd = lang.OperandDef(lang.OperandKind.InputTensor, lang.T, dim_sym)
1767 op_def.add_operand(name, opnd)
1768 return opnd