[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / utils / jupyter / mlir_opt_kernel / kernel.py
blobc0e4fc1db4c8a834528a5000be1d4ec5f398ff1b
1 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2 # See https://llvm.org/LICENSE.txt for license information.
3 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 from subprocess import Popen
6 import os
7 import subprocess
8 import tempfile
9 import traceback
10 from ipykernel.kernelbase import Kernel
12 __version__ = "0.0.1"
15 def _get_executable():
16 """Find the mlir-opt executable."""
18 def is_exe(fpath):
19 """Returns whether executable file."""
20 return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
22 program = os.environ.get("MLIR_OPT_EXECUTABLE", "mlir-opt")
23 path, name = os.path.split(program)
24 # Attempt to get the executable
25 if path:
26 if is_exe(program):
27 return program
28 else:
29 for path in os.environ["PATH"].split(os.pathsep):
30 file = os.path.join(path, name)
31 if is_exe(file):
32 return file
33 raise OSError("mlir-opt not found, please see README")
36 class MlirOptKernel(Kernel):
37 """Kernel using mlir-opt inside jupyter.
39 The reproducer syntax (`// configuration:`) is used to run passes. The
40 previous result can be referenced to by using `_` (this variable is reset
41 upon error). E.g.,
43 ```mlir
44 // configuration: --pass
45 func.func @foo(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> { ... }
46 ```
48 ```mlir
49 // configuration: --next-pass
51 ```
52 """
54 implementation = "mlir"
55 implementation_version = __version__
57 language_version = __version__
58 language = "mlir"
59 language_info = {
60 "name": "mlir",
61 "codemirror_mode": {"name": "mlir"},
62 "mimetype": "text/x-mlir",
63 "file_extension": ".mlir",
64 "pygments_lexer": "text",
67 @property
68 def banner(self):
69 """Returns kernel banner."""
70 # Just a placeholder.
71 return "mlir-opt kernel %s" % __version__
73 def __init__(self, **kwargs):
74 Kernel.__init__(self, **kwargs)
75 self._ = None
76 self.executable = None
77 self.silent = False
79 def get_executable(self):
80 """Returns the mlir-opt executable path."""
81 if not self.executable:
82 self.executable = _get_executable()
83 return self.executable
85 def process_output(self, output):
86 """Reports regular command output."""
87 if not self.silent:
88 # Send standard output
89 stream_content = {"name": "stdout", "text": output}
90 self.send_response(self.iopub_socket, "stream", stream_content)
92 def process_error(self, output):
93 """Reports error response."""
94 if not self.silent:
95 # Send standard error
96 stream_content = {"name": "stderr", "text": output}
97 self.send_response(self.iopub_socket, "stream", stream_content)
99 def do_execute(
100 self, code, silent, store_history=True, user_expressions=None, allow_stdin=False
102 """Execute user code using mlir-opt binary."""
104 def ok_status():
105 """Returns OK status."""
106 return {
107 "status": "ok",
108 "execution_count": self.execution_count,
109 "payload": [],
110 "user_expressions": {},
113 def run(code):
114 """Run the code by pipeing via filesystem."""
115 try:
116 inputmlir = tempfile.NamedTemporaryFile(delete=False)
117 command = [
118 # Specify input and output file to error out if also
119 # set as arg.
120 self.get_executable(),
121 "--color",
122 inputmlir.name,
123 "-o",
124 "-",
126 # Simple handling of repeating last line.
127 if code.endswith("\n_"):
128 if not self._:
129 raise NameError("No previous result set")
130 code = code[:-1] + self._
131 inputmlir.write(code.encode("utf-8"))
132 inputmlir.close()
133 pipe = Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
134 output, errors = pipe.communicate()
135 exitcode = pipe.returncode
136 finally:
137 os.unlink(inputmlir.name)
139 # Replace temporary filename with placeholder. This takes the very
140 # remote chance where the full input filename (generated above)
141 # overlaps with something in the dump unrelated to the file.
142 fname = inputmlir.name.encode("utf-8")
143 output = output.replace(fname, b"<<input>>")
144 errors = errors.replace(fname, b"<<input>>")
145 return output, errors, exitcode
147 self.silent = silent
148 if not code.strip():
149 return ok_status()
151 try:
152 output, errors, exitcode = run(code)
154 if exitcode:
155 self._ = None
156 else:
157 self._ = output.decode("utf-8")
158 except KeyboardInterrupt:
159 return {"status": "abort", "execution_count": self.execution_count}
160 except Exception as error:
161 # Print traceback for local debugging.
162 traceback.print_exc()
163 self._ = None
164 exitcode = 255
165 errors = repr(error).encode("utf-8")
167 if exitcode:
168 content = {"ename": "", "evalue": str(exitcode), "traceback": []}
170 self.send_response(self.iopub_socket, "error", content)
171 self.process_error(errors.decode("utf-8"))
173 content["execution_count"] = self.execution_count
174 content["status"] = "error"
175 return content
177 if not silent:
178 data = {}
179 data["text/x-mlir"] = self._
180 content = {
181 "execution_count": self.execution_count,
182 "data": data,
183 "metadata": {},
185 self.send_response(self.iopub_socket, "execute_result", content)
186 self.process_output(self._)
187 self.process_error(errors.decode("utf-8"))
188 return ok_status()