1 /*===--------- amxtf32transposeintrin.h - AMX-TF32 and AMX-TRANSPOSE --------===
3 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 * See https://llvm.org/LICENSE.txt for license information.
5 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 *===------------------------------------------------------------------------===
11 "Never use <amxtf32tranposeintrin.h> directly; include <immintrin.h> instead."
12 #endif // __IMMINTRIN_H
14 #ifndef __AMX_TF32TRANSPOSEINTRIN_H
15 #define __AMX_TF32TRANSPOSEINTRIN_H
18 #define __DEFAULT_FN_ATTRS_TF32_TRANSPOSE \
19 __attribute__((__always_inline__, __nodebug__, \
20 __target__("amx-tf32,amx-transpose")))
23 /// void _tile_tmmultf32ps(constexpr int srcdst, constexpr int a, \
27 /// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction.
30 /// The destination tile. Max size is 1024 Bytes.
32 /// The 1st source tile. Max size is 1024 Bytes.
34 /// The 2nd source tile. Max size is 1024 Bytes.
37 /// DEFINE zero_lower_mantissa_bits_fp32(x[31:0]) {
39 /// dword[31:13] := x[31:13]
43 /// DEFINE silence_snan_fp32(x[31:0]) {
44 /// IF (x.exponent == 255 and x.fraction != 0 and x.fraction[22] == 0)
45 /// x.fraction[22] := 1
49 /// elements_dest:= srcdst.colsb/4
51 /// FOR m := 0 TO (srcdst.rows-1)
53 /// FOR k := 0 TO (a.rows-1)
54 /// FOR n := 0 TO (elements_dest-1)
55 /// a1e := silence_snan_fp32(a.row[k].fp32[m])
56 /// a2e := silence_snan_fp32(b.row[k].fp32[n])
57 /// s1e := zero_lower_mantissa_bits_fp32(a1e)
58 /// s2e := zero_lower_mantissa_bits_fp32(a2e)
59 /// tmp.fp32[n] += s1e * s2e
63 /// FOR n := 0 TO (elements_dest-1)
64 /// tmp.fp32[n] += srcdst.row[m].fp32[n]
66 /// write_row_and_zero(srcdst, m, tmp, srcdst.colsb)
70 /// zero_upper_rows(srcdst, srcdst.rows)
71 /// zero_tileconfig_start()
73 #define _tile_tmmultf32ps(srcdst, a, b) \
74 __builtin_ia32_ttmmultf32ps((srcdst), (a), (b))
76 // dst = m x n (srcdest), src1 = k x m, src2 = k x n
77 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_TF32_TRANSPOSE
78 _tile_tmmultf32ps_internal(unsigned short m
, unsigned short n
, unsigned short k
,
79 _tile1024i dst
, _tile1024i src1
, _tile1024i src2
) {
80 return __builtin_ia32_ttmmultf32ps_internal(m
, n
, k
, dst
, src1
, src2
);
83 /// Compute transpose and do Matrix Multiplication of src0 and src1, and then do
84 /// Matrix Plus with dst. All the calculation is base on float32 but with the
85 /// lower 13-bit set to 0.
87 /// \headerfile <immintrin.h>
89 /// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction.
92 /// The destination tile. Max size is 1024 Bytes.
94 /// The 1st source tile. Max size is 1024 Bytes.
96 /// The 2nd source tile. Max size is 1024 Bytes.
97 __DEFAULT_FN_ATTRS_TF32_TRANSPOSE
98 static void __tile_tmmultf32ps(__tile1024i
*dst
, __tile1024i src0
,
100 dst
->tile
= _tile_tmmultf32ps_internal(src0
.row
, src1
.col
, src0
.col
,
101 dst
->tile
, src0
.tile
, src1
.tile
);
105 #endif // __AMX_TF32TRANSPOSEINTRIN_H