1 from typing
import Optional
, List
4 # Base class for an instruction. To implement a basic instruction that doesn't
5 # impact the control-flow, create a new class inheriting from this.
7 # Contains the name of the output register, if any.
9 # Contains the instruction opcode.
11 # Contains all the instruction operands, except result and opcode.
14 def __init__(self
, line
: str):
17 if len(tokens
) > 1 and tokens
[1] == "=":
18 self
._result
= tokens
[0]
19 self
._opcode
= tokens
[2]
20 self
._operands
= tokens
[3:] if len(tokens
) > 2 else []
23 self
._opcode
= tokens
[0]
24 self
._operands
= tokens
[1:] if len(tokens
) > 1 else []
27 if self
._result
is None:
28 return f
" {self._opcode} {self._operands}"
29 return f
"{self._result:3} = {self._opcode} {self._operands}"
31 # Returns the instruction opcode.
32 def opcode(self
) -> str:
35 # Returns the instruction operands.
36 def operands(self
) -> List
[str]:
39 # Returns the instruction output register. Calling this function is
40 # only allowed if has_output_register() is true.
41 def output_register(self
) -> str:
42 assert self
._result
is not None
45 # Returns true if this function has an output register. False otherwise.
46 def has_output_register(self
) -> bool:
47 return self
._result
is not None
49 # This function is used to initialize state related to this instruction
50 # before module execution begins. For example, global Input variables
51 # can use this to store the lane ID into the register.
52 def static_execution(self
, lane
):
55 # This function is called everytime this instruction is executed by a
56 # tangle. This function should not be directly overriden, instead see
57 # _impl and _advance_ip.
58 def runtime_execution(self
, module
, lane
):
59 self
._impl
(module
, lane
)
60 self
._advance
_ip
(module
, lane
)
62 # This function needs to be overriden if your instruction can be executed.
63 # It implements the logic of the instruction.
64 # 'Static' instructions like OpConstant should not override this since
65 # they are not supposed to be executed at runtime.
66 def _impl(self
, module
, lane
):
67 raise RuntimeError(f
"Unimplemented instruction {self}")
69 # By default, IP is incremented to point to the next instruction.
70 # If the instruction modifies IP (like OpBranch), this must be overridden.
71 def _advance_ip(self
, module
, lane
):
72 lane
.set_ip(lane
.ip() + 1)
75 # Those are parsed, but never executed.
76 class OpEntryPoint(Instruction
):
80 class OpFunction(Instruction
):
84 class OpFunctionEnd(Instruction
):
88 class OpLabel(Instruction
):
92 class OpVariable(Instruction
):
96 class OpName(Instruction
):
97 def name(self
) -> str:
98 return self
._operands
[1][1:-1]
100 def decoratedRegister(self
) -> str:
101 return self
._operands
[0]
104 # The only decoration we use if the BuiltIn one to initialize the values.
105 class OpDecorate(Instruction
):
106 def static_execution(self
, lane
):
107 if self
._operands
[1] == "LinkageAttributes":
111 self
._operands
[1] == "BuiltIn"
112 and self
._operands
[2] == "SubgroupLocalInvocationId"
114 lane
.set_register(self
._operands
[0], lane
.tid())
118 class OpConstant(Instruction
):
119 def static_execution(self
, lane
):
120 lane
.set_register(self
._result
, int(self
._operands
[1]))
123 class OpConstantTrue(OpConstant
):
124 def static_execution(self
, lane
):
125 lane
.set_register(self
._result
, True)
128 class OpConstantFalse(OpConstant
):
129 def static_execution(self
, lane
):
130 lane
.set_register(self
._result
, False)
133 class OpConstantComposite(OpConstant
):
134 def static_execution(self
, lane
):
136 for op
in self
._operands
[1:]:
137 result
.append(lane
.get_register(op
))
138 lane
.set_register(self
._result
, result
)
141 # Control flow instructions
142 class OpFunctionCall(Instruction
):
143 def _impl(self
, module
, lane
):
146 def _advance_ip(self
, module
, lane
):
147 entry
= module
.get_function_entry(self
._operands
[1])
148 lane
.do_call(entry
, self
._result
)
151 class OpReturn(Instruction
):
152 def _impl(self
, module
, lane
):
155 def _advance_ip(self
, module
, lane
):
159 class OpReturnValue(Instruction
):
160 def _impl(self
, module
, lane
):
163 def _advance_ip(self
, module
, lane
):
164 lane
.do_return(lane
.get_register(self
._operands
[0]))
167 class OpBranch(Instruction
):
168 def _impl(self
, module
, lane
):
171 def _advance_ip(self
, module
, lane
):
172 lane
.set_ip(module
.get_bb_entry(self
._operands
[0]))
176 class OpBranchConditional(Instruction
):
177 def _impl(self
, module
, lane
):
180 def _advance_ip(self
, module
, lane
):
181 condition
= lane
.get_register(self
._operands
[0])
183 lane
.set_ip(module
.get_bb_entry(self
._operands
[1]))
185 lane
.set_ip(module
.get_bb_entry(self
._operands
[2]))
188 class OpSwitch(Instruction
):
189 def _impl(self
, module
, lane
):
192 def _advance_ip(self
, module
, lane
):
193 value
= lane
.get_register(self
._operands
[0])
194 default_label
= self
._operands
[1]
196 while i
< len(self
._operands
):
197 imm
= int(self
._operands
[i
])
198 label
= self
._operands
[i
+ 1]
200 lane
.set_ip(module
.get_bb_entry(label
))
203 lane
.set_ip(module
.get_bb_entry(default_label
))
206 class OpUnreachable(Instruction
):
207 def _impl(self
, module
, lane
):
208 raise RuntimeError("This instruction should never be executed.")
211 # Convergence instructions
212 class MergeInstruction(Instruction
):
213 def merge_location(self
):
214 return self
._operands
[0]
216 def continue_location(self
):
217 return None if len(self
._operands
) < 3 else self
._operands
[1]
219 def _impl(self
, module
, lane
):
220 lane
.handle_convergence_header(self
)
223 class OpLoopMerge(MergeInstruction
):
227 class OpSelectionMerge(MergeInstruction
):
232 class OpBitcast(Instruction
):
233 def _impl(self
, module
, lane
):
234 # TODO: find out the type from the defining instruction.
235 # This can only work for DXC.
236 if self
._operands
[0] == "%int":
237 lane
.set_register(self
._result
, int(lane
.get_register(self
._operands
[1])))
239 raise RuntimeError("Unsupported OpBitcast operand")
242 class OpAccessChain(Instruction
):
243 def _impl(self
, module
, lane
):
244 # Python dynamic types allows me to simplify. As long as the SPIR-V
245 # is legal, this should be fine.
246 # Note: SPIR-V structs are stored as tuples
247 value
= lane
.get_register(self
._operands
[1])
248 for operand
in self
._operands
[2:]:
249 value
= value
[lane
.get_register(operand
)]
250 lane
.set_register(self
._result
, value
)
253 class OpCompositeConstruct(Instruction
):
254 def _impl(self
, module
, lane
):
256 for op
in self
._operands
[1:]:
257 output
.append(lane
.get_register(op
))
258 lane
.set_register(self
._result
, output
)
261 class OpCompositeExtract(Instruction
):
262 def _impl(self
, module
, lane
):
263 value
= lane
.get_register(self
._operands
[1])
265 for op
in self
._operands
[2:]:
266 output
= output
[int(op
)]
267 lane
.set_register(self
._result
, output
)
270 class OpStore(Instruction
):
271 def _impl(self
, module
, lane
):
272 lane
.set_register(self
._operands
[0], lane
.get_register(self
._operands
[1]))
275 class OpLoad(Instruction
):
276 def _impl(self
, module
, lane
):
277 lane
.set_register(self
._result
, lane
.get_register(self
._operands
[1]))
280 class OpIAdd(Instruction
):
281 def _impl(self
, module
, lane
):
282 LHS
= lane
.get_register(self
._operands
[1])
283 RHS
= lane
.get_register(self
._operands
[2])
284 lane
.set_register(self
._result
, LHS
+ RHS
)
287 class OpISub(Instruction
):
288 def _impl(self
, module
, lane
):
289 LHS
= lane
.get_register(self
._operands
[1])
290 RHS
= lane
.get_register(self
._operands
[2])
291 lane
.set_register(self
._result
, LHS
- RHS
)
294 class OpIMul(Instruction
):
295 def _impl(self
, module
, lane
):
296 LHS
= lane
.get_register(self
._operands
[1])
297 RHS
= lane
.get_register(self
._operands
[2])
298 lane
.set_register(self
._result
, LHS
* RHS
)
301 class OpLogicalNot(Instruction
):
302 def _impl(self
, module
, lane
):
303 LHS
= lane
.get_register(self
._operands
[1])
304 lane
.set_register(self
._result
, not LHS
)
307 class _LessThan(Instruction
):
308 def _impl(self
, module
, lane
):
309 LHS
= lane
.get_register(self
._operands
[1])
310 RHS
= lane
.get_register(self
._operands
[2])
311 lane
.set_register(self
._result
, LHS
< RHS
)
314 class _GreaterThan(Instruction
):
315 def _impl(self
, module
, lane
):
316 LHS
= lane
.get_register(self
._operands
[1])
317 RHS
= lane
.get_register(self
._operands
[2])
318 lane
.set_register(self
._result
, LHS
> RHS
)
321 class OpSLessThan(_LessThan
):
325 class OpULessThan(_LessThan
):
329 class OpSGreaterThan(_GreaterThan
):
333 class OpUGreaterThan(_GreaterThan
):
337 class OpIEqual(Instruction
):
338 def _impl(self
, module
, lane
):
339 LHS
= lane
.get_register(self
._operands
[1])
340 RHS
= lane
.get_register(self
._operands
[2])
341 lane
.set_register(self
._result
, LHS
== RHS
)
344 class OpINotEqual(Instruction
):
345 def _impl(self
, module
, lane
):
346 LHS
= lane
.get_register(self
._operands
[1])
347 RHS
= lane
.get_register(self
._operands
[2])
348 lane
.set_register(self
._result
, LHS
!= RHS
)
351 class OpPhi(Instruction
):
352 def _impl(self
, module
, lane
):
353 previousBBName
= lane
.get_previous_bb_name()
355 while i
< len(self
._operands
):
356 label
= self
._operands
[i
+ 1]
357 if label
== previousBBName
:
358 lane
.set_register(self
._result
, lane
.get_register(self
._operands
[i
]))
361 raise RuntimeError("previousBB not in the OpPhi _operands")
364 class OpSelect(Instruction
):
365 def _impl(self
, module
, lane
):
366 condition
= lane
.get_register(self
._operands
[1])
367 value
= lane
.get_register(self
._operands
[2 if condition
else 3])
368 lane
.set_register(self
._result
, value
)
372 class OpGroupNonUniformBroadcastFirst(Instruction
):
373 def _impl(self
, module
, lane
):
374 assert lane
.get_register(self
._operands
[1]) == 3
375 if lane
.is_first_active_lane():
376 lane
.broadcast_register(self
._result
, lane
.get_register(self
._operands
[2]))
379 class OpGroupNonUniformElect(Instruction
):
380 def _impl(self
, module
, lane
):
381 lane
.set_register(self
._result
, lane
.is_first_active_lane())