Skip to content

Commit

Permalink
fix: get_service_application_ids_dict: handle json_arrayagg for mysql…
Browse files Browse the repository at this point in the history
… and other dialects
  • Loading branch information
Kiryous committed Sep 30, 2024
1 parent 29718c1 commit 4837d5e
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions keep/topologies/topologies_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import ValidationError
from sqlalchemy.orm import joinedload, selectinload
from uuid import UUID
import json

from sqlmodel import Session, select

Expand Down Expand Up @@ -49,22 +50,38 @@ def get_service_application_ids_dict(
TopologyServiceApplication.service_id,
get_aggreated_field(
session,
TopologyServiceApplication.application_id, # type: ignore
TopologyServiceApplication.application_id, # type: ignore
"application_ids",
),
) # type: ignore
)
.where(TopologyServiceApplication.service_id.in_(service_ids))
.group_by(TopologyServiceApplication.service_id)
)
results = session.exec(query).all()
dialect_name = session.bind.dialect.name if session.bind else ""
result = {}
if session.bind is None:
raise ValueError("Session is not bound to a database")
if session.bind.dialect.name == "sqlite":
result = {}
for service_id, application_ids in results:
result[service_id] = [UUID(app_id) for app_id in application_ids.split(",")]
return result
return {service_id: application_ids for service_id, application_ids in results}
for application_id, service_ids in results:
if dialect_name == "postgresql":
# PostgreSQL returns a list of UUIDs
pass
elif dialect_name == "mysql":
# MySQL returns a JSON string, so we need to parse it
service_ids = json.loads(service_ids)
elif dialect_name == "sqlite":
# SQLite returns a comma-separated string
service_ids = [UUID(id) for id in service_ids.split(",")]
else:
if service_ids and isinstance(service_ids[0], UUID):
# If it's already a list of UUIDs (like in PostgreSQL), use it as is
pass
else:
# For any other case, try to convert to UUID
service_ids = [UUID(str(id)) for id in service_ids]
result[application_id] = service_ids

return result


class TopologiesService:
Expand Down

0 comments on commit 4837d5e

Please sign in to comment.