1 //===-- lib/Evaluate/fold-matmul.h ----------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef FORTRAN_EVALUATE_FOLD_MATMUL_H_
10 #define FORTRAN_EVALUATE_FOLD_MATMUL_H_
12 #include "fold-implementation.h"
14 namespace Fortran::evaluate
{
17 static Expr
<T
> FoldMatmul(FoldingContext
&context
, FunctionRef
<T
> &&funcRef
) {
18 using Element
= typename Constant
<T
>::Element
;
19 auto args
{funcRef
.arguments()};
20 CHECK(args
.size() == 2);
21 Folder
<T
> folder
{context
};
22 Constant
<T
> *ma
{folder
.Folding(args
[0])};
23 Constant
<T
> *mb
{folder
.Folding(args
[1])};
25 return Expr
<T
>{std::move(funcRef
)};
27 CHECK(ma
->Rank() >= 1 && ma
->Rank() <= 2 && mb
->Rank() >= 1 &&
28 mb
->Rank() <= 2 && (ma
->Rank() == 2 || mb
->Rank() == 2));
29 ConstantSubscript commonExtent
{ma
->shape().back()};
30 if (mb
->shape().front() != commonExtent
) {
31 context
.messages().Say(
32 "Arguments to MATMUL have distinct extents %zd and %zd on their last and first dimensions"_err_en_US
,
33 commonExtent
, mb
->shape().front());
34 return MakeInvalidIntrinsic(std::move(funcRef
));
36 ConstantSubscript rows
{ma
->Rank() == 1 ? 1 : ma
->shape()[0]};
37 ConstantSubscript columns
{mb
->Rank() == 1 ? 1 : mb
->shape()[1]};
38 std::vector
<Element
> elements
;
39 elements
.reserve(rows
* columns
);
41 [[maybe_unused
]] const auto &rounding
{
42 context
.targetCharacteristics().roundingMode()};
43 // result(j,k) = SUM(A(j,:) * B(:,k))
44 for (ConstantSubscript ci
{0}; ci
< columns
; ++ci
) {
45 for (ConstantSubscript ri
{0}; ri
< rows
; ++ri
) {
46 ConstantSubscripts aAt
{ma
->lbounds()};
47 if (ma
->Rank() == 2) {
50 ConstantSubscripts bAt
{mb
->lbounds()};
51 if (mb
->Rank() == 2) {
55 [[maybe_unused
]] Element correction
{};
56 for (ConstantSubscript j
{0}; j
< commonExtent
; ++j
) {
57 Element aElt
{ma
->At(aAt
)};
58 Element bElt
{mb
->At(bAt
)};
59 if constexpr (T::category
== TypeCategory::Real
||
60 T::category
== TypeCategory::Complex
) {
61 auto product
{aElt
.Multiply(bElt
)};
62 overflow
|= product
.flags
.test(RealFlag::Overflow
);
63 if constexpr (useKahanSummation
) {
64 auto next
{product
.value
.Subtract(correction
, rounding
)};
65 overflow
|= next
.flags
.test(RealFlag::Overflow
);
66 auto added
{sum
.Add(next
.value
, rounding
)};
67 overflow
|= added
.flags
.test(RealFlag::Overflow
);
68 correction
= added
.value
.Subtract(sum
, rounding
)
69 .value
.Subtract(next
.value
, rounding
)
71 sum
= std::move(added
.value
);
73 auto added
{sum
.Add(product
.value
)};
74 overflow
|= added
.flags
.test(RealFlag::Overflow
);
75 sum
= std::move(added
.value
);
77 } else if constexpr (T::category
== TypeCategory::Integer
) {
78 // Don't use Kahan summation in numeric MATMUL folding;
79 // the runtime doesn't use it, and results should match.
80 auto product
{aElt
.MultiplySigned(bElt
)};
81 overflow
|= product
.SignedMultiplicationOverflowed();
82 auto added
{sum
.AddSigned(product
.lower
)};
83 overflow
|= added
.overflow
;
84 sum
= std::move(added
.value
);
86 static_assert(T::category
== TypeCategory::Logical
);
87 sum
= sum
.OR(aElt
.AND(bElt
));
92 elements
.push_back(sum
);
96 context
.languageFeatures().ShouldWarn(
97 common::UsageWarning::FoldingException
)) {
98 context
.messages().Say(common::UsageWarning::FoldingException
,
99 "MATMUL of %s data overflowed during computation"_warn_en_US
,
102 ConstantSubscripts shape
;
103 if (ma
->Rank() == 2) {
104 shape
.push_back(rows
);
106 if (mb
->Rank() == 2) {
107 shape
.push_back(columns
);
109 return Expr
<T
>{Constant
<T
>{std::move(elements
), std::move(shape
)}};
111 } // namespace Fortran::evaluate
112 #endif // FORTRAN_EVALUATE_FOLD_MATMUL_H_