diff --git a/src/shellinspector/runner.py b/src/shellinspector/runner.py index 6c2c43c..ee8446a 100644 --- a/src/shellinspector/runner.py +++ b/src/shellinspector/runner.py @@ -22,6 +22,12 @@ LOGGER = logging.getLogger(Path(__file__).name) +class TimeoutException(Exception): + def __init__(self, output_so_far: str): + self.output_so_far = output_so_far + super().__init__() + + @dataclasses.dataclass class ShellinspectorPyContext: applied_example: dict @@ -90,10 +96,10 @@ def run_command(self, line): actual_output = actual_output.replace("\r\n", "\n") if found_prompt: - return True, actual_output + return actual_output else: self.close() - return False, actual_output + raise TimeoutException(actual_output) def set_environment(self, context): for k, v in context.items(): @@ -101,10 +107,7 @@ def set_environment(self, context): assert self.prompt() def get_environment(self): - success, output = self.run_command("export") - - if not success: - raise NotImplementedError() + output = self.run_command("export") env = {} @@ -137,7 +140,7 @@ def push_state(self): def pop_state(self): if self.closed: - raise Exception("Session is closed") + return self.sendline("echo $SHELLINSPECTOR_PROMPT_STATE") assert self.prompt() @@ -264,6 +267,27 @@ def _close_session(self, cmd): f"Session could not be closed, because it doesn't exist, command: {cmd}" ) + def _make_session(self, key, cmd, timeout_seconds): + LOGGER.debug("creating session: %s", key) + if cmd.host == "local": + LOGGER.debug("new local shell session") + session = self.sessions[key] = get_localshell(timeout_seconds) + else: + ssh_config = { + **self.ssh_config, + "username": cmd.user, + "server": self.ssh_config["server"], + "port": self.ssh_config["port"], + } + LOGGER.debug("connecting via SSH: %s", ssh_config) + session = get_ssh_session(ssh_config, timeout_seconds) + + if logging.root.level == logging.DEBUG: + # use .buffer here, because pexpect wants to write bytes, not strs + session.logfile = sys.stdout.buffer + + return session + def _get_session(self, cmd, timeout_seconds): """ Create or reuse a shell session used to run the given command. @@ -289,31 +313,17 @@ def _get_session(self, cmd, timeout_seconds): if key not in self.sessions: # connect, if there is no session - LOGGER.debug("creating session: %s", key) - if cmd.host == "local": - LOGGER.debug("new local shell session") - session = self.sessions[key] = get_localshell(timeout_seconds) - else: - ssh_config = { - **self.ssh_config, - "username": cmd.user, - "server": self.ssh_config["server"], - "port": self.ssh_config["port"], - } - LOGGER.debug("connecting via SSH: %s", ssh_config) - session = self.sessions[key] = get_ssh_session( - ssh_config, timeout_seconds - ) - - if logging.root.level == logging.DEBUG: - # use .buffer here, because pexpect wants to write bytes, not strs - session.logfile = sys.stdout.buffer + self.sessions[key] = self._make_session(key, cmd, timeout_seconds) + elif self.sessions[key].closed: + # destroy and reconnect, if there is a broken session + LOGGER.debug("closing failed session: %s", key) + self._close_session(cmd) + self.sessions[key] = self._make_session(key, cmd, timeout_seconds) else: # reuse, if we're already connected LOGGER.debug("reusing session: %s", key) - session = self.sessions[key] - return session + return self.sessions[key] def add_reporter(self, reporter): self.reporters.append(reporter) @@ -364,26 +374,28 @@ def _check_result(self, cmd, command_output, returncode): return False def _run_command(self, session, cmd): - success, command_output = session.run_command(cmd.command) - if not success: + try: + command_output = session.run_command(cmd.command) + except TimeoutException as ex: self.report( RunnerEvent.ERROR, cmd, { - "message": "could not find prompt for command", - "actual": command_output, + "message": "timeout, could not find prompt for command", + "actual": ex.output_so_far, }, ) return False - success, rc_output = session.run_command("echo $?") - if not success: + try: + rc_output = session.run_command("echo $?") + except TimeoutException as ex: self.report( RunnerEvent.ERROR, cmd, { - "message": "could not find prompt for return code", - "actual": rc_output, + "message": "timeout, could not find prompt for return code", + "actual": ex.output_so_far, }, ) return False diff --git a/tests/test_runner.py b/tests/test_runner.py index b28714a..ef837a9 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -795,7 +795,10 @@ def set_environment(self, env): [ ( RunnerEvent.ERROR, - {"message": "could not find prompt for command", "actual": "a"}, + { + "message": "timeout, could not find prompt for command", + "actual": "a", + }, ), ], ), @@ -806,7 +809,10 @@ def set_environment(self, env): [ ( RunnerEvent.ERROR, - {"message": "could not find prompt for return code", "actual": "0"}, + { + "message": "timeout, could not find prompt for return code", + "actual": "0", + }, ) ], ),