This commit is contained in:
2024-10-09 16:13:22 +00:00
commit 0ea3f048dc
437 changed files with 44406 additions and 0 deletions

0
optimizers/__init__.py Executable file
View File

32
optimizers/optimizer_factory.py Executable file
View File

@@ -0,0 +1,32 @@
import torch.optim as optim
class OptimizerFactory:
@staticmethod
def create(config, params):
optim_type = config["type"]
lr = config.get("lr", 1e-3)
if optim_type == "sgd":
return optim.SGD(
params,
lr=lr,
momentum=config.get("momentum", 0.9),
weight_decay=config.get("weight_decay", 1e-4),
)
elif optim_type == "adam":
return optim.Adam(
params,
lr=lr,
betas=config.get("betas", (0.9, 0.999)),
eps=config.get("eps", 1e-8),
)
else:
raise NotImplementedError("Unknown optimizers: {}".format(optim_type))
""" ------------ Debug ------------ """
if __name__ == "__main__":
from configs.config import ConfigManager
ConfigManager.load_config_with("../configs/local_train_config.yaml")
ConfigManager.print_config()