-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Saving functionality. #3
Conversation
baselines/ddpg/ddpg.py
Outdated
@@ -42,6 +42,8 @@ def learn(network, env, | |||
tau=0.01, | |||
eval_env=None, | |||
param_noise_adaption_interval=50, | |||
load_path = None, | |||
save_path = '<specify/path>' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed a comma
baselines/ddpg/ddpg.py
Outdated
@@ -269,5 +274,10 @@ def as_scalar(x): | |||
with open(os.path.join(logdir, 'eval_env_state.pkl'), 'wb') as f: | |||
pickle.dump(eval_env.get_state(), f) | |||
|
|||
os.mkdirs(logdir,exist_ok=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error in python mkdirs does not exist
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error :
Traceback (most recent call last): File "/usr/home/brahma/anaconda3/envs/torcs/lib/python3.6/runpy.py", line 193, in _run_module_as_main "__main__", mod_spec) File "/usr/home/brahma/anaconda3/envs/torcs/lib/python3.6/runpy.py", line 85, in _run_code exec(code, run_globals) File "/usr/home/brahma/BURIDI/MADRAS/baselines/baselines/run.py", line 226, in <module> main(sys.argv) File "/usr/home/brahma/BURIDI/MADRAS/baselines/baselines/run.py", line 198, in main model, env = train(args, extra_args) File "/usr/home/brahma/BURIDI/MADRAS/baselines/baselines/run.py", line 81, in train **alg_kwargs File "/usr/home/brahma/BURIDI/MADRAS/baselines/baselines/ddpg/ddpg.py", line 98, in learn agent.load(load_path) TypeError: 'NoneType' object is not callable
Recreation:
- Ran the training with save path mentioned it created files with epoch numbers.
- Load any of the file to recreate the error. Loading using this command
python -m baselines.run --alg='ddpg' --env='Madras-v0' --load_path='file_path'
and also tried directly setting the load_path in the code.
Restating some issues here for convenience. @rudrasohan |
Why?
There was no previous option to support the saving of trained networks(ddpg). Hence to add that functionality.
What?
Built a saving function where you can manually specify a dir for the network to save and can also load the desired model back.
Testing
For saving:
Specify the path of the save directory in the field
specify/path
in theddpg.py
. Then run the code as usual.python -m baselines.run --alg='ddpg' --env='Madras-v0'
For loading:
python -m baselines.run --alg='ddpg' --env='Madras-v0' --load_path=/specify/path/to/save/file
NOTE:
The save function takes in a dir whereas the load function takes in a file.
This PR also includes #1 as that will be required for madras-env to be working.