transforms

Normalize

class Normalize(mean, std)

Normalize a tensor or array from a fixed mean and std

Parameters
  • mean (Union[Tensor, ndarray, int, float]) – may be float, tensor or array. if has dimensions (such as channels for images) must match shape of std

  • std (Union[Tensor, ndarray, int, float]) – may be float, tensor or array. if has dimensions (such as channels for images) must match shape of std

Example

>>> import torch
>>> from hearth.data.transforms import Normalize
>>>
>>> transform = Normalize(mean=1.5, std=1.1859)
>>> x = torch.linspace(0, 3, 5)
>>> transform(x)
tensor([-1.2649, -0.6324,  0.0000,  0.6324,  1.2649])
>>> channel_transform = Normalize(mean=torch.tensor([7.6596, 8.0000, 8.3404]),
...                                 std=torch.tensor([4.8622, 4.8622, 4.8622]))
>>> x= torch.linspace(0, 16, 48).reshape(4, 4, 3)
>>> y = channel_transform(x)
>>> y.shape
torch.Size([4, 4, 3])
>>> y.mean(dim=(0, 1))
tensor([-0.00,  0.00,  0.00])
>>> y.std(dim=(0, 1))
tensor([1.0000, 1.0000, 1.0000])

Pipeline

class Pipeline(*transforms)

Pipeline applies a chain of transforms to an input in order.

Example

>>> import torch
>>> import numpy as np
>>> from hearth.data.transforms import Normalize, Tensorize, Pipeline
>>>
>>> pipeline = Pipeline(Tensorize(dtype='float32'), Normalize(mean=-0.34, std=1.75))
>>> pipeline
Pipeline(Tensorize(dtype=torch.float32, device=cpu), Normalize(mean=-0.34, std=1.75))
>>> len(pipeline)
2
>>> x = np.array([-3.0, -1.5, .3, .4, 2.1])
>>> pipeline(x)
tensor([-1.5200, -0.6629,  0.3657,  0.4229,  1.3943])

Tensorize

class Tensorize(dtype=None, device='cpu')

Tensorizes the given input with optional dtype and device

Parameters
  • dtype (Union[str, dtype, None]) – an optional string or torch.dtype. Defaults to None.

  • device (Union[str, device]) – [description]. Defaults to ‘cpu’.

Example

>>> import torch
>>> from hearth.data.transforms import Tensorize
>>>
>>> transform = Tensorize(dtype='float32')
>>> transform([1.1, 2.2, 3.3])
tensor([1.1000, 2.2000, 3.3000])

Transform

class Transform(*args, **kwds)

Abstract base class for all transforms.