diff --git a/examples/benchmarks/RL_workshop.ipynb b/examples/benchmarks/RL_workshop.ipynb index 0018261..ac7ce66 100644 --- a/examples/benchmarks/RL_workshop.ipynb +++ b/examples/benchmarks/RL_workshop.ipynb @@ -696,7 +696,7 @@ "outputs": [], "source": [ "# Here, obs is a list of MoleculeState observations, with the first entry corresponding to the parent node\n", - "obs = env.reset()" + "obs, info = env.reset()" ] }, { @@ -728,7 +728,7 @@ ], "source": [ "done = False\n", - "obs = env.reset()\n", + "obs, info = env.reset()\n", "np.random.seed(0)\n", "\n", "while not done:\n", @@ -770,7 +770,7 @@ "outputs": [], "source": [ "model = policy_model()\n", - "obs = env.reset()" + "obs, info = env.reset()" ] }, { @@ -1045,7 +1045,7 @@ "\n", "env_preprocessor = get_preprocessor(env.observation_space)(env.observation_space)\n", "policy = trainer.get_policy()\n", - "obs = env.reset()" + "obs, info = env.reset()" ] }, { diff --git a/rlmolecule/molecule_state.py b/rlmolecule/molecule_state.py index df7a9d1..57791dc 100644 --- a/rlmolecule/molecule_state.py +++ b/rlmolecule/molecule_state.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Sequence, Type, Union -import gym +import gymnasium as gym import nfp import numpy as np import ray diff --git a/tests/test_molecule_policy.py b/tests/test_molecule_policy.py index 968b92b..6d4c57f 100644 --- a/tests/test_molecule_policy.py +++ b/tests/test_molecule_policy.py @@ -28,7 +28,7 @@ def molecule_env(qed_root: MoleculeState): def test_policy_model(molecule_env, single_layer_model): - observation, reward, terminal, info = molecule_env.step(0) + observation, reward, terminal, truncated, info = molecule_env.step(0) preprocessor = get_preprocessor(molecule_env.observation_space) obs = preprocessor(molecule_env.observation_space).transform(observation) diff --git a/tests/test_molecule_state.py b/tests/test_molecule_state.py index bdbfb82..a69113d 100644 --- a/tests/test_molecule_state.py +++ b/tests/test_molecule_state.py @@ -46,12 +46,12 @@ def test_prune_terminal(builder): assert repr(env.state.children[-1]) == "C (t)" # select the terminal state - obs, reward, terminal, info = env.step(len(env.state.children) - 1) + obs, reward, terminal, truncated, info = env.step(len(env.state.children) - 1) assert terminal assert np.isclose(reward, 0.3597849378839701) - obs = env.reset() - obs, reward, terminal, info = env.step(len(env.state.children) - 1) + obs, info = env.reset() + obs, reward, terminal, truncated, info = env.step(len(env.state.children) - 1) assert not terminal assert np.isclose(reward, 0) @@ -75,12 +75,12 @@ def test_prune_terminal_ray(ray_init): assert repr(env.state.children[-1]) == "C (t)" # select the terminal state - obs, reward, terminal, info = env.step(len(env.state.children) - 1) + obs, reward, terminal, truncated, info = env.step(len(env.state.children) - 1) assert terminal assert np.isclose(reward, 0.3597849378839701) - obs = env.reset() - obs, reward, terminal, info = env.step(len(env.state.children) - 1) + obs, info = env.reset() + obs, reward, terminal, truncated, info = env.step(len(env.state.children) - 1) assert not terminal assert np.isclose(reward, 0)