Skip to content

Commit

Permalink
Make permissions work with Paginated results
Browse files Browse the repository at this point in the history
  • Loading branch information
bellini666 committed Oct 17, 2024
1 parent ed5019b commit eda46e7
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 9 deletions.
8 changes: 7 additions & 1 deletion strawberry_django/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class OffsetPaginationInput:

@strawberry.type
class Paginated(Generic[NodeType]):
queryset: strawberry.Private[QuerySet]
queryset: strawberry.Private[Optional[QuerySet]]
pagination: strawberry.Private[OffsetPaginationInput]

@strawberry.field
Expand All @@ -46,13 +46,19 @@ def offset(self) -> int:
@strawberry.field(description="Total count of existing results.")
@django_resolver
def total_count(self, root) -> int:
if self.queryset is None:
return 0

return get_total_count(self.queryset)

@strawberry.field(description="List of paginated results.")
@django_resolver
def results(self) -> list[NodeType]:
from strawberry_django.optimizer import is_optimized_by_prefetching

if self.queryset is None:
return []

if is_optimized_by_prefetching(self.queryset):
results = self.queryset._result_cache # type: ignore
else:
Expand Down
4 changes: 4 additions & 0 deletions strawberry_django/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

from strawberry_django.auth.utils import aget_current_user, get_current_user
from strawberry_django.fields.types import OperationInfo, OperationMessage
from strawberry_django.pagination import OffsetPaginationInput, Paginated
from strawberry_django.resolvers import django_resolver

from .utils.query import filter_for_user
Expand Down Expand Up @@ -405,6 +406,9 @@ def handle_no_permission(self, exception: BaseException, *, info: Info):
if isinstance(ret_type, StrawberryList):
return []

if isinstance(ret_type, type) and issubclass(ret_type, Paginated):
return Paginated(queryset=None, pagination=OffsetPaginationInput())

