Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for SSH connection via aliases from ~/.ssh/config #790

Merged
merged 7 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 135 additions & 8 deletions smart_open/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@
"""

import getpass
import os
import logging
import urllib.parse

from typing import (
Dict,
Callable,
Tuple,
)

try:
import paramiko
except ImportError:
Expand All @@ -52,11 +59,43 @@
'sftp://username@host/path/file',
)

#
# Global storage for SSH config files.
#
_SSH_CONFIG_FILES = [os.path.expanduser("~/.ssh/config")]


def _unquote(text):
return text and urllib.parse.unquote(text)


def _str2bool(string):
if string == "no":
return False
if string == "yes":
return True
raise ValueError(f"Expected 'yes' / 'no', got {string}.")

#
# The parameter names used by Paramiko (and smart_open) slightly differ to
# those used in ~/.ssh/config, so we use a mapping to bridge the gap.
#
# The keys are option names as they appear in Paramiko (and smart_open)
# The values are a tuples containing:
#
# 1. their corresponding names in the ~/.ssh/config file
# 2. a callable to convert the parameter value from a string to the appropriate type
#
_PARAMIKO_CONFIG_MAP: Dict[str, Tuple[str, Callable]] = {
"timeout": ("connecttimeout", float),
"compress": ("compression", _str2bool),
"gss_auth": ("gssapiauthentication", _str2bool),
"gss_kex": ("gssapikeyexchange", _str2bool),
"gss_deleg_creds": ("gssapidelegatecredentials", _str2bool),
"gss_trust_dns": ("gssapitrustdns", _str2bool),
}


def parse_uri(uri_as_string):
split_uri = urllib.parse.urlsplit(uri_as_string)
assert split_uri.scheme in SCHEMES
Expand All @@ -65,7 +104,7 @@ def parse_uri(uri_as_string):
uri_path=_unquote(split_uri.path),
user=_unquote(split_uri.username),
host=split_uri.hostname,
port=int(split_uri.port or DEFAULT_PORT),
port=int(split_uri.port) if split_uri.port else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we setting port to None instead of the default here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I recall correctly, this is to ensure that non-default ports from configuration files are loaded correctly. We think it best that the parse_uri function should not inject additional information, such as a default port, to a URI, as this is the job of the configuration parser. This is important for cases where a non-default port is specified in the config file for a connection, but not in the URI as-provided. If parse_uri injects the default here, then the non-default port from the config will be ignored, which is not what we want. We only want to override ports specified in the config if explicitly provided by the user as part of the URI.

password=_unquote(split_uri.password),
)

Expand All @@ -90,7 +129,98 @@ def _connect_ssh(hostname, username, port, password, transport_params):
return ssh


def open(path, mode='r', host=None, user=None, password=None, port=DEFAULT_PORT, transport_params=None):
def _maybe_fetch_config(host, username=None, password=None, port=None, transport_params=None):
# If all fields are set, return as-is.
if not any(arg is None for arg in (host, username, password, port, transport_params)):
return host, username, password, port, transport_params

if not host:
raise ValueError('you must specify the host to connect to')
if not transport_params:
transport_params = {}
if "connect_kwargs" not in transport_params:
transport_params["connect_kwargs"] = {}

# Attempt to load an OpenSSH config.
#
# Connections configured in this way are not guaranteed to perform exactly
# as they do in typical usage due to mismatches between the set of OpenSSH
# configuration options and those that Paramiko supports. We provide a best
# attempt, and support:
#
# - hostname -> address resolution
# - username inference
# - port inference
# - identityfile inference
# - connection timeout inference
# - compression selection
# - GSS configuration
#
connect_params = transport_params["connect_kwargs"]
config_files = [f for f in _SSH_CONFIG_FILES if os.path.exists(f)]
#
# This is the actual name of the host. The input host may actually be an
# alias.
#
actual_hostname = ""

for config_filename in config_files:
try:
cfg = paramiko.SSHConfig.from_path(config_filename)
except PermissionError:
continue

if host not in cfg.get_hostnames():
continue

cfg = cfg.lookup(host)
if username is None:
username = cfg.get("user", None)

if not actual_hostname:
actual_hostname = cfg["hostname"]

if port is None:
try:
port = int(cfg["port"])
except (IndexError, ValueError):
#
# Nb. ignore missing/invalid port numbers
#
pass

#
# Special case, as we can have multiple identity files, so we check
# that the identityfile list has len > 0. This should be redundant, but
# keeping it for safety.
#
if connect_params.get("key_filename") is None:
identityfile = cfg.get("identityfile", [])
if len(identityfile):
connect_params["key_filename"] = identityfile

for param_name, (sshcfg_name, from_str) in _PARAMIKO_CONFIG_MAP.items():
if connect_params.get(param_name) is None and sshcfg_name in cfg:
connect_params[param_name] = from_str(cfg[sshcfg_name])

#
# Continue working through other config files, if there are any,
# as they may contain more options for our host
#

if port is None:
port = DEFAULT_PORT

if not username:
username = getpass.getuser()

if actual_hostname:
host = actual_hostname

return host, username, password, port, transport_params


def open(path, mode='r', host=None, user=None, password=None, port=None, transport_params=None):
"""Open a file on a remote machine over SSH.

