Source code for power_cogs.config.base.base_config
from __future__ import annotations
from dataclasses import field
from typing import Any, Dict, List, Optional
from omegaconf import MISSING # Do not confuse with dataclass.MISSING
from pydantic.dataclasses import dataclass
from power_cogs.config.config_utils import add_configs
from power_cogs.config.torch import AdamConf, DataLoaderConf, ExponentialLRConf
from power_cogs.config.tune.tune_config import TuneConfig
# This denotes the base config, inherit these config classes when you want to implement your own config
DEFAULTS = {
"model_config": "default",
"dataset_config": "default",
"dataloader_config": "default",
"optimizer_config": "default",
"scheduler_config": "default",
"logging_config": "default",
"tune_config": "default",
}
[docs]def make_trainer_defaults(overrides=[]):
override_keys = []
for o in overrides:
override_keys.append(list(o.keys())[0])
for d in DEFAULTS:
if d not in override_keys:
overrides.append({d: DEFAULTS[d]})
return overrides
[docs]@dataclass
class BaseModelConfig:
_target_: str = MISSING
[docs]@dataclass
class BaseDatasetConfig:
_target_: str = MISSING
[docs]@dataclass
class BaseLoggingConfig:
checkpoint_path: str = "checkpoints"
tensorboard_log_path: Optional[str] = None
[docs]@dataclass
class BaseTrainerConfig:
defaults: List[Any] = field(default_factory=lambda: make_trainer_defaults())
_target_: str = MISSING
name: Optional[str] = None
pretrained_path: Optional[str] = None
visualize_output: bool = True
use_cuda: bool = False
device_id: int = 0
early_stoppage: bool = False
loss_threshold: float = -float("inf")
batch_size: int = 32
epochs: int = 100
checkpoint_interval: int = 100
num_samples: Optional[int] = None
model_config: Any = MISSING
dataset_config: Any = MISSING
optimizer_config: Any = MISSING
scheduler_config: Any = MISSING
logging_config: Any = MISSING
dataloader_config: Any = MISSING
tune_config: Any = MISSING
config: Any = field(default_factory=lambda: {})
config_defaults = [{"trainer": "default"}]
[docs]@dataclass
class Config:
defaults: List[Any] = field(default_factory=lambda: config_defaults)
trainer: Any = MISSING
config_dicts: List[Dict[str, Any]] = [
dict(group="trainer/model_config", name="default", node=BaseModelConfig),
dict(group="trainer/dataset_config", name="default", node=BaseDatasetConfig),
dict(group="trainer/dataloader_config", name="default", node=DataLoaderConf),
dict(group="trainer/optimizer_config", name="default", node=AdamConf),
dict(group="trainer/scheduler_config", name="default", node=ExponentialLRConf),
dict(group="trainer/logging_config", name="default", node=BaseLoggingConfig),
dict(group="trainer/tune_config", name="default", node=TuneConfig),
dict(group="trainer", name="default", node=BaseTrainerConfig),
dict(name="default", node=Config),
]
add_configs(config_dicts)