diff --git a/cca_zoo/linear/_gradient/_base.py b/cca_zoo/linear/_gradient/_base.py index 6c4d3769..c9b0b373 100644 --- a/cca_zoo/linear/_gradient/_base.py +++ b/cca_zoo/linear/_gradient/_base.py @@ -20,7 +20,7 @@ DEFAULT_LOADER_KWARGS = dict(pin_memory=False, drop_last=False, shuffle=True) -DEFAULT_OPTIMIZER_KWARGS = dict(optimizer="SGD", momentum=0.9, nesterov=True) +DEFAULT_OPTIMIZER_KWARGS = dict(optimizer="Adam") class BaseGradientModel(BaseModel, pl.LightningModule):