[llvm-objcopy] Fix prints wrong path when dump-section output path doesn't exist...
[llvm-project.git] / mlir / test / python / dialects / memref.py
blobb91fdc367cf301629440a7b0364c649057933015
1 # RUN: %PYTHON %s | FileCheck %s
3 import mlir.dialects.arith as arith
4 import mlir.dialects.memref as memref
5 import mlir.extras.types as T
6 from mlir.dialects.memref import _infer_memref_subview_result_type
7 from mlir.ir import *
10 def run(f):
11 print("\nTEST:", f.__name__)
12 f()
13 return f
16 # CHECK-LABEL: TEST: testSubViewAccessors
17 @run
18 def testSubViewAccessors():
19 ctx = Context()
20 module = Module.parse(
21 r"""
22 func.func @f1(%arg0: memref<?x?xf32>) {
23 %0 = arith.constant 0 : index
24 %1 = arith.constant 1 : index
25 %2 = arith.constant 2 : index
26 %3 = arith.constant 3 : index
27 %4 = arith.constant 4 : index
28 %5 = arith.constant 5 : index
29 memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
30 return
32 """,
33 ctx,
35 func_body = module.body.operations[0].regions[0].blocks[0]
36 subview = func_body.operations[6]
38 assert subview.source == subview.operands[0]
39 assert len(subview.offsets) == 2
40 assert len(subview.sizes) == 2
41 assert len(subview.strides) == 2
42 assert subview.result == subview.results[0]
44 # CHECK: SubViewOp
45 print(type(subview).__name__)
47 # CHECK: constant 0
48 print(subview.offsets[0])
49 # CHECK: constant 1
50 print(subview.offsets[1])
51 # CHECK: constant 2
52 print(subview.sizes[0])
53 # CHECK: constant 3
54 print(subview.sizes[1])
55 # CHECK: constant 4
56 print(subview.strides[0])
57 # CHECK: constant 5
58 print(subview.strides[1])
61 # CHECK-LABEL: TEST: testCustomBuidlers
62 @run
63 def testCustomBuidlers():
64 with Context() as ctx, Location.unknown(ctx):
65 module = Module.parse(
66 r"""
67 func.func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) {
68 return
70 """
72 f = module.body.operations[0]
73 func_body = f.regions[0].blocks[0]
74 with InsertionPoint.at_block_terminator(func_body):
75 memref.LoadOp(f.arguments[0], f.arguments[1:])
77 # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
78 # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
79 print(module)
80 assert module.operation.verify()
83 # CHECK-LABEL: TEST: testMemRefAttr
84 @run
85 def testMemRefAttr():
86 with Context() as ctx, Location.unknown(ctx):
87 module = Module.create()
88 with InsertionPoint(module.body):
89 memref.global_("objFifo_in0", T.memref(16, T.i32()))
90 # CHECK: memref.global @objFifo_in0 : memref<16xi32>
91 print(module)
94 # CHECK-LABEL: TEST: testSubViewOpInferReturnTypeSemantics
95 @run
96 def testSubViewOpInferReturnTypeSemantics():
97 with Context() as ctx, Location.unknown(ctx):
98 module = Module.create()
99 with InsertionPoint(module.body):
100 x = memref.alloc(T.memref(10, 10, T.i32()), [], [])
101 # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32>
102 print(x.owner)
104 y = memref.subview(x, [1, 1], [3, 3], [1, 1])
105 assert y.owner.verify()
106 # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
107 print(y.owner)
109 z = memref.subview(
111 [arith.constant(T.index(), 1), 1],
112 [3, 3],
113 [1, 1],
115 # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
116 print(z.owner)
118 z = memref.subview(
120 [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
121 [3, 3],
122 [1, 1],
124 # CHECK: %{{.*}} = memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>>
125 print(z.owner)
127 s = arith.addi(arith.constant(T.index(), 3), arith.constant(T.index(), 4))
128 z = memref.subview(
130 [s, 0],
131 [3, 3],
132 [1, 1],
134 # CHECK: {{.*}} = memref.subview %[[ALLOC]][%0, 0] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: ?>>
135 print(z)
137 try:
138 _infer_memref_subview_result_type(
139 x.type,
140 [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
141 [ShapedType.get_dynamic_size(), 3],
142 [1, 1],
144 except ValueError as e:
145 # CHECK: Only inferring from python or mlir integer constant is supported
146 print(e)
148 try:
149 memref.subview(
151 [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
152 [ShapedType.get_dynamic_size(), 3],
153 [1, 1],
155 except ValueError as e:
156 # CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type
157 print(e)
159 layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1])
160 x = memref.alloc(
161 T.memref(
164 T.i32(),
165 layout=layout,
168 [arith.constant(T.index(), 42)],
170 # CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>>
171 print(x.owner)
172 y = memref.subview(
174 [1, 1],
175 [3, 3],
176 [1, 1],
177 result_type=T.memref(3, 3, T.i32(), layout=layout),
179 # CHECK: %{{.*}} = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>>
180 print(y.owner)
183 # CHECK-LABEL: TEST: testSubViewOpInferReturnTypeExtensiveSlicing
184 @run
185 def testSubViewOpInferReturnTypeExtensiveSlicing():
186 def check_strides_offset(memref, np_view):
187 layout = memref.type.layout
188 dtype_size_in_bytes = np_view.dtype.itemsize
189 golden_strides = (np.array(np_view.strides) // dtype_size_in_bytes).tolist()
190 golden_offset = (
191 np_view.ctypes.data - np_view.base.ctypes.data
192 ) // dtype_size_in_bytes
194 assert (layout.strides, layout.offset) == (golden_strides, golden_offset)
196 with Context() as ctx, Location.unknown(ctx):
197 module = Module.create()
198 with InsertionPoint(module.body):
199 shape = (10, 22, 3, 44)
200 golden_mem = np.zeros(shape, dtype=np.int32)
201 mem1 = memref.alloc(T.memref(*shape, T.i32()), [], [])
203 # fmt: off
204 check_strides_offset(memref.subview(mem1, (1, 0, 0, 0), (1, 22, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, ...])
205 check_strides_offset(memref.subview(mem1, (0, 1, 0, 0), (10, 1, 3, 44), (1, 1, 1, 1)), golden_mem[:, 1:2])
206 check_strides_offset(memref.subview(mem1, (0, 0, 1, 0), (10, 22, 1, 44), (1, 1, 1, 1)), golden_mem[:, :, 1:2])
207 check_strides_offset(memref.subview(mem1, (0, 0, 0, 1), (10, 22, 3, 1), (1, 1, 1, 1)), golden_mem[:, :, :, 1:2])
208 check_strides_offset(memref.subview(mem1, (0, 1, 0, 1), (10, 1, 3, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, :, 1:2])
209 check_strides_offset(memref.subview(mem1, (1, 0, 0, 1), (1, 22, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, :, :, 1:2])
210 check_strides_offset(memref.subview(mem1, (1, 1, 0, 0), (1, 1, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, :])
211 check_strides_offset(memref.subview(mem1, (0, 0, 1, 1), (10, 22, 1, 1), (1, 1, 1, 1)), golden_mem[:, :, 1:2, 1:2])
212 check_strides_offset(memref.subview(mem1, (0, 1, 1, 0), (10, 1, 1, 44), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, :])
213 check_strides_offset(memref.subview(mem1, (1, 0, 1, 0), (1, 22, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, :])
214 check_strides_offset(memref.subview(mem1, (1, 1, 0, 1), (1, 1, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, 1:2])
215 check_strides_offset(memref.subview(mem1, (1, 0, 1, 1), (1, 22, 1, 1), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, 1:2])
216 check_strides_offset(memref.subview(mem1, (0, 1, 1, 1), (10, 1, 1, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, 1:2])
217 check_strides_offset(memref.subview(mem1, (1, 1, 1, 0), (1, 1, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, 1:2, :])
218 # fmt: on
220 # default strides and offset means no stridedlayout attribute means affinemap layout
221 assert memref.subview(
222 mem1, (0, 0, 0, 0), (10, 22, 3, 44), (1, 1, 1, 1)
223 ).type.layout == AffineMapAttr.get(
224 AffineMap.get(
228 AffineDimExpr.get(0),
229 AffineDimExpr.get(1),
230 AffineDimExpr.get(2),
231 AffineDimExpr.get(3),
236 shape = (7, 22, 30, 44)
237 golden_mem = np.zeros(shape, dtype=np.int32)
238 mem2 = memref.alloc(T.memref(*shape, T.i32()), [], [])
239 # fmt: off
240 check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 3, 44), (1, 2, 1, 1)), golden_mem[:, 0:22:2])
241 check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 44), (1, 2, 30, 1)), golden_mem[:, 0:22:2, 0:330:30])
242 check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 11), (1, 2, 30, 400)), golden_mem[:, 0:22:2, 0:330:30, 0:4400:400])
243 # fmt: on
245 shape = (8, 8)
246 golden_mem = np.zeros(shape, dtype=np.int32)
247 # fmt: off
248 mem3 = memref.alloc(T.memref(*shape, T.i32()), [], [])
249 check_strides_offset(memref.subview(mem3, (0, 0), (4, 4), (1, 1)), golden_mem[0:4, 0:4])
250 check_strides_offset(memref.subview(mem3, (4, 4), (4, 4), (1, 1)), golden_mem[4:8, 4:8])
251 # fmt: on