Skip to content

Commit

Permalink
Implement typing (#39)
Browse files Browse the repository at this point in the history
* type orm/

* more typing

* typing

* no typing

* CI master instead of main

* upd

* commit

* more typing

* unused

* rm test

---------

Co-authored-by: Koos85 <[email protected]>
  • Loading branch information
joente and Koos85 authored Oct 1, 2024
1 parent 0e60660 commit de1117c
Show file tree
Hide file tree
Showing 13 changed files with 166 additions and 120 deletions.
7 changes: 2 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ name: CI
on:
push:
branches:
- main
- master
pull_request:
branches:
- main
- master

jobs:
build:
Expand All @@ -26,9 +26,6 @@ jobs:
python -m pip install --upgrade pip
pip install pytest pycodestyle
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Run tests with pytest
run: |
pytest
- name: Lint with PyCodeStyle
run: |
find . -name \*.py -exec pycodestyle {} +
2 changes: 1 addition & 1 deletion aiogcd/connector/client_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _read_token_file(self):
return token
return None

async def get(self):
async def get(self) -> str:
"""Returns the access token. If _refresh_ts is passed, the token will
be refreshed. A lock is used to prevent refreshing the token twice.
Expand Down
94 changes: 55 additions & 39 deletions aiogcd/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import json
import aiohttp
from typing import Iterable, Optional, Any, Union
from .client_token import Token
from .service_account_token import ServiceAccountToken
from .entity import Entity
Expand All @@ -26,7 +27,7 @@
_MAX_LOOPS = 128


def _get_api_endpoint():
def _get_api_endpoint() -> str:
emu_host = os.getenv('DATASTORE_EMULATOR_HOST')
if emu_host is None:
return DEFAULT_API_ENDPOINT
Expand All @@ -37,12 +38,12 @@ class GcdConnector:

def __init__(
self,
project_id,
client_id,
client_secret,
token_file,
scopes=DEFAULT_SCOPES,
namespace_id=None):
project_id: str,
client_id: str,
client_secret: str,
token_file: str,
scopes: Iterable[str] = DEFAULT_SCOPES,
namespace_id: Optional[str] = None):

self.project_id = project_id
self.namespace_id = namespace_id
Expand Down Expand Up @@ -72,7 +73,7 @@ def __init__(
async def connect(self):
await self._token.connect()

async def insert_entities(self, entities):
async def insert_entities(self, entities) -> tuple[bool, ...]:
"""Returns a tuple containing boolean values. Each boolean value is
True in case of a successful mutation and False if not. The order of
booleans is the same as the supplied tuple or list.
Expand All @@ -88,7 +89,7 @@ async def insert_entities(self, entities):
# alias
entities = insert_entities

async def insert_entity(self, entity):
async def insert_entity(self, entity: Entity) -> bool:
"""Returns True if successful or False if not. In case of False then
most likely a conflict was detected.
Expand All @@ -100,7 +101,8 @@ async def insert_entity(self, entity):
"""
return (await self._commit_entities_or_keys([entity], 'insert'))[0]

async def upsert_entities(self, entities):
async def upsert_entities(self, entities: Iterable[Entity]) -> \
tuple[bool, ...]:
"""Returns a tuple containing boolean values. Each boolean value is
True in case of a successful mutation and False if not. The order of
booleans is the same as the supplied tuple or list.
Expand All @@ -112,7 +114,7 @@ async def upsert_entities(self, entities):
"""
return await self._commit_entities_or_keys(entities, 'upsert')

async def upsert_entity(self, entity):
async def upsert_entity(self, entity: Entity) -> bool:
"""Returns True if successful or False if not. In case of False then
most likely a conflict was detected.
Expand All @@ -123,7 +125,8 @@ async def upsert_entity(self, entity):
"""
return (await self._commit_entities_or_keys([entity], 'upsert'))[0]

async def update_entities(self, entities):
async def update_entities(self, entities: Iterable[Entity]) -> \
tuple[bool, ...]:
"""Returns a tuple containing boolean values. Each boolean value is
True in case of a successful mutation and False if not. The order of
booleans is the same as the supplied tuple or list.
Expand All @@ -135,7 +138,7 @@ async def update_entities(self, entities):
"""
return await self._commit_entities_or_keys(entities, 'update')

async def update_entity(self, entity):
async def update_entity(self, entity: Entity) -> bool:
"""Returns True if successful or False if not. In case of False then
most likely a conflict was detected.
Expand All @@ -146,7 +149,7 @@ async def update_entity(self, entity):
"""
return (await self._commit_entities_or_keys([entity], 'update'))[0]

async def delete_keys(self, keys):
async def delete_keys(self, keys: Iterable[Key]) -> tuple[bool, ...]:
"""Returns a tuple containing boolean values. Each boolean value is
True in case of a successful mutation and False if not. The order of
booleans is the same as the supplied tuple or list.
Expand All @@ -158,7 +161,7 @@ async def delete_keys(self, keys):
"""
return await self._commit_entities_or_keys(keys, 'delete')

async def delete_key(self, key):
async def delete_key(self, key: Key) -> bool:
"""Returns True if successful or False if not. In case of False then
most likely a conflict was detected.
Expand All @@ -169,7 +172,8 @@ async def delete_key(self, key):
"""
return (await self._commit_entities_or_keys([key], 'delete'))[0]

async def commit(self, mutations):
async def commit(self, mutations: Iterable[dict[str, Any]]) -> \
tuple[dict, ...]:
"""Commit mutations.
The only supported commit mode is NON_TRANSACTIONAL.
Expand Down Expand Up @@ -205,7 +209,7 @@ async def commit(self, mutations):
resp.status
))

async def run_query(self, data):
async def run_query(self, data) -> list[dict]:
"""Return entities by given query data.
:param data: see the following link for the data format:
Expand All @@ -216,7 +220,7 @@ async def run_query(self, data):
results, _ = await self._run_query(data)
return results

async def _run_query(self, data):
async def _run_query(self, data) -> tuple[list[dict], Optional[str]]:
results = []
cursor = None

Expand Down Expand Up @@ -276,11 +280,12 @@ async def _run_query(self, data):

return results, cursor

async def _get_entities_cursor(self, data):
async def _get_entities_cursor(self, data) -> \
tuple[list[Entity], Optional[str]]:
results, cursor = await self._run_query(data)
return [Entity(result['entity']) for result in results], cursor

async def get_entities(self, data):
async def get_entities(self, data) -> list[Entity]:
"""Return entities by given query data.
:param data: see the following link for the data format:
Expand All @@ -291,12 +296,12 @@ async def get_entities(self, data):
results, _ = await self._run_query(data)
return [Entity(result['entity']) for result in results]

async def get_keys(self, data):
async def get_keys(self, data) -> list[Key]:
data['query']['projection'] = [{'property': {'name': '__key__'}}]
results, _ = await self._run_query(data)
return [Key(result['entity']['key']) for result in results]

async def get_entity(self, data):
async def get_entity(self, data) -> Optional[Entity]:
"""Return an entity object by given query data.
:param data: see the following link for the data format:
Expand All @@ -308,19 +313,24 @@ async def get_entity(self, data):
result = await self.get_entities(data)
return result[0] if result else None

async def get_key(self, data):
async def get_key(self, data) -> Optional[Key]:
data['query']['limit'] = 1
result = await self.get_keys(data)
return result[0] if result else None

async def get_entities_by_kind(self, kind, offset=None, limit=None,
cursor=None):
async def get_entities_by_kind(self, kind: str,
offset: Optional[int] = None,
limit: Optional[int] = None,
cursor: Optional[str] = None) -> Union[
list[Entity],
tuple[list[Entity], Optional[str]]
]:
"""Returns entities by kind.
When a limit is set, this function returns a list and a cursor.
If no limit is used, then only the list will be returned.
"""
query = {'kind': [{'name': kind}]}
query: dict[str, Any] = {'kind': [{'name': kind}]}
data = {'query': query}
if cursor:
query['startCursor'] = cursor
Expand All @@ -333,10 +343,13 @@ async def get_entities_by_kind(self, kind, offset=None, limit=None,
query['limit'] = limit
return await self._get_entities_cursor(data)

async def get_entities_by_keys(self, keys, missing=None, deferred=None,
eventual=False):
async def get_entities_by_keys(self, keys: Iterable[Key],
missing: Optional[list[Any]] = None,
deferred: Optional[list[Key]] = None,
eventual: bool = False) -> list[Entity]:
"""Returns entity objects for the given keys or an empty list in case
no entity is found.
no entity is found. The order of entities might not be equal to the
order of provided keys.
:param keys: list of Key objects
:return: list of Entity objects.
Expand Down Expand Up @@ -384,8 +397,10 @@ def data():

return entities

async def get_entity_by_key(self, key, missing=None, deferred=None,
eventual=False):
async def get_entity_by_key(self, key: Key,
missing: Optional[list[Any]] = None,
deferred: Optional[list[Key]] = None,
eventual: bool = False) -> Optional[Entity]:
"""Returns an entity object for the given key or None in case no
entity is found.
Expand All @@ -397,23 +412,24 @@ async def get_entity_by_key(self, key, missing=None, deferred=None,
if entity:
return entity[0]

async def _get_headers(self):
async def _get_headers(self) -> dict[str, str]:
token = await self._token.get()
return {
'Authorization': 'Bearer {}'.format(token),
'Content-Type': 'application/json'
}

@staticmethod
def _check_mutation_result(entity_or_key, mutation_result):
def _check_mutation_result(entity_or_key, mutation_result) -> bool:
if 'key' in mutation_result:
# The automatically allocated key.
# Set only when the mutation allocated a key.
entity_or_key.key = Key(mutation_result['key'])

return not mutation_result.get('conflictDetected', False)

async def _commit_entities_or_keys(self, entities_or_keys, method):
async def _commit_entities_or_keys(self, entities_or_keys, method) -> \
tuple[bool, ...]:
mutations = [
{method: entity_or_key.get_dict()}
for entity_or_key in entities_or_keys]
Expand All @@ -429,11 +445,11 @@ async def _commit_entities_or_keys(self, entities_or_keys, method):
class GcdServiceAccountConnector(GcdConnector):
def __init__(
self,
project_id,
service_file,
session=None,
scopes=None,
namespace_id=None):
project_id: str,
service_file: str,
session: Optional[aiohttp.ClientSession] = None,
scopes: Optional[Iterable[str]] = None,
namespace_id: Optional[str] = None):

scopes = scopes or list(DEFAULT_SCOPES)
self.project_id = project_id
Expand Down
6 changes: 3 additions & 3 deletions aiogcd/connector/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class Decoder(Buffer):
_idx = None
_end = None

def __new__(cls, *args, ks=None):
def __new__(cls, *args, ks):
assert ks is not None, \
'Key string is required, for example: Decoder(ks=<ket_string>)'
'Key string is required, for example: Decoder(ks=<key_string>)'

decoder = super().__new__(cls)

Expand Down Expand Up @@ -97,7 +97,7 @@ def get_var_int64(self):

return result

def get_prefixed_string(self):
def get_prefixed_string(self) -> str:
n = self.get_var_int32()
if self._idx + n > len(self):
raise BufferDecodeError('truncated')
Expand Down
2 changes: 1 addition & 1 deletion aiogcd/connector/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class Entity:

def __init__(self, entity_res):
def __init__(self, entity_res: dict):
"""Initialize an Entity object.
Example:
Expand Down
12 changes: 8 additions & 4 deletions aiogcd/connector/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Author: Jeroen van der Heijden <[email protected]>
"""
import base64
from typing import Optional
from .buffer import Buffer
from .buffer import BufferDecodeError
from .path import Path
Expand Down Expand Up @@ -40,8 +41,10 @@ class Key:
"""
_ks = None

def __init__(self, *args, ks=None, path=None, project_id=None,
namespace_id=None):
def __init__(self, *args, ks: Optional[str] = None,
path: Optional[Path] = None,
project_id: Optional[str] = None,
namespace_id: Optional[str] = None):
if len(args) == 1 and isinstance(args[0], dict):
assert ks is None and path is None and project_id is None, \
self.KEY_INIT_MSG
Expand Down Expand Up @@ -147,8 +150,9 @@ def _extract_id_or_name(pair):
return None

@staticmethod
def _deserialize_ks(ks):
"""Returns a Key() object from a key string."""
def _deserialize_ks(ks: str):
"""Returns a tuple with the project_id, namespace_id and Path
from a key string."""

decoder = Decoder(ks=ks)
project_id = None
Expand Down
Loading

0 comments on commit de1117c

Please sign in to comment.