6 from collections
.abc
import Callable
, Iterator
7 from contextlib
import AbstractContextManager
, contextmanager
8 from pathlib
import Path
11 from test_driver
.logger
import AbstractLogger
12 from test_driver
.machine
import Machine
, NixStartScript
, retry
13 from test_driver
.polling_condition
import PollingCondition
14 from test_driver
.vlan
import VLan
19 def get_tmp_dir() -> Path
:
20 """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD
21 Raises an exception in case the retrieved temporary directory is not writeable
22 See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir
24 tmp_dir
= Path(tempfile
.gettempdir())
25 tmp_dir
.mkdir(mode
=0o700, exist_ok
=True)
26 if not tmp_dir
.is_dir():
27 raise NotADirectoryError(
28 f
"The directory defined by TMPDIR, TEMP, TMP or CWD: {tmp_dir} is not a directory"
30 if not os
.access(tmp_dir
, os
.W_OK
):
31 raise PermissionError(
32 f
"The directory defined by TMPDIR, TEMP, TMP, or CWD: {tmp_dir} is not writeable"
37 def pythonize_name(name
: str) -> str:
38 return re
.sub(r
"^[^A-z_]|[^A-z0-9_]", "_", name
)
42 """A handle to the driver that sets up the environment
47 machines
: list[Machine
]
48 polling_conditions
: list[PollingCondition
]
50 race_timer
: threading
.Timer
51 logger
: AbstractLogger
55 start_scripts
: list[str],
59 logger
: AbstractLogger
,
60 keep_vm_state
: bool = False,
61 global_timeout
: int = 24 * 60 * 60 * 7,
64 self
.out_dir
= out_dir
65 self
.global_timeout
= global_timeout
66 self
.race_timer
= threading
.Timer(global_timeout
, self
.terminate_test
)
69 tmp_dir
= get_tmp_dir()
71 with self
.logger
.nested("start all VLans"):
72 vlans
= list(set(vlans
))
73 self
.vlans
= [VLan(nr
, tmp_dir
, self
.logger
) for nr
in vlans
]
75 def cmd(scripts
: list[str]) -> Iterator
[NixStartScript
]:
77 yield NixStartScript(s
)
79 self
.polling_conditions
= []
84 keep_vm_state
=keep_vm_state
,
85 name
=cmd
.machine_name
,
87 callbacks
=[self
.check_polling_conditions
],
91 for cmd
in cmd(start_scripts
)
94 def __enter__(self
) -> "Driver":
97 def __exit__(self
, *_
: Any
) -> None:
98 with self
.logger
.nested("cleanup"):
99 self
.race_timer
.cancel()
100 for machine
in self
.machines
:
103 except Exception as e
:
104 self
.logger
.error(f
"Error during cleanup of {machine.name}: {e}")
106 for vlan
in self
.vlans
:
109 except Exception as e
:
110 self
.logger
.error(f
"Error during cleanup of vlan{vlan.nr}: {e}")
112 def subtest(self
, name
: str) -> Iterator
[None]:
113 """Group logs under a given test name"""
114 with self
.logger
.subtest(name
):
117 except Exception as e
:
118 self
.logger
.error(f
'Test "{name}" failed with error: "{e}"')
121 def test_symbols(self
) -> dict[str, Any
]:
123 def subtest(name
: str) -> Iterator
[None]:
124 return self
.subtest(name
)
126 general_symbols
= dict(
127 start_all
=self
.start_all
,
128 test_script
=self
.test_script
,
129 machines
=self
.machines
,
134 create_machine
=self
.create_machine
,
136 run_tests
=self
.run_tests
,
137 join_all
=self
.join_all
,
139 serial_stdout_off
=self
.serial_stdout_off
,
140 serial_stdout_on
=self
.serial_stdout_on
,
141 polling_condition
=self
.polling_condition
,
142 Machine
=Machine
, # for typing
144 machine_symbols
= {pythonize_name(m
.name
): m
for m
in self
.machines
}
145 # If there's exactly one machine, make it available under the name
146 # "machine", even if it's not called that.
147 if len(self
.machines
) == 1:
148 (machine_symbols
["machine"],) = self
.machines
150 f
"vlan{v.nr}": self
.vlans
[idx
] for idx
, v
in enumerate(self
.vlans
)
153 "additionally exposed symbols:\n "
154 + ", ".join(map(lambda m
: m
.name
, self
.machines
))
156 + ", ".join(map(lambda v
: f
"vlan{v.nr}", self
.vlans
))
158 + ", ".join(list(general_symbols
.keys()))
160 return {**general_symbols
, **machine_symbols
, **vlan_symbols
}
162 def test_script(self
) -> None:
163 """Run the test script"""
164 with self
.logger
.nested("run the VM test script"):
165 symbols
= self
.test_symbols() # call eagerly
166 exec(self
.tests
, symbols
, None)
168 def run_tests(self
) -> None:
169 """Run the test script (for non-interactive test runs)"""
171 f
"Test will time out and terminate in {self.global_timeout} seconds"
173 self
.race_timer
.start()
175 # TODO: Collect coverage data
176 for machine
in self
.machines
:
178 machine
.execute("sync")
180 def start_all(self
) -> None:
181 """Start all machines"""
182 with self
.logger
.nested("start all VMs"):
183 for machine
in self
.machines
:
186 def join_all(self
) -> None:
187 """Wait for all machines to shut down"""
188 with self
.logger
.nested("wait for all VMs to finish"):
189 for machine
in self
.machines
:
190 machine
.wait_for_shutdown()
191 self
.race_timer
.cancel()
193 def terminate_test(self
) -> None:
194 # This will be usually running in another thread than
195 # the thread actually executing the test script.
196 with self
.logger
.nested("timeout reached; test terminating..."):
197 for machine
in self
.machines
:
199 # As we cannot `sys.exit` from another thread
200 # We can at least force the main thread to get SIGTERM'ed.
201 # This will prevent any user who caught all the exceptions
202 # to swallow them and prevent itself from terminating.
203 os
.kill(os
.getpid(), signal
.SIGTERM
)
209 name
: str |
None = None,
210 keep_vm_state
: bool = False,
212 tmp_dir
= get_tmp_dir()
214 cmd
= NixStartScript(start_command
)
215 name
= name
or cmd
.machine_name
219 out_dir
=self
.out_dir
,
222 keep_vm_state
=keep_vm_state
,
226 def serial_stdout_on(self
) -> None:
227 self
.logger
.print_serial_logs(True)
229 def serial_stdout_off(self
) -> None:
230 self
.logger
.print_serial_logs(False)
232 def check_polling_conditions(self
) -> None:
233 for condition
in self
.polling_conditions
:
234 condition
.maybe_raise()
236 def polling_condition(
238 fun_
: Callable |
None = None,
240 seconds_interval
: float = 2.0,
241 description
: str |
None = None,
242 ) -> Callable
[[Callable
], AbstractContextManager
] | AbstractContextManager
:
246 def __init__(self
, fun
: Callable
):
247 self
.condition
= PollingCondition(
254 def __enter__(self
) -> None:
255 driver
.polling_conditions
.append(self
.condition
)
257 def __exit__(self
, a
, b
, c
) -> None: # type: ignore
258 res
= driver
.polling_conditions
.pop()
259 assert res
is self
.condition
261 def wait(self
, timeout
: int = 900) -> None:
262 def condition(last
: bool) -> bool:
265 f
"Last chance for {self.condition.description}"
267 ret
= self
.condition
.check(force
=True)
268 if not ret
and not last
:
270 f
"({self.condition.description} failure not fatal yet)"
274 with driver
.logger
.nested(f
"waiting for {self.condition.description}"):
275 retry(condition
, timeout
=timeout
)