1 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2 # See https://llvm.org/LICENSE.txt for license information.
3 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 from subprocess
import Popen
10 from ipykernel
.kernelbase
import Kernel
15 def _get_executable():
16 """Find the mlir-opt executable."""
19 """Returns whether executable file."""
20 return os
.path
.isfile(fpath
) and os
.access(fpath
, os
.X_OK
)
22 program
= os
.environ
.get("MLIR_OPT_EXECUTABLE", "mlir-opt")
23 path
, name
= os
.path
.split(program
)
24 # Attempt to get the executable
29 for path
in os
.environ
["PATH"].split(os
.pathsep
):
30 file = os
.path
.join(path
, name
)
33 raise OSError("mlir-opt not found, please see README")
36 class MlirOptKernel(Kernel
):
37 """Kernel using mlir-opt inside jupyter.
39 The reproducer syntax (`// configuration:`) is used to run passes. The
40 previous result can be referenced to by using `_` (this variable is reset
44 // configuration: --pass
45 func.func @foo(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> { ... }
49 // configuration: --next-pass
54 implementation
= "mlir"
55 implementation_version
= __version__
57 language_version
= __version__
61 "codemirror_mode": {"name": "mlir"},
62 "mimetype": "text/x-mlir",
63 "file_extension": ".mlir",
64 "pygments_lexer": "text",
69 """Returns kernel banner."""
71 return "mlir-opt kernel %s" % __version__
73 def __init__(self
, **kwargs
):
74 Kernel
.__init
__(self
, **kwargs
)
76 self
.executable
= None
79 def get_executable(self
):
80 """Returns the mlir-opt executable path."""
81 if not self
.executable
:
82 self
.executable
= _get_executable()
83 return self
.executable
85 def process_output(self
, output
):
86 """Reports regular command output."""
88 # Send standard output
89 stream_content
= {"name": "stdout", "text": output
}
90 self
.send_response(self
.iopub_socket
, "stream", stream_content
)
92 def process_error(self
, output
):
93 """Reports error response."""
96 stream_content
= {"name": "stderr", "text": output
}
97 self
.send_response(self
.iopub_socket
, "stream", stream_content
)
100 self
, code
, silent
, store_history
=True, user_expressions
=None, allow_stdin
=False
102 """Execute user code using mlir-opt binary."""
105 """Returns OK status."""
108 "execution_count": self
.execution_count
,
110 "user_expressions": {},
114 """Run the code by pipeing via filesystem."""
116 inputmlir
= tempfile
.NamedTemporaryFile(delete
=False)
118 # Specify input and output file to error out if also
120 self
.get_executable(),
126 # Simple handling of repeating last line.
127 if code
.endswith("\n_"):
129 raise NameError("No previous result set")
130 code
= code
[:-1] + self
._
131 inputmlir
.write(code
.encode("utf-8"))
133 pipe
= Popen(command
, stdout
=subprocess
.PIPE
, stderr
=subprocess
.PIPE
)
134 output
, errors
= pipe
.communicate()
135 exitcode
= pipe
.returncode
137 os
.unlink(inputmlir
.name
)
139 # Replace temporary filename with placeholder. This takes the very
140 # remote chance where the full input filename (generated above)
141 # overlaps with something in the dump unrelated to the file.
142 fname
= inputmlir
.name
.encode("utf-8")
143 output
= output
.replace(fname
, b
"<<input>>")
144 errors
= errors
.replace(fname
, b
"<<input>>")
145 return output
, errors
, exitcode
152 output
, errors
, exitcode
= run(code
)
157 self
._ = output
.decode("utf-8")
158 except KeyboardInterrupt:
159 return {"status": "abort", "execution_count": self
.execution_count
}
160 except Exception as error
:
161 # Print traceback for local debugging.
162 traceback
.print_exc()
165 errors
= repr(error
).encode("utf-8")
168 content
= {"ename": "", "evalue": str(exitcode
), "traceback": []}
170 self
.send_response(self
.iopub_socket
, "error", content
)
171 self
.process_error(errors
.decode("utf-8"))
173 content
["execution_count"] = self
.execution_count
174 content
["status"] = "error"
179 data
["text/x-mlir"] = self
._
181 "execution_count": self
.execution_count
,
185 self
.send_response(self
.iopub_socket
, "execute_result", content
)
186 self
.process_output(self
._)
187 self
.process_error(errors
.decode("utf-8"))