Skip to content

Commit

Permalink
factorize subject session
Browse files Browse the repository at this point in the history
  • Loading branch information
maximemulder committed Dec 22, 2024
1 parent 2aff368 commit 7605a83
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 116 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ include = [
"python/lib/config_file.py",
"python/lib/env.py",
"python/lib/file_system.py",
"python/lib/get_subject_session.py",
"python/lib/logging.py",
"python/lib/make_env.py",
"python/lib/validate_subject_info.py",
Expand Down
6 changes: 6 additions & 0 deletions python/lib/database_lib/session_db.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""This class performs session table related database queries and common checks"""

from typing_extensions import deprecated

__license__ = "GPLv3"


@deprecated('Use `lib.db.models.session.DbSession` instead')
class SessionDB:
"""
This class performs database queries for session table.
Expand Down Expand Up @@ -35,6 +37,7 @@ def __init__(self, db, verbose):
self.db = db
self.verbose = verbose

@deprecated('Use `lib.db.queries.try_get_candidate_with_cand_id_visit_label` instead')
def create_session_dict(self, cand_id, visit_label):
"""
Queries the session table for a particular candidate ID and visit label and returns a dictionary
Expand All @@ -56,6 +59,7 @@ def create_session_dict(self, cand_id, visit_label):

return results[0] if results else None

@deprecated('Use `lib.db.queries.site.try_get_site_with_psc_id_visit_label` instead')
def get_session_center_info(self, pscid, visit_label):
"""
Get site information for a given visit.
Expand All @@ -77,6 +81,7 @@ def get_session_center_info(self, pscid, visit_label):

return results[0] if results else None

@deprecated('Use `lib.get_subject_session.get_candidate_next_visit_number` instead')
def determine_next_session_site_id_and_visit_number(self, cand_id):
"""
Determines the next session site and visit number based on the last session inserted for a given candidate.
Expand All @@ -99,6 +104,7 @@ def determine_next_session_site_id_and_visit_number(self, cand_id):

return results[0] if results else None

@deprecated('Use `lib.db.models.session.DbSession` instead')
def insert_into_session(self, fields, values):
"""
Insert a new row in the session table using fields list as column names and values as values.
Expand Down
3 changes: 1 addition & 2 deletions python/lib/db/models/notification_spool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,4 @@ class DbNotificationSpool(Base):
origin : Mapped[Optional[str]] = mapped_column('Origin')
active : Mapped[bool] = mapped_column('Active', YNBool)

type : Mapped['db_notification_type.DbNotificationType'] \
= relationship('DbNotificationType')
type : Mapped['db_notification_type.DbNotificationType'] = relationship('DbNotificationType')
109 changes: 11 additions & 98 deletions python/lib/dcm2bids_imaging_pipeline_lib/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import lib.utilities
from lib.database import Database
from lib.database_lib.config import Config
from lib.db.queries.session import try_get_session_with_cand_id_visit_label
from lib.dicom_archive import DicomArchive
from lib.exception.determine_subject_info_error import DetermineSubjectInfoError
from lib.exception.validate_subject_info_error import ValidateSubjectInfoError
Expand Down Expand Up @@ -192,10 +193,15 @@ def determine_study_info(self):

# get the CenterID from the session table if the PSCID and visit label exists
# and could be extracted from the database
self.session_obj.create_session_dict(self.subject_info.cand_id, self.subject_info.visit_label)
session_dict = self.session_obj.session_info_dict
if session_dict:
return {"CenterName": session_dict["MRI_alias"], "CenterID": session_dict["CenterID"]}

self.session = try_get_session_with_cand_id_visit_label(
self.env.db,
self.subject_info.cand_id,
self.subject_info.visit_label,
)

if self.session is not None:
return {"CenterName": self.session.site.mri_alias, "CenterID": self.session.site_id}

# if could not find center information based on cand_id and visit_label, use the
# patient name to match it to the site alias or MRI alias
Expand Down Expand Up @@ -223,7 +229,7 @@ def determine_scanner_info(self):
self.dicom_archive_obj.tarchive_info_dict['ScannerSerialNumber'],
self.dicom_archive_obj.tarchive_info_dict['ScannerModel'],
self.site_dict['CenterID'],
self.session_obj.session_info_dict['ProjectID'] if self.session_obj.session_info_dict else None
self.session.project_id if self.session is not None else None,
)

log_verbose(self.env, f"Found Scanner ID: {scanner_id}")
Expand All @@ -248,99 +254,6 @@ def validate_subject_info(self):
upload_id=self.upload_id, fields=('IsCandidateInfoValidated',), values=('0',)
)

def get_session_info(self):
"""
Creates the session info dictionary based on entries found in the session table.
"""

self.session_obj.create_session_dict(self.subject_info.cand_id, self.subject_info.visit_label)

if self.session_obj.session_info_dict:
log_verbose(self.env, f"Session ID for the file to insert is {self.session_obj.session_info_dict['ID']}")

