-
Notifications
You must be signed in to change notification settings - Fork 686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type hints #293
base: master
Are you sure you want to change the base?
Type hints #293
Conversation
The latest updates on your projects. Learn more about Vercel for Git ↗︎
|
@@ -159,15 +166,24 @@ def get_action_and_value(self, x, action=None): | |||
# env setup | |||
envs = gym.vector.SyncVectorEnv( | |||
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] | |||
) | |||
) # type:ignore[abstract] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SyncVectorEnv inherits from VectorEnv which inherits from Env. For older gym versions (I'm currently on 0.23.1), Env is an ABC with abstract method render
that is not overriden by any of the vector envs. Since it's fixed in the newest gym release and not our issue, I'm ignoring here.
# Handling gym shapes being Optionals (variant 1) | ||
# Personally i'd prefer the asserts | ||
assert isinstance(envs.single_observation_space.shape, tuple), "shape of observation space must be defined" | ||
assert isinstance(envs.single_action_space.shape, tuple), "shape of action space must be defined" | ||
|
||
# Handling gym shapes being Optionals (variant 2) | ||
# Once could also cast inside each call but in my eyes that's not conducive to readability | ||
obs_space_shape = cast(tuple[int, ...], envs.single_observation_space.shape) | ||
action_space_shape = cast(tuple[int, ...], envs.single_action_space.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gym spaces can in theory return None
shapes. Mypy will complain about this when concatenating the shape tuples later.
- Option 1 is to use a cast either once here or every time the spaces are accessed. I don't think that's very readable.
- Option 2 is to assert that the space shapes are tuples. Doing it once here fixes all errors for the rest of the code. Since there's an assert in this place already anyway I think this is the better option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Option 1 is more preferrable
@@ -92,11 +96,11 @@ def __init__(self, env): | |||
nn.Linear(84, env.single_action_space.n), | |||
) | |||
|
|||
def forward(self, x): | |||
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've used FloatTensor here but we should just use torch.Tensor
if type hints are pursued further.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! The PR looks good. I have left some comments
@@ -82,7 +83,10 @@ def thunk(): | |||
|
|||
# ALGO LOGIC: initialize agent here: | |||
class QNetwork(nn.Module): | |||
def __init__(self, env): | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remove this space?
@@ -92,11 +96,11 @@ def __init__(self, env): | |||
nn.Linear(84, env.single_action_space.n), | |||
) | |||
|
|||
def forward(self, x): | |||
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine
# Handling gym shapes being Optionals (variant 1) | ||
# Personally i'd prefer the asserts | ||
assert isinstance(envs.single_observation_space.shape, tuple), "shape of observation space must be defined" | ||
assert isinstance(envs.single_action_space.shape, tuple), "shape of action space must be defined" | ||
|
||
# Handling gym shapes being Optionals (variant 2) | ||
# Once could also cast inside each call but in my eyes that's not conducive to readability | ||
obs_space_shape = cast(tuple[int, ...], envs.single_observation_space.shape) | ||
action_space_shape = cast(tuple[int, ...], envs.single_action_space.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Option 1 is more preferrable
def get_action_and_value(self, x, action=None): | ||
def get_action_and_value( | ||
self, x: torch.Tensor, action: Optional[torch.Tensor] = None | ||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding from __future__ import annotations
at the top would guarantee the 3.7+ compatibility, and make dropping support later on quite easy.
Description
As discussed on Discord, I've done basic type hints for PPO and DQN. Everything checks out with mypy 0.982 (
mypy cleanrl/ppo.py --show-error-codes --ignore-missing-imports
). Of course we can have a discussion about whether that's the checker that will be used if we go further down the road of implementing this. I'll put comments in noteworthy places.The tests fail because
tuple[int]
orlist[int]
only works from Python 3.9 on.Types of changes
Checklist:
pre-commit run --all-files
passes (required).