1 // RUN: mlir-opt %s -allow-unregistered-dialect -test-pdl-bytecode-pass -split-input-file | FileCheck %s
5 //===----------------------------------------------------------------------===//
6 // 1-layer perceptron with split fwd/bwd operations
7 //===----------------------------------------------------------------------===//
11 pdl.pattern : benefit(1) {
14 %weight_type = pdl.type
15 %rxact = pdl.operand : %in_type
16 %weight = pdl.operand : %weight_type
18 %attr0 = pdl.attribute = false
19 %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type)
22 %op1 = pdl.operation "kernel.FcFwd" (%rxact, %weight : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
23 %val1 = pdl.result 0 of %op1 // txact
24 pdl.replace %op0 with (%val1 : !pdl.value) // tf.MatMul
29 pdl.pattern : benefit(4) {
32 %weight_type = pdl.type
33 %const_type = pdl.type
34 %rxact = pdl.operand : %in_type
35 %rxdelta = pdl.operand : %out_type
36 %weight = pdl.operand : %weight_type
38 %attr0 = pdl.attribute = true
39 %attr1 = pdl.attribute = false
40 %op0 = pdl.operation "tf.MatMul" (%rxact, %rxdelta : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr1} -> (%weight_type : !pdl.type)
41 %val0 = pdl.result 0 of %op0
42 %op1 = pdl.operation "tf.Const" -> (%const_type : !pdl.type)
43 %val1 = pdl.result 0 of %op1
44 %op2 = pdl.operation "tf.Mul" (%val0, %val1 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
45 %val2 = pdl.result 0 of %op2
46 %op3 = pdl.operation "tf.Sub" (%weight, %val2 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
49 %op4 = pdl.operation "kernel.FcBwd" (%rxact, %rxdelta, %weight : !pdl.value, !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
50 %val4 = pdl.result 0 of %op4 // weight_out
51 pdl.replace %op3 with (%val4 : !pdl.value) // tf.Sub
52 pdl.erase %op2 // tf.Mul
53 pdl.erase %op1 // tf.Const
54 pdl.erase %op0 // tf.MatMul
58 // softmax_cross_entropy
59 pdl.pattern : benefit(6) {
61 %label_type = pdl.type
63 %mean_loss_type = pdl.type
64 %mean_const_type = pdl.type
65 %mul_const_type = pdl.type
66 %rxact = pdl.operand : %in_type
67 %rxlabel = pdl.operand : %label_type
69 %op0 = pdl.operation "tf.SparseSoftmaxCrossEntropyWithLogits" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%loss_type, %in_type : !pdl.type, !pdl.type)
70 %val0_0 = pdl.result 0 of %op0 // loss
71 %val0_1 = pdl.result 1 of %op0 // gradient
72 %op1 = pdl.operation "tf.Const" -> (%mean_const_type : !pdl.type)
73 %val1 = pdl.result 0 of %op1
74 %op2 = pdl.operation "tf.Mean" (%val0_0, %val1 : !pdl.value, !pdl.value) -> (%mean_loss_type : !pdl.type)
75 %val2 = pdl.result 0 of %op2
76 %op3 = pdl.operation "tf.PreventGradient" (%val0_1 : !pdl.value) -> (%in_type : !pdl.type)
77 %val3 = pdl.result 0 of %op3
78 %op4 = pdl.operation "tf.Const" -> (%mul_const_type : !pdl.type)
79 %val4 = pdl.result 0 of %op4
80 %op5 = pdl.operation "tf.Mul" (%val3, %val4 : !pdl.value, !pdl.value) -> (%in_type : !pdl.type)
82 pdl.rewrite { // roots: %op2, %op5
83 %op6 = pdl.operation "kernel.SoftmaxCrossEntropy" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%mean_loss_type, %in_type : !pdl.type, !pdl.type)
84 %val6_0 = pdl.result 0 of %op6 // txloss
85 %val6_1 = pdl.result 1 of %op6 // txdelta
86 pdl.replace %op5 with (%val6_1 : !pdl.value) // tf.Mul
87 pdl.erase %op4 // tf.Const
88 pdl.erase %op3 // tf.PreventGradient
89 pdl.replace %op2 with (%val6_0 : !pdl.value) // tf.Mean
90 pdl.erase %op1 // tf.Const
91 pdl.erase %op0 // tf.SparseSoftmaxCrossEntropyWithLogits
96 // CHECK-LABEL: test.mlp_split
97 // CHECK: %[[FWD:.*]] = "kernel.FcFwd"(%arg0, %arg2) : (tensor<2x20xf32>, tensor<20x10xf32>) -> tensor<2x10xf32>
98 // CHECK: %[[SM:.*]]:2 = "kernel.SoftmaxCrossEntropy"(%[[FWD]], %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<f32>, tensor<2x10xf32>)
99 // CHECK: %[[BWD:.*]] = "kernel.FcBwd"(%arg0, %[[SM]]#1, %arg2) : (tensor<2x20xf32>, tensor<2x10xf32>, tensor<20x10xf32>) -> tensor<20x10xf32>
100 // CHECK: return %[[SM:.*]]#0, %[[BWD]] : tensor<f32>, tensor<20x10xf32>
101 module @ir attributes { test.mlp_split } {
102 func.func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<20x10xf32>) -> (tensor<f32>, tensor<20x10xf32>) {
103 %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
104 %1 = "tf.Const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32>
105 %2 = "tf.Const"() {value = dense<5.000000e-01> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
106 %3 = "tf.MatMul"(%arg0, %arg2) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x10xf32>) -> tensor<2x10xf32>
107 %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%3, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x10xf32>)
108 %4 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
109 %5 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) -> tensor<2x10xf32>
110 %6 = "tf.Mul"(%5, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) -> tensor<2x10xf32>
111 %7 = "tf.MatMul"(%arg0, %6) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x10xf32>) -> tensor<20x10xf32>
112 %8 = "tf.Mul"(%7, %1) : (tensor<20x10xf32>, tensor<f32>) -> tensor<20x10xf32>
113 %9 = "tf.Sub"(%arg2, %8) : (tensor<20x10xf32>, tensor<20x10xf32>) -> tensor<20x10xf32>
114 return %4, %9 : tensor<f32>, tensor<20x10xf32>
120 //===----------------------------------------------------------------------===//
121 // 2-layer perceptron with fused fwd/bwd operations
122 //===----------------------------------------------------------------------===//
127 pdl.pattern : benefit(3) {
128 %const_type = pdl.type
129 %param_type = pdl.type
130 %param = pdl.operand : %param_type
131 %gradient = pdl.operand : %param_type
133 %attr0 = pdl.attribute
134 %op0 = pdl.operation "tf.Const" {"value" = %attr0} -> (%const_type : !pdl.type)
135 %val0 = pdl.result 0 of %op0
136 %op1 = pdl.operation "tf.Mul" (%gradient, %val0 : !pdl.value, !pdl.value) -> (%param_type : !pdl.type)
137 %val1 = pdl.result 0 of %op1
138 %op2 = pdl.operation "tf.Sub" (%param, %val1 : !pdl.value, !pdl.value) -> (%param_type : !pdl.type)
141 %op3 = pdl.operation "kernel.GD" (%param, %gradient : !pdl.value, !pdl.value) -> (%param_type : !pdl.type)
142 %val3 = pdl.result 0 of %op3
143 pdl.replace %op2 with (%val3 : !pdl.value) // tf.Sub
144 pdl.erase %op1 // tf.Mul
149 pdl.pattern : benefit(8) {
152 %weight_type = pdl.type
153 %bias_type = pdl.type
154 %rxact = pdl.operand : %in_type
155 %rxdelta = pdl.operand : %out_type
156 %weight = pdl.operand : %weight_type
157 %bias = pdl.operand : %bias_type
159 %attr0 = pdl.attribute = false
160 %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type)
161 %val0 = pdl.result 0 of %op0
162 %op1 = pdl.operation "tf.BiasAdd" (%val0, %bias : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
163 %val1 = pdl.result 0 of %op1
164 %op2 = pdl.operation "tf.Relu" (%val1 : !pdl.value) -> (%out_type : !pdl.type)
165 %val2 = pdl.result 0 of %op2
166 %op3 = pdl.operation "tf.ReluGrad" (%rxdelta, %val2 : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
167 %val3 = pdl.result 0 of %op3
168 %attr1 = pdl.attribute = true
169 %op4 = pdl.operation "tf.MatMul" (%rxact, %val3 : !pdl.value, !pdl.value) {"transpose_a" = %attr1, "transpose_b" = %attr0} -> (%weight_type : !pdl.type)
170 %val4 = pdl.result 0 of %op4
171 %op5 = pdl.operation "kernel.GD" (%weight, %val4 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
172 %op6 = pdl.operation "tf.BiasAddGrad" (%val3 : !pdl.value) -> (%bias_type : !pdl.type)
173 %val6 = pdl.result 0 of %op6
174 %op7 = pdl.operation "kernel.GD" (%bias, %val6 : !pdl.value, !pdl.value) -> (%bias_type : !pdl.type)
176 pdl.rewrite { // roots: %op2, %op5, %op7
177 %op8 = pdl.operation "kernel.FcWithBias" (%rxact, %rxdelta, %weight, %bias : !pdl.value, !pdl.value, !pdl.value, !pdl.value) -> (%out_type, %weight_type, %bias_type : !pdl.type, !pdl.type, !pdl.type)
178 %val8_0 = pdl.result 0 of %op8 // txact
179 %val8_1 = pdl.result 1 of %op8 // weight_out
180 %val8_2 = pdl.result 2 of %op8 // bias_out
181 pdl.replace %op7 with (%val8_2 : !pdl.value) // kernel.GD
182 pdl.erase %op6 // tf.BiasAddGrad
183 pdl.replace %op5 with (%val8_1 : !pdl.value) // kernel.GD
184 pdl.erase %op4 // tf.MatMul
185 pdl.erase %op3 // tf.ReluGrad
186 pdl.replace %op2 with (%val8_0 : !pdl.value) // tf.Relu
187 pdl.erase %op1 // tf.BiasAdd
188 pdl.erase %op0 // tf.MatMul
193 pdl.pattern : benefit(4) {
196 %weight_type = pdl.type
197 %rxact = pdl.operand : %in_type
198 %rxdelta = pdl.operand : %out_type
199 %weight = pdl.operand : %weight_type
201 %attr0 = pdl.attribute = false
202 %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type)
203 %attr1 = pdl.attribute = true
204 %op1 = pdl.operation "tf.MatMul" (%rxdelta, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr1} -> (%in_type : !pdl.type)
205 %op2 = pdl.operation "tf.MatMul" (%rxact, %rxdelta : !pdl.value, !pdl.value) {"transpose_a" = %attr1, "transpose_b" = %attr0} -> (%weight_type : !pdl.type)
206 %val2 = pdl.result 0 of %op2
207 %op3 = pdl.operation "kernel.GD" (%weight, %val2 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
209 pdl.rewrite { // roots: %op0, %op1, %op3
210 %op4 = pdl.operation "kernel.Fc" (%rxact, %rxdelta, %weight : !pdl.value, !pdl.value, !pdl.value) -> (%out_type, %in_type, %weight_type : !pdl.type, !pdl.type, !pdl.type)
211 %val4_0 = pdl.result 0 of %op4 // txact
212 %val4_1 = pdl.result 1 of %op4 // txdelta
213 %val4_2 = pdl.result 2 of %op4 // weight_out
214 pdl.replace %op3 with (%val4_2 : !pdl.value) // Sgd
215 pdl.erase %op2 // tf.MatMul
216 pdl.replace %op1 with (%val4_1 : !pdl.value) // tf.MatMul
217 pdl.replace %op0 with (%val4_0 : !pdl.value) // tf.MatMul
221 // softmax_cross_entropy
222 pdl.pattern : benefit(6) {
224 %label_type = pdl.type
225 %loss_type = pdl.type
226 %mean_loss_type = pdl.type
227 %mean_const_type = pdl.type
228 %mul_const_type = pdl.type
229 %rxact = pdl.operand : %in_type
230 %rxlabel = pdl.operand : %label_type
232 %op0 = pdl.operation "tf.SparseSoftmaxCrossEntropyWithLogits" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%loss_type, %in_type : !pdl.type, !pdl.type)
233 %val0_0 = pdl.result 0 of %op0 // loss
234 %val0_1 = pdl.result 1 of %op0 // gradient
235 %op1 = pdl.operation "tf.Const" -> (%mean_const_type : !pdl.type)
236 %val1 = pdl.result 0 of %op1
237 %op2 = pdl.operation "tf.Mean" (%val0_0, %val1 : !pdl.value, !pdl.value) -> (%mean_loss_type : !pdl.type)
238 %val2 = pdl.result 0 of %op2
239 %op3 = pdl.operation "tf.PreventGradient" (%val0_1 : !pdl.value) -> (%in_type : !pdl.type)
240 %val3 = pdl.result 0 of %op3
241 %op4 = pdl.operation "tf.Const" -> (%mul_const_type : !pdl.type)
242 %val4 = pdl.result 0 of %op4
243 %op5 = pdl.operation "tf.Mul" (%val3, %val4 : !pdl.value, !pdl.value) -> (%in_type : !pdl.type)
245 pdl.rewrite { // roots: %op2, %op5
246 %op6 = pdl.operation "kernel.SoftmaxCrossEntropy" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%mean_loss_type, %in_type : !pdl.type, !pdl.type)
247 %val6_0 = pdl.result 0 of %op6 // txloss
248 %val6_1 = pdl.result 1 of %op6 // txdelta
249 pdl.replace %op5 with (%val6_1 : !pdl.value) // tf.Mul
250 pdl.erase %op4 // tf.Const
251 pdl.erase %op3 // tf.PreventGradient
252 pdl.replace %op2 with (%val6_0 : !pdl.value) // tf.Mean
253 pdl.erase %op1 // tf.Const
254 pdl.erase %op0 // tf.SparseSoftmaxCrossEntropyWithLogits
259 // CHECK-LABEL: test.mlp_fused
260 // CHECK: %[[FC2:.*]]:3 = "kernel.Fc"(%[[FC1:.*]]#0, %[[SM:.*]]#1, %arg4) : (tensor<2x256xf32>, tensor<2x10xf32>, tensor<256x10xf32>) -> (tensor<2x10xf32>, tensor<2x256xf32>, tensor<256x10xf32>)
261 // CHECK: %[[SM]]:2 = "kernel.SoftmaxCrossEntropy"(%[[FC2]]#0, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<f32>, tensor<2x10xf32>)
262 // CHECK: %[[FC1]]:3 = "kernel.FcWithBias"(%arg0, %[[FC2]]#1, %arg3, %arg2) : (tensor<2x20xf32>, tensor<2x256xf32>, tensor<20x256xf32>, tensor<256xf32>) -> (tensor<2x256xf32>, tensor<20x256xf32>, tensor<256xf32>)
263 module @ir attributes { test.mlp_fused } {
264 func.func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<256xf32>, %arg3: tensor<20x256xf32>, %arg4: tensor<256x10xf32>) -> () { // tensor<f32>, tensor<256xf32>, tensor<20x256xf32>, tensor<256x10xf32>) {
265 // The replacement operations fuse forward and backward pass; therefore, the
266 // resulting graph is not a DAG. To address this, we wrap the operations in
268 "test.graph_region"() ({
269 %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
270 %1 = "tf.Const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32>
271 %2 = "tf.Const"() {value = dense<5.000000e-01> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
272 %3 = "tf.MatMul"(%arg0, %arg3) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x256xf32>) -> tensor<2x256xf32>
273 %4 = "tf.BiasAdd"(%3, %arg2) {data_format = "NHWC"} : (tensor<2x256xf32>, tensor<256xf32>) -> tensor<2x256xf32>
274 %5 = "tf.Relu"(%4) : (tensor<2x256xf32>) -> tensor<2x256xf32>
275 %6 = "tf.MatMul"(%5, %arg4) {transpose_a = false, transpose_b = false} : (tensor<2x256xf32>, tensor<256x10xf32>) -> tensor<2x10xf32>
276 %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%6, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x10xf32>)
277 %7 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
278 %8 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) -> tensor<2x10xf32>
279 %9 = "tf.Mul"(%8, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) -> tensor<2x10xf32>
280 %10 = "tf.MatMul"(%9, %arg4) {transpose_a = false, transpose_b = true} : (tensor<2x10xf32>, tensor<256x10xf32>) -> tensor<2x256xf32>
281 %11 = "tf.MatMul"(%5, %9) {transpose_a = true, transpose_b = false} : (tensor<2x256xf32>, tensor<2x10xf32>) -> tensor<256x10xf32>
282 %12 = "tf.ReluGrad"(%10, %5) : (tensor<2x256xf32>, tensor<2x256xf32>) -> tensor<2x256xf32>
283 %13 = "tf.BiasAddGrad"(%12) {data_format = "NHWC"} : (tensor<2x256xf32>) -> tensor<256xf32>
284 %14 = "tf.MatMul"(%arg0, %12) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x256xf32>) -> tensor<20x256xf32>
285 %15 = "tf.Mul"(%14, %1) : (tensor<20x256xf32>, tensor<f32>) -> tensor<20x256xf32>
286 %16 = "tf.Sub"(%arg3, %15) : (tensor<20x256xf32>, tensor<20x256xf32>) -> tensor<20x256xf32>
287 %17 = "tf.Mul"(%13, %1) : (tensor<256xf32>, tensor<f32>) -> tensor<256xf32>
288 %18 = "tf.Sub"(%arg2, %17) : (tensor<256xf32>, tensor<256xf32>) -> tensor<256xf32>
289 %19 = "tf.Mul"(%11, %1) : (tensor<256x10xf32>, tensor<f32>) -> tensor<256x10xf32>
290 %20 = "tf.Sub"(%arg4, %19) : (tensor<256x10xf32>, tensor<256x10xf32>) -> tensor<256x10xf32>