def create_session(self):
"""
Function that will create a new visit in the session table for the imaging scans after verification
that all the information necessary for the creation of the visit are present.
"""

create_visit = self.subject_info.create_visit

if create_visit is None:
log_error_exit(
self.env,
f"Visit {self.subject_info.visit_label} for candidate {self.subject_info.cand_id} does not exist.",
lib.exitcode.GET_SESSION_ID_FAILURE,
)

# check that the project ID and cohort ID refers to an existing row in project_cohort_rel table
self.session_obj.create_proj_cohort_rel_info_dict(create_visit.project_id, create_visit.cohort_id)
if not self.session_obj.proj_cohort_rel_info_dict.keys():
log_error_exit(
self.env,
(
f"Cannot create visit with project ID {create_visit.project_id}"
f" and cohort ID {create_visit.cohort_id}:"
f" no such association in table project_cohort_rel"
),
lib.exitcode.CREATE_SESSION_FAILURE,
)

# determine the visit number and center ID for the next session to be created
center_id, visit_nb = self.determine_new_session_site_and_visit_nb()
if not center_id:
log_error_exit(
self.env,
(
f"No center ID found for candidate {self.subject_info.cand_id}"
f", visit {self.subject_info.visit_label}"
)
)
else:
log_verbose(self.env, f"Set newVisitNo = {visit_nb} and center ID = {center_id}")

# create the new visit
session_id = self.session_obj.insert_into_session(
{
'CandID': self.subject_info.cand_id,
'Visit_label': self.subject_info.visit_label,
'CenterID': center_id,
'VisitNo': visit_nb,
'Current_stage': 'Not Started',
'Scan_done': 'Y',
'Submitted': 'N',
'CohortID': create_visit.cohort_id,
'ProjectID': create_visit.project_id
}
)
if session_id:
self.get_session_info()

def determine_new_session_site_and_visit_nb(self):
"""
Determines the site and visit number of the new session to be created.
:returns: The center ID and visit number of the future new session
"""
visit_nb = 0
center_id = 0

if self.subject_info.is_phantom:
center_info_dict = self.session_obj.get_session_center_info(
self.subject_info.psc_id, self.subject_info.visit_label,
)

if center_info_dict:
center_id = center_info_dict["CenterID"]
visit_nb = 1
else:
center_info_dict = self.session_obj.get_next_session_site_id_and_visit_number(self.subject_info.cand_id)
if center_info_dict:
center_id = center_info_dict["CenterID"]
visit_nb = center_info_dict["newVisitNo"]

return center_id, visit_nb

