normcap: fix on GNOME wayland when used via keybind or alt-f2 (#351763)
[NixPkgs.git] / nixos / lib / test-driver / test_driver / logger.py
blob564d39f4f055ca36366121b4eb57e8f408101da2
1 import atexit
2 import codecs
3 import os
4 import sys
5 import time
6 import unicodedata
7 from abc import ABC, abstractmethod
8 from collections.abc import Iterator
9 from contextlib import ExitStack, contextmanager
10 from pathlib import Path
11 from queue import Empty, Queue
12 from typing import Any
13 from xml.sax.saxutils import XMLGenerator
14 from xml.sax.xmlreader import AttributesImpl
16 from colorama import Fore, Style
17 from junit_xml import TestCase, TestSuite
20 class AbstractLogger(ABC):
21 @abstractmethod
22 def log(self, message: str, attributes: dict[str, str] = {}) -> None:
23 pass
25 @abstractmethod
26 @contextmanager
27 def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
28 pass
30 @abstractmethod
31 @contextmanager
32 def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
33 pass
35 @abstractmethod
36 def info(self, *args, **kwargs) -> None: # type: ignore
37 pass
39 @abstractmethod
40 def warning(self, *args, **kwargs) -> None: # type: ignore
41 pass
43 @abstractmethod
44 def error(self, *args, **kwargs) -> None: # type: ignore
45 pass
47 @abstractmethod
48 def log_serial(self, message: str, machine: str) -> None:
49 pass
51 @abstractmethod
52 def print_serial_logs(self, enable: bool) -> None:
53 pass
56 class JunitXMLLogger(AbstractLogger):
57 class TestCaseState:
58 def __init__(self) -> None:
59 self.stdout = ""
60 self.stderr = ""
61 self.failure = False
63 def __init__(self, outfile: Path) -> None:
64 self.tests: dict[str, JunitXMLLogger.TestCaseState] = {
65 "main": self.TestCaseState()
67 self.currentSubtest = "main"
68 self.outfile: Path = outfile
69 self._print_serial_logs = True
70 atexit.register(self.close)
72 def log(self, message: str, attributes: dict[str, str] = {}) -> None:
73 self.tests[self.currentSubtest].stdout += message + os.linesep
75 @contextmanager
76 def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
77 old_test = self.currentSubtest
78 self.tests.setdefault(name, self.TestCaseState())
79 self.currentSubtest = name
81 yield
83 self.currentSubtest = old_test
85 @contextmanager
86 def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
87 self.log(message)
88 yield
90 def info(self, *args, **kwargs) -> None: # type: ignore
91 self.tests[self.currentSubtest].stdout += args[0] + os.linesep
93 def warning(self, *args, **kwargs) -> None: # type: ignore
94 self.tests[self.currentSubtest].stdout += args[0] + os.linesep
96 def error(self, *args, **kwargs) -> None: # type: ignore
97 self.tests[self.currentSubtest].stderr += args[0] + os.linesep
98 self.tests[self.currentSubtest].failure = True
100 def log_serial(self, message: str, machine: str) -> None:
101 if not self._print_serial_logs:
102 return
104 self.log(f"{machine} # {message}")
106 def print_serial_logs(self, enable: bool) -> None:
107 self._print_serial_logs = enable
109 def close(self) -> None:
110 with open(self.outfile, "w") as f:
111 test_cases = []
112 for name, test_case_state in self.tests.items():
113 tc = TestCase(
114 name,
115 stdout=test_case_state.stdout,
116 stderr=test_case_state.stderr,
118 if test_case_state.failure:
119 tc.add_failure_info("test case failed")
121 test_cases.append(tc)
122 ts = TestSuite("NixOS integration test", test_cases)
123 f.write(TestSuite.to_xml_string([ts]))
126 class CompositeLogger(AbstractLogger):
127 def __init__(self, logger_list: list[AbstractLogger]) -> None:
128 self.logger_list = logger_list
130 def add_logger(self, logger: AbstractLogger) -> None:
131 self.logger_list.append(logger)
133 def log(self, message: str, attributes: dict[str, str] = {}) -> None:
134 for logger in self.logger_list:
135 logger.log(message, attributes)
137 @contextmanager
138 def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
139 with ExitStack() as stack:
140 for logger in self.logger_list:
141 stack.enter_context(logger.subtest(name, attributes))
142 yield
144 @contextmanager
145 def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
146 with ExitStack() as stack:
147 for logger in self.logger_list:
148 stack.enter_context(logger.nested(message, attributes))
149 yield
151 def info(self, *args, **kwargs) -> None: # type: ignore
152 for logger in self.logger_list:
153 logger.info(*args, **kwargs)
155 def warning(self, *args, **kwargs) -> None: # type: ignore
156 for logger in self.logger_list:
157 logger.warning(*args, **kwargs)
159 def error(self, *args, **kwargs) -> None: # type: ignore
160 for logger in self.logger_list:
161 logger.error(*args, **kwargs)
162 sys.exit(1)
164 def print_serial_logs(self, enable: bool) -> None:
165 for logger in self.logger_list:
166 logger.print_serial_logs(enable)
168 def log_serial(self, message: str, machine: str) -> None:
169 for logger in self.logger_list:
170 logger.log_serial(message, machine)
173 class TerminalLogger(AbstractLogger):
174 def __init__(self) -> None:
175 self._print_serial_logs = True
177 def maybe_prefix(self, message: str, attributes: dict[str, str]) -> str:
178 if "machine" in attributes:
179 return f"{attributes['machine']}: {message}"
180 return message
182 @staticmethod
183 def _eprint(*args: object, **kwargs: Any) -> None:
184 print(*args, file=sys.stderr, **kwargs)
186 def log(self, message: str, attributes: dict[str, str] = {}) -> None:
187 self._eprint(self.maybe_prefix(message, attributes))
189 @contextmanager
190 def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
191 with self.nested("subtest: " + name, attributes):
192 yield
194 @contextmanager
195 def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
196 self._eprint(
197 self.maybe_prefix(
198 Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes
202 tic = time.time()
203 yield
204 toc = time.time()
205 self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")
207 def info(self, *args, **kwargs) -> None: # type: ignore
208 self.log(*args, **kwargs)
210 def warning(self, *args, **kwargs) -> None: # type: ignore
211 self.log(*args, **kwargs)
213 def error(self, *args, **kwargs) -> None: # type: ignore
214 self.log(*args, **kwargs)
216 def print_serial_logs(self, enable: bool) -> None:
217 self._print_serial_logs = enable
219 def log_serial(self, message: str, machine: str) -> None:
220 if not self._print_serial_logs:
221 return
223 self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL)
226 class XMLLogger(AbstractLogger):
227 def __init__(self, outfile: str) -> None:
228 self.logfile_handle = codecs.open(outfile, "wb")
229 self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
230 self.queue: Queue[dict[str, str]] = Queue()
232 self._print_serial_logs = True
234 self.xml.startDocument()
235 self.xml.startElement("logfile", attrs=AttributesImpl({}))
237 def close(self) -> None:
238 self.xml.endElement("logfile")
239 self.xml.endDocument()
240 self.logfile_handle.close()
242 def sanitise(self, message: str) -> str:
243 return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C")
245 def maybe_prefix(self, message: str, attributes: dict[str, str]) -> str:
246 if "machine" in attributes:
247 return f"{attributes['machine']}: {message}"
248 return message
250 def log_line(self, message: str, attributes: dict[str, str]) -> None:
251 self.xml.startElement("line", attrs=AttributesImpl(attributes))
252 self.xml.characters(message)
253 self.xml.endElement("line")
255 def info(self, *args, **kwargs) -> None: # type: ignore
256 self.log(*args, **kwargs)
258 def warning(self, *args, **kwargs) -> None: # type: ignore
259 self.log(*args, **kwargs)
261 def error(self, *args, **kwargs) -> None: # type: ignore
262 self.log(*args, **kwargs)
264 def log(self, message: str, attributes: dict[str, str] = {}) -> None:
265 self.drain_log_queue()
266 self.log_line(message, attributes)
268 def print_serial_logs(self, enable: bool) -> None:
269 self._print_serial_logs = enable
271 def log_serial(self, message: str, machine: str) -> None:
272 if not self._print_serial_logs:
273 return
275 self.enqueue({"msg": message, "machine": machine, "type": "serial"})
277 def enqueue(self, item: dict[str, str]) -> None:
278 self.queue.put(item)
280 def drain_log_queue(self) -> None:
281 try:
282 while True:
283 item = self.queue.get_nowait()
284 msg = self.sanitise(item["msg"])
285 del item["msg"]
286 self.log_line(msg, item)
287 except Empty:
288 pass
290 @contextmanager
291 def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
292 with self.nested("subtest: " + name, attributes):
293 yield
295 @contextmanager
296 def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
297 self.xml.startElement("nest", attrs=AttributesImpl({}))
298 self.xml.startElement("head", attrs=AttributesImpl(attributes))
299 self.xml.characters(message)
300 self.xml.endElement("head")
302 tic = time.time()
303 self.drain_log_queue()
304 yield
305 self.drain_log_queue()
306 toc = time.time()
307 self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")
309 self.xml.endElement("nest")