diff --git a/src/shellinspector/parser.py b/src/shellinspector/parser.py index 315fc42..730c8f3 100644 --- a/src/shellinspector/parser.py +++ b/src/shellinspector/parser.py @@ -1,7 +1,7 @@ +import dataclasses import re import typing -from dataclasses import dataclass -from dataclasses import replace +from contextlib import suppress from enum import Enum from pathlib import Path @@ -20,7 +20,7 @@ class AssertMode(Enum): IGNORE = "_" -@dataclass +@dataclasses.dataclass class Command: execution_mode: ExecutionMode command: str @@ -46,7 +46,7 @@ def short(self): return f"{self.execution_mode.name}({self.user}@{self.host}) `{self.command}` (expect {self.line_count} lines, {self.assert_mode.name})" -@dataclass +@dataclasses.dataclass class Error: source_file: Path source_line_no: int @@ -54,7 +54,15 @@ class Error: message: str -@dataclass +@dataclasses.dataclass +class Settings: + timeout_seconds: int + + def __init__(self, timeout_seconds=5): + self.timeout_seconds = timeout_seconds + + +@dataclasses.dataclass class Specfile: path: Path commands: list[Command] @@ -62,6 +70,7 @@ class Specfile: environment: dict[str, str] examples: list[dict[str, str]] applied_example: dict + settings: Settings def __init__( self, path, commands=None, errors=None, environment=None, examples=None @@ -72,12 +81,13 @@ def __init__( self.environment = environment or {} self.examples = examples or [] self.applied_example = None + self.settings = Settings() def copy(self): return Specfile( self.path, - [replace(c) for c in self.commands], - [replace(e) for e in self.errors], + [dataclasses.replace(c) for c in self.commands], + [dataclasses.replace(e) for e in self.errors], self.environment.copy(), [e.copy() for e in self.examples], ) @@ -274,6 +284,15 @@ def parse(path: str, stream: typing.IO) -> Specfile: setattr(specfile, key, value) + frontmatter_settings = frontmatter.get("settings", {}) + global_settings = config.get("settings", {}) + + for key in dataclasses.fields(specfile.settings): + with suppress(LookupError): + setattr(specfile.settings, key.name, global_settings[key.name]) + with suppress(LookupError): + setattr(specfile.settings, key.name, frontmatter_settings[key.name]) + parse_commands(specfile, commands) return specfile diff --git a/src/shellinspector/runner.py b/src/shellinspector/runner.py index dbe6334..6031b49 100644 --- a/src/shellinspector/runner.py +++ b/src/shellinspector/runner.py @@ -198,16 +198,16 @@ def disable_color(): del os.environ["TERM"] -def get_ssh_session(ssh_config): +def get_ssh_session(ssh_config, timeout_seconds): with disable_color(): - shell = RemoteShell(timeout=5) + shell = RemoteShell(timeout=timeout_seconds) shell.login(**ssh_config) return shell -def get_localshell(): +def get_localshell(timeout_seconds): with disable_color(): - shell = LocalShell(timeout=5) + shell = LocalShell(timeout=timeout_seconds) shell.login() return shell @@ -258,7 +258,7 @@ def _close_session(self, cmd): f"Session could not be closed, because it doesn't exist, command: {cmd}" ) - def _get_session(self, cmd): + def _get_session(self, cmd, timeout_seconds): """ Create or reuse a shell session used to run the given command. @@ -286,7 +286,7 @@ def _get_session(self, cmd): LOGGER.debug("creating session: %s", key) if cmd.host == "local": LOGGER.debug("new local shell session") - session = self.sessions[key] = get_localshell() + session = self.sessions[key] = get_localshell(timeout_seconds) else: ssh_config = { **self.ssh_config, @@ -295,7 +295,9 @@ def _get_session(self, cmd): "port": self.ssh_config["port"], } LOGGER.debug("connecting via SSH: %s", ssh_config) - session = self.sessions[key] = get_ssh_session(ssh_config) + session = self.sessions[key] = get_ssh_session( + ssh_config, timeout_seconds + ) if logging.root.level == logging.DEBUG: # use .buffer here, because pexpect wants to write bytes, not strs @@ -390,7 +392,7 @@ def run(self, specfile): try: for cmd in specfile.commands: self.report(RunnerEvent.COMMAND_STARTING, cmd, {}) - session = self._get_session(cmd) + session = self._get_session(cmd, specfile.settings.timeout_seconds) if cmd.execution_mode == ExecutionMode.PYTHON: ctx = ShellinspectorPyContext({}, {}) diff --git a/tests/fixtures/parse_global_config/combine/shellinspector.yaml b/tests/fixtures/parse_global_config/combine/shellinspector.yaml index 9c53800..b406b3e 100644 --- a/tests/fixtures/parse_global_config/combine/shellinspector.yaml +++ b/tests/fixtures/parse_global_config/combine/shellinspector.yaml @@ -3,3 +3,5 @@ environment: FROM_CONFIG: 1 examples: - FROM_CONFIG: 1 +settings: + timeout_seconds: 99 diff --git a/tests/test_parser.py b/tests/test_parser.py index a33a882..84d0165 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -537,3 +537,4 @@ def test_global_config_default(): assert specfile.environment == {"FROM_CONFIG": 1} assert specfile.examples == [{"FROM_CONFIG": 1}] + assert specfile.settings.timeout_seconds == 99 diff --git a/tests/test_runner.py b/tests/test_runner.py index f9dcc55..3d66ad6 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -203,14 +203,14 @@ def test_remoteshell_get_environment(ssh_config): def test_get_localshell(): - shell = get_localshell() + shell = get_localshell(5) shell.sendline("echo a") assert shell.prompt(), shell.before assert shell.before.decode() == "a\r\n" def test_get_ssh_session(ssh_config): - shell = get_ssh_session(ssh_config) + shell = get_ssh_session(ssh_config, 5) shell.sendline("echo a") assert shell.prompt(), shell.before assert shell.before.decode() == "a\r\n" @@ -238,6 +238,22 @@ def rep(*args, **kwargs): @pytest.fixture def command_local_echo_literal(): + return Command( + ExecutionMode.USER, + "echo a", + None, + None, + "local", + AssertMode.LITERAL, + "a\n", + "/some.ispec", + 1, + "$ echo a", + ) + + +@pytest.fixture +def command_local_echo_literal_fail(): return Command( ExecutionMode.USER, "echo a", @@ -252,6 +268,22 @@ def command_local_echo_literal(): ) +@pytest.fixture +def command_remote_echo_literal(): + return Command( + ExecutionMode.ROOT, + "echo a", + "root", + None, + "remote", + AssertMode.LITERAL, + "a\n", + "/some.ispec", + 1, + "$ echo a", + ) + + @pytest.fixture def command_local_echo_regex(): return Command( @@ -289,7 +321,7 @@ def command_local_echo_ignore(): ( # LITERAL ( - lazy_fixture("command_local_echo_literal"), + lazy_fixture("command_local_echo_literal_fail"), ["a", 0], True, [ @@ -304,7 +336,7 @@ def command_local_echo_ignore(): ), # LITERAL & FAIL-Tests ( - lazy_fixture("command_local_echo_literal"), + lazy_fixture("command_local_echo_literal_fail"), ["b", 0], False, [ @@ -319,7 +351,7 @@ def command_local_echo_ignore(): ], ), ( - lazy_fixture("command_local_echo_literal"), + lazy_fixture("command_local_echo_literal_fail"), ["a", 1], False, [ @@ -334,7 +366,7 @@ def command_local_echo_ignore(): ], ), ( - lazy_fixture("command_local_echo_literal"), + lazy_fixture("command_local_echo_literal_fail"), ["b", 1], False, [ @@ -466,7 +498,7 @@ def test_get_session(make_runner, ssh_config, user, host, expected_class): "$ echo a", ) - session1 = runner._get_session(cmd) + session1 = runner._get_session(cmd, 5) assert isinstance(session1, expected_class) @@ -474,15 +506,15 @@ def test_get_session(make_runner, ssh_config, user, host, expected_class): assert session1.prompt() assert session1.before.decode().strip() == "a" - session2 = runner._get_session(cmd) + session2 = runner._get_session(cmd, 5) assert id(session1) == id(session2) cmd.session_name = "a" - session3 = runner._get_session(cmd) + session3 = runner._get_session(cmd, 5) assert id(session1) != id(session3) - session4 = runner._get_session(cmd) + session4 = runner._get_session(cmd, 5) assert id(session3) == id(session4) @@ -503,7 +535,31 @@ def test_get_session_unknown_host(make_runner, ssh_config): ) with pytest.raises(Exception, match="Unknown host: xxx.*"): - runner._get_session(cmd) + runner._get_session(cmd, 5) + + +def test_timeout_setting( + make_runner, ssh_config, command_local_echo_literal, command_remote_echo_literal +): + runner, events = make_runner(ssh_config) + + specfile = Specfile("virtual.ispec") + specfile.commands = [command_local_echo_literal, command_remote_echo_literal] + + runner.run(specfile) + + for event in events: + assert event[0][0] in ( + RunnerEvent.COMMAND_STARTING, + RunnerEvent.COMMAND_PASSED, + RunnerEvent.RUN_SUCCEEDED, + ), event + + sessions = list(runner.sessions.values()) + assert len(sessions) == 2 + + for session in sessions: + assert session.timeout == 5 def test_logout(make_runner, ssh_config): @@ -751,7 +807,7 @@ def set_environment(self, env): ) def test_run_command( make_runner, - command_local_echo_literal, + command_local_echo_literal_fail, prompt_works, actual_output, expected_result, @@ -759,14 +815,14 @@ def test_run_command( ): session = FakeSession(prompt_works, actual_output) runner, events = make_runner() - result = runner._run_command(session, command_local_echo_literal) + result = runner._run_command(session, command_local_echo_literal_fail) assert result == expected_result, events assert len(events) == len(expected_events) for i in range(len(events)): assert events[i][0][0] == expected_events[i][0] - assert events[i][0][1] == command_local_echo_literal + assert events[i][0][1] == command_local_echo_literal_fail assert events[i][1] == expected_events[i][1] @@ -805,7 +861,7 @@ def test_run_command( ) def test_run1( make_runner, - command_local_echo_literal, + command_local_echo_literal_fail, prompt_works, actual_output, expected_result, @@ -813,9 +869,9 @@ def test_run1( ): runner, events = make_runner() session = FakeSession(prompt_works, actual_output) - runner._get_session = lambda cmd: session + runner._get_session = lambda cmd, timeout: session specfile = Specfile("virtual.ispec") - specfile.commands = [command_local_echo_literal] + specfile.commands = [command_local_echo_literal_fail] result = runner.run(specfile) assert result == expected_result, events