Skip to content

Commit

Permalink
adding load model
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Schmidt committed Sep 5, 2019
1 parent dacd248 commit 0f33437
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
8 changes: 7 additions & 1 deletion mnist_cnn/mnist_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,14 @@ def get_model(data, lr=1e-1):
run.fit(4000, learn)

# Evaluate model
from inspection import evaluate
evaluate(valid_ds, learn.model)

# Save model
state = learn.model.state_dict()
torch.save(state, './mnist_cnn_small_1')
torch.save(state, './mnist_cnn_small_1.model')

# +
# Load model
# m = learn.model
# m.load_state_dict((torch.load('./mnist_cnn_small_1.model')))
14 changes: 12 additions & 2 deletions mnist_cnn/mnist_cnn_big.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,18 @@ def get_model(data, lr=1e-1): #1e-1
run.fit(4000, learn)

# Evaluate model
from inspection import evaluate
evaluate(valid_ds, learn.model)

# +
# Save model
state = learn.model.state_dict()
torch.save(state, './mnist_cnn_small_1')
# state = learn.model.state_dict()
# torch.save(state, './mnist_cnn_big_1.model')

# +
# Load model
# m = learn.model
# m.load_state_dict((torch.load('./mnist_cnn_big_1.model')))
# -


0 comments on commit 0f33437

Please sign in to comment.