[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / flang / test / HLFIR / matmul-lowering.fir
blobfd76db265951617bcacbe123e995de45ee2473b3
1 // Test hlfir.matmul operation lowering to fir runtime call
2 // RUN: fir-opt %s -lower-hlfir-intrinsics | FileCheck %s
4 func.func @_QPmatmul1(%arg0: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "lhs"}, %arg1: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "rhs"}, %arg2: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "res"}) {
5   %0:2 = hlfir.declare %arg0 {uniq_name = "_QFmatmul1Elhs"} : (!fir.box<!fir.array<?x?xi32>>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>)
6   %1:2 = hlfir.declare %arg2 {uniq_name = "_QFmatmul1Eres"} : (!fir.box<!fir.array<?x?xi32>>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>)
7   %2:2 = hlfir.declare %arg1 {uniq_name = "_QFmatmul1Erhs"} : (!fir.box<!fir.array<?x?xi32>>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>)
8   %3 = hlfir.matmul %0#0 %2#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>) -> !hlfir.expr<?x?xi32>
9   hlfir.assign %3 to %1#0 : !hlfir.expr<?x?xi32>, !fir.box<!fir.array<?x?xi32>>
10   hlfir.destroy %3 : !hlfir.expr<?x?xi32>
11   return
13 // CHECK-LABEL: func.func @_QPmatmul1(
14 // CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "lhs"}
15 // CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "rhs"}
16 // CHECK:           %[[ARG2:.*]]: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "res"}
17 // CHECK-DAG:     %[[TRUE:.*]] = arith.constant true
18 // CHECK-DAG:     %[[LHS_VAR:.*]]:2 = hlfir.declare %[[ARG0]]
19 // CHECK-DAG:     %[[RHS_VAR:.*]]:2 = hlfir.declare %[[ARG1]]
20 // CHECK-DAG:     %[[RES_VAR:.*]]:2 = hlfir.declare %[[ARG2]]
22 // CHECK-DAG:     %[[RET_BOX:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x?xi32>>>
23 // CHECK-DAG:     %[[RET_ADDR:.*]] = fir.zero_bits !fir.heap<!fir.array<?x?xi32>>
24 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
25 // CHECK-DAG:     %[[RET_SHAPE:.*]] = fir.shape %[[C0]], %[[C0]] : (index, index) -> !fir.shape<2>
26 // CHECK-DAG:     %[[RET_EMBOX:.*]] = fir.embox %[[RET_ADDR]](%[[RET_SHAPE]])
27 // CHECK-DAG:     fir.store %[[RET_EMBOX]] to %[[RET_BOX]]
29 // CHECK:         %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
30 // CHECK-DAG:     %[[LHS_ARG:.*]] = fir.convert %[[LHS_VAR]]#1 : (!fir.box<!fir.array<?x?xi32>>) -> !fir.box<none>
31 // CHECK-DAG:     %[[RHS_ARG:.*]] = fir.convert %[[RHS_VAR]]#1 : (!fir.box<!fir.array<?x?xi32>>) -> !fir.box<none>
32 // CHECK:         %[[NONE:.*]] = fir.call @_FortranAMatmulInteger4Integer4(%[[RET_ARG]], %[[LHS_ARG]], %[[RHS_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) fastmath<contract>
34 // CHECK:         %[[RET:.*]] = fir.load %[[RET_BOX]]
35 // CHECK-DAG:     %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[RET]]
36 // CHECK-DAG:     %[[ADDR:.*]] = fir.box_addr %[[RET]]
37 // CHECK-NEXT:    %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1
38 // TODO: fix alias analysis in hlfir.assign bufferization
39 // CHECK-NEXT:    %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"}
40 // TODO: add shape information from original intrinsic op
41 // CHECK:         %[[ASEXPR:.*]] = hlfir.as_expr %[[TMP]]#0 move %[[TRUE]] : (!fir.box<!fir.array<?x?xi32>>, i1) -> !hlfir.expr<?x?xi32>
42 // CHECK:         hlfir.assign %[[ASEXPR]] to %[[RES_VAR]]#0
43 // CHECK:         hlfir.destroy %[[ASEXPR]]
44 // CHECK-NEXT:    return
45 // CHECK-NEXT:  }
47 // nested matmuls leading to recursive pattern application
48 func.func @_QPtest(%arg0: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "a"}, %arg1: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "b"}, %arg2: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "c"}, %arg3: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "out"}) {
49   %c3 = arith.constant 3 : index
50   %c3_0 = arith.constant 3 : index
51   %0 = fir.shape %c3, %c3_0 : (index, index) -> !fir.shape<2>
52   %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFtestEa"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
53   %c3_1 = arith.constant 3 : index
54   %c3_2 = arith.constant 3 : index
55   %2 = fir.shape %c3_1, %c3_2 : (index, index) -> !fir.shape<2>
56   %3:2 = hlfir.declare %arg1(%2) {uniq_name = "_QFtestEb"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
57   %c3_3 = arith.constant 3 : index
58   %c3_4 = arith.constant 3 : index
59   %4 = fir.shape %c3_3, %c3_4 : (index, index) -> !fir.shape<2>
60   %5:2 = hlfir.declare %arg2(%4) {uniq_name = "_QFtestEc"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
61   %c3_5 = arith.constant 3 : index
62   %c3_6 = arith.constant 3 : index
63   %6 = fir.shape %c3_5, %c3_6 : (index, index) -> !fir.shape<2>
64   %7:2 = hlfir.declare %arg3(%6) {uniq_name = "_QFtestEout"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
65   %8 = hlfir.matmul %1#0 %3#0 {fastmath = #arith.fastmath<contract>} : (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>) -> !hlfir.expr<3x3xf32>
66   %9 = hlfir.matmul %8 %5#0 {fastmath = #arith.fastmath<contract>} : (!hlfir.expr<3x3xf32>, !fir.ref<!fir.array<3x3xf32>>) -> !hlfir.expr<3x3xf32>
67   hlfir.assign %9 to %7#0 : !hlfir.expr<3x3xf32>, !fir.ref<!fir.array<3x3xf32>>
68   hlfir.destroy %9 : !hlfir.expr<3x3xf32>
69   hlfir.destroy %8 : !hlfir.expr<3x3xf32>
70   return
72 // just check that we apply the patterns successfully. The details are checked above
73 // CHECK-LABEL: func.func @_QPtest(
74 // CHECK:         fir.call @_FortranAMatmulReal4Real4({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
75 // CHECK:         fir.call @_FortranAMatmulReal4Real4({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
76 // CHECK:         return
77 // CHECK-NEXT:  }