Source code for mendevi.cst.model

"""All the available models."""

import functools
import importlib
import inspect
import sys

from mendevi.models.base import Model
from mendevi.utils import get_project_root


def _yield_all_models() -> Model:
    """Discover the Model instances in mendevi/models/*.py."""
    for path in (get_project_root() / "models").glob("*.py"):
        if path.stem in {"__init__", "base"}:  # skip root module
            continue
        modulename = f"mendevi.models.{path.stem}"
        spec = importlib.util.spec_from_file_location(modulename, path)
        modulevar = importlib.util.module_from_spec(spec)
        sys.modules[modulename] = modulevar  # to avoid pickle error: it's not the same object...
        spec.loader.exec_module(modulevar)
        for name in dir(modulevar):
            var = modulevar.__dict__[name]
            if inspect.isclass(var) and issubclass(var, Model) and var is not Model:
                yield name, var


[docs] @functools.cache def import_all_models() -> dict[str, Model]: """Import all the class in mendevi/models/*.py that inherit from Model. Examples -------- >>> from pprint import pprint >>> from mendevi.cst.model import import_all_models >>> pprint(import_all_models()) {'EncodeLinear': <class 'mendevi.models.lr.EncodeLinear'>, 'GPR': <class 'mendevi.models.gpr.GPR'>, 'LR': <class 'mendevi.models.lr.LR'>, 'NearestL2': <class 'mendevi.models.interp.NearestL2'>, 'PowerCores': <class 'mendevi.models.power_cores.PowerCores'>} >>> """ return dict(_yield_all_models())