Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question on inputs to get action_preds #4

Open
cwfparsonson opened this issue Nov 23, 2022 · 2 comments
Open

Question on inputs to get action_preds #4

cwfparsonson opened this issue Nov 23, 2022 · 2 comments

Comments

@cwfparsonson
Copy link

cwfparsonson commented Nov 23, 2022

Hi,

Many thanks for the implementation!

In the paper, the authors say that the DT policy chooses an action given the previous K states and returns-to-go. However, as far as I can see, in your code (and also in the original DT code), the action_preds calculated in the forward() method of DT only uses x[:,1] as input, which corresponds to the states.

Should the input not be x[:,0] (returns R_1, R_2, ...) and x[:, 1] (states S_1, S_2, ...)?

Also, am I correct that the way you would sample a discrete single action 'choice' at evaluation time from action_preds is by indexing action_preds[0, -1] to get the 'action predictions' of each possible action for the current time step, and then calling torch.max(action_preds[0, -1])? I.e. similar to 'greedy sampling' from the action predictions in standard RL.

@cwfparsonson cwfparsonson changed the title Question on inputs to the action_preds Question on inputs to get action_preds Nov 23, 2022
@daniellawson9999
Copy link
Owner

daniellawson9999 commented Nov 24, 2022

Hello,

This is a good question regarding why we are using x[:1] to predict actions. x is a list of the transformed representations that we get after passing our initial inputs (stacked_inputs) to the transformer, and specifically x[:,1] is a list of the transformed representations of each state after applying several layers of causal (multi-headed) self-attention. The state's representation at time t ( x[t,1] ) was computed by attending to states at previous timesteps (0 to t-1), the return (from 0 to t), and actions from (0 to t-1). So, relevant information regarding the return at the current timestep as well as information from all previous timesteps is stored in the final representation of the state after applying attention, which means that the final representation of the state at time t can be sufficient from predicting the action at time t.

For a discrete action space, for a deterministic policy, you would probably want torch.argmax(action_preds[0,-1]) which returns the index of the action with largest log-probability. If you want a stochastic policy (for exploration), you can instead softmax your logits and get probabilities that parameterize a categorical distribution, which you can then sample from. You can see how this can be done in sample function in the original dt repo: https://github.com/kzl/decision-transformer/blob/master/atari/mingpt/utils.py. It would be interesting to see if a similar approach would work for Atari, or if changes would be needed.

@cwfparsonson
Copy link
Author

Thank you for the quick and detailed reply! That all makes sense.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants