Skip to content

Commit

Permalink
fix: /topolygy handle json_arrayagg for mysql (#2051)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kiryous authored Sep 30, 2024
1 parent 8e95e56 commit 4e6dff0
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 16 deletions.
33 changes: 25 additions & 8 deletions keep/topologies/topologies_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import ValidationError
from sqlalchemy.orm import joinedload, selectinload
from uuid import UUID
import json

from sqlmodel import Session, select

Expand Down Expand Up @@ -49,22 +50,38 @@ def get_service_application_ids_dict(
TopologyServiceApplication.service_id,
get_aggreated_field(
session,
TopologyServiceApplication.application_id, # type: ignore
TopologyServiceApplication.application_id, # type: ignore
"application_ids",
),
) # type: ignore
)
.where(TopologyServiceApplication.service_id.in_(service_ids))
.group_by(TopologyServiceApplication.service_id)
)
results = session.exec(query).all()
dialect_name = session.bind.dialect.name if session.bind else ""
result = {}
if session.bind is None:
raise ValueError("Session is not bound to a database")
if session.bind.dialect.name == "sqlite":
result = {}
for service_id, application_ids in results:
result[service_id] = [UUID(app_id) for app_id in application_ids.split(",")]
return result
return {service_id: application_ids for service_id, application_ids in results}
for application_id, service_ids in results:
if dialect_name == "postgresql":
# PostgreSQL returns a list of UUIDs
pass
elif dialect_name == "mysql":
# MySQL returns a JSON string, so we need to parse it
service_ids = json.loads(service_ids)
elif dialect_name == "sqlite":
# SQLite returns a comma-separated string
service_ids = [UUID(id) for id in service_ids.split(",")]
else:
if service_ids and isinstance(service_ids[0], UUID):
# If it's already a list of UUIDs (like in PostgreSQL), use it as is
pass
else:
# For any other case, try to convert to UUID
service_ids = [UUID(str(id)) for id in service_ids]
result[application_id] = service_ids

return result


class TopologiesService:
Expand Down
31 changes: 23 additions & 8 deletions tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def create_service(db_session, tenant_id, id):
service = TopologyService(
tenant_id=tenant_id,
service="test_service_" + id,
display_name="Test Service",
display_name=id,
repository="test_repository",
tags=["test_tag"],
description="test_description",
Expand Down Expand Up @@ -72,21 +72,28 @@ def test_get_all_topology_data(db_session):
def test_get_applications_by_tenant_id(db_session):
service_1 = create_service(db_session, SINGLE_TENANT_UUID, "1")
service_2 = create_service(db_session, SINGLE_TENANT_UUID, "2")
application = TopologyApplication(
application_1 = TopologyApplication(
tenant_id=SINGLE_TENANT_UUID,
name="Test Application",
services=[service_1, service_2],
)
db_session.add(application)
application_2 = TopologyApplication(
tenant_id=SINGLE_TENANT_UUID,
name="Test Application 2",
services=[service_1],
)
db_session.add(application_1)
db_session.add(application_2)
db_session.commit()

result = TopologiesService.get_applications_by_tenant_id(
SINGLE_TENANT_UUID, db_session
)
assert len(result) == 1
assert len(result) == 2
assert result[0].name == "Test Application"
assert len(result[0].services) == 2

assert result[1].name == "Test Application 2"
assert len(result[1].services) == 1

def test_create_application_by_tenant_id(db_session):
application_dto = TopologyApplicationDtoIn(name="New Application", services=[])
Expand Down Expand Up @@ -171,21 +178,29 @@ def test_get_applications(db_session, client, test_app):

service_1 = create_service(db_session, SINGLE_TENANT_UUID, "1")
service_2 = create_service(db_session, SINGLE_TENANT_UUID, "2")
service_3 = create_service(db_session, SINGLE_TENANT_UUID, "3")

application = TopologyApplication(
application_1 = TopologyApplication(
tenant_id=SINGLE_TENANT_UUID,
name="Test Application",
services=[service_1, service_2],
)
db_session.add(application)
application_2 = TopologyApplication(
tenant_id=SINGLE_TENANT_UUID,
name="Test Application 2",
services=[service_3],
)
db_session.add(application_1)
db_session.add(application_2)
db_session.commit()

response = client.get(
"/topology/applications", headers={"x-api-key": VALID_API_KEY}
)
assert response.status_code == 200
assert len(response.json()) == 1
assert len(response.json()) == 2
assert response.json()[0]["name"] == "Test Application"
assert response.json()[1]["services"][0]["name"] == "3"


@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True)
Expand Down

0 comments on commit 4e6dff0

Please sign in to comment.