Skip to content

Commit

Permalink
handle timeouts more gracefully
Browse files Browse the repository at this point in the history
  • Loading branch information
luto authored and brutus committed Sep 5, 2024
1 parent 0b86141 commit 30a8649
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 38 deletions.
84 changes: 48 additions & 36 deletions src/shellinspector/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
LOGGER = logging.getLogger(Path(__file__).name)


class TimeoutException(Exception):
def __init__(self, output_so_far: str):
self.output_so_far = output_so_far
super().__init__()


@dataclasses.dataclass
class ShellinspectorPyContext:
applied_example: dict
Expand Down Expand Up @@ -90,21 +96,18 @@ def run_command(self, line):
actual_output = actual_output.replace("\r\n", "\n")

if found_prompt:
return True, actual_output
return actual_output
else:
self.close()
return False, actual_output
raise TimeoutException(actual_output)

def set_environment(self, context):
for k, v in context.items():
self.sendline(f"export {k}={shlex.quote(str(v))}")
assert self.prompt()

def get_environment(self):
success, output = self.run_command("export")

if not success:
raise NotImplementedError()
output = self.run_command("export")

env = {}

Expand Down Expand Up @@ -137,7 +140,7 @@ def push_state(self):

def pop_state(self):
if self.closed:
raise Exception("Session is closed")
return

self.sendline("echo $SHELLINSPECTOR_PROMPT_STATE")
assert self.prompt()
Expand Down Expand Up @@ -264,6 +267,27 @@ def _close_session(self, cmd):
f"Session could not be closed, because it doesn't exist, command: {cmd}"
)

def _make_session(self, key, cmd, timeout_seconds):
LOGGER.debug("creating session: %s", key)
if cmd.host == "local":
LOGGER.debug("new local shell session")
session = self.sessions[key] = get_localshell(timeout_seconds)
else:
ssh_config = {
**self.ssh_config,
"username": cmd.user,
"server": self.ssh_config["server"],
"port": self.ssh_config["port"],
}
LOGGER.debug("connecting via SSH: %s", ssh_config)
session = get_ssh_session(ssh_config, timeout_seconds)

if logging.root.level == logging.DEBUG:
# use .buffer here, because pexpect wants to write bytes, not strs
session.logfile = sys.stdout.buffer

return session

def _get_session(self, cmd, timeout_seconds):
"""
Create or reuse a shell session used to run the given command.
Expand All @@ -289,31 +313,17 @@ def _get_session(self, cmd, timeout_seconds):

if key not in self.sessions:
# connect, if there is no session
LOGGER.debug("creating session: %s", key)
if cmd.host == "local":
LOGGER.debug("new local shell session")
session = self.sessions[key] = get_localshell(timeout_seconds)
else:
ssh_config = {
**self.ssh_config,
"username": cmd.user,
"server": self.ssh_config["server"],
"port": self.ssh_config["port"],
}
LOGGER.debug("connecting via SSH: %s", ssh_config)
session = self.sessions[key] = get_ssh_session(
ssh_config, timeout_seconds
)

if logging.root.level == logging.DEBUG:
# use .buffer here, because pexpect wants to write bytes, not strs
session.logfile = sys.stdout.buffer
self.sessions[key] = self._make_session(key, cmd, timeout_seconds)
elif self.sessions[key].closed:
# destroy and reconnect, if there is a broken session
LOGGER.debug("closing failed session: %s", key)
self._close_session(cmd)
self.sessions[key] = self._make_session(key, cmd, timeout_seconds)
else:
# reuse, if we're already connected
LOGGER.debug("reusing session: %s", key)
session = self.sessions[key]

return session
return self.sessions[key]

def add_reporter(self, reporter):
self.reporters.append(reporter)
Expand Down Expand Up @@ -364,26 +374,28 @@ def _check_result(self, cmd, command_output, returncode):
return False

def _run_command(self, session, cmd):
success, command_output = session.run_command(cmd.command)
if not success:
try:
command_output = session.run_command(cmd.command)
except TimeoutException as ex:
self.report(
RunnerEvent.ERROR,
cmd,
{
"message": "could not find prompt for command",
"actual": command_output,
"message": "timeout, could not find prompt for command",
"actual": ex.output_so_far,
},
)
return False

success, rc_output = session.run_command("echo $?")
if not success:
try:
rc_output = session.run_command("echo $?")
except TimeoutException as ex:
self.report(
RunnerEvent.ERROR,
cmd,
{
"message": "could not find prompt for return code",
"actual": rc_output,
"message": "timeout, could not find prompt for return code",
"actual": ex.output_so_far,
},
)
return False
Expand Down
10 changes: 8 additions & 2 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,10 @@ def set_environment(self, env):
[
(
RunnerEvent.ERROR,
{"message": "could not find prompt for command", "actual": "a"},
{
"message": "timeout, could not find prompt for command",
"actual": "a",
},
),
],
),
Expand All @@ -806,7 +809,10 @@ def set_environment(self, env):
[
(
RunnerEvent.ERROR,
{"message": "could not find prompt for return code", "actual": "0"},
{
"message": "timeout, could not find prompt for return code",
"actual": "0",
},
)
],
),
Expand Down

0 comments on commit 30a8649

Please sign in to comment.