[clang-tidy][doc] align the title style in clang-tidy/index.rst (#119938)
[llvm-project.git] / mlir / test / Integration / Dialect / PDL / CPU / multiroot.mlir
blobeb04ceb1ce45631c39560439baf4ea77449c3ddd
1 // RUN: mlir-opt %s  -allow-unregistered-dialect -test-pdl-bytecode-pass -split-input-file | FileCheck %s
3 // -----
5 //===----------------------------------------------------------------------===//
6 // 1-layer perceptron with split fwd/bwd operations
7 //===----------------------------------------------------------------------===//
9 module @patterns {
10   // fc_fwd
11   pdl.pattern : benefit(1) {
12     %in_type = pdl.type
13     %out_type = pdl.type
14     %weight_type = pdl.type
15     %rxact = pdl.operand : %in_type
16     %weight = pdl.operand : %weight_type
18     %attr0 = pdl.attribute = false
19     %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type)
21     pdl.rewrite %op0 {
22       %op1 = pdl.operation "kernel.FcFwd" (%rxact, %weight : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
23       %val1 = pdl.result 0 of %op1  // txact
24       pdl.replace %op0 with (%val1 : !pdl.value)  // tf.MatMul
25     }
26   }
28   // fc_bwd
29   pdl.pattern : benefit(4) {
30     %in_type = pdl.type
31     %out_type = pdl.type
32     %weight_type = pdl.type
33     %const_type = pdl.type
34     %rxact = pdl.operand : %in_type
35     %rxdelta = pdl.operand : %out_type
36     %weight = pdl.operand : %weight_type
38     %attr0 = pdl.attribute = true
39     %attr1 = pdl.attribute = false
40     %op0 = pdl.operation "tf.MatMul" (%rxact, %rxdelta : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr1} -> (%weight_type : !pdl.type)
41     %val0 = pdl.result 0 of %op0
42     %op1 = pdl.operation "tf.Const" -> (%const_type : !pdl.type)
43     %val1 = pdl.result 0 of %op1
44     %op2 = pdl.operation "tf.Mul" (%val0, %val1 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
45     %val2 = pdl.result 0 of %op2
46     %op3 = pdl.operation "tf.Sub" (%weight, %val2 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
48     pdl.rewrite %op3 {
49       %op4 = pdl.operation "kernel.FcBwd" (%rxact, %rxdelta, %weight : !pdl.value, !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
50       %val4 = pdl.result 0 of %op4  // weight_out
51       pdl.replace %op3 with (%val4 : !pdl.value)  // tf.Sub
52       pdl.erase %op2  // tf.Mul
53       pdl.erase %op1  // tf.Const
54       pdl.erase %op0  // tf.MatMul
55     }
56   }
58   // softmax_cross_entropy
59   pdl.pattern : benefit(6) {
60     %in_type = pdl.type
61     %label_type = pdl.type
62     %loss_type = pdl.type
63     %mean_loss_type = pdl.type
64     %mean_const_type = pdl.type
65     %mul_const_type = pdl.type
66     %rxact = pdl.operand : %in_type
67     %rxlabel = pdl.operand : %label_type
69     %op0 = pdl.operation "tf.SparseSoftmaxCrossEntropyWithLogits" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%loss_type, %in_type : !pdl.type, !pdl.type)
70     %val0_0 = pdl.result 0 of %op0  // loss
71     %val0_1 = pdl.result 1 of %op0  // gradient
72     %op1 = pdl.operation "tf.Const" -> (%mean_const_type : !pdl.type)
73     %val1 = pdl.result 0 of %op1
74     %op2 = pdl.operation "tf.Mean" (%val0_0, %val1 : !pdl.value, !pdl.value) -> (%mean_loss_type : !pdl.type)
75     %val2 = pdl.result 0 of %op2
76     %op3 = pdl.operation "tf.PreventGradient" (%val0_1 : !pdl.value) -> (%in_type : !pdl.type)
77     %val3 = pdl.result 0 of %op3
78     %op4 = pdl.operation "tf.Const" -> (%mul_const_type : !pdl.type)
79     %val4 = pdl.result 0 of %op4
80     %op5 = pdl.operation "tf.Mul" (%val3, %val4 : !pdl.value, !pdl.value) -> (%in_type : !pdl.type)
82     pdl.rewrite {  // roots: %op2, %op5
83       %op6 = pdl.operation "kernel.SoftmaxCrossEntropy" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%mean_loss_type, %in_type : !pdl.type, !pdl.type)
84       %val6_0 = pdl.result 0 of %op6  // txloss
85       %val6_1 = pdl.result 1 of %op6  // txdelta
86       pdl.replace %op5 with (%val6_1 : !pdl.value)  // tf.Mul
87       pdl.erase %op4  // tf.Const
88       pdl.erase %op3  // tf.PreventGradient
89       pdl.replace %op2 with (%val6_0 : !pdl.value)  // tf.Mean
90       pdl.erase %op1  // tf.Const
91       pdl.erase %op0  // tf.SparseSoftmaxCrossEntropyWithLogits
92     }
93   }
96 // CHECK-LABEL: test.mlp_split
97 // CHECK: %[[FWD:.*]] = "kernel.FcFwd"(%arg0, %arg2) : (tensor<2x20xf32>, tensor<20x10xf32>) -> tensor<2x10xf32>
98 // CHECK: %[[SM:.*]]:2 = "kernel.SoftmaxCrossEntropy"(%[[FWD]], %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<f32>, tensor<2x10xf32>)
99 // CHECK: %[[BWD:.*]] = "kernel.FcBwd"(%arg0, %[[SM]]#1, %arg2) : (tensor<2x20xf32>, tensor<2x10xf32>, tensor<20x10xf32>) -> tensor<20x10xf32>
100 // CHECK: return %[[SM:.*]]#0, %[[BWD]] : tensor<f32>, tensor<20x10xf32>
101 module @ir attributes { test.mlp_split } {
102   func.func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<20x10xf32>) -> (tensor<f32>, tensor<20x10xf32>) {
103     %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
104     %1 = "tf.Const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32>
105     %2 = "tf.Const"() {value = dense<5.000000e-01> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
106     %3 = "tf.MatMul"(%arg0, %arg2) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x10xf32>) -> tensor<2x10xf32>
107     %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%3, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x10xf32>)
108     %4 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
109     %5 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) -> tensor<2x10xf32>
110     %6 = "tf.Mul"(%5, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) -> tensor<2x10xf32>
111     %7 = "tf.MatMul"(%arg0, %6) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x10xf32>) -> tensor<20x10xf32>
112     %8 = "tf.Mul"(%7, %1) : (tensor<20x10xf32>, tensor<f32>) -> tensor<20x10xf32>
113     %9 = "tf.Sub"(%arg2, %8) : (tensor<20x10xf32>, tensor<20x10xf32>) -> tensor<20x10xf32>
114     return %4, %9 : tensor<f32>, tensor<20x10xf32>
115   }
118 // -----
120 //===----------------------------------------------------------------------===//
121 // 2-layer perceptron with fused fwd/bwd operations
122 //===----------------------------------------------------------------------===//
124 module @patterns {
126   // gradient descent
127   pdl.pattern : benefit(3) {
128     %const_type = pdl.type
129     %param_type = pdl.type
130     %param = pdl.operand : %param_type
131     %gradient = pdl.operand : %param_type
133     %attr0 = pdl.attribute
134     %op0 = pdl.operation "tf.Const" {"value" = %attr0} -> (%const_type : !pdl.type)
135     %val0 = pdl.result 0 of %op0
136     %op1 = pdl.operation "tf.Mul" (%gradient, %val0 : !pdl.value, !pdl.value) -> (%param_type : !pdl.type)
137     %val1 = pdl.result 0 of %op1
138     %op2 = pdl.operation "tf.Sub" (%param, %val1 : !pdl.value, !pdl.value) -> (%param_type : !pdl.type)
140     pdl.rewrite %op2 {
141       %op3 = pdl.operation "kernel.GD" (%param, %gradient : !pdl.value, !pdl.value) -> (%param_type : !pdl.type)
142       %val3 = pdl.result 0 of %op3
143       pdl.replace %op2 with (%val3 : !pdl.value)  // tf.Sub
144       pdl.erase %op1  // tf.Mul
145     }
146   }
148   // first FC
149   pdl.pattern : benefit(8) {
150     %in_type = pdl.type
151     %out_type = pdl.type
152     %weight_type = pdl.type
153     %bias_type = pdl.type
154     %rxact = pdl.operand : %in_type
155     %rxdelta = pdl.operand : %out_type
156     %weight = pdl.operand : %weight_type
157     %bias = pdl.operand : %bias_type
159     %attr0 = pdl.attribute = false
160     %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type)
161     %val0 = pdl.result 0 of %op0
162     %op1 = pdl.operation "tf.BiasAdd" (%val0, %bias : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
163     %val1 = pdl.result 0 of %op1
164     %op2 = pdl.operation "tf.Relu" (%val1 : !pdl.value) -> (%out_type : !pdl.type)
165     %val2 = pdl.result 0 of %op2
166     %op3 = pdl.operation "tf.ReluGrad" (%rxdelta, %val2 : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
167     %val3 = pdl.result 0 of %op3
168     %attr1 = pdl.attribute = true
169     %op4 = pdl.operation "tf.MatMul" (%rxact, %val3 : !pdl.value, !pdl.value) {"transpose_a" = %attr1, "transpose_b" = %attr0} -> (%weight_type : !pdl.type)
170     %val4 = pdl.result 0 of %op4
171     %op5 = pdl.operation "kernel.GD" (%weight, %val4 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
172     %op6 = pdl.operation "tf.BiasAddGrad" (%val3 : !pdl.value) -> (%bias_type : !pdl.type)
173     %val6 = pdl.result 0 of %op6
174     %op7 = pdl.operation "kernel.GD" (%bias, %val6 : !pdl.value, !pdl.value) -> (%bias_type : !pdl.type)
176     pdl.rewrite {  // roots: %op2, %op5, %op7
177       %op8 = pdl.operation "kernel.FcWithBias" (%rxact, %rxdelta, %weight, %bias : !pdl.value, !pdl.value, !pdl.value, !pdl.value) -> (%out_type, %weight_type, %bias_type : !pdl.type, !pdl.type, !pdl.type)
178       %val8_0 = pdl.result 0 of %op8  // txact
179       %val8_1 = pdl.result 1 of %op8  // weight_out
180       %val8_2 = pdl.result 2 of %op8  // bias_out
181       pdl.replace %op7 with (%val8_2 : !pdl.value)  // kernel.GD
182       pdl.erase %op6  // tf.BiasAddGrad
183       pdl.replace %op5 with (%val8_1 : !pdl.value)  // kernel.GD
184       pdl.erase %op4  // tf.MatMul
185       pdl.erase %op3  // tf.ReluGrad
186       pdl.replace %op2 with (%val8_0 : !pdl.value)  // tf.Relu
187       pdl.erase %op1  // tf.BiasAdd
188       pdl.erase %op0  // tf.MatMul
189     }
190   }
192   // second FC
193   pdl.pattern : benefit(4) {
194     %in_type = pdl.type
195     %out_type = pdl.type
196     %weight_type = pdl.type
197     %rxact = pdl.operand : %in_type
198     %rxdelta = pdl.operand : %out_type
199     %weight = pdl.operand : %weight_type
201     %attr0 = pdl.attribute = false
202     %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type)
203     %attr1 = pdl.attribute = true
204     %op1 = pdl.operation "tf.MatMul" (%rxdelta, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr1} -> (%in_type : !pdl.type)
205     %op2 = pdl.operation "tf.MatMul" (%rxact, %rxdelta : !pdl.value, !pdl.value) {"transpose_a" = %attr1, "transpose_b" = %attr0} -> (%weight_type : !pdl.type)
206     %val2 = pdl.result 0 of %op2
207     %op3 = pdl.operation "kernel.GD" (%weight, %val2 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
209     pdl.rewrite {  // roots: %op0, %op1, %op3
210       %op4 = pdl.operation "kernel.Fc" (%rxact, %rxdelta, %weight : !pdl.value, !pdl.value, !pdl.value) -> (%out_type, %in_type, %weight_type : !pdl.type, !pdl.type, !pdl.type)
211       %val4_0 = pdl.result 0 of %op4  // txact
212       %val4_1 = pdl.result 1 of %op4  // txdelta
213       %val4_2 = pdl.result 2 of %op4  // weight_out
214       pdl.replace %op3 with (%val4_2 : !pdl.value)  // Sgd
215       pdl.erase %op2  // tf.MatMul
216       pdl.replace %op1 with (%val4_1 : !pdl.value)  // tf.MatMul
217       pdl.replace %op0 with (%val4_0 : !pdl.value)  // tf.MatMul
218     }
219   }
221   // softmax_cross_entropy
222   pdl.pattern : benefit(6) {
223     %in_type = pdl.type
224     %label_type = pdl.type
225     %loss_type = pdl.type
226     %mean_loss_type = pdl.type
227     %mean_const_type = pdl.type
228     %mul_const_type = pdl.type
229     %rxact = pdl.operand : %in_type
230     %rxlabel = pdl.operand : %label_type
232     %op0 = pdl.operation "tf.SparseSoftmaxCrossEntropyWithLogits" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%loss_type, %in_type : !pdl.type, !pdl.type)
233     %val0_0 = pdl.result 0 of %op0  // loss
234     %val0_1 = pdl.result 1 of %op0  // gradient
235     %op1 = pdl.operation "tf.Const" -> (%mean_const_type : !pdl.type)
236     %val1 = pdl.result 0 of %op1
237     %op2 = pdl.operation "tf.Mean" (%val0_0, %val1 : !pdl.value, !pdl.value) -> (%mean_loss_type : !pdl.type)
238     %val2 = pdl.result 0 of %op2
239     %op3 = pdl.operation "tf.PreventGradient" (%val0_1 : !pdl.value) -> (%in_type : !pdl.type)
240     %val3 = pdl.result 0 of %op3
241     %op4 = pdl.operation "tf.Const" -> (%mul_const_type : !pdl.type)
242     %val4 = pdl.result 0 of %op4
243     %op5 = pdl.operation "tf.Mul" (%val3, %val4 : !pdl.value, !pdl.value) -> (%in_type : !pdl.type)
245     pdl.rewrite {  // roots: %op2, %op5
246       %op6 = pdl.operation "kernel.SoftmaxCrossEntropy" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%mean_loss_type, %in_type : !pdl.type, !pdl.type)
247       %val6_0 = pdl.result 0 of %op6  // txloss
248       %val6_1 = pdl.result 1 of %op6  // txdelta
249       pdl.replace %op5 with (%val6_1 : !pdl.value)  // tf.Mul
250       pdl.erase %op4  // tf.Const
251       pdl.erase %op3  // tf.PreventGradient
252       pdl.replace %op2 with (%val6_0 : !pdl.value)  // tf.Mean
253       pdl.erase %op1  // tf.Const
254       pdl.erase %op0  // tf.SparseSoftmaxCrossEntropyWithLogits
255     }
256   }
259 // CHECK-LABEL: test.mlp_fused
260 // CHECK: %[[FC2:.*]]:3 = "kernel.Fc"(%[[FC1:.*]]#0, %[[SM:.*]]#1, %arg4) : (tensor<2x256xf32>, tensor<2x10xf32>, tensor<256x10xf32>) -> (tensor<2x10xf32>, tensor<2x256xf32>, tensor<256x10xf32>)
261 // CHECK: %[[SM]]:2 = "kernel.SoftmaxCrossEntropy"(%[[FC2]]#0, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<f32>, tensor<2x10xf32>)
262 // CHECK: %[[FC1]]:3 = "kernel.FcWithBias"(%arg0, %[[FC2]]#1, %arg3, %arg2) : (tensor<2x20xf32>, tensor<2x256xf32>, tensor<20x256xf32>, tensor<256xf32>) -> (tensor<2x256xf32>, tensor<20x256xf32>, tensor<256xf32>)
263 module @ir attributes { test.mlp_fused } {
264   func.func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<256xf32>, %arg3: tensor<20x256xf32>, %arg4: tensor<256x10xf32>) -> () { // tensor<f32>, tensor<256xf32>, tensor<20x256xf32>, tensor<256x10xf32>) {
265     // The replacement operations fuse forward and backward pass; therefore, the
266     // resulting graph is not a DAG. To address this, we wrap the operations in
267     // a graph region.
268     "test.graph_region"() ({
269       %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
270       %1 = "tf.Const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32>
271       %2 = "tf.Const"() {value = dense<5.000000e-01> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
272       %3 = "tf.MatMul"(%arg0, %arg3) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x256xf32>) -> tensor<2x256xf32>
273       %4 = "tf.BiasAdd"(%3, %arg2) {data_format = "NHWC"} : (tensor<2x256xf32>, tensor<256xf32>) -> tensor<2x256xf32>
274       %5 = "tf.Relu"(%4) : (tensor<2x256xf32>) -> tensor<2x256xf32>
275       %6 = "tf.MatMul"(%5, %arg4) {transpose_a = false, transpose_b = false} : (tensor<2x256xf32>, tensor<256x10xf32>) -> tensor<2x10xf32>
276       %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%6, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x10xf32>)
277       %7 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
278       %8 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) -> tensor<2x10xf32>
279       %9 = "tf.Mul"(%8, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) -> tensor<2x10xf32>
280       %10 = "tf.MatMul"(%9, %arg4) {transpose_a = false, transpose_b = true} : (tensor<2x10xf32>, tensor<256x10xf32>) -> tensor<2x256xf32>
281       %11 = "tf.MatMul"(%5, %9) {transpose_a = true, transpose_b = false} : (tensor<2x256xf32>, tensor<2x10xf32>) -> tensor<256x10xf32>
282       %12 = "tf.ReluGrad"(%10, %5) : (tensor<2x256xf32>, tensor<2x256xf32>) -> tensor<2x256xf32>
283       %13 = "tf.BiasAddGrad"(%12) {data_format = "NHWC"} : (tensor<2x256xf32>) -> tensor<256xf32>
284       %14 = "tf.MatMul"(%arg0, %12) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x256xf32>) -> tensor<20x256xf32>
285       %15 = "tf.Mul"(%14, %1) : (tensor<20x256xf32>, tensor<f32>) -> tensor<20x256xf32>
286       %16 = "tf.Sub"(%arg3, %15) : (tensor<20x256xf32>, tensor<20x256xf32>) -> tensor<20x256xf32>
287       %17 = "tf.Mul"(%13, %1) : (tensor<256xf32>, tensor<f32>) -> tensor<256xf32>
288       %18 = "tf.Sub"(%arg2, %17) : (tensor<256xf32>, tensor<256xf32>) -> tensor<256xf32>
289       %19 = "tf.Mul"(%11, %1) : (tensor<256x10xf32>, tensor<f32>) -> tensor<256x10xf32>
290       %20 = "tf.Sub"(%arg4, %19) : (tensor<256x10xf32>, tensor<256x10xf32>) -> tensor<256x10xf32>
291     }) : () -> ()
292     return
293   }