Skip to content

Commit

Permalink
Add reading of setuptools metadata to find smart_open transport / com…
Browse files Browse the repository at this point in the history
…pressor extensions
  • Loading branch information
arthurlm committed Apr 30, 2022
1 parent fe6cf99 commit e6384f7
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 1 deletion.
15 changes: 15 additions & 0 deletions smart_open/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#
"""Implements the compression layer of the ``smart_open`` library."""
import logging
import importlib
import importlib.metadata
import os.path

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -145,3 +147,16 @@ def compression_wrapper(file_obj, mode, compression):
#
register_compressor('.bz2', _handle_bz2)
register_compressor('.gz', _handle_gzip)


def _register_compressor_entry_point(ep):
try:
assert len(ep.name) > 0, "At least one char is required for ep.name"
extension = ".{}".format(ep.name)
register_compressor(extension, ep.load())
except Exception:
logger.warning("Fail to load smart_open compressor extension: %s (target: %s)", ep.name, ep.value)


for ep in importlib.metadata.entry_points().select(group='smart_open_compressor'):
_register_compressor_entry_point(ep)
10 changes: 10 additions & 0 deletions smart_open/tests/fixtures/compressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
"""Some no-op compressor"""


def handle_foo():
...


def handle_bar():
...
56 changes: 55 additions & 1 deletion smart_open/tests/test_transport.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,44 @@
# -*- coding: utf-8 -*-
from importlib.metadata import EntryPoint
import pytest
import unittest

from smart_open.transport import register_transport, get_transport
from smart_open.transport import (
register_transport, get_transport, _REGISTRY, _ERRORS, _register_transport_entry_point
)


def unregister_transport(x):
if x in _REGISTRY:
del _REGISTRY[x]
if x in _ERRORS:
del _ERRORS[x]


def assert_transport_not_registered(scheme):
with pytest.raises(NotImplementedError):
get_transport(scheme)


def assert_transport_registered(scheme):
transport = get_transport(scheme)
assert transport.SCHEME == scheme


class TransportTest(unittest.TestCase):
def tearDown(self):
unregister_transport("foo")
unregister_transport("missing")

def test_registry_requires_declared_schemes(self):
with pytest.raises(ValueError):
register_transport('smart_open.tests.fixtures.no_schemes_transport')

def test_registry_valid_transport(self):
assert_transport_not_registered("foo")
register_transport('smart_open.tests.fixtures.good_transport')
assert_transport_registered("foo")

def test_registry_errors_on_double_register_scheme(self):
register_transport('smart_open.tests.fixtures.good_transport')
with pytest.raises(AssertionError):
Expand All @@ -20,3 +48,29 @@ def test_registry_errors_get_transport_for_module_with_missing_deps(self):
register_transport('smart_open.tests.fixtures.missing_deps_transport')
with pytest.raises(ImportError):
get_transport("missing")

def test_register_entry_point_valid(self):
assert_transport_not_registered("foo")
_register_transport_entry_point(EntryPoint(
"foo",
"smart_open.tests.fixtures.good_transport",
"smart_open_transport",
))
assert_transport_registered("foo")

def test_register_entry_point_catch_bad_data(self):
_register_transport_entry_point(EntryPoint(
"invalid",
"smart_open.some_totaly_invalid_module",
"smart_open_transport",
))

def test_register_entry_point_for_module_with_missing_deps(self):
assert_transport_not_registered("missing")
_register_transport_entry_point(EntryPoint(
"missing",
"smart_open.tests.fixtures.missing_deps_transport",
"smart_open_transport",
))
with pytest.raises(ImportError):
get_transport("missing")
12 changes: 12 additions & 0 deletions smart_open/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""
import importlib
import importlib.metadata
import logging

import smart_open.local_file
Expand Down Expand Up @@ -102,5 +103,16 @@ def get_transport(scheme):
register_transport('smart_open.ssh')
register_transport('smart_open.webhdfs')


def _register_transport_entry_point(ep):
try:
register_transport(ep.value)
except Exception:
logger.warning("Fail to load smart_open transport extension: %s (target: %s)", ep.name, ep.value)


for ep in importlib.metadata.entry_points().select(group='smart_open_transport'):
_register_transport_entry_point(ep)

SUPPORTED_SCHEMES = tuple(sorted(_REGISTRY.keys()))
"""The transport schemes that the local installation of ``smart_open`` supports."""

0 comments on commit e6384f7

Please sign in to comment.