[AArch64] Add fpext and fpround costs (#119292)
[llvm-project.git] / mlir / test / Dialect / SPIRV / IR / matrix-ops.mlir
blob372fcc6e514b97421c73560c6af6b1fc8326ce50
1 // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
3 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
4   // CHECK-LABEL: @matrix_times_scalar_1
5   spirv.func @matrix_times_scalar_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spirv.matrix<3 x vector<3xf32>> "None" {
6     // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32
7     %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32
8     spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
9   }
11   // CHECK-LABEL: @matrix_times_scalar_2
12   spirv.func @matrix_times_scalar_2(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA> "None" {
13     // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, f16
14     %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, f16
15     spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
16   }
18   // CHECK-LABEL: @matrix_transpose_1
19   spirv.func @matrix_transpose_1(%arg0 : !spirv.matrix<3 x vector<2xf32>>) -> !spirv.matrix<2 x vector<3xf32>> "None" {
20     // CHECK: {{%.*}} = spirv.Transpose {{%.*}} : !spirv.matrix<3 x vector<2xf32>> -> !spirv.matrix<2 x vector<3xf32>>
21     %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<2xf32>> -> !spirv.matrix<2 x vector<3xf32>>
22     spirv.ReturnValue %result : !spirv.matrix<2 x vector<3xf32>>
23   }
25   // CHECK-LABEL: @matrix_transpose_2
26   spirv.func @matrix_transpose_2(%arg0 : !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None" {
27     // CHECK: {{%.*}} = spirv.Transpose {{%.*}} : !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
28     %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
29     spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
30   }
32   // CHECK-LABEL: @matrix_times_matrix_1
33   spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
34     // CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
35     %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
36     spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
37   }
39   // CHECK-LABEL: @matrix_times_matrix_2
40   spirv.func @matrix_times_matrix_2(%arg0: !spirv.matrix<3 x vector<2xf32>>, %arg1: !spirv.matrix<2 x vector<3xf32>>) -> !spirv.matrix<2 x vector<2xf32>> "None"{
41     // CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<2 x vector<2xf32>>
42     %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<2 x vector<2xf32>>
43     spirv.ReturnValue %result : !spirv.matrix<2 x vector<2xf32>>
44   }
47 // -----
49 func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f16) {
50   // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
51   %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f16
52   return
55 // -----
57 func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f64) {
58   // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
59   %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f64
60   return
63 // -----
65 func.func @transpose_op_shape_mismatch_1(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
66    // expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
67    %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<3 x vector<3xf32>>
68    return
71 // -----
73 func.func @transpose_op_shape_mismatch_2(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
74    // expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
75    %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<2 x vector<4xf32>>
76    return
79 // -----
81 func.func @transpose_op_type_mismatch(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
82    // expected-error @+1 {{input and output matrices must have the same component type}}
83    %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<4 x vector<3xf16>>
84    return
87 // -----
89 func.func @matrix_times_matrix_invalid_input_shape_1(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<3xf32>>){
90    // expected-error @+1 {{right and result matrices must have equal columns' count}}
91    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<3 x vector<2xf32>>
92    return
95 // -----
97 func.func @matrix_times_matrix_invalid_input_shape_2(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<3xf32>>){
98    // expected-error @+1 {{left and result matrices must have equal rows' count}}
99    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<2 x vector<3xf32>>
100    return
103 // -----
105 func.func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<2xf32>>){
106    // expected-error @+1 {{left matrix columns' count must be equal to the right matrix rows' count}}
107    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<2xf32>> -> !spirv.matrix<2 x vector<2xf32>>
108    return
111 // -----
113 func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
114    // expected-error @+1 {{right and result matrices' component type must be the same}}
115    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf64>>
116    return
120 // -----
122 func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 x vector<3xf64>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
123    // expected-error @+1 {{left and result matrices' component type must be the same}}
124    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf64>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
125    return