Skip to content

Commit

Permalink
Client: get method with provider arguments (#1359)
Browse files Browse the repository at this point in the history
  • Loading branch information
IceKhan13 authored Jul 1, 2024
1 parent 3a9c41e commit 491ecf7
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 49 deletions.
24 changes: 10 additions & 14 deletions client/qiskit_serverless/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,9 @@ def list(self, **kwargs) -> List[QiskitFunction]:
"""Returns list of available programs."""
raise NotImplementedError

def get(self, title: str) -> Optional[QiskitFunction]:
def get(
self, title: str, provider: Optional[str] = None
) -> Optional[QiskitFunction]:
"""Returns qiskit function based on title provided."""
raise NotImplementedError

Expand Down Expand Up @@ -494,18 +496,10 @@ def list(self, **kwargs) -> List[QiskitFunction]:
"""Returns list of available programs."""
return self._job_client.get_programs(**kwargs)

def get(self, title: str) -> Optional[QiskitFunction]:
results = self._job_client.get_programs(title=title)
if len(results) > 1:
warnings.warn(
f"There are more than 1 program with title {title}"
"available. Returning most recent one. "
"If you want to get list of all functions "
"please, use `list` method."
)

functions = {function.title: function for function in results}
return functions.get(title)
def get(
self, title: str, provider: Optional[str] = None
) -> Optional[QiskitFunction]:
return self._job_client.get_program(title=title, provider=provider)

def _verify_token(self, token: str):
"""Verify token."""
Expand Down Expand Up @@ -744,7 +738,9 @@ def file_delete(self, file: str):
def list(self, **kwargs):
return self.client.get_programs(**kwargs)

def get(self, title: str) -> Optional[QiskitFunction]:
def get(
self, title: str, provider: Optional[str] = None
) -> Optional[QiskitFunction]:
functions = {
function.title: function for function in self.client.get_programs()
}
Expand Down
47 changes: 47 additions & 0 deletions client/qiskit_serverless/core/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
QiskitObjectsDecoder,
)
from qiskit_serverless.utils.json import is_jsonable, safe_json_request
from qiskit_serverless.utils.formatting import format_provider_name_and_title

RuntimeEnv = ray.runtime_env.RuntimeEnv

Expand Down Expand Up @@ -144,6 +145,12 @@ def get_programs(self, **kwargs):
"""Returns list of programs."""
raise NotImplementedError

def get_program(
self, title: str, provider: Optional[str] = None
) -> Optional[QiskitFunction]:
"""Returns program based on parameters."""
raise NotImplementedError


class RayJobClient(BaseJobClient):
"""RayJobClient."""
Expand Down Expand Up @@ -337,6 +344,21 @@ def get_programs(self, **kwargs):
for program in self._patterns
]

def get_program(
self, title: str, provider: Optional[str] = None
) -> Optional[QiskitFunction]:
"""Returns program based on parameters."""
all_programs = {
program.get("title"): QiskitFunction(
program.get("title"),
provider=program.get("provider", None),
raw_data=program,
job_client=self,
)
for program in self._patterns
}
return all_programs.get("title")


class GatewayJobClient(BaseJobClient):
"""GatewayJobClient."""
Expand Down Expand Up @@ -563,6 +585,31 @@ def get_programs(self, **kwargs):
for program in response_data
]

def get_program(
self, title: str, provider: Optional[str] = None
) -> Optional[QiskitFunction]:
"""Returns program based on parameters."""
provider, title = format_provider_name_and_title(
request_provider=provider, title=title
)

tracer = trace.get_tracer("client.tracer")
with tracer.start_as_current_span("program.get_by_title"):
response_data = safe_json_request(
request=lambda: requests.get(
f"{self.host}/api/{self.version}/programs/get_by_title/{title}",
headers={"Authorization": f"Bearer {self._token}"},
params={"provider": provider},
timeout=REQUESTS_TIMEOUT,
)
)
return QiskitFunction(
response_data.get("title"),
provider=response_data.get("provider", None),
raw_data=response_data,
job_client=self,
)


class Job:
"""Job."""
Expand Down
2 changes: 2 additions & 0 deletions client/qiskit_serverless/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
S3Storage
ErrorCodes
JsonSerializable
format_provider_name_and_title
"""

from .json import JsonSerializable
from .errors import ErrorCodes
from .storage import S3Storage, BaseStorage
from .formatting import format_provider_name_and_title
from .runtime_service_client import ServerlessRuntimeService
44 changes: 44 additions & 0 deletions client/qiskit_serverless/utils/formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# This code is a Qiskit project.
#
# (C) Copyright IBM 2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
==========================================================
Json utilities (:mod:`qiskit_serverless.utils.formatting`)
==========================================================
.. currentmodule:: qiskit_serverless.utils.formatting
Qiskit Serverless formatting utilities
======================================
.. autosummary::
:toctree: ../stubs/
format_provider_name_and_title
"""
from typing import Tuple, Union


