Skip to content

Commit

Permalink
default line search to False (can be slow)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Feb 16, 2023
1 parent 1c00ed3 commit 906a2f1
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
11 changes: 7 additions & 4 deletions cca_zoo/models/_stochastic/_eigengame.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
c=0,
nesterov=True,
rho=0.1,
line_search=True,
line_search=False,
):
super().__init__(
latent_dims=latent_dims,
Expand Down Expand Up @@ -140,7 +140,10 @@ def _Aw(self, view, projections):
return view.T @ projections / view.shape[0]

def _Bw(self, view, projection, weight, c):
return (c * weight) + (1 - c) * (view.T @ projection) / projection.shape[0]
if c==1:
return (c * weight)
else:
return (c * weight) + (1 - c) * (view.T @ projection) / projection.shape[0]

def _get_terms(self, i, view, projections, v):
projections.mask[i] = True
Expand Down Expand Up @@ -247,7 +250,7 @@ def __init__(
epochs=1,
learning_rate=1,
nesterov=True,
line_search=True,
line_search=False,
):
super().__init__(
latent_dims=latent_dims,
Expand Down Expand Up @@ -341,7 +344,7 @@ def __init__(
epochs=1,
learning_rate=1,
nesterov=True,
line_search=True,
line_search=False,
):
super().__init__(
latent_dims=latent_dims,
Expand Down
1 change: 1 addition & 0 deletions cca_zoo/models/_stochastic/_ghagep.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
learning_rate=learning_rate,
nesterov=nesterov,
rho=rho,
line_search=False,
c=c
)

Expand Down
2 changes: 1 addition & 1 deletion examples/plot_dcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
# %%
# Deep CCA EigenGame
# ----------------------------
dcca_eg = DCCA_EigenGame(latent_dims=LATENT_DIMS, encoders=[encoder_1, encoder_2])
dcca_eg = DCCA_EigenGame(latent_dims=LATENT_DIMS, encoders=[encoder_1, encoder_2],lr=1e-5)
trainer = pl.Trainer(
max_epochs=EPOCHS,
enable_checkpointing=False,
Expand Down

0 comments on commit 906a2f1

Please sign in to comment.