"""Unify all the models with a common sctucture."""
import abc
import logging
import numbers
import pathlib
import re
import sqlite3
import tempfile
import typing
import numpy as np
import torch
from context_verbose import Printer
from mendevi.cst.encoders import ENCODERS
from mendevi.cst.labels import LABELS
from mendevi.cst.profiles import PROFILES
from mendevi.database.create import create_database, is_sqlite
from mendevi.database.extract import SqlLinker
from mendevi.database.meta import get_extractor, merge_extractors
from mendevi.download.decapsulation import retrive_file
from mendevi.probe import probe_and_store
from mendevi.utils import compute_video_hash
EPS = torch.finfo(torch.float32).eps
[docs]
def to_torch(data: dict[str, list], dtype: torch.dtype = torch.float32) -> dict[str, torch.Tensor]:
"""Convert the fields into a tensor compatible encoding.
Allows you to cast lists of numbers into floating point vectors.
Labels are encoded as one-hot.
Returns
-------
dict[str, torch.Tensor]
For each starting field, associate the torch matrix of size (n, k),
where n is the number of elements and k is the dimension.
k is 1 for lists of numbers and is the cardinality of the number of labels otherwise.
Examples
--------
>>> from pprint import pprint
>>> from mendevi.models.base import to_torch
>>> data = {
... "effort": ["fast", "medium", "slow"],
... "encoder": ["libx264", "libsvtav1"],
... "profile": ["sd", "hd", "fhd"],
... "scalar": [0.0, 1.0, 2.0],
... }
>>> pprint(to_torch(data))
{'effort': tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]]),
'encoder': tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]]),
'profile': tensor([[0., 0., 1., 0.],
[0., 1., 0., 0.],
[1., 0., 0., 0.]]),
'scalar': tensor([[0.],
[1.],
[2.]])}
>>>
"""
assert isinstance(data, dict), data.__class__.__name__
out_data: dict[str, torch.Tensor] = {}
for name, values in data.items():
assert isinstance(values, list), values.__class__.__name__
match name:
case "effort":
out_data["effort"] = torch.nn.functional.one_hot(
torch.asarray(
[{"fast": 0, "medium": 1, "slow": 2}[e] for e in values], dtype=int,
),
num_classes=3,
).to(dtype)
case "encoder":
encoders = {e: i for i, e in enumerate(sorted(ENCODERS))}
out_data["encoder"] = torch.nn.functional.one_hot(
torch.asarray([encoders[e] for e in values], dtype=int),
num_classes=len(encoders),
).to(dtype)
case "profile":
profiles = {e: i for i, e in enumerate(sorted(PROFILES))}
out_data["profile"] = torch.nn.functional.one_hot(
torch.asarray([profiles[e] for e in values], dtype=int),
num_classes=len(profiles),
).to(dtype)
case "mode":
out_data["mode"] = torch.nn.functional.one_hot(
torch.asarray([{"cbr": 0, "vbr": 1}[e] for e in values], dtype=int),
num_classes=2,
).to(dtype)
case "hostname":
msg = "The one hot encoding for hostname is not defined."
raise NotImplementedError(msg)
case _:
assert all(isinstance(v, numbers.Real) for v in values), values
out_data[name] = torch.asarray(values, dtype=dtype)[:, None]
return out_data
[docs]
def from_torch(data: dict[str, torch.Tensor]) -> dict[str, list]:
"""Back convertion from vectorized torch to python list, bijection of :py:func:`to_torch`."""
out_data: dict[str, list] = {}
for name, values in data.items():
assert isinstance(values, torch.Tensor), values.__class__.__name__
assert values.ndim == 2, values.shape
match name:
case "effort" | "encoder" | "profile" | "mode" | "hostname":
raise NotImplementedError
case _:
out_data[name] = values.squeeze(1).tolist()
return out_data
[docs]
class Model(metaclass=MetaModel):
"""Common structure to all models.
Attributes
----------
cite : str
The latex bibtext model citation.
parameters : torch.Tensor | None
The trainable parameters of the model (read and write).
input_labels : list[str]
The name of all input parameters (readonly).
output_labels : list[str]
The name of all output parameters (readonly).
accuracy : dict[str, float]
For each output label, associate the standard deviation of the associated average error.
This dictionary is constructed when the .fit method is called (readonly).
"""
def __init__(self, title: str | None = None, **kwargs: dict) -> None:
"""Initialise the model.
Parameters
----------
title : str, optional
The model title.
**kwargs : dict
Includes the following fields.
sources : str
All sources for the model, the conference paper, the authors, etc.
input_labels : list[str]
The name of all input parameters.
The possibles values are `mendevi.plot.axis.Name`.
output_labels : list[str]
The name of all output parameters.
The possibles values are `mendevi.plot.axis.Name`.
parameters : object, optional
The learnable parameters for regressive models.
"""
assert set(kwargs).issubset({"sources", "input_labels", "output_labels", "parameters"})
# check input_labels
input_labels = kwargs.get("input_labels", [])
assert hasattr(input_labels, "__iter__"), input_labels.__class__.__name__
input_labels = list(input_labels)
assert input_labels, "input must be not empty"
assert all(isinstance(lab, str) and lab in LABELS for lab in input_labels), input_labels
self._input_labels = input_labels
# check output_labels
output_labels = kwargs.get("output_labels", [])
assert hasattr(output_labels, "__iter__"), output_labels.__class__.__name__
output_labels = list(output_labels)
assert output_labels, "output must be not empty"
assert all(isinstance(lab, str) and lab in LABELS for lab in output_labels), output_labels
self._output_labels = output_labels
# check parameters
self._parameters = kwargs.get("parameters")
self._std_mean: dict = {}
self._accuracy: dict | None = None
# check title
if title is None:
title = (
f"{'regressive ' if self._parameters is not None else ''}model "
f"to predict {', '.join(sorted(self._output_labels))} "
f"from {', '.join(sorted(self._input_labels))}"
)
else:
assert isinstance(title, str), title.__class__.__name__
self._title = title
# check authors
sources = kwargs.get("sources", "")
assert isinstance(sources, str), sources.__class__.__name__
self._sources = sources
@abc.abstractmethod
def _fit(self, values: dict[str]) -> None:
"""Perform data regression to find the model's hyperparameters.
Hyperparameters must be added to the ``self.parameters`` dictionary.
Parameters
----------
values: dict[str, list]
For each input and output label (ie self.input_labels union self.output_labels),
associate the list of corresponding values.
Notes
-----
If the model contains hyperparameters,
it is essential that the child class redefines one of the methods
self._fit, self._fit_vect or self._fit_vect_norm.
"""
# level 1
if (
self.__class__._fit_vect is not Model._fit_vect # noqa: SLF001
or self.__class__._fit_vect_norm is not Model._fit_vect_norm # noqa: SLF001
):
self._fit_vect(to_torch(values))
# level 0
else:
raise NotImplementedError
@abc.abstractmethod
def _fit_vect(self, values: dict[str, torch.Tensor]) -> None:
"""Do the same as ``self._fit``, but provide tensor rather than lists."""
# level 2
if self.__class__._fit_vect_norm is not Model._fit_vect_norm: # noqa: SLF001
assert set(self._input_labels) | set(self._output_labels) == set(values)
self._std_mean = {
name: (
torch.std_mean(data, dim=0, keepdim=True)
if data.shape[1] == 1 else
(1.0, 0.0) # do nothing for onehot encoding
)
for name, data in values.items()
}
self._fit_vect_norm(
{
name: (data - self._std_mean[name][1]) / (self._std_mean[name][0] + EPS)
for name, data in values.items()
},
)
# level 1
else:
raise NotImplementedError
@abc.abstractmethod
def _fit_vect_norm(self, values: dict[str, torch.Tensor]) -> None:
"""Do the same as ``self._fit_vect``, but the tensors are centered and reduced."""
raise NotImplementedError
@abc.abstractmethod
def _predict(self, values: dict[str, list]) -> dict[str, list]:
"""Perform model prediction.
Parameters
----------
values: dict[str, list]
For each input label (ie self.input_labels), associate the list of corresponding values.
Returns
-------
output: dict[str, list]
For each ouput label (ie self.output_labels),
associate the list of corresponding predicted values.
Notes
-----
The child class must to redefines one of the methods
self._predict, self._predict_vect or self._predict_vect_norm.
"""
# level 1
if (
self.__class__._predict_vect is not Model._predict_vect # noqa: SLF001
or self.__class__._predict_vect_norm is not Model._predict_vect_norm # noqa: SLF001
):
return from_torch(self._predict_vect(to_torch(values)))
# level 0
raise NotImplementedError
@abc.abstractmethod
def _predict_vect(self, values: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Do the same as ``self._predict``, but provide and return tensor rather than lists."""
# level 2
if self.__class__._predict_vect_norm is not Model._predict_vect_norm: # noqa: SLF001
assert set(values).issubset(set(self._input_labels)), (values, self._input_labels)
assert self._std_mean, "you must to call fit before predict"
norm_prediction = self._predict_vect_norm(
{
name: (data - self._std_mean[name][1]) / (self._std_mean[name][0] + EPS)
for name, data in values.items()
},
)
assert set(self._output_labels) == set(norm_prediction)
return {
name: (
norm_prediction[name] * (self._std_mean[name][0]+EPS) + self._std_mean[name][1]
)
for name in self._output_labels
}
# level 1
raise NotImplementedError
@abc.abstractmethod
def _predict_vect_norm(self, values: dict[str, torch.Tensor]) -> dict[str]:
"""Implement the heart of the model."""
raise NotImplementedError
@property
def accuracy(self) -> dict[str, float]:
"""Return the error std for each output label."""
if self._accuracy is None:
msg = "you must call the .fit method before you can know the accuracy of the model"
return RuntimeError(msg)
return self._accuracy.copy() # copy for readonly
@property
def cite(self) -> str:
"""Return the bibtex citation."""
raise NotImplementedError
[docs]
def fit(
self,
database: pathlib.Path | str,
select: str | None = None,
query: str | None = None,
table: str | None = None,
) -> typing.Self:
"""Fit the trainable hyper parameters of the model.
Parameters
----------
database : pathlike
The training database.
select : str, optional
The python expression to keep the line, like ``mendevi plot --filter``.
query : str, optional
If provided, use this sql query to perform the request,
otherwise (default) attemps to guess the query.
table : str, optional
The main sql table juste after the FROM in SELECT.
It helps to choose the write query when there is several candidates.
Return
------
self
A reference to the inplace fitted model.
"""
with Printer(f"Fit {self._title!r}...", color="pink") as prt:
# verification
database = retrive_file(database)
assert is_sqlite(database), f"{database} is not a valid SQL database"
# get sql query
prt.print("get SQL extractor")
atom_names, line_extractor = merge_extractors(
set(self._input_labels) | set(self._output_labels),
select=select,
return_callable=True,
)
if query is None:
# search all SQL queries
select = {s for lbl in atom_names for s in get_extractor(lbl).func.select}
if len(queries := SqlLinker(*select).sql) == 0:
msg = "fail to create the SQL query, please provide it yourself"
raise RuntimeError(msg)
# select good query
if table is not None:
queries = {re.search(r"FROM\s+(?P<tab>\w+)", q)["tab"]: q for q in queries}
if table not in queries:
msg = f"possible queries from {', '.join(queries)}, not {table}"
raise ValueError(msg)
queries = [queries[table]]
# warning
if len(queries) > 1:
logging.getLogger(__name__).warning(
"several request founded %s, please provide the table or the request",
queries,
)
query = queries.pop(0)
else:
assert isinstance(query, str), query.__class__.__name__
# perform sql request
prt.print("perform SQL query")
values = {label: [] for label in set(self._input_labels) | set(self._output_labels)}
with sqlite3.connect(f"file:{database}?mode=ro", uri=True) as conn:
conn.row_factory = sqlite3.Row
for raw in conn.execute(query):
for label, value in line_extractor(dict(raw)).items():
values[label].append(value)
# fit the model
prt.print("fit the model")
self._fit(values)
# evaluate the model error
with prt("Model accuracy:"):
prediction = to_torch(
self._predict({lbl: values[lbl] for lbl in self._input_labels}),
)
validation = to_torch({lbl: values[lbl] for lbl in self._output_labels})
self._accuracy = {
name: float(((prediction[name] - validation[name])**2).mean().sqrt())
for name in self._output_labels
}
for atom_names, std_err in self._accuracy.items():
prt.print(f"std of {atom_names}: {std_err:.4g}")
prt.print_time()
return self
@property
def input_labels(self) -> list[str]:
"""Return the name of all input parameters."""
return self._input_labels.copy()
@property
def output_labels(self) -> list[str]:
"""Return the name of all output parameters."""
return self._output_labels.copy()
@property
def parameters(self) -> torch.Tensor or None:
"""Return the trainable parameters of the model."""
return self._parameters
@parameters.setter
def parameters(self, new_params: torch.Tensor) -> None:
"""Update the parameters."""
if self._parameters is not None and new_params.__class__ != self._parameters.__class__:
logging.getLogger(__name__).warning("change the type of parameters")
self._parameters = new_params
[docs]
def predict(self, *input_args: tuple, **input_kwargs: dict) -> dict[str]:
"""Perform the prediction(s) of this model.
Parameters
----------
*input_args, **input_kwargs
The parameters values, with the keys defined during initialisation.
Returns
-------
prediction : dict[str]
Associate each ouput variable with the prediction.
"""
with Printer(f"Predict {self._title!r}...", color="pink") as prt:
# check args
prt.print("check args")
values: dict[str] = {}
for i, arg in enumerate(input_args):
assert i != len(self._input_labels), (
f"only {len(self._input)} arguments expeted {self._input_labels}, "
f"{input_args} given"
)
values[self._input_labels[i]] = arg
for name, arg in input_kwargs.items():
if name in values:
msg = f"argument {name} given twice"
raise ValueError(msg)
if name not in self._input_labels:
msg = f"only {self._input_labels} arguments excpected, not {name}"
raise ValueError(msg)
values[name] = arg
# cast args to list
prt.print("cast args")
for name, arg in values.copy().items():
match arg:
case numbers.Real() | str():
values[name] = [arg]
case np.ndarray() | torch.Tensor():
values[name] = arg.tolist()
case tuple():
values[name] = list(arg)
prt.print(f"{name} = {values[name]!s:.80}")
# duplicate singloton args
prt.print("duplicate args")
size = max(map(len, values.values()))
values = {name: arg * size if len(arg) == 1 else arg for name, arg in values.items()}
# predict
prt.print("predict")
prediction = self._predict(values)
# check output
assert isinstance(prediction, dict), prediction.__class__.__name__
assert prediction.keys() == set(self._output_labels), \
f"_predict must return {self._output_labels}, not {sorted(prediction)}"
prt.print_time()
return prediction
[docs]
def predict_from_video(
self, video: pathlib.Path | str, *args: tuple, **kwargs: dict,
) -> dict[str]:
"""Simplify the predict method by automatically extracting parameters from the video.
Parameters
----------
video : pathlike
The path to the video whose behavior we want to ``predict``.
*args, **kwargs, optional
The other arguments are passed to the prdict method.
Returns
-------
prediction : dict[str]
The value returned by the ``predict`` method.
"""
# To avoid redundancy, and thus be able to use the rest of the API,
# the decision was made to call the probe function on the video
# rather than extracting the parameters manually.
# We therefore create a small temporary database
# that allows us to reuse the other main functions of mendevi to serve this purpose.
video = pathlib.Path(video).expanduser()
assert video.exists(), video
# 1) creation of the database
database = pathlib.Path(tempfile.gettempdir()) / "model.db"
if not database.exists():
create_database(database)
# 2) find the parameters for probe
atom_names, _ = merge_extractors(set(self._input_labels), return_callable=False)
probe_kwargs = {
"rms_sobel": "rms_sobel" in atom_names,
"rms_time_diff": "rms_time_diff" in atom_names,
"uvq": "uvq" in atom_names,
"ref": {}, # for comparative metrics
}
# 3) call probe
if (conn := probe_and_store(database, video, **probe_kwargs)) is not None:
conn.close()
# 4) extract fields from SQL database
vid_id: bytes = compute_video_hash(video)
with sqlite3.connect(f"file:{database}?mode=ro", uri=True) as conn:
conn.row_factory = sqlite3.Row
vid_fields = {
info["name"][4:] for info in conn.execute("PRAGMA table_info(t_vid_video)")
} - {"id"}
input_labels = {lbl for lbl in self._input_labels if any(lbl in f for f in vid_fields)}
_, line_extractor = merge_extractors(input_labels, return_callable=True)
raw = conn.execute("SELECT * FROM t_vid_video WHERE vid_id = ?", (vid_id,)).fetchone()
vid_kwargs = line_extractor(dict(raw))
# 5) call the classical prediction
return self.predict(*args, **kwargs, **vid_kwargs)