From 1a8360b3c5ce6d6852a151638a0d62ad82a77204 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 17 Dec 2024 16:20:45 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- sota-implementations/dqn/dqn_atari.py | 2 +- sota-implementations/dqn/dqn_cartpole.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 6f249ff810a..f6bcf3044cb 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -86,7 +86,7 @@ def main(cfg: "DictConfig"): # noqa: F821 delay_value=True, ) loss_module.set_keys(done="end-of-life", terminated="end-of-life") - loss_module.make_value_estimator(gamma=cfg.loss.gamma) + loss_module.make_value_estimator(gamma=cfg.loss.gamma, device=device) target_net_updater = HardUpdate( loss_module, value_network_update_interval=cfg.loss.hard_update_freq ) diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index c82490ca677..69689dd4c92 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -69,7 +69,7 @@ def main(cfg: "DictConfig"): # noqa: F821 loss_function="l2", delay_value=True, ) - loss_module.make_value_estimator(gamma=cfg.loss.gamma) + loss_module.make_value_estimator(gamma=cfg.loss.gamma, device=device) loss_module = loss_module.to(device) target_net_updater = HardUpdate( loss_module, value_network_update_interval=cfg.loss.hard_update_freq