1 # RUN: %PYTHON %s | FileCheck %s
3 import mlir
.dialects
.arith
as arith
4 import mlir
.dialects
.memref
as memref
5 import mlir
.extras
.types
as T
6 from mlir
.dialects
.memref
import _infer_memref_subview_result_type
11 print("\nTEST:", f
.__name
__)
16 # CHECK-LABEL: TEST: testSubViewAccessors
18 def testSubViewAccessors():
20 module
= Module
.parse(
22 func.func @f1(%arg0: memref<?x?xf32>) {
23 %0 = arith.constant 0 : index
24 %1 = arith.constant 1 : index
25 %2 = arith.constant 2 : index
26 %3 = arith.constant 3 : index
27 %4 = arith.constant 4 : index
28 %5 = arith.constant 5 : index
29 memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
35 func_body
= module
.body
.operations
[0].regions
[0].blocks
[0]
36 subview
= func_body
.operations
[6]
38 assert subview
.source
== subview
.operands
[0]
39 assert len(subview
.offsets
) == 2
40 assert len(subview
.sizes
) == 2
41 assert len(subview
.strides
) == 2
42 assert subview
.result
== subview
.results
[0]
45 print(type(subview
).__name
__)
48 print(subview
.offsets
[0])
50 print(subview
.offsets
[1])
52 print(subview
.sizes
[0])
54 print(subview
.sizes
[1])
56 print(subview
.strides
[0])
58 print(subview
.strides
[1])
61 # CHECK-LABEL: TEST: testCustomBuidlers
63 def testCustomBuidlers():
64 with
Context() as ctx
, Location
.unknown(ctx
):
65 module
= Module
.parse(
67 func.func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) {
72 f
= module
.body
.operations
[0]
73 func_body
= f
.regions
[0].blocks
[0]
74 with InsertionPoint
.at_block_terminator(func_body
):
75 memref
.LoadOp(f
.arguments
[0], f
.arguments
[1:])
77 # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
78 # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
80 assert module
.operation
.verify()
83 # CHECK-LABEL: TEST: testMemRefAttr
86 with
Context() as ctx
, Location
.unknown(ctx
):
87 module
= Module
.create()
88 with
InsertionPoint(module
.body
):
89 memref
.global_("objFifo_in0", T
.memref(16, T
.i32()))
90 # CHECK: memref.global @objFifo_in0 : memref<16xi32>
94 # CHECK-LABEL: TEST: testSubViewOpInferReturnTypeSemantics
96 def testSubViewOpInferReturnTypeSemantics():
97 with
Context() as ctx
, Location
.unknown(ctx
):
98 module
= Module
.create()
99 with
InsertionPoint(module
.body
):
100 x
= memref
.alloc(T
.memref(10, 10, T
.i32()), [], [])
101 # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32>
104 y
= memref
.subview(x
, [1, 1], [3, 3], [1, 1])
105 assert y
.owner
.verify()
106 # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
111 [arith
.constant(T
.index(), 1), 1],
115 # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
120 [arith
.constant(T
.index(), 3), arith
.constant(T
.index(), 4)],
124 # CHECK: %{{.*}} = memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>>
127 s
= arith
.addi(arith
.constant(T
.index(), 3), arith
.constant(T
.index(), 4))
134 # CHECK: {{.*}} = memref.subview %[[ALLOC]][%0, 0] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: ?>>
138 _infer_memref_subview_result_type(
140 [arith
.constant(T
.index(), 3), arith
.constant(T
.index(), 4)],
141 [ShapedType
.get_dynamic_size(), 3],
144 except ValueError as e
:
145 # CHECK: Only inferring from python or mlir integer constant is supported
151 [arith
.constant(T
.index(), 3), arith
.constant(T
.index(), 4)],
152 [ShapedType
.get_dynamic_size(), 3],
155 except ValueError as e
:
156 # CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type
159 layout
= StridedLayoutAttr
.get(ShapedType
.get_dynamic_size(), [10, 1])
168 [arith
.constant(T
.index(), 42)],
170 # CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>>
177 result_type
=T
.memref(3, 3, T
.i32(), layout
=layout
),
179 # CHECK: %{{.*}} = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>>
183 # CHECK-LABEL: TEST: testSubViewOpInferReturnTypeExtensiveSlicing
185 def testSubViewOpInferReturnTypeExtensiveSlicing():
186 def check_strides_offset(memref
, np_view
):
187 layout
= memref
.type.layout
188 dtype_size_in_bytes
= np_view
.dtype
.itemsize
189 golden_strides
= (np
.array(np_view
.strides
) // dtype_size_in_bytes
).tolist()
191 np_view
.ctypes
.data
- np_view
.base
.ctypes
.data
192 ) // dtype_size_in_bytes
194 assert (layout
.strides
, layout
.offset
) == (golden_strides
, golden_offset
)
196 with
Context() as ctx
, Location
.unknown(ctx
):
197 module
= Module
.create()
198 with
InsertionPoint(module
.body
):
199 shape
= (10, 22, 3, 44)
200 golden_mem
= np
.zeros(shape
, dtype
=np
.int32
)
201 mem1
= memref
.alloc(T
.memref(*shape
, T
.i32()), [], [])
204 check_strides_offset(memref
.subview(mem1
, (1, 0, 0, 0), (1, 22, 3, 44), (1, 1, 1, 1)), golden_mem
[1:2, ...])
205 check_strides_offset(memref
.subview(mem1
, (0, 1, 0, 0), (10, 1, 3, 44), (1, 1, 1, 1)), golden_mem
[:, 1:2])
206 check_strides_offset(memref
.subview(mem1
, (0, 0, 1, 0), (10, 22, 1, 44), (1, 1, 1, 1)), golden_mem
[:, :, 1:2])
207 check_strides_offset(memref
.subview(mem1
, (0, 0, 0, 1), (10, 22, 3, 1), (1, 1, 1, 1)), golden_mem
[:, :, :, 1:2])
208 check_strides_offset(memref
.subview(mem1
, (0, 1, 0, 1), (10, 1, 3, 1), (1, 1, 1, 1)), golden_mem
[:, 1:2, :, 1:2])
209 check_strides_offset(memref
.subview(mem1
, (1, 0, 0, 1), (1, 22, 3, 1), (1, 1, 1, 1)), golden_mem
[1:2, :, :, 1:2])
210 check_strides_offset(memref
.subview(mem1
, (1, 1, 0, 0), (1, 1, 3, 44), (1, 1, 1, 1)), golden_mem
[1:2, 1:2, :, :])
211 check_strides_offset(memref
.subview(mem1
, (0, 0, 1, 1), (10, 22, 1, 1), (1, 1, 1, 1)), golden_mem
[:, :, 1:2, 1:2])
212 check_strides_offset(memref
.subview(mem1
, (0, 1, 1, 0), (10, 1, 1, 44), (1, 1, 1, 1)), golden_mem
[:, 1:2, 1:2, :])
213 check_strides_offset(memref
.subview(mem1
, (1, 0, 1, 0), (1, 22, 1, 44), (1, 1, 1, 1)), golden_mem
[1:2, :, 1:2, :])
214 check_strides_offset(memref
.subview(mem1
, (1, 1, 0, 1), (1, 1, 3, 1), (1, 1, 1, 1)), golden_mem
[1:2, 1:2, :, 1:2])
215 check_strides_offset(memref
.subview(mem1
, (1, 0, 1, 1), (1, 22, 1, 1), (1, 1, 1, 1)), golden_mem
[1:2, :, 1:2, 1:2])
216 check_strides_offset(memref
.subview(mem1
, (0, 1, 1, 1), (10, 1, 1, 1), (1, 1, 1, 1)), golden_mem
[:, 1:2, 1:2, 1:2])
217 check_strides_offset(memref
.subview(mem1
, (1, 1, 1, 0), (1, 1, 1, 44), (1, 1, 1, 1)), golden_mem
[1:2, 1:2, 1:2, :])
220 # default strides and offset means no stridedlayout attribute means affinemap layout
221 assert memref
.subview(
222 mem1
, (0, 0, 0, 0), (10, 22, 3, 44), (1, 1, 1, 1)
223 ).type.layout
== AffineMapAttr
.get(
228 AffineDimExpr
.get(0),
229 AffineDimExpr
.get(1),
230 AffineDimExpr
.get(2),
231 AffineDimExpr
.get(3),
236 shape
= (7, 22, 30, 44)
237 golden_mem
= np
.zeros(shape
, dtype
=np
.int32
)
238 mem2
= memref
.alloc(T
.memref(*shape
, T
.i32()), [], [])
240 check_strides_offset(memref
.subview(mem2
, (0, 0, 0, 0), (7, 11, 3, 44), (1, 2, 1, 1)), golden_mem
[:, 0:22:2])
241 check_strides_offset(memref
.subview(mem2
, (0, 0, 0, 0), (7, 11, 11, 44), (1, 2, 30, 1)), golden_mem
[:, 0:22:2, 0:330:30])
242 check_strides_offset(memref
.subview(mem2
, (0, 0, 0, 0), (7, 11, 11, 11), (1, 2, 30, 400)), golden_mem
[:, 0:22:2, 0:330:30, 0:4400:400])
246 golden_mem
= np
.zeros(shape
, dtype
=np
.int32
)
248 mem3
= memref
.alloc(T
.memref(*shape
, T
.i32()), [], [])
249 check_strides_offset(memref
.subview(mem3
, (0, 0), (4, 4), (1, 1)), golden_mem
[0:4, 0:4])
250 check_strides_offset(memref
.subview(mem3
, (4, 4), (4, 4), (1, 1)), golden_mem
[4:8, 4:8])