From c703b4645c79e889fd6a0f3f64f01f957d981aa4 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 2 Jan 2025 14:40:15 -0800 Subject: [PATCH] [mlir][py] Enable loading only specified dialects during creation. (#121421) Gives option post as global list as well as arg to control which dialects are loaded during context creation. This enables setting either a good base set or skipping in individual cases. --- mlir/python/mlir/_mlir_libs/__init__.py | 42 ++++++++++++++++++++++++++++++--- mlir/python/mlir/ir.py | 6 ++++- mlir/test/python/ir/dialects.py | 36 ++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index c5cb22c6dccb..d021dde05dd8 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -58,6 +58,7 @@ def get_include_dirs() -> Sequence[str]: # needs. _dialect_registry = None +_load_on_create_dialects = None def get_dialect_registry(): @@ -71,6 +72,21 @@ def get_dialect_registry(): return _dialect_registry +def append_load_on_create_dialect(dialect: str): + global _load_on_create_dialects + if _load_on_create_dialects is None: + _load_on_create_dialects = [dialect] + else: + _load_on_create_dialects.append(dialect) + + +def get_load_on_create_dialects(): + global _load_on_create_dialects + if _load_on_create_dialects is None: + _load_on_create_dialects = [] + return _load_on_create_dialects + + def _site_initialize(): import importlib import itertools @@ -132,15 +148,35 @@ def _site_initialize(): break class Context(ir._BaseContext): - def __init__(self, *args, **kwargs): + def __init__(self, load_on_create_dialects=None, *args, **kwargs): super().__init__(*args, **kwargs) self.append_dialect_registry(get_dialect_registry()) for hook in post_init_hooks: hook(self) if not disable_multithreading: self.enable_multithreading(True) - if not disable_load_all_available_dialects: - self.load_all_available_dialects() + if load_on_create_dialects is not None: + logger.debug( + "Loading all dialects from load_on_create_dialects arg %r", + load_on_create_dialects, + ) + for dialect in load_on_create_dialects: + # This triggers loading the dialect into the context. + _ = self.dialects[dialect] + else: + if disable_load_all_available_dialects: + dialects = get_load_on_create_dialects() + if dialects: + logger.debug( + "Loading all dialects from global load_on_create_dialects %r", + dialects, + ) + for dialect in dialects: + # This triggers loading the dialect into the context. + _ = self.dialects[dialect] + else: + logger.debug("Loading all available dialects") + self.load_all_available_dialects() if init_module: logger.debug( "Registering translations from initializer %r", init_module diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 9a6ce462047a..6f37266d5bf3 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -5,7 +5,11 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug from ._mlir_libs._mlir import register_type_caster, register_value_caster -from ._mlir_libs import get_dialect_registry +from ._mlir_libs import ( + get_dialect_registry, + append_load_on_create_dialect, + get_load_on_create_dialects, +) # Convenience decorator for registering user-friendly Attribute builders. diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py index d59c6a6bc424..5a2ed684d298 100644 --- a/mlir/test/python/ir/dialects.py +++ b/mlir/test/python/ir/dialects.py @@ -121,3 +121,39 @@ def testAppendPrefixSearchPath(): sys.path.append(".") _cext.globals.append_dialect_search_prefix("custom_dialect") assert _cext.globals._check_dialect_module_loaded("custom") + + +# CHECK-LABEL: TEST: testDialectLoadOnCreate +@run +def testDialectLoadOnCreate(): + with Context(load_on_create_dialects=[]) as ctx: + ctx.emit_error_diagnostics = True + ctx.allow_unregistered_dialects = True + + def callback(d): + # CHECK: DIAGNOSTIC + # CHECK-SAME: op created with unregistered dialect + print(f"DIAGNOSTIC={d.message}") + return True + + handler = ctx.attach_diagnostic_handler(callback) + loc = Location.unknown(ctx) + try: + op = Operation.create("arith.addi", loc=loc) + ctx.allow_unregistered_dialects = False + op.verify() + except MLIRError as e: + pass + + with Context(load_on_create_dialects=["func"]) as ctx: + loc = Location.unknown(ctx) + fn = Operation.create("func.func", loc=loc) + + # TODO: This may require an update if a site wide policy is set. + # CHECK: Load on create: [] + print(f"Load on create: {get_load_on_create_dialects()}") + append_load_on_create_dialect("func") + # CHECK: Load on create: + # CHECK-SAME: func + print(f"Load on create: {get_load_on_create_dialects()}") + print(get_load_on_create_dialects()) -- 2.11.4.GIT