Expects authentication to be already set up via existing keys on the local machine.
Expand Down Expand Up @@ -125,12 +255,9 @@ def open(path, mode='r', host=None, user=None, password=None, port=DEFAULT_PORT,
If ``username`` or ``password`` are specified in *both* the uri and
``transport_params``, ``transport_params`` will take precedence
"""
if not host:
raise ValueError('you must specify the host to connect to')
if not user:
user = getpass.getuser()
if not transport_params:
transport_params = {}
host, user, password, port, transport_params = _maybe_fetch_config(
host, user, password, port, transport_params
)

key = (host, user)

Expand Down
11 changes: 11 additions & 0 deletions smart_open/tests/test_data/ssh.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Host another-host
HostName another-host-domain.com
User another-user
Port 2345
IdentityFile /path/to/key/file
ConnectTimeout 20
Compression yes
GSSAPIAuthentication no
GSSAPIKeyExchange no
GSSAPIDelegateCredentials no
GSSAPITrustDns no
6 changes: 3 additions & 3 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def test_scp(self):
self.assertEqual(uri.uri_path, '/path/to/file')
self.assertEqual(uri.user, 'user')
self.assertEqual(uri.host, 'host')
self.assertEqual(uri.port, 22)
self.assertEqual(uri.port, None)
self.assertEqual(uri.password, None)

def test_scp_with_pass(self):
Expand All @@ -319,7 +319,7 @@ def test_scp_with_pass(self):
self.assertEqual(uri.uri_path, '/path/to/file')
self.assertEqual(uri.user, 'user')
self.assertEqual(uri.host, 'host')
self.assertEqual(uri.port, 22)
self.assertEqual(uri.port, None)
self.assertEqual(uri.password, 'pass')

def test_sftp(self):
Expand All @@ -329,7 +329,7 @@ def test_sftp(self):
self.assertEqual(uri.uri_path, '/path/to/file')
self.assertEqual(uri.user, None)
self.assertEqual(uri.host, 'host')
self.assertEqual(uri.port, 22)
self.assertEqual(uri.port, None)
self.assertEqual(uri.password, None)

def test_sftp_with_user_and_pass(self):
Expand Down
61 changes: 61 additions & 0 deletions smart_open/tests/test_ssh.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
# -*- coding: utf-8 -*-

import logging
import os
import tempfile
import unittest
from unittest import mock
import uuid

from paramiko import SSHException

import smart_open.ssh

_TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), "test_data")
_CONFIG_PATH = os.path.join(_TEST_DATA_PATH, "ssh.cfg")


def mock_ssh(func):
def wrapper(*args, **kwargs):
Expand All @@ -20,6 +26,13 @@ def wrapper(*args, **kwargs):


class SSHOpen(unittest.TestCase):
def setUp(self):
self._cfg_files = smart_open.ssh._SSH_CONFIG_FILES
smart_open.ssh._SSH_CONFIG_FILES = [_CONFIG_PATH]

def tearDown(self):
smart_open.ssh._SSH_CONFIG_FILES = self._cfg_files

@mock_ssh
def test_open(self, mock_connect, get_transp_mock):
smart_open.open("ssh://user:pass@some-host/")
Expand Down Expand Up @@ -68,6 +81,54 @@ def mocked_open_sftp():
mock_connect.assert_called_with("some-host", 22, username="user", password="pass")
mock_sftp.open.assert_called_once()

@mock_ssh
def test_open_with_openssh_config(self, mock_connect, get_transp_mock):
smart_open.open("ssh://another-host/")
mock_connect.assert_called_with(
"another-host-domain.com",
2345,
username="another-user",
key_filename=["/path/to/key/file"],
timeout=20.,
compress=True,
gss_auth=False,
gss_kex=False,
gss_deleg_creds=False,
gss_trust_dns=False,
)

@mock_ssh
def test_open_with_openssh_config_override_port(self, mock_connect, get_transp_mock):
smart_open.open("ssh://another-host:22/")
mock_connect.assert_called_with(
"another-host-domain.com",
22,
username="another-user",
key_filename=["/path/to/key/file"],
timeout=20.,
compress=True,
gss_auth=False,
gss_kex=False,
gss_deleg_creds=False,
gss_trust_dns=False,
)

@mock_ssh
def test_open_with_openssh_config_override_user(self, mock_connect, get_transp_mock):
smart_open.open("ssh://new-user@another-host/")
mock_connect.assert_called_with(
"another-host-domain.com",
2345,
username="new-user",
key_filename=["/path/to/key/file"],
timeout=20.,
compress=True,
gss_auth=False,
gss_kex=False,
gss_deleg_creds=False,
gss_trust_dns=False,
)


if __name__ == "__main__":
logging.basicConfig(format="%(asctime)s : %(levelname)s : %(message)s", level=logging.DEBUG)
Expand Down
Loading