[Support] Remove unused includes (NFC) (#116752)
[llvm-project.git] / flang / runtime / transformational.cpp
blobab303bdef9b1d145e95b5c7d70811028cf809110
1 //===-- runtime/transformational.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 // Implements the transformational intrinsic functions of Fortran 2018 that
10 // rearrange or duplicate data without (much) regard to type. These are
11 // CSHIFT, EOSHIFT, PACK, RESHAPE, SPREAD, TRANSPOSE, and UNPACK.
13 // Many of these are defined in the 2018 standard with text that makes sense
14 // only if argument arrays have lower bounds of one. Rather than interpret
15 // these cases as implying a hidden constraint, these implementations
16 // work with arbitrary lower bounds. This may be technically an extension
17 // of the standard but it more likely to conform with its intent.
19 #include "flang/Runtime/transformational.h"
20 #include "copy.h"
21 #include "terminator.h"
22 #include "tools.h"
23 #include "flang/Common/float128.h"
24 #include "flang/Runtime/descriptor.h"
26 namespace Fortran::runtime {
28 // Utility for CSHIFT & EOSHIFT rank > 1 cases that determines the shift count
29 // for each of the vector sections of the result.
30 class ShiftControl {
31 public:
32 RT_API_ATTRS ShiftControl(const Descriptor &s, Terminator &t, int dim)
33 : shift_{s}, terminator_{t}, shiftRank_{s.rank()}, dim_{dim} {}
34 RT_API_ATTRS void Init(const Descriptor &source, const char *which) {
35 int rank{source.rank()};
36 RUNTIME_CHECK(terminator_, shiftRank_ == 0 || shiftRank_ == rank - 1);
37 auto catAndKind{shift_.type().GetCategoryAndKind()};
38 RUNTIME_CHECK(
39 terminator_, catAndKind && catAndKind->first == TypeCategory::Integer);
40 shiftElemLen_ = catAndKind->second;
41 if (shiftRank_ > 0) {
42 int k{0};
43 for (int j{0}; j < rank; ++j) {
44 if (j + 1 != dim_) {
45 const Dimension &shiftDim{shift_.GetDimension(k)};
46 lb_[k++] = shiftDim.LowerBound();
47 if (shiftDim.Extent() != source.GetDimension(j).Extent()) {
48 terminator_.Crash("%s: on dimension %d, SHIFT= has extent %jd but "
49 "ARRAY= has extent %jd",
50 which, k, static_cast<std::intmax_t>(shiftDim.Extent()),
51 static_cast<std::intmax_t>(source.GetDimension(j).Extent()));
55 } else if (auto count{GetInt64Safe(
56 shift_.OffsetElement<char>(), shiftElemLen_, terminator_)}) {
57 shiftCount_ = *count;
58 } else {
59 terminator_.Crash("%s: SHIFT= value exceeds 64 bits", which);
62 RT_API_ATTRS SubscriptValue GetShift(const SubscriptValue resultAt[]) const {
63 if (shiftRank_ > 0) {
64 SubscriptValue shiftAt[maxRank];
65 int k{0};
66 for (int j{0}; j < shiftRank_ + 1; ++j) {
67 if (j + 1 != dim_) {
68 shiftAt[k] = lb_[k] + resultAt[j] - 1;
69 ++k;
72 auto count{GetInt64Safe(
73 shift_.Element<char>(shiftAt), shiftElemLen_, terminator_)};
74 RUNTIME_CHECK(terminator_, count.has_value());
75 return *count;
76 } else {
77 return shiftCount_; // invariant count extracted in Init()
81 private:
82 const Descriptor &shift_;
83 Terminator &terminator_;
84 int shiftRank_;
85 int dim_;
86 SubscriptValue lb_[maxRank];
87 std::size_t shiftElemLen_;
88 SubscriptValue shiftCount_{};
91 // Fill an EOSHIFT result with default boundary values
92 static RT_API_ATTRS void DefaultInitialize(
93 const Descriptor &result, Terminator &terminator) {
94 auto catAndKind{result.type().GetCategoryAndKind()};
95 RUNTIME_CHECK(
96 terminator, catAndKind && catAndKind->first != TypeCategory::Derived);
97 std::size_t elementLen{result.ElementBytes()};
98 std::size_t bytes{result.Elements() * elementLen};
99 if (catAndKind->first == TypeCategory::Character) {
100 switch (int kind{catAndKind->second}) {
101 case 1:
102 Fortran::runtime::fill_n(result.OffsetElement<char>(), bytes, ' ');
103 break;
104 case 2:
105 Fortran::runtime::fill_n(result.OffsetElement<char16_t>(), bytes / 2,
106 static_cast<char16_t>(' '));
107 break;
108 case 4:
109 Fortran::runtime::fill_n(result.OffsetElement<char32_t>(), bytes / 4,
110 static_cast<char32_t>(' '));
111 break;
112 default:
113 terminator.Crash(
114 "not yet implemented: CHARACTER(KIND=%d) in EOSHIFT intrinsic", kind);
116 } else {
117 std::memset(result.raw().base_addr, 0, bytes);
121 static inline RT_API_ATTRS std::size_t AllocateResult(Descriptor &result,
122 const Descriptor &source, int rank, const SubscriptValue extent[],
123 Terminator &terminator, const char *function) {
124 std::size_t elementLen{source.ElementBytes()};
125 const DescriptorAddendum *sourceAddendum{source.Addendum()};
126 result.Establish(source.type(), elementLen, nullptr, rank, extent,
127 CFI_attribute_allocatable, sourceAddendum != nullptr);
128 if (sourceAddendum) {
129 *result.Addendum() = *sourceAddendum;
131 for (int j{0}; j < rank; ++j) {
132 result.GetDimension(j).SetBounds(1, extent[j]);
134 if (int stat{result.Allocate()}) {
135 terminator.Crash(
136 "%s: Could not allocate memory for result (stat=%d)", function, stat);
138 return elementLen;
141 template <TypeCategory CAT, int KIND>
142 static inline RT_API_ATTRS std::size_t AllocateBesselResult(Descriptor &result,
143 int32_t n1, int32_t n2, Terminator &terminator, const char *function) {
144 int rank{1};
145 SubscriptValue extent[maxRank];
146 for (int j{0}; j < maxRank; j++) {
147 extent[j] = 0;
149 if (n1 <= n2) {
150 extent[0] = n2 - n1 + 1;
153 std::size_t elementLen{Descriptor::BytesFor(CAT, KIND)};
154 result.Establish(TypeCode{CAT, KIND}, elementLen, nullptr, rank, extent,
155 CFI_attribute_allocatable, false);
156 for (int j{0}; j < rank; ++j) {
157 result.GetDimension(j).SetBounds(1, extent[j]);
159 if (int stat{result.Allocate()}) {
160 terminator.Crash(
161 "%s: Could not allocate memory for result (stat=%d)", function, stat);
163 return elementLen;
166 template <TypeCategory CAT, int KIND>
167 static inline RT_API_ATTRS void DoBesselJn(Descriptor &result, int32_t n1,
168 int32_t n2, CppTypeFor<CAT, KIND> x, CppTypeFor<CAT, KIND> bn2,
169 CppTypeFor<CAT, KIND> bn2_1, const char *sourceFile, int line) {
170 Terminator terminator{sourceFile, line};
171 AllocateBesselResult<CAT, KIND>(result, n1, n2, terminator, "BESSEL_JN");
173 // The standard requires that n1 and n2 be non-negative. However, some other
174 // compilers generate results even when n1 and/or n2 are negative. For now,
175 // we also do not enforce the non-negativity constraint.
176 if (n2 < n1) {
177 return;
180 SubscriptValue at[maxRank];
181 for (int j{0}; j < maxRank; ++j) {
182 at[j] = 0;
185 // if n2 >= n1, there will be at least one element in the result.
186 at[0] = n2 - n1 + 1;
187 *result.Element<CppTypeFor<CAT, KIND>>(at) = bn2;
189 if (n2 == n1) {
190 return;
193 at[0] = n2 - n1;
194 *result.Element<CppTypeFor<CAT, KIND>>(at) = bn2_1;
196 // Bessel functions of the first kind are stable for a backward recursion
197 // (see https://dlmf.nist.gov/10.74.iv and https://dlmf.nist.gov/10.6.E1).
199 // J(n-1, x) = (2.0 / x) * n * J(n, x) - J(n+1, x)
201 // which is equivalent to
203 // J(n, x) = (2.0 / x) * (n + 1) * J(n+1, x) - J(n+2, x)
205 CppTypeFor<CAT, KIND> bn_2 = bn2;
206 CppTypeFor<CAT, KIND> bn_1 = bn2_1;
207 CppTypeFor<CAT, KIND> twoOverX = 2.0 / x;
208 for (int n{n2 - 2}; n >= n1; --n) {
209 auto bn = twoOverX * (n + 1) * bn_1 - bn_2;
211 at[0] = n - n1 + 1;
212 *result.Element<CppTypeFor<CAT, KIND>>(at) = bn;
214 bn_2 = bn_1;
215 bn_1 = bn;
219 template <TypeCategory CAT, int KIND>
220 static inline RT_API_ATTRS void DoBesselJnX0(Descriptor &result, int32_t n1,
221 int32_t n2, const char *sourceFile, int line) {
222 Terminator terminator{sourceFile, line};
223 AllocateBesselResult<CAT, KIND>(result, n1, n2, terminator, "BESSEL_JN");
225 // The standard requires that n1 and n2 be non-negative. However, some other
226 // compilers generate results even when n1 and/or n2 are negative. For now,
227 // we also do not enforce the non-negativity constraint.
228 if (n2 < n1) {
229 return;
232 SubscriptValue at[maxRank];
233 for (int j{0}; j < maxRank; ++j) {
234 at[j] = 0;
237 // J(0, 0.0) = 1.0, when n == 0.
238 // J(n, 0.0) = 0.0, when n > 0.
239 at[0] = 1;
240 *result.Element<CppTypeFor<CAT, KIND>>(at) = (n1 == 0) ? 1.0 : 0.0;
241 for (int j{2}; j <= n2 - n1 + 1; ++j) {
242 at[0] = j;
243 *result.Element<CppTypeFor<CAT, KIND>>(at) = 0.0;
247 template <TypeCategory CAT, int KIND>
248 static inline RT_API_ATTRS void DoBesselYn(Descriptor &result, int32_t n1,
249 int32_t n2, CppTypeFor<CAT, KIND> x, CppTypeFor<CAT, KIND> bn1,
250 CppTypeFor<CAT, KIND> bn1_1, const char *sourceFile, int line) {
251 Terminator terminator{sourceFile, line};
252 AllocateBesselResult<CAT, KIND>(result, n1, n2, terminator, "BESSEL_YN");
254 // The standard requires that n1 and n2 be non-negative. However, some other
255 // compilers generate results even when n1 and/or n2 are negative. For now,
256 // we also do not enforce the non-negativity constraint.
257 if (n2 < n1) {
258 return;
261 SubscriptValue at[maxRank];
262 for (int j{0}; j < maxRank; ++j) {
263 at[j] = 0;
266 // if n2 >= n1, there will be at least one element in the result.
267 at[0] = 1;
268 *result.Element<CppTypeFor<CAT, KIND>>(at) = bn1;
270 if (n2 == n1) {
271 return;
274 at[0] = 2;
275 *result.Element<CppTypeFor<CAT, KIND>>(at) = bn1_1;
277 // Bessel functions of the second kind are stable for a forward recursion
278 // (see https://dlmf.nist.gov/10.74.iv and https://dlmf.nist.gov/10.6.E1).
280 // Y(n+1, x) = (2.0 / x) * n * Y(n, x) - Y(n-1, x)
282 // which is equivalent to
284 // Y(n, x) = (2.0 / x) * (n - 1) * Y(n-1, x) - Y(n-2, x)
286 CppTypeFor<CAT, KIND> bn_2 = bn1;
287 CppTypeFor<CAT, KIND> bn_1 = bn1_1;
288 CppTypeFor<CAT, KIND> twoOverX = 2.0 / x;
289 for (int n{n1 + 2}; n <= n2; ++n) {
290 auto bn = twoOverX * (n - 1) * bn_1 - bn_2;
292 at[0] = n - n1 + 1;
293 *result.Element<CppTypeFor<CAT, KIND>>(at) = bn;
295 bn_2 = bn_1;
296 bn_1 = bn;
300 template <TypeCategory CAT, int KIND>
301 static inline RT_API_ATTRS void DoBesselYnX0(Descriptor &result, int32_t n1,
302 int32_t n2, const char *sourceFile, int line) {
303 Terminator terminator{sourceFile, line};
304 AllocateBesselResult<CAT, KIND>(result, n1, n2, terminator, "BESSEL_YN");
306 // The standard requires that n1 and n2 be non-negative. However, some other
307 // compilers generate results even when n1 and/or n2 are negative. For now,
308 // we also do not enforce the non-negativity constraint.
309 if (n2 < n1) {
310 return;
313 SubscriptValue at[maxRank];
314 for (int j{0}; j < maxRank; ++j) {
315 at[j] = 0;
318 // Y(n, 0.0) = -Inf, when n >= 0
319 for (int j{1}; j <= n2 - n1 + 1; ++j) {
320 at[0] = j;
321 *result.Element<CppTypeFor<CAT, KIND>>(at) =
322 -std::numeric_limits<CppTypeFor<CAT, KIND>>::infinity();
326 extern "C" {
327 RT_EXT_API_GROUP_BEGIN
329 // BESSEL_JN
330 // TODO: REAL(2 & 3)
331 void RTDEF(BesselJn_4)(Descriptor &result, int32_t n1, int32_t n2,
332 CppTypeFor<TypeCategory::Real, 4> x, CppTypeFor<TypeCategory::Real, 4> bn2,
333 CppTypeFor<TypeCategory::Real, 4> bn2_1, const char *sourceFile, int line) {
334 DoBesselJn<TypeCategory::Real, 4>(
335 result, n1, n2, x, bn2, bn2_1, sourceFile, line);
338 void RTDEF(BesselJn_8)(Descriptor &result, int32_t n1, int32_t n2,
339 CppTypeFor<TypeCategory::Real, 8> x, CppTypeFor<TypeCategory::Real, 8> bn2,
340 CppTypeFor<TypeCategory::Real, 8> bn2_1, const char *sourceFile, int line) {
341 DoBesselJn<TypeCategory::Real, 8>(
342 result, n1, n2, x, bn2, bn2_1, sourceFile, line);
345 #if HAS_FLOAT80
346 void RTDEF(BesselJn_10)(Descriptor &result, int32_t n1, int32_t n2,
347 CppTypeFor<TypeCategory::Real, 10> x,
348 CppTypeFor<TypeCategory::Real, 10> bn2,
349 CppTypeFor<TypeCategory::Real, 10> bn2_1, const char *sourceFile,
350 int line) {
351 DoBesselJn<TypeCategory::Real, 10>(
352 result, n1, n2, x, bn2, bn2_1, sourceFile, line);
354 #endif
356 #if HAS_LDBL128 || HAS_FLOAT128
357 void RTDEF(BesselJn_16)(Descriptor &result, int32_t n1, int32_t n2,
358 CppTypeFor<TypeCategory::Real, 16> x,
359 CppTypeFor<TypeCategory::Real, 16> bn2,
360 CppTypeFor<TypeCategory::Real, 16> bn2_1, const char *sourceFile,
361 int line) {
362 DoBesselJn<TypeCategory::Real, 16>(
363 result, n1, n2, x, bn2, bn2_1, sourceFile, line);
365 #endif
367 // TODO: REAL(2 & 3)
368 void RTDEF(BesselJnX0_4)(Descriptor &result, int32_t n1, int32_t n2,
369 const char *sourceFile, int line) {
370 DoBesselJnX0<TypeCategory::Real, 4>(result, n1, n2, sourceFile, line);
373 void RTDEF(BesselJnX0_8)(Descriptor &result, int32_t n1, int32_t n2,
374 const char *sourceFile, int line) {
375 DoBesselJnX0<TypeCategory::Real, 8>(result, n1, n2, sourceFile, line);
378 #if HAS_FLOAT80
379 void RTDEF(BesselJnX0_10)(Descriptor &result, int32_t n1, int32_t n2,
380 const char *sourceFile, int line) {
381 DoBesselJnX0<TypeCategory::Real, 10>(result, n1, n2, sourceFile, line);
383 #endif
385 #if HAS_LDBL128 || HAS_FLOAT128
386 void RTDEF(BesselJnX0_16)(Descriptor &result, int32_t n1, int32_t n2,
387 const char *sourceFile, int line) {
388 DoBesselJnX0<TypeCategory::Real, 16>(result, n1, n2, sourceFile, line);
390 #endif
392 // BESSEL_YN
393 // TODO: REAL(2 & 3)
394 void RTDEF(BesselYn_4)(Descriptor &result, int32_t n1, int32_t n2,
395 CppTypeFor<TypeCategory::Real, 4> x, CppTypeFor<TypeCategory::Real, 4> bn1,
396 CppTypeFor<TypeCategory::Real, 4> bn1_1, const char *sourceFile, int line) {
397 DoBesselYn<TypeCategory::Real, 4>(
398 result, n1, n2, x, bn1, bn1_1, sourceFile, line);
401 void RTDEF(BesselYn_8)(Descriptor &result, int32_t n1, int32_t n2,
402 CppTypeFor<TypeCategory::Real, 8> x, CppTypeFor<TypeCategory::Real, 8> bn1,
403 CppTypeFor<TypeCategory::Real, 8> bn1_1, const char *sourceFile, int line) {
404 DoBesselYn<TypeCategory::Real, 8>(
405 result, n1, n2, x, bn1, bn1_1, sourceFile, line);
408 #if HAS_FLOAT80
409 void RTDEF(BesselYn_10)(Descriptor &result, int32_t n1, int32_t n2,
410 CppTypeFor<TypeCategory::Real, 10> x,
411 CppTypeFor<TypeCategory::Real, 10> bn1,
412 CppTypeFor<TypeCategory::Real, 10> bn1_1, const char *sourceFile,
413 int line) {
414 DoBesselYn<TypeCategory::Real, 10>(
415 result, n1, n2, x, bn1, bn1_1, sourceFile, line);
417 #endif
419 #if HAS_LDBL128 || HAS_FLOAT128
420 void RTDEF(BesselYn_16)(Descriptor &result, int32_t n1, int32_t n2,
421 CppTypeFor<TypeCategory::Real, 16> x,
422 CppTypeFor<TypeCategory::Real, 16> bn1,
423 CppTypeFor<TypeCategory::Real, 16> bn1_1, const char *sourceFile,
424 int line) {
425 DoBesselYn<TypeCategory::Real, 16>(
426 result, n1, n2, x, bn1, bn1_1, sourceFile, line);
428 #endif
430 // TODO: REAL(2 & 3)
431 void RTDEF(BesselYnX0_4)(Descriptor &result, int32_t n1, int32_t n2,
432 const char *sourceFile, int line) {
433 DoBesselYnX0<TypeCategory::Real, 4>(result, n1, n2, sourceFile, line);
436 void RTDEF(BesselYnX0_8)(Descriptor &result, int32_t n1, int32_t n2,
437 const char *sourceFile, int line) {
438 DoBesselYnX0<TypeCategory::Real, 8>(result, n1, n2, sourceFile, line);
441 #if HAS_FLOAT80
442 void RTDEF(BesselYnX0_10)(Descriptor &result, int32_t n1, int32_t n2,
443 const char *sourceFile, int line) {
444 DoBesselYnX0<TypeCategory::Real, 10>(result, n1, n2, sourceFile, line);
446 #endif
448 #if HAS_LDBL128 || HAS_FLOAT128
449 void RTDEF(BesselYnX0_16)(Descriptor &result, int32_t n1, int32_t n2,
450 const char *sourceFile, int line) {
451 DoBesselYnX0<TypeCategory::Real, 16>(result, n1, n2, sourceFile, line);
453 #endif
455 // CSHIFT where rank of ARRAY argument > 1
456 void RTDEF(Cshift)(Descriptor &result, const Descriptor &source,
457 const Descriptor &shift, int dim, const char *sourceFile, int line) {
458 Terminator terminator{sourceFile, line};
459 int rank{source.rank()};
460 RUNTIME_CHECK(terminator, rank > 1);
461 if (dim < 1 || dim > rank) {
462 terminator.Crash(
463 "CSHIFT: DIM=%d must be >= 1 and <= ARRAY= rank %d", dim, rank);
465 ShiftControl shiftControl{shift, terminator, dim};
466 shiftControl.Init(source, "CSHIFT");
467 SubscriptValue extent[maxRank];
468 source.GetShape(extent);
469 AllocateResult(result, source, rank, extent, terminator, "CSHIFT");
470 SubscriptValue resultAt[maxRank];
471 for (int j{0}; j < rank; ++j) {
472 resultAt[j] = 1;
474 SubscriptValue sourceLB[maxRank];
475 source.GetLowerBounds(sourceLB);
476 SubscriptValue dimExtent{extent[dim - 1]};
477 SubscriptValue dimLB{sourceLB[dim - 1]};
478 SubscriptValue &resDim{resultAt[dim - 1]};
479 for (std::size_t n{result.Elements()}; n > 0; n -= dimExtent) {
480 SubscriptValue shiftCount{shiftControl.GetShift(resultAt)};
481 SubscriptValue sourceAt[maxRank];
482 for (int j{0}; j < rank; ++j) {
483 sourceAt[j] = sourceLB[j] + resultAt[j] - 1;
485 SubscriptValue &sourceDim{sourceAt[dim - 1]};
486 sourceDim = dimLB + shiftCount % dimExtent;
487 if (sourceDim < dimLB) {
488 sourceDim += dimExtent;
490 for (resDim = 1; resDim <= dimExtent; ++resDim) {
491 CopyElement(result, resultAt, source, sourceAt, terminator);
492 if (++sourceDim == dimLB + dimExtent) {
493 sourceDim = dimLB;
496 result.IncrementSubscripts(resultAt);
500 // CSHIFT where rank of ARRAY argument == 1
501 void RTDEF(CshiftVector)(Descriptor &result, const Descriptor &source,
502 std::int64_t shift, const char *sourceFile, int line) {
503 Terminator terminator{sourceFile, line};
504 RUNTIME_CHECK(terminator, source.rank() == 1);
505 const Dimension &sourceDim{source.GetDimension(0)};
506 SubscriptValue extent{sourceDim.Extent()};
507 AllocateResult(result, source, 1, &extent, terminator, "CSHIFT");
508 SubscriptValue lb{sourceDim.LowerBound()};
509 for (SubscriptValue j{0}; j < extent; ++j) {
510 SubscriptValue resultAt{1 + j};
511 SubscriptValue sourceAt{
512 lb + static_cast<SubscriptValue>(j + shift) % extent};
513 if (sourceAt < lb) {
514 sourceAt += extent;
516 CopyElement(result, &resultAt, source, &sourceAt, terminator);
520 // EOSHIFT of rank > 1
521 void RTDEF(Eoshift)(Descriptor &result, const Descriptor &source,
522 const Descriptor &shift, const Descriptor *boundary, int dim,
523 const char *sourceFile, int line) {
524 Terminator terminator{sourceFile, line};
525 SubscriptValue extent[maxRank];
526 int rank{source.GetShape(extent)};
527 RUNTIME_CHECK(terminator, rank > 1);
528 if (dim < 1 || dim > rank) {
529 terminator.Crash(
530 "EOSHIFT: DIM=%d must be >= 1 and <= ARRAY= rank %d", dim, rank);
532 std::size_t elementLen{
533 AllocateResult(result, source, rank, extent, terminator, "EOSHIFT")};
534 int boundaryRank{-1};
535 if (boundary) {
536 boundaryRank = boundary->rank();
537 RUNTIME_CHECK(terminator, boundaryRank == 0 || boundaryRank == rank - 1);
538 RUNTIME_CHECK(terminator, boundary->type() == source.type());
539 if (boundary->ElementBytes() != elementLen) {
540 terminator.Crash("EOSHIFT: BOUNDARY= has element byte length %zd, but "
541 "ARRAY= has length %zd",
542 boundary->ElementBytes(), elementLen);
544 if (boundaryRank > 0) {
545 int k{0};
546 for (int j{0}; j < rank; ++j) {
547 if (j != dim - 1) {
548 if (boundary->GetDimension(k).Extent() != extent[j]) {
549 terminator.Crash("EOSHIFT: BOUNDARY= has extent %jd on dimension "
550 "%d but must conform with extent %jd of ARRAY=",
551 static_cast<std::intmax_t>(boundary->GetDimension(k).Extent()),
552 k + 1, static_cast<std::intmax_t>(extent[j]));
554 ++k;
559 ShiftControl shiftControl{shift, terminator, dim};
560 shiftControl.Init(source, "EOSHIFT");
561 SubscriptValue resultAt[maxRank];
562 for (int j{0}; j < rank; ++j) {
563 resultAt[j] = 1;
565 if (!boundary) {
566 DefaultInitialize(result, terminator);
568 SubscriptValue sourceLB[maxRank];
569 source.GetLowerBounds(sourceLB);
570 SubscriptValue boundaryAt[maxRank];
571 if (boundaryRank > 0) {
572 boundary->GetLowerBounds(boundaryAt);
574 SubscriptValue dimExtent{extent[dim - 1]};
575 SubscriptValue dimLB{sourceLB[dim - 1]};
576 SubscriptValue &resDim{resultAt[dim - 1]};
577 for (std::size_t n{result.Elements()}; n > 0; n -= dimExtent) {
578 SubscriptValue shiftCount{shiftControl.GetShift(resultAt)};
579 SubscriptValue sourceAt[maxRank];
580 for (int j{0}; j < rank; ++j) {
581 sourceAt[j] = sourceLB[j] + resultAt[j] - 1;
583 SubscriptValue &sourceDim{sourceAt[dim - 1]};
584 sourceDim = dimLB + shiftCount;
585 for (resDim = 1; resDim <= dimExtent; ++resDim) {
586 if (sourceDim >= dimLB && sourceDim < dimLB + dimExtent) {
587 CopyElement(result, resultAt, source, sourceAt, terminator);
588 } else if (boundary) {
589 CopyElement(result, resultAt, *boundary, boundaryAt, terminator);
591 ++sourceDim;
593 result.IncrementSubscripts(resultAt);
594 if (boundaryRank > 0) {
595 boundary->IncrementSubscripts(boundaryAt);
600 // EOSHIFT of vector
601 void RTDEF(EoshiftVector)(Descriptor &result, const Descriptor &source,
602 std::int64_t shift, const Descriptor *boundary, const char *sourceFile,
603 int line) {
604 Terminator terminator{sourceFile, line};
605 RUNTIME_CHECK(terminator, source.rank() == 1);
606 SubscriptValue extent{source.GetDimension(0).Extent()};
607 std::size_t elementLen{
608 AllocateResult(result, source, 1, &extent, terminator, "EOSHIFT")};
609 if (boundary) {
610 RUNTIME_CHECK(terminator, boundary->rank() == 0);
611 RUNTIME_CHECK(terminator, boundary->type() == source.type());
612 if (boundary->ElementBytes() != elementLen) {
613 terminator.Crash("EOSHIFT: BOUNDARY= has element byte length %zd but "
614 "ARRAY= has length %zd",
615 boundary->ElementBytes(), elementLen);
618 if (!boundary) {
619 DefaultInitialize(result, terminator);
621 SubscriptValue lb{source.GetDimension(0).LowerBound()};
622 for (SubscriptValue j{1}; j <= extent; ++j) {
623 SubscriptValue sourceAt{lb + j - 1 + static_cast<SubscriptValue>(shift)};
624 if (sourceAt >= lb && sourceAt < lb + extent) {
625 CopyElement(result, &j, source, &sourceAt, terminator);
626 } else if (boundary) {
627 CopyElement(result, &j, *boundary, 0, terminator);
632 // PACK
633 void RTDEF(Pack)(Descriptor &result, const Descriptor &source,
634 const Descriptor &mask, const Descriptor *vector, const char *sourceFile,
635 int line) {
636 Terminator terminator{sourceFile, line};
637 CheckConformability(source, mask, terminator, "PACK", "ARRAY=", "MASK=");
638 auto maskType{mask.type().GetCategoryAndKind()};
639 RUNTIME_CHECK(
640 terminator, maskType && maskType->first == TypeCategory::Logical);
641 SubscriptValue trues{0};
642 if (mask.rank() == 0) {
643 if (IsLogicalElementTrue(mask, nullptr)) {
644 trues = source.Elements();
646 } else {
647 SubscriptValue maskAt[maxRank];
648 mask.GetLowerBounds(maskAt);
649 for (std::size_t n{mask.Elements()}; n > 0; --n) {
650 if (IsLogicalElementTrue(mask, maskAt)) {
651 ++trues;
653 mask.IncrementSubscripts(maskAt);
656 SubscriptValue extent{trues};
657 if (vector) {
658 RUNTIME_CHECK(terminator, vector->rank() == 1);
659 RUNTIME_CHECK(terminator, source.type() == vector->type());
660 if (source.ElementBytes() != vector->ElementBytes()) {
661 terminator.Crash("PACK: ARRAY= has element byte length %zd, but VECTOR= "
662 "has length %zd",
663 source.ElementBytes(), vector->ElementBytes());
665 extent = vector->GetDimension(0).Extent();
666 if (extent < trues) {
667 terminator.Crash("PACK: VECTOR= has extent %jd but there are %jd MASK= "
668 "elements that are .TRUE.",
669 static_cast<std::intmax_t>(extent),
670 static_cast<std::intmax_t>(trues));
673 AllocateResult(result, source, 1, &extent, terminator, "PACK");
674 SubscriptValue sourceAt[maxRank], resultAt{1};
675 source.GetLowerBounds(sourceAt);
676 if (mask.rank() == 0) {
677 if (IsLogicalElementTrue(mask, nullptr)) {
678 for (SubscriptValue n{trues}; n > 0; --n) {
679 CopyElement(result, &resultAt, source, sourceAt, terminator);
680 ++resultAt;
681 source.IncrementSubscripts(sourceAt);
684 } else {
685 SubscriptValue maskAt[maxRank];
686 mask.GetLowerBounds(maskAt);
687 for (std::size_t n{source.Elements()}; n > 0; --n) {
688 if (IsLogicalElementTrue(mask, maskAt)) {
689 CopyElement(result, &resultAt, source, sourceAt, terminator);
690 ++resultAt;
692 source.IncrementSubscripts(sourceAt);
693 mask.IncrementSubscripts(maskAt);
696 if (vector) {
697 SubscriptValue vectorAt{
698 vector->GetDimension(0).LowerBound() + resultAt - 1};
699 for (; resultAt <= extent; ++resultAt, ++vectorAt) {
700 CopyElement(result, &resultAt, *vector, &vectorAt, terminator);
705 // RESHAPE
706 // F2018 16.9.163
707 void RTDEF(Reshape)(Descriptor &result, const Descriptor &source,
708 const Descriptor &shape, const Descriptor *pad, const Descriptor *order,
709 const char *sourceFile, int line) {
710 // Compute and check the rank of the result.
711 Terminator terminator{sourceFile, line};
712 RUNTIME_CHECK(terminator, shape.rank() == 1);
713 RUNTIME_CHECK(terminator, shape.type().IsInteger());
714 SubscriptValue resultRank{shape.GetDimension(0).Extent()};
715 if (resultRank < 0 || resultRank > static_cast<SubscriptValue>(maxRank)) {
716 terminator.Crash(
717 "RESHAPE: SHAPE= vector length %jd implies a bad result rank",
718 static_cast<std::intmax_t>(resultRank));
721 // Extract and check the shape of the result; compute its element count.
722 SubscriptValue resultExtent[maxRank];
723 std::size_t shapeElementBytes{shape.ElementBytes()};
724 std::size_t resultElements{1};
725 SubscriptValue shapeSubscript{shape.GetDimension(0).LowerBound()};
726 for (int j{0}; j < resultRank; ++j, ++shapeSubscript) {
727 auto extent{GetInt64Safe(
728 shape.Element<char>(&shapeSubscript), shapeElementBytes, terminator)};
729 if (!extent) {
730 terminator.Crash("RESHAPE: value of SHAPE(%d) exceeds 64 bits", j + 1);
731 } else if (*extent < 0) {
732 terminator.Crash("RESHAPE: bad value for SHAPE(%d)=%jd", j + 1,
733 static_cast<std::intmax_t>(*extent));
735 resultExtent[j] = *extent;
736 resultElements *= resultExtent[j];
739 // Check that there are sufficient elements in the SOURCE=, or that
740 // the optional PAD= argument is present and nonempty.
741 std::size_t elementBytes{source.ElementBytes()};
742 std::size_t sourceElements{source.Elements()};
743 std::size_t padElements{pad ? pad->Elements() : 0};
744 if (resultElements > sourceElements) {
745 if (padElements <= 0) {
746 terminator.Crash(
747 "RESHAPE: not enough elements, need %zd but only have %zd",
748 resultElements, sourceElements);
750 if (pad->ElementBytes() != elementBytes) {
751 terminator.Crash("RESHAPE: PAD= has element byte length %zd but SOURCE= "
752 "has length %zd",
753 pad->ElementBytes(), elementBytes);
757 // Extract and check the optional ORDER= argument, which must be a
758 // permutation of [1..resultRank].
759 int dimOrder[maxRank];
760 if (order) {
761 RUNTIME_CHECK(terminator, order->rank() == 1);
762 RUNTIME_CHECK(terminator, order->type().IsInteger());
763 if (order->GetDimension(0).Extent() != resultRank) {
764 terminator.Crash("RESHAPE: the extent of ORDER (%jd) must match the rank"
765 " of the SHAPE (%d)",
766 static_cast<std::intmax_t>(order->GetDimension(0).Extent()),
767 resultRank);
769 std::uint64_t values{0};
770 SubscriptValue orderSubscript{order->GetDimension(0).LowerBound()};
771 std::size_t orderElementBytes{order->ElementBytes()};
772 for (SubscriptValue j{0}; j < resultRank; ++j, ++orderSubscript) {
773 auto k{GetInt64Safe(order->Element<char>(&orderSubscript),
774 orderElementBytes, terminator)};
775 if (!k) {
776 terminator.Crash("RESHAPE: ORDER element value exceeds 64 bits");
777 } else if (*k < 1 || *k > resultRank || ((values >> *k) & 1)) {
778 terminator.Crash("RESHAPE: bad value for ORDER element (%jd)",
779 static_cast<std::intmax_t>(*k));
781 values |= std::uint64_t{1} << *k;
782 dimOrder[j] = *k - 1;
784 } else {
785 for (int j{0}; j < resultRank; ++j) {
786 dimOrder[j] = j;
790 // Allocate result descriptor
791 AllocateResult(
792 result, source, resultRank, resultExtent, terminator, "RESHAPE");
794 // Populate the result's elements.
795 SubscriptValue resultSubscript[maxRank];
796 result.GetLowerBounds(resultSubscript);
797 SubscriptValue sourceSubscript[maxRank];
798 source.GetLowerBounds(sourceSubscript);
799 std::size_t resultElement{0};
800 std::size_t elementsFromSource{std::min(resultElements, sourceElements)};
801 for (; resultElement < elementsFromSource; ++resultElement) {
802 CopyElement(result, resultSubscript, source, sourceSubscript, terminator);
803 source.IncrementSubscripts(sourceSubscript);
804 result.IncrementSubscripts(resultSubscript, dimOrder);
806 if (resultElement < resultElements) {
807 // Remaining elements come from the optional PAD= argument.
808 SubscriptValue padSubscript[maxRank];
809 pad->GetLowerBounds(padSubscript);
810 for (; resultElement < resultElements; ++resultElement) {
811 CopyElement(result, resultSubscript, *pad, padSubscript, terminator);
812 pad->IncrementSubscripts(padSubscript);
813 result.IncrementSubscripts(resultSubscript, dimOrder);
818 // SPREAD
819 void RTDEF(Spread)(Descriptor &result, const Descriptor &source, int dim,
820 std::int64_t ncopies, const char *sourceFile, int line) {
821 Terminator terminator{sourceFile, line};
822 int rank{source.rank() + 1};
823 RUNTIME_CHECK(terminator, rank <= maxRank);
824 if (dim < 1 || dim > rank) {
825 terminator.Crash("SPREAD: DIM=%d argument for rank-%d source array "
826 "must be greater than 1 and less than or equal to %d",
827 dim, rank - 1, rank);
829 ncopies = std::max<std::int64_t>(ncopies, 0);
830 SubscriptValue extent[maxRank];
831 int k{0};
832 for (int j{0}; j < rank; ++j) {
833 extent[j] = j == dim - 1 ? ncopies : source.GetDimension(k++).Extent();
835 AllocateResult(result, source, rank, extent, terminator, "SPREAD");
836 SubscriptValue resultAt[maxRank];
837 for (int j{0}; j < rank; ++j) {
838 resultAt[j] = 1;
840 SubscriptValue &resultDim{resultAt[dim - 1]};
841 SubscriptValue sourceAt[maxRank];
842 source.GetLowerBounds(sourceAt);
843 for (std::size_t n{result.Elements()}; n > 0; n -= ncopies) {
844 for (resultDim = 1; resultDim <= ncopies; ++resultDim) {
845 CopyElement(result, resultAt, source, sourceAt, terminator);
847 result.IncrementSubscripts(resultAt);
848 source.IncrementSubscripts(sourceAt);
852 // TRANSPOSE
853 void RTDEF(Transpose)(Descriptor &result, const Descriptor &matrix,
854 const char *sourceFile, int line) {
855 Terminator terminator{sourceFile, line};
856 RUNTIME_CHECK(terminator, matrix.rank() == 2);
857 SubscriptValue extent[2]{
858 matrix.GetDimension(1).Extent(), matrix.GetDimension(0).Extent()};
859 AllocateResult(result, matrix, 2, extent, terminator, "TRANSPOSE");
860 SubscriptValue resultAt[2]{1, 1};
861 SubscriptValue matrixLB[2];
862 matrix.GetLowerBounds(matrixLB);
863 for (std::size_t n{result.Elements()}; n-- > 0;
864 result.IncrementSubscripts(resultAt)) {
865 SubscriptValue matrixAt[2]{
866 matrixLB[0] + resultAt[1] - 1, matrixLB[1] + resultAt[0] - 1};
867 CopyElement(result, resultAt, matrix, matrixAt, terminator);
871 // UNPACK
872 void RTDEF(Unpack)(Descriptor &result, const Descriptor &vector,
873 const Descriptor &mask, const Descriptor &field, const char *sourceFile,
874 int line) {
875 Terminator terminator{sourceFile, line};
876 RUNTIME_CHECK(terminator, vector.rank() == 1);
877 int rank{mask.rank()};
878 RUNTIME_CHECK(terminator, rank > 0);
879 SubscriptValue extent[maxRank];
880 mask.GetShape(extent);
881 CheckConformability(mask, field, terminator, "UNPACK", "MASK=", "FIELD=");
882 std::size_t elementLen{
883 AllocateResult(result, field, rank, extent, terminator, "UNPACK")};
884 RUNTIME_CHECK(terminator, vector.type() == field.type());
885 if (vector.ElementBytes() != elementLen) {
886 terminator.Crash(
887 "UNPACK: VECTOR= has element byte length %zd but FIELD= has length %zd",
888 vector.ElementBytes(), elementLen);
890 SubscriptValue resultAt[maxRank], maskAt[maxRank], fieldAt[maxRank],
891 vectorAt{vector.GetDimension(0).LowerBound()};
892 for (int j{0}; j < rank; ++j) {
893 resultAt[j] = 1;
895 mask.GetLowerBounds(maskAt);
896 field.GetLowerBounds(fieldAt);
897 SubscriptValue vectorElements{vector.GetDimension(0).Extent()};
898 SubscriptValue vectorLeft{vectorElements};
899 for (std::size_t n{result.Elements()}; n-- > 0;) {
900 if (IsLogicalElementTrue(mask, maskAt)) {
901 if (vectorLeft-- == 0) {
902 terminator.Crash(
903 "UNPACK: VECTOR= argument has fewer elements (%d) than "
904 "MASK= has .TRUE. entries",
905 vectorElements);
907 CopyElement(result, resultAt, vector, &vectorAt, terminator);
908 ++vectorAt;
909 } else {
910 CopyElement(result, resultAt, field, fieldAt, terminator);
912 result.IncrementSubscripts(resultAt);
913 mask.IncrementSubscripts(maskAt);
914 field.IncrementSubscripts(fieldAt);
918 RT_EXT_API_GROUP_END
919 } // extern "C"
920 } // namespace Fortran::runtime