diff --git a/public_transit_client/client.py b/public_transit_client/client.py index 67075a5..61913a6 100644 --- a/public_transit_client/client.py +++ b/public_transit_client/client.py @@ -11,13 +11,13 @@ Coordinate, Departure, DistanceToStop, + QueryConfig, + RouterInfo, + ScheduleInfo, SearchType, Stop, StopConnection, TimeType, - QueryConfig, - ScheduleInfo, - RouterInfo, ) LOG = logging.getLogger(__name__) @@ -252,14 +252,14 @@ def _build_params_dict( ), } - if source is isinstance(source, tuple): + if isinstance(source, tuple): params["sourceLatitude"] = str(source[0]) params["sourceLongitude"] = str(source[1]) elif isinstance(source, str): params["sourceStopId"] = source if target is not None: - if target is isinstance(target, tuple): + if isinstance(target, tuple): params["targetLatitude"] = str(target[0]) params["targetLongitude"] = str(target[1]) elif isinstance(target, str): diff --git a/public_transit_client/model.py b/public_transit_client/model.py index a497426..e708e22 100644 --- a/public_transit_client/model.py +++ b/public_transit_client/model.py @@ -1,4 +1,4 @@ -from datetime import datetime, date +from datetime import date, datetime from enum import Enum from itertools import pairwise diff --git a/pyproject.toml b/pyproject.toml index f70acb0..4ac3502 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "public-transit-client" -version = "1.0.0" +version = "1.0.1" description = "Client to access the public transit service API endpoints." authors = [ "Lukas Connolly ", diff --git a/tests/integration/test_integration_routing.py b/tests/integration/test_integration_routing.py index 8efd759..1bb91af 100644 --- a/tests/integration/test_integration_routing.py +++ b/tests/integration/test_integration_routing.py @@ -1,4 +1,4 @@ -from datetime import datetime, date +from datetime import date, datetime import pytest @@ -9,12 +9,12 @@ from public_transit_client.model import ( Connection, Coordinate, - StopConnection, - TimeType, QueryConfig, - TransportMode, RouterInfo, ScheduleInfo, + StopConnection, + TimeType, + TransportMode, ) HOST = "http://localhost:8080" @@ -70,6 +70,42 @@ def test_get_connections(client): assert connections[0].to_stop.id == to_stop +@pytest.mark.integration +def test_get_connections_coordinates(client): + from_coordinate = Coordinate(latitude=36.914, longitude=-116.761) + to_coordinate = Coordinate(latitude=36.881, longitude=-116.817) + departure_time = datetime(2008, 6, 1) + connections = client.get_connections( + source=from_coordinate, + target=to_coordinate, + time=departure_time, + time_type=TimeType.DEPARTURE, + ) + + assert isinstance(connections, list) + assert len(connections) > 0 + assert all(isinstance(connection, Connection) for connection in connections) + assert connections[0].from_coordinate == from_coordinate + assert connections[0].to_coordinate == to_coordinate + + +@pytest.mark.integration +def test_get_connections_coordinate_tuples(client): + from_coordinate = (36.914, -116.761) + to_coordinate = (36.881, -116.817) + departure_time = datetime(2008, 6, 1) + connections = client.get_connections( + source=from_coordinate, + target=to_coordinate, + time=departure_time, + time_type=TimeType.DEPARTURE, + ) + + assert isinstance(connections, list) + assert len(connections) > 0 + assert all(isinstance(connection, Connection) for connection in connections) + + @pytest.mark.integration def test_get_connections_invalid_stop(client): from_stop = "INVALID_STOP"