[Frontend] Remove unused includes (NFC) (#116927)
[llvm-project.git] / llvm / utils / spirv-sim / instructions.py
blob5e64a480a2be6bfbdbfc3f02e444cd226b68b09f
1 from typing import Optional, List
4 # Base class for an instruction. To implement a basic instruction that doesn't
5 # impact the control-flow, create a new class inheriting from this.
6 class Instruction:
7 # Contains the name of the output register, if any.
8 _result: Optional[str]
9 # Contains the instruction opcode.
10 _opcode: str
11 # Contains all the instruction operands, except result and opcode.
12 _operands: List[str]
14 def __init__(self, line: str):
15 self.line = line
16 tokens = line.split()
17 if len(tokens) > 1 and tokens[1] == "=":
18 self._result = tokens[0]
19 self._opcode = tokens[2]
20 self._operands = tokens[3:] if len(tokens) > 2 else []
21 else:
22 self._result = None
23 self._opcode = tokens[0]
24 self._operands = tokens[1:] if len(tokens) > 1 else []
26 def __str__(self):
27 if self._result is None:
28 return f" {self._opcode} {self._operands}"
29 return f"{self._result:3} = {self._opcode} {self._operands}"
31 # Returns the instruction opcode.
32 def opcode(self) -> str:
33 return self._opcode
35 # Returns the instruction operands.
36 def operands(self) -> List[str]:
37 return self._operands
39 # Returns the instruction output register. Calling this function is
40 # only allowed if has_output_register() is true.
41 def output_register(self) -> str:
42 assert self._result is not None
43 return self._result
45 # Returns true if this function has an output register. False otherwise.
46 def has_output_register(self) -> bool:
47 return self._result is not None
49 # This function is used to initialize state related to this instruction
50 # before module execution begins. For example, global Input variables
51 # can use this to store the lane ID into the register.
52 def static_execution(self, lane):
53 pass
55 # This function is called everytime this instruction is executed by a
56 # tangle. This function should not be directly overriden, instead see
57 # _impl and _advance_ip.
58 def runtime_execution(self, module, lane):
59 self._impl(module, lane)
60 self._advance_ip(module, lane)
62 # This function needs to be overriden if your instruction can be executed.
63 # It implements the logic of the instruction.
64 # 'Static' instructions like OpConstant should not override this since
65 # they are not supposed to be executed at runtime.
66 def _impl(self, module, lane):
67 raise RuntimeError(f"Unimplemented instruction {self}")
69 # By default, IP is incremented to point to the next instruction.
70 # If the instruction modifies IP (like OpBranch), this must be overridden.
71 def _advance_ip(self, module, lane):
72 lane.set_ip(lane.ip() + 1)
75 # Those are parsed, but never executed.
76 class OpEntryPoint(Instruction):
77 pass
80 class OpFunction(Instruction):
81 pass
84 class OpFunctionEnd(Instruction):
85 pass
88 class OpLabel(Instruction):
89 pass
92 class OpVariable(Instruction):
93 pass
96 class OpName(Instruction):
97 def name(self) -> str:
98 return self._operands[1][1:-1]
100 def decoratedRegister(self) -> str:
101 return self._operands[0]
104 # The only decoration we use if the BuiltIn one to initialize the values.
105 class OpDecorate(Instruction):
106 def static_execution(self, lane):
107 if self._operands[1] == "LinkageAttributes":
108 return
110 assert (
111 self._operands[1] == "BuiltIn"
112 and self._operands[2] == "SubgroupLocalInvocationId"
114 lane.set_register(self._operands[0], lane.tid())
117 # Constants
118 class OpConstant(Instruction):
119 def static_execution(self, lane):
120 lane.set_register(self._result, int(self._operands[1]))
123 class OpConstantTrue(OpConstant):
124 def static_execution(self, lane):
125 lane.set_register(self._result, True)
128 class OpConstantFalse(OpConstant):
129 def static_execution(self, lane):
130 lane.set_register(self._result, False)
133 class OpConstantComposite(OpConstant):
134 def static_execution(self, lane):
135 result = []
136 for op in self._operands[1:]:
137 result.append(lane.get_register(op))
138 lane.set_register(self._result, result)
141 # Control flow instructions
142 class OpFunctionCall(Instruction):
143 def _impl(self, module, lane):
144 pass
146 def _advance_ip(self, module, lane):
147 entry = module.get_function_entry(self._operands[1])
148 lane.do_call(entry, self._result)
151 class OpReturn(Instruction):
152 def _impl(self, module, lane):
153 pass
155 def _advance_ip(self, module, lane):
156 lane.do_return(None)
159 class OpReturnValue(Instruction):
160 def _impl(self, module, lane):
161 pass
163 def _advance_ip(self, module, lane):
164 lane.do_return(lane.get_register(self._operands[0]))
167 class OpBranch(Instruction):
168 def _impl(self, module, lane):
169 pass
171 def _advance_ip(self, module, lane):
172 lane.set_ip(module.get_bb_entry(self._operands[0]))
173 pass
176 class OpBranchConditional(Instruction):
177 def _impl(self, module, lane):
178 pass
180 def _advance_ip(self, module, lane):
181 condition = lane.get_register(self._operands[0])
182 if condition:
183 lane.set_ip(module.get_bb_entry(self._operands[1]))
184 else:
185 lane.set_ip(module.get_bb_entry(self._operands[2]))
188 class OpSwitch(Instruction):
189 def _impl(self, module, lane):
190 pass
192 def _advance_ip(self, module, lane):
193 value = lane.get_register(self._operands[0])
194 default_label = self._operands[1]
195 i = 2
196 while i < len(self._operands):
197 imm = int(self._operands[i])
198 label = self._operands[i + 1]
199 if value == imm:
200 lane.set_ip(module.get_bb_entry(label))
201 return
202 i += 2
203 lane.set_ip(module.get_bb_entry(default_label))
206 class OpUnreachable(Instruction):
207 def _impl(self, module, lane):
208 raise RuntimeError("This instruction should never be executed.")
211 # Convergence instructions
212 class MergeInstruction(Instruction):
213 def merge_location(self):
214 return self._operands[0]
216 def continue_location(self):
217 return None if len(self._operands) < 3 else self._operands[1]
219 def _impl(self, module, lane):
220 lane.handle_convergence_header(self)
223 class OpLoopMerge(MergeInstruction):
224 pass
227 class OpSelectionMerge(MergeInstruction):
228 pass
231 # Other instructions
232 class OpBitcast(Instruction):
233 def _impl(self, module, lane):
234 # TODO: find out the type from the defining instruction.
235 # This can only work for DXC.
236 if self._operands[0] == "%int":
237 lane.set_register(self._result, int(lane.get_register(self._operands[1])))
238 else:
239 raise RuntimeError("Unsupported OpBitcast operand")
242 class OpAccessChain(Instruction):
243 def _impl(self, module, lane):
244 # Python dynamic types allows me to simplify. As long as the SPIR-V
245 # is legal, this should be fine.
246 # Note: SPIR-V structs are stored as tuples
247 value = lane.get_register(self._operands[1])
248 for operand in self._operands[2:]:
249 value = value[lane.get_register(operand)]
250 lane.set_register(self._result, value)
253 class OpCompositeConstruct(Instruction):
254 def _impl(self, module, lane):
255 output = []
256 for op in self._operands[1:]:
257 output.append(lane.get_register(op))
258 lane.set_register(self._result, output)
261 class OpCompositeExtract(Instruction):
262 def _impl(self, module, lane):
263 value = lane.get_register(self._operands[1])
264 output = value
265 for op in self._operands[2:]:
266 output = output[int(op)]
267 lane.set_register(self._result, output)
270 class OpStore(Instruction):
271 def _impl(self, module, lane):
272 lane.set_register(self._operands[0], lane.get_register(self._operands[1]))
275 class OpLoad(Instruction):
276 def _impl(self, module, lane):
277 lane.set_register(self._result, lane.get_register(self._operands[1]))
280 class OpIAdd(Instruction):
281 def _impl(self, module, lane):
282 LHS = lane.get_register(self._operands[1])
283 RHS = lane.get_register(self._operands[2])
284 lane.set_register(self._result, LHS + RHS)
287 class OpISub(Instruction):
288 def _impl(self, module, lane):
289 LHS = lane.get_register(self._operands[1])
290 RHS = lane.get_register(self._operands[2])
291 lane.set_register(self._result, LHS - RHS)
294 class OpIMul(Instruction):
295 def _impl(self, module, lane):
296 LHS = lane.get_register(self._operands[1])
297 RHS = lane.get_register(self._operands[2])
298 lane.set_register(self._result, LHS * RHS)
301 class OpLogicalNot(Instruction):
302 def _impl(self, module, lane):
303 LHS = lane.get_register(self._operands[1])
304 lane.set_register(self._result, not LHS)
307 class _LessThan(Instruction):
308 def _impl(self, module, lane):
309 LHS = lane.get_register(self._operands[1])
310 RHS = lane.get_register(self._operands[2])
311 lane.set_register(self._result, LHS < RHS)
314 class _GreaterThan(Instruction):
315 def _impl(self, module, lane):
316 LHS = lane.get_register(self._operands[1])
317 RHS = lane.get_register(self._operands[2])
318 lane.set_register(self._result, LHS > RHS)
321 class OpSLessThan(_LessThan):
322 pass
325 class OpULessThan(_LessThan):
326 pass
329 class OpSGreaterThan(_GreaterThan):
330 pass
333 class OpUGreaterThan(_GreaterThan):
334 pass
337 class OpIEqual(Instruction):
338 def _impl(self, module, lane):
339 LHS = lane.get_register(self._operands[1])
340 RHS = lane.get_register(self._operands[2])
341 lane.set_register(self._result, LHS == RHS)
344 class OpINotEqual(Instruction):
345 def _impl(self, module, lane):
346 LHS = lane.get_register(self._operands[1])
347 RHS = lane.get_register(self._operands[2])
348 lane.set_register(self._result, LHS != RHS)
351 class OpPhi(Instruction):
352 def _impl(self, module, lane):
353 previousBBName = lane.get_previous_bb_name()
354 i = 1
355 while i < len(self._operands):
356 label = self._operands[i + 1]
357 if label == previousBBName:
358 lane.set_register(self._result, lane.get_register(self._operands[i]))
359 return
360 i += 2
361 raise RuntimeError("previousBB not in the OpPhi _operands")
364 class OpSelect(Instruction):
365 def _impl(self, module, lane):
366 condition = lane.get_register(self._operands[1])
367 value = lane.get_register(self._operands[2 if condition else 3])
368 lane.set_register(self._result, value)
371 # Wave intrinsics
372 class OpGroupNonUniformBroadcastFirst(Instruction):
373 def _impl(self, module, lane):
374 assert lane.get_register(self._operands[1]) == 3
375 if lane.is_first_active_lane():
376 lane.broadcast_register(self._result, lane.get_register(self._operands[2]))
379 class OpGroupNonUniformElect(Instruction):
380 def _impl(self, module, lane):
381 lane.set_register(self._result, lane.is_first_active_lane())