Skip to content

Commit

Permalink
[REFACTOR] argilla-server: rewiew OAuth owner condition (#5313)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR changes the behavior to detect the user role when an OAuth
sign-in occurs.

- If the connected user matches the `USERNAME`, the connected user
becomes an owner.
- The rest of the users will be defined as `annotator`. 
- All logic related to roles in ORG has been removed until finding a
proper auth scope.

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Refactor (change restructuring the codebase without changing
functionality)
- Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: José Francisco Calvo <[email protected]>
  • Loading branch information
frascuchon and jfcalvo authored Jul 29, 2024
1 parent 9f2d6dc commit 3430b66
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 196 deletions.
10 changes: 10 additions & 0 deletions argilla-server/docker/argilla-hf-spaces/scripts/start.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,14 @@

set -e

# Preset oauth env vars based on injected space variables.
# See https://huggingface.co/docs/hub/en/spaces-oauth#create-an-oauth-app
export OAUTH2_HUGGINGFACE_CLIENT_ID=$OAUTH_CLIENT_ID
export OAUTH2_HUGGINGFACE_CLIENT_SECRET=$OAUTH_CLIENT_SECRET
export OAUTH2_HUGGINGFACE_SCOPE=$OAUTH_SCOPES

# Set the space author name as username if no provided.
# See https://huggingface.co/docs/hub/en/spaces-overview#helper-environment-variables for more details
export USERNAME="${USERNAME:-$SPACE_AUTHOR_NAME}"

honcho start
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,9 @@

set -e

# Preset oauth env vars based on injected space variables.
# See https://huggingface.co/docs/hub/en/spaces-oauth#create-an-oauth-app
export OAUTH2_HUGGINGFACE_CLIENT_ID=$OAUTH_CLIENT_ID
export OAUTH2_HUGGINGFACE_CLIENT_SECRET=$OAUTH_CLIENT_SECRET
export OAUTH2_HUGGINGFACE_SCOPE=$OAUTH_SCOPES

echo "Running database migrations"
python -m argilla_server database migrate

# Set the space author name as username if no provided.
# See https://huggingface.co/docs/hub/en/spaces-overview#helper-environment-variables for more details
USERNAME="${USERNAME:-$SPACE_AUTHOR_NAME}"

if [ -n "$USERNAME" ] && [ -n "$PASSWORD" ]; then
echo "Creating owner user with username ${USERNAME}"
python -m argilla_server database users create \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@


def add_exception_handlers(app: FastAPI):
@app.exception_handler(errors.AuthenticationError)
async def authentication_error(request, exc):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
# TODO: Once we move to v2.0 we can remove the content using detail attribute
# and use the new one using code and message.
# content={"code": exc.code, "message": exc.message},
content={"detail": str(exc)},
)

@app.exception_handler(errors.NotFoundError)
async def not_found_error_exception_handler(request, exc):
return JSONResponse(
Expand Down
55 changes: 22 additions & 33 deletions argilla-server/src/argilla_server/api/handlers/v1/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from typing import Optional

from fastapi import APIRouter, Depends, HTTPException, Request, Path
from fastapi import APIRouter, Depends, Request, Path
from fastapi.responses import RedirectResponse
from sqlalchemy.ext.asyncio import AsyncSession

Expand All @@ -23,10 +23,9 @@
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.enums import UserRole
from argilla_server.errors.future import AuthenticationError, NotFoundError
from argilla_server.errors.future import NotFoundError
from argilla_server.models import User
from argilla_server.pydantic_v1 import Field, ValidationError
from argilla_server.security.authentication.jwt import JWT
from argilla_server.pydantic_v1 import Field
from argilla_server.security.authentication.oauth2 import OAuth2ClientProvider
from argilla_server.security.authentication.userinfo import UserInfo
from argilla_server.security.settings import settings
Expand Down Expand Up @@ -74,32 +73,22 @@ async def get_access_token(
provider: OAuth2ClientProvider = Depends(get_provider_by_name_or_raise),
db: AsyncSession = Depends(get_async_db),
) -> Token:
try:
user_info = UserInfo(await provider.get_user_data(request)).use_claims(provider.claims)
user = await User.get_by(db, username=user_info.username)
if user is None:
try:
user_create = UserOAuthCreate(
username=user_info.username,
first_name=user_info.first_name,
role=user_info.role,
)
except ValidationError as ex:
raise AuthenticationError("Could not authenticate user") from ex

user = await accounts.create_user_with_random_password(
db,
**user_create.dict(exclude_unset=True),
workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces],
)
telemetry.track_user_created(user, is_oauth=True)

elif user.role != user_info.role:
raise AuthenticationError("Could not authenticate user")

return Token(access_token=JWT.create(user_info))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
# TODO: Create exception handler for AuthenticationError
except AuthenticationError as e:
raise HTTPException(status_code=401, detail=str(e)) from e
userinfo = UserInfo(await provider.get_user_data(request)).use_claims(provider.claims)

if not userinfo.username:
raise RuntimeError("OAuth error: Missing username")

user = await User.get_by(db, username=userinfo.username)
if user is None:
user = await accounts.create_user_with_random_password(
db,
**UserOAuthCreate(
username=userinfo.username,
first_name=userinfo.first_name,
role=userinfo.role,
).dict(exclude_unset=True),
workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces],
)
telemetry.track_user_created(user, is_oauth=True)

return Token(access_token=accounts.generate_user_token(user))
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pydantic import BaseSettings, Field
from argilla_server.pydantic_v1 import BaseSettings, Field


class HuggingfaceSettings(BaseSettings):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, Union, Optional

from typing import Any, Callable, Union
from argilla_server.enums import UserRole


def _parse_role_from_environment(userinfo: dict) -> Optional[UserRole]:
"""This is a temporal solution, and it will be replaced by a proper Sign up process"""
if userinfo["username"] == os.getenv("USERNAME"):
return UserRole.owner


class Claims(dict):
Expand All @@ -29,3 +37,4 @@ def __init__(self, seq=None, **kwargs) -> None:
self["identity"] = kwargs.get("identity", self.get("identity", "sub"))
self["picture"] = kwargs.get("picture", self.get("picture", "picture"))
self["email"] = kwargs.get("email", self.get("email", "email"))
self["role"] = kwargs.get("role", _parse_role_from_environment)
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
# limitations under the License.

import logging
from typing import Union, Optional

from social_core.backends.open_id_connect import OpenIdConnectAuth

from argilla_server.enums import UserRole
from argilla_server.integrations.huggingface.spaces import HUGGINGFACE_SETTINGS
from argilla_server.logging import LoggingMixin
from argilla_server.security.authentication.claims import Claims
from argilla_server.security.authentication.oauth2.providers._base import OAuth2ClientProvider
Expand All @@ -42,46 +39,9 @@ def oidc_endpoint(self) -> str:
_HF_PREFERRED_USERNAME = "preferred_username"


def _is_space_author(userinfo: dict, space_author: str) -> bool:
"""Return True if the space author name is the userinfo username. Otherwise, False"""
return space_author and space_author == userinfo.get(_HF_PREFERRED_USERNAME)


def _find_org_from_userinfo(userinfo: dict, org_name: str) -> Optional[dict]:
"""Find the organization by name from the userinfo"""
for org in userinfo.get("orgs") or []:
if org_name == org.get(_HF_PREFERRED_USERNAME):
return org


def _get_user_role_by_org(org: dict) -> Union[UserRole, None]:
"""Return the computed UserRole from the role found in a organization (if any)"""
_ROLE_IN_ORG = "roleInOrg"
_ROLES_MAPPING = {"admin": UserRole.owner}

org_role = None
if _ROLE_IN_ORG not in org:
_LOGGER.warning(f"Cannot find the user role info in org {org}. Review granted permissions")
else:
org_role = org[_ROLE_IN_ORG]

return _ROLES_MAPPING.get(org_role) or UserRole.annotator


class HuggingfaceClientProvider(OAuth2ClientProvider, LoggingMixin):
"""Specialized HuggingFace OAuth2 provider."""

@staticmethod
def parse_role_from_userinfo(userinfo: dict) -> Union[str, None]:
"""Parse the Argilla user role from info provided as part of the user info"""
space_author_name = HUGGINGFACE_SETTINGS.space_author_name

if _is_space_author(userinfo, space_author_name):
return UserRole.owner
elif org := _find_org_from_userinfo(userinfo, space_author_name):
return _get_user_role_by_org(org)
return UserRole.annotator

claims = Claims(username=_HF_PREFERRED_USERNAME, role=parse_role_from_userinfo, first_name="name")
claims = Claims(username=_HF_PREFERRED_USERNAME, first_name="name")
backend_class = HuggingfaceOpenId
name = "huggingface"
19 changes: 11 additions & 8 deletions argilla-server/tests/unit/api/handlers/v1/test_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ async def test_provider_huggingface_access_token(
assert JWT.decode(json_response["access_token"])["username"] == "username"
assert json_response["token_type"] == "bearer"

user = (await db.execute(select(User).where(User.username == "username"))).scalar_one_or_none()
user = await db.scalar(select(User).filter_by(username="username"))
assert user is not None
assert user.role == UserRole.annotator

Expand All @@ -182,7 +182,7 @@ async def test_provider_huggingface_access_token_with_missing_username(
cookies={"oauth2_state": "valid"},
)

assert response.status_code == 401
assert response.status_code == 500

async def test_provider_huggingface_access_token_with_missing_name(
self,
Expand Down Expand Up @@ -244,7 +244,7 @@ async def test_provider_access_token_with_not_found_code(
response = await async_client.get(
"/api/v1/oauth2/providers/huggingface/access-token", headers=owner_auth_header
)
assert response.status_code == 400
assert response.status_code == 422
assert response.json() == {"detail": "'code' parameter was not found in callback request"}

async def test_provider_access_token_with_not_found_state(
Expand All @@ -254,7 +254,7 @@ async def test_provider_access_token_with_not_found_state(
response = await async_client.get(
"/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code"}, headers=owner_auth_header
)
assert response.status_code == 400
assert response.status_code == 422
assert response.json() == {"detail": "'state' parameter was not found in callback request"}

async def test_provider_access_token_with_invalid_state(
Expand All @@ -267,7 +267,7 @@ async def test_provider_access_token_with_invalid_state(
headers=owner_auth_header,
cookies={"oauth2_state": "valid"},
)
assert response.status_code == 400
assert response.status_code == 422
assert response.json() == {"detail": "'state' parameter does not match"}

async def test_provider_access_token_with_authentication_error(
Expand All @@ -287,7 +287,7 @@ async def test_provider_access_token_with_authentication_error(
assert response.status_code == 401
assert response.json() == {"detail": "error"}

async def test_provider_access_token_with_unauthorized_user(
async def test_provider_access_token_with_already_created_user(
self,
async_client: AsyncClient,
db: AsyncSession,
Expand All @@ -307,8 +307,11 @@ async def test_provider_access_token_with_unauthorized_user(
headers=owner_auth_header,
cookies={"oauth2_state": "valid"},
)
assert response.status_code == 401
assert response.json() == {"detail": "Could not authenticate user"}
assert response.status_code == 200

userinfo = JWT.decode(response.json()["access_token"])
assert userinfo["username"] == admin.username
assert userinfo["role"] == admin.role

async def test_provider_access_token_with_same_username(
self,
Expand Down

This file was deleted.

This file was deleted.

Loading

0 comments on commit 3430b66

Please sign in to comment.