1 //===-- lib/Evaluate/fold-matmul.h ----------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef FORTRAN_EVALUATE_FOLD_MATMUL_H_
10 #define FORTRAN_EVALUATE_FOLD_MATMUL_H_
12 #include "fold-implementation.h"
14 namespace Fortran::evaluate
{
17 static Expr
<T
> FoldMatmul(FoldingContext
&context
, FunctionRef
<T
> &&funcRef
) {
18 using Element
= typename Constant
<T
>::Element
;
19 auto args
{funcRef
.arguments()};
20 CHECK(args
.size() == 2);
21 Folder
<T
> folder
{context
};
22 Constant
<T
> *ma
{folder
.Folding(args
[0])};
23 Constant
<T
> *mb
{folder
.Folding(args
[1])};
25 return Expr
<T
>{std::move(funcRef
)};
27 CHECK(ma
->Rank() >= 1 && ma
->Rank() <= 2 && mb
->Rank() >= 1 &&
28 mb
->Rank() <= 2 && (ma
->Rank() == 2 || mb
->Rank() == 2));
29 ConstantSubscript commonExtent
{ma
->shape().back()};
30 if (mb
->shape().front() != commonExtent
) {
31 context
.messages().Say(
32 "Arguments to MATMUL have distinct extents %zd and %zd on their last and first dimensions"_err_en_US
,
33 commonExtent
, mb
->shape().front());
34 return MakeInvalidIntrinsic(std::move(funcRef
));
36 ConstantSubscript rows
{ma
->Rank() == 1 ? 1 : ma
->shape()[0]};
37 ConstantSubscript columns
{mb
->Rank() == 1 ? 1 : mb
->shape()[1]};
38 std::vector
<Element
> elements
;
39 elements
.reserve(rows
* columns
);
41 [[maybe_unused
]] const auto &rounding
{
42 context
.targetCharacteristics().roundingMode()};
43 // result(j,k) = SUM(A(j,:) * B(:,k))
44 for (ConstantSubscript ci
{0}; ci
< columns
; ++ci
) {
45 for (ConstantSubscript ri
{0}; ri
< rows
; ++ri
) {
46 ConstantSubscripts aAt
{ma
->lbounds()};
47 if (ma
->Rank() == 2) {
50 ConstantSubscripts bAt
{mb
->lbounds()};
51 if (mb
->Rank() == 2) {
55 [[maybe_unused
]] Element correction
{};
56 for (ConstantSubscript j
{0}; j
< commonExtent
; ++j
) {
57 Element aElt
{ma
->At(aAt
)};
58 Element bElt
{mb
->At(bAt
)};
59 if constexpr (T::category
== TypeCategory::Real
||
60 T::category
== TypeCategory::Complex
) {
62 auto product
{aElt
.Multiply(bElt
, rounding
)};
63 overflow
|= product
.flags
.test(RealFlag::Overflow
);
64 auto next
{correction
.Add(product
.value
, rounding
)};
65 overflow
|= next
.flags
.test(RealFlag::Overflow
);
66 auto added
{sum
.Add(next
.value
, rounding
)};
67 overflow
|= added
.flags
.test(RealFlag::Overflow
);
68 correction
= added
.value
.Subtract(sum
, rounding
)
69 .value
.Subtract(next
.value
, rounding
)
71 sum
= std::move(added
.value
);
72 } else if constexpr (T::category
== TypeCategory::Integer
) {
73 auto product
{aElt
.MultiplySigned(bElt
)};
74 overflow
|= product
.SignedMultiplicationOverflowed();
75 auto added
{sum
.AddSigned(product
.lower
)};
76 overflow
|= added
.overflow
;
77 sum
= std::move(added
.value
);
79 static_assert(T::category
== TypeCategory::Logical
);
80 sum
= sum
.OR(aElt
.AND(bElt
));
85 elements
.push_back(sum
);
89 context
.messages().Say(
90 "MATMUL of %s data overflowed during computation"_warn_en_US
,
93 ConstantSubscripts shape
;
94 if (ma
->Rank() == 2) {
95 shape
.push_back(rows
);
97 if (mb
->Rank() == 2) {
98 shape
.push_back(columns
);
100 return Expr
<T
>{Constant
<T
>{std::move(elements
), std::move(shape
)}};
102 } // namespace Fortran::evaluate
103 #endif // FORTRAN_EVALUATE_FOLD_MATMUL_H_