def check_if_tarchive_validated_in_db(self):
"""
Checks whether the DICOM archive was previously validated in the database (as per the value present
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def _move_and_update_dicom_archive(self):
archive_location = self.dicom_archive_obj.tarchive_info_dict["ArchiveLocation"]

fields_to_update = ("SessionID",)
values_for_update = (self.session_obj.session_id,)
values_for_update = (self.session.id,)
pattern = re.compile("^[0-9]{4}/")
if acq_date and not pattern.match(archive_location):
# move the DICOM archive into a year subfolder
Expand Down Expand Up @@ -412,7 +412,7 @@ def _update_mri_upload(self):
self.imaging_upload_obj.update_mri_upload(
upload_id=self.upload_id,
fields=("Inserting", "InsertionComplete", "number_of_mincInserted", "number_of_mincCreated", "SessionID"),
values=("0", "1", len(files_inserted_list), len(self.nifti_files_to_insert), self.session_obj.session_id)
values=("0", "1", len(files_inserted_list), len(self.nifti_files_to_insert), self.session.id)
)

def _get_summary_of_insertion(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lib.dcm2bids_imaging_pipeline_lib.base_pipeline import BasePipeline
from lib.exception.determine_subject_info_error import DetermineSubjectInfoError
from lib.exception.validate_subject_info_error import ValidateSubjectInfoError
from lib.get_subject_session import get_subject_session
from lib.logging import log_error_exit, log_verbose
from lib.validate_subject_info import validate_subject_info

Expand Down Expand Up @@ -110,9 +111,7 @@ def __init__(self, loris_getopt_obj, script_name):
# ---------------------------------------------------------------------------------------------
# Determine/create the session the file should be linked to
# ---------------------------------------------------------------------------------------------
self.get_session_info()
if not self.session_obj.session_info_dict:
self.create_session()
self.session = get_subject_session(self.env, self.subject_info)

# ---------------------------------------------------------------------------------------------
# Determine acquisition protocol (or register into mri_protocol_violated_scans and exits)
Expand Down Expand Up @@ -169,9 +168,9 @@ def __init__(self, loris_getopt_obj, script_name):
self.exclude_violations_list = []
if not self.bypass_extra_checks:
self.violations_summary = self.imaging_obj.run_extra_file_checks(
self.session_obj.session_info_dict['ProjectID'],
self.session_obj.session_info_dict['CohortID'],
self.session_obj.session_info_dict['Visit_label'],
self.session.project_id,
self.session.cohort_id,
self.session.visit_label,
self.scan_type_id,
self.json_file_dict
)
Expand Down Expand Up @@ -357,15 +356,15 @@ def _determine_acquisition_protocol(self):
self.json_file_dict['DeviceSerialNumber'],
self.json_file_dict['ManufacturersModelName'],
self.site_dict['CenterID'],
self.session_obj.session_info_dict['ProjectID']
self.session.project_id,
)

# get the list of lines in the mri_protocol table that apply to the given scan based on the protocol group
protocols_list = self.imaging_obj.get_list_of_eligible_protocols_based_on_session_info(
self.session_obj.session_info_dict['ProjectID'],
self.session_obj.session_info_dict['CohortID'],
self.session_obj.session_info_dict['CenterID'],
self.session_obj.session_info_dict['Visit_label'],
self.session.project_id,
self.session.cohort_id,
self.session.site_id,
self.session.visit_label,
self.scanner_id
)

Expand Down Expand Up @@ -458,7 +457,7 @@ def _determine_new_nifti_assembly_rel_path(self):
# determine NIfTI file name
new_nifti_name = self._construct_nifti_filename(file_bids_entities_dict)
already_inserted_filenames = self.imaging_obj.get_list_of_files_already_inserted_for_session_id(
self.session_obj.session_info_dict['ID']
self.session.id,
)
while new_nifti_name in already_inserted_filenames:
file_bids_entities_dict['run'] += 1
Expand Down Expand Up @@ -680,7 +679,7 @@ def _register_into_files_and_parameter_file(self, nifti_rel_path):
)

files_insert_info_dict = {
'SessionID': self.session_obj.session_info_dict['ID'],
'SessionID': self.session.id,
'File': nifti_rel_path,
'SeriesUID': scan_param['SeriesInstanceUID'] if 'SeriesInstanceUID' in scan_param.keys() else None,
'EchoTime': scan_param['EchoTime'] if 'EchoTime' in scan_param.keys() else None,
Expand Down
83 changes: 83 additions & 0 deletions python/lib/get_subject_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import cast

import lib.exitcode
from lib.config_file import SubjectInfo
from lib.db.models.candidate import DbCandidate
from lib.db.models.session import DbSession
from lib.db.queries.candidate import try_get_candidate_with_cand_id
from lib.db.queries.session import try_get_session_with_cand_id_visit_label
from lib.db.queries.site import try_get_site_with_cand_id_visit_label
from lib.env import Env
from lib.logging import log_error_exit, log_verbose


def get_candidate_next_visit_number(candidate: DbCandidate) -> int:
"""
Get the next visit number for a new session for a given candidate.
"""

visit_numbers = [session.visit_number for session in candidate.sessions if session.visit_number is not None]
return max(*visit_numbers, 0) + 1


def get_subject_session(env: Env, subject_info: SubjectInfo) -> DbSession:
"""
Get the imaging session corresponding to a given subject configuration.
This function first looks for an adequate session in the database, and returns it if one is
found. If no session is found, this function creates a new session in the database if the
subject configuration allows it, or exits the program otherwise.
"""

session = _get_subject_session(env, subject_info)
log_verbose(env, f"Session ID for the file to insert is {session.id}")
return session


def _get_subject_session(env: Env, subject_info: SubjectInfo) -> DbSession:
"""
Implementation of `get_subject_session`.
"""

session = try_get_session_with_cand_id_visit_label(env.db, subject_info.cand_id, subject_info.visit_label)
if session is not None:
return session

if subject_info.create_visit is None:
log_error_exit(
env,
f"Visit {subject_info.visit_label} for candidate {subject_info.cand_id} does not exist.",
lib.exitcode.GET_SESSION_ID_FAILURE,
)

if subject_info.is_phantom:
site = try_get_site_with_cand_id_visit_label(env.db, subject_info.cand_id, subject_info.visit_label)
visit_number = 1
else:
candidate = try_get_candidate_with_cand_id(env.db, subject_info.cand_id)
# Safe because it has been checked that the candidate exists in `validate_subject_info`
candidate = cast(DbCandidate, candidate)
site = candidate.registration_site
visit_number = get_candidate_next_visit_number(candidate)

if site is None:
log_error_exit(
env,
f"No center ID found for candidate {subject_info.cand_id}, visit {subject_info.visit_label}"
)

session = DbSession(
cand_id = subject_info.cand_id,
site_id = site.id,
visit_number = visit_number,
current_stage = 'Not Started',
scan_done = True,
submitted = False,
project_id = subject_info.create_visit.project_id,
cohort_id = subject_info.create_visit.cohort_id,
)

env.db.add(session)
env.db.commit()

return session
Loading

0 comments on commit 7605a83

Please sign in to comment.