From 277468630a34da7f8ed8c9695b9aa31a9db4045c Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Wed, 24 Jul 2024 17:56:48 +0200 Subject: [PATCH 01/13] [ENHANCEMENT] improve argilla deployments when running on spaces (#5255) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description This is the main feature branch to improve the Argilla deployments when running on spaces. The main features that include this PR are: - The OAUTH client configuration will be automatically read from the HF environment (the `hf_oauth` flag must be set to `true`. See the [docs](https://huggingface.co/docs/hub/en/spaces-oauth#create-an-oauth-app) for more details) - If users want to create a specific owner without using the OAuth flow, `USERNAME` and `PASSWORD` env variables are available for that purpose (By default, `USERNAME` will be filled with the `SPACE_AUTHOR_NAME` value) - The Argilla template will be set to running Argilla with OAuth enabled by default (users don't need to fine-tune anything after cloning the template). - Workspaces defined in `.oauth.yaml` will be created automatically (By default, the template will provide the `argilla` workspace) - When the space author is an HF username, that user will be the argilla `owner` for the Argilla server. - When the space author is an HF organization, user roles will be computed from roles in the ORG -> (for now admin roles in the HF org will be mapped as `owner` roles in Argilla. The rest will be mapped as the `annotator` role). ## Tasks - [X] Update image README.md - [x] [Simplify environment variables](https://github.com/argilla-io/argilla/pull/5256) - [x] [Create a single user providing USERNAME and PASSWORD env variables (temporal solution)](https://github.com/argilla-io/argilla/pull/5256) - [x] [Reading injected `OAUTH_CLIENT_ID` and `OAUTH_CLIENT_SECRET` to avoid the OAuth app configuration step.](https://github.com/argilla-io/argilla/pull/5262) - [x] [Create the user with roles depending on the space privileges (user space VS org space)](https://github.com/argilla-io/argilla/pull/5299) - [x] [Create workspaces configured in `.oauth.yaml::allowed_workspaces` ](https://github.com/argilla-io/argilla/pull/5287) - [ ] Update docs - [x] [Rename quickstart image to `argilla-hf-spaces`](https://github.com/argilla-io/argilla/pull/5307) **Type of change** - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: José Francisco Calvo --- .../argilla-server.build-docker-images.yml | 29 ++-- argilla-server/CHANGELOG.md | 2 + .../Dockerfile | 21 +-- .../Procfile | 0 .../docker/argilla-hf-spaces/README.md | 15 ++ .../config/elasticsearch.yml | 0 .../requirements.txt | 0 .../scripts/start.sh} | 0 .../scripts/start_argilla_server.sh | 35 +++++ argilla-server/docker/quickstart/README.md | 134 ------------------ .../scripts/start_argilla_server.sh | 40 ------ argilla-server/pyproject.toml | 4 +- argilla-server/src/argilla_server/_app.py | 42 ++++-- .../argilla_server/api/handlers/v1/oauth2.py | 86 +++++------ .../argilla_server/api/schemas/v1/settings.py | 19 +-- .../src/argilla_server/contexts/settings.py | 5 +- .../argilla_server/integrations/__init__.py | 14 ++ .../integrations/huggingface/__init__.py | 14 ++ .../integrations/huggingface/spaces.py | 34 +++++ .../authentication/oauth2/__init__.py | 4 +- .../authentication/oauth2/auth_backend.py | 2 +- .../oauth2/providers/__init__.py | 42 ++++++ .../_base.py} | 7 +- .../oauth2/providers/_github.py | 28 ++++ .../oauth2/providers/_huggingface.py | 87 ++++++++++++ .../authentication/oauth2/settings.py | 10 +- .../oauth2/supported_providers.py | 51 ------- .../security/authentication/userinfo.py | 16 +++ .../src/argilla_server/utils/_telemetry.py | 6 +- .../handlers/v1/settings/test_get_settings.py | 2 +- .../tests/unit/api/handlers/v1/test_oauth2.py | 69 +++++++-- .../oauth2/providers/__init__.py | 14 ++ .../providers/test_huggingface_provider.py | 88 ++++++++++++ .../authentication/oauth2/test_settings.py | 4 +- .../security/authentication/test_userinfo.py | 53 +++++++ argilla-server/tests/unit/test_app.py | 64 ++++++++- 36 files changed, 677 insertions(+), 364 deletions(-) rename argilla-server/docker/{quickstart => argilla-hf-spaces}/Dockerfile (82%) rename argilla-server/docker/{quickstart => argilla-hf-spaces}/Procfile (100%) create mode 100644 argilla-server/docker/argilla-hf-spaces/README.md rename argilla-server/docker/{quickstart => argilla-hf-spaces}/config/elasticsearch.yml (100%) rename argilla-server/docker/{quickstart => argilla-hf-spaces}/requirements.txt (100%) rename argilla-server/docker/{quickstart/scripts/start_quickstart_argilla.sh => argilla-hf-spaces/scripts/start.sh} (100%) create mode 100755 argilla-server/docker/argilla-hf-spaces/scripts/start_argilla_server.sh delete mode 100644 argilla-server/docker/quickstart/README.md delete mode 100644 argilla-server/docker/quickstart/scripts/start_argilla_server.sh create mode 100644 argilla-server/src/argilla_server/integrations/__init__.py create mode 100644 argilla-server/src/argilla_server/integrations/huggingface/__init__.py create mode 100644 argilla-server/src/argilla_server/integrations/huggingface/spaces.py create mode 100644 argilla-server/src/argilla_server/security/authentication/oauth2/providers/__init__.py rename argilla-server/src/argilla_server/security/authentication/oauth2/{client_provider.py => providers/_base.py} (97%) create mode 100644 argilla-server/src/argilla_server/security/authentication/oauth2/providers/_github.py create mode 100644 argilla-server/src/argilla_server/security/authentication/oauth2/providers/_huggingface.py delete mode 100644 argilla-server/src/argilla_server/security/authentication/oauth2/supported_providers.py create mode 100644 argilla-server/tests/unit/security/authentication/oauth2/providers/__init__.py create mode 100644 argilla-server/tests/unit/security/authentication/oauth2/providers/test_huggingface_provider.py create mode 100644 argilla-server/tests/unit/security/authentication/test_userinfo.py diff --git a/.github/workflows/argilla-server.build-docker-images.yml b/.github/workflows/argilla-server.build-docker-images.yml index 61913a6e5a..4201f52edd 100644 --- a/.github/workflows/argilla-server.build-docker-images.yml +++ b/.github/workflows/argilla-server.build-docker-images.yml @@ -49,14 +49,14 @@ jobs: echo "PLATFORMS=linux/amd64,linux/arm64" >> $GITHUB_ENV echo "IMAGE_TAG=v$PACKAGE_VERSION" >> $GITHUB_ENV echo "SERVER_DOCKER_IMAGE=argilla/argilla-server" >> $GITHUB_ENV - echo "QUICKSTART_DOCKER_IMAGE=argilla/argilla-quickstart" >> $GITHUB_ENV + echo "HF_SPACES_DOCKER_IMAGE=argilla/argilla-hf-spaces" >> $GITHUB_ENV echo "DOCKER_USERNAME=$DOCKER_USERNAME" >> $GITHUB_ENV echo "DOCKER_PASSWORD=$DOCKER_PASSWORD" >> $GITHUB_ENV else echo "PLATFORMS=linux/amd64" >> $GITHUB_ENV echo "IMAGE_TAG=$DOCKER_IMAGE_TAG" >> $GITHUB_ENV echo "SERVER_DOCKER_IMAGE=argilladev/argilla-server" >> $GITHUB_ENV - echo "QUICKSTART_DOCKER_IMAGE=argilladev/argilla-quickstart" >> $GITHUB_ENV + echo "HF_SPACES_DOCKER_IMAGE=argilladev/argilla-hf-spaces" >> $GITHUB_ENV echo "DOCKER_USERNAME=$DOCKER_USERNAME_DEV" >> $GITHUB_ENV echo "DOCKER_PASSWORD=$DOCKER_PASSWORD_DEV" >> $GITHUB_ENV fi @@ -92,7 +92,6 @@ jobs: uses: docker/build-push-action@v5 with: context: argilla-server/docker/server - file: argilla-server/docker/server/Dockerfile platforms: ${{ env.PLATFORMS }} tags: ${{ env.SERVER_DOCKER_IMAGE }}:${{ env.IMAGE_TAG }} labels: ${{ steps.meta.outputs.labels }} @@ -103,35 +102,33 @@ jobs: uses: docker/build-push-action@v5 with: context: argilla-server/docker/server - file: argilla-server/docker/server/Dockerfile platforms: ${{ env.PLATFORMS }} tags: ${{ env.SERVER_DOCKER_IMAGE }}:latest labels: ${{ steps.meta.outputs.labels }} push: true - - name: Build and push `argilla-quickstart` image + - name: Build and push `argilla-hf-spaces` image uses: docker/build-push-action@v5 with: - context: argilla-server/docker/quickstart - file: argilla-server/docker/quickstart/Dockerfile + context: argilla-server/docker/argilla-hf-spaces platforms: ${{ env.PLATFORMS }} - tags: ${{ env.QUICKSTART_DOCKER_IMAGE }}:${{ env.IMAGE_TAG }} + tags: ${{ env.HF_SPACES_DOCKER_IMAGE }}:${{ env.IMAGE_TAG }} labels: ${{ steps.meta.outputs.labels }} build-args: | ARGILLA_SERVER_IMAGE=${{ env.SERVER_DOCKER_IMAGE }} ARGILLA_VERSION=${{ env.IMAGE_TAG }} push: true - - name: Push latest `argilla-quickstart` image + - name: Push latest `argilla-hf-spaces` image if: ${{ inputs.is_release && inputs.publish_latest }} uses: docker/build-push-action@v5 with: - context: argilla-server/docker/quickstart - file: argilla-server/docker/quickstart/Dockerfile + context: argilla-server/docker/argilla-hf-spaces platforms: ${{ env.PLATFORMS }} - tags: ${{ env.QUICKSTART_DOCKER_IMAGE }}:latest + tags: ${{ env.HF_SPACES_DOCKER_IMAGE }}:latest labels: ${{ steps.meta.outputs.labels }} build-args: | + ARGILLA_SERVER_IMAGE=${{ env.SERVER_DOCKER_IMAGE }} ARGILLA_VERSION=${{ env.IMAGE_TAG }} push: true @@ -141,14 +138,14 @@ jobs: with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_PASSWORD }} - repository: argilla/argilla-server + repository: $${{ env.SERVER_DOCKER_IMAGE }} readme-filepath: argilla-server/README.md - - name: Docker Hub Description for `argilla-quickstart` + - name: Docker Hub Description for `argilla-hf-spaces` uses: peter-evans/dockerhub-description@v4 if: ${{ inputs.is_release && inputs.publish_latest }} with: username: ${{ secrets.AR_DOCKER_USERNAME }} password: ${{ secrets.AR_DOCKER_PASSWORD }} - repository: argilla/argilla-quickstart - readme-filepath: argilla-server/docker/quickstart/README.md + repository: $${{ env.HF_SPACES_DOCKER_IMAGE }} + readme-filepath: argilla-server/docker/argilla-hf-spaces/README.md diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 8b99c0c776..a1c6037986 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -29,6 +29,7 @@ These are the section headers that we use: - Added new `ARGILLA_DATABASE_POSTGRESQL_MAX_OVERFLOW` environment variable allowing to set the number of connections that can be opened above and beyond the `ARGILLA_DATABASE_POSTGRESQL_POOL_SIZE` setting. ([#5220](https://github.com/argilla-io/argilla/pull/5220)) - Added new `Server-Timing` header to all responses with the total time in milliseconds the server took to generate the response. ([#5239](https://github.com/argilla-io/argilla/pull/5239)) - Added `REINDEX_DATASETS` environment variable to Argilla server Docker image. ([#5268](https://github.com/argilla-io/argilla/pull/5268)) +- Added `argilla-hf-spaces` docker image for running Argilla server in HF spaces. ([#5307](https://github.com/argilla-io/argilla/pull/5307)) ### Changed @@ -51,6 +52,7 @@ These are the section headers that we use: - [breaking] Removed support for `response_status` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5163](https://github.com/argilla-io/argilla/pull/5163)) - [breaking] Removed support for `metadata` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5156](https://github.com/argilla-io/argilla/pull/5156)) - [breaking] Removed support for `sort_by` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5166](https://github.com/argilla-io/argilla/pull/5166)) +- Removed argilla quickstart docker image (Older versions are still available). ([#5307](https://github.com/argilla-io/argilla/pull/5307)) ## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) diff --git a/argilla-server/docker/quickstart/Dockerfile b/argilla-server/docker/argilla-hf-spaces/Dockerfile similarity index 82% rename from argilla-server/docker/quickstart/Dockerfile rename to argilla-server/docker/argilla-hf-spaces/Dockerfile index 78f4d4119e..4c881a3c96 100644 --- a/argilla-server/docker/quickstart/Dockerfile +++ b/argilla-server/docker/argilla-hf-spaces/Dockerfile @@ -6,7 +6,7 @@ FROM ${ARGILLA_SERVER_IMAGE}:${ARGILLA_VERSION} USER root # Copy Argilla distribution files -COPY scripts/start_quickstart_argilla.sh /home/argilla +COPY scripts/start.sh /home/argilla COPY scripts/start_argilla_server.sh /home/argilla COPY Procfile /home/argilla COPY requirements.txt /packages/requirements.txt @@ -31,7 +31,7 @@ RUN \ chown argilla:argilla /etc/default/elasticsearch && \ # Install quickstart image dependencies pip install -r /packages/requirements.txt && \ - chmod +x /home/argilla/start_quickstart_argilla.sh && \ + chmod +x /home/argilla/start.sh && \ chmod +x /home/argilla/start_argilla_server.sh && \ # Give ownership of the data directory to the argilla user chown -R argilla:argilla /data && \ @@ -52,20 +52,9 @@ USER argilla ENV ELASTIC_CONTAINER=true ENV ES_JAVA_OPTS="-Xms1g -Xmx1g" -ENV OWNER_USERNAME=owner -ENV OWNER_PASSWORD=12345678 -ENV OWNER_API_KEY=owner.apikey - -ENV ADMIN_USERNAME=admin -ENV ADMIN_PASSWORD=12345678 -ENV ADMIN_API_KEY=admin.apikey - -ENV ANNOTATOR_USERNAME=argilla -ENV ANNOTATOR_PASSWORD=12345678 +ENV USERNAME="" +ENV PASSWORD="" ENV ARGILLA_HOME_PATH=/data/argilla -ENV ARGILLA_WORKSPACE=$ADMIN_USERNAME - -ENV UVICORN_PORT=6900 -CMD ["/bin/bash", "start_quickstart_argilla.sh"] +CMD ["/bin/bash", "start.sh"] diff --git a/argilla-server/docker/quickstart/Procfile b/argilla-server/docker/argilla-hf-spaces/Procfile similarity index 100% rename from argilla-server/docker/quickstart/Procfile rename to argilla-server/docker/argilla-hf-spaces/Procfile diff --git a/argilla-server/docker/argilla-hf-spaces/README.md b/argilla-server/docker/argilla-hf-spaces/README.md new file mode 100644 index 0000000000..521adc5f7a --- /dev/null +++ b/argilla-server/docker/argilla-hf-spaces/README.md @@ -0,0 +1,15 @@ +

+ Argilla +
+ Argilla +
+

+ +> This Docker image corresponds to the **Argilla Hugging Face Spaces deployment** and **can only be used to deploy Argilla inside the Hugging Face Hub**. For other type of deployments check the Argilla docs. + + +Argilla is a **collaboration tool for AI engineers and domain experts** that require **high-quality outputs, data ownership, and overall efficiency**. + +## Why use Argilla? + +Whether you are working on monitoring and improving complex **generative tasks** involving LLM pipelines with RAG, or you are working on a **predictive task** for things like AB-testing of span- and text-classification models. Our versatile platform helps you ensure **your data work pays off**. diff --git a/argilla-server/docker/quickstart/config/elasticsearch.yml b/argilla-server/docker/argilla-hf-spaces/config/elasticsearch.yml similarity index 100% rename from argilla-server/docker/quickstart/config/elasticsearch.yml rename to argilla-server/docker/argilla-hf-spaces/config/elasticsearch.yml diff --git a/argilla-server/docker/quickstart/requirements.txt b/argilla-server/docker/argilla-hf-spaces/requirements.txt similarity index 100% rename from argilla-server/docker/quickstart/requirements.txt rename to argilla-server/docker/argilla-hf-spaces/requirements.txt diff --git a/argilla-server/docker/quickstart/scripts/start_quickstart_argilla.sh b/argilla-server/docker/argilla-hf-spaces/scripts/start.sh similarity index 100% rename from argilla-server/docker/quickstart/scripts/start_quickstart_argilla.sh rename to argilla-server/docker/argilla-hf-spaces/scripts/start.sh diff --git a/argilla-server/docker/argilla-hf-spaces/scripts/start_argilla_server.sh b/argilla-server/docker/argilla-hf-spaces/scripts/start_argilla_server.sh new file mode 100755 index 0000000000..93cd012346 --- /dev/null +++ b/argilla-server/docker/argilla-hf-spaces/scripts/start_argilla_server.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +set -e + +# Preset oauth env vars based on injected space variables. +# See https://huggingface.co/docs/hub/en/spaces-oauth#create-an-oauth-app +export OAUTH2_HUGGINGFACE_CLIENT_ID=$OAUTH_CLIENT_ID +export OAUTH2_HUGGINGFACE_CLIENT_SECRET=$OAUTH_CLIENT_SECRET +export OAUTH2_HUGGINGFACE_SCOPE=$OAUTH_SCOPES + +echo "Running database migrations" +python -m argilla_server database migrate + +# Set the space author name as username if no provided. +# See https://huggingface.co/docs/hub/en/spaces-overview#helper-environment-variables for more details +USERNAME="${USERNAME:-$SPACE_AUTHOR_NAME}" + +if [ -n "$USERNAME" ] && [ -n "$PASSWORD" ]; then + echo "Creating owner user with username ${USERNAME}" + python -m argilla_server database users create \ + --first-name "$USERNAME" \ + --username "$USERNAME" \ + --password "$PASSWORD" \ + --role owner +else + echo "No username and password was provided. Skipping user creation" +fi + +# Forcing reindex on restart since elasticsearch data could be allocated in a non-persistent volume +echo "Reindexing existing datasets" +python -m argilla_server search-engine reindex + +# Start Argilla +echo "Starting Argilla" +python -m uvicorn argilla_server:app --host "0.0.0.0" diff --git a/argilla-server/docker/quickstart/README.md b/argilla-server/docker/quickstart/README.md deleted file mode 100644 index 3516605be9..0000000000 --- a/argilla-server/docker/quickstart/README.md +++ /dev/null @@ -1,134 +0,0 @@ -

- Argilla -
- Argilla -
-

-

- -CI - - -Codecov - -CI - -

- -

Open-source framework for data-centric NLP

-

Data Labeling, curation, and Inference Store

-

Designed for MLOps & Feedback Loops

- - -> 🆕 🔥 Play with Argilla UI with this [live-demo](https://argilla-live-demo.hf.space) powered by Hugging Face Spaces ( -> login:`argilla`, password:`12345678`) - -> 🆕 🔥 Since `1.2.0` Argilla supports vector search for finding the most similar records to a given one. This feature -> uses vector or semantic search combined with more traditional search (keyword and filter based). Learn more on -> this [deep-dive guide](https://docs.argilla.io/en/latest/guides/features/semantic-search.html) - - -![imagen](https://user-images.githubusercontent.com/1107111/204772677-facee627-9b3b-43ca-8533-bbc9b4e2d0aa.png) - - - - -

- - - - - - - - - -

-
- -

-

-Documentation | -Key Features | -Quickstart | -Principles | -Migration from Rubrix | -FAQ -

-

- -## Key Features - -### Advanced NLP labeling - -- Programmatic labeling - using [weak supervision](https://docs.argilla.io/en/latest/guides/techniques/weak_supervision.html). Built-in label - models (Snorkel, Flyingsquid) -- [Bulk-labeling](https://docs.argilla.io/en/latest/reference/webapp/features.html#bulk-annotate) - and [search-driven annotation](https://docs.argilla.io/en/latest/guides/features/queries.html) -- Iterate on training data with - any [pre-trained model](https://docs.argilla.io/en/latest/tutorials/libraries/huggingface.html) - or [library](https://docs.argilla.io/en/latest/tutorials/libraries/libraries.html) -- Efficiently review and refine annotations in the UI and with Python -- Use Argilla built-in metrics and methods - for [finding label and data errors (e.g., cleanlab)](https://docs.argilla.io/en/latest/tutorials/notebooks/monitoring-textclassification-cleanlab-explainability.html) -- Simple integration - with [active learning workflows](https://docs.argilla.io/en/latest/tutorials/techniques/active_learning.html) - -### Monitoring - -- Close the gap between production data and data collection activities -- [Auto-monitoring](https://docs.argilla.io/en/latest/guides/steps/3_deploying.html) - for [major NLP libraries and pipelines](https://docs.argilla.io/en/latest/tutorials/libraries/libraries.html) (spaCy, - Hugging Face, FlairNLP) -- [ASGI middleware](https://docs.argilla.io/en/latest/tutorials/notebooks/deploying-texttokenclassification-fastapi.html) - for HTTP endpoints -- Argilla Metrics to understand data and model - issues, [like entity consistency for NER models](https://docs.argilla.io/en/latest/guides/steps/4_monitoring.html) -- Integrated with Kibana for custom dashboards - -### Team workspaces - -- Bring different users and roles into the NLP data and model lifecycles -- Organize data collection, review and monitoring into - different [workspaces](https://docs.argilla.io/en/latest/getting_started/installation/user_management.html#workspace) -- Manage workspace access for different users - -## Quickstart - -Argilla is composed of a Python Server with Elasticsearch as the database layer, and a Python Client to create and -manage datasets. - -To get started you just need to run the docker image with following command: - -``` bash - docker run -d --name quickstart -p 6900:6900 argilla/argilla-quickstart:latest -``` - -This will run the latest quickstart docker image with 3 users `owner`, `admin`, and `argilla`. The password for these users is `12345678`. You can also configure these [environment variables](#environment-variables) as per you needs. - -### Environment Variables - -- `OWNER_USERNAME`: The owner username to log in Argilla. The default owner username is `owner`. By setting up - a custom username you can use your own username to login into the app. -- `OWNER_PASSWORD`: This sets a custom password for login into the app with the `owner` username. The default - password is `12345678`. By setting up a custom password you can use your own password to login into the app. -- `OWNER_API_KEY`: Argilla provides a Python library to interact with the app (read, write, and update data, log model - predictions, etc.). If you don't set this variable, the library and your app will use the default API key - i.e. `owner.apikey`. If you want to secure your app for reading and writing data, we recommend you to set up this - variable. The API key you choose can be any string of your choice and you can check an online generator if you like. -- `ADMIN_USERNAME`: The admin username to log in Argilla. The default admin username is `admin`. By setting up - a custom username you can use your own username to login into the app. -- `ADMIN_PASSWORD`: This sets a custom password for login into the app with the `argilla` username. The default - password is `12345678`. By setting up a custom password you can use your own password to login into the app. -- `ADMIN_API_KEY`: Argilla provides a Python library to interact with the app (read, write, and update data, log model - predictions, etc.). If you don't set this variable, the library and your app will use the default API key - i.e. `admin.apikey`. If you want to secure your app for reading and writing data, we recommend you to set up this - variable. The API key you choose can be any string of your choice and you can check an online generator if you like. -- `ANNOTATOR_USERNAME`: The annotator username to login in Argilla. The default annotator username is `argilla`. By setting - up a custom username you can use your own username to login into the app. -- `ANNOTATOR_PASSWORD`: This sets a custom password for login into the app with the `argilla` username. The default password - is `12345678`. By setting up a custom password you can use your own password to login into the app. -- `ARGILLA_WORKSPACE`: The name of a workspace that will be created and used by default for admin and annotator users. The default value will be the one defined by `ADMIN_USERNAME` environment variable. diff --git a/argilla-server/docker/quickstart/scripts/start_argilla_server.sh b/argilla-server/docker/quickstart/scripts/start_argilla_server.sh deleted file mode 100644 index 23a3a24336..0000000000 --- a/argilla-server/docker/quickstart/scripts/start_argilla_server.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env bash - -set -e - -echo "Running database migrations" -python -m argilla_server database migrate - -echo "Creating owner user" -python -m argilla_server database users create \ - --first-name "Owner" \ - --username "$OWNER_USERNAME" \ - --password "$OWNER_PASSWORD" \ - --api-key "$OWNER_API_KEY" \ - --role owner \ - --workspace "$ARGILLA_WORKSPACE" - -echo "Creating admin user" -python -m argilla_server database users create \ - --first-name "Admin" \ - --username "$ADMIN_USERNAME" \ - --password "$ADMIN_PASSWORD" \ - --api-key "$ADMIN_API_KEY" \ - --role admin \ - --workspace "$ARGILLA_WORKSPACE" - -echo "Creating annotator user" -python -m argilla_server database users create \ - --first-name "Annotator" \ - --username "$ANNOTATOR_USERNAME" \ - --password "$ANNOTATOR_PASSWORD" \ - --role annotator \ - --workspace "$ARGILLA_WORKSPACE" - -# Forcing reindex on restart since elasticsearch data could be allocated in a non-persistent volume -echo "Reindexing existing datasets" -python -m argilla_server search-engine reindex - -# Start Argilla -echo "Starting Argilla" -python -m uvicorn argilla_server:app --host "0.0.0.0" diff --git a/argilla-server/pyproject.toml b/argilla-server/pyproject.toml index 364d6a3dcf..ad6bb7e86c 100644 --- a/argilla-server/pyproject.toml +++ b/argilla-server/pyproject.toml @@ -173,5 +173,5 @@ server-dev.composite = [ ] test = { cmd = "pytest", env_file = ".env.test" } -build-server-image = { shell = "cp -R dist docker/server && docker build -t argilla/argilla-server:local docker/server" } -build-quickstart-image = { shell = "docker build --build-arg ARGILLA_VERSION=local -t argilla/argilla-quickstart:local docker/quickstart" } +docker-build-argilla-server = { shell = "pdm build && cp -R dist docker/server && docker build -t argilla/argilla-server:local docker/server" } +docker-build-argilla-hf-spaces = { shell = "pdm run docker-build-argilla-server && docker build --build-arg ARGILLA_VERSION=local -t argilla/argilla-hf-spaces:local docker/argilla-hf-spaces" } diff --git a/argilla-server/src/argilla_server/_app.py b/argilla-server/src/argilla_server/_app.py index c30b4007d2..5285d59b29 100644 --- a/argilla-server/src/argilla_server/_app.py +++ b/argilla-server/src/argilla_server/_app.py @@ -36,7 +36,7 @@ from argilla_server.contexts import accounts from argilla_server.database import get_async_db from argilla_server.logging import configure_logging -from argilla_server.models import User +from argilla_server.models import User, Workspace from argilla_server.search_engine import get_search_engine from argilla_server.settings import settings from argilla_server.static_rewrite import RewriteStaticFiles @@ -179,21 +179,35 @@ def show_telemetry_warning(): _LOGGER.warning(message) -async def configure_database(): - async def check_default_user(db: AsyncSession): - def _user_has_default_credentials(user: User): - return user.api_key == DEFAULT_API_KEY or accounts.verify_password(DEFAULT_PASSWORD, user.password_hash) - - default_user = await accounts.get_user_by_username(db, DEFAULT_USERNAME) - if default_user and _user_has_default_credentials(default_user): - _LOGGER.warning( - f"User {DEFAULT_USERNAME!r} with default credentials has been found in the database. " - "If you are using argilla in a production environment this can be a serious security problem. " - f"We recommend that you create a new admin user and then delete the default {DEFAULT_USERNAME!r} one." - ) +async def _create_oauth_allowed_workspaces(db: AsyncSession): + from argilla_server.security.settings import settings as security_settings + + if not security_settings.oauth.enabled: + return + + for allowed_workspace in security_settings.oauth.allowed_workspaces: + if await Workspace.get_by(db, name=allowed_workspace.name) is None: + _LOGGER.info(f"Creating workspace with name {allowed_workspace.name!r}") + await accounts.create_workspace(db, {"name": allowed_workspace.name}) + +async def _show_default_user_warning(db: AsyncSession): + def _user_has_default_credentials(user: User): + return user.api_key == DEFAULT_API_KEY or accounts.verify_password(DEFAULT_PASSWORD, user.password_hash) + + default_user = await User.get_by(db, username=DEFAULT_USERNAME) + if default_user and _user_has_default_credentials(default_user): + _LOGGER.warning( + f"User {DEFAULT_USERNAME!r} with default credentials has been found in the database. " + "If you are using argilla in a production environment this can be a serious security problem. " + f"We recommend that you create a new admin user and then delete the default {DEFAULT_USERNAME!r} one." + ) + + +async def configure_database(): async with contextlib.asynccontextmanager(get_async_db)() as db: - await check_default_user(db) + await _show_default_user_warning(db) + await _create_oauth_allowed_workspaces(db) async def configure_search_engine(): diff --git a/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py b/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py index 96f5912492..b1f61ef09c 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py @@ -11,18 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, Path from fastapi.responses import RedirectResponse from sqlalchemy.ext.asyncio import AsyncSession from argilla_server import telemetry from argilla_server.api.schemas.v1.oauth2 import Provider, Providers, Token +from argilla_server.api.schemas.v1.users import UserCreate from argilla_server.contexts import accounts from argilla_server.database import get_async_db from argilla_server.enums import UserRole -from argilla_server.errors.future import AuthenticationError +from argilla_server.errors.future import AuthenticationError, NotFoundError from argilla_server.models import User +from argilla_server.pydantic_v1 import Field, ValidationError from argilla_server.security.authentication.jwt import JWT from argilla_server.security.authentication.oauth2 import OAuth2ClientProvider from argilla_server.security.authentication.userinfo import UserInfo @@ -31,11 +34,26 @@ router = APIRouter(prefix="/oauth2", tags=["Authentication"]) -_USER_ROLE_ON_CREATION = UserRole.annotator +class UserOAuthCreate(UserCreate): + """This schema is used to validate the creation of a new user by using the oauth userinfo""" + + username: str = Field(min_length=1) + role: Optional[UserRole] + password: Optional[str] = None + + +def get_provider_by_name_or_raise(provider: str = Path()) -> OAuth2ClientProvider: + if not settings.oauth.enabled: + raise NotFoundError(message="OAuth2 is not enabled") + + if provider in settings.oauth.providers: + return settings.oauth.providers[provider] + + raise NotFoundError(message=f"OAuth Provider '{provider}' not found") @router.get("/providers", response_model=Providers) -def list_providers(_request: Request) -> Providers: +def list_providers() -> Providers: if not settings.oauth.enabled: return Providers(items=[]) @@ -43,61 +61,45 @@ def list_providers(_request: Request) -> Providers: @router.get("/providers/{provider}/authentication") -def get_authentication(request: Request, provider: str) -> RedirectResponse: - _check_oauth_enabled_or_raise() - - provider = _get_provider_by_name_or_raise(provider) +def get_authentication( + request: Request, + provider: OAuth2ClientProvider = Depends(get_provider_by_name_or_raise), +) -> RedirectResponse: return provider.authorization_redirect(request) @router.get("/providers/{provider}/access-token", response_model=Token) async def get_access_token( request: Request, - provider: str, + provider: OAuth2ClientProvider = Depends(get_provider_by_name_or_raise), db: AsyncSession = Depends(get_async_db), ) -> Token: - _check_oauth_enabled_or_raise() - try: - provider = _get_provider_by_name_or_raise(provider) - user_info = UserInfo(await provider.get_user_data(request)) - - user_info.use_claims(provider.claims) - username = user_info.username - - user = await accounts.get_user_by_username(db, username) + user_info = UserInfo(await provider.get_user_data(request)).use_claims(provider.claims) + user = await User.get_by(db, username=user_info.username) if user is None: + try: + user_create = UserOAuthCreate( + username=user_info.username, + first_name=user_info.first_name, + role=user_info.role, + ) + except ValidationError as ex: + raise AuthenticationError("Could not authenticate user") from ex + user = await accounts.create_user_with_random_password( db, - username=username, - first_name=user_info.name, - role=_USER_ROLE_ON_CREATION, + **user_create.dict(exclude_unset=True), workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces], ) telemetry.track_user_created(user, is_oauth=True) - elif not _is_user_created_by_oauth_provider(user): - # User should sign in using username/password workflow + + elif user.role != user_info.role: raise AuthenticationError("Could not authenticate user") return Token(access_token=JWT.create(user_info)) except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) + raise HTTPException(status_code=400, detail=str(e)) from e + # TODO: Create exception handler for AuthenticationError except AuthenticationError as e: - raise HTTPException(status_code=401, detail=str(e)) - - -def _check_oauth_enabled_or_raise() -> None: - if not settings.oauth.enabled: - raise HTTPException(status_code=404, detail="OAuth2 is not enabled") - - -def _get_provider_by_name_or_raise(provider_name: str) -> OAuth2ClientProvider: - if provider_name not in settings.oauth.providers: - raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found") - return settings.oauth.providers[provider_name] - - -def _is_user_created_by_oauth_provider(user: User) -> bool: - # TODO: We must link the created user with the provider, and base this check on that. - # For now, we just validate the user role on creation. - return user.role == _USER_ROLE_ON_CREATION + raise HTTPException(status_code=401, detail=str(e)) from e diff --git a/argilla-server/src/argilla_server/api/schemas/v1/settings.py b/argilla-server/src/argilla_server/api/schemas/v1/settings.py index 62df486aef..a8d00562ca 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/settings.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/settings.py @@ -14,29 +14,14 @@ from typing import Optional -from argilla_server.pydantic_v1 import BaseModel, BaseSettings, Field +from argilla_server.integrations.huggingface.spaces import HuggingfaceSettings +from argilla_server.pydantic_v1 import BaseModel class ArgillaSettings(BaseModel): show_huggingface_space_persistent_storage_warning: Optional[bool] -class HuggingfaceSettings(BaseSettings): - space_id: str = Field(None, env="SPACE_ID") - space_title: str = Field(None, env="SPACE_TITLE") - space_subdomain: str = Field(None, env="SPACE_SUBDOMAIN") - space_host: str = Field(None, env="SPACE_HOST") - space_repo_name: str = Field(None, env="SPACE_REPO_NAME") - space_author_name: str = Field(None, env="SPACE_AUTHOR_NAME") - # NOTE: Hugging Face has a typo in their environment variable name, - # using PERSISTANT instead of PERSISTENT. We will use the correct spelling in our code. - space_persistent_storage_enabled: bool = Field(False, env="PERSISTANT_STORAGE_ENABLED") - - @property - def is_running_on_huggingface(self) -> bool: - return bool(self.space_id) - - class Settings(BaseModel): argilla: ArgillaSettings huggingface: Optional[HuggingfaceSettings] diff --git a/argilla-server/src/argilla_server/contexts/settings.py b/argilla-server/src/argilla_server/contexts/settings.py index d2f483a185..c7ca9e4dec 100644 --- a/argilla-server/src/argilla_server/contexts/settings.py +++ b/argilla-server/src/argilla_server/contexts/settings.py @@ -14,11 +14,10 @@ from typing import Union -from argilla_server.api.schemas.v1.settings import ArgillaSettings, HuggingfaceSettings, Settings +from argilla_server.api.schemas.v1.settings import ArgillaSettings, Settings +from argilla_server.integrations.huggingface.spaces import HUGGINGFACE_SETTINGS, HuggingfaceSettings from argilla_server.settings import settings -HUGGINGFACE_SETTINGS = HuggingfaceSettings() - def get_settings() -> Settings: return Settings( diff --git a/argilla-server/src/argilla_server/integrations/__init__.py b/argilla-server/src/argilla_server/integrations/__init__.py new file mode 100644 index 0000000000..4b6cecae7f --- /dev/null +++ b/argilla-server/src/argilla_server/integrations/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/argilla-server/src/argilla_server/integrations/huggingface/__init__.py b/argilla-server/src/argilla_server/integrations/huggingface/__init__.py new file mode 100644 index 0000000000..4b6cecae7f --- /dev/null +++ b/argilla-server/src/argilla_server/integrations/huggingface/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/argilla-server/src/argilla_server/integrations/huggingface/spaces.py b/argilla-server/src/argilla_server/integrations/huggingface/spaces.py new file mode 100644 index 0000000000..6b95187866 --- /dev/null +++ b/argilla-server/src/argilla_server/integrations/huggingface/spaces.py @@ -0,0 +1,34 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pydantic import BaseSettings, Field + + +class HuggingfaceSettings(BaseSettings): + space_id: str = Field(None, env="SPACE_ID") + space_title: str = Field(None, env="SPACE_TITLE") + space_subdomain: str = Field(None, env="SPACE_SUBDOMAIN") + space_host: str = Field(None, env="SPACE_HOST") + space_repo_name: str = Field(None, env="SPACE_REPO_NAME") + space_author_name: str = Field(None, env="SPACE_AUTHOR_NAME") + # NOTE: Hugging Face has a typo in their environment variable name, + # using PERSISTANT instead of PERSISTENT. We will use the correct spelling in our code. + space_persistent_storage_enabled: bool = Field(False, env="PERSISTANT_STORAGE_ENABLED") + + @property + def is_running_on_huggingface(self) -> bool: + return bool(self.space_id) + + +HUGGINGFACE_SETTINGS = HuggingfaceSettings() diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/__init__.py b/argilla-server/src/argilla_server/security/authentication/oauth2/__init__.py index 7c77829cdf..f8dcc52a18 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/__init__.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .client_provider import OAuth2ClientProvider # noqa -from .settings import OAuth2Settings # noqa +from argilla_server.security.authentication.oauth2.providers import OAuth2ClientProvider # noqa +from argilla_server.security.authentication.oauth2.settings import OAuth2Settings # noqa __all__ = ["OAuth2Settings", "OAuth2ClientProvider"] diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py b/argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py index a196fd244d..53774a2bf8 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py @@ -19,7 +19,7 @@ from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser from argilla_server.security.authentication.jwt import JWT -from argilla_server.security.authentication.oauth2.client_provider import OAuth2ClientProvider +from argilla_server.security.authentication.oauth2.providers import OAuth2ClientProvider from argilla_server.security.authentication.userinfo import UserInfo diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/__init__.py b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/__init__.py new file mode 100644 index 0000000000..0950bc30d2 --- /dev/null +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/__init__.py @@ -0,0 +1,42 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Type + +from argilla_server.errors.future import NotFoundError +from argilla_server.security.authentication.oauth2.providers._base import OAuth2ClientProvider +from argilla_server.security.authentication.oauth2.providers._github import GitHubClientProvider +from argilla_server.security.authentication.oauth2.providers._huggingface import HuggingfaceClientProvider + +__all__ = [ + "OAuth2ClientProvider", + "GitHubClientProvider", + "HuggingfaceClientProvider", + "get_provider_by_name", +] + +_ALL_SUPPORTED_OAUTH2_PROVIDERS = { + GitHubClientProvider.name: GitHubClientProvider, + HuggingfaceClientProvider.name: HuggingfaceClientProvider, +} + + +def get_provider_by_name(name: str) -> Type["OAuth2ClientProvider"]: + """Get a registered oauth provider by name. Raise a ValueError if provided not found.""" + if provider_class := _ALL_SUPPORTED_OAUTH2_PROVIDERS.get(name): + return provider_class + else: + raise NotFoundError( + f"Unsupported provider {name}. " f"Supported providers are {_ALL_SUPPORTED_OAUTH2_PROVIDERS.keys()}" + ) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/client_provider.py b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py similarity index 97% rename from argilla-server/src/argilla_server/security/authentication/oauth2/client_provider.py rename to argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py index e6ddc9a069..239a7c9fcf 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/client_provider.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py @@ -17,16 +17,17 @@ import random import re import string -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union +from typing import Dict, Any, ClassVar, Type, Optional, Union, List, Tuple from urllib.parse import urljoin import httpx -from fastapi import Request -from fastapi.responses import RedirectResponse from oauthlib.oauth2 import WebApplicationClient from social_core.backends.oauth import BaseOAuth2 from social_core.exceptions import AuthException + from social_core.strategy import BaseStrategy +from starlette.requests import Request +from starlette.responses import RedirectResponse from argilla_server.errors import future from argilla_server.security.authentication.claims import Claims diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_github.py b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_github.py new file mode 100644 index 0000000000..ea4f3f1918 --- /dev/null +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_github.py @@ -0,0 +1,28 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from social_core.backends.github import GithubOAuth2 + +from argilla_server.security.authentication.claims import Claims +from argilla_server.security.authentication.oauth2.providers._base import OAuth2ClientProvider + + +class GitHubClientProvider(OAuth2ClientProvider): + claims = Claims( + picture="avatar_url", + identity=lambda user: f"{user.provider}:{user.id}", + username="login", + ) + backend_class = GithubOAuth2 + name = "github" diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_huggingface.py b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_huggingface.py new file mode 100644 index 0000000000..ea7bb23a79 --- /dev/null +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_huggingface.py @@ -0,0 +1,87 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Union, Optional + +from social_core.backends.open_id_connect import OpenIdConnectAuth + +from argilla_server.enums import UserRole +from argilla_server.integrations.huggingface.spaces import HUGGINGFACE_SETTINGS +from argilla_server.logging import LoggingMixin +from argilla_server.security.authentication.claims import Claims +from argilla_server.security.authentication.oauth2.providers._base import OAuth2ClientProvider + +_LOGGER = logging.getLogger("argilla.security.oauth2.providers.huggingface") + + +class HuggingfaceOpenId(OpenIdConnectAuth): + """Huggingface OpenID Connect authentication backend.""" + + name = "huggingface" + + OIDC_ENDPOINT = "https://huggingface.co" + AUTHORIZATION_URL = "https://huggingface.co/oauth/authorize" + ACCESS_TOKEN_URL = "https://huggingface.co/oauth/token" + + def oidc_endpoint(self) -> str: + return self.OIDC_ENDPOINT + + +_HF_PREFERRED_USERNAME = "preferred_username" + + +def _is_space_author(userinfo: dict, space_author: str) -> bool: + """Return True if the space author name is the userinfo username. Otherwise, False""" + return space_author and space_author == userinfo.get(_HF_PREFERRED_USERNAME) + + +def _find_org_from_userinfo(userinfo: dict, org_name: str) -> Optional[dict]: + """Find the organization by name from the userinfo""" + for org in userinfo.get("orgs") or []: + if org_name == org.get(_HF_PREFERRED_USERNAME): + return org + + +def _get_user_role_by_org(org: dict) -> Union[UserRole, None]: + """Return the computed UserRole from the role found in a organization (if any)""" + _ROLE_IN_ORG = "roleInOrg" + _ROLES_MAPPING = {"admin": UserRole.owner} + + org_role = None + if _ROLE_IN_ORG not in org: + _LOGGER.warning(f"Cannot find the user role info in org {org}. Review granted permissions") + else: + org_role = org[_ROLE_IN_ORG] + + return _ROLES_MAPPING.get(org_role) or UserRole.annotator + + +class HuggingfaceClientProvider(OAuth2ClientProvider, LoggingMixin): + """Specialized HuggingFace OAuth2 provider.""" + + @staticmethod + def parse_role_from_userinfo(userinfo: dict) -> Union[str, None]: + """Parse the Argilla user role from info provided as part of the user info""" + space_author_name = HUGGINGFACE_SETTINGS.space_author_name + + if _is_space_author(userinfo, space_author_name): + return UserRole.owner + elif org := _find_org_from_userinfo(userinfo, space_author_name): + return _get_user_role_by_org(org) + return UserRole.annotator + + claims = Claims(username=_HF_PREFERRED_USERNAME, role=parse_role_from_userinfo, first_name="name") + backend_class = HuggingfaceOpenId + name = "huggingface" diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py b/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py index e4627d0d95..e4771bad07 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py @@ -16,8 +16,7 @@ import yaml -from argilla_server.security.authentication.oauth2.client_provider import OAuth2ClientProvider -from argilla_server.security.authentication.oauth2.supported_providers import ALL_SUPPORTED_OAUTH2_PROVIDERS +from argilla_server.security.authentication.oauth2.providers import get_provider_by_name, OAuth2ClientProvider __all__ = ["OAuth2Settings"] @@ -92,12 +91,7 @@ def _build_providers(cls, settings: dict) -> List["OAuth2ClientProvider"]: for provider in settings.pop("providers", []): name = provider.pop("name") - - provider_class = ALL_SUPPORTED_OAUTH2_PROVIDERS.get(name) - if not provider_class: - raise ValueError( - f"Unsupported provider {name}. Supported providers are {ALL_SUPPORTED_OAUTH2_PROVIDERS.keys()}" - ) + provider_class = get_provider_by_name(name) providers.append(provider_class.from_dict(provider)) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/supported_providers.py b/argilla-server/src/argilla_server/security/authentication/oauth2/supported_providers.py deleted file mode 100644 index 65fc715b7e..0000000000 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/supported_providers.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from social_core.backends.github import GithubOAuth2 -from social_core.backends.open_id_connect import OpenIdConnectAuth - -from argilla_server.security.authentication.claims import Claims -from argilla_server.security.authentication.oauth2.client_provider import OAuth2ClientProvider - - -class HuggingfaceOpenId(OpenIdConnectAuth): - """Huggingface OpenID Connect authentication backend.""" - - name = "huggingface" - - OIDC_ENDPOINT = "https://huggingface.co" - AUTHORIZATION_URL = "https://huggingface.co/oauth/authorize" - ACCESS_TOKEN_URL = "https://huggingface.co/oauth/token" - - def oidc_endpoint(self) -> str: - return self.OIDC_ENDPOINT - - -class GitHubClientProvider(OAuth2ClientProvider): - claims = Claims(picture="avatar_url", identity=lambda user: f"{user.provider}:{user.id}", username="login") - backend_class = GithubOAuth2 - name = "github" - - -class HuggingfaceClientProvider(OAuth2ClientProvider): - """Specialized HuggingFace OAuth2 provider.""" - - claims = Claims(username="preferred_username") - backend_class = HuggingfaceOpenId - name = "huggingface" - - -_providers = [GitHubClientProvider, HuggingfaceClientProvider] - -ALL_SUPPORTED_OAUTH2_PROVIDERS = {provider_class.name: provider_class for provider_class in _providers} diff --git a/argilla-server/src/argilla_server/security/authentication/userinfo.py b/argilla-server/src/argilla_server/security/authentication/userinfo.py index f61549f853..54220fc027 100644 --- a/argilla-server/src/argilla_server/security/authentication/userinfo.py +++ b/argilla-server/src/argilla_server/security/authentication/userinfo.py @@ -16,8 +16,11 @@ from starlette.authentication import BaseUser +from argilla_server.enums import UserRole from argilla_server.security.authentication.claims import Claims +_DEFAULT_USER_ROLE = UserRole.annotator + class UserInfo(BaseUser, dict): """User info from a provider.""" @@ -26,6 +29,19 @@ class UserInfo(BaseUser, dict): def is_authenticated(self) -> bool: return True + @property + def username(self) -> str: + return self["username"] + + @property + def first_name(self) -> str: + return self.get("first_name") or self.username + + @property + def role(self) -> UserRole: + role = self.get("role") or _DEFAULT_USER_ROLE + return UserRole(role) + def use_claims(self, claims: Optional[Claims]) -> "UserInfo": claims = claims or {} diff --git a/argilla-server/src/argilla_server/utils/_telemetry.py b/argilla-server/src/argilla_server/utils/_telemetry.py index 6d03cf5379..8015ba5319 100644 --- a/argilla-server/src/argilla_server/utils/_telemetry.py +++ b/argilla-server/src/argilla_server/utils/_telemetry.py @@ -14,6 +14,8 @@ import logging import os +from argilla_server.integrations.huggingface.spaces import HUGGINGFACE_SETTINGS + _LOGGER = logging.getLogger(__name__) @@ -33,9 +35,7 @@ def server_deployment_type() -> str: def is_running_on_huggingface_space() -> bool: """Returns True if the current process is running inside a Huggingface Space, False otherwise.""" - from argilla_server.api.schemas.v1.settings import HuggingfaceSettings - - return HuggingfaceSettings().is_running_on_huggingface + return HUGGINGFACE_SETTINGS.is_running_on_huggingface def is_running_on_docker_container() -> bool: diff --git a/argilla-server/tests/unit/api/handlers/v1/settings/test_get_settings.py b/argilla-server/tests/unit/api/handlers/v1/settings/test_get_settings.py index cd4a39e8da..e52bd3a5ea 100644 --- a/argilla-server/tests/unit/api/handlers/v1/settings/test_get_settings.py +++ b/argilla-server/tests/unit/api/handlers/v1/settings/test_get_settings.py @@ -16,9 +16,9 @@ from unittest import mock import pytest -from argilla_server.api.schemas.v1.settings import HuggingfaceSettings from argilla_server.contexts import settings as settings_context from argilla_server.contexts.settings import HUGGINGFACE_SETTINGS +from argilla_server.integrations.huggingface.spaces import HuggingfaceSettings from argilla_server.settings import settings as argilla_server_settings from httpx import AsyncClient diff --git a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py index 81b55ba152..34abf0e415 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py @@ -15,16 +15,16 @@ from unittest import mock import pytest -from argilla_server.enums import UserRole -from argilla_server.errors.future import AuthenticationError -from argilla_server.models import User -from argilla_server.security.authentication import JWT -from argilla_server.security.authentication.oauth2 import OAuth2Settings from httpcore import URL from httpx import AsyncClient from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from argilla_server.enums import UserRole +from argilla_server.errors.future import AuthenticationError +from argilla_server.models import User +from argilla_server.security.authentication import JWT +from argilla_server.security.authentication.oauth2 import OAuth2Settings from tests.factories import AdminFactory, AnnotatorFactory @@ -143,7 +143,7 @@ async def test_provider_huggingface_access_token( ): with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.client_provider.OAuth2ClientProvider._fetch_user_data", + "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", return_value={"preferred_username": "username", "name": "name"}, ): response = await async_client.get( @@ -163,6 +163,57 @@ async def test_provider_huggingface_access_token( assert user is not None assert user.role == UserRole.annotator + async def test_provider_huggingface_access_token_with_missing_username( + self, + async_client: AsyncClient, + db: AsyncSession, + owner_auth_header: dict, + default_oauth_settings: OAuth2Settings, + ): + with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): + with mock.patch( + "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", + return_value={"name": "name"}, + ): + response = await async_client.get( + "/api/v1/oauth2/providers/huggingface/access-token", + params={"code": "code", "state": "valid"}, + headers=owner_auth_header, + cookies={"oauth2_state": "valid"}, + ) + + assert response.status_code == 401 + + async def test_provider_huggingface_access_token_with_missing_name( + self, + async_client: AsyncClient, + db: AsyncSession, + owner_auth_header: dict, + default_oauth_settings: OAuth2Settings, + ): + with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): + with mock.patch( + "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", + return_value={"preferred_username": "username"}, + ): + response = await async_client.get( + "/api/v1/oauth2/providers/huggingface/access-token", + params={"code": "code", "state": "valid"}, + headers=owner_auth_header, + cookies={"oauth2_state": "valid"}, + ) + + assert response.status_code == 200 + + json_response = response.json() + assert JWT.decode(json_response["access_token"])["username"] == "username" + assert json_response["token_type"] == "bearer" + + user = await db.scalar(select(User).filter_by(username="username")) + assert user is not None + assert user.role == UserRole.annotator + assert user.first_name == "username" + async def test_provider_access_token_with_oauth_disabled( self, async_client: AsyncClient, @@ -224,7 +275,7 @@ async def test_provider_access_token_with_authentication_error( ): with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.client_provider.OAuth2ClientProvider._fetch_user_data", + "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", side_effect=AuthenticationError("error"), ): response = await async_client.get( @@ -247,7 +298,7 @@ async def test_provider_access_token_with_unauthorized_user( with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.client_provider.OAuth2ClientProvider._fetch_user_data", + "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", return_value={"preferred_username": admin.username, "name": admin.first_name}, ): response = await async_client.get( @@ -270,7 +321,7 @@ async def test_provider_access_token_with_same_username( with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.client_provider.OAuth2ClientProvider._fetch_user_data", + "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", return_value={"preferred_username": user.username, "name": user.first_name}, ): response = await async_client.get( diff --git a/argilla-server/tests/unit/security/authentication/oauth2/providers/__init__.py b/argilla-server/tests/unit/security/authentication/oauth2/providers/__init__.py new file mode 100644 index 0000000000..4b6cecae7f --- /dev/null +++ b/argilla-server/tests/unit/security/authentication/oauth2/providers/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/argilla-server/tests/unit/security/authentication/oauth2/providers/test_huggingface_provider.py b/argilla-server/tests/unit/security/authentication/oauth2/providers/test_huggingface_provider.py new file mode 100644 index 0000000000..88b526212e --- /dev/null +++ b/argilla-server/tests/unit/security/authentication/oauth2/providers/test_huggingface_provider.py @@ -0,0 +1,88 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytest_mock import MockerFixture + +from argilla_server.enums import UserRole +from argilla_server.integrations.huggingface.spaces import HuggingfaceSettings +from argilla_server.security.authentication.oauth2.providers import HuggingfaceClientProvider +from argilla_server.security.authentication.oauth2.providers import _huggingface + + +class TestHuggingfaceOauthProvider: + def test_parse_role_from_userinfo_for_space_author(self, mocker: "MockerFixture"): + mocker.patch.object(_huggingface, "HUGGINGFACE_SETTINGS", HuggingfaceSettings(space_author_name="author")) + + userinfo = {"preferred_username": "author"} + role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) + assert role == UserRole.owner + + def test_parse_role_without_spaces_info(self, mocker: "MockerFixture"): + mocker.patch.object(_huggingface, "HUGGINGFACE_SETTINGS", HuggingfaceSettings(space_author_name=None)) + + userinfo = {"preferred_username": "author"} + role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) + assert role == UserRole.annotator + + def test_parse_role_with_different_author_name(self, mocker: "MockerFixture"): + mocker.patch.object(_huggingface, "HUGGINGFACE_SETTINGS", HuggingfaceSettings(space_author_name="other")) + + userinfo = {"preferred_username": "author"} + role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) + assert role == UserRole.annotator + + def test_parse_role_with_missing_username(self): + userinfo = {} + role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) + assert role == UserRole.annotator + + def test_parse_role_with_admin_role_in_org(self, mocker: "MockerFixture"): + mocker.patch.object(_huggingface, "HUGGINGFACE_SETTINGS", HuggingfaceSettings(space_author_name="org")) + + userinfo = { + "preferred_username": "author", + "orgs": [{"preferred_username": "org", "roleInOrg": "admin"}], + } + role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) + assert role == UserRole.owner + + def test_parse_role_with_non_admin_role_in_org(self, mocker: "MockerFixture"): + mocker.patch.object(_huggingface, "HUGGINGFACE_SETTINGS", HuggingfaceSettings(space_author_name="org")) + + userinfo = { + "preferred_username": "author", + "orgs": [{"preferred_username": "org", "roleInOrg": "contributor"}], + } + role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) + assert role == UserRole.annotator + + def test_parse_role_for_other_org_author(self, mocker: "MockerFixture"): + mocker.patch.object(_huggingface, "HUGGINGFACE_SETTINGS", HuggingfaceSettings(space_author_name="other_org")) + + userinfo = { + "preferred_username": "author", + "orgs": [{"preferred_username": "org", "roleInOrg": "contributor"}], + } + role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) + assert role == UserRole.annotator + + def test_parse_role_with_missing_org_role_info(self, mocker: "MockerFixture"): + mocker.patch.object(_huggingface, "HUGGINGFACE_SETTINGS", HuggingfaceSettings(space_author_name="org")) + + userinfo = { + "preferred_username": "author", + "orgs": [{"preferred_username": "org"}], + } + role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) + assert role == UserRole.annotator diff --git a/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py b/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py index ceeef30132..c0175f273b 100644 --- a/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py +++ b/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py @@ -13,12 +13,14 @@ # limitations under the License. import pytest + +from argilla_server.errors.future import NotFoundError from argilla_server.security.authentication.oauth2 import OAuth2Settings class TestOAuth2Settings: def test_configure_unsupported_provider(self): - with pytest.raises(ValueError): + with pytest.raises(NotFoundError): OAuth2Settings.from_dict({"providers": [{"name": "unsupported"}]}) def test_configure_github_provider(self): diff --git a/argilla-server/tests/unit/security/authentication/test_userinfo.py b/argilla-server/tests/unit/security/authentication/test_userinfo.py new file mode 100644 index 0000000000..918e99ba5e --- /dev/null +++ b/argilla-server/tests/unit/security/authentication/test_userinfo.py @@ -0,0 +1,53 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from argilla_server.enums import UserRole +from argilla_server.security.authentication import UserInfo +from argilla_server.security.authentication.claims import Claims + + +class TestUserInfo: + def test_get_user_name_without_claims(self): + userinfo = UserInfo() + with pytest.raises(KeyError): + userinfo.username # noqa + + def test_get_userinfo_first_name(self): + userinfo = UserInfo({"username": "user", "first_name": "User"}) + assert userinfo.first_name == "User" + + def test_get_default_userinfo_first_name(self): + userinfo = UserInfo({"username": "user"}) + assert userinfo.first_name == "user" + + def test_get_default_userinfo_role(self): + userinfo = UserInfo({"username": "user"}) + assert userinfo.role == UserRole.annotator + + def test_get_userinfo_role(self): + userinfo = UserInfo({"username": "user", "role": "owner"}) + assert userinfo.role == UserRole.owner + + def test_get_userinfo_with_claims(self): + userinfo = UserInfo({"username": "user"}).use_claims( + Claims( + first_name=lambda user: user["username"].upper(), + last_name=lambda user: "Peter", + ) + ) + + assert userinfo.first_name == "USER" + assert userinfo.last_name == "Peter" diff --git a/argilla-server/tests/unit/test_app.py b/argilla-server/tests/unit/test_app.py index aff6075f3d..7472f661bf 100644 --- a/argilla-server/tests/unit/test_app.py +++ b/argilla-server/tests/unit/test_app.py @@ -12,13 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import cast +from unittest import mock import pytest -from argilla_server._app import create_server_app +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server._app import create_server_app, configure_database, _create_oauth_allowed_workspaces +from argilla_server.models import Workspace +from argilla_server.security.authentication.oauth2 import OAuth2Settings +from argilla_server.security.authentication.oauth2.settings import AllowedWorkspace from argilla_server.settings import Settings, settings from starlette.routing import Mount from starlette.testclient import TestClient +from tests.factories import WorkspaceFactory + @pytest.fixture def test_settings(): @@ -27,6 +36,7 @@ def test_settings(): settings.base_url = "/" +@pytest.mark.asyncio class TestApp: def test_create_app_with_base_url(self, test_settings: Settings): base_url = "/base/url" @@ -56,3 +66,55 @@ def test_server_timing_header(self): response = client.get("/api/v1/version") assert response.headers["Server-Timing"] + + async def test_create_allowed_workspaces(self, db: AsyncSession): + with mock.patch( + "argilla_server.security.settings.Settings.oauth", + new_callable=lambda: OAuth2Settings.from_dict( + { + "enabled": True, + "allowed_workspaces": [{"name": "ws1"}, {"name": "ws2"}], + } + ), + ): + await _create_oauth_allowed_workspaces(db) + + workspaces = (await db.scalars(select(Workspace))).all() + assert len(workspaces) == 2 + assert set([ws.name for ws in workspaces]) == {"ws1", "ws2"} + + async def test_create_allowed_workspaces_with_oauth_disabled(self, db: AsyncSession): + with mock.patch( + "argilla_server.security.settings.Settings.oauth", + new_callable=lambda: OAuth2Settings.from_dict( + { + "enabled": False, + "allowed_workspaces": [{"name": "ws1"}, {"name": "ws2"}], + } + ), + ): + await _create_oauth_allowed_workspaces(db) + + workspaces = (await db.scalars(select(Workspace))).all() + assert len(workspaces) == 0 + + async def test_create_workspaces_with_empty_workspaces_list(self, db: AsyncSession): + with mock.patch( + "argilla_server.security.settings.Settings.oauth", new_callable=lambda: OAuth2Settings(enabled=True) + ): + await _create_oauth_allowed_workspaces(db) + + workspaces = (await db.scalars(select(Workspace))).all() + assert len(workspaces) == 0 + + async def test_create_workspaces_with_existing_workspaces(self, db: AsyncSession): + ws = await WorkspaceFactory.create(name="test") + + with mock.patch( + "argilla_server.security.settings.Settings.oauth", + new_callable=lambda: OAuth2Settings(enabled=True, allowed_workspaces=[AllowedWorkspace(name=ws.name)]), + ): + await _create_oauth_allowed_workspaces(db) + + workspaces = (await db.scalars(select(Workspace))).all() + assert len(workspaces) == 1 From e61abb03ef883b9a706821227586f4ac4a986461 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dami=C3=A1n=20Pumar?= Date: Thu, 25 Jul 2024 10:07:36 +0200 Subject: [PATCH 02/13] =?UTF-8?q?=E2=9C=A8=20Catch=20unhandled=20exception?= =?UTF-8?q?s=20(#5306)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../mode/useBulkAnnotationViewModel.ts | 13 ++--- .../mode/useFocusAnnotationViewModel.ts | 48 +++++++++---------- 2 files changed, 27 insertions(+), 34 deletions(-) diff --git a/argilla-frontend/components/features/annotation/container/mode/useBulkAnnotationViewModel.ts b/argilla-frontend/components/features/annotation/container/mode/useBulkAnnotationViewModel.ts index 561bfb1957..6ace8e3e3b 100644 --- a/argilla-frontend/components/features/annotation/container/mode/useBulkAnnotationViewModel.ts +++ b/argilla-frontend/components/features/annotation/container/mode/useBulkAnnotationViewModel.ts @@ -7,7 +7,6 @@ import { AvailableStatus, BulkAnnotationUseCase, } from "~/v1/domain/usecases/bulk-annotation-use-case"; -import { useDebounce } from "~/v1/infrastructure/services/useDebounce"; import { useNotifications } from "~/v1/infrastructure/services/useNotifications"; import { useTranslate } from "~/v1/infrastructure/services/useTranslate"; @@ -17,7 +16,6 @@ export const useBulkAnnotationViewModel = ({ records: Records; }) => { const notification = useNotifications(); - const debounceForSubmit = useDebounce(300); const affectAllRecords = ref(false); const progress = ref(0); @@ -44,11 +42,12 @@ export const useBulkAnnotationViewModel = ({ recordReference: Record, selectedRecords: Record[] ) => { + let allSuccessful = false; try { const totalRecords = records.total; const isAffectingAllRecords = affectAllRecords.value; - const allSuccessful = await bulkAnnotationUseCase.execute( + allSuccessful = await bulkAnnotationUseCase.execute( status, criteria, recordReference, @@ -73,19 +72,13 @@ export const useBulkAnnotationViewModel = ({ type: "info", }); } - - progress.value = 0; - - await debounceForSubmit.wait(); - - return allSuccessful; } catch { } finally { affectAllRecords.value = false; progress.value = 0; } - return false; + return allSuccessful; }; const discard = async ( diff --git a/argilla-frontend/components/features/annotation/container/mode/useFocusAnnotationViewModel.ts b/argilla-frontend/components/features/annotation/container/mode/useFocusAnnotationViewModel.ts index 722562ac20..ba5fc77b19 100644 --- a/argilla-frontend/components/features/annotation/container/mode/useFocusAnnotationViewModel.ts +++ b/argilla-frontend/components/features/annotation/container/mode/useFocusAnnotationViewModel.ts @@ -4,11 +4,8 @@ import { Record } from "~/v1/domain/entities/record/Record"; import { DiscardRecordUseCase } from "~/v1/domain/usecases/discard-record-use-case"; import { SubmitRecordUseCase } from "~/v1/domain/usecases/submit-record-use-case"; import { SaveDraftUseCase } from "~/v1/domain/usecases/save-draft-use-case"; -import { useDebounce } from "~/v1/infrastructure/services/useDebounce"; export const useFocusAnnotationViewModel = () => { - const debounceForSubmit = useDebounce(300); - const isDraftSaving = ref(false); const isDiscarding = ref(false); const isSubmitting = ref(false); @@ -17,33 +14,36 @@ export const useFocusAnnotationViewModel = () => { const saveDraftUseCase = useResolve(SaveDraftUseCase); const discard = async (record: Record) => { - isDiscarding.value = true; - - await discardUseCase.execute(record); - - await debounceForSubmit.wait(); - - isDiscarding.value = false; + try { + isDiscarding.value = true; + + await discardUseCase.execute(record); + } catch { + } finally { + isDiscarding.value = false; + } }; const submit = async (record: Record) => { - isSubmitting.value = true; - - await submitUseCase.execute(record); - - await debounceForSubmit.wait(); - - isSubmitting.value = false; + try { + isSubmitting.value = true; + + await submitUseCase.execute(record); + } catch { + } finally { + isSubmitting.value = false; + } }; const saveAsDraft = async (record: Record) => { - isDraftSaving.value = true; - - await saveDraftUseCase.execute(record); - - await debounceForSubmit.wait(); - - isDraftSaving.value = false; + try { + isDraftSaving.value = true; + + await saveDraftUseCase.execute(record); + } catch { + } finally { + isDraftSaving.value = false; + } }; return { From c8716753ed097fcef80082765e8b03f47bf398a1 Mon Sep 17 00:00:00 2001 From: Natalia Elvira <126158523+nataliaElv@users.noreply.github.com> Date: Thu, 25 Jul 2024 13:55:01 +0200 Subject: [PATCH 03/13] Docs: 2.0 minor changes (#5295) # Description Closes #5303 Closes #5261 **Type of change** - Documentation update **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: davidberenstein1957 --- CONTRIBUTING.md | 2 +- README.md | 26 +++++++++--------- argilla-frontend/README.md | 3 +-- argilla-server/README.md | 2 +- argilla-v1/README.md | 23 ++++++++-------- argilla-v1/environment_dev.yml | 2 +- argilla-v1/pyproject.toml | 2 +- argilla/README.md | 24 ++++++++--------- argilla/docs/community/contributor.md | 4 +-- argilla/docs/community/index.md | 2 +- argilla/docs/getting_started/faq.md | 6 ++--- argilla/docs/how_to_guides/distribution.md | 2 +- argilla/docs/how_to_guides/index.md | 27 ++++++++++--------- argilla/docs/index.md | 27 ++++++++++--------- argilla/mkdocs.yml | 17 ++++++------ argilla/pdm.lock | 18 +++++++++++-- argilla/pyproject.toml | 1 + argilla/src/argilla/client.py | 10 +++---- .../tests/integration/test_export_dataset.py | 1 + docs/_source/_common/next_steps.md | 2 +- docs/_source/community/contributing.md | 4 +-- 21 files changed, 109 insertions(+), 96 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c5693649ac..f463e297aa 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -13,7 +13,7 @@ However you choose to contribute, please be mindful and respect our ## Need some help? -We understand that getting started might be a bit difficult, therefore, you can join our discord channel [#argilla-general](https://discord.gg/hugging-face-879548962464493619) by selecting *Argilla* in *Channels and Roles* after joining the Discord. +We understand that getting started might be a bit difficult, therefore, you can join our discord channel [#argilla-distilabel-general](https://discord.gg/hugging-face-879548962464493619) by selecting *Argilla* in *Channels and Roles* after joining the Discord. ## Want to work on your own? diff --git a/README.md b/README.md index 4b53b4f67c..38a1caad45 100644 --- a/README.md +++ b/README.md @@ -32,13 +32,13 @@

-Argilla is a **collaboration platform for AI engineers and domain experts** that require **high-quality outputs, full data ownership, and overall efficiency**. +Argilla is a collaboration tool for AI engineers and domain experts who need to build high-quality datasets for their projects. If you just want to get started, we recommend our [UI demo](https://demo.argilla.io/sign-in?auth=ZGVtbzoxMjM0NTY3OA%3D%3D) or our [free Hugging Face Spaces deployment integration](https://huggingface.co/new-space?template=argilla/argilla-template-space). Curious, and want to know more? Read our [documentation](https://argilla-io.github.io/argilla/latest/). ## Why use Argilla? -Whether you are working on monitoring and improving complex **generative tasks** involving LLM pipelines with RAG, or you are working on a **predictive task** for things like AB-testing of span- and text-classification models. Our versatile platform helps you ensure **your data work pays off**. +Argilla can be used for collecting human feedback for a wide variety of AI projects like traditional NLP (text classification, NER, etc.), LLMs (RAG, preference tuning, etc.), or multimodal models (text to image, etc.). Argilla's programmatic approach lets you build workflows for continuous evaluation and model improvement. The goal of Argilla is to ensure your data work pays off by quickly iterating on the right data and models. ### Improve your AI output quality through data quality @@ -46,11 +46,11 @@ Compute is expensive and output quality is important. We help you focus on data, ### Take control of your data and models -Most AI platforms are black boxes. Argilla is different. We believe that you should be the owner of both your data and your models. That's why we provide you with all the tools your team needs to **manage your data and models in a way that suits you best**. +Most AI tools are black boxes. Argilla is different. We believe that you should be the owner of both your data and your models. That's why we provide you with all the tools your team needs to **manage your data and models in a way that suits you best**. ### Improve efficiency by quickly iterating on the right data and models -Gathering data is a time-consuming process. Argilla helps by providing a platform that allows you to **interact with your data in a more engaging way**. This means you can quickly and easily label your data with filters, AI feedback suggestions and semantic search. So you can focus on training your models and monitoring their performance. +Gathering data is a time-consuming process. Argilla helps by providing a tool that allows you to **interact with your data in a more engaging way**. This means you can quickly and easily label your data with filters, AI feedback suggestions and semantic search. So you can focus on training your models and monitoring their performance. ## 🏘️ Community @@ -58,7 +58,7 @@ We are an open-source community-driven project and we love to hear from you. Her - [Community Meetup](https://lu.ma/embed-checkout/evt-IQtRiSuXZCIW6FB): listen in or present during one of our bi-weekly events. -- [Discord](http://hf.co/join/discord): get direct support from the community in #argilla-general and #argilla-help. +- [Discord](http://hf.co/join/discord): get direct support from the community in #argilla-distilabel-general and #argilla-distilabel-help. - [Roadmap](https://github.com/orgs/argilla-io/projects/10/views/1): plans change but we love to discuss those with our community so feel encouraged to participate. @@ -66,18 +66,18 @@ We are an open-source community-driven project and we love to hear from you. Her ### Open-source datasets and models -Argilla is a tool that can be used to achieve and keep **high-quality data standards** with a **focus on NLP and LLMs**. Our community uses Argilla to create amazing open-source [datasets](https://huggingface.co/datasets?other=argilla) and [models](https://huggingface.co/models?other=distilabel), and **we love contributions to open-source** ourselves too. +The community uses Argilla to create amazing open-source [datasets](https://huggingface.co/datasets?library=library:argilla&sort=trending) and [models](https://huggingface.co/models?other=distilabel). -- Our [cleaned UltraFeedback dataset](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences-cleaned) and the [Notus](https://huggingface.co/argilla/notus-7b-v1) and [Notux](https://huggingface.co/argilla/notux-8x7b-v1) models, where we improved benchmark and empirical human judgment for the Mistral and Mixtral models with cleaner data using **human feedback**. -- Our [distilabeled Intel Orca DPO dataset](https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs) and the [improved OpenHermes model](https://huggingface.co/argilla/distilabeled-OpenHermes-2.5-Mistral-7B), show how we improve model performance by filtering out 50% of the original dataset through **human and AI feedback**. +- [Cleaned UltraFeedback dataset](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences-cleaned) used to fine-tune the [Notus](https://huggingface.co/argilla/notus-7b-v1) and [Notux](https://huggingface.co/argilla/notux-8x7b-v1) models. The original UltraFeedback dataset was curated using Argilla UI filters to find and report a bug in the original data generation code. Based on this data curation process, Argilla built this new version of the UltraFeedback dataset and fine-tuned Notus, outperforming Zephyr on several benchmarks. +- [distilabeled Intel Orca DPO dataset](https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs) used to fine-tune the [improved OpenHermes model](https://huggingface.co/argilla/distilabeled-OpenHermes-2.5-Mistral-7B). This dataset was built by combining human curation in Argilla with AI feedback from distilabel, leading to an improved version of the Intel Orca dataset and outperforming models fine-tuned on the original dataset. -### Internal Use cases +### Examples Use cases -AI teams from companies like [the Red Cross](https://510.global/), [Loris.ai](https://loris.ai/) and [Prolific](https://www.prolific.com/) use Argilla to **improve the quality and efficiency of AI** projects. They shared their experiences in our [AI community meetup](https://lu.ma/embed-checkout/evt-IQtRiSuXZCIW6FB). +AI teams from companies like [the Red Cross](https://510.global/), [Loris.ai](https://loris.ai/) and [Prolific](https://www.prolific.com/) use Argilla to improve the quality and efficiency of AI projects. They shared their experiences in our [AI community meetup](https://lu.ma/embed-checkout/evt-IQtRiSuXZCIW6FB). -- AI for good: [the Red Cross presentation](https://youtu.be/ZsCqrAhzkFU?feature=shared) showcases **how their experts and AI team collaborate** by classifying and redirecting requests from refugees of the Ukrainian crisis to streamline the support processes of the Red Cross. -- Customer support: during [the Loris meetup](https://youtu.be/jWrtgf2w4VU?feature=shared) they showed how their AI team uses unsupervised and few-shot contrastive learning to help them **quickly validate and gain labelled samples for a huge amount of multi-label classifiers**. -- Research studies: [the showcase from Prolific](https://youtu.be/ePDlhIxnuAs?feature=shared) announced their integration with our platform. They use it to actively **distribute data collection projects** among their annotating workforce. This allows them to quickly and **efficiently collect high-quality data** for their research studies. +- AI for good: [the Red Cross presentation](https://youtu.be/ZsCqrAhzkFU?feature=shared) showcases how the Red Cross domain experts and AI team collaborated by classifying and redirecting requests from refugees of the Ukrainian crisis to streamline the support processes of the Red Cross. +- Customer support: during [the Loris meetup](https://youtu.be/jWrtgf2w4VU?feature=shared) they showed how their AI team uses unsupervised and few-shot contrastive learning to help them quickly validate and gain labelled samples for a huge amount of multi-label classifiers. +- Research studies: [the showcase from Prolific](https://youtu.be/ePDlhIxnuAs?feature=shared) announced their integration with our platform. They use it to actively distribute data collection projects among their annotating workforce. This allows Prolific to quickly and efficiently collect high-quality data for research studies. ## 👨‍💻 Getting started diff --git a/argilla-frontend/README.md b/argilla-frontend/README.md index 1f4402973f..1e99d21cfe 100644 --- a/argilla-frontend/README.md +++ b/argilla-frontend/README.md @@ -31,8 +31,7 @@

-Argilla is a **collaboration platform for AI engineers and domain experts** that require **high-quality outputs, full data ownership, and overall efficiency**. - +Argilla is a collaboration tool for AI engineers and domain experts who need to build high-quality datasets for their projects. If you just want to get started, we recommend our [UI demo](https://demo.argilla.io/sign-in?auth=ZGVtbzoxMjM0NTY3OA%3D%3D) or our [free Hugging Face Spaces deployment integration](https://huggingface.co/new-space?template=argilla/argilla-template-space). Curious, and want to know more? Read our [documentation](https://argilla-io.github.io/argilla/latest/). This repository only contains developer info about the front end. If you want to get started, we recommend taking a diff --git a/argilla-server/README.md b/argilla-server/README.md index 3406281946..bff0cd111e 100644 --- a/argilla-server/README.md +++ b/argilla-server/README.md @@ -31,7 +31,7 @@

-Argilla is a **collaboration platform for AI engineers and domain experts** that require **high-quality outputs, full data ownership, and overall efficiency**. +Argilla is a collaboration tool for AI engineers and domain experts who need to build high-quality datasets for their projects. This repository only contains developer info about the backend server. If you want to get started, we recommend taking a look at our [main repository](https://github.com/argilla-io/argilla) or our [documentation](https://argilla-io.github.io/argilla/latest/). diff --git a/argilla-v1/README.md b/argilla-v1/README.md index bd0c21217f..4266984fd3 100644 --- a/argilla-v1/README.md +++ b/argilla-v1/README.md @@ -32,8 +32,7 @@

-Argilla is a **collaboration platform for AI engineers and domain experts** that require **high-quality outputs, full data ownership, and overall efficiency**. - +Argilla is a collaboration tool for AI engineers and domain experts who need to build high-quality datasets for their projects. > [!NOTE] > This README represents the 1.29 SDK version. We have stopped development for the 1.x SDK version, while still committing to bug fixes. If you are looking for the README of the 2.x SDK version take a look [here](../README.md). @@ -49,7 +48,7 @@ Compute is expensive and output quality is important. We help you focus on data, ### Take control of your data and models -Most AI platforms are black boxes. Argilla is different. We believe that you should be the owner of both your data and your models. That's why we provide you with all the tools your team needs to **manage your data and models in a way that suits you best**. +Most AI tools are black boxes. Argilla is different. We believe that you should be the owner of both your data and your models. That's why we provide you with all the tools your team needs to **manage your data and models in a way that suits you best**. ### Improve efficiency by quickly iterating on the right data and models @@ -61,7 +60,7 @@ We are an open-source community-driven project and we love to hear from you. Her - [Community Meetup](https://lu.ma/embed-checkout/evt-IQtRiSuXZCIW6FB): listen in or present during one of our bi-weekly events. -- [Discord](http://hf.co/join/discord"): get direct support from the community in #argilla-general and #argilla-help. +- [Discord](http://hf.co/join/discord"): get direct support from the community in #argilla-distilabel-general and #argilla-distilabel-help. - [Roadmap](https://github.com/orgs/argilla-io/projects/10/views/1): plans change but we love to discuss those with our community so feel encouraged to participate. @@ -69,18 +68,18 @@ We are an open-source community-driven project and we love to hear from you. Her ### Open-source datasets and models -Argilla is a tool that can be used to achieve and keep **high-quality data standards** with a **focus on NLP and LLMs**. Our community uses Argilla to create amazing open-source [datasets](https://huggingface.co/datasets?other=argilla) and [models](https://huggingface.co/models?other=distilabel), and **we love contributions to open-source** ourselves too. +The community uses Argilla to create amazing open-source [datasets](https://huggingface.co/datasets?library=library:argilla&sort=trending) and [models](https://huggingface.co/models?other=distilabel). -- Our [cleaned UltraFeedback dataset](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences-cleaned) and the [Notus](https://huggingface.co/argilla/notus-7b-v1) and [Notux](https://huggingface.co/argilla/notux-8x7b-v1) models, where we improved benchmark and empirical human judgment for the Mistral and Mixtral models with cleaner data using **human feedback**. -- Our [distilabeled Intel Orca DPO dataset](https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs) and the [improved OpenHermes model](https://huggingface.co/argilla/distilabeled-OpenHermes-2.5-Mistral-7B), show how we improve model performance by filtering out 50% of the original dataset through **human and AI feedback**. +- [Cleaned UltraFeedback dataset](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences-cleaned) used to fine-tune the [Notus](https://huggingface.co/argilla/notus-7b-v1) and [Notux](https://huggingface.co/argilla/notux-8x7b-v1) models. The original UltraFeedback dataset was curated using Argilla UI filters to find and report a bug in the original data generation code. Based on this data curation process, Argilla built this new version of the UltraFeedback dataset and fine-tuned Notus, outperforming Zephyr on several benchmarks. +- [distilabeled Intel Orca DPO dataset](https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs) used to fine-tune the [improved OpenHermes model](https://huggingface.co/argilla/distilabeled-OpenHermes-2.5-Mistral-7B). This dataset was built by combining human curation in Argilla with AI feedback from distilabel, leading to an improved version of the Intel Orca dataset and outperforming models fine-tuned on the original dataset. -### Internal Use cases +### Examples Use cases -AI teams from companies like [the Red Cross](https://510.global/), [Loris.ai](https://loris.ai/) and [Prolific](https://www.prolific.com/) use Argilla to **improve the quality and efficiency of AI** projects. They shared their experiences in our [AI community meetup](https://lu.ma/embed-checkout/evt-IQtRiSuXZCIW6FB). +AI teams from companies like [the Red Cross](https://510.global/), [Loris.ai](https://loris.ai/) and [Prolific](https://www.prolific.com/) use Argilla to improve the quality and efficiency of AI projects. They shared their experiences in our [AI community meetup](https://lu.ma/embed-checkout/evt-IQtRiSuXZCIW6FB). -- AI for good: [the Red Cross presentation](https://youtu.be/ZsCqrAhzkFU?feature=shared) showcases **how their experts and AI team collaborate** by classifying and redirecting requests from refugees of the Ukrainian crisis to streamline the support processes of the Red Cross. -- Customer support: during [the Loris meetup](https://youtu.be/jWrtgf2w4VU?feature=shared) they showed how their AI team uses unsupervised and few-shot contrastive learning to help them **quickly validate and gain labelled samples for a huge amount of multi-label classifiers**. -- Research studies: [the showcase from Prolific](https://youtu.be/ePDlhIxnuAs?feature=shared) announced their integration with our platform. They use it to actively **distribute data collection projects** among their annotating workforce. This allows them to quickly and **efficiently collect high-quality data** for their research studies. +- AI for good: [the Red Cross presentation](https://youtu.be/ZsCqrAhzkFU?feature=shared) showcases how the Red Cross domain experts and AI team collaborated by classifying and redirecting requests from refugees of the Ukrainian crisis to streamline the support processes of the Red Cross. +- Customer support: during [the Loris meetup](https://youtu.be/jWrtgf2w4VU?feature=shared) they showed how their AI team uses unsupervised and few-shot contrastive learning to help them quickly validate and gain labelled samples for a huge amount of multi-label classifiers. +- Research studies: [the showcase from Prolific](https://youtu.be/ePDlhIxnuAs?feature=shared) announced their integration with our platform. They use it to actively distribute data collection projects among their annotating workforce. This allows Prolific to quickly and efficiently collect high-quality data for research studies. ## 👨‍💻 Getting started diff --git a/argilla-v1/environment_dev.yml b/argilla-v1/environment_dev.yml index fa74c0387a..61fa5e1518 100644 --- a/argilla-v1/environment_dev.yml +++ b/argilla-v1/environment_dev.yml @@ -40,7 +40,7 @@ dependencies: # extra test dependencies - cleanlab~=2.0.0 # With this version, tests are failing - datasets>1.17.0,!= 2.3.2 # TODO: push_to_hub fails up to 2.3.2, check patches when they come out eventually - - huggingface_hub>=0.5.0 + - huggingface_hub>=0.5.0, <0.24 - flair>=0.12.2 - scipy~=1.12.0 # To avoid error importing scipy.linalg (https://github.com/argilla-io/argilla/actions/runs/8689560057/job/23828206007#step:12:3851) - faiss-cpu diff --git a/argilla-v1/pyproject.toml b/argilla-v1/pyproject.toml index a3089362ce..6a873c2859 100644 --- a/argilla-v1/pyproject.toml +++ b/argilla-v1/pyproject.toml @@ -63,7 +63,7 @@ integrations = [ # TODO: `push_to_hub` fails up to 2.3.2, check patches when they come out eventually "datasets > 1.17.0,!= 2.3.2", # TODO: some backward comp. problems introduced in 0.5.0 - "huggingface_hub >= 0.5.0", + "huggingface_hub >= 0.5.0,< 0.24", # Version 0.12 fixes a known installation issue related to `sentencepiece` and `tokenizers`, more at https://github.com/flairNLP/flair/issues/3129 # Version 0.12.2 relaxes the `huggingface_hub` dependency "flair >= 0.12.2", diff --git a/argilla/README.md b/argilla/README.md index 4b53b4f67c..c1453d312f 100644 --- a/argilla/README.md +++ b/argilla/README.md @@ -32,8 +32,7 @@

-Argilla is a **collaboration platform for AI engineers and domain experts** that require **high-quality outputs, full data ownership, and overall efficiency**. - +Argilla is a collaboration tool for AI engineers and domain experts who need to build high-quality datasets for their projects. If you just want to get started, we recommend our [UI demo](https://demo.argilla.io/sign-in?auth=ZGVtbzoxMjM0NTY3OA%3D%3D) or our [free Hugging Face Spaces deployment integration](https://huggingface.co/new-space?template=argilla/argilla-template-space). Curious, and want to know more? Read our [documentation](https://argilla-io.github.io/argilla/latest/). ## Why use Argilla? @@ -46,7 +45,7 @@ Compute is expensive and output quality is important. We help you focus on data, ### Take control of your data and models -Most AI platforms are black boxes. Argilla is different. We believe that you should be the owner of both your data and your models. That's why we provide you with all the tools your team needs to **manage your data and models in a way that suits you best**. +Most AI tools are black boxes. Argilla is different. We believe that you should be the owner of both your data and your models. That's why we provide you with all the tools your team needs to **manage your data and models in a way that suits you best**. ### Improve efficiency by quickly iterating on the right data and models @@ -58,7 +57,7 @@ We are an open-source community-driven project and we love to hear from you. Her - [Community Meetup](https://lu.ma/embed-checkout/evt-IQtRiSuXZCIW6FB): listen in or present during one of our bi-weekly events. -- [Discord](http://hf.co/join/discord): get direct support from the community in #argilla-general and #argilla-help. +- [Discord](http://hf.co/join/discord): get direct support from the community in #argilla-distilabel-general and #argilla-distilabel-help. - [Roadmap](https://github.com/orgs/argilla-io/projects/10/views/1): plans change but we love to discuss those with our community so feel encouraged to participate. @@ -66,18 +65,18 @@ We are an open-source community-driven project and we love to hear from you. Her ### Open-source datasets and models -Argilla is a tool that can be used to achieve and keep **high-quality data standards** with a **focus on NLP and LLMs**. Our community uses Argilla to create amazing open-source [datasets](https://huggingface.co/datasets?other=argilla) and [models](https://huggingface.co/models?other=distilabel), and **we love contributions to open-source** ourselves too. +The community uses Argilla to create amazing open-source [datasets](https://huggingface.co/datasets?library=library:argilla&sort=trending) and [models](https://huggingface.co/models?other=distilabel). -- Our [cleaned UltraFeedback dataset](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences-cleaned) and the [Notus](https://huggingface.co/argilla/notus-7b-v1) and [Notux](https://huggingface.co/argilla/notux-8x7b-v1) models, where we improved benchmark and empirical human judgment for the Mistral and Mixtral models with cleaner data using **human feedback**. -- Our [distilabeled Intel Orca DPO dataset](https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs) and the [improved OpenHermes model](https://huggingface.co/argilla/distilabeled-OpenHermes-2.5-Mistral-7B), show how we improve model performance by filtering out 50% of the original dataset through **human and AI feedback**. +- [Cleaned UltraFeedback dataset](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences-cleaned) used to fine-tune the [Notus](https://huggingface.co/argilla/notus-7b-v1) and [Notux](https://huggingface.co/argilla/notux-8x7b-v1) models. The original UltraFeedback dataset was curated using Argilla UI filters to find and report a bug in the original data generation code. Based on this data curation process, Argilla built this new version of the UltraFeedback dataset and fine-tuned Notus, outperforming Zephyr on several benchmarks. +- [distilabeled Intel Orca DPO dataset](https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs) used to fine-tune the [improved OpenHermes model](https://huggingface.co/argilla/distilabeled-OpenHermes-2.5-Mistral-7B). This dataset was built by combining human curation in Argilla with AI feedback from distilabel, leading to an improved version of the Intel Orca dataset and outperforming models fine-tuned on the original dataset. -### Internal Use cases +### Examples Use cases -AI teams from companies like [the Red Cross](https://510.global/), [Loris.ai](https://loris.ai/) and [Prolific](https://www.prolific.com/) use Argilla to **improve the quality and efficiency of AI** projects. They shared their experiences in our [AI community meetup](https://lu.ma/embed-checkout/evt-IQtRiSuXZCIW6FB). +AI teams from companies like [the Red Cross](https://510.global/), [Loris.ai](https://loris.ai/) and [Prolific](https://www.prolific.com/) use Argilla to improve the quality and efficiency of AI projects. They shared their experiences in our [AI community meetup](https://lu.ma/embed-checkout/evt-IQtRiSuXZCIW6FB). -- AI for good: [the Red Cross presentation](https://youtu.be/ZsCqrAhzkFU?feature=shared) showcases **how their experts and AI team collaborate** by classifying and redirecting requests from refugees of the Ukrainian crisis to streamline the support processes of the Red Cross. -- Customer support: during [the Loris meetup](https://youtu.be/jWrtgf2w4VU?feature=shared) they showed how their AI team uses unsupervised and few-shot contrastive learning to help them **quickly validate and gain labelled samples for a huge amount of multi-label classifiers**. -- Research studies: [the showcase from Prolific](https://youtu.be/ePDlhIxnuAs?feature=shared) announced their integration with our platform. They use it to actively **distribute data collection projects** among their annotating workforce. This allows them to quickly and **efficiently collect high-quality data** for their research studies. +- AI for good: [the Red Cross presentation](https://youtu.be/ZsCqrAhzkFU?feature=shared) showcases how the Red Cross domain experts and AI team collaborated by classifying and redirecting requests from refugees of the Ukrainian crisis to streamline the support processes of the Red Cross. +- Customer support: during [the Loris meetup](https://youtu.be/jWrtgf2w4VU?feature=shared) they showed how their AI team uses unsupervised and few-shot contrastive learning to help them quickly validate and gain labelled samples for a huge amount of multi-label classifiers. +- Research studies: [the showcase from Prolific](https://youtu.be/ePDlhIxnuAs?feature=shared) announced their integration with our platform. They use it to actively distribute data collection projects among their annotating workforce. This allows Prolific to quickly and efficiently collect high-quality data for research studies. ## 👨‍💻 Getting started @@ -154,4 +153,3 @@ To help our community with the creation of contributions, we have created our [c - diff --git a/argilla/docs/community/contributor.md b/argilla/docs/community/contributor.md index b45c2be013..fe800a1fd8 100644 --- a/argilla/docs/community/contributor.md +++ b/argilla/docs/community/contributor.md @@ -21,8 +21,8 @@ Discord is a handy tool for more casual conversations and to answer day-to-day q When part of the Hugging Face Discord, you can select "Channels & roles" and select "Argilla" along with any of the other groups that are interesting to you. "Argilla" will cover anything about argilla and distilabel. You can join the following channels: -* **#argilla-general**: 📣 Stay up-to-date and general discussions. -* **#argilla-help**: 🙋‍♀️ Need assistance? We're always here to help. Select the appropriate label (argilla or distilabel) for your issue and post it. +* **#argilla-distilabel-general**: 📣 Stay up-to-date and general discussions. +* **#argilla-distilabel-help**: 🙋‍♀️ Need assistance? We're always here to help. Select the appropriate label (argilla or distilabel) for your issue and post it. So now there is only one thing left to do: introduce yourself and talk to the community. You'll always be welcome! 🤗👋 diff --git a/argilla/docs/community/index.md b/argilla/docs/community/index.md index fe9fa26642..48df2adb87 100644 --- a/argilla/docs/community/index.md +++ b/argilla/docs/community/index.md @@ -13,7 +13,7 @@ We are an open-source community-driven project not only focused on building a gr --- - In our Discord channels (#argilla-general and #argilla-help), you can get direct support from the community. + In our Discord channels (#argilla-distilabel-general and #argilla-distilabel-help), you can get direct support from the community. [:octicons-arrow-right-24: Discord ↗](http://hf.co/join/discord) diff --git a/argilla/docs/getting_started/faq.md b/argilla/docs/getting_started/faq.md index 5c0722b497..e4076a04c3 100644 --- a/argilla/docs/getting_started/faq.md +++ b/argilla/docs/getting_started/faq.md @@ -7,7 +7,7 @@ hide: toc ??? Question "What is Argilla?" - Argilla is a collaboration platform for AI engineers and domain experts that require high-quality outputs, full data ownership, and overall efficiency. It is designed to help you achieve and keep high-quality data standards, store your training data, store the results of your models, evaluate their performance, and improve the data through human and AI feedback. + Argilla is a collaboration tool for AI engineers and domain experts that require high-quality outputs, full data ownership, and overall efficiency. It is designed to help you achieve and keep high-quality data standards, store your training data, store the results of your models, evaluate their performance, and improve the data through human and AI feedback. ??? Question "Does Argilla cost money?" @@ -19,7 +19,7 @@ hide: toc ??? Question "Does Argilla train models?" - No. Argilla is a collaboration platform to achieve and keep high-quality data standards. You can use Argilla to store your training data, store the results of your models, evaluate their performance and improve the data. For training models, you can use any machine learning framework or library that you prefer even though we recommend starting with [Hugging Face Transformers](https://github.com/huggingface/transformers). + No. Argilla is a collaboration tool to achieve and keep high-quality data standards. You can use Argilla to store your training data, store the results of your models, evaluate their performance and improve the data. For training models, you can use any machine learning framework or library that you prefer even though we recommend starting with [Hugging Face Transformers](https://github.com/huggingface/transformers). ??? Question "Does Argilla provide annotation workforces?" @@ -31,7 +31,7 @@ hide: toc Furthermore, Argilla places particular emphasis on smooth integration with other tools in the community, particularly within the realms of MLOps and NLP. So, its compatibility with popular frameworks like spaCy and Hugging Face makes it exceptionally user-friendly and accessible. - Finally, platforms like Snorkel, Prodigy or Scale, while more comprehensive, often require a significant commitment. Argilla, on the other hand, works more as a component within the MLOps ecosystem, allowing users to begin with specific use cases and then scale up as needed. This flexibility is particularly beneficial for users and customers who prefer to start small and expand their applications over time, as opposed to committing to an all-encompassing platform from the outset. + Finally, platforms like Snorkel, Prodigy or Scale, while more comprehensive, often require a significant commitment. Argilla, on the other hand, works more as a tool within the MLOps ecosystem, allowing users to begin with specific use cases and then scale up as needed. This flexibility is particularly beneficial for users and customers who prefer to start small and expand their applications over time, as opposed to committing to an all-encompassing tool from the outset. ??? Question "What is the difference between Argilla 2.0 and the legacy datasets in 1.0?" diff --git a/argilla/docs/how_to_guides/distribution.md b/argilla/docs/how_to_guides/distribution.md index 94332df052..fedb2e2d16 100644 --- a/argilla/docs/how_to_guides/distribution.md +++ b/argilla/docs/how_to_guides/distribution.md @@ -6,7 +6,7 @@ description: In this section, we will provide a step-by-step guide to show how t This guide explains how you can use Argilla’s **automatic task distribution** to efficiently divide the task of annotating a dataset among multiple team members. -Owners and admins can define the minimum number of submitted responses expected for each record depending on whether the dataset should have annotation overlap and how much. Argilla will use this setting to handle automatically the records that will be shown in the pending queues of all users with access to the dataset. +Owners and admins can define the minimum number of submitted responses expected for each record. Argilla will use this setting to handle automatically the records that will be shown in the pending queues of all users with access to the dataset. When a record has met the minimum number of submissions, the status of the record will change to `completed` and the record will be removed from the `Pending` queue of all team members, so they can focus on providing responses where they are most needed. The dataset’s annotation task will be fully completed once all records have the `completed` status. diff --git a/argilla/docs/how_to_guides/index.md b/argilla/docs/how_to_guides/index.md index 25cd4245f6..0ad4a839bc 100644 --- a/argilla/docs/how_to_guides/index.md +++ b/argilla/docs/how_to_guides/index.md @@ -27,7 +27,7 @@ These guides provide step-by-step instructions for common scenarios, including d [:octicons-arrow-right-24: How-to guide](workspace.md) -- __Manage and create datasets__ +- __Create, update, and delete datasets__ --- @@ -43,37 +43,38 @@ These guides provide step-by-step instructions for common scenarios, including d [:octicons-arrow-right-24: How-to guide](record.md) -- __Query and filter a dataset__ +- __Distribute the annotation__ --- - Learn how to query and filter a `Dataset`. + Learn how to use Argilla's automatic task distribution to annotate as a team efficiently. - [:octicons-arrow-right-24: How-to guide](query.md) + [:octicons-arrow-right-24: How-to guide](distribution.md) -- __Importing and exporting datasets and records__ +- __Annotate a dataset__ --- - Learn how to export your dataset or its records to Python, your local disk, or the Hugging face Hub. + Learn how to use the Argilla UI to navigate datasets and submit responses. - [:octicons-arrow-right-24: How-to guide](import_export.md) + [:octicons-arrow-right-24: How-to guide](annotate.md) -- __Annotate a dataset__ +- __Query and filter a dataset__ --- - Learn how to use the Argilla UI to navigate datasets and submit responses. + Learn how to query and filter a `Dataset`. - [:octicons-arrow-right-24: How-to guide](annotate.md) + [:octicons-arrow-right-24: How-to guide](query.md) -- __Distribute the annotation__ +- __Import and export datasets and records__ --- - Learn how to use Argilla's automatic task distribution to annotate as a team efficiently. + Learn how to export your dataset or its records to Python, your local disk, or the Hugging face Hub. + + [:octicons-arrow-right-24: How-to guide](import_export.md) - [:octicons-arrow-right-24: How-to guide](distribution.md) diff --git a/argilla/docs/index.md b/argilla/docs/index.md index 097f4cd062..c92b2abaa0 100644 --- a/argilla/docs/index.md +++ b/argilla/docs/index.md @@ -1,11 +1,11 @@ --- -description: Argilla is a **collaboration platform for AI engineers and domain experts** that require **high-quality outputs, full data ownership, and overall efficiency**. +description: Argilla is a collaboration tool for AI engineers and domain experts who need to build high-quality datasets for their projects. hide: navigation --- # Welcome to Argilla 2.x -Argilla is a **collaboration platform for AI engineers and domain experts** that require **high-quality outputs, full data ownership, and overall efficiency**. +Argilla is a collaboration tool for AI engineers and domain experts who need to build high-quality datasets for their projects. !!! INFO "Looking for Argilla 1.x?" Looking for documentation for Argilla 1.x? Visit [the latest release](https://docs.argilla.io/en/latest/). @@ -27,7 +27,7 @@ Argilla is a **collaboration platform for AI engineers and domain experts** that --- - Get familiar with basic and complex workflows for Argilla. Learn how to managing `Users`, `Workspaces`, `Datasets`, and `Records` to set up your data annotation projects. + Get familiar with basic and advanced workflows for Argilla. Learn how to manage `Users`, `Workspaces`, `Datasets`, and `Records` to set up your data annotation projects. [:octicons-arrow-right-24: Learn more](how_to_guides/index.md) @@ -37,7 +37,8 @@ Argilla is a **collaboration platform for AI engineers and domain experts** that ## Why use Argilla? -Whether you are working on monitoring and improving complex **generative tasks** involving LLM pipelines with RAG, or you are working on a **predictive task** for things like AB-testing of span- and text-classification models. Our versatile platform helps you ensure **your data work pays off**. +Argilla can be used for collecting human feedback for a wide variety of AI projects like traditional NLP (text classification, NER, etc.), LLMs (RAG, preference tuning, etc.), or multimodal models (text to image, etc.). Argilla's programmatic approach lets you build workflows for continuous evaluation and model improvement. The goal of Argilla is to ensure your data work pays off by quickly iterating on the right data and models. +

Improve your AI output quality through data quality

@@ -45,26 +46,26 @@ Compute is expensive and output quality is important. We help you focus on data,

Take control of your data and models

-Most AI platforms are black boxes. Argilla is different. We believe that you should be the owner of both your data and your models. That's why we provide you with all the tools your team needs to **manage your data and models in a way that suits you best**. +Most AI tools are black boxes. Argilla is different. We believe that you should be the owner of both your data and your models. That's why we provide you with all the tools your team needs to **manage your data and models in a way that suits you best**.

Improve efficiency by quickly iterating on the right data and models

-Gathering data is a time-consuming process. Argilla helps by providing a platform that allows you to **interact with your data in a more engaging way**. This means you can quickly and easily label your data with filters, AI feedback suggestions and semantic search. So you can focus on training your models and monitoring their performance. +Gathering data is a time-consuming process. Argilla helps by providing a tool that allows you to **interact with your data in a more engaging way**. This means you can quickly and easily label your data with filters, AI feedback suggestions and semantic search. So you can focus on training your models and monitoring their performance. ## What do people build with Argilla?

Datasets and models

-Argilla is a tool that can be used to achieve and keep **high-quality data standards** with a **focus on NLP and LLMs**. Our community uses Argilla to create amazing open-source [datasets](https://huggingface.co/datasets?other=argilla) and [models](https://huggingface.co/models?other=distilabel), and **we love contributions to open-source** ourselves too. +The community uses Argilla to create amazing open-source [datasets](https://huggingface.co/datasets?library=library:argilla&sort=trending) and [models](https://huggingface.co/models?other=distilabel). -- Our [cleaned UltraFeedback dataset](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences-cleaned) and the [Notus](https://huggingface.co/argilla/notus-7b-v1) and [Notux](https://huggingface.co/argilla/notux-8x7b-v1) models, where we improved benchmark and empirical human judgment for the Mistral and Mixtral models with cleaner data using **human feedback**. -- Our [distilabeled Intel Orca DPO dataset](https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs) and the [improved OpenHermes model](https://huggingface.co/argilla/distilabeled-OpenHermes-2.5-Mistral-7B), show how we improve model performance by filtering out 50% of the original dataset through **human and AI feedback**. +- [Cleaned UltraFeedback dataset](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences-cleaned) used to fine-tune the [Notus](https://huggingface.co/argilla/notus-7b-v1) and [Notux](https://huggingface.co/argilla/notux-8x7b-v1) models. The original UltraFeedback dataset was curated using Argilla UI filters to find and report a bug in the original data generation code. Based on this data curation process, Argilla built this new version of the UltraFeedback dataset and fine-tuned Notus, outperforming Zephyr on several benchmarks. +- [distilabeled Intel Orca DPO dataset](https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs) used to fine-tune the [improved OpenHermes model](https://huggingface.co/argilla/distilabeled-OpenHermes-2.5-Mistral-7B). This dataset was built by combining human curation in Argilla with AI feedback from distilabel, leading to an improved version of the Intel Orca dataset and outperforming models fine-tuned on the original dataset.

Projects and pipelines

-AI teams from companies like [the Red Cross](https://510.global/), [Loris.ai](https://loris.ai/) and [Prolific](https://www.prolific.com/) use Argilla to **improve the quality and efficiency of AI** projects. They shared their experiences in our [AI community meetup](https://lu.ma/embed-checkout/evt-IQtRiSuXZCIW6FB). +AI teams from companies like [the Red Cross](https://510.global/), [Loris.ai](https://loris.ai/) and [Prolific](https://www.prolific.com/) use Argilla to improve the quality and efficiency of AI projects. They shared their experiences in our [AI community meetup](https://lu.ma/embed-checkout/evt-IQtRiSuXZCIW6FB). -- AI for good: [the Red Cross presentation](https://youtu.be/ZsCqrAhzkFU?feature=shared) showcases **how their experts and AI team collaborate** by classifying and redirecting requests from refugees of the Ukrainian crisis to streamline the support processes of the Red Cross. -- Customer support: during [the Loris meetup](https://youtu.be/jWrtgf2w4VU?feature=shared) they showed how their AI team uses unsupervised and few-shot contrastive learning to help them **quickly validate and gain labelled samples for a huge amount of multi-label classifiers**. -- Research studies: [the showcase from Prolific](https://youtu.be/ePDlhIxnuAs?feature=shared) announced their integration with our platform. They use it to actively **distribute data collection projects** among their annotating workforce. This allows them to quickly and **efficiently collect high-quality data** for their research studies. +- AI for good: [the Red Cross presentation](https://youtu.be/ZsCqrAhzkFU?feature=shared) showcases how the Red Cross domain experts and AI team collaborated by classifying and redirecting requests from refugees of the Ukrainian crisis to streamline the support processes of the Red Cross. +- Customer support: during [the Loris meetup](https://youtu.be/jWrtgf2w4VU?feature=shared) they showed how their AI team uses unsupervised and few-shot contrastive learning to help them quickly validate and gain labelled samples for a huge amount of multi-label classifiers. +- Research studies: [the showcase from Prolific](https://youtu.be/ePDlhIxnuAs?feature=shared) announced their integration with our platform. They use it to actively distribute data collection projects among their annotating workforce. This allows Prolific to quickly and efficiently collect high-quality data for research studies. diff --git a/argilla/mkdocs.yml b/argilla/mkdocs.yml index c7dd0fe28c..099fb7d06a 100644 --- a/argilla/mkdocs.yml +++ b/argilla/mkdocs.yml @@ -135,14 +135,15 @@ nav: - FAQ: getting_started/faq.md - How-to guides: - how_to_guides/index.md - - Manage users and credentials: how_to_guides/user.md - - Manage workspaces: how_to_guides/workspace.md - - Manage and create datasets: how_to_guides/dataset.md - - Add, update, and delete records: how_to_guides/record.md - - Query and filter records: how_to_guides/query.md - - Importing and exporting datasets and records: how_to_guides/import_export.md - - Annotate a dataset: how_to_guides/annotate.md - - Distribute the annotation task: how_to_guides/distribution.md + - Basic: + - Manage users and credentials: how_to_guides/user.md + - Manage workspaces: how_to_guides/workspace.md + - Create, update and delete datasets: how_to_guides/dataset.md + - Add, update, and delete records: how_to_guides/record.md + - Distribute the annotation task: how_to_guides/distribution.md + - Annotate datasets: how_to_guides/annotate.md + - Query and filter records: how_to_guides/query.md + - Import and export datasets: how_to_guides/import_export.md - Advanced: - Use Markdown to format rich content: how_to_guides/use_markdown_to_format_rich_content.md - Migrate your legacy datasets to Argilla V2: how_to_guides/migrate_from_legacy_datasets.md diff --git a/argilla/pdm.lock b/argilla/pdm.lock index da59020f7e..b1ac73bb56 100644 --- a/argilla/pdm.lock +++ b/argilla/pdm.lock @@ -4,8 +4,8 @@ [metadata] groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] -lock_version = "4.4.1" -content_hash = "sha256:106793ca83dbd9b4c7125bf4bf3c466284aa245672ed7641c94a71308e549255" +lock_version = "4.4.2" +content_hash = "sha256:80432e7b8d98d1ac852cc1fa0b571731462229737f0432103ec05c8ffd71578b" [[package]] name = "aiohttp" @@ -2095,6 +2095,20 @@ files = [ {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, ] +[[package]] +name = "pytest-retry" +version = "1.6.3" +requires_python = ">=3.9" +summary = "Adds the ability to retry flaky tests in CI environments" +groups = ["dev"] +dependencies = [ + "pytest>=7.0.0", +] +files = [ + {file = "pytest_retry-1.6.3-py3-none-any.whl", hash = "sha256:e96f7df77ee70b0838d1085f9c3b8b5b7d74bf8947a0baf32e2b8c71b27683c8"}, + {file = "pytest_retry-1.6.3.tar.gz", hash = "sha256:36ccfa11c8c8f9ddad5e20375182146d040c20c4a791745139c5a99ddf1b557d"}, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" diff --git a/argilla/pyproject.toml b/argilla/pyproject.toml index 0c2d4f4d59..a4ac8b0864 100644 --- a/argilla/pyproject.toml +++ b/argilla/pyproject.toml @@ -62,6 +62,7 @@ dev = [ "CairoSVG >= 2.7.1", "mknotebooks >= 0.8.0", "argilla-v1[listeners] @ file:///${PROJECT_ROOT}/../argilla-v1", + "pytest-retry>=1.5", ] [tool.pdm.scripts] diff --git a/argilla/src/argilla/client.py b/argilla/src/argilla/client.py index 36b0a21ed8..14dc9863c7 100644 --- a/argilla/src/argilla/client.py +++ b/argilla/src/argilla/client.py @@ -16,7 +16,7 @@ from abc import abstractmethod from collections.abc import Sequence from functools import cached_property -from typing import TYPE_CHECKING, overload, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union, overload from uuid import UUID from argilla import _api @@ -25,12 +25,10 @@ from argilla._exceptions import ArgillaError, NotFoundError from argilla._helpers import GenericIterator from argilla._helpers._resource_repr import ResourceHTMLReprMixin -from argilla._models import UserModel, WorkspaceModel, DatasetModel, ResourceModel +from argilla._models import DatasetModel, ResourceModel, UserModel, WorkspaceModel if TYPE_CHECKING: - from argilla import Workspace - from argilla import Dataset - from argilla import User + from argilla import Dataset, User, Workspace __all__ = ["Argilla"] @@ -160,7 +158,7 @@ def __len__(self) -> int: return len(self._api.list()) def add(self, user: "User") -> "User": - """Add a new user to the Argilla platform. + """Add a new user to Argilla. Args: user: User object. diff --git a/argilla/tests/integration/test_export_dataset.py b/argilla/tests/integration/test_export_dataset.py index 63a295e06a..5cad1ac996 100644 --- a/argilla/tests/integration/test_export_dataset.py +++ b/argilla/tests/integration/test_export_dataset.py @@ -71,6 +71,7 @@ def token(): return os.getenv("HF_TOKEN_ARGILLA_INTERNAL_TESTING") +@pytest.mark.flaky(retries=3, only_on=[OSError]) # I/O error hub consistency CICD pipline @pytest.mark.parametrize("with_records_export", [True, False]) class TestDiskImportExportMixin: def test_export_dataset_to_disk( diff --git a/docs/_source/_common/next_steps.md b/docs/_source/_common/next_steps.md index 7415471930..438fa82b7e 100644 --- a/docs/_source/_common/next_steps.md +++ b/docs/_source/_common/next_steps.md @@ -11,7 +11,7 @@ tabindex="0" > -🙋‍♀️ Join the Argilla community on [Discord](http://hf.co/join/discord) and get direct support from the community in #argilla-general and #argilla-help. +🙋‍♀️ Join the Argilla community on [Discord](http://hf.co/join/discord) and get direct support from the community in #argilla-distilabel-general and #argilla-distilabel-help. ⭐ Argilla [Github repo](https://github.com/argilla-io/argilla) to stay updated about new releases and tutorials. diff --git a/docs/_source/community/contributing.md b/docs/_source/community/contributing.md index e68f01c58d..d42e16aae9 100644 --- a/docs/_source/community/contributing.md +++ b/docs/_source/community/contributing.md @@ -25,8 +25,8 @@ Discord is a handy tool for more casual conversations and to answer day-to-day q When part of the Hugging Face Discord, you can select "Channels & roles" and select "Argilla" along with any of the other groups that are interesting to you. "Argilla" will cover anything about argilla and distilabel. You can join the following channels: -* **#argilla-general**: 📣 Stay up-to-date and general discussions. -* **#argilla-help**: 🙋‍♀️ Need assistance? We're always here to help. Select the appropriate label (argilla or distilabel) for your issue and post it. +* **#argilla-distilabel-general**: 📣 Stay up-to-date and general discussions. +* **#argilla-distilabel-help**: 🙋‍♀️ Need assistance? We're always here to help. Select the appropriate label (argilla or distilabel) for your issue and post it. So now there is only one thing left to do: introduce yourself and talk to the community. You'll always be welcome! 🤗👋 From 903432919428c0deb108b8ff641aede15b4a4167 Mon Sep 17 00:00:00 2001 From: Leire Date: Thu, 25 Jul 2024 16:26:40 +0200 Subject: [PATCH 04/13] fix: UI - styles for task distribution section in settings (#5310) - [x] Fix missing styles in distribution task section in setting page - [x] Update translation for the error toast --- .../settings/SettingsInfoReadOnly.vue | 33 +++++++++++++++++-- argilla-frontend/translation/en.js | 2 +- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/argilla-frontend/components/features/annotation/settings/SettingsInfoReadOnly.vue b/argilla-frontend/components/features/annotation/settings/SettingsInfoReadOnly.vue index 0b74ba23d9..446d753e9e 100644 --- a/argilla-frontend/components/features/annotation/settings/SettingsInfoReadOnly.vue +++ b/argilla-frontend/components/features/annotation/settings/SettingsInfoReadOnly.vue @@ -26,13 +26,13 @@ v-text="$t('taskDistribution')" /> -
+
@@ -94,4 +94,33 @@ export default { margin: 0; } } + +.form-group { + display: flex; + align-items: center; + width: 100%; + gap: $base-space; + &__input--read-only { + display: flex; + flex-direction: row; + align-items: center; + width: 80px; + height: 24px; + padding: $base-space * 2; + border: 1px solid $black-20; + border-radius: $border-radius; + background: $black-4; + border: 1px solid $black-20; + opacity: 0.6; + } +} +.info-icon { + color: $black-37; + margin-right: $base-space * 2; + &[data-title] { + position: relative; + overflow: visible; + @include tooltip-mini("top", $base-space); + } +} diff --git a/argilla-frontend/translation/en.js b/argilla-frontend/translation/en.js index b67c315726..d8fb47072f 100644 --- a/argilla-frontend/translation/en.js +++ b/argilla-frontend/translation/en.js @@ -238,7 +238,7 @@ export default { }, update_distribution_with_existing_responses: { message: - "Distribution settings cannot be modified for a published dataset", + "This dataset has responses. Task distribution settings cannot be modified", }, }, http: { From bc9ae81c6a81f3dc3a8fc2be10c38cde3e3d32c7 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Mon, 29 Jul 2024 08:20:51 +0200 Subject: [PATCH 05/13] [BUGFIX] `argilla-server`: Increase oauth state cookie max age to 90 seconds (#5314) # Description This PR fixes errors when using OAuth. The default max-age for the OAuth `state` (a random code used for security reasons) has been increased to avoid cookie expiration making the flow fail. **Type of change** - Bug fix (non-breaking change which fixes an issue) **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- .../security/authentication/oauth2/providers/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py index 239a7c9fcf..ef6586f92d 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py @@ -49,7 +49,7 @@ class OAuth2ClientProvider: """OAuth2 flow handler of a certain provider.""" OAUTH_STATE_COOKIE_NAME = "oauth2_state" - OAUTH_STATE_COOKIE_MAX_AGE = 10 + OAUTH_STATE_COOKIE_MAX_AGE = 90 name: ClassVar[str] backend_class: ClassVar[Type[BaseOAuth2]] From 9f2d6dcafd269660377b77b200c57591a82bd106 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dami=C3=A1n=20Pumar?= Date: Mon, 29 Jul 2024 12:02:54 +0200 Subject: [PATCH 06/13] =?UTF-8?q?=F0=9F=90=9B=20Fix=20oauth=20(#5316)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Paco Aranda Co-authored-by: Francisco Aranda --- argilla-frontend/middleware/auth-guard.ts | 11 +++- argilla-frontend/v1/di/di.ts | 8 +-- .../v1/domain/services/IAuthService.ts | 7 +++ .../v1/domain/services/IOAuthRepository.ts | 2 +- .../domain/usecases/oauth-login-use-case.ts | 14 ++++- .../usecases/oauth-login-usecase.test.ts | 58 +++++++++++++++++++ .../repositories/OAuthRepository.ts | 6 +- 7 files changed, 91 insertions(+), 15 deletions(-) create mode 100644 argilla-frontend/v1/domain/services/IAuthService.ts create mode 100644 argilla-frontend/v1/domain/usecases/oauth-login-usecase.test.ts diff --git a/argilla-frontend/middleware/auth-guard.ts b/argilla-frontend/middleware/auth-guard.ts index 7e97a3f596..f333e69d1e 100644 --- a/argilla-frontend/middleware/auth-guard.ts +++ b/argilla-frontend/middleware/auth-guard.ts @@ -24,7 +24,9 @@ export default ({ $auth, route, redirect }: Context) => { switch (route.name) { case "sign-in": if ($auth.loggedIn) return redirect("/"); + if (route.params.omitCTA) return; + if (isRunningOnHuggingFace()) { // eslint-disable-next-line @typescript-eslint/no-unused-vars const { redirect: _, ...query } = route.query; @@ -36,12 +38,15 @@ export default ({ $auth, route, redirect }: Context) => { } break; case "oauth-provider-callback": - if (!Object.keys(route.query).length) redirect("/"); + if (!Object.keys(route.query).length) return redirect("/"); + break; case "welcome-hf-sign-in": - if (!isRunningOnHuggingFace()) redirect("/"); - break; + if ($auth.loggedIn) return redirect("/"); + + if (!isRunningOnHuggingFace()) return redirect("/"); + break; default: if (!$auth.loggedIn) { if (route.path !== "/") { diff --git a/argilla-frontend/v1/di/di.ts b/argilla-frontend/v1/di/di.ts index d176256e27..6a3ff2975f 100644 --- a/argilla-frontend/v1/di/di.ts +++ b/argilla-frontend/v1/di/di.ts @@ -73,9 +73,7 @@ export const loadDependencyContainer = (context: Context) => { register(VectorRepository).withDependency(useAxios).build(), register(AgentRepository).withDependency(useAxios).build(), register(EnvironmentRepository).withDependency(useAxios).build(), - register(OAuthRepository) - .withDependencies(useAxios, useRoutes, useAuth) - .build(), + register(OAuthRepository).withDependencies(useAxios, useRoutes).build(), register(WorkspaceRepository).withDependency(useAxios).build(), register(DeleteDatasetUseCase).withDependency(DatasetRepository).build(), @@ -184,7 +182,9 @@ export const loadDependencyContainer = (context: Context) => { .withDependency(EnvironmentRepository) .build(), - register(OAuthLoginUseCase).withDependency(OAuthRepository).build(), + register(OAuthLoginUseCase) + .withDependencies(OAuthRepository, useAuth) + .build(), ]; Container.register(dependencies); diff --git a/argilla-frontend/v1/domain/services/IAuthService.ts b/argilla-frontend/v1/domain/services/IAuthService.ts new file mode 100644 index 0000000000..9b203c9a92 --- /dev/null +++ b/argilla-frontend/v1/domain/services/IAuthService.ts @@ -0,0 +1,7 @@ +import { HTTPResponse } from "@nuxtjs/auth-next"; + +export interface IAuthService { + logout(...args: unknown[]): Promise; + + setUserToken(token: string): Promise; +} diff --git a/argilla-frontend/v1/domain/services/IOAuthRepository.ts b/argilla-frontend/v1/domain/services/IOAuthRepository.ts index eb1d0a11e1..02c3863c1c 100644 --- a/argilla-frontend/v1/domain/services/IOAuthRepository.ts +++ b/argilla-frontend/v1/domain/services/IOAuthRepository.ts @@ -9,5 +9,5 @@ export interface IOAuthRepository { authorize(provider: ProviderType): void; - login(provider: ProviderType, oauthParams: OAuthParams): Promise; + login(provider: ProviderType, oauthParams: OAuthParams): Promise; } diff --git a/argilla-frontend/v1/domain/usecases/oauth-login-use-case.ts b/argilla-frontend/v1/domain/usecases/oauth-login-use-case.ts index 88125c504b..f6db87aae1 100644 --- a/argilla-frontend/v1/domain/usecases/oauth-login-use-case.ts +++ b/argilla-frontend/v1/domain/usecases/oauth-login-use-case.ts @@ -3,10 +3,14 @@ import { OAuthProvider, ProviderType, } from "../entities/oauth/OAuthProvider"; +import { IAuthService } from "../services/IAuthService"; import { IOAuthRepository } from "../services/IOAuthRepository"; export class OAuthLoginUseCase { - constructor(private readonly oauthRepository: IOAuthRepository) {} + constructor( + private readonly oauthRepository: IOAuthRepository, + private readonly auth: IAuthService + ) {} async getProviders(): Promise { try { @@ -20,7 +24,11 @@ export class OAuthLoginUseCase { return this.oauthRepository.authorize(provider); } - login(provider: ProviderType, oauthParams: OAuthParams) { - return this.oauthRepository.login(provider, oauthParams); + async login(provider: ProviderType, oauthParams: OAuthParams) { + await this.auth.logout(); + + const token = await this.oauthRepository.login(provider, oauthParams); + + if (token) await this.auth.setUserToken(token); } } diff --git a/argilla-frontend/v1/domain/usecases/oauth-login-usecase.test.ts b/argilla-frontend/v1/domain/usecases/oauth-login-usecase.test.ts new file mode 100644 index 0000000000..5a7e7f6736 --- /dev/null +++ b/argilla-frontend/v1/domain/usecases/oauth-login-usecase.test.ts @@ -0,0 +1,58 @@ +import { mock } from "@codescouts/test/jest"; +import { IOAuthRepository } from "../services/IOAuthRepository"; +import { IAuthService } from "../services/IAuthService"; +import { OAuthLoginUseCase } from "./oauth-login-use-case"; + +describe("OAuthLoginUseCase should", () => { + test("logout user before call the login on backend", async () => { + const oauthRepository = mock(); + const auth = mock(); + + oauthRepository.login.mockRejectedValue(new Error("FAKE")); + + const useCase = new OAuthLoginUseCase(oauthRepository, auth); + + try { + await useCase.login("huggingface", null); + } catch { + expect(auth.logout).toHaveBeenCalledTimes(1); + } + }); + + test("call the backend after logout", async () => { + const oauthRepository = mock(); + const auth = mock(); + + const useCase = new OAuthLoginUseCase(oauthRepository, auth); + + await useCase.login("huggingface", null); + + expect(oauthRepository.login).toHaveBeenCalledWith("huggingface", null); + }); + + test("save token if the token is defined", async () => { + const oauthRepository = mock(); + const auth = mock(); + + oauthRepository.login.mockResolvedValue("FAKE_TOKEN"); + + const useCase = new OAuthLoginUseCase(oauthRepository, auth); + + await useCase.login("huggingface", null); + + expect(auth.setUserToken).toHaveBeenCalledWith("FAKE_TOKEN"); + }); + + test("no save token if the token is not defined", async () => { + const oauthRepository = mock(); + const auth = mock(); + + oauthRepository.login.mockResolvedValue(""); + + const useCase = new OAuthLoginUseCase(oauthRepository, auth); + + await useCase.login("huggingface", null); + + expect(auth.setUserToken).toHaveBeenCalledTimes(0); + }); +}); diff --git a/argilla-frontend/v1/infrastructure/repositories/OAuthRepository.ts b/argilla-frontend/v1/infrastructure/repositories/OAuthRepository.ts index dd871cd839..7b5a515ff2 100644 --- a/argilla-frontend/v1/infrastructure/repositories/OAuthRepository.ts +++ b/argilla-frontend/v1/infrastructure/repositories/OAuthRepository.ts @@ -1,5 +1,4 @@ import { type NuxtAxiosInstance } from "@nuxtjs/axios"; -import { Auth } from "@nuxtjs/auth-next"; import { Response } from "../types"; import { useRunningEnvironment } from "../services/useRunningEnvironment"; import { largeCache } from "./AxiosCache"; @@ -24,8 +23,7 @@ export class OAuthRepository implements IOAuthRepository { private readonly axios: NuxtAxiosInstance; constructor( axios: NuxtAxiosInstance, - private readonly router: RouterService, - private readonly auth: Auth + private readonly router: RouterService ) { this.axios = axios.create({ withCredentials: false, @@ -74,7 +72,7 @@ export class OAuthRepository implements IOAuthRepository { params, }); - if (data.access_token) await this.auth.setUserToken(data.access_token); + return data.access_token; } catch (error) { throw { response: OAUTH_API_ERRORS.ERROR_FETCHING_OAUTH_ACCESS_TOKEN, From 3430b662d3c32f24822bcb33d37f8c4d57d9fa28 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Mon, 29 Jul 2024 13:00:18 +0200 Subject: [PATCH 07/13] [REFACTOR] `argilla-server`: rewiew OAuth owner condition (#5313) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description This PR changes the behavior to detect the user role when an OAuth sign-in occurs. - If the connected user matches the `USERNAME`, the connected user becomes an owner. - The rest of the users will be defined as `annotator`. - All logic related to roles in ORG has been removed until finding a proper auth scope. **Type of change** - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: José Francisco Calvo --- .../docker/argilla-hf-spaces/scripts/start.sh | 10 +++ .../scripts/start_argilla_server.sh | 10 --- .../api/errors/v1/exception_handlers.py | 10 +++ .../argilla_server/api/handlers/v1/oauth2.py | 55 +++++------- .../integrations/huggingface/spaces.py | 2 +- .../security/authentication/claims.py | 11 ++- .../oauth2/providers/_huggingface.py | 42 +-------- .../tests/unit/api/handlers/v1/test_oauth2.py | 19 ++-- .../oauth2/providers/__init__.py | 14 --- .../providers/test_huggingface_provider.py | 88 ------------------- .../security/authentication/test_userinfo.py | 8 ++ 11 files changed, 73 insertions(+), 196 deletions(-) delete mode 100644 argilla-server/tests/unit/security/authentication/oauth2/providers/__init__.py delete mode 100644 argilla-server/tests/unit/security/authentication/oauth2/providers/test_huggingface_provider.py diff --git a/argilla-server/docker/argilla-hf-spaces/scripts/start.sh b/argilla-server/docker/argilla-hf-spaces/scripts/start.sh index e1672fe31e..28b5e2a1e1 100644 --- a/argilla-server/docker/argilla-hf-spaces/scripts/start.sh +++ b/argilla-server/docker/argilla-hf-spaces/scripts/start.sh @@ -2,4 +2,14 @@ set -e +# Preset oauth env vars based on injected space variables. +# See https://huggingface.co/docs/hub/en/spaces-oauth#create-an-oauth-app +export OAUTH2_HUGGINGFACE_CLIENT_ID=$OAUTH_CLIENT_ID +export OAUTH2_HUGGINGFACE_CLIENT_SECRET=$OAUTH_CLIENT_SECRET +export OAUTH2_HUGGINGFACE_SCOPE=$OAUTH_SCOPES + +# Set the space author name as username if no provided. +# See https://huggingface.co/docs/hub/en/spaces-overview#helper-environment-variables for more details +export USERNAME="${USERNAME:-$SPACE_AUTHOR_NAME}" + honcho start diff --git a/argilla-server/docker/argilla-hf-spaces/scripts/start_argilla_server.sh b/argilla-server/docker/argilla-hf-spaces/scripts/start_argilla_server.sh index 93cd012346..2333da11b0 100755 --- a/argilla-server/docker/argilla-hf-spaces/scripts/start_argilla_server.sh +++ b/argilla-server/docker/argilla-hf-spaces/scripts/start_argilla_server.sh @@ -2,19 +2,9 @@ set -e -# Preset oauth env vars based on injected space variables. -# See https://huggingface.co/docs/hub/en/spaces-oauth#create-an-oauth-app -export OAUTH2_HUGGINGFACE_CLIENT_ID=$OAUTH_CLIENT_ID -export OAUTH2_HUGGINGFACE_CLIENT_SECRET=$OAUTH_CLIENT_SECRET -export OAUTH2_HUGGINGFACE_SCOPE=$OAUTH_SCOPES - echo "Running database migrations" python -m argilla_server database migrate -# Set the space author name as username if no provided. -# See https://huggingface.co/docs/hub/en/spaces-overview#helper-environment-variables for more details -USERNAME="${USERNAME:-$SPACE_AUTHOR_NAME}" - if [ -n "$USERNAME" ] && [ -n "$PASSWORD" ]; then echo "Creating owner user with username ${USERNAME}" python -m argilla_server database users create \ diff --git a/argilla-server/src/argilla_server/api/errors/v1/exception_handlers.py b/argilla-server/src/argilla_server/api/errors/v1/exception_handlers.py index 5d9aa03732..c3bf0f4633 100644 --- a/argilla-server/src/argilla_server/api/errors/v1/exception_handlers.py +++ b/argilla-server/src/argilla_server/api/errors/v1/exception_handlers.py @@ -19,6 +19,16 @@ def add_exception_handlers(app: FastAPI): + @app.exception_handler(errors.AuthenticationError) + async def authentication_error(request, exc): + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + # TODO: Once we move to v2.0 we can remove the content using detail attribute + # and use the new one using code and message. + # content={"code": exc.code, "message": exc.message}, + content={"detail": str(exc)}, + ) + @app.exception_handler(errors.NotFoundError) async def not_found_error_exception_handler(request, exc): return JSONResponse( diff --git a/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py b/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py index b1f61ef09c..5c8908bf80 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Request, Path +from fastapi import APIRouter, Depends, Request, Path from fastapi.responses import RedirectResponse from sqlalchemy.ext.asyncio import AsyncSession @@ -23,10 +23,9 @@ from argilla_server.contexts import accounts from argilla_server.database import get_async_db from argilla_server.enums import UserRole -from argilla_server.errors.future import AuthenticationError, NotFoundError +from argilla_server.errors.future import NotFoundError from argilla_server.models import User -from argilla_server.pydantic_v1 import Field, ValidationError -from argilla_server.security.authentication.jwt import JWT +from argilla_server.pydantic_v1 import Field from argilla_server.security.authentication.oauth2 import OAuth2ClientProvider from argilla_server.security.authentication.userinfo import UserInfo from argilla_server.security.settings import settings @@ -74,32 +73,22 @@ async def get_access_token( provider: OAuth2ClientProvider = Depends(get_provider_by_name_or_raise), db: AsyncSession = Depends(get_async_db), ) -> Token: - try: - user_info = UserInfo(await provider.get_user_data(request)).use_claims(provider.claims) - user = await User.get_by(db, username=user_info.username) - if user is None: - try: - user_create = UserOAuthCreate( - username=user_info.username, - first_name=user_info.first_name, - role=user_info.role, - ) - except ValidationError as ex: - raise AuthenticationError("Could not authenticate user") from ex - - user = await accounts.create_user_with_random_password( - db, - **user_create.dict(exclude_unset=True), - workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces], - ) - telemetry.track_user_created(user, is_oauth=True) - - elif user.role != user_info.role: - raise AuthenticationError("Could not authenticate user") - - return Token(access_token=JWT.create(user_info)) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) from e - # TODO: Create exception handler for AuthenticationError - except AuthenticationError as e: - raise HTTPException(status_code=401, detail=str(e)) from e + userinfo = UserInfo(await provider.get_user_data(request)).use_claims(provider.claims) + + if not userinfo.username: + raise RuntimeError("OAuth error: Missing username") + + user = await User.get_by(db, username=userinfo.username) + if user is None: + user = await accounts.create_user_with_random_password( + db, + **UserOAuthCreate( + username=userinfo.username, + first_name=userinfo.first_name, + role=userinfo.role, + ).dict(exclude_unset=True), + workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces], + ) + telemetry.track_user_created(user, is_oauth=True) + + return Token(access_token=accounts.generate_user_token(user)) diff --git a/argilla-server/src/argilla_server/integrations/huggingface/spaces.py b/argilla-server/src/argilla_server/integrations/huggingface/spaces.py index 6b95187866..04140b1438 100644 --- a/argilla-server/src/argilla_server/integrations/huggingface/spaces.py +++ b/argilla-server/src/argilla_server/integrations/huggingface/spaces.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pydantic import BaseSettings, Field +from argilla_server.pydantic_v1 import BaseSettings, Field class HuggingfaceSettings(BaseSettings): diff --git a/argilla-server/src/argilla_server/security/authentication/claims.py b/argilla-server/src/argilla_server/security/authentication/claims.py index 4fd4d50d53..a34696a685 100644 --- a/argilla-server/src/argilla_server/security/authentication/claims.py +++ b/argilla-server/src/argilla_server/security/authentication/claims.py @@ -11,8 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from typing import Any, Callable, Union, Optional -from typing import Any, Callable, Union +from argilla_server.enums import UserRole + + +def _parse_role_from_environment(userinfo: dict) -> Optional[UserRole]: + """This is a temporal solution, and it will be replaced by a proper Sign up process""" + if userinfo["username"] == os.getenv("USERNAME"): + return UserRole.owner class Claims(dict): @@ -29,3 +37,4 @@ def __init__(self, seq=None, **kwargs) -> None: self["identity"] = kwargs.get("identity", self.get("identity", "sub")) self["picture"] = kwargs.get("picture", self.get("picture", "picture")) self["email"] = kwargs.get("email", self.get("email", "email")) + self["role"] = kwargs.get("role", _parse_role_from_environment) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_huggingface.py b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_huggingface.py index ea7bb23a79..57365930d8 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_huggingface.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_huggingface.py @@ -13,12 +13,9 @@ # limitations under the License. import logging -from typing import Union, Optional from social_core.backends.open_id_connect import OpenIdConnectAuth -from argilla_server.enums import UserRole -from argilla_server.integrations.huggingface.spaces import HUGGINGFACE_SETTINGS from argilla_server.logging import LoggingMixin from argilla_server.security.authentication.claims import Claims from argilla_server.security.authentication.oauth2.providers._base import OAuth2ClientProvider @@ -42,46 +39,9 @@ def oidc_endpoint(self) -> str: _HF_PREFERRED_USERNAME = "preferred_username" -def _is_space_author(userinfo: dict, space_author: str) -> bool: - """Return True if the space author name is the userinfo username. Otherwise, False""" - return space_author and space_author == userinfo.get(_HF_PREFERRED_USERNAME) - - -def _find_org_from_userinfo(userinfo: dict, org_name: str) -> Optional[dict]: - """Find the organization by name from the userinfo""" - for org in userinfo.get("orgs") or []: - if org_name == org.get(_HF_PREFERRED_USERNAME): - return org - - -def _get_user_role_by_org(org: dict) -> Union[UserRole, None]: - """Return the computed UserRole from the role found in a organization (if any)""" - _ROLE_IN_ORG = "roleInOrg" - _ROLES_MAPPING = {"admin": UserRole.owner} - - org_role = None - if _ROLE_IN_ORG not in org: - _LOGGER.warning(f"Cannot find the user role info in org {org}. Review granted permissions") - else: - org_role = org[_ROLE_IN_ORG] - - return _ROLES_MAPPING.get(org_role) or UserRole.annotator - - class HuggingfaceClientProvider(OAuth2ClientProvider, LoggingMixin): """Specialized HuggingFace OAuth2 provider.""" - @staticmethod - def parse_role_from_userinfo(userinfo: dict) -> Union[str, None]: - """Parse the Argilla user role from info provided as part of the user info""" - space_author_name = HUGGINGFACE_SETTINGS.space_author_name - - if _is_space_author(userinfo, space_author_name): - return UserRole.owner - elif org := _find_org_from_userinfo(userinfo, space_author_name): - return _get_user_role_by_org(org) - return UserRole.annotator - - claims = Claims(username=_HF_PREFERRED_USERNAME, role=parse_role_from_userinfo, first_name="name") + claims = Claims(username=_HF_PREFERRED_USERNAME, first_name="name") backend_class = HuggingfaceOpenId name = "huggingface" diff --git a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py index 34abf0e415..7f8d6ddee1 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py @@ -159,7 +159,7 @@ async def test_provider_huggingface_access_token( assert JWT.decode(json_response["access_token"])["username"] == "username" assert json_response["token_type"] == "bearer" - user = (await db.execute(select(User).where(User.username == "username"))).scalar_one_or_none() + user = await db.scalar(select(User).filter_by(username="username")) assert user is not None assert user.role == UserRole.annotator @@ -182,7 +182,7 @@ async def test_provider_huggingface_access_token_with_missing_username( cookies={"oauth2_state": "valid"}, ) - assert response.status_code == 401 + assert response.status_code == 500 async def test_provider_huggingface_access_token_with_missing_name( self, @@ -244,7 +244,7 @@ async def test_provider_access_token_with_not_found_code( response = await async_client.get( "/api/v1/oauth2/providers/huggingface/access-token", headers=owner_auth_header ) - assert response.status_code == 400 + assert response.status_code == 422 assert response.json() == {"detail": "'code' parameter was not found in callback request"} async def test_provider_access_token_with_not_found_state( @@ -254,7 +254,7 @@ async def test_provider_access_token_with_not_found_state( response = await async_client.get( "/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code"}, headers=owner_auth_header ) - assert response.status_code == 400 + assert response.status_code == 422 assert response.json() == {"detail": "'state' parameter was not found in callback request"} async def test_provider_access_token_with_invalid_state( @@ -267,7 +267,7 @@ async def test_provider_access_token_with_invalid_state( headers=owner_auth_header, cookies={"oauth2_state": "valid"}, ) - assert response.status_code == 400 + assert response.status_code == 422 assert response.json() == {"detail": "'state' parameter does not match"} async def test_provider_access_token_with_authentication_error( @@ -287,7 +287,7 @@ async def test_provider_access_token_with_authentication_error( assert response.status_code == 401 assert response.json() == {"detail": "error"} - async def test_provider_access_token_with_unauthorized_user( + async def test_provider_access_token_with_already_created_user( self, async_client: AsyncClient, db: AsyncSession, @@ -307,8 +307,11 @@ async def test_provider_access_token_with_unauthorized_user( headers=owner_auth_header, cookies={"oauth2_state": "valid"}, ) - assert response.status_code == 401 - assert response.json() == {"detail": "Could not authenticate user"} + assert response.status_code == 200 + + userinfo = JWT.decode(response.json()["access_token"]) + assert userinfo["username"] == admin.username + assert userinfo["role"] == admin.role async def test_provider_access_token_with_same_username( self, diff --git a/argilla-server/tests/unit/security/authentication/oauth2/providers/__init__.py b/argilla-server/tests/unit/security/authentication/oauth2/providers/__init__.py deleted file mode 100644 index 4b6cecae7f..0000000000 --- a/argilla-server/tests/unit/security/authentication/oauth2/providers/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/argilla-server/tests/unit/security/authentication/oauth2/providers/test_huggingface_provider.py b/argilla-server/tests/unit/security/authentication/oauth2/providers/test_huggingface_provider.py deleted file mode 100644 index 88b526212e..0000000000 --- a/argilla-server/tests/unit/security/authentication/oauth2/providers/test_huggingface_provider.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pytest_mock import MockerFixture - -from argilla_server.enums import UserRole -from argilla_server.integrations.huggingface.spaces import HuggingfaceSettings -from argilla_server.security.authentication.oauth2.providers import HuggingfaceClientProvider -from argilla_server.security.authentication.oauth2.providers import _huggingface - - -class TestHuggingfaceOauthProvider: - def test_parse_role_from_userinfo_for_space_author(self, mocker: "MockerFixture"): - mocker.patch.object(_huggingface, "HUGGINGFACE_SETTINGS", HuggingfaceSettings(space_author_name="author")) - - userinfo = {"preferred_username": "author"} - role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) - assert role == UserRole.owner - - def test_parse_role_without_spaces_info(self, mocker: "MockerFixture"): - mocker.patch.object(_huggingface, "HUGGINGFACE_SETTINGS", HuggingfaceSettings(space_author_name=None)) - - userinfo = {"preferred_username": "author"} - role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) - assert role == UserRole.annotator - - def test_parse_role_with_different_author_name(self, mocker: "MockerFixture"): - mocker.patch.object(_huggingface, "HUGGINGFACE_SETTINGS", HuggingfaceSettings(space_author_name="other")) - - userinfo = {"preferred_username": "author"} - role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) - assert role == UserRole.annotator - - def test_parse_role_with_missing_username(self): - userinfo = {} - role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) - assert role == UserRole.annotator - - def test_parse_role_with_admin_role_in_org(self, mocker: "MockerFixture"): - mocker.patch.object(_huggingface, "HUGGINGFACE_SETTINGS", HuggingfaceSettings(space_author_name="org")) - - userinfo = { - "preferred_username": "author", - "orgs": [{"preferred_username": "org", "roleInOrg": "admin"}], - } - role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) - assert role == UserRole.owner - - def test_parse_role_with_non_admin_role_in_org(self, mocker: "MockerFixture"): - mocker.patch.object(_huggingface, "HUGGINGFACE_SETTINGS", HuggingfaceSettings(space_author_name="org")) - - userinfo = { - "preferred_username": "author", - "orgs": [{"preferred_username": "org", "roleInOrg": "contributor"}], - } - role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) - assert role == UserRole.annotator - - def test_parse_role_for_other_org_author(self, mocker: "MockerFixture"): - mocker.patch.object(_huggingface, "HUGGINGFACE_SETTINGS", HuggingfaceSettings(space_author_name="other_org")) - - userinfo = { - "preferred_username": "author", - "orgs": [{"preferred_username": "org", "roleInOrg": "contributor"}], - } - role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) - assert role == UserRole.annotator - - def test_parse_role_with_missing_org_role_info(self, mocker: "MockerFixture"): - mocker.patch.object(_huggingface, "HUGGINGFACE_SETTINGS", HuggingfaceSettings(space_author_name="org")) - - userinfo = { - "preferred_username": "author", - "orgs": [{"preferred_username": "org"}], - } - role = HuggingfaceClientProvider.parse_role_from_userinfo(userinfo) - assert role == UserRole.annotator diff --git a/argilla-server/tests/unit/security/authentication/test_userinfo.py b/argilla-server/tests/unit/security/authentication/test_userinfo.py index 918e99ba5e..8203d56c66 100644 --- a/argilla-server/tests/unit/security/authentication/test_userinfo.py +++ b/argilla-server/tests/unit/security/authentication/test_userinfo.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import pytest +from pytest_mock import MockerFixture from argilla_server.enums import UserRole from argilla_server.security.authentication import UserInfo @@ -51,3 +53,9 @@ def test_get_userinfo_with_claims(self): assert userinfo.first_name == "USER" assert userinfo.last_name == "Peter" + + def test_get_userinfo_role_with_username_env(self, mocker: MockerFixture): + mocker.patch.dict(os.environ, {"USERNAME": "user"}) + + userinfo = UserInfo({"id": "user"}).use_claims(Claims(username="id")) + assert userinfo.role == UserRole.owner From 445978169b87c9cbe94f6526a2e5b27eb72ef430 Mon Sep 17 00:00:00 2001 From: Leire Date: Mon, 29 Jul 2024 13:14:53 +0200 Subject: [PATCH 08/13] fix: restore data after an error while updating the distribution task setting (#5312) --- .../annotation/container/fields/Record.vue | 2 +- .../annotation/settings/SettingsInfo.vue | 4 +- .../annotation/settings/SettingsMetadata.vue | 2 +- .../annotation/settings/SettingsQuestions.vue | 6 ++- .../domain/entities/dataset/Dataset.test.ts | 50 ++++++++----------- .../v1/domain/entities/dataset/Dataset.ts | 16 +++--- .../v1/domain/entities/question/Question.ts | 7 ++- .../update-dataset-setting-use-case.ts | 8 ++- 8 files changed, 50 insertions(+), 45 deletions(-) diff --git a/argilla-frontend/components/features/annotation/container/fields/Record.vue b/argilla-frontend/components/features/annotation/container/fields/Record.vue index 620c2c8aa1..17f4eea661 100644 --- a/argilla-frontend/components/features/annotation/container/fields/Record.vue +++ b/argilla-frontend/components/features/annotation/container/fields/Record.vue @@ -62,7 +62,7 @@ export default { if ( this.record?.questions .filter((q) => q.isSpanType) - .some((q) => q.isModified) + .some((q) => q.isAnswerModified) ) { this.onSelectedRecord(true); } diff --git a/argilla-frontend/components/features/annotation/settings/SettingsInfo.vue b/argilla-frontend/components/features/annotation/settings/SettingsInfo.vue index 4176a3e404..737d957efb 100644 --- a/argilla-frontend/components/features/annotation/settings/SettingsInfo.vue +++ b/argilla-frontend/components/features/annotation/settings/SettingsInfo.vue @@ -56,7 +56,7 @@ @@ -91,7 +91,7 @@ diff --git a/argilla-frontend/components/features/annotation/settings/SettingsMetadata.vue b/argilla-frontend/components/features/annotation/settings/SettingsMetadata.vue index 088b88e44c..ffa96d8b84 100644 --- a/argilla-frontend/components/features/annotation/settings/SettingsMetadata.vue +++ b/argilla-frontend/components/features/annotation/settings/SettingsMetadata.vue @@ -60,7 +60,7 @@ diff --git a/argilla-frontend/components/features/annotation/settings/SettingsQuestions.vue b/argilla-frontend/components/features/annotation/settings/SettingsQuestions.vue index 8b6751cfea..d320da948b 100644 --- a/argilla-frontend/components/features/annotation/settings/SettingsQuestions.vue +++ b/argilla-frontend/components/features/annotation/settings/SettingsQuestions.vue @@ -106,14 +106,16 @@ type="button" class="secondary light small" @on-click="restore(question)" - :disabled="!question.isModified" + :disabled="!question.isSettingsModified" > diff --git a/argilla-frontend/v1/domain/entities/dataset/Dataset.test.ts b/argilla-frontend/v1/domain/entities/dataset/Dataset.test.ts index 1228cdb77b..7fed4a0fae 100644 --- a/argilla-frontend/v1/domain/entities/dataset/Dataset.test.ts +++ b/argilla-frontend/v1/domain/entities/dataset/Dataset.test.ts @@ -66,47 +66,41 @@ describe("Dataset", () => { }); describe("restore", () => { - describe("restoreDistribution should", () => { - test("restore only the distribution based on original values", () => { - const dataset = createEmptyDataset(); - dataset.distribution.minSubmitted = 20; + test("restore only the distribution based on original values", () => { + const dataset = createEmptyDataset(); + dataset.distribution.minSubmitted = 20; - expect(dataset.distribution).not.toEqual(dataset.original.distribution); + expect(dataset.distribution).not.toEqual(dataset.original.distribution); - dataset.restoreDistribution(); + dataset.restore("distribution"); - expect(dataset.distribution).toEqual(dataset.original.distribution); - }); + expect(dataset.distribution).toEqual(dataset.original.distribution); }); - describe("restoreMetadata should", () => { - test("restore only the metadata info based on original values", () => { - const dataset = createEmptyDataset(); - dataset.allowExtraMetadata = true; + test("restore only the metadata info based on original values", () => { + const dataset = createEmptyDataset(); + dataset.allowExtraMetadata = true; - expect(dataset.allowExtraMetadata).not.toEqual( - dataset.original.allowExtraMetadata - ); + expect(dataset.allowExtraMetadata).not.toEqual( + dataset.original.allowExtraMetadata + ); - dataset.restoreMetadata(); + dataset.restore("metadata"); - expect(dataset.allowExtraMetadata).toEqual( - dataset.original.allowExtraMetadata - ); - }); + expect(dataset.allowExtraMetadata).toEqual( + dataset.original.allowExtraMetadata + ); }); - describe("restoreGuidelines should", () => { - test("restore only the guidelines based on original values", () => { - const dataset = createEmptyDataset(); - dataset.guidelines = "NEW GUIDELINES"; + test("restore only the guidelines based on original values", () => { + const dataset = createEmptyDataset(); + dataset.guidelines = "NEW GUIDELINES"; - expect(dataset.guidelines).not.toEqual(dataset.original.guidelines); + expect(dataset.guidelines).not.toEqual(dataset.original.guidelines); - dataset.restoreGuidelines(); + dataset.restore("guidelines"); - expect(dataset.guidelines).toEqual(dataset.original.guidelines); - }); + expect(dataset.guidelines).toEqual(dataset.original.guidelines); }); }); diff --git a/argilla-frontend/v1/domain/entities/dataset/Dataset.ts b/argilla-frontend/v1/domain/entities/dataset/Dataset.ts index 15103f1dd6..7d4eb76e37 100644 --- a/argilla-frontend/v1/domain/entities/dataset/Dataset.ts +++ b/argilla-frontend/v1/domain/entities/dataset/Dataset.ts @@ -58,21 +58,23 @@ export class Dataset { ); } - restore() { - this.restoreGuidelines(); - this.restoreMetadata(); - this.restoreDistribution(); + restore(part: "guidelines" | "metadata" | "distribution") { + if (part === "guidelines") return this.restoreGuidelines(); + + if (part === "metadata") return this.restoreMetadata(); + + if (part === "distribution") return this.restoreDistribution(); } - restoreGuidelines() { + private restoreGuidelines() { this.guidelines = this.original.guidelines; } - restoreMetadata() { + private restoreMetadata() { this.allowExtraMetadata = this.original.allowExtraMetadata; } - restoreDistribution() { + private restoreDistribution() { this.distribution = { ...this.original.distribution, }; diff --git a/argilla-frontend/v1/domain/entities/question/Question.ts b/argilla-frontend/v1/domain/entities/question/Question.ts index 4df30e3ce7..15e8fb6851 100644 --- a/argilla-frontend/v1/domain/entities/question/Question.ts +++ b/argilla-frontend/v1/domain/entities/question/Question.ts @@ -97,12 +97,15 @@ export class Question { return this.type.isRatingType; } + public get isAnswerModified(): boolean { + return !this.answer.isEqual(this.original.answer); + } + public get isModified(): boolean { return ( this.title !== this.original.title || this.description !== this.original.description || - !this.settings.isEqual(this.original.settings) || - !this.answer.isEqual(this.original.answer) + !this.settings.isEqual(this.original.settings) ); } diff --git a/argilla-frontend/v1/domain/usecases/dataset-setting/update-dataset-setting-use-case.ts b/argilla-frontend/v1/domain/usecases/dataset-setting/update-dataset-setting-use-case.ts index 8f8258123d..06437b9e68 100644 --- a/argilla-frontend/v1/domain/usecases/dataset-setting/update-dataset-setting-use-case.ts +++ b/argilla-frontend/v1/domain/usecases/dataset-setting/update-dataset-setting-use-case.ts @@ -8,9 +8,13 @@ export class UpdateDatasetSettingUseCase { dataset: Dataset, part: "guidelines" | "metadata" | "distribution" ) { - const response = await this.update(dataset, part); + try { + const response = await this.update(dataset, part); - dataset.update(response.when, part); + dataset.update(response.when, part); + } catch (e) { + dataset.restore(part); + } } private update( From b495d9750a6ebade0ee7ae0643c8a0c70ed33f51 Mon Sep 17 00:00:00 2001 From: Natalia Elvira <126158523+nataliaElv@users.noreply.github.com> Date: Mon, 29 Jul 2024 13:17:35 +0200 Subject: [PATCH 09/13] Docs: update migration guides (#5311) # Description **Type of change** - Documentation update **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Paco Aranda --- .../migrate_from_legacy_datasets.md | 212 ++++++++++++++---- argilla/docs/how_to_guides/record.md | 14 +- 2 files changed, 179 insertions(+), 47 deletions(-) diff --git a/argilla/docs/how_to_guides/migrate_from_legacy_datasets.md b/argilla/docs/how_to_guides/migrate_from_legacy_datasets.md index 9c54cdfd79..92dbe9074f 100644 --- a/argilla/docs/how_to_guides/migrate_from_legacy_datasets.md +++ b/argilla/docs/how_to_guides/migrate_from_legacy_datasets.md @@ -23,7 +23,6 @@ The guide will take you through three steps: 2. **Define the new dataset** in the Argilla V2 format. 3. **Upload the dataset records** to the new Argilla V2 dataset format and attributes. - ### Step 1: Retrieve the legacy dataset Connect to the Argilla V1 server via the new `argilla` package. First, you should install an extra dependency: @@ -32,6 +31,7 @@ pip install "argilla[legacy]" ``` Now, you can use the `v1` module to connect to the Argilla V1 server. + ```python import argilla.v1 as rg_v1 @@ -68,30 +68,101 @@ client = rg.Argilla() Next, define the new dataset settings: -```python -settings = rg.Settings( - fields=[ - rg.TextField(name="text"), # (1) - ], - questions=[ - rg.LabelQuestion(name="label", labels=settings_v1.label_schema), # (2) - ], - metadata=[ - rg.TermsMetadataProperty(name="split"), # (3) - ], - vectors=[ - rg.VectorField(name='mini-lm-sentence-transformers', dimensions=384), # (4) - ], -) -``` +=== "For single-label classification" + + ```python + settings = rg.Settings( + fields=[ + rg.TextField(name="text"), # (1) + ], + questions=[ + rg.LabelQuestion(name="label", labels=settings_v1.label_schema), + ], + metadata=[ + rg.TermsMetadataProperty(name="split"), # (2) + ], + vectors=[ + rg.VectorField(name='mini-lm-sentence-transformers', dimensions=384), # (3) + ], + ) + ``` + + 1. The default field in `DatasetForTextClassification` is `text`, but make sure you provide all fields included in `record.inputs`. + + 2. Make sure you provide all relevant metadata fields available in the dataset. + + 3. Make sure you provide all relevant vectors available in the dataset. + +=== "For multi-label classification" + + ```python + settings = rg.Settings( + fields=[ + rg.TextField(name="text"), # (1) + ], + questions=[ + rg.MultiLabelQuestion(name="labels", labels=settings_v1.label_schema), + ], + metadata=[ + rg.TermsMetadataProperty(name="split"), # (2) + ], + vectors=[ + rg.VectorField(name='mini-lm-sentence-transformers', dimensions=384), # (3) + ], + ) + ``` + + 1. The default field in `DatasetForTextClassification` is `text`, but we should provide all fields included in `record.inputs`. + + 2. Make sure you provide all relevant metadata fields available in the dataset. + + 3. Make sure you provide all relevant vectors available in the dataset. + +=== "For token classification" + + ```python + settings = rg.Settings( + fields=[ + rg.TextField(name="text"), + ], + questions=[ + rg.SpanQuestion(name="spans", labels=settings_v1.label_schema), + ], + metadata=[ + rg.TermsMetadataProperty(name="split"), # (1) + ], + vectors=[ + rg.VectorField(name='mini-lm-sentence-transformers', dimensions=384), # (2) + ], + ) + ``` -1. The default name for text classification is `text`, but we should provide all names included in `record.inputs`. + 1. Make sure you provide all relevant metadata fields available in the dataset. -2. The basis question for text classification is a `LabelQuestion` for single-label or `MultiLabelQuestion` for multi-label classification. + 2. Make sure you provide all relevant vectors available in the dataset. -3. Here, we need to provide all relevant metadata fields. +=== "For text generation" -4. The vectors fields available in the dataset. + ```python + settings = rg.Settings( + fields=[ + rg.TextField(name="text"), + ], + questions=[ + rg.TextQuestion(name="text_generation"), + ], + metadata=[ + rg.TermsMetadataProperty(name="split"), # (1) + ], + vectors=[ + rg.VectorField(name='mini-lm-sentence-transformers', dimensions=384), # (2) + ], + ) + ``` + + 1. We should provide all relevant metadata fields available in the dataset. + + 2. We should provide all relevant vectors available in the dataset. Finally, create the new dataset on the Argilla V2 server: @@ -127,25 +198,41 @@ Here are a set of example functions to convert the records for single-label and if prediction := data.get("prediction"): label, score = prediction[0].values() agent = data["prediction_agent"] - suggestions.append(rg.Suggestion(question_name="label", value=label, score=score, agent=agent)) + suggestions.append( + rg.Suggestion( + question_name="label", # (1) + value=label, + score=score, + agent=agent + ) + ) if annotation := data.get("annotation"): user_id = users_by_name.get(data["annotation_agent"], current_user).id - responses.append(rg.Response(question_name="label", value=annotation, user_id=user_id)) + responses.append( + rg.Response( + question_name="label", # (2) + value=annotation, + user_id=user_id + ) + ) - vectors = (data.get("vectors") or {}) return rg.Record( id=data["id"], fields=data["inputs"], # The inputs field should be a dictionary with the same keys as the `fields` in the settings metadata=data["metadata"], # The metadata field should be a dictionary with the same keys as the `metadata` in the settings - vectors=[rg.Vector(name=name, values=value) for name, value in vectors.items()], + vectors=data.get("vectors") or {}, suggestions=suggestions, responses=responses, ) ``` + 1. Make sure the `question_name` matches the name of the question in question settings. + + 2. Make sure the `question_name` matches the name of the question in question settings. + === "For multi-label classification" ```python @@ -157,25 +244,41 @@ Here are a set of example functions to convert the records for single-label and if prediction := data.get("prediction"): labels, scores = zip(*[(pred["label"], pred["score"]) for pred in prediction]) agent = data["prediction_agent"] - suggestions.append(rg.Suggestion(question_name="labels", value=labels, score=scores, agent=agent)) + suggestions.append( + rg.Suggestion( + question_name="labels", # (1) + value=labels, + score=scores, + agent=agent + ) + ) if annotation := data.get("annotation"): user_id = users_by_name.get(data["annotation_agent"], current_user).id - responses.append(rg.Response(question_name="label", value=annotation, user_id=user_id)) + responses.append( + rg.Response( + question_name="labels", # (2) + value=annotation, + user_id=user_id + ) + ) - vectors = data.get("vectors") or {} return rg.Record( id=data["id"], fields=data["inputs"], # The inputs field should be a dictionary with the same keys as the `fields` in the settings metadata=data["metadata"], # The metadata field should be a dictionary with the same keys as the `metadata` in the settings - vectors=[rg.Vector(name=name, values=value) for name, value in vectors.items()], + vectors=data.get("vectors") or {}, suggestions=suggestions, responses=responses, ) ``` + 1. Make sure the `question_name` matches the name of the question in question settings. + + 2. Make sure the `question_name` matches the name of the question in question settings. + === "For token classification" ```python @@ -187,27 +290,43 @@ Here are a set of example functions to convert the records for single-label and if prediction := data.get("prediction"): scores = [span["score"] for span in prediction] agent = data["prediction_agent"] - suggestions.append(rg.Suggestion(question_name="spans", value=prediction, score=scores, agent=agent)) + suggestions.append( + rg.Suggestion( + question_name="spans", # (1) + value=prediction, + score=scores, + agent=agent + ) + ) if annotation := data.get("annotation"): user_id = users_by_name.get(data["annotation_agent"], current_user).id - responses.append(rg.Response(question_name="spans", value=annotation, user_id=user_id)) + responses.append( + rg.Response( + question_name="spans", # (2) + value=annotation, + user_id=user_id + ) + ) - vectors = data.get("vectors") or {} return rg.Record( id=data["id"], fields={"text": data["text"]}, # The inputs field should be a dictionary with the same keys as the `fields` in the settings metadata=data["metadata"], # The metadata field should be a dictionary with the same keys as the `metadata` in the settings - vectors=[rg.Vector(name=name, values=value) for name, value in vectors.items()], + vectors=data.get("vectors") or {}, # The vectors field should be a dictionary with the same keys as the `vectors` in the settings suggestions=suggestions, responses=responses, ) ``` -=== "For Text generation" + 1. Make sure the `question_name` matches the name of the question in question settings. + + 2. Make sure the `question_name` matches the name of the question in question settings. + +=== "For text generation" ```python def map_to_record_for_text_generation(data: dict, users_by_name: dict, current_user: rg.User) -> rg.Record: @@ -219,32 +338,45 @@ Here are a set of example functions to convert the records for single-label and first = prediction[0] agent = data["prediction_agent"] suggestions.append( - rg.Suggestion(question_name="text_generation", value=first["text"], score=first["score"], agent=agent) + rg.Suggestion( + question_name="text_generation", # (1) + value=first["text"], + score=first["score"], + agent=agent + ) ) if annotation := data.get("annotation"): # From data[annotation] user_id = users_by_name.get(data["annotation_agent"], current_user).id - responses.append(rg.Response(question_name="text_generation", value=annotation, user_id=user_id)) + responses.append( + rg.Response( + question_name="text_generation", # (2) + value=annotation, + user_id=user_id + ) + ) - vectors = (data.get("vectors") or {}) return rg.Record( id=data["id"], fields={"text": data["text"]}, # The inputs field should be a dictionary with the same keys as the `fields` in the settings metadata=data["metadata"], # The metadata field should be a dictionary with the same keys as the `metadata` in the settings - vectors=[rg.Vector(name=name, values=value) for name, value in vectors.items()], + vectors=data.get("vectors") or {}, # The vectors field should be a dictionary with the same keys as the `vectors` in the settings suggestions=suggestions, responses=responses, ) ``` + 1. Make sure the `question_name` matches the name of the question in question settings. + + 2. Make sure the `question_name` matches the name of the question in question settings. + The functions above depend on the `users_by_name` dictionary and the `current_user` object to assign responses to users, we need to load the existing users. You can retrieve the users from the Argilla V2 server and the current user as follows: ```python -# For users_by_name = {user.username: user for user in client.users} current_user = client.me ``` @@ -260,5 +392,5 @@ for data in hf_records: # Upload the records to the new dataset dataset.records.log(records) ``` -You have now successfully migrated your legacy dataset to Argilla V2. For more guides on how to use the Argilla SDK, please refer to the [How to guides](index.md). +You have now successfully migrated your legacy dataset to Argilla V2. For more guides on how to use the Argilla SDK, please refer to the [How to guides](index.md). diff --git a/argilla/docs/how_to_guides/record.md b/argilla/docs/how_to_guides/record.md index b2a152ae86..dd4ae0f928 100644 --- a/argilla/docs/how_to_guides/record.md +++ b/argilla/docs/how_to_guides/record.md @@ -231,18 +231,18 @@ You can associate vectors, like text embeddings, to your records. They can be us "question": "Do you need oxygen to breathe?", "answer": "Yes" }, - vectors=[ - rg.Vector("my_vector", [0.1, 0.2, 0.3]) - ], + vectors={ + "my_vector": [0.1, 0.2, 0.3] + }, ), rg.Record( fields={ "question": "What is the boiling point of water?", "answer": "100 degrees Celsius" }, - vectors=[ - rg.Vector("my_vector", [0.2, 0.5, 0.3]) - ], + vectors={ + "my_vector": [0.2, 0.5, 0.3] + }, ), ] dataset.records.log(records) @@ -476,7 +476,7 @@ dataset.records.log(records=updated_data) for record in dataset.records(with_vectors=True): record.vectors["new_vector"] = [ 0, 1, 2, 3, 4, 5 ] - record.vector["v"] = [ 0.1, 0.2, 0.3 ] + record.vectors["v"] = [ 0.1, 0.2, 0.3 ] updated_records.append(record) From 8ac8d85cdca6a223bb9791ca2c4cb2750f1258fa Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Mon, 29 Jul 2024 13:51:58 +0200 Subject: [PATCH 10/13] [CI] `argilla`: Update services for integration tests (#5320) # Description With the latest changes on docker images, the quickstart image is not used anymore. This PR updates the environment when running integration tests to user the official argilla server image, since API_KEY cannot be set up for the new `argilla-hf-spaces` image. **Type of change** - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- .github/workflows/argilla.yml | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/.github/workflows/argilla.yml b/.github/workflows/argilla.yml index 89d6dc8110..117ca824de 100644 --- a/.github/workflows/argilla.yml +++ b/.github/workflows/argilla.yml @@ -20,16 +20,21 @@ on: jobs: build: services: - argilla-quickstart: - image: argilladev/argilla-quickstart:develop + argilla-server: + image: argilladev/argilla-server:develop ports: - 6900:6900 env: - ANNOTATOR_USERNAME: annotator - OWNER_USERNAME: argilla - OWNER_API_KEY: argilla.apikey - ADMIN_USERNAME: admin - ADMIN_API_KEY: admin.apikey + ARGILLA_ENABLE_TELEMETRY: 0 + ARGILLA_ELASTICSEARCH: http://elasticsearch:9200 + DEFAULT_USER_ENABLED: 1 + elasticsearch: + image: docker.elastic.co/elasticsearch/elasticsearch:8.8.2 + ports: + - 9200:9200 + env: + discovery.type: single-node + xpack.security.enabled: false runs-on: ubuntu-latest defaults: run: @@ -50,7 +55,7 @@ jobs: - name: Install dependencies run: | pdm install - - name: Wait for argilla-quickstart to start + - name: Wait for argilla server to start run: | while ! curl -XGET http://localhost:6900/api/_status; do sleep 5; done - name: Set huggingface hub credentials From 3b91292aa1e4771f1b9d8b2c9a7eb2ad8f676e03 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Mon, 29 Jul 2024 13:57:29 +0200 Subject: [PATCH 11/13] chore: Update and align error message (#5319) # Description **Type of change** - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- argilla-frontend/translation/de.js | 4 ++++ argilla-frontend/translation/en.js | 2 +- argilla-server/src/argilla_server/validators/datasets.py | 2 +- .../unit/api/handlers/v1/datasets/test_update_dataset.py | 2 +- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/argilla-frontend/translation/de.js b/argilla-frontend/translation/de.js index 0ef4ac053a..a7d1045dd0 100644 --- a/argilla-frontend/translation/de.js +++ b/argilla-frontend/translation/de.js @@ -224,6 +224,10 @@ export default { missing_vector: { message: "Vektor nicht im ausgewählten Datensatz gefunden", }, + update_distribution_with_existing_responses: { + message: + "Die Verteilungseinstellungen können für einen Datensatz mit Benutzerantworten nicht geändert werden", + }, }, http: { 401: { diff --git a/argilla-frontend/translation/en.js b/argilla-frontend/translation/en.js index d8fb47072f..db5b04e2b6 100644 --- a/argilla-frontend/translation/en.js +++ b/argilla-frontend/translation/en.js @@ -238,7 +238,7 @@ export default { }, update_distribution_with_existing_responses: { message: - "This dataset has responses. Task distribution settings cannot be modified", + "Distribution settings can't be modified for a dataset containing user responses", }, }, http: { diff --git a/argilla-server/src/argilla_server/validators/datasets.py b/argilla-server/src/argilla_server/validators/datasets.py index 0df3bc9bfa..490b909dbe 100644 --- a/argilla-server/src/argilla_server/validators/datasets.py +++ b/argilla-server/src/argilla_server/validators/datasets.py @@ -50,5 +50,5 @@ async def validate(cls, db: AsyncSession, dataset: Dataset, dataset_attrs: dict) async def _validate_distribution(cls, dataset: Dataset, dataset_attrs: dict) -> None: if dataset_attrs.get("distribution") is not None and (await dataset.responses_count) > 0: raise UpdateDistributionWithExistingResponsesError( - "Distribution settings cannot be modified for a dataset with records including responses" + "Distribution settings can't be modified for a dataset containing user responses" ) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py index bbba4b9140..a21bc4ad7d 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py @@ -142,7 +142,7 @@ async def test_update_dataset_distribution_for_dataset_with_responses( assert response.json() == { "code": "update_distribution_with_existing_responses", - "message": "Distribution settings cannot be modified for a dataset with records including responses", + "message": "Distribution settings can't be modified for a dataset containing user responses", } async def test_update_dataset_distribution_with_invalid_strategy( From 7d5a37762698dcb4ee80cd9ddde2485018258a38 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Mon, 29 Jul 2024 16:04:53 +0200 Subject: [PATCH 12/13] fix: 5322 bug pythondeployment fix failing integration tests import export (#5323) # Description Hub integration tests are sometimes a bid flaky due to IO errors. Closes #5322 **Type of change** - Bug fix (non-breaking change which fixes an issue) **How Has This Been Tested** NA **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- argilla/tests/integration/test_export_dataset.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/argilla/tests/integration/test_export_dataset.py b/argilla/tests/integration/test_export_dataset.py index 5cad1ac996..940ff3c6b5 100644 --- a/argilla/tests/integration/test_export_dataset.py +++ b/argilla/tests/integration/test_export_dataset.py @@ -22,6 +22,9 @@ import argilla as rg import pytest +from huggingface_hub.utils._errors import BadRequestError, FileMetadataError, HfHubHTTPError + +_RETRIES = 5 @pytest.fixture @@ -71,7 +74,7 @@ def token(): return os.getenv("HF_TOKEN_ARGILLA_INTERNAL_TESTING") -@pytest.mark.flaky(retries=3, only_on=[OSError]) # I/O error hub consistency CICD pipline +@pytest.mark.flaky(retries=_RETRIES, only_on=[OSError]) # I/O consistency CICD pipline @pytest.mark.parametrize("with_records_export", [True, False]) class TestDiskImportExportMixin: def test_export_dataset_to_disk( @@ -135,6 +138,9 @@ def test_import_dataset_from_disk( assert new_dataset.settings.questions[0].name == "label" +@pytest.mark.flaky( + retries=_RETRIES, only_on=[BadRequestError, FileMetadataError, HfHubHTTPError] +) # Hub consistency CICD pipline @pytest.mark.skipif( not os.getenv("HF_TOKEN_ARGILLA_INTERNAL_TESTING"), reason="You are missing a token to write to `argilla-internal-testing` org on the Hugging Face Hub", From ad31d608d4d222a70834ea373c167c2aacbc7641 Mon Sep 17 00:00:00 2001 From: Sara Han <127759186+sdiazlor@users.noreply.github.com> Date: Mon, 29 Jul 2024 16:58:19 +0200 Subject: [PATCH 13/13] fix: add import task distribution (#5326) # Description Added the import, so that we can use TaskDistribution as part of `import argilla as rg` Closes # **Type of change** - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- argilla/src/argilla/settings/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/argilla/src/argilla/settings/__init__.py b/argilla/src/argilla/settings/__init__.py index af38765efc..3989088018 100644 --- a/argilla/src/argilla/settings/__init__.py +++ b/argilla/src/argilla/settings/__init__.py @@ -17,3 +17,4 @@ from argilla.settings._vector import * # noqa: F403 from argilla.settings._question import * # noqa: F403 from argilla.settings._resource import * # noqa: F403 +from argilla.settings._task_distribution import * # noqa: F403