mendevi.models.base.to_torch¶
- mendevi.models.base.to_torch(data: dict[str, list], dtype: dtype = torch.float32) dict[str, Tensor][source]¶
Convert the fields into a tensor compatible encoding.
Allows you to cast lists of numbers into floating point vectors. Labels are encoded as one-hot.
Returns¶
- dict[str, torch.Tensor]
For each starting field, associate the torch matrix of size (n, k), where n is the number of elements and k is the dimension. k is 1 for lists of numbers and is the cardinality of the number of labels otherwise.
Examples¶
>>> from pprint import pprint >>> from mendevi.models.base import to_torch >>> data = { ... "effort": ["fast", "medium", "slow"], ... "encoder": ["libx264", "libsvtav1"], ... "profile": ["sd", "hd", "fhd"], ... "scalar": [0.0, 1.0, 2.0], ... } >>> pprint(to_torch(data)) {'effort': tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]), 'encoder': tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]]), 'profile': tensor([[0., 0., 1., 0.], [0., 1., 0., 0.], [1., 0., 0., 0.]]), 'scalar': tensor([[0.], [1.], [2.]])} >>>