power_cogs.tune package

Submodules

power_cogs.tune.tune_wrapper module

class power_cogs.tune.tune_wrapper.CustomMLflowLoggerCallback(*args, **kwargs)[source]

Bases: ray.tune.integration.mlflow.MLflowLoggerCallback

on_trial_save(iteration: int, trials: List[ray.tune.trial.Trial], trial: ray.tune.trial.Trial, **info)[source]

Called after receiving a checkpoint from a trial.

Arguments:
iteration (int): Number of iterations of the tuning loop. trials (List[Trial]): List of trials. trial (Trial): Trial that just saved a checkpoint. **info: Kwargs dict for forward compatibility.
class power_cogs.tune.tune_wrapper.TuneTrainer(config=None, logger_creator=None)[source]

Bases: ray.tune.trainable.Trainable

load_checkpoint(tmp_checkpoint_dir)[source]

Subclasses should override this to implement restore().

Warning:
In this method, do not rely on absolute paths. The absolute path of the checkpoint_dir used in Trainable.save_checkpoint may be changed.

If Trainable.save_checkpoint returned a prefixed string, the prefix of the checkpoint string returned by Trainable.save_checkpoint may be changed. This is because trial pausing depends on temporary directories.

The directory structure under the checkpoint_dir provided to Trainable.save_checkpoint is preserved.

See the example below.

class Example(Trainable):
    def save_checkpoint(self, checkpoint_path):
        print(checkpoint_path)
        return os.path.join(checkpoint_path, "my/check/point")

    def load_checkpoint(self, checkpoint):
        print(checkpoint)

>>> trainer = Example()
>>> obj = trainer.save_to_object()  # This is used when PAUSED.
<logdir>/tmpc8k_c_6hsave_to_object/checkpoint_0/my/check/point
>>> trainer.restore_from_object(obj)  # Note the different prefix.
<logdir>/tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point

New in version 0.8.7.

Args:
checkpoint (str|dict): If dict, the return value is as
returned by save_checkpoint. If a string, then it is a checkpoint path that may have a different prefix than that returned by save_checkpoint. The directory structure underneath the checkpoint_dir save_checkpoint is preserved.
save_checkpoint(tmp_checkpoint_dir)[source]

Subclasses should override this to implement save().

Warning:
Do not rely on absolute paths in the implementation of Trainable.save_checkpoint and Trainable.load_checkpoint.

Use validate_save_restore to catch Trainable.save_checkpoint/ Trainable.load_checkpoint errors before execution.

>>> from ray.tune.utils import validate_save_restore
>>> validate_save_restore(MyTrainableClass)
>>> validate_save_restore(MyTrainableClass, use_object_store=True)

New in version 0.8.7.

Args:
tmp_checkpoint_dir (str): The directory where the checkpoint
file must be stored. In a Tune run, if the trial is paused, the provided path may be temporary and moved.
Returns:
A dict or string. If string, the return value is expected to be prefixed by tmp_checkpoint_dir. If dict, the return value will be automatically serialized by Tune and passed to Trainable.load_checkpoint().
Examples:
>>> print(trainable1.save_checkpoint("/tmp/checkpoint_1"))
"/tmp/checkpoint_1/my_checkpoint_file"
>>> print(trainable2.save_checkpoint("/tmp/checkpoint_2"))
{"some": "data"}
>>> trainable.save_checkpoint("/tmp/bad_example")
"/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.
setup(config)[source]

Subclasses should override this for custom initialization.

New in version 0.8.7.

Args:
config (dict): Hyperparameters and other configs given.
Copy of self.config.
step()[source]

Subclasses should override this to implement train().

The return value will be automatically passed to the loggers. Users can also return tune.result.DONE or tune.result.SHOULD_CHECKPOINT as a key to manually trigger termination or checkpointing of this trial. Note that manual checkpointing only works when subclassing Trainables.

New in version 0.8.7.

Returns:
A dict that describes training progress.
class power_cogs.tune.tune_wrapper.TuneWrapper(config: Any, trainer_config: Any, trainer_overrides: Any)[source]

Bases: object

tune()[source]
power_cogs.tune.tune_wrapper.create_stopper(config)[source]

Module contents