3 from __future__
import annotations
4 from dataclasses
import dataclass
5 from instructions
import *
6 from typing
import Any
, Iterable
, Callable
, Optional
, Tuple
, List
, Dict
13 RE_EXPECTS
= re
.compile(r
"^([0-9]+,)*[0-9]+$")
16 # Parse the SPIR-V instructions. Some instructions are ignored because
17 # not required to simulate this module.
18 # Instructions are to be implemented in instructions.py
19 def parseInstruction(i
):
37 if i
.opcode() in IGNORED
:
41 Type
= getattr(sys
.modules
["instructions"], i
.opcode())
42 except AttributeError:
43 raise RuntimeError(f
"Unsupported instruction {i}")
44 if not inspect
.isclass(Type
):
46 f
"{i} instruction definition is not a class. Did you used 'def' instead of 'class'?"
51 # Split a list of instructions into pieces. Pieces are delimited by instructions of the type splitType.
52 # The delimiter is the first instruction of the next piece.
53 # This function returns no empty pieces:
54 # - if 2 subsequent delimiters will mean 2 pieces. One with only the first delimiter, and the second
55 # with the delimiter and following instructions.
56 # - if the first instruction is a delimiter, the first piece will begin with this delimiter.
57 def splitInstructions(
58 splitType
: type, instructions
: Iterable
[Instruction
]
59 ) -> List
[List
[Instruction
]]:
60 blocks
: List
[List
[Instruction
]] = [[]]
61 for instruction
in instructions
:
62 if isinstance(instruction
, splitType
) and len(blocks
[-1]) > 0:
64 blocks
[-1].append(instruction
)
68 # Defines a BasicBlock in the simulator.
69 # Begins at an OpLabel, and ends with a control-flow instruction.
71 def __init__(self
, instructions
) -> None:
72 assert isinstance(instructions
[0], OpLabel
)
73 # The name of the basic block, which is the register of the leading
75 self
._name
= instructions
[0].output_register()
76 # The list of instructions belonging to this block.
77 self
._instructions
= instructions
[1:]
79 # Returns the name of this basic block.
83 # Returns the instruction at index in this basic block.
84 def __getitem__(self
, index
: int) -> Instruction
:
85 return self
._instructions
[index
]
87 # Returns the number of instructions in this basic block, excluding the
90 return len(self
._instructions
)
93 print(f
" {self._name}:")
94 for instruction
in self
._instructions
:
95 print(f
" {instruction}")
98 # Defines a Function in the simulator.
100 def __init__(self
, instructions
) -> None:
101 assert isinstance(instructions
[0], OpFunction
)
102 # The name of the function (name of the register returned by OpFunction).
103 self
._name
: str = instructions
[0].output_register()
104 # The list of basic blocks that belongs to this function.
105 self
._basic
_blocks
: List
[BasicBlock
] = []
106 # The variables local to this function.
107 self
._variables
: List
[OpVariable
] = [
108 x
for x
in instructions
if isinstance(x
, OpVariable
)
111 assert isinstance(instructions
[-1], OpFunctionEnd
)
112 body
= filter(lambda x
: not isinstance(x
, OpVariable
), instructions
[1:-1])
113 for block
in splitInstructions(OpLabel
, body
):
114 self
._basic
_blocks
.append(BasicBlock(block
))
116 # Returns the name of this function.
117 def name(self
) -> str:
120 # Returns the basic block at index in this function.
121 def __getitem__(self
, index
: int) -> BasicBlock
:
122 return self
._basic
_blocks
[index
]
124 # Returns the index of the basic block with the given name if found,
126 def get_bb_index(self
, name
) -> int:
127 for i
in range(len(self
._basic
_blocks
)):
128 if self
._basic
_blocks
[i
].name() == name
:
134 for var
in self
._variables
:
137 for bb
in self
._basic
_blocks
:
141 # Represents an instruction pointer in the simulator.
143 class InstructionPointer
:
144 # The current function the IP points to.
146 # The basic block index in function IP points to.
148 # The instruction in basic_block IP points to.
149 instruction_index
: int
152 bb
= self
.function
[self
.basic_block
]
153 i
= bb
[self
.instruction_index
]
154 return f
"{bb.name()}:{self.instruction_index} in {self.function.name()} | {i}"
157 return hash((self
.function
.name(), self
.basic_block
, self
.instruction_index
))
159 # Returns the basic block IP points to.
160 def bb(self
) -> BasicBlock
:
161 return self
.function
[self
.basic_block
]
163 # Returns the instruction IP points to.
164 def instruction(self
):
165 return self
.function
[self
.basic_block
][self
.instruction_index
]
167 # Increment IP by 1. This only works inside a basic-block boundary.
168 # Incrementing IP when at the boundary of a basic block will fail.
169 def __add__(self
, value
: int):
170 bb
= self
.function
[self
.basic_block
]
171 assert len(bb
) > self
.instruction_index
+ value
172 return InstructionPointer(
173 self
.function
, self
.basic_block
, self
.instruction_index
+ value
177 # Defines a Lane in this simulator.
179 # The registers known by this lane.
180 _registers
: Dict
[str, Any
]
181 # The current IP of this lane.
182 _ip
: Optional
[InstructionPointer
]
183 # If this lane running.
185 # The wave this lane belongs to.
187 # The callstack of this lane. Each tuple represents 1 call.
188 # The first element is the IP the function will return to.
189 # The second element is the callback to call to store the return value
190 # into the correct register.
191 _callstack
: List
[Tuple
[InstructionPointer
, Callable
[[Any
], None]]]
193 _previous_bb
: Optional
[BasicBlock
]
194 _current_bb
: Optional
[BasicBlock
]
196 def __init__(self
, wave
: Wave
, tid
: int) -> None:
197 self
._registers
= dict()
203 # The index of this lane in the wave.
205 # The last BB this lane was executing into.
206 self
._previous
_bb
= None
207 # The current BB this lane is executing into.
208 self
._current
_bb
= None
210 # Returns the lane/thread ID of this lane in its wave.
211 def tid(self
) -> int:
214 # Returns true is this lane if the first by index in the current active tangle.
215 def is_first_active_lane(self
) -> bool:
216 return self
._tid
== self
._wave
.get_first_active_lane_index()
218 # Broadcast value into the registers of all active lanes.
219 def broadcast_register(self
, register
: str, value
: Any
) -> None:
220 self
._wave
.broadcast_register(register
, value
)
222 # Returns the IP this lane is currently at.
223 def ip(self
) -> InstructionPointer
:
224 assert self
._ip
is not None
227 # Returns true if this lane is running, false otherwise.
228 # Running means not dead. An inactive lane is running.
229 def running(self
) -> bool:
232 # Set the register at "name" to "value" in this lane.
233 def set_register(self
, name
: str, value
: Any
) -> None:
234 self
._registers
[name
] = value
236 # Get the value in register "name" in this lane.
237 # If allow_undef is true, fetching an unknown register won't fail.
238 def get_register(self
, name
: str, allow_undef
: bool = False) -> Optional
[Any
]:
239 if allow_undef
and name
not in self
._registers
:
241 return self
._registers
[name
]
243 def set_ip(self
, ip
: InstructionPointer
) -> None:
244 if ip
.bb() != self
._current
_bb
:
245 self
._previous
_bb
= self
._current
_bb
246 self
._current
_bb
= ip
.bb()
249 def get_previous_bb_name(self
):
250 return self
._previous
_bb
.name()
252 def handle_convergence_header(self
, instruction
):
253 self
._wave
.handle_convergence_header(self
, instruction
)
255 def do_call(self
, ip
, output_register
):
256 return_ip
= None if self
._ip
is None else self
._ip
+ 1
257 self
._callstack
.append(
258 (return_ip
, lambda value
: self
.set_register(output_register
, value
))
262 def do_return(self
, value
):
263 ip
, callback
= self
._callstack
[-1]
264 self
._callstack
.pop()
267 if len(self
._callstack
) == 0:
268 self
._running
= False
273 # Represents the SPIR-V module in the simulator.
275 _functions
: Dict
[str, Function
]
276 _prolog
: List
[Instruction
]
277 _globals
: List
[Instruction
]
278 _name2reg
: Dict
[str, str]
279 _reg2name
: Dict
[str, str]
281 def __init__(self
, instructions
) -> None:
282 chunks
= splitInstructions(OpFunction
, instructions
)
284 # The instructions located outside of all functions.
285 self
._prolog
= chunks
[0]
286 # The functions in this module.
288 # Global variables in this module.
291 for x
in instructions
292 if isinstance(x
, OpVariable
) or issubclass(type(x
), OpConstant
)
295 # Helper dictionaries to get real names of registers, or registers by names.
298 for instruction
in instructions
:
299 if isinstance(instruction
, OpName
):
300 name
= instruction
.name()
301 reg
= instruction
.decoratedRegister()
302 self
._name
2reg
[name
] = reg
303 self
._reg
2name
[reg
] = name
305 for chunk
in chunks
[1:]:
306 function
= Function(chunk
)
307 assert function
.name() not in self
._functions
308 self
._functions
[function
.name()] = function
310 # Returns the register matching "name" if any, None otherwise.
311 # This assumes names are unique.
312 def getRegisterFromName(self
, name
):
313 if name
in self
._name
2reg
:
314 return self
._name
2reg
[name
]
317 # Returns the name given to "register" if any, None otherwise.
318 def getNameFromRegister(self
, register
):
319 if register
in self
._reg
2name
:
320 return self
._reg
2name
[register
]
323 # Initialize the module before wave execution begins.
324 # See Instruction::static_execution for more details.
325 def initialize(self
, lane
):
326 for instruction
in self
._globals
:
327 instruction
.static_execution(lane
)
329 # Initialize builtins
330 for instruction
in self
._prolog
:
331 if isinstance(instruction
, OpDecorate
):
332 instruction
.static_execution(lane
)
334 def execute_one_instruction(self
, lane
: Lane
, ip
: InstructionPointer
) -> None:
335 ip
.instruction().runtime_execution(self
, lane
)
337 # Returns the first valid IP for the function defined by the given register.
338 # Calling this with a register not returned by OpFunction is illegal.
339 def get_function_entry(self
, register
: str) -> InstructionPointer
:
340 if register
not in self
._functions
:
341 raise RuntimeError(f
"Function defining {register} not found.")
342 return InstructionPointer(self
._functions
[register
], 0, 0)
344 # Returns the first valid IP for the basic block defined by register.
345 # Calling this with a register not returned by an OpLabel is illegal.
346 def get_bb_entry(self
, register
: str) -> InstructionPointer
:
347 for name
, function
in self
._functions
.items():
348 index
= function
.get_bb_index(register
)
350 return InstructionPointer(function
, index
, 0)
351 raise RuntimeError(f
"Instruction defining {register} not found.")
353 # Returns the list of function names in this module.
354 # If an OpName exists for this function, returns the pretty name, else
355 # returns the register name.
356 def get_function_names(self
):
357 return [self
.getNameFromRegister(reg
) for reg
, func
in self
._functions
.items()]
359 # Returns the global variables defined in this module.
360 def variables(self
) -> Iterable
:
361 return [x
.output_register() for x
in self
._globals
]
363 def dump(self
, function_name
: Optional
[str] = None):
366 for instruction
in self
._globals
:
367 print(f
" {instruction}")
369 if function_name
is None:
371 for register
, function
in self
._functions
.items():
372 name
= self
.getNameFromRegister(register
)
373 print(f
" Function {register} ({name})")
377 register
= self
.getRegisterFromName(function_name
)
378 print(f
" function {register} ({function_name}):")
379 if register
is not None:
380 self
._functions
[register
].dump()
382 print(f
" error: cannot find function.")
385 # Defines a convergence requirement for the simulation:
386 # A list of lanes impacted by a merge and possibly the associated
389 class ConvergenceRequirement
:
390 mergeTarget
: InstructionPointer
391 continueTarget
: Optional
[InstructionPointer
]
392 impactedLanes
: set[int]
395 Task
= Dict
[InstructionPointer
, List
[Lane
]]
398 # Defines a Lane group/Wave in the simulator.
400 # The module this wave will execute.
402 # The lanes this wave will be composed of.
404 # The instructions scheduled for execution.
406 # The actual requirements to comply with when executing instructions.
407 # E.g: the set of lanes required to merge before executing the merge block.
408 _convergence_requirements
: List
[ConvergenceRequirement
]
409 # The indices of the active lanes for the current executing instruction.
410 _active_lane_indices
: set[int]
412 def __init__(self
, module
, wave_size
: int) -> None:
414 self
._module
= module
417 for i
in range(wave_size
):
418 self
._lanes
.append(Lane(self
, i
))
421 self
._convergence
_requirements
= []
422 # The indices of the active lanes for the current executing instruction.
423 self
._active
_lane
_indices
= set()
425 # Returns True if the given IP can be executed for the given list of lanes.
426 def _is_task_candidate(self
, ip
: InstructionPointer
, lanes
: List
[Lane
]):
427 merged_lanes
: set[int] = set()
428 for lane
in self
._lanes
:
429 if not lane
.running():
430 merged_lanes
.add(lane
.tid())
432 for requirement
in self
._convergence
_requirements
:
433 # This task is not executing a merge or continue target.
434 # Adding all lanes at those points into the ignore list.
435 if requirement
.mergeTarget
!= ip
and requirement
.continueTarget
!= ip
:
436 for tid
in requirement
.impactedLanes
:
437 if self
._lanes
[tid
].ip() == requirement
.mergeTarget
:
438 merged_lanes
.add(tid
)
439 if self
._lanes
[tid
].ip() == requirement
.continueTarget
:
440 merged_lanes
.add(tid
)
443 # This task is executing the current requirement continue/merge
445 for tid
in requirement
.impactedLanes
:
446 lane
= self
._lanes
[tid
]
447 if not lane
.running():
450 if lane
.tid() in merged_lanes
:
453 if ip
== requirement
.mergeTarget
:
454 if lane
.ip() != requirement
.mergeTarget
:
458 lane
.ip() != requirement
.mergeTarget
459 and lane
.ip() != requirement
.continueTarget
464 # Returns the next task we can schedule. This must always return a task.
465 # Calling this when all lanes are dead is invalid.
466 def _get_next_runnable_task(self
) -> Tuple
[InstructionPointer
, List
[Lane
]]:
468 for ip
, lanes
in self
._tasks
.items():
471 if self
._is
_task
_candidate
(ip
, lanes
):
476 lanes
= self
._tasks
[candidate
]
478 return (candidate
, lanes
)
479 raise RuntimeError("No task to execute. Deadlock?")
481 # Handle an encountered merge instruction for the given lane.
482 def handle_convergence_header(self
, lane
: Lane
, instruction
: MergeInstruction
):
483 mergeTarget
= self
._module
.get_bb_entry(instruction
.merge_location())
484 for requirement
in self
._convergence
_requirements
:
485 if requirement
.mergeTarget
== mergeTarget
:
486 requirement
.impactedLanes
.add(lane
.tid())
489 continueTarget
= None
490 if instruction
.continue_location():
491 continueTarget
= self
._module
.get_bb_entry(instruction
.continue_location())
492 requirement
= ConvergenceRequirement(
493 mergeTarget
, continueTarget
, set([lane
.tid()])
495 self
._convergence
_requirements
.append(requirement
)
497 # Returns true if some instructions are scheduled for execution.
498 def _has_tasks(self
) -> bool:
499 return len(self
._tasks
) > 0
501 # Returns the index of the first active lane right now.
502 def get_first_active_lane_index(self
) -> int:
503 return min(self
._active
_lane
_indices
)
505 # Broadcast the given value to all active lane registers.
506 def broadcast_register(self
, register
: str, value
: Any
) -> None:
507 for tid
in self
._active
_lane
_indices
:
508 self
._lanes
[tid
].set_register(register
, value
)
510 # Returns the entrypoint of the function associated with 'name'.
511 # Calling this function with an invalid name is illegal.
512 def _get_function_entry_from_name(self
, name
: str) -> InstructionPointer
:
513 register
= self
._module
.getRegisterFromName(name
)
514 assert register
is not None
515 return self
._module
.get_function_entry(register
)
517 # Run the wave on the function 'function_name' until all lanes are dead.
518 # If verbose is True, execution trace is printed.
519 # Returns the value returned by the function for each lane.
520 def run(self
, function_name
: str, verbose
: bool = False) -> List
[Any
]:
521 for t
in self
._lanes
:
522 self
._module
.initialize(t
)
524 entry_ip
= self
._get
_function
_entry
_from
_name
(function_name
)
525 assert entry_ip
is not None
526 for t
in self
._lanes
:
527 t
.do_call(entry_ip
, "__shader_output__")
529 self
._tasks
[self
._lanes
[0].ip()] = self
._lanes
530 while self
._has
_tasks
():
531 ip
, lanes
= self
._get
_next
_runnable
_task
()
532 self
._active
_lane
_indices
= set([x
.tid() for x
in lanes
])
535 f
"Executing with lanes {self._active_lane_indices}: {ip.instruction()}"
539 self
._module
.execute_one_instruction(lane
, ip
)
540 if not lane
.running():
543 if lane
.ip() in self
._tasks
:
544 self
._tasks
[lane
.ip()].append(lane
)
546 self
._tasks
[lane
.ip()] = [lane
]
548 if verbose
and ip
.instruction().has_output_register():
549 register
= ip
.instruction().output_register()
551 f
" {register:3} = {[ x.get_register(register, allow_undef=True) for x in lanes ]}"
555 for lane
in self
._lanes
:
556 output
.append(lane
.get_register("__shader_output__"))
559 def dump_register(self
, register
: str) -> None:
560 for lane
in self
._lanes
:
562 f
" Lane {lane.tid():2} | {register:3} = {lane.get_register(register)}"
566 parser
= argparse
.ArgumentParser(
567 description
="simulator", formatter_class
=argparse
.ArgumentDefaultsHelpFormatter
570 "-i", "--input", help="Text SPIR-V to read from", required
=False, default
="-"
572 parser
.add_argument("-f", "--function", help="Function to execute")
573 parser
.add_argument("-w", "--wave", help="Wave size", default
=32, required
=False)
577 help="Expected results per lanes, expects a list of values. Ex: '1, 2, 3'.",
579 parser
.add_argument("-v", "--verbose", help="verbose", action
="store_true")
580 args
= parser
.parse_args()
583 def load_instructions(filename
: str):
587 if filename
.strip() != "-":
589 with
open(filename
, "r") as f
:
590 lines
= f
.read().split("\n")
591 except Exception: # (FileNotFoundError, PermissionError):
594 lines
= sys
.stdin
.readlines()
596 # Remove leading/trailing whitespaces.
597 lines
= [x
.strip() for x
in lines
]
599 lines
= [x
for x
in filter(lambda x
: len(x
) != 0 and x
[0] != ";", lines
)]
602 for i
in [Instruction(x
) for x
in lines
]:
603 out
= parseInstruction(i
)
605 instructions
.append(out
)
610 if args
.expects
is None or not RE_EXPECTS
.match(args
.expects
):
611 print("Invalid format for --expects/-e flag.", file=sys
.stderr
)
613 if args
.function
is None:
614 print("Invalid format for --function/-f flag.", file=sys
.stderr
)
619 print("Invalid format for --wave/-w flag.", file=sys
.stderr
)
622 expected_results
= [int(x
.strip()) for x
in args
.expects
.split(",")]
623 wave_size
= int(args
.wave
)
624 if len(expected_results
) != wave_size
:
625 print("Wave size != expected result array size", file=sys
.stderr
)
628 instructions
= load_instructions(args
.input)
629 if len(instructions
) == 0:
630 print("Invalid input. Expected a text SPIR-V module.")
633 module
= Module(instructions
)
636 module
.dump(args
.function
)
638 function_names
= module
.get_function_names()
639 if args
.function
not in function_names
:
641 f
"'{args.function}' function not found. Known functions are:",
644 for name
in function_names
:
645 print(f
" - {name}", file=sys
.stderr
)
648 wave
= Wave(module
, wave_size
)
649 results
= wave
.run(args
.function
, verbose
=args
.verbose
)
651 if expected_results
!= results
:
652 print("Expected != Observed", file=sys
.stderr
)
653 print(f
"{expected_results} != {results}", file=sys
.stderr
)