Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 14, 2024
1 parent 864a99f commit 1b1a21e
Show file tree
Hide file tree
Showing 107 changed files with 126 additions and 207 deletions.
4 changes: 3 additions & 1 deletion .github/actions/generate-credentials/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def generate_credentials() -> Dict[str, Any]:
credentials = {}
for user in ["owner", "admin", "annotator"]:
logging.info(f"Generating random credential for user '{user}'")
password = generate_password_from_secret(secret=SECRET, salt=f"{GITHUB_REF}/{user}", length=32)
password = generate_password_from_secret(
secret=SECRET, salt=f"{GITHUB_REF}/{user}", length=32
)
credentials[user] = password
return credentials

Expand Down
16 changes: 12 additions & 4 deletions .github/actions/slack-post-credentials/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def get_slack_channel_id(client: WebClient) -> Union[str, None]:
for channel in result["channels"]:
if channel["name"] == SLACK_CHANNEL_NAME:
channel_id = channel["id"]
logging.info(f"Found channel id for '{SLACK_CHANNEL_NAME}' channel: '{channel_id}'")
logging.info(
f"Found channel id for '{SLACK_CHANNEL_NAME}' channel: '{channel_id}'"
)
return channel_id


Expand All @@ -87,7 +89,9 @@ def get_pr_url(pr_number: int) -> str:
return f"https://github.com/argilla-io/argilla/pull/{pr_number}"


def get_thread_ts_pr_message(client: WebClient, channel_id: str, pr_number: int) -> Union[str, None]:
def get_thread_ts_pr_message(
client: WebClient, channel_id: str, pr_number: int
) -> Union[str, None]:
response = client.conversations_history(channel=channel_id, limit=1000)
response.validate()

Expand Down Expand Up @@ -119,7 +123,9 @@ def bot_already_replied(client: WebClient, channel_id: str, thread_ts: str) -> b
return False