# If it is a Connection, try to return an empty connection, but only if
# it is the only possibility available...
for ret_possibility in ret_types:
Expand Down
8 changes: 3 additions & 5 deletions tests/projects/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ class Query:
issue_list_perm_required: list[IssueType] = strawberry_django.field(
extensions=[HasPerm(perms=["projects.view_issue"])],
)
issue_paginated_list_perm_required: Paginated[IssueType] = strawberry_django.field(
issues_paginated_perm_required: Paginated[IssueType] = strawberry_django.field(
extensions=[HasPerm(perms=["projects.view_issue"])],
)
issue_conn_perm_required: ListConnectionWithTotalCount[IssueType] = (
Expand All @@ -456,10 +456,8 @@ class Query:
issue_list_obj_perm_required_paginated: list[IssueType] = strawberry_django.field(
extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True
)
issue_paginated_list_obj_perm_required_paginated: Paginated[IssueType] = (
strawberry_django.field(
extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True
)
issues_paginated_obj_perm_required: Paginated[IssueType] = strawberry_django.field(
extensions=[HasRetvalPerm(perms=["projects.view_issue"])],
)
issue_conn_obj_perm_required: ListConnectionWithTotalCount[IssueType] = (
strawberry_django.connection(
Expand Down
4 changes: 2 additions & 2 deletions tests/projects/snapshots/schema.gql
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ type Query {
id: GlobalID!
): IssueType @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issueListPermRequired: [IssueType!]! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issuePaginatedListPermRequired(pagination: OffsetPaginationInput): IssueTypePaginated! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issuesPaginatedPermRequired(pagination: OffsetPaginationInput): IssueTypePaginated! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issueConnPermRequired(
"""Returns the items in the list that come before the specified cursor."""
before: String = null
Expand All @@ -766,7 +766,7 @@ type Query {
): IssueType @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issueListObjPermRequired: [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issueListObjPermRequiredPaginated(pagination: OffsetPaginationInput): [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issuePaginatedListObjPermRequiredPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issuesPaginatedObjPermRequired(pagination: OffsetPaginationInput): IssueTypePaginated! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true)
issueConnObjPermRequired(
"""Returns the items in the list that come before the specified cursor."""
before: String = null
Expand Down
167 changes: 166 additions & 1 deletion tests/test_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from .projects.faker import (
GroupFactory,
IssueFactory,
MilestoneFactory,
StaffUserFactory,
SuperuserUserFactory,
UserFactory,
)
from .utils import GraphQLTestClient
from .utils import GraphQLTestClient, assert_num_queries

PermKind: TypeAlias = Literal["user", "group", "superuser"]
perm_kinds: list[PermKind] = ["user", "group", "superuser"]
Expand Down Expand Up @@ -934,3 +935,167 @@ def test_conn_obj_perm_required(db, gql_client: GraphQLTestClient, kind: PermKin
"totalCount": 1,
},
}


@pytest.mark.django_db(transaction=True)
def test_query_paginated_with_permissions(db, gql_client: GraphQLTestClient):
query = """
query TestQuery ($pagination: OffsetPaginationInput) {
issuesPaginatedPermRequired (pagination: $pagination) {
totalCount
results {
name
milestone {
name
}
}
}
}
"""

milestone1 = MilestoneFactory.create()
milestone2 = MilestoneFactory.create()

issue1 = IssueFactory.create(milestone=milestone1)
issue2 = IssueFactory.create(milestone=milestone1)
issue3 = IssueFactory.create(milestone=milestone1)
issue4 = IssueFactory.create(milestone=milestone2)
issue5 = IssueFactory.create(milestone=milestone2)

# No user logged in
with assert_num_queries(0):
res = gql_client.query(query)

assert res.data == {
"issuesPaginatedPermRequired": {
"totalCount": 0,
"results": [],
}
}

user = UserFactory.create()

# User logged in without permissions
with gql_client.login(user):
with assert_num_queries(4):
res = gql_client.query(query)

assert res.data == {
"issuesPaginatedPermRequired": {
"totalCount": 0,
"results": [],
}
}

# User logged in with permissions
user.user_permissions.add(Permission.objects.get(codename="view_issue"))
with gql_client.login(user):
with assert_num_queries(6 if DjangoOptimizerExtension.enabled.get() else 11):
res = gql_client.query(query)

assert res.data == {
"issuesPaginatedPermRequired": {
"totalCount": 5,
"results": [
{"name": issue1.name, "milestone": {"name": milestone1.name}},
{"name": issue2.name, "milestone": {"name": milestone1.name}},
{"name": issue3.name, "milestone": {"name": milestone1.name}},
{"name": issue4.name, "milestone": {"name": milestone2.name}},
{"name": issue5.name, "milestone": {"name": milestone2.name}},
],
}
}

with assert_num_queries(6 if DjangoOptimizerExtension.enabled.get() else 8):
res = gql_client.query(query, variables={"pagination": {"limit": 2}})

assert res.data == {
"issuesPaginatedPermRequired": {
"totalCount": 5,
"results": [
{"name": issue1.name, "milestone": {"name": milestone1.name}},
{"name": issue2.name, "milestone": {"name": milestone1.name}},
],
}
}


@pytest.mark.django_db(transaction=True)
def test_query_paginated_with_obj_permissions(db, gql_client: GraphQLTestClient):
query = """
query TestQuery ($pagination: OffsetPaginationInput) {
issuesPaginatedObjPermRequired (pagination: $pagination) {
totalCount
results {
name
milestone {
name
}
}
}
}
"""

milestone1 = MilestoneFactory.create()
milestone2 = MilestoneFactory.create()

IssueFactory.create(milestone=milestone1)
issue2 = IssueFactory.create(milestone=milestone1)
IssueFactory.create(milestone=milestone1)
issue4 = IssueFactory.create(milestone=milestone2)
IssueFactory.create(milestone=milestone2)

# No user logged in
with assert_num_queries(0):
res = gql_client.query(query)

assert res.data == {
"issuesPaginatedObjPermRequired": {
"totalCount": 0,
"results": [],
}
}

user = UserFactory.create()

# User logged in without permissions
with gql_client.login(user):
with assert_num_queries(5):
res = gql_client.query(query)

assert res.data == {
"issuesPaginatedObjPermRequired": {
"totalCount": 0,
"results": [],
}
}

assign_perm("view_issue", user, issue2)
assign_perm("view_issue", user, issue4)

# User logged in with permissions
with gql_client.login(user):
with assert_num_queries(4 if DjangoOptimizerExtension.enabled.get() else 6):
res = gql_client.query(query)

assert res.data == {
"issuesPaginatedObjPermRequired": {
"totalCount": 2,
"results": [
{"name": issue2.name, "milestone": {"name": milestone1.name}},
{"name": issue4.name, "milestone": {"name": milestone2.name}},
],
}
}

with assert_num_queries(4 if DjangoOptimizerExtension.enabled.get() else 5):
res = gql_client.query(query, variables={"pagination": {"limit": 1}})

assert res.data == {
"issuesPaginatedObjPermRequired": {
"totalCount": 2,
"results": [
{"name": issue2.name, "milestone": {"name": milestone1.name}},
],
}
}

0 comments on commit eda46e7

Please sign in to comment.