Source code for power_cogs.trainer.mnist_trainer

import attr
import numpy as np
import torch.nn.functional as F

# internal
from power_cogs.base.base_torch_trainer import BaseTorchTrainer


[docs]@attr.s class MNISTTrainer(BaseTorchTrainer): _config_name_: str = "mnist"
[docs] def post_dataset_setup(self): self.model_config["input_dims"] = self.dataset.input_dims self.model_config["output_dims"] = self.dataset.output_dims
[docs] def train_iter(self, batch_size: int = 32, iteration: int = 0): losses = [] for batch_ndx, sample in enumerate(self.dataloader): self.optimizer.zero_grad() data = sample["data"].float() targets = sample["targets"] out = self.model(data) loss = F.cross_entropy(out, targets) loss.backward() self.optimizer.step() self.scheduler.step() losses.append(loss.item()) train_dict = { "out": None, "metrics": { "loss": np.mean(losses), "min_loss": np.min(losses), "max_loss": np.max(losses), "mean_loss": np.mean(losses), "sum_loss": np.sum(losses), "median_loss": np.median(losses), }, "loss": np.mean(losses), } return train_dict