Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PPO] Trajectory action out of range #564

Open
pbautista-apt opened this issue Feb 18, 2021 · 13 comments
Open

[PPO] Trajectory action out of range #564

pbautista-apt opened this issue Feb 18, 2021 · 13 comments
Assignees

Comments

@pbautista-apt
Copy link

Hello,

I'm trying to implement the PPO agent using a custom environment with a Discrete spaces object with bounds [0,4), but the agent policy is choosing a number out of range.
action_space = spaces.Discrete(4)

I created two new networks

actor = actor_distribution_network.ActorDistributionNetwork(
				train_env.observation_spec(),
				train_env.action_spec(),
				fc_layer_params=fc_layer_params,
				preprocessing_layers=preprocess_layers,
				activation_fn=customReLU,
				preprocessing_combiner=tf.keras.layers.Concatenate()
)
value_ = value_network.ValueNetwork(
				train_env.observation_spec(),
				activation_fn=customReLU,
				preprocessing_layers=preprocess_layers,
				preprocessing_combiner=tf.keras.layers.Concatenate()
)
agent = ppo_agent.PPOAgent(
	train_env.time_step_spec(),
	train_env.action_spec(),
	actor_net=actor,
	value_net=value_,
	optimizer=optimizer,
	use_gae=True,
    use_td_lambda_return=True,
)

and verified that the bounds of the action_spec from the resulting agent is [0,4)
image

However, during my training loop

collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
    train_env,
    agent.collect_policy,
    observers=[replay_buffer.add_batch],
    num_episodes=10
)

for step in range(num_iterations):
    collect_driver.run()
    trajectories = replay_buffer.gather_all()
    train_loss = agent.train(experience=trajectories).loss
    replay_buffer.clear()

The loop would run from anywhere between 0-400 iterations with no problems and eventually I end up getting an InvalidArgumentError:

InvalidArgumentError: Received a label value of 4 which is outside the valid range of [0, 4). Label values: 4 4...[Op:SparseSoftmaxCrossEntropyWithLogits]

Upon further inspection of the trajectories, it seems that the the agent policy is outputting values outside of the action bounds.
action=<tf.Tensor: shape=(1, 60), dtype=int64, numpy= array([[4, 4, ...], dtype=int64)>

I initially thought it was a problem with the activation function with the networks and created a bounded ReLU function in order to limit it, but still the same problems.

Would this be an issue with my environment or the networks I setup?

@ZakSingh
Copy link

ZakSingh commented Mar 4, 2021

I'm getting the same exact problem with my environment. Were you able to find a solution?

@pbautista-apt
Copy link
Author

@ZakSingh Not yet. The weird thing is that I don't run into the same problem with the DQN.

@ebrevdo
Copy link
Contributor

ebrevdo commented Mar 18, 2021

@summer-yue PTAL?

@RachithP
Copy link

RachithP commented Jun 23, 2021

I can confirm this issue for continuous values as well. tf-agent v0.7.1.
These seems linked #121, #216

@ebrevdo
Copy link
Contributor

ebrevdo commented Jun 23, 2021

@egonina current rotation; ptal?

@ebrevdo
Copy link
Contributor

ebrevdo commented Jun 23, 2021

Perhaps ActorDistributionNetwork does not respect boundary values?

@ebrevdo
Copy link
Contributor

ebrevdo commented Jun 23, 2021

I think the problem is in the construction of the Discrete output distribution here. Can you link to a gist with the full traceback of the error?

@RachithP
Copy link

RachithP commented Jun 24, 2021

Don't think this will be helpful - Stack Trace.

Also,
output of print(self.action_spec())

BoundedArraySpec(shape=(2,), dtype=dtype('float32'), name='action', minimum=-1.0, maximum=1.0)

So, expected action to be in range [-1, 1], but got (-1.2548504, 0.55205715).

As mentioned in here, PPO does not handle action bound clipping.

@egonina
Copy link
Contributor

egonina commented Jun 24, 2021

Have you tried the workaround in #216 ?

@kuanghuei can you PTAL as well since you're more familiar with PPO and have context on previous issues. Thanks!

@RachithP
Copy link

For my case, I just clipped action values in my env.

@ebrevdo
Copy link
Contributor

ebrevdo commented Jun 24, 2021

You can alternatively pass a discrete_projection_net or continuous_projection_net argument to ActorDistributionNetwork that is a function that builds a distribution that properly respects your action spec.

For example, if you are using discrete actions, the default is:

def _categorical_projection_net(action_spec, logits_init_output_factor=0.1):
  return categorical_projection_network.CategoricalProjectionNetwork(
      action_spec, logits_init_output_factor=logits_init_output_factor)

But instead you can use something like:

def create_projection_net(action_spec):
  num_actions = action_spec.maximum - action_spec.minimum
  return tfa.networks.Sequential([
     tf.keras.layers.Dense(num_actions),
     tf.keras.layers.Lambda(lambda logits: tfp.distributions.Categorical(logits, dtype=action_spec.dtype))
  ])

For a continuous network you could instead emit a TruncatedNormal.

@ebrevdo
Copy link
Contributor

ebrevdo commented Jun 24, 2021

You could also just build a complete Sequential that emits a Lambda creating a Distribution as the full action network instead of relying on ActorDistributionNetwork. This has been the recommended approach since ~late 2020.

@basvanopheusden
Copy link

@ebrevdo Could you elaborate on how to do that? I was running into this issue in #216, it continues to pop up occasionally despite the workaround.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants