Skip to content

Commit

Permalink
add setting for timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
luto committed Jul 11, 2024
1 parent ea55d6e commit 77b2e59
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 32 deletions.
33 changes: 26 additions & 7 deletions src/shellinspector/parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
import re
import typing
from dataclasses import dataclass
from dataclasses import replace
from contextlib import suppress
from enum import Enum
from pathlib import Path

Expand All @@ -20,7 +20,7 @@ class AssertMode(Enum):
IGNORE = "_"


@dataclass
@dataclasses.dataclass
class Command:
execution_mode: ExecutionMode
command: str
Expand All @@ -46,22 +46,31 @@ def short(self):
return f"{self.execution_mode.name}({self.user}@{self.host}) `{self.command}` (expect {self.line_count} lines, {self.assert_mode.name})"


@dataclass
@dataclasses.dataclass
class Error:
source_file: Path
source_line_no: int
source_line: str
message: str


@dataclass
@dataclasses.dataclass
class Settings:
timeout_seconds: int

def __init__(self, timeout_seconds=5):
self.timeout_seconds = timeout_seconds


@dataclasses.dataclass
class Specfile:
path: Path
commands: list[Command]
errors: list[Error]
environment: dict[str, str]
examples: list[dict[str, str]]
applied_example: dict
settings: Settings

def __init__(
self, path, commands=None, errors=None, environment=None, examples=None
Expand All @@ -72,12 +81,13 @@ def __init__(
self.environment = environment or {}
self.examples = examples or []
self.applied_example = None
self.settings = Settings()

def copy(self):
return Specfile(
self.path,
[replace(c) for c in self.commands],
[replace(e) for e in self.errors],
[dataclasses.replace(c) for c in self.commands],
[dataclasses.replace(e) for e in self.errors],
self.environment.copy(),
[e.copy() for e in self.examples],
)
Expand Down Expand Up @@ -274,6 +284,15 @@ def parse(path: str, stream: typing.IO) -> Specfile:

setattr(specfile, key, value)

frontmatter_settings = frontmatter.get("settings", {})
global_settings = config.get("settings", {})

for key in dataclasses.fields(specfile.settings):
with suppress(LookupError):
setattr(specfile.settings, key.name, global_settings[key.name])
with suppress(LookupError):
setattr(specfile.settings, key.name, frontmatter_settings[key.name])

parse_commands(specfile, commands)

return specfile
18 changes: 10 additions & 8 deletions src/shellinspector/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,16 @@ def disable_color():
del os.environ["TERM"]


def get_ssh_session(ssh_config):
def get_ssh_session(ssh_config, timeout_seconds):
with disable_color():
shell = RemoteShell(timeout=5)
shell = RemoteShell(timeout=timeout_seconds)
shell.login(**ssh_config)
return shell


def get_localshell():
def get_localshell(timeout_seconds):
with disable_color():
shell = LocalShell(timeout=5)
shell = LocalShell(timeout=timeout_seconds)
shell.login()
return shell

Expand Down Expand Up @@ -258,7 +258,7 @@ def _close_session(self, cmd):
f"Session could not be closed, because it doesn't exist, command: {cmd}"
)

def _get_session(self, cmd):
def _get_session(self, cmd, timeout_seconds):
"""
Create or reuse a shell session used to run the given command.
Expand Down Expand Up @@ -286,7 +286,7 @@ def _get_session(self, cmd):
LOGGER.debug("creating session: %s", key)
if cmd.host == "local":
LOGGER.debug("new local shell session")
session = self.sessions[key] = get_localshell()
session = self.sessions[key] = get_localshell(timeout_seconds)
else:
ssh_config = {
**self.ssh_config,
Expand All @@ -295,7 +295,9 @@ def _get_session(self, cmd):
"port": self.ssh_config["port"],
}
LOGGER.debug("connecting via SSH: %s", ssh_config)
session = self.sessions[key] = get_ssh_session(ssh_config)
session = self.sessions[key] = get_ssh_session(
ssh_config, timeout_seconds
)

if logging.root.level == logging.DEBUG:
# use .buffer here, because pexpect wants to write bytes, not strs
Expand Down Expand Up @@ -390,7 +392,7 @@ def run(self, specfile):
try:
for cmd in specfile.commands:
self.report(RunnerEvent.COMMAND_STARTING, cmd, {})
session = self._get_session(cmd)
session = self._get_session(cmd, specfile.settings.timeout_seconds)

if cmd.execution_mode == ExecutionMode.PYTHON:
ctx = ShellinspectorPyContext({}, {})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ environment:
FROM_CONFIG: 1
examples:
- FROM_CONFIG: 1
settings:
timeout_seconds: 99
1 change: 1 addition & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,3 +537,4 @@ def test_global_config_default():

assert specfile.environment == {"FROM_CONFIG": 1}
assert specfile.examples == [{"FROM_CONFIG": 1}]
assert specfile.settings.timeout_seconds == 99
90 changes: 73 additions & 17 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,14 @@ def test_remoteshell_get_environment(ssh_config):


def test_get_localshell():
shell = get_localshell()
shell = get_localshell(5)
shell.sendline("echo a")
assert shell.prompt(), shell.before
assert shell.before.decode() == "a\r\n"


def test_get_ssh_session(ssh_config):
shell = get_ssh_session(ssh_config)
shell = get_ssh_session(ssh_config, 5)
shell.sendline("echo a")
assert shell.prompt(), shell.before
assert shell.before.decode() == "a\r\n"
Expand Down Expand Up @@ -238,6 +238,22 @@ def rep(*args, **kwargs):

@pytest.fixture
def command_local_echo_literal():
return Command(
ExecutionMode.USER,
"echo a",
None,
None,
"local",
AssertMode.LITERAL,
"a\n",
"/some.ispec",
1,
"$ echo a",
)


@pytest.fixture
def command_local_echo_literal_fail():
return Command(
ExecutionMode.USER,
"echo a",
Expand All @@ -252,6 +268,22 @@ def command_local_echo_literal():
)


@pytest.fixture
def command_remote_echo_literal():
return Command(
ExecutionMode.ROOT,
"echo a",
"root",
None,
"remote",
AssertMode.LITERAL,
"a\n",
"/some.ispec",
1,
"$ echo a",
)


@pytest.fixture
def command_local_echo_regex():
return Command(
Expand Down Expand Up @@ -289,7 +321,7 @@ def command_local_echo_ignore():
(
# LITERAL
(
lazy_fixture("command_local_echo_literal"),
lazy_fixture("command_local_echo_literal_fail"),
["a", 0],
True,
[
Expand All @@ -304,7 +336,7 @@ def command_local_echo_ignore():
),
# LITERAL & FAIL-Tests
(
lazy_fixture("command_local_echo_literal"),
lazy_fixture("command_local_echo_literal_fail"),
["b", 0],
False,
[
Expand All @@ -319,7 +351,7 @@ def command_local_echo_ignore():
],
),
(
lazy_fixture("command_local_echo_literal"),
lazy_fixture("command_local_echo_literal_fail"),
["a", 1],
False,
[
Expand All @@ -334,7 +366,7 @@ def command_local_echo_ignore():
],
),
(
lazy_fixture("command_local_echo_literal"),
lazy_fixture("command_local_echo_literal_fail"),
["b", 1],
False,
[
Expand Down Expand Up @@ -466,23 +498,23 @@ def test_get_session(make_runner, ssh_config, user, host, expected_class):
"$ echo a",
)

session1 = runner._get_session(cmd)
session1 = runner._get_session(cmd, 5)

assert isinstance(session1, expected_class)

session1.sendline("echo a")
assert session1.prompt()
assert session1.before.decode().strip() == "a"

session2 = runner._get_session(cmd)
session2 = runner._get_session(cmd, 5)
assert id(session1) == id(session2)

cmd.session_name = "a"

session3 = runner._get_session(cmd)
session3 = runner._get_session(cmd, 5)
assert id(session1) != id(session3)

session4 = runner._get_session(cmd)
session4 = runner._get_session(cmd, 5)
assert id(session3) == id(session4)


Expand All @@ -503,7 +535,31 @@ def test_get_session_unknown_host(make_runner, ssh_config):
)

with pytest.raises(Exception, match="Unknown host: xxx.*"):
runner._get_session(cmd)
runner._get_session(cmd, 5)


def test_timeout_setting(
make_runner, ssh_config, command_local_echo_literal, command_remote_echo_literal
):
runner, events = make_runner(ssh_config)

specfile = Specfile("virtual.ispec")
specfile.commands = [command_local_echo_literal, command_remote_echo_literal]

runner.run(specfile)

for event in events:
assert event[0][0] in (
RunnerEvent.COMMAND_STARTING,
RunnerEvent.COMMAND_PASSED,
RunnerEvent.RUN_SUCCEEDED,
), event

sessions = list(runner.sessions.values())
assert len(sessions) == 2

for session in sessions:
assert session.timeout == 5


def test_logout(make_runner, ssh_config):
Expand Down Expand Up @@ -751,22 +807,22 @@ def set_environment(self, env):
)
def test_run_command(
make_runner,
command_local_echo_literal,
command_local_echo_literal_fail,
prompt_works,
actual_output,
expected_result,
expected_events,
):
session = FakeSession(prompt_works, actual_output)
runner, events = make_runner()
result = runner._run_command(session, command_local_echo_literal)
result = runner._run_command(session, command_local_echo_literal_fail)
assert result == expected_result, events

assert len(events) == len(expected_events)

for i in range(len(events)):
assert events[i][0][0] == expected_events[i][0]
assert events[i][0][1] == command_local_echo_literal
assert events[i][0][1] == command_local_echo_literal_fail
assert events[i][1] == expected_events[i][1]


Expand Down Expand Up @@ -805,17 +861,17 @@ def test_run_command(
)
def test_run1(
make_runner,
command_local_echo_literal,
command_local_echo_literal_fail,
prompt_works,
actual_output,
expected_result,
expected_events,
):
runner, events = make_runner()
session = FakeSession(prompt_works, actual_output)
runner._get_session = lambda cmd: session
runner._get_session = lambda cmd, timeout: session
specfile = Specfile("virtual.ispec")
specfile.commands = [command_local_echo_literal]
specfile.commands = [command_local_echo_literal_fail]

result = runner.run(specfile)
assert result == expected_result, events
Expand Down

0 comments on commit 77b2e59

Please sign in to comment.