success
This commit is contained in:
32
optimizers/optimizer_factory.py
Executable file
32
optimizers/optimizer_factory.py
Executable file
@@ -0,0 +1,32 @@
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
class OptimizerFactory:
|
||||
@staticmethod
|
||||
def create(config, params):
|
||||
optim_type = config["type"]
|
||||
lr = config.get("lr", 1e-3)
|
||||
if optim_type == "sgd":
|
||||
return optim.SGD(
|
||||
params,
|
||||
lr=lr,
|
||||
momentum=config.get("momentum", 0.9),
|
||||
weight_decay=config.get("weight_decay", 1e-4),
|
||||
)
|
||||
elif optim_type == "adam":
|
||||
return optim.Adam(
|
||||
params,
|
||||
lr=lr,
|
||||
betas=config.get("betas", (0.9, 0.999)),
|
||||
eps=config.get("eps", 1e-8),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unknown optimizers: {}".format(optim_type))
|
||||
|
||||
|
||||
""" ------------ Debug ------------ """
|
||||
if __name__ == "__main__":
|
||||
from configs.config import ConfigManager
|
||||
|
||||
ConfigManager.load_config_with("../configs/local_train_config.yaml")
|
||||
ConfigManager.print_config()
|
Reference in New Issue
Block a user