Source code for sconce.trainers.autoencoder_trainer

from abc import ABC
from sconce.trainer import Trainer

from matplotlib import pyplot as plt


__all__ = ['AutoencoderMixin', 'AutoencoderTrainer']


[docs]class AutoencoderMixin(ABC):
[docs] def plot_input_output_pairs(self, title='A Sampling of Autoencoder Results', num_cols=10, figsize=(15, 3.2)): inputs, targets = self.validation_feed.next() out_dict = self._run_model(inputs, targets, train=True) outputs = out_dict['outputs'] fig = plt.figure(figsize=figsize) fig.suptitle(title, fontsize=20) for i in range(num_cols): input_image = inputs.data.cpu()[i][0] output_image = outputs.view_as(inputs).data.cpu()[i][0] ax = fig.add_subplot(2, num_cols, i + 1) ax.imshow(input_image, cmap='gray') if i == 0: ax.set_ylabel('Input') else: ax.axis('off') ax = fig.add_subplot(2, num_cols, num_cols + i + 1) ax.imshow(output_image, cmap='gray') if i == 0: ax.set_ylabel('Output') else: ax.axis('off') return fig
[docs] def plot_latent_space(self, title="Latent Representation", figsize=(8, 8)): fig = plt.figure(figsize=figsize) fig.suptitle(title, fontsize=20) self.model.train(False) self.validation_feed.reset() for i in range(len(self.validation_feed)): inputs, targets = self.validation_feed.next() x_latent = self.model.encode(inputs=inputs, targets=targets) x_latent_numpy = x_latent.cpu().data.numpy() plt.scatter(x=x_latent_numpy.T[0], y=x_latent_numpy.T[1], c=targets.cpu().data.numpy(), alpha=0.4) plt.colorbar() return fig
[docs]class AutoencoderTrainer(Trainer, AutoencoderMixin): pass