[mlir][int-range] Limit xor int range inference to i1 (#116968)
[llvm-project.git] / llvm / utils / spirv-sim / spirv-sim.py
blob428b0ca4eb796c50b610683fabbcaaee27788172
1 #!/usr/bin/env python3
3 from __future__ import annotations
4 from dataclasses import dataclass
5 from instructions import *
6 from typing import Any, Iterable, Callable, Optional, Tuple, List, Dict
7 import argparse
8 import fileinput
9 import inspect
10 import re
11 import sys
13 RE_EXPECTS = re.compile(r"^([0-9]+,)*[0-9]+$")
16 # Parse the SPIR-V instructions. Some instructions are ignored because
17 # not required to simulate this module.
18 # Instructions are to be implemented in instructions.py
19 def parseInstruction(i):
20 IGNORED = set(
22 "OpCapability",
23 "OpMemoryModel",
24 "OpExecutionMode",
25 "OpExtension",
26 "OpSource",
27 "OpTypeInt",
28 "OpTypeStruct",
29 "OpTypeFloat",
30 "OpTypeBool",
31 "OpTypeVoid",
32 "OpTypeFunction",
33 "OpTypePointer",
34 "OpTypeArray",
37 if i.opcode() in IGNORED:
38 return None
40 try:
41 Type = getattr(sys.modules["instructions"], i.opcode())
42 except AttributeError:
43 raise RuntimeError(f"Unsupported instruction {i}")
44 if not inspect.isclass(Type):
45 raise RuntimeError(
46 f"{i} instruction definition is not a class. Did you used 'def' instead of 'class'?"
48 return Type(i.line)
51 # Split a list of instructions into pieces. Pieces are delimited by instructions of the type splitType.
52 # The delimiter is the first instruction of the next piece.
53 # This function returns no empty pieces:
54 # - if 2 subsequent delimiters will mean 2 pieces. One with only the first delimiter, and the second
55 # with the delimiter and following instructions.
56 # - if the first instruction is a delimiter, the first piece will begin with this delimiter.
57 def splitInstructions(
58 splitType: type, instructions: Iterable[Instruction]
59 ) -> List[List[Instruction]]:
60 blocks: List[List[Instruction]] = [[]]
61 for instruction in instructions:
62 if isinstance(instruction, splitType) and len(blocks[-1]) > 0:
63 blocks.append([])
64 blocks[-1].append(instruction)
65 return blocks
68 # Defines a BasicBlock in the simulator.
69 # Begins at an OpLabel, and ends with a control-flow instruction.
70 class BasicBlock:
71 def __init__(self, instructions) -> None:
72 assert isinstance(instructions[0], OpLabel)
73 # The name of the basic block, which is the register of the leading
74 # OpLabel.
75 self._name = instructions[0].output_register()
76 # The list of instructions belonging to this block.
77 self._instructions = instructions[1:]
79 # Returns the name of this basic block.
80 def name(self):
81 return self._name
83 # Returns the instruction at index in this basic block.
84 def __getitem__(self, index: int) -> Instruction:
85 return self._instructions[index]
87 # Returns the number of instructions in this basic block, excluding the
88 # leading OpLabel.
89 def __len__(self):
90 return len(self._instructions)
92 def dump(self):
93 print(f" {self._name}:")
94 for instruction in self._instructions:
95 print(f" {instruction}")
98 # Defines a Function in the simulator.
99 class Function:
100 def __init__(self, instructions) -> None:
101 assert isinstance(instructions[0], OpFunction)
102 # The name of the function (name of the register returned by OpFunction).
103 self._name: str = instructions[0].output_register()
104 # The list of basic blocks that belongs to this function.
105 self._basic_blocks: List[BasicBlock] = []
106 # The variables local to this function.
107 self._variables: List[OpVariable] = [
108 x for x in instructions if isinstance(x, OpVariable)
111 assert isinstance(instructions[-1], OpFunctionEnd)
112 body = filter(lambda x: not isinstance(x, OpVariable), instructions[1:-1])
113 for block in splitInstructions(OpLabel, body):
114 self._basic_blocks.append(BasicBlock(block))
116 # Returns the name of this function.
117 def name(self) -> str:
118 return self._name
120 # Returns the basic block at index in this function.
121 def __getitem__(self, index: int) -> BasicBlock:
122 return self._basic_blocks[index]
124 # Returns the index of the basic block with the given name if found,
125 # -1 otherwise.
126 def get_bb_index(self, name) -> int:
127 for i in range(len(self._basic_blocks)):
128 if self._basic_blocks[i].name() == name:
129 return i
130 return -1
132 def dump(self):
133 print(" Variables:")
134 for var in self._variables:
135 print(f" {var}")
136 print(" Blocks:")
137 for bb in self._basic_blocks:
138 bb.dump()
141 # Represents an instruction pointer in the simulator.
142 @dataclass
143 class InstructionPointer:
144 # The current function the IP points to.
145 function: Function
146 # The basic block index in function IP points to.
147 basic_block: int
148 # The instruction in basic_block IP points to.
149 instruction_index: int
151 def __str__(self):
152 bb = self.function[self.basic_block]
153 i = bb[self.instruction_index]
154 return f"{bb.name()}:{self.instruction_index} in {self.function.name()} | {i}"
156 def __hash__(self):
157 return hash((self.function.name(), self.basic_block, self.instruction_index))
159 # Returns the basic block IP points to.
160 def bb(self) -> BasicBlock:
161 return self.function[self.basic_block]
163 # Returns the instruction IP points to.
164 def instruction(self):
165 return self.function[self.basic_block][self.instruction_index]
167 # Increment IP by 1. This only works inside a basic-block boundary.
168 # Incrementing IP when at the boundary of a basic block will fail.
169 def __add__(self, value: int):
170 bb = self.function[self.basic_block]
171 assert len(bb) > self.instruction_index + value
172 return InstructionPointer(
173 self.function, self.basic_block, self.instruction_index + value
177 # Defines a Lane in this simulator.
178 class Lane:
179 # The registers known by this lane.
180 _registers: Dict[str, Any]
181 # The current IP of this lane.
182 _ip: Optional[InstructionPointer]
183 # If this lane running.
184 _running: bool
185 # The wave this lane belongs to.
186 _wave: Wave
187 # The callstack of this lane. Each tuple represents 1 call.
188 # The first element is the IP the function will return to.
189 # The second element is the callback to call to store the return value
190 # into the correct register.
191 _callstack: List[Tuple[InstructionPointer, Callable[[Any], None]]]
193 _previous_bb: Optional[BasicBlock]
194 _current_bb: Optional[BasicBlock]
196 def __init__(self, wave: Wave, tid: int) -> None:
197 self._registers = dict()
198 self._ip = None
199 self._running = True
200 self._wave = wave
201 self._callstack = []
203 # The index of this lane in the wave.
204 self._tid = tid
205 # The last BB this lane was executing into.
206 self._previous_bb = None
207 # The current BB this lane is executing into.
208 self._current_bb = None
210 # Returns the lane/thread ID of this lane in its wave.
211 def tid(self) -> int:
212 return self._tid
214 # Returns true is this lane if the first by index in the current active tangle.
215 def is_first_active_lane(self) -> bool:
216 return self._tid == self._wave.get_first_active_lane_index()
218 # Broadcast value into the registers of all active lanes.
219 def broadcast_register(self, register: str, value: Any) -> None:
220 self._wave.broadcast_register(register, value)
222 # Returns the IP this lane is currently at.
223 def ip(self) -> InstructionPointer:
224 assert self._ip is not None
225 return self._ip
227 # Returns true if this lane is running, false otherwise.
228 # Running means not dead. An inactive lane is running.
229 def running(self) -> bool:
230 return self._running
232 # Set the register at "name" to "value" in this lane.
233 def set_register(self, name: str, value: Any) -> None:
234 self._registers[name] = value
236 # Get the value in register "name" in this lane.
237 # If allow_undef is true, fetching an unknown register won't fail.
238 def get_register(self, name: str, allow_undef: bool = False) -> Optional[Any]:
239 if allow_undef and name not in self._registers:
240 return None
241 return self._registers[name]
243 def set_ip(self, ip: InstructionPointer) -> None:
244 if ip.bb() != self._current_bb:
245 self._previous_bb = self._current_bb
246 self._current_bb = ip.bb()
247 self._ip = ip
249 def get_previous_bb_name(self):
250 return self._previous_bb.name()
252 def handle_convergence_header(self, instruction):
253 self._wave.handle_convergence_header(self, instruction)
255 def do_call(self, ip, output_register):
256 return_ip = None if self._ip is None else self._ip + 1
257 self._callstack.append(
258 (return_ip, lambda value: self.set_register(output_register, value))
260 self.set_ip(ip)
262 def do_return(self, value):
263 ip, callback = self._callstack[-1]
264 self._callstack.pop()
266 callback(value)
267 if len(self._callstack) == 0:
268 self._running = False
269 else:
270 self.set_ip(ip)
273 # Represents the SPIR-V module in the simulator.
274 class Module:
275 _functions: Dict[str, Function]
276 _prolog: List[Instruction]
277 _globals: List[Instruction]
278 _name2reg: Dict[str, str]
279 _reg2name: Dict[str, str]
281 def __init__(self, instructions) -> None:
282 chunks = splitInstructions(OpFunction, instructions)
284 # The instructions located outside of all functions.
285 self._prolog = chunks[0]
286 # The functions in this module.
287 self._functions = {}
288 # Global variables in this module.
289 self._globals = [
291 for x in instructions
292 if isinstance(x, OpVariable) or issubclass(type(x), OpConstant)
295 # Helper dictionaries to get real names of registers, or registers by names.
296 self._name2reg = {}
297 self._reg2name = {}
298 for instruction in instructions:
299 if isinstance(instruction, OpName):
300 name = instruction.name()
301 reg = instruction.decoratedRegister()
302 self._name2reg[name] = reg
303 self._reg2name[reg] = name
305 for chunk in chunks[1:]:
306 function = Function(chunk)
307 assert function.name() not in self._functions
308 self._functions[function.name()] = function
310 # Returns the register matching "name" if any, None otherwise.
311 # This assumes names are unique.
312 def getRegisterFromName(self, name):
313 if name in self._name2reg:
314 return self._name2reg[name]
315 return None
317 # Returns the name given to "register" if any, None otherwise.
318 def getNameFromRegister(self, register):
319 if register in self._reg2name:
320 return self._reg2name[register]
321 return None
323 # Initialize the module before wave execution begins.
324 # See Instruction::static_execution for more details.
325 def initialize(self, lane):
326 for instruction in self._globals:
327 instruction.static_execution(lane)
329 # Initialize builtins
330 for instruction in self._prolog:
331 if isinstance(instruction, OpDecorate):
332 instruction.static_execution(lane)
334 def execute_one_instruction(self, lane: Lane, ip: InstructionPointer) -> None:
335 ip.instruction().runtime_execution(self, lane)
337 # Returns the first valid IP for the function defined by the given register.
338 # Calling this with a register not returned by OpFunction is illegal.
339 def get_function_entry(self, register: str) -> InstructionPointer:
340 if register not in self._functions:
341 raise RuntimeError(f"Function defining {register} not found.")
342 return InstructionPointer(self._functions[register], 0, 0)
344 # Returns the first valid IP for the basic block defined by register.
345 # Calling this with a register not returned by an OpLabel is illegal.
346 def get_bb_entry(self, register: str) -> InstructionPointer:
347 for name, function in self._functions.items():
348 index = function.get_bb_index(register)
349 if index != -1:
350 return InstructionPointer(function, index, 0)
351 raise RuntimeError(f"Instruction defining {register} not found.")
353 # Returns the list of function names in this module.
354 # If an OpName exists for this function, returns the pretty name, else
355 # returns the register name.
356 def get_function_names(self):
357 return [self.getNameFromRegister(reg) for reg, func in self._functions.items()]
359 # Returns the global variables defined in this module.
360 def variables(self) -> Iterable:
361 return [x.output_register() for x in self._globals]
363 def dump(self, function_name: Optional[str] = None):
364 print("Module:")
365 print(" globals:")
366 for instruction in self._globals:
367 print(f" {instruction}")
369 if function_name is None:
370 print(" functions:")
371 for register, function in self._functions.items():
372 name = self.getNameFromRegister(register)
373 print(f" Function {register} ({name})")
374 function.dump()
375 return
377 register = self.getRegisterFromName(function_name)
378 print(f" function {register} ({function_name}):")
379 if register is not None:
380 self._functions[register].dump()
381 else:
382 print(f" error: cannot find function.")
385 # Defines a convergence requirement for the simulation:
386 # A list of lanes impacted by a merge and possibly the associated
387 # continue target.
388 @dataclass
389 class ConvergenceRequirement:
390 mergeTarget: InstructionPointer
391 continueTarget: Optional[InstructionPointer]
392 impactedLanes: set[int]
395 Task = Dict[InstructionPointer, List[Lane]]
398 # Defines a Lane group/Wave in the simulator.
399 class Wave:
400 # The module this wave will execute.
401 _module: Module
402 # The lanes this wave will be composed of.
403 _lanes: List[Lane]
404 # The instructions scheduled for execution.
405 _tasks: Task
406 # The actual requirements to comply with when executing instructions.
407 # E.g: the set of lanes required to merge before executing the merge block.
408 _convergence_requirements: List[ConvergenceRequirement]
409 # The indices of the active lanes for the current executing instruction.
410 _active_lane_indices: set[int]
412 def __init__(self, module, wave_size: int) -> None:
413 assert wave_size > 0
414 self._module = module
415 self._lanes = []
417 for i in range(wave_size):
418 self._lanes.append(Lane(self, i))
420 self._tasks = {}
421 self._convergence_requirements = []
422 # The indices of the active lanes for the current executing instruction.
423 self._active_lane_indices = set()
425 # Returns True if the given IP can be executed for the given list of lanes.
426 def _is_task_candidate(self, ip: InstructionPointer, lanes: List[Lane]):
427 merged_lanes: set[int] = set()
428 for lane in self._lanes:
429 if not lane.running():
430 merged_lanes.add(lane.tid())
432 for requirement in self._convergence_requirements:
433 # This task is not executing a merge or continue target.
434 # Adding all lanes at those points into the ignore list.
435 if requirement.mergeTarget != ip and requirement.continueTarget != ip:
436 for tid in requirement.impactedLanes:
437 if self._lanes[tid].ip() == requirement.mergeTarget:
438 merged_lanes.add(tid)
439 if self._lanes[tid].ip() == requirement.continueTarget:
440 merged_lanes.add(tid)
441 continue
443 # This task is executing the current requirement continue/merge
444 # target.
445 for tid in requirement.impactedLanes:
446 lane = self._lanes[tid]
447 if not lane.running():
448 continue
450 if lane.tid() in merged_lanes:
451 continue
453 if ip == requirement.mergeTarget:
454 if lane.ip() != requirement.mergeTarget:
455 return False
456 else:
457 if (
458 lane.ip() != requirement.mergeTarget
459 and lane.ip() != requirement.continueTarget
461 return False
462 return True
464 # Returns the next task we can schedule. This must always return a task.
465 # Calling this when all lanes are dead is invalid.
466 def _get_next_runnable_task(self) -> Tuple[InstructionPointer, List[Lane]]:
467 candidate = None
468 for ip, lanes in self._tasks.items():
469 if len(lanes) == 0:
470 continue
471 if self._is_task_candidate(ip, lanes):
472 candidate = ip
473 break
475 if candidate:
476 lanes = self._tasks[candidate]
477 del self._tasks[ip]
478 return (candidate, lanes)
479 raise RuntimeError("No task to execute. Deadlock?")
481 # Handle an encountered merge instruction for the given lane.
482 def handle_convergence_header(self, lane: Lane, instruction: MergeInstruction):
483 mergeTarget = self._module.get_bb_entry(instruction.merge_location())
484 for requirement in self._convergence_requirements:
485 if requirement.mergeTarget == mergeTarget:
486 requirement.impactedLanes.add(lane.tid())
487 return
489 continueTarget = None
490 if instruction.continue_location():
491 continueTarget = self._module.get_bb_entry(instruction.continue_location())
492 requirement = ConvergenceRequirement(
493 mergeTarget, continueTarget, set([lane.tid()])
495 self._convergence_requirements.append(requirement)
497 # Returns true if some instructions are scheduled for execution.
498 def _has_tasks(self) -> bool:
499 return len(self._tasks) > 0
501 # Returns the index of the first active lane right now.
502 def get_first_active_lane_index(self) -> int:
503 return min(self._active_lane_indices)
505 # Broadcast the given value to all active lane registers.
506 def broadcast_register(self, register: str, value: Any) -> None:
507 for tid in self._active_lane_indices:
508 self._lanes[tid].set_register(register, value)
510 # Returns the entrypoint of the function associated with 'name'.
511 # Calling this function with an invalid name is illegal.
512 def _get_function_entry_from_name(self, name: str) -> InstructionPointer:
513 register = self._module.getRegisterFromName(name)
514 assert register is not None
515 return self._module.get_function_entry(register)
517 # Run the wave on the function 'function_name' until all lanes are dead.
518 # If verbose is True, execution trace is printed.
519 # Returns the value returned by the function for each lane.
520 def run(self, function_name: str, verbose: bool = False) -> List[Any]:
521 for t in self._lanes:
522 self._module.initialize(t)
524 entry_ip = self._get_function_entry_from_name(function_name)
525 assert entry_ip is not None
526 for t in self._lanes:
527 t.do_call(entry_ip, "__shader_output__")
529 self._tasks[self._lanes[0].ip()] = self._lanes
530 while self._has_tasks():
531 ip, lanes = self._get_next_runnable_task()
532 self._active_lane_indices = set([x.tid() for x in lanes])
533 if verbose:
534 print(
535 f"Executing with lanes {self._active_lane_indices}: {ip.instruction()}"
538 for lane in lanes:
539 self._module.execute_one_instruction(lane, ip)
540 if not lane.running():
541 continue
543 if lane.ip() in self._tasks:
544 self._tasks[lane.ip()].append(lane)
545 else:
546 self._tasks[lane.ip()] = [lane]
548 if verbose and ip.instruction().has_output_register():
549 register = ip.instruction().output_register()
550 print(
551 f" {register:3} = {[ x.get_register(register, allow_undef=True) for x in lanes ]}"
554 output = []
555 for lane in self._lanes:
556 output.append(lane.get_register("__shader_output__"))
557 return output
559 def dump_register(self, register: str) -> None:
560 for lane in self._lanes:
561 print(
562 f" Lane {lane.tid():2} | {register:3} = {lane.get_register(register)}"
566 parser = argparse.ArgumentParser(
567 description="simulator", formatter_class=argparse.ArgumentDefaultsHelpFormatter
569 parser.add_argument(
570 "-i", "--input", help="Text SPIR-V to read from", required=False, default="-"
572 parser.add_argument("-f", "--function", help="Function to execute")
573 parser.add_argument("-w", "--wave", help="Wave size", default=32, required=False)
574 parser.add_argument(
575 "-e",
576 "--expects",
577 help="Expected results per lanes, expects a list of values. Ex: '1, 2, 3'.",
579 parser.add_argument("-v", "--verbose", help="verbose", action="store_true")
580 args = parser.parse_args()
583 def load_instructions(filename: str):
584 if filename is None:
585 return []
587 if filename.strip() != "-":
588 try:
589 with open(filename, "r") as f:
590 lines = f.read().split("\n")
591 except Exception: # (FileNotFoundError, PermissionError):
592 return []
593 else:
594 lines = sys.stdin.readlines()
596 # Remove leading/trailing whitespaces.
597 lines = [x.strip() for x in lines]
598 # Strip comments.
599 lines = [x for x in filter(lambda x: len(x) != 0 and x[0] != ";", lines)]
601 instructions = []
602 for i in [Instruction(x) for x in lines]:
603 out = parseInstruction(i)
604 if out != None:
605 instructions.append(out)
606 return instructions
609 def main():
610 if args.expects is None or not RE_EXPECTS.match(args.expects):
611 print("Invalid format for --expects/-e flag.", file=sys.stderr)
612 sys.exit(1)
613 if args.function is None:
614 print("Invalid format for --function/-f flag.", file=sys.stderr)
615 sys.exit(1)
616 try:
617 int(args.wave)
618 except ValueError:
619 print("Invalid format for --wave/-w flag.", file=sys.stderr)
620 sys.exit(1)
622 expected_results = [int(x.strip()) for x in args.expects.split(",")]
623 wave_size = int(args.wave)
624 if len(expected_results) != wave_size:
625 print("Wave size != expected result array size", file=sys.stderr)
626 sys.exit(1)
628 instructions = load_instructions(args.input)
629 if len(instructions) == 0:
630 print("Invalid input. Expected a text SPIR-V module.")
631 sys.exit(1)
633 module = Module(instructions)
634 if args.verbose:
635 module.dump()
636 module.dump(args.function)
638 function_names = module.get_function_names()
639 if args.function not in function_names:
640 print(
641 f"'{args.function}' function not found. Known functions are:",
642 file=sys.stderr,
644 for name in function_names:
645 print(f" - {name}", file=sys.stderr)
646 sys.exit(1)
648 wave = Wave(module, wave_size)
649 results = wave.run(args.function, verbose=args.verbose)
651 if expected_results != results:
652 print("Expected != Observed", file=sys.stderr)
653 print(f"{expected_results} != {results}", file=sys.stderr)
654 sys.exit(1)
655 sys.exit(0)
658 main()