Source code for topobenchmarkx.optimizer.optimizer
"""Optimizer class responsible of managing both optimizer and scheduler."""
import functools
from typing import Any
import torch.optim
from .base import AbstractOptimizer
TORCH_OPTIMIZERS = torch.optim.__dict__
TORCH_SCHEDULERS = torch.optim.lr_scheduler.__dict__
[docs]
class TBXOptimizer(AbstractOptimizer):
"""Optimizer class that manage both optimizer and scheduler, fully compatible with `torch.optim` classes.
Parameters
----------
optimizer_id : str
Name of the torch optimizer class to be used.
parameters : dict
Parameters to be passed to the optimizer.
scheduler : dict, optional
Scheduler id and parameters to be used. Default is None.
"""
def __init__(self, optimizer_id, parameters, scheduler=None) -> None:
optimizer_id = optimizer_id
self.optimizer = functools.partial(
TORCH_OPTIMIZERS[optimizer_id], **parameters
)
if scheduler is not None:
scheduler_id = scheduler.get("scheduler_id")
scheduler_params = scheduler.get("scheduler_params")
self.scheduler = functools.partial(
TORCH_SCHEDULERS[scheduler_id], **scheduler_params
)
else:
self.scheduler = None
def __repr__(self) -> str:
if self.scheduler is not None:
return f"{self.__class__.__name__}(optimizer={self.optimizer.__name__}, scheduler={self.scheduler.__name__})"
else:
return f"{self.__class__.__name__}(optimizer={self.optimizer.__name__})"