Source code for power_cogs.dataset.mnist_dataset
import typing
import numpy as np
import torch
from sklearn import datasets
# internal
from power_cogs.base.base_torch_dataset import BaseTorchDataset
[docs]class MNISTDataset(BaseTorchDataset):
def __init__(self):
super(MNISTDataset, self).__init__()
data = datasets.load_digits()
self.data = torch.from_numpy(data["data"] / 255.0)
self.targets = torch.from_numpy(data["target"])
self.input_dims = self.data.shape[-1]
self.output_dims = np.unique(self.targets).shape[0]
def __len__(self) -> int:
return self.data.shape[0]
def __getitem__(self, index: int) -> typing.Dict[str, typing.Any]:
return {
"data": self.data[index],
"targets": self.targets[index],
"indices": index,
}
[docs] def to_device(self, device: torch.device):
self.data.to(device)
self.targets.to(device)