Source code for power_cogs.config.mnist_config

from dataclasses import field
from typing import Any, Dict, List, Optional

from omegaconf import MISSING  # Do not confuse with dataclass.MISSING
from pydantic.dataclasses import dataclass

from power_cogs.config.base import (
    BaseDatasetConfig,
    BaseModelConfig,
    BaseTrainerConfig,
    make_trainer_defaults,
)
from power_cogs.config.config_utils import add_configs


[docs]@dataclass class MNISTModelConfig(BaseModelConfig): _target_: str = "power_cogs.model.mnist_model.MNISTModel" input_dims: Optional[int] = None hidden_dims: List[int] = field(default_factory=lambda: [32]) output_dims: Optional[int] = None output_activation: str = "torch.nn.functional.relu" use_normal_init: bool = True normal_std: float = 0.01 zero_bias: bool = False
[docs]@dataclass class MNISTDatasetConfig(BaseDatasetConfig): _target_: str = "power_cogs.dataset.mnist_dataset.MNISTDataset"
trainer_defaults = [ {"model_config": "mnist"}, {"dataset_config": "mnist"}, ]
[docs]@dataclass class MNISTTrainerConfig(BaseTrainerConfig): _target_: str = "power_cogs.trainer.mnist_trainer.MNISTTrainer" defaults: List[Any] = field( default_factory=lambda: make_trainer_defaults(overrides=trainer_defaults) )
config_defaults = [{"trainer": "mnist"}]
[docs]@dataclass class MNISTConfig: defaults: List[Any] = field(default_factory=lambda: config_defaults) trainer: Any = MISSING
config_dicts: List[Dict[str, Any]] = [ dict(group="trainer/model_config", name="mnist", node=MNISTModelConfig), dict(group="trainer/dataset_config", name="mnist", node=MNISTDatasetConfig), dict(group="trainer", name="mnist", node=MNISTTrainerConfig), dict(name="mnist", node=MNISTConfig), ] add_configs(config_dicts)