def reply_thread_with_credentials(client: WebClient, channel_id: str, thread_ts: str) -> None:
def reply_thread_with_credentials(
client: WebClient, channel_id: str, thread_ts: str
) -> None:
client.chat_postMessage(
channel=channel_id,
text=f"Credentials for PR deployed environment (use as password and API key):\n- URL: {URL}\n- owner: '{OWNER}'\n- admin: '{ADMIN}'\n- annotator: '{ANNOTATOR}'",
Expand Down Expand Up @@ -153,7 +159,9 @@ def reply_thread_with_credentials(client: WebClient, channel_id: str, thread_ts:
pr_number = get_pull_request_number()
if pr_number is None:
logging.error(f"Could not parse `GITHUB_REF` ({GITHUB_REF}) to get PR number")
raise ValueError(f"Could not parse `GITHUB_REF` ({GITHUB_REF}) to get PR number")
raise ValueError(
f"Could not parse `GITHUB_REF` ({GITHUB_REF}) to get PR number"
)

client = get_slack_client()

Expand Down
1 change: 0 additions & 1 deletion argilla-server/src/argilla_server/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from argilla_server.database import get_async_db
from argilla_server.logging import configure_logging
from argilla_server.models import User
from argilla_server.pydantic_v1.errors import ConfigError
from argilla_server.search_engine import get_search_engine
from argilla_server.security import auth
from argilla_server.settings import settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
Create Date: 2023-07-24 12:47:11.715011
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from argilla_server import models
from argilla_server.api.policies.v1 import DatasetPolicy, MetadataPropertyPolicy, authorize, is_authorized
from argilla_server.api.schemas.v1.datasets import (
Dataset as DatasetSchema,
Expand All @@ -38,10 +37,10 @@
MetadataPropertyCreate,
)
from argilla_server.api.schemas.v1.vector_settings import VectorSettings, VectorSettingsCreate, VectorsSettings
from argilla_server.contexts import accounts, datasets
from argilla_server.contexts import datasets
from argilla_server.database import get_async_db
from argilla_server.enums import ResponseStatus
from argilla_server.models import Dataset, User, Workspace
from argilla_server.models import Dataset, User
from argilla_server.search_engine import (
SearchEngine,
get_search_engine,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import argilla_server.search_engine as search_engine
from argilla_server.api.policies.v1 import DatasetPolicy, RecordPolicy, authorize, is_authorized
from argilla_server.api.schemas.v1.datasets import Dataset as DatasetSchema
from argilla_server.api.schemas.v1.records import (
Filters,
FilterScope,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from uuid import UUID

from fastapi import APIRouter, Depends, Security, status
from fastapi import APIRouter, Depends, Security
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from uuid import UUID

from fastapi import APIRouter, Depends, Security, status
from fastapi import APIRouter, Depends, Security
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _check_oauth_enabled_or_raise() -> None:


def _get_provider_by_name_or_raise(provider_name: str) -> OAuth2ClientProvider:
if not provider_name in settings.oauth.providers:
if provider_name not in settings.oauth.providers:
raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
return settings.oauth.providers[provider_name]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from uuid import UUID

from fastapi import APIRouter, Depends, Security, status
from fastapi import APIRouter, Depends, Security
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

Expand Down
3 changes: 1 addition & 2 deletions argilla-server/src/argilla_server/api/handlers/v1/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING
from uuid import UUID

from fastapi import APIRouter, Depends, Query, Security, status
Expand All @@ -26,7 +25,7 @@
from argilla_server.api.schemas.v1.responses import Response, ResponseCreate
from argilla_server.api.schemas.v1.suggestions import Suggestion as SuggestionSchema
from argilla_server.api.schemas.v1.suggestions import SuggestionCreate, Suggestions
from argilla_server.contexts import datasets, questions
from argilla_server.contexts import datasets
from argilla_server.database import get_async_db
from argilla_server.errors.future.base_errors import NotFoundError, UnprocessableEntityError
from argilla_server.models import Dataset, Question, Record, Suggestion, User
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from uuid import UUID

from fastapi import APIRouter, Depends, Security, status
from fastapi import APIRouter, Depends, Security
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from uuid import UUID

from fastapi import APIRouter, Depends, Security, status
from fastapi import APIRouter, Depends, Security
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

Expand Down
2 changes: 0 additions & 2 deletions argilla-server/src/argilla_server/api/handlers/v1/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from uuid import UUID

from fastapi import APIRouter, Depends, Request, Security, status
Expand All @@ -25,7 +24,6 @@
from argilla_server.api.schemas.v1.workspaces import Workspaces
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.errors.future import NotUniqueError
from argilla_server.models import User
from argilla_server.security import auth

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from uuid import UUID

from fastapi import APIRouter, Depends, Security, status
from fastapi import APIRouter, Depends, Security
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
Workspaces,
WorkspaceUserCreate,
)
from argilla_server.contexts import accounts, datasets
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.errors.future import NotFoundError, UnprocessableEntityError
from argilla_server.models import User, Workspace, WorkspaceUser
Expand Down
3 changes: 1 addition & 2 deletions argilla-server/src/argilla_server/api/policies/v1/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Awaitable, Callable, Optional
from uuid import UUID
from typing import Awaitable, Callable

from argilla_server.errors import ForbiddenOperationError
from argilla_server.models import User
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from argilla_server.api.schemas.v1.commons import UpdateSchema
from argilla_server.api.schemas.v1.fields import FieldName
from argilla_server.enums import OptionsOrder, QuestionType
from argilla_server.pydantic_v1 import BaseModel, Field, conlist, constr, root_validator, validator
from argilla_server.pydantic_v1 import BaseModel, Field, conlist, constr, root_validator
from argilla_server.settings import settings

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional
from typing import Optional

from argilla_server.pydantic_v1 import BaseModel, BaseSettings, Field

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from typing import Any, List, Literal, Optional, Union
from uuid import UUID

from argilla_server.api.schemas.v1.questions import QuestionName
Expand Down
5 changes: 0 additions & 5 deletions argilla-server/src/argilla_server/bulk/records_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ async def create_records_bulk(self, dataset: Dataset, bulk_create: RecordsBulkCr
return RecordsBulk(items=records)

async def _upsert_records_relationships(self, records: List[Record], records_create: List[RecordCreate]) -> None:

records_and_suggestions = list(zip(records, [r.suggestions for r in records_create]))
records_and_responses = list(zip(records, [r.responses for r in records_create]))
records_and_vectors = list(zip(records, [r.vectors for r in records_create]))
Expand All @@ -88,7 +87,6 @@ async def _upsert_records_relationships(self, records: List[Record], records_cre
async def _upsert_records_suggestions(
self, records_and_suggestions: List[Tuple[Record, List[SuggestionCreate]]]
) -> List[Suggestion]:

upsert_many_suggestions = []
for idx, (record, suggestions) in enumerate(records_and_suggestions):
try:
Expand Down Expand Up @@ -121,7 +119,6 @@ async def _upsert_records_suggestions(
async def _upsert_records_responses(
self, records_and_responses: List[Tuple[Record, List[UserResponseCreate]]]
) -> List[Response]:

user_ids = [response.user_id for _, responses in records_and_responses for response in responses or []]
users_by_id = await fetch_users_by_ids_as_dict(self._db, user_ids)

Expand Down Expand Up @@ -152,7 +149,6 @@ async def _upsert_records_responses(
async def _upsert_records_vectors(
self, records_and_vectors: List[Tuple[Record, Dict[str, List[float]]]]
) -> List[Vector]:

upsert_many_vectors = []
for idx, (record, vectors) in enumerate(records_and_vectors):
try:
Expand Down Expand Up @@ -225,7 +221,6 @@ async def _fetch_existing_dataset_records(
dataset: Dataset,
records_upsert: List[RecordUpsert],
) -> Dict[Union[str, UUID], Record]:

records_by_external_id = await fetch_records_by_external_ids_as_dict(
self._db, dataset, [r.external_id for r in records_upsert]
)
Expand Down
2 changes: 1 addition & 1 deletion argilla-server/src/argilla_server/cli/rich.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any
from typing import Any

from rich.console import Console, RenderableType
from rich.panel import Panel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def _reindex_dataset(db: AsyncSession, search_engine: SearchEngine, progre


async def _reindex_datasets(db: AsyncSession, search_engine: SearchEngine, progress: Progress) -> None:
task = progress.add_task(f"reindexing feedback datasets...", total=await Reindexer.count_datasets(db))
task = progress.add_task("reindexing feedback datasets...", total=await Reindexer.count_datasets(db))

async for dataset in Reindexer.reindex_datasets(db, search_engine):
await _reindex_dataset_records(db, search_engine, progress, dataset)
Expand Down
2 changes: 1 addition & 1 deletion argilla-server/src/argilla_server/contexts/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from passlib.context import CryptContext
from sqlalchemy import exists, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, selectinload
from sqlalchemy.orm import selectinload

from argilla_server.contexts import datasets
from argilla_server.enums import UserRole
Expand Down
4 changes: 2 additions & 2 deletions argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@
VectorSettingsCreate,
)
from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema
from argilla_server.contexts import accounts, questions
from argilla_server.contexts import accounts
from argilla_server.enums import DatasetStatus, RecordInclude, UserRole
from argilla_server.errors.future import NotFoundError, NotUniqueError, UnprocessableEntityError
from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError
from argilla_server.models import (
Dataset,
Field,
Expand Down
6 changes: 1 addition & 5 deletions argilla-server/src/argilla_server/contexts/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
from uuid import UUID

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

import argilla_server.errors.future as errors
from argilla_server.api.schemas.v1.questions import (
QuestionCreate,
QuestionUpdate,
)
from argilla_server.models import Dataset, Question, User
from argilla_server.models import Dataset, Question
from argilla_server.validators.questions import (
QuestionCreateValidator,
QuestionDeleteValidator,
Expand Down
8 changes: 3 additions & 5 deletions argilla-server/src/argilla_server/contexts/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Iterable, Sequence
from typing import Dict, Sequence
from uuid import UUID

from sqlalchemy import select, sql
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from argilla_server.models import Dataset, Record, Suggestion
from argilla_server.models import Dataset, Record


async def list_dataset_records_by_ids(
db: AsyncSession, dataset_id: UUID, record_ids: Sequence[UUID]
) -> Sequence[Record]:

query = select(Record).filter(Record.id.in_(record_ids), Record.dataset_id == dataset_id)
return (await db.execute(query)).unique().scalars().all()


async def list_dataset_records_by_external_ids(
db: AsyncSession, dataset_id: UUID, external_ids: Sequence[str]
) -> Sequence[Record]:

query = (
select(Record)
.filter(Record.external_id.in_(external_ids), Record.dataset_id == dataset_id)
Expand Down
2 changes: 1 addition & 1 deletion argilla-server/src/argilla_server/contexts/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Union
from typing import Union

from argilla_server.api.schemas.v1.settings import ArgillaSettings, HuggingfaceSettings, Settings
from argilla_server.settings import settings
Expand Down
1 change: 0 additions & 1 deletion argilla-server/src/argilla_server/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Common helper functions
"""
import logging
from typing import Any, Dict, List, Optional

_LOGGER = logging.getLogger("argilla_server")

Expand Down
2 changes: 1 addition & 1 deletion argilla-server/src/argilla_server/search_engine/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,7 @@ async def __value_count_aggregation(self, index_name: str, field_name: str, quer

async def __stats_aggregation(self, index_name: str, field_name: str, query: dict) -> dict:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-metrics-stats-aggregation.html
aggregation_name = f"numeric_stats"
aggregation_name = "numeric_stats"

stats_agg = {aggregation_name: {"stats": {"field": field_name}}}

Expand Down
Loading

0 comments on commit 1b1a21e

Please sign in to comment.