Source code for power_cogs.model.mnist_model
from typing import List, Optional
import torch
from power_cogs.base import BaseTorchModel
# internal
from power_cogs.utils.torch_utils import create_linear_network
[docs]class MNISTModel(BaseTorchModel):
def __init__(
self,
input_dims: int = 64,
hidden_dims: List[int] = [32],
output_dims: int = 10,
output_activation: Optional[str] = None,
use_normal_init: bool = True,
normal_std: float = 0.01,
zero_bias: bool = False,
):
super(MNISTModel, self).__init__()
self.input_shape = input_dims
self.hidden_dims = hidden_dims
self.output_dims = output_dims
if output_activation is not None:
self.output_activation = eval(output_activation)
self.net = create_linear_network(input_dims, hidden_dims, output_dims)
def init_weights(m):
if isinstance(m, torch.nn.Conv3d):
torch.nn.init.normal_(m.weight, std=normal_std)
if getattr(m, "bias", None) is not None:
if zero_bias:
torch.nn.init.zeros_(m.bias)
else:
torch.nn.init.normal_(m.bias, std=normal_std)
if use_normal_init:
with torch.no_grad():
self.apply(init_weights)
[docs] def forward(self, x):
x = self.net(x)
return self.output_activation(x)