Skip to content

Commit

Permalink
use run-context in SshUrlOperations properly
Browse files Browse the repository at this point in the history
This commit changes the call-tree location
of the run-context usage. Before the run
contexts where applied in `SshUrlOperation._stat`.
This is wrong because some caller of _stat
expect a runner-result, in this case a
Generator, to still return data, after
_stat returned. This was not the case of
course because the context exit handler
in _stat would exhaust the Generator when
cleaning up the subprocess resources.
  • Loading branch information
christian-monch committed Oct 17, 2023
1 parent c1a9c47 commit 7a62d96
Showing 1 changed file with 67 additions and 68 deletions.
135 changes: 67 additions & 68 deletions datalad_next/url_operations/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from datalad_next.runners.data_processors import (
pattern_border_processor,
process_from,
ProcessorPipeline,
)
from datalad_next.runners.run import run
from datalad_next.utils.consts import COPY_BUFSIZE
Expand Down Expand Up @@ -86,56 +86,10 @@ def stat(self,
See :meth:`datalad_next.url_operations.UrlOperations.stat`
for parameter documentation and exception behavior.
"""
props = self._stat(
url,
cmd=SshUrlOperations._stat_cmd,
)
return {k: v for k, v in props.items() if not k.startswith('_')}

def _stat(self, url: str, cmd: str) -> Dict:
# any stream must start with this magic marker, or we do not
# recognize what is happening
# after this marker, the server will send the size of the
# to-be-downloaded file in bytes, followed by another magic
# b'\1', and the file content after that
need_magic = b'\1\2\3'
expected_size_str = b''
expected_size = None

ssh_cat = _SshCat(url)
cmd = ssh_cat.get_cmd(cmd)
cmd = ssh_cat.get_cmd(SshUrlOperations._stat_cmd)
with run(cmd, protocol_class=StdOutCaptureGeneratorProtocol) as stream:

# The try clause enables us to execute the code after the context
# handler if the iterator stops unexpectedly. That would, for
# example be the case, if the ssh-subprocess terminates prematurely.
try:
filtered_stream = process_from(
stream,
[pattern_border_processor(need_magic)]
)

# The first chunk should start with the magic
chunk = next(filtered_stream)
if chunk[:len(need_magic)] != need_magic:
raise RuntimeError("Protocol error: report header not received")
chunk = chunk[len(need_magic):]

# The length is transferred now and terminated by b'\x01'.
while b'\x01' not in chunk:
chunk += next(filtered_stream)

marker_index = chunk.index(b'\x01')
expected_size = int(chunk[:marker_index])

props = {
'content-length': expected_size,
'_stream': chain([chunk[marker_index + 1:]], filtered_stream),
}
return props

except StopIteration:
pass
props = self._get_props(stream)

# At this point the subprocess has either exited, was terminated, or
# was killed.
Expand All @@ -147,6 +101,48 @@ def _stat(self, url: str, cmd: str) -> Dict:
url,
message=f'ssh process returned {stream.return_code}'
)
return {k: v for k, v in props.items() if not k.startswith('_')}

def _get_props(self, stream: Generator) -> Dict:
# any stream must start with this magic marker, or we do not
# recognize what is happening
# after this marker, the server will send the size of the
# to-be-downloaded file in bytes, followed by another magic
# b'\1', and the file content after that
magic_marker = b'\1\2\3'

# Create a pipeline object that contains a single data
# processors, i.e. the "pattern_border_processor". It guarantees, that
# each chunk has at least the size of the pattern and that no chunk
# ends with a pattern prefix (except from the last chunk).
# (We could have used the convenience wrapper "process_from", but we
# want to remove the filter again below. This requires us to have a
# ProcessorPipeline-object).
pipeline = ProcessorPipeline([pattern_border_processor(magic_marker)])
filtered_stream = pipeline.process_from(stream)

# The first chunk should start with the magic marker, i.e. b'\1\2\3'
chunk = next(filtered_stream)
if chunk[:len(magic_marker)] != magic_marker:
raise RuntimeError("Protocol error: report header not received")

# Remove the filter again. The chunk is extended to contain all
# data that was buffered in the pipeline.
chunk = b''.join([chunk[len(magic_marker):]] + pipeline.finalize())

# The length is transferred now and terminated by b'\x01'.
while b'\x01' not in chunk:
chunk += next(stream)

marker_index = chunk.index(b'\x01')
expected_size = int(chunk[:marker_index])
chunk = chunk[marker_index + 1:]
props = {
'content-length': expected_size,
'_stream': chain([chunk], stream) if chunk else stream
}
return props


def download(self,
from_url: str,
Expand All @@ -173,13 +169,17 @@ def download(self,

dst_fp = None

try:
props = self._stat(
from_url,
cmd=f'{SshUrlOperations._stat_cmd}; {SshUrlOperations._cat_cmd}',
)
stream = props.pop('_stream')
ssh_cat = _SshCat(from_url)
cmd = ssh_cat.get_cmd(f'{SshUrlOperations._stat_cmd}; {SshUrlOperations._cat_cmd}')
with run(cmd, protocol_class=StdOutCaptureGeneratorProtocol) as stream:

props = self._get_props(stream)
expected_size = props['content-length']
# The stream might have changed due to not yet processed, but
# fetched data, that is now chained in front of it. Therefore we get
# the updated stream from the props
stream = props.pop('_stream')

dst_fp = sys.stdout.buffer if to_path is None \
else open(to_path, 'wb')
# Localize variable access to minimize overhead
Expand All @@ -199,19 +199,18 @@ def download(self,
self._progress_report_update(
progress_id, ('Downloaded chunk',), len(chunk))
props.update(hasher.get_hexdigest())
return props
except CommandError as e:
if e.code == 244:
# this is the special code for a file-not-found
raise UrlOperationsResourceUnknown(from_url) from e
else:
# wrap this into the datalad-standard, but keep the
# original exception linked
raise UrlOperationsRemoteError(from_url, message=str(e)) from e
finally:
if dst_fp and to_path is not None:
dst_fp.close()
self._progress_report_stop(progress_id, ('Finished download',))

# At this point the subprocess has either exited, was terminated, or
# was killed.
if stream.return_code == 244:
# this is the special code for a file-not-found
raise UrlOperationsResourceUnknown(from_url)
elif stream.return_code != 0:
raise UrlOperationsRemoteError(
from_url,
message=f'ssh process returned {stream.return_code}'
)
return props

def upload(self,
from_path: Path | None,
Expand Down

0 comments on commit 7a62d96

Please sign in to comment.