diff --git a/gymnasium_robotics/envs/maze/maze.py b/gymnasium_robotics/envs/maze/maze.py index 361fc81c..f700f471 100644 --- a/gymnasium_robotics/envs/maze/maze.py +++ b/gymnasium_robotics/envs/maze/maze.py @@ -150,11 +150,19 @@ def make_maze( maze._unique_reset_locations += maze._combined_locations # Save new xml with maze to a temporary file - with tempfile.TemporaryDirectory() as tmp_dir: - temp_xml_path = path.join(path.dirname(tmp_dir), "ant_maze.xml") - tree.write(temp_xml_path) - - return maze, temp_xml_path + # Make temporary file object and make the string path to our new file + tmp_dir = tempfile.TemporaryDirectory() + temp_xml_path = path.join(tmp_dir.name, "ant_maze.xml") + + # Write the new xml to the temporary file + with open(temp_xml_path, "wb") as xml_file: + tree.write(xml_file) + + return ( + maze, + temp_xml_path, + tmp_dir, # The tmp_dir object is returned to keep it alive + ) class MazeEnv(GoalEnv): @@ -172,7 +180,7 @@ def __init__( self.reward_type = reward_type self.continuing_task = continuing_task - self.maze, self.tmp_xml_file_path = Maze.make_maze( + self.maze, self.tmp_xml_file_path, self.tmp_dir = Maze.make_maze( agent_xml_path, maze_map, maze_size_scaling, maze_height ) @@ -308,3 +316,7 @@ def compute_truncated( def update_target_site_pos(self, pos): raise NotImplementedError + + def __del__(self): + self.tmp_dir.cleanup() + super().__del__() diff --git a/gymnasium_robotics/envs/maze/maze_v4.py b/gymnasium_robotics/envs/maze/maze_v4.py index e601a805..6be75e59 100644 --- a/gymnasium_robotics/envs/maze/maze_v4.py +++ b/gymnasium_robotics/envs/maze/maze_v4.py @@ -12,7 +12,6 @@ """ import math import tempfile -import time import xml.etree.ElementTree as ET from os import path from typing import Dict, List, Optional, Union @@ -53,7 +52,6 @@ def __init__( maze_size_scaling: float, maze_height: float, ): - self._maze_map = maze_map self._maze_size_scaling = maze_size_scaling self._maze_height = maze_height @@ -235,12 +233,19 @@ def make_maze( maze._unique_reset_locations += maze._combined_locations # Save new xml with maze to a temporary file - with tempfile.TemporaryDirectory() as tmp_dir: - temp_xml_name = f"ant_maze{str(time.time())}.xml" - temp_xml_path = path.join(path.dirname(tmp_dir), temp_xml_name) - tree.write(temp_xml_path) - - return maze, temp_xml_path + # Make temporary file object and make the string path to our new file + tmp_dir = tempfile.TemporaryDirectory() + temp_xml_path = path.join(tmp_dir.name, "ant_maze.xml") + + # Write the new xml to the temporary file + with open(temp_xml_path, "wb") as xml_file: + tree.write(xml_file) + + return ( + maze, + temp_xml_path, + tmp_dir, # The tmp_dir object is returned to keep it alive + ) class MazeEnv(GoalEnv): @@ -256,11 +261,10 @@ def __init__( position_noise_range: float = 0.25, **kwargs, ): - self.reward_type = reward_type self.continuing_task = continuing_task self.reset_target = reset_target - self.maze, self.tmp_xml_file_path = Maze.make_maze( + self.maze, self.tmp_xml_file_path, self.tmp_dir = Maze.make_maze( agent_xml_path, maze_map, maze_size_scaling, maze_height ) @@ -419,3 +423,7 @@ def update_target_site_pos(self, pos): """Override this method to update the site qpos in the MuJoCo simulation after a new goal is selected. This is mainly for visualization purposes.""" raise NotImplementedError + + def __del__(self): + self.tmp_dir.cleanup() + super().__del__()