From eda46e7a1d21be98b3701b15488b32845e83beca Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Thu, 17 Oct 2024 10:34:41 -0300 Subject: [PATCH] Make permissions work with Paginated results --- strawberry_django/pagination.py | 8 +- strawberry_django/permissions.py | 4 + tests/projects/schema.py | 8 +- tests/projects/snapshots/schema.gql | 4 +- tests/test_permissions.py | 167 +++++++++++++++++++++++++++- 5 files changed, 182 insertions(+), 9 deletions(-) diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index 99982f07..34145afa 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -32,7 +32,7 @@ class OffsetPaginationInput: @strawberry.type class Paginated(Generic[NodeType]): - queryset: strawberry.Private[QuerySet] + queryset: strawberry.Private[Optional[QuerySet]] pagination: strawberry.Private[OffsetPaginationInput] @strawberry.field @@ -46,6 +46,9 @@ def offset(self) -> int: @strawberry.field(description="Total count of existing results.") @django_resolver def total_count(self, root) -> int: + if self.queryset is None: + return 0 + return get_total_count(self.queryset) @strawberry.field(description="List of paginated results.") @@ -53,6 +56,9 @@ def total_count(self, root) -> int: def results(self) -> list[NodeType]: from strawberry_django.optimizer import is_optimized_by_prefetching + if self.queryset is None: + return [] + if is_optimized_by_prefetching(self.queryset): results = self.queryset._result_cache # type: ignore else: diff --git a/strawberry_django/permissions.py b/strawberry_django/permissions.py index dab1995a..71ee7eb3 100644 --- a/strawberry_django/permissions.py +++ b/strawberry_django/permissions.py @@ -39,6 +39,7 @@ from strawberry_django.auth.utils import aget_current_user, get_current_user from strawberry_django.fields.types import OperationInfo, OperationMessage +from strawberry_django.pagination import OffsetPaginationInput, Paginated from strawberry_django.resolvers import django_resolver from .utils.query import filter_for_user @@ -405,6 +406,9 @@ def handle_no_permission(self, exception: BaseException, *, info: Info): if isinstance(ret_type, StrawberryList): return [] + if isinstance(ret_type, type) and issubclass(ret_type, Paginated): + return Paginated(queryset=None, pagination=OffsetPaginationInput()) + # If it is a Connection, try to return an empty connection, but only if # it is the only possibility available... for ret_possibility in ret_types: diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 3e8c1fa7..f9814c33 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -435,7 +435,7 @@ class Query: issue_list_perm_required: list[IssueType] = strawberry_django.field( extensions=[HasPerm(perms=["projects.view_issue"])], ) - issue_paginated_list_perm_required: Paginated[IssueType] = strawberry_django.field( + issues_paginated_perm_required: Paginated[IssueType] = strawberry_django.field( extensions=[HasPerm(perms=["projects.view_issue"])], ) issue_conn_perm_required: ListConnectionWithTotalCount[IssueType] = ( @@ -456,10 +456,8 @@ class Query: issue_list_obj_perm_required_paginated: list[IssueType] = strawberry_django.field( extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True ) - issue_paginated_list_obj_perm_required_paginated: Paginated[IssueType] = ( - strawberry_django.field( - extensions=[HasRetvalPerm(perms=["projects.view_issue"])], pagination=True - ) + issues_paginated_obj_perm_required: Paginated[IssueType] = strawberry_django.field( + extensions=[HasRetvalPerm(perms=["projects.view_issue"])], ) issue_conn_obj_perm_required: ListConnectionWithTotalCount[IssueType] = ( strawberry_django.connection( diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index 212a75ed..2fc46b1c 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -742,7 +742,7 @@ type Query { id: GlobalID! ): IssueType @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListPermRequired: [IssueType!]! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) - issuePaginatedListPermRequired(pagination: OffsetPaginationInput): IssueTypePaginated! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) + issuesPaginatedPermRequired(pagination: OffsetPaginationInput): IssueTypePaginated! @hasPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueConnPermRequired( """Returns the items in the list that come before the specified cursor.""" before: String = null @@ -766,7 +766,7 @@ type Query { ): IssueType @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListObjPermRequired: [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueListObjPermRequiredPaginated(pagination: OffsetPaginationInput): [IssueType!]! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) - issuePaginatedListObjPermRequiredPaginated(pagination: OffsetPaginationInput): IssueTypePaginated! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) + issuesPaginatedObjPermRequired(pagination: OffsetPaginationInput): IssueTypePaginated! @hasRetvalPerm(permissions: [{app: "projects", permission: "view_issue"}], any: true) issueConnObjPermRequired( """Returns the items in the list that come before the specified cursor.""" before: String = null diff --git a/tests/test_permissions.py b/tests/test_permissions.py index c1c492ee..c6e79318 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -9,11 +9,12 @@ from .projects.faker import ( GroupFactory, IssueFactory, + MilestoneFactory, StaffUserFactory, SuperuserUserFactory, UserFactory, ) -from .utils import GraphQLTestClient +from .utils import GraphQLTestClient, assert_num_queries PermKind: TypeAlias = Literal["user", "group", "superuser"] perm_kinds: list[PermKind] = ["user", "group", "superuser"] @@ -934,3 +935,167 @@ def test_conn_obj_perm_required(db, gql_client: GraphQLTestClient, kind: PermKin "totalCount": 1, }, } + + +@pytest.mark.django_db(transaction=True) +def test_query_paginated_with_permissions(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery ($pagination: OffsetPaginationInput) { + issuesPaginatedPermRequired (pagination: $pagination) { + totalCount + results { + name + milestone { + name + } + } + } + } + """ + + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + issue1 = IssueFactory.create(milestone=milestone1) + issue2 = IssueFactory.create(milestone=milestone1) + issue3 = IssueFactory.create(milestone=milestone1) + issue4 = IssueFactory.create(milestone=milestone2) + issue5 = IssueFactory.create(milestone=milestone2) + + # No user logged in + with assert_num_queries(0): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedPermRequired": { + "totalCount": 0, + "results": [], + } + } + + user = UserFactory.create() + + # User logged in without permissions + with gql_client.login(user): + with assert_num_queries(4): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedPermRequired": { + "totalCount": 0, + "results": [], + } + } + + # User logged in with permissions + user.user_permissions.add(Permission.objects.get(codename="view_issue")) + with gql_client.login(user): + with assert_num_queries(6 if DjangoOptimizerExtension.enabled.get() else 11): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedPermRequired": { + "totalCount": 5, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + {"name": issue3.name, "milestone": {"name": milestone1.name}}, + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + {"name": issue5.name, "milestone": {"name": milestone2.name}}, + ], + } + } + + with assert_num_queries(6 if DjangoOptimizerExtension.enabled.get() else 8): + res = gql_client.query(query, variables={"pagination": {"limit": 2}}) + + assert res.data == { + "issuesPaginatedPermRequired": { + "totalCount": 5, + "results": [ + {"name": issue1.name, "milestone": {"name": milestone1.name}}, + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_query_paginated_with_obj_permissions(db, gql_client: GraphQLTestClient): + query = """ + query TestQuery ($pagination: OffsetPaginationInput) { + issuesPaginatedObjPermRequired (pagination: $pagination) { + totalCount + results { + name + milestone { + name + } + } + } + } + """ + + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + IssueFactory.create(milestone=milestone1) + issue2 = IssueFactory.create(milestone=milestone1) + IssueFactory.create(milestone=milestone1) + issue4 = IssueFactory.create(milestone=milestone2) + IssueFactory.create(milestone=milestone2) + + # No user logged in + with assert_num_queries(0): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedObjPermRequired": { + "totalCount": 0, + "results": [], + } + } + + user = UserFactory.create() + + # User logged in without permissions + with gql_client.login(user): + with assert_num_queries(5): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedObjPermRequired": { + "totalCount": 0, + "results": [], + } + } + + assign_perm("view_issue", user, issue2) + assign_perm("view_issue", user, issue4) + + # User logged in with permissions + with gql_client.login(user): + with assert_num_queries(4 if DjangoOptimizerExtension.enabled.get() else 6): + res = gql_client.query(query) + + assert res.data == { + "issuesPaginatedObjPermRequired": { + "totalCount": 2, + "results": [ + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + {"name": issue4.name, "milestone": {"name": milestone2.name}}, + ], + } + } + + with assert_num_queries(4 if DjangoOptimizerExtension.enabled.get() else 5): + res = gql_client.query(query, variables={"pagination": {"limit": 1}}) + + assert res.data == { + "issuesPaginatedObjPermRequired": { + "totalCount": 2, + "results": [ + {"name": issue2.name, "milestone": {"name": milestone1.name}}, + ], + } + }