mendevi.models.base.to_torch

mendevi.models.base.to_torch(data: dict[str, list], dtype: dtype = torch.float32) dict[str, Tensor][source]

Convert the fields into a tensor compatible encoding.

Allows you to cast lists of numbers into floating point vectors. Labels are encoded as one-hot.

Returns

dict[str, torch.Tensor]

For each starting field, associate the torch matrix of size (n, k), where n is the number of elements and k is the dimension. k is 1 for lists of numbers and is the cardinality of the number of labels otherwise.

Examples

>>> from pprint import pprint
>>> from mendevi.models.base import to_torch
>>> data = {
...     "effort": ["fast", "medium", "slow"],
...     "encoder": ["libx264", "libsvtav1"],
...     "profile": ["sd", "hd", "fhd"],
...     "scalar": [0.0, 1.0, 2.0],
... }
>>> pprint(to_torch(data))
{'effort': tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]]),
 'encoder': tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]]),
 'profile': tensor([[0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.]]),
 'scalar': tensor([[0.],
        [1.],
        [2.]])}
>>>