Source code for power_cogs.config.base.base_config

from __future__ import annotations

from dataclasses import field
from typing import Any, Dict, List, Optional

from omegaconf import MISSING  # Do not confuse with dataclass.MISSING
from pydantic.dataclasses import dataclass

from power_cogs.config.config_utils import add_configs
from power_cogs.config.torch import AdamConf, DataLoaderConf, ExponentialLRConf
from power_cogs.config.tune.tune_config import TuneConfig

# This denotes the base config, inherit these config classes when you want to implement your own config
DEFAULTS = {
    "model_config": "default",
    "dataset_config": "default",
    "dataloader_config": "default",
    "optimizer_config": "default",
    "scheduler_config": "default",
    "logging_config": "default",
    "tune_config": "default",
}


[docs]def make_trainer_defaults(overrides=[]): override_keys = [] for o in overrides: override_keys.append(list(o.keys())[0]) for d in DEFAULTS: if d not in override_keys: overrides.append({d: DEFAULTS[d]}) return overrides
[docs]@dataclass class BaseModelConfig: _target_: str = MISSING
[docs]@dataclass class BaseDatasetConfig: _target_: str = MISSING
[docs]@dataclass class BaseLoggingConfig: checkpoint_path: str = "checkpoints" tensorboard_log_path: Optional[str] = None
[docs]@dataclass class BaseTrainerConfig: defaults: List[Any] = field(default_factory=lambda: make_trainer_defaults()) _target_: str = MISSING name: Optional[str] = None pretrained_path: Optional[str] = None visualize_output: bool = True use_cuda: bool = False device_id: int = 0 early_stoppage: bool = False loss_threshold: float = -float("inf") batch_size: int = 32 epochs: int = 100 checkpoint_interval: int = 100 num_samples: Optional[int] = None model_config: Any = MISSING dataset_config: Any = MISSING optimizer_config: Any = MISSING scheduler_config: Any = MISSING logging_config: Any = MISSING dataloader_config: Any = MISSING tune_config: Any = MISSING config: Any = field(default_factory=lambda: {})
config_defaults = [{"trainer": "default"}]
[docs]@dataclass class Config: defaults: List[Any] = field(default_factory=lambda: config_defaults) trainer: Any = MISSING
config_dicts: List[Dict[str, Any]] = [ dict(group="trainer/model_config", name="default", node=BaseModelConfig), dict(group="trainer/dataset_config", name="default", node=BaseDatasetConfig), dict(group="trainer/dataloader_config", name="default", node=DataLoaderConf), dict(group="trainer/optimizer_config", name="default", node=AdamConf), dict(group="trainer/scheduler_config", name="default", node=ExponentialLRConf), dict(group="trainer/logging_config", name="default", node=BaseLoggingConfig), dict(group="trainer/tune_config", name="default", node=TuneConfig), dict(group="trainer", name="default", node=BaseTrainerConfig), dict(name="default", node=Config), ] add_configs(config_dicts)