Merge #361424: refactor lib.packagesFromDirectoryRecursive (v2)
[NixPkgs.git] / nixos / lib / test-driver / test_driver / driver.py
blob6061c1bc09b85afcd75180161544c83d0597aa75
1 import os
2 import re
3 import signal
4 import tempfile
5 import threading
6 from collections.abc import Callable, Iterator
7 from contextlib import AbstractContextManager, contextmanager
8 from pathlib import Path
9 from typing import Any
11 from test_driver.logger import AbstractLogger
12 from test_driver.machine import Machine, NixStartScript, retry
13 from test_driver.polling_condition import PollingCondition
14 from test_driver.vlan import VLan
16 SENTINEL = object()
19 def get_tmp_dir() -> Path:
20 """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD
21 Raises an exception in case the retrieved temporary directory is not writeable
22 See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir
23 """
24 tmp_dir = Path(tempfile.gettempdir())
25 tmp_dir.mkdir(mode=0o700, exist_ok=True)
26 if not tmp_dir.is_dir():
27 raise NotADirectoryError(
28 f"The directory defined by TMPDIR, TEMP, TMP or CWD: {tmp_dir} is not a directory"
30 if not os.access(tmp_dir, os.W_OK):
31 raise PermissionError(
32 f"The directory defined by TMPDIR, TEMP, TMP, or CWD: {tmp_dir} is not writeable"
34 return tmp_dir
37 def pythonize_name(name: str) -> str:
38 return re.sub(r"^[^A-z_]|[^A-z0-9_]", "_", name)
41 class Driver:
42 """A handle to the driver that sets up the environment
43 and runs the tests"""
45 tests: str
46 vlans: list[VLan]
47 machines: list[Machine]
48 polling_conditions: list[PollingCondition]
49 global_timeout: int
50 race_timer: threading.Timer
51 logger: AbstractLogger
53 def __init__(
54 self,
55 start_scripts: list[str],
56 vlans: list[int],
57 tests: str,
58 out_dir: Path,
59 logger: AbstractLogger,
60 keep_vm_state: bool = False,
61 global_timeout: int = 24 * 60 * 60 * 7,
63 self.tests = tests
64 self.out_dir = out_dir
65 self.global_timeout = global_timeout
66 self.race_timer = threading.Timer(global_timeout, self.terminate_test)
67 self.logger = logger
69 tmp_dir = get_tmp_dir()
71 with self.logger.nested("start all VLans"):
72 vlans = list(set(vlans))
73 self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans]
75 def cmd(scripts: list[str]) -> Iterator[NixStartScript]:
76 for s in scripts:
77 yield NixStartScript(s)
79 self.polling_conditions = []
81 self.machines = [
82 Machine(
83 start_command=cmd,
84 keep_vm_state=keep_vm_state,
85 name=cmd.machine_name,
86 tmp_dir=tmp_dir,
87 callbacks=[self.check_polling_conditions],
88 out_dir=self.out_dir,
89 logger=self.logger,
91 for cmd in cmd(start_scripts)
94 def __enter__(self) -> "Driver":
95 return self
97 def __exit__(self, *_: Any) -> None:
98 with self.logger.nested("cleanup"):
99 self.race_timer.cancel()
100 for machine in self.machines:
101 try:
102 machine.release()
103 except Exception as e:
104 self.logger.error(f"Error during cleanup of {machine.name}: {e}")
106 for vlan in self.vlans:
107 try:
108 vlan.stop()
109 except Exception as e:
110 self.logger.error(f"Error during cleanup of vlan{vlan.nr}: {e}")
112 def subtest(self, name: str) -> Iterator[None]:
113 """Group logs under a given test name"""
114 with self.logger.subtest(name):
115 try:
116 yield
117 except Exception as e:
118 self.logger.error(f'Test "{name}" failed with error: "{e}"')
119 raise e
121 def test_symbols(self) -> dict[str, Any]:
122 @contextmanager
123 def subtest(name: str) -> Iterator[None]:
124 return self.subtest(name)
126 general_symbols = dict(
127 start_all=self.start_all,
128 test_script=self.test_script,
129 machines=self.machines,
130 vlans=self.vlans,
131 driver=self,
132 log=self.logger,
133 os=os,
134 create_machine=self.create_machine,
135 subtest=subtest,
136 run_tests=self.run_tests,
137 join_all=self.join_all,
138 retry=retry,
139 serial_stdout_off=self.serial_stdout_off,
140 serial_stdout_on=self.serial_stdout_on,
141 polling_condition=self.polling_condition,
142 Machine=Machine, # for typing
144 machine_symbols = {pythonize_name(m.name): m for m in self.machines}
145 # If there's exactly one machine, make it available under the name
146 # "machine", even if it's not called that.
147 if len(self.machines) == 1:
148 (machine_symbols["machine"],) = self.machines
149 vlan_symbols = {
150 f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
152 print(
153 "additionally exposed symbols:\n "
154 + ", ".join(map(lambda m: m.name, self.machines))
155 + ",\n "
156 + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
157 + ",\n "
158 + ", ".join(list(general_symbols.keys()))
160 return {**general_symbols, **machine_symbols, **vlan_symbols}
162 def test_script(self) -> None:
163 """Run the test script"""
164 with self.logger.nested("run the VM test script"):
165 symbols = self.test_symbols() # call eagerly
166 exec(self.tests, symbols, None)
168 def run_tests(self) -> None:
169 """Run the test script (for non-interactive test runs)"""
170 self.logger.info(
171 f"Test will time out and terminate in {self.global_timeout} seconds"
173 self.race_timer.start()
174 self.test_script()
175 # TODO: Collect coverage data
176 for machine in self.machines:
177 if machine.is_up():
178 machine.execute("sync")
180 def start_all(self) -> None:
181 """Start all machines"""
182 with self.logger.nested("start all VMs"):
183 for machine in self.machines:
184 machine.start()
186 def join_all(self) -> None:
187 """Wait for all machines to shut down"""
188 with self.logger.nested("wait for all VMs to finish"):
189 for machine in self.machines:
190 machine.wait_for_shutdown()
191 self.race_timer.cancel()
193 def terminate_test(self) -> None:
194 # This will be usually running in another thread than
195 # the thread actually executing the test script.
196 with self.logger.nested("timeout reached; test terminating..."):
197 for machine in self.machines:
198 machine.release()
199 # As we cannot `sys.exit` from another thread
200 # We can at least force the main thread to get SIGTERM'ed.
201 # This will prevent any user who caught all the exceptions
202 # to swallow them and prevent itself from terminating.
203 os.kill(os.getpid(), signal.SIGTERM)
205 def create_machine(
206 self,
207 start_command: str,
209 name: str | None = None,
210 keep_vm_state: bool = False,
211 ) -> Machine:
212 tmp_dir = get_tmp_dir()
214 cmd = NixStartScript(start_command)
215 name = name or cmd.machine_name
217 return Machine(
218 tmp_dir=tmp_dir,
219 out_dir=self.out_dir,
220 start_command=cmd,
221 name=name,
222 keep_vm_state=keep_vm_state,
223 logger=self.logger,
226 def serial_stdout_on(self) -> None:
227 self.logger.print_serial_logs(True)
229 def serial_stdout_off(self) -> None:
230 self.logger.print_serial_logs(False)
232 def check_polling_conditions(self) -> None:
233 for condition in self.polling_conditions:
234 condition.maybe_raise()
236 def polling_condition(
237 self,
238 fun_: Callable | None = None,
240 seconds_interval: float = 2.0,
241 description: str | None = None,
242 ) -> Callable[[Callable], AbstractContextManager] | AbstractContextManager:
243 driver = self
245 class Poll:
246 def __init__(self, fun: Callable):
247 self.condition = PollingCondition(
248 fun,
249 driver.logger,
250 seconds_interval,
251 description,
254 def __enter__(self) -> None:
255 driver.polling_conditions.append(self.condition)
257 def __exit__(self, a, b, c) -> None: # type: ignore
258 res = driver.polling_conditions.pop()
259 assert res is self.condition
261 def wait(self, timeout: int = 900) -> None:
262 def condition(last: bool) -> bool:
263 if last:
264 driver.logger.info(
265 f"Last chance for {self.condition.description}"
267 ret = self.condition.check(force=True)
268 if not ret and not last:
269 driver.logger.info(
270 f"({self.condition.description} failure not fatal yet)"
272 return ret
274 with driver.logger.nested(f"waiting for {self.condition.description}"):
275 retry(condition, timeout=timeout)
277 if fun_ is None:
278 return Poll
279 else:
280 return Poll(fun_)