diff --git a/quetz/database.py b/quetz/database.py index 9810b12d..48e05487 100644 --- a/quetz/database.py +++ b/quetz/database.py @@ -70,8 +70,8 @@ def get_session_maker( def get_session(config: Config | None) -> Session: """Get a database session. - ea - Important note: this function is mocked during tests! + + Important note: this function is mocked during tests! """ if config is None: diff --git a/quetz/tasks/workers.py b/quetz/tasks/workers.py index 98dbacd3..2b98b61a 100644 --- a/quetz/tasks/workers.py +++ b/quetz/tasks/workers.py @@ -137,87 +137,75 @@ def job_wrapper( logger = logging.getLogger("quetz.worker") pkgstore = kwargs.pop("pkgstore", None) - db = kwargs.pop("db", None) - dao = kwargs.pop("dao", None) auth = kwargs.pop("auth", None) session = kwargs.pop("session", None) - if db: - close_session = False - elif dao: - db = dao.db - close_session = False - else: - db = get_session(config) - close_session = True - - user_id: Optional[str] - if task_id: - task = db.query(Task).filter(Task.id == task_id).one_or_none() - if not task: - raise KeyError(f"Task '{task_id}' not found") - # take extra arguments from job definition - if task.job.extra_args: - job_extra_args = json.loads(task.job.extra_args) - kwargs.update(job_extra_args) - if task.job.owner_id: - user_id = str(uuid.UUID(bytes=task.job.owner_id)) + with get_session(config) as db: + user_id: Optional[str] + if task_id: + task = db.query(Task).filter(Task.id == task_id).one_or_none() + if not task: + raise KeyError(f"Task '{task_id}' not found") + # take extra arguments from job definition + if task.job.extra_args: + job_extra_args = json.loads(task.job.extra_args) + kwargs.update(job_extra_args) + if task.job.owner_id: + user_id = str(uuid.UUID(bytes=task.job.owner_id)) + else: + user_id = None else: + task = None user_id = None - else: - task = None - user_id = None - - if not pkgstore: - pkgstore = config.get_package_store() - - dao = Dao(db) - - if not auth: - browser_session: Dict[str, str] = {} - api_key = None - if user_id: - browser_session["user_id"] = user_id - auth = Rules(api_key, browser_session, db) - if not session: - session = get_remote_session() - - if task: - task.status = TaskStatus.running - task.job.status = JobStatus.running - db.commit() - - callable_f: Callable = pickle.loads(func) if isinstance(func, bytes) else func - - extra_kwargs = prepare_arguments( - callable_f, - dao=dao, - auth=auth, - session=session, - config=config, - pkgstore=pkgstore, - user_id=user_id, - ) - - kwargs.update(extra_kwargs) - - try: - callable_f(**kwargs) - except Exception as exc: + + if not pkgstore: + pkgstore = config.get_package_store() + + dao = Dao(db) + + if not auth: + browser_session: Dict[str, str] = {} + api_key = None + if user_id: + browser_session["user_id"] = user_id + auth = Rules(api_key, browser_session, db) + if not session: + session = get_remote_session() + if task: - task.status = TaskStatus.failed - logger.error( - f"exception occurred when evaluating function {callable_f.__name__}:{exc}" + task.status = TaskStatus.running + task.job.status = JobStatus.running + db.commit() + + callable_f: Callable = pickle.loads(func) if isinstance(func, bytes) else func + + extra_kwargs = prepare_arguments( + callable_f, + dao=dao, + auth=auth, + session=session, + config=config, + pkgstore=pkgstore, + user_id=user_id, ) - if exc_passthrou: - raise exc - else: - if task: - task.status = TaskStatus.success - finally: - db.commit() - if close_session: - db.close() + + kwargs.update(extra_kwargs) + + try: + callable_f(**kwargs) + except Exception as exc: + if task: + task.status = TaskStatus.failed + logger.error( + f"exception occurred when evaluating function {callable_f.__name__}:{exc}" + ) + if exc_passthrou: + raise exc + else: + if task: + task.status = TaskStatus.success + finally: + db.commit() class AbstractWorker: diff --git a/quetz/tests/test_workers.py b/quetz/tests/test_workers.py index c109cbb2..b8a1353f 100644 --- a/quetz/tests/test_workers.py +++ b/quetz/tests/test_workers.py @@ -132,11 +132,11 @@ def db_cleanup(config): from quetz.database import get_session - db = get_session(config.sqlalchemy_database_url) - user = db.query(User).one_or_none() - if user: - db.delete(user) - db.commit() + with get_session(config) as db: + user = db.query(User).one_or_none() + if user: + db.delete(user) + db.commit() @pytest.mark.asyncio