-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question on inputs to get action_preds #4
Comments
Hello, This is a good question regarding why we are using x[:1] to predict actions. x is a list of the transformed representations that we get after passing our initial inputs (stacked_inputs) to the transformer, and specifically x[:,1] is a list of the transformed representations of each state after applying several layers of causal (multi-headed) self-attention. The state's representation at time t ( x[t,1] ) was computed by attending to states at previous timesteps (0 to t-1), the return (from 0 to t), and actions from (0 to t-1). So, relevant information regarding the return at the current timestep as well as information from all previous timesteps is stored in the final representation of the state after applying attention, which means that the final representation of the state at time t can be sufficient from predicting the action at time t. For a discrete action space, for a deterministic policy, you would probably want torch.argmax(action_preds[0,-1]) which returns the index of the action with largest log-probability. If you want a stochastic policy (for exploration), you can instead softmax your logits and get probabilities that parameterize a categorical distribution, which you can then sample from. You can see how this can be done in sample function in the original dt repo: https://github.com/kzl/decision-transformer/blob/master/atari/mingpt/utils.py. It would be interesting to see if a similar approach would work for Atari, or if changes would be needed. |
Thank you for the quick and detailed reply! That all makes sense. |
Hi,
Many thanks for the implementation!
In the paper, the authors say that the DT policy chooses an action given the previous K states and returns-to-go. However, as far as I can see, in your code (and also in the original DT code), the action_preds calculated in the forward() method of DT only uses x[:,1] as input, which corresponds to the states.
Should the input not be x[:,0] (returns R_1, R_2, ...) and x[:, 1] (states S_1, S_2, ...)?
Also, am I correct that the way you would sample a discrete single action 'choice' at evaluation time from action_preds is by indexing action_preds[0, -1] to get the 'action predictions' of each possible action for the current time step, and then calling torch.max(action_preds[0, -1])? I.e. similar to 'greedy sampling' from the action predictions in standard RL.
The text was updated successfully, but these errors were encountered: