Skip to content

Commit

Permalink
Update oauth2 handler to update workspaces if needed
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbauriegel authored Jan 16, 2025
1 parent 55c64b3 commit 58bb93a
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions argilla-server/src/argilla_server/api/handlers/v1/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.errors.future import NotFoundError
from argilla_server.models import User
from argilla_server.models import User, UserRole, Workspace, WorkspaceUser
from argilla_server.security.authentication.oauth2 import OAuth2ClientProvider
from argilla_server.security.authentication.userinfo import UserInfo
from argilla_server.security.settings import settings
Expand Down Expand Up @@ -61,20 +61,46 @@ async def get_access_token(
if not userinfo.username:
raise RuntimeError("OAuth error: Missing username")

user = await User.get_by(db, username=userinfo.username)
if user is None:
user_w_workspace = await accounts.get_user_by_username(db, username=userinfo.username)
if user_w_workspace is None:
exs_workspaces = await accounts.list_workspaces(db)
exs_workspaces = [w.name for w in exs_workspaces]
default_available_workspaces = [workspace.name for workspace in settings.oauth.allowed_workspaces]
workspaces = userinfo.available_workspaces or default_available_workspaces
# Check first if workspaces exist
workspaces = [w for w in workspaces if w in exs_workspaces]

user = await accounts.create_user_with_random_password(
user_w_workspace = await accounts.create_user_with_random_password(
db,
username=userinfo.username,
first_name=userinfo.first_name,
last_name=userinfo.last_name,
role=userinfo.role,
workspaces=workspaces,
)
else:
if user.role != userinfo.role:
user = await user.update(db, role=userinfo.role)
# With existing user update the role if needed
if user_w_workspace.role != userinfo.role:
user_w_workspace = await user_w_workspace.update(db, role=userinfo.role)
# With existing user update the workspaces if needed
if user_w_workspace.role != UserRole.owner and set(user_w_workspace.workspaces) != set(
userinfo.available_workspaces
):
for workspace_name in userinfo.available_workspaces:
workspace = await Workspace.get_by(db, name=workspace_name)
if not workspace:
continue

return Token(access_token=accounts.generate_user_token(user))
await WorkspaceUser.create(
db,
workspace_id=workspace.id,
user_id=user_w_workspace.id,
autocommit=False,
)
for workspace in user_w_workspace.workspaces:
if workspace.name not in userinfo.available_workspaces:
ws_user = await WorkspaceUser.get_by(db, workspace_id=workspace.id, user_id=user_w_workspace.id)
await ws_user.delete(db, autocommit=False)
await db.commit()

return Token(access_token=accounts.generate_user_token(user_w_workspace))

0 comments on commit 58bb93a

Please sign in to comment.