Skip to content

Commit

Permalink
fix(optimizer): handle existing select_related in querysets (#515)
Browse files Browse the repository at this point in the history
* fix(optimizer): handle existing select_related in querysets

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
taobojlen and pre-commit-ci[bot] authored May 11, 2024
1 parent f67efd4 commit 0efbc1d
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 4 deletions.
30 changes: 26 additions & 4 deletions strawberry_django/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,18 +285,40 @@ def apply(

only_set = set(self.only)
select_related_only_set = set()
select_related_set = set(self.select_related)

# inspect the queryset to find any existing select_related fields
def get_related_fields_with_prefix(
queryset_select_related: dict[str, Any], prefix=""
):
fields = []
for parent, nested in queryset_select_related.items():
current_path = f"{prefix}{parent}"
fields.append(current_path)
if nested: # If there are nested relations, dive deeper
fields.extend(
get_related_fields_with_prefix(
nested, prefix=current_path + "__"
)
)
return fields

if isinstance(qs.query.select_related, dict):
select_related_set.update(
get_related_fields_with_prefix(qs.query.select_related)
)

if config.enable_select_related and self.select_related:
qs = qs.select_related(*self.select_related)
if config.enable_select_related and select_related_set:
qs = qs.select_related(*select_related_set)

for select_related in self.select_related:
for select_related in select_related_set:
if select_related in only_set:
continue

if not any(only.startswith(select_related) for only in only_set):
select_related_only_set.add(select_related)

if config.enable_only and only_set:
if config.enable_only and (only_set or select_related_only_set):
qs = qs.only(*(only_set | select_related_only_set))

if config.enable_annotate and self.annotate:
Expand Down
9 changes: 9 additions & 0 deletions tests/projects/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,15 @@ class TagType(relay.Node):
name: strawberry.auto
issues: ListConnectionWithTotalCount[IssueType] = strawberry_django.connection()

@strawberry_django.field
def issues_with_selected_related_milestone_and_project(self) -> List[IssueType]:
# here, the `select_related` is on the queryset directly, and not on the field
return (
self.issues.all() # type: ignore
.select_related("milestone", "milestone__project")
.order_by("id")
)


@strawberry_django.type(Quiz)
class QuizType(relay.Node):
Expand Down
1 change: 1 addition & 0 deletions tests/projects/snapshots/schema.gql
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,7 @@ type TagType implements Node {
"""Returns the items in the list that come after the specified cursor."""
last: Int = null
): IssueTypeConnection!
issuesWithSelectedRelatedMilestoneAndProject: [IssueType!]!
}

"""A connection to a list of items."""
Expand Down
1 change: 1 addition & 0 deletions tests/projects/snapshots/schema_with_inheritance.gql
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ type TagType implements Node {
"""Returns the items in the list that come after the specified cursor."""
last: Int = null
): IssueTypeConnection!
issuesWithSelectedRelatedMilestoneAndProject: [IssueType!]!
}

type UserType implements Node {
Expand Down
37 changes: 37 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,43 @@ def test_query_select_related_without_only(db, gql_client: GraphQLTestClient):
}


@pytest.mark.django_db(transaction=True)
def test_handles_existing_select_related(db, gql_client: GraphQLTestClient):
"""select_related should not cause errors, even if the field does not get queried."""
# We're *not* querying the issues' milestones, even though it's
# prefetched.
query = """
query TestQuery {
tagList {
issuesWithSelectedRelatedMilestoneAndProject {
id
name
}
}
}
"""

tag = TagFactory.create()

issues = IssueFactory.create_batch(3)
for issue in issues:
tag.issues.add(issue)

with assert_num_queries(2):
res = gql_client.query(query)

assert res.data == {
"tagList": [
{
"issuesWithSelectedRelatedMilestoneAndProject": [
{"id": to_base64("IssueType", t.id), "name": t.name}
for t in sorted(issues, key=lambda i: i.pk)
],
},
],
}


@pytest.mark.django_db(transaction=True)
def test_query_nested_connection_with_filter(db, gql_client: GraphQLTestClient):
query = """
Expand Down

0 comments on commit 0efbc1d

Please sign in to comment.