def format_provider_name_and_title(
request_provider, title
) -> Tuple[Union[str, None], str]:
"""
This method returns provider_name and title from a title with / if it contains it
"""
if request_provider:
return request_provider, title

title_split = title.split("/")
if len(title_split) == 1:
return None, title_split[0]

return title_split[0], title_split[1]
100 changes: 67 additions & 33 deletions gateway/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import mimetypes
import os
import time
from typing import Optional
from wsgiref.util import FileWrapper

from concurrency.exceptions import RecordModifiedError
Expand Down Expand Up @@ -111,41 +112,12 @@ def get_object(self):
def get_queryset(self):
author = self.request.user
title = self.request.query_params.get("title")
provider_name = self.request.query_params.get("provider")

logger.info("ProgramViewSet get view_program permission")
view_program_permission = Permission.objects.get(
codename=VIEW_PROGRAM_PERMISSION
)

# Groups logic
user_criteria = Q(user=author)
view_permission_criteria = Q(permissions=view_program_permission)
author_groups_with_view_permissions = Group.objects.filter(
user_criteria & view_permission_criteria
)
author_groups_with_view_permissions_count = (
author_groups_with_view_permissions.count()
)
logger.info(
"ProgramViewSet get author[%s] groups [%s]",
author.id,
author_groups_with_view_permissions_count,
)
author_programs = self._get_program_queryset_for_title_and_provider(
author=author, title=title, provider_name=provider_name
).distinct()

# Programs logic
author_criteria = Q(author=author)
author_groups_with_view_permissions_criteria = Q(
instances__in=author_groups_with_view_permissions
)
if title:
author_programs = Program.objects.filter(
(author_criteria | author_groups_with_view_permissions_criteria)
& Q(title=title)
).distinct()
else:
author_programs = Program.objects.filter(
author_criteria | author_groups_with_view_permissions_criteria
).distinct()
author_programs_count = author_programs.count()
logger.info(
"ProgramViewSet get author[%s] programs[%s]",
Expand Down Expand Up @@ -324,6 +296,68 @@ def run(self, request):

return Response(job_serializer.data)

@action(methods=["GET"], detail=False, url_path="get_by_title/(?P<title>[^/.]+)")
def get_by_title(self, request, title):
"""Returns programs by title."""
author = self.request.user
provider_name = self.request.query_params.get("provider")

result_program = self._get_program_queryset_for_title_and_provider(
author=author, title=title, provider_name=provider_name
).first()

if result_program:
return Response(self.get_serializer(result_program).data)

return Response(status=404)

def _get_program_queryset_for_title_and_provider(
self, author, title: str, provider_name: Optional[str]
):
"""Returns queryset for program for gived request, title and provider."""
view_program_permission = Permission.objects.get(
codename=VIEW_PROGRAM_PERMISSION
)

# Groups logic
user_criteria = Q(user=author)
view_permission_criteria = Q(permissions=view_program_permission)
author_groups_with_view_permissions = Group.objects.filter(
user_criteria & view_permission_criteria
)
author_groups_with_view_permissions_count = (
author_groups_with_view_permissions.count()
)
logger.info(
"ProgramViewSet get author[%s] groups [%s]",
author.id,
author_groups_with_view_permissions_count,
)

# Programs logic
author_criteria = Q(author=author)
author_groups_with_view_permissions_criteria = Q(
instances__in=author_groups_with_view_permissions
)
if title:
serializer = self.get_serializer_upload_program(data=self.request.data)
provider_name, title = serializer.get_provider_name_and_title(
provider_name, title
)
title_criteria = Q(title=title)
if provider_name:
title_criteria = Q(title=title, provider__name=provider_name)
result_queryset = Program.objects.filter(
(author_criteria | author_groups_with_view_permissions_criteria)
& title_criteria
)
else:
result_queryset = Program.objects.filter(
author_criteria | author_groups_with_view_permissions_criteria
)

return result_queryset


class JobViewSet(viewsets.GenericViewSet):
"""
Expand Down
Loading

0 comments on commit 491ecf7

Please sign in to comment.