Source code for mendevi.models.solver

"""Use a model to solve an optimization problem."""

import itertools
import pathlib
import shlex
import tempfile
import typing

from context_verbose import Printer

from mendevi.cmd import CmdFFMPEG
from mendevi.cst.encoders import ENCODERS
from mendevi.cst.profiles import PROFILES
from mendevi.cst.threads import THREADS
from mendevi.encode import get_transcode_cmd
from mendevi.utils import get_pix_fmt

from .base import Model

N_QUAL: int = 24  # default number of tested quality points
DEFAULT_GRID: dict[str, list] = {
    "effort": ["fast", "medium", "slow"],
    "encoder": sorted(ENCODERS),
    "mode": ["vbr", "cbr"],
    "profile": sorted(PROFILES),
    "quality": [i/(N_QUAL+2) for i in range(1, N_QUAL+2)],
    "threads": [THREADS],
}


[docs] class Solver: """Minimizes a metric for a given model. Performs a grid search on all free parameters. Examples -------- >>> import cutcutcodec >>> from mendevi.models import Solver >>> from mendevi.models.lr import EncodeLinear >>> model = EncodeLinear().fit("x264_vs_openh264.db", table="t_enc_encode") >>> solver = Solver(model, lambda **kwd: kwd["log_energy_per_frame"] + (kwd["psnr"]-35)**2) >>> video = cutcutcodec.utils.get_project_root() / "media" / "video" / "intro.webm" >>> values, loss = solver.solve(video, encoder=["libopenh264", "libx264"]) >>> """ def __init__( self, model: Model, loss: typing.Callable, grid: dict[str, list] | None = None, ) -> None: """Prepare the solver. Parameters ---------- model : :py:class:`mendevi.models.base.Model` The instantiated and fitted model, ready to be evaluated. loss : callable The cost function, which takes the value of the labels as input and returns a scalar. grid : dict[str, list], optional Allows you to define the list of values to be tested. It is a dictionary that for each input label, associate the values to be tested. """ # input verifications assert isinstance(model, Model), \ f"you must to provide a model, not {model.__class__.__name__}" assert model.is_fit(), "the model has not been trained, you must first call the fit method." assert callable(loss), loss.__class__.__name__ grid = grid or {} assert isinstance(grid, dict), grid.__class__.__name__ assert all(k in model.input_labels for k in grid), (sorted(grid), model.input_labels) assert all(isinstance(v, list) for v in grid.values()), grid self._model = model self._loss = loss # set parameters for the grid search self._grid: dict[str, list] = {} # for each input label, associate the values to be tested for label in self._model.input_labels: if label in grid: self._grid[label] = grid[label] elif label in DEFAULT_GRID: self._grid[label] = DEFAULT_GRID[label] @staticmethod def _enc_cmd(video: pathlib.Path, values: dict[str, list]) -> CmdFFMPEG | None: """To each point, return the ffmpeg encode command.""" keys = list(values) if not {"encoder", "mode", "quality"}.issubset(keys) or video is None: return # reject not encoding cases dst = pathlib.Path(tempfile.gettempdir()) / f"{video.stem}_optimal.mp4" for point_tuple in zip(*(values[k] for k in values), strict=True): point = dict(zip(keys, point_tuple, strict=True)) kwargs = { "encoder": point["encoder"], "mode": point["mode"], "quality": point["quality"], "effort": point.get("effort", "medium"), "threads": point.get("threads", THREADS), "filter": point.get("filter"), "fps": point.get("fps"), "resolution": point.get( "resolution", ( (point["height"], point["width"]) if "height" in point and "width" in point else None ), ), "pix_fmt": point.get("pix_fmt", get_pix_fmt(video)), } cmd = get_transcode_cmd(video, dst, **kwargs) cmd.output = ["-map", "0:v:0", *cmd.output] yield cmd
[docs] def solve( self, video: pathlib.Path | str | None = None, **kwargs: dict[str], ) -> tuple[dict[str, list], list[float]]: """Test all combinations as a cartesian product. Parameters ---------- video : pathlike, optional Use this to extract the missing parameters. **kwargs : dict, optional Use to define or redefine certain input parameters to be explored. Explore the cartesian product of all the values. Returns ------- values : dict[str, list] Associate the list of provided and predicted values with each label, input and output. loss : list[float] The value of the stroke function for each point. This list is sorted in ascending order. """ with Printer("Solve...", color="pink") as prt: # preparation of cartesian product of input parameters if video is not None: kwargs = {**self._model.extract_video_props(video), **kwargs} kwargs = self._model.vectorize_kwargs({**self._grid, **kwargs}, _skip_duplication=True) keys = sorted(kwargs) # for repetability values = dict( zip( keys, zip(*itertools.product(*(kwargs[k] for k in keys)), strict=True), strict=True, ), ) # predictions values |= self._model.predict(**values) keys = sorted(values) # compute loss loss = [ self._loss(**dict(zip(keys, p, strict=True))) for p in zip(*(values[k] for k in keys), strict=True) ] # sort idx = sorted(range(len(loss)), key=loss.__getitem__) values = {k: [v[i] for i in idx] for k, v in values.items()} loss = [loss[i] for i in idx] # display with Printer(f"Result for loss = {loss[0]:#.4g}:", color="green"): lbl_len = max(map(len, values)) for lbl in sorted(keys): std = "" if lbl in self._model.accuracy: std = f" \u00B1 {self._model.accuracy[lbl]:#.4g}{std}" if isinstance(values[lbl][0], float): prt.print(f"{lbl:<{lbl_len}}: {values[lbl][0]:#.4g}{std}") else: prt.print(f"{lbl:<{lbl_len}}: {values[lbl][0]}{std}") # case log if lbl.startswith("log_"): maxi = 10.0**(values[lbl][0] + self._model.accuracy[lbl]) mini = 10.0**(values[lbl][0] - self._model.accuracy[lbl]) mean = 10.0**values[lbl][0] prt.print( f"{lbl[4:]:<{lbl_len}}: {mean:#.4g} in [{mini:#.4g}, {maxi:#.4g}]", ) try: cmd = next(iter(self._enc_cmd(video, values))) except StopIteration: pass else: prt.print(cmd) prt.print( f"mendevi probe {shlex.quote(str(cmd.output[-1]))} " f"-r {shlex.quote(str(video))} -d /tmp/solver.db", ) return values, loss