[Clang] ensure mangled names are valid identifiers before being suggested in ifunc...
[llvm-project.git] / flang / lib / Evaluate / fold-matmul.h
blobc3d65a90409098b2e3a6b8897f4449df7eb0e1aa
1 //===-- lib/Evaluate/fold-matmul.h ----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #ifndef FORTRAN_EVALUATE_FOLD_MATMUL_H_
10 #define FORTRAN_EVALUATE_FOLD_MATMUL_H_
12 #include "fold-implementation.h"
14 namespace Fortran::evaluate {
16 template <typename T>
17 static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
18 using Element = typename Constant<T>::Element;
19 auto args{funcRef.arguments()};
20 CHECK(args.size() == 2);
21 Folder<T> folder{context};
22 Constant<T> *ma{folder.Folding(args[0])};
23 Constant<T> *mb{folder.Folding(args[1])};
24 if (!ma || !mb) {
25 return Expr<T>{std::move(funcRef)};
27 CHECK(ma->Rank() >= 1 && ma->Rank() <= 2 && mb->Rank() >= 1 &&
28 mb->Rank() <= 2 && (ma->Rank() == 2 || mb->Rank() == 2));
29 ConstantSubscript commonExtent{ma->shape().back()};
30 if (mb->shape().front() != commonExtent) {
31 context.messages().Say(
32 "Arguments to MATMUL have distinct extents %zd and %zd on their last and first dimensions"_err_en_US,
33 commonExtent, mb->shape().front());
34 return MakeInvalidIntrinsic(std::move(funcRef));
36 ConstantSubscript rows{ma->Rank() == 1 ? 1 : ma->shape()[0]};
37 ConstantSubscript columns{mb->Rank() == 1 ? 1 : mb->shape()[1]};
38 std::vector<Element> elements;
39 elements.reserve(rows * columns);
40 bool overflow{false};
41 [[maybe_unused]] const auto &rounding{
42 context.targetCharacteristics().roundingMode()};
43 // result(j,k) = SUM(A(j,:) * B(:,k))
44 for (ConstantSubscript ci{0}; ci < columns; ++ci) {
45 for (ConstantSubscript ri{0}; ri < rows; ++ri) {
46 ConstantSubscripts aAt{ma->lbounds()};
47 if (ma->Rank() == 2) {
48 aAt[0] += ri;
50 ConstantSubscripts bAt{mb->lbounds()};
51 if (mb->Rank() == 2) {
52 bAt[1] += ci;
54 Element sum{};
55 [[maybe_unused]] Element correction{};
56 for (ConstantSubscript j{0}; j < commonExtent; ++j) {
57 Element aElt{ma->At(aAt)};
58 Element bElt{mb->At(bAt)};
59 if constexpr (T::category == TypeCategory::Real ||
60 T::category == TypeCategory::Complex) {
61 auto product{aElt.Multiply(bElt)};
62 overflow |= product.flags.test(RealFlag::Overflow);
63 if constexpr (useKahanSummation) {
64 auto next{product.value.Subtract(correction, rounding)};
65 overflow |= next.flags.test(RealFlag::Overflow);
66 auto added{sum.Add(next.value, rounding)};
67 overflow |= added.flags.test(RealFlag::Overflow);
68 correction = added.value.Subtract(sum, rounding)
69 .value.Subtract(next.value, rounding)
70 .value;
71 sum = std::move(added.value);
72 } else {
73 auto added{sum.Add(product.value)};
74 overflow |= added.flags.test(RealFlag::Overflow);
75 sum = std::move(added.value);
77 } else if constexpr (T::category == TypeCategory::Integer) {
78 // Don't use Kahan summation in numeric MATMUL folding;
79 // the runtime doesn't use it, and results should match.
80 auto product{aElt.MultiplySigned(bElt)};
81 overflow |= product.SignedMultiplicationOverflowed();
82 auto added{sum.AddSigned(product.lower)};
83 overflow |= added.overflow;
84 sum = std::move(added.value);
85 } else {
86 static_assert(T::category == TypeCategory::Logical);
87 sum = sum.OR(aElt.AND(bElt));
89 ++aAt.back();
90 ++bAt.front();
92 elements.push_back(sum);
95 if (overflow &&
96 context.languageFeatures().ShouldWarn(
97 common::UsageWarning::FoldingException)) {
98 context.messages().Say(common::UsageWarning::FoldingException,
99 "MATMUL of %s data overflowed during computation"_warn_en_US,
100 T::AsFortran());
102 ConstantSubscripts shape;
103 if (ma->Rank() == 2) {
104 shape.push_back(rows);
106 if (mb->Rank() == 2) {
107 shape.push_back(columns);
109 return Expr<T>{Constant<T>{std::move(elements), std::move(shape)}};
111 } // namespace Fortran::evaluate
112 #endif // FORTRAN_EVALUATE_FOLD_MATMUL_H_