Skip to content

Commit

Permalink
Adds session instance to CredentialsManager constructor
Browse files Browse the repository at this point in the history
Moves access checks to CredentialsService for better separation of concerns
Refactoring to get the CredentialsManager ready for unit test
  • Loading branch information
arash77 committed Jan 20, 2025
1 parent 6c26250 commit 2e5953a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 64 deletions.
64 changes: 22 additions & 42 deletions lib/galaxy/managers/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@
from sqlalchemy import select
from sqlalchemy.orm import aliased

from galaxy.exceptions import (
AuthenticationRequired,
ItemOwnershipException,
RequestParameterInvalidException,
)
from galaxy.managers.context import ProvidesUserContext
from galaxy.exceptions import RequestParameterInvalidException
from galaxy.model import (
CredentialsGroup,
Secret,
Expand All @@ -23,27 +18,22 @@
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.schema.credentials import SOURCE_TYPE
from galaxy.schema.fields import DecodedDatabaseIdField
from galaxy.schema.schema import FlexibleUserIdType


class CredentialsManager:
"""Manager object shared by controllers for interacting with credentials."""

def __init__(self, session: galaxy_scoped_session) -> None:
self.session = session

def get_user_credentials(
self,
trans: ProvidesUserContext,
user_id: FlexibleUserIdType,
user_id: DecodedDatabaseIdField,
source_type: Optional[SOURCE_TYPE] = None,
source_id: Optional[str] = None,
user_credentials_id: Optional[DecodedDatabaseIdField] = None,
group_id: Optional[DecodedDatabaseIdField] = None,
) -> List[Tuple[UserCredentials, CredentialsGroup]]:
if trans.anonymous:
raise AuthenticationRequired("You need to be logged in to access your credentials.")
if user_id == "current":
user_id = trans.user.id
elif trans.user.id != user_id:
raise ItemOwnershipException("You can only access your own credentials.")
user_cred_alias, group_alias = aliased(UserCredentials), aliased(CredentialsGroup)
stmt = (
select(user_cred_alias, group_alias)
Expand All @@ -61,27 +51,25 @@ def get_user_credentials(
if group_id:
stmt = stmt.where(group_alias.id == group_id)

result = trans.sa_session.execute(stmt).all()
result = self.session.execute(stmt).all()
return [(row[0], row[1]) for row in result]

def fetch_credentials(
self,
session: galaxy_scoped_session,
group_id: DecodedDatabaseIdField,
) -> Tuple[List[Variable], List[Secret]]:
variables = list(
session.execute(select(Variable).where(Variable.user_credential_group_id == group_id)).scalars().all()
self.session.execute(select(Variable).where(Variable.user_credential_group_id == group_id)).scalars().all()
)

secrets = list(
session.execute(select(Secret).where(Secret.user_credential_group_id == group_id)).scalars().all()
self.session.execute(select(Secret).where(Secret.user_credential_group_id == group_id)).scalars().all()
)

return variables, secrets

def add_user_credentials(
self,
session: galaxy_scoped_session,
db_user_credentials: List[Tuple[UserCredentials, CredentialsGroup]],
user_id: DecodedDatabaseIdField,
reference: str,
Expand All @@ -96,13 +84,12 @@ def add_user_credentials(
source_type=source_type,
source_id=source_id,
)
session.add(user_credentials)
session.flush()
self.session.add(user_credentials)
self.session.flush()
return user_credentials.id

def add_group(
self,
session: galaxy_scoped_session,
db_user_credentials: List[Tuple[UserCredentials, CredentialsGroup]],
user_credentials_id: DecodedDatabaseIdField,
group_name: str,
Expand All @@ -114,13 +101,12 @@ def add_group(
)
if not credentials_group:
credentials_group = CredentialsGroup(name=group_name, user_credentials_id=user_credentials_id)
session.add(credentials_group)
session.flush()
self.session.add(credentials_group)
self.session.flush()
return credentials_group.id

def add_variable(
self,
session: galaxy_scoped_session,
variables: List[Variable],
user_credential_group_id: DecodedDatabaseIdField,
variable_name: str,
Expand All @@ -138,11 +124,10 @@ def add_variable(
name=variable_name,
value=variable_value,
)
session.add(variable)
self.session.add(variable)

def add_secret(
self,
session: galaxy_scoped_session,
secrets: List[Secret],
user_credential_group_id: DecodedDatabaseIdField,
secret_name: str,
Expand All @@ -160,36 +145,31 @@ def add_secret(
name=secret_name,
already_set=True if secret_value else False,
)
session.add(secret)
self.session.add(secret)

def commit_session(
self,
session: galaxy_scoped_session,
) -> None:
with transaction(session):
session.commit()
def commit_session(self) -> None:
with transaction(self.session):
self.session.commit()

def update_current_group(
self,
trans: ProvidesUserContext,
user_id: DecodedDatabaseIdField,
user_credentials_id: DecodedDatabaseIdField,
group_name: str,
) -> None:
db_user_credentials = self.get_user_credentials(trans, trans.user.id, user_credentials_id=user_credentials_id)
db_user_credentials = self.get_user_credentials(user_id, user_credentials_id=user_credentials_id)
for user_credentials, credentials_group in db_user_credentials:
if credentials_group.name == group_name:
user_credentials.current_group_id = credentials_group.id
trans.sa_session.add(user_credentials)
self.session.add(user_credentials)
break
else:
raise RequestParameterInvalidException("Group not found to set as current.")

def delete_rows(
self,
session: galaxy_scoped_session,
rows_to_delete: List,
) -> None:
for row in rows_to_delete:
session.delete(row)
with transaction(session):
session.commit()
self.session.delete(row)
self.commit_session()
54 changes: 32 additions & 22 deletions lib/galaxy/webapps/galaxy/services/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
)

from galaxy.exceptions import (
AuthenticationRequired,
ItemOwnershipException,
ObjectNotFound,
RequestParameterInvalidException,
)
Expand Down Expand Up @@ -50,7 +52,8 @@ def list_user_credentials(
source_id: Optional[str] = None,
) -> UserCredentialsListResponse:
"""Lists all credentials the user has provided (credentials themselves are not included)."""
return self._list_user_credentials(trans, user_id, source_type, source_id)
self._check_access(trans, user_id)
return self._list_user_credentials(trans, source_type, source_id)

def provide_credential(
self,
Expand All @@ -59,8 +62,9 @@ def provide_credential(
payload: CreateSourceCredentialsPayload,
) -> UserCredentialsListResponse:
"""Allows users to provide credentials for a group of secrets and variables."""
self._create_or_update_credentials(trans, user_id, payload)
return self._list_user_credentials(trans, user_id, payload.source_type, payload.source_id)
self._check_access(trans, user_id)
self._create_or_update_credentials(trans, payload)
return self._list_user_credentials(trans, payload.source_type, payload.source_id)

def delete_credentials(
self,
Expand All @@ -70,8 +74,9 @@ def delete_credentials(
group_id: Optional[DecodedDatabaseIdField] = None,
) -> None:
"""Deletes a specific credential group or all credentials for a specific service."""
self._check_access(trans, user_id)
db_user_credentials = self._credentials_manager.get_user_credentials(
trans, user_id, user_credentials_id=user_credentials_id, group_id=group_id
trans.user.id, user_credentials_id=user_credentials_id, group_id=group_id
)
if not db_user_credentials:
raise ObjectNotFound("No credentials found.")
Expand All @@ -83,22 +88,21 @@ def delete_credentials(
if credentials_group.name == "default":
raise RequestParameterInvalidException("Cannot delete the default group.")
if credentials_group.id == uc.current_group_id:
self._credentials_manager.update_current_group(trans, uc.id, "default")
variables, secrets = self._credentials_manager.fetch_credentials(trans.sa_session, credentials_group.id)
self._credentials_manager.update_current_group(trans.user.id, uc.id, "default")
variables, secrets = self._credentials_manager.fetch_credentials(credentials_group.id)
rows_to_delete.extend([credentials_group, *variables, *secrets])
self._credentials_manager.delete_rows(trans.sa_session, rows_to_delete)
self._credentials_manager.delete_rows(rows_to_delete)

def _list_user_credentials(
self,
trans: ProvidesUserContext,
user_id: FlexibleUserIdType,
source_type: Optional[SOURCE_TYPE] = None,
source_id: Optional[str] = None,
) -> UserCredentialsListResponse:
db_user_credentials = self._credentials_manager.get_user_credentials(trans, user_id, source_type, source_id)
db_user_credentials = self._credentials_manager.get_user_credentials(trans.user.id, source_type, source_id)
credentials_dict = self._map_user_credentials(db_user_credentials)
for user_credentials, credentials_group in db_user_credentials:
variables, secrets = self._credentials_manager.fetch_credentials(trans.sa_session, credentials_group.id)
variables, secrets = self._credentials_manager.fetch_credentials(credentials_group.id)
group = credentials_dict[user_credentials.id]["groups"].get(credentials_group.name, {})
group["variables"].extend(
{"id": variable.id, "name": variable.name, "value": variable.value} for variable in variables
Expand Down Expand Up @@ -146,43 +150,49 @@ def _map_user_credentials(
def _create_or_update_credentials(
self,
trans: ProvidesUserContext,
user_id: FlexibleUserIdType,
payload: CreateSourceCredentialsPayload,
) -> None:
session = trans.sa_session
source_type, source_id = payload.source_type, payload.source_id
db_user_credentials = self._credentials_manager.get_user_credentials(trans, user_id, source_type, source_id)
db_user_credentials = self._credentials_manager.get_user_credentials(trans.user.id, source_type, source_id)
for service_payload in payload.credentials:
reference = service_payload.reference
current_group_name = service_payload.current_group
if not current_group_name:
current_group_name = "default"
user_credentials_id = self._credentials_manager.add_user_credentials(
session, db_user_credentials, trans.user.id, reference, source_type, source_id
db_user_credentials, trans.user.id, reference, source_type, source_id
)
for group in service_payload.groups:
group_name = group.name
user_credential_group_id = self._credentials_manager.add_group(
session, db_user_credentials, user_credentials_id, group_name, reference
db_user_credentials, user_credentials_id, group_name, reference
)
variables, secrets = self._credentials_manager.fetch_credentials(session, user_credential_group_id)
variables, secrets = self._credentials_manager.fetch_credentials(user_credential_group_id)
user_vault = UserVaultWrapper(self._app.vault, trans.user)
for variable_payload in group.variables:
variable_name, variable_value = variable_payload.name, variable_payload.value
if variable_value is None:
continue
self._credentials_manager.add_variable(
session, variables, user_credential_group_id, variable_name, variable_value
variables, user_credential_group_id, variable_name, variable_value
)
for secret_payload in group.secrets:
secret_name, secret_value = secret_payload.name, secret_payload.value
if secret_value is None:
continue
vault_ref = f"{source_type}|{source_id}|{reference}|{group_name}|{secret_name}"
user_vault.write_secret(vault_ref, secret_value)
self._credentials_manager.add_secret(
session, secrets, user_credential_group_id, secret_name, secret_value
)
self._credentials_manager.add_secret(secrets, user_credential_group_id, secret_name, secret_value)

self._credentials_manager.update_current_group(trans.user.id, user_credentials_id, current_group_name)
self._credentials_manager.commit_session()

self._credentials_manager.update_current_group(trans, user_credentials_id, current_group_name)
self._credentials_manager.commit_session(session)
def _check_access(
self,
trans: ProvidesUserContext,
user_id: FlexibleUserIdType,
) -> None:
if trans.anonymous:
raise AuthenticationRequired("You need to be logged in to access your credentials.")
if user_id != "current" and user_id != trans.user.id:
raise ItemOwnershipException("You can only access your own credentials.")

0 comments on commit 2e5953a

Please sign in to comment.