diff --git a/.ci_support/environment-docs.yml b/.ci_support/environment-docs.yml index 66253bbd5..992965c8d 100644 --- a/.ci_support/environment-docs.yml +++ b/.ci_support/environment-docs.yml @@ -7,23 +7,23 @@ dependencies: - myst-parser - conda =24.7.1 - conda_subprocess =0.0.5 -- cloudpickle =3.1.0 +- cloudpickle =3.1.1 - gitpython =3.1.44 -- h5io_browser =0.1.5 +- h5io_browser =0.1.6 - h5py =3.12.1 - jinja2 =3.1.5 - monty =2025.1.9 -- numpy =2.2.1 +- numpy =2.2.2 - pandas =2.2.3 - pint =0.24.4 - psutil =6.1.1 -- pyfileindex =0.0.32 +- pyfileindex =0.0.33 - pyiron_dataclasses =0.0.1 - pyiron_snippets =0.1.4 -- executorlib =0.0.7 +- executorlib =0.0.8 - pysqa =0.2.3 -- pytables =3.10.1 -- sqlalchemy =2.0.36 +- pytables =3.10.2 +- sqlalchemy =2.0.37 - tqdm =4.67.1 - traitlets =5.14.3 - jupyter-book =1.0.0 diff --git a/.ci_support/environment-mini.yml b/.ci_support/environment-mini.yml index 9d8c0ded0..47a232dc0 100644 --- a/.ci_support/environment-mini.yml +++ b/.ci_support/environment-mini.yml @@ -1,20 +1,20 @@ channels: - conda-forge dependencies: -- cloudpickle =3.1.0 -- h5io_browser =0.1.5 +- cloudpickle =3.1.1 +- h5io_browser =0.1.6 - h5py =3.12.1 - monty =2025.1.9 -- numpy =2.2.1 +- numpy =2.2.2 - pandas =2.2.3 - psutil =6.1.1 -- pyfileindex =0.0.32 +- pyfileindex =0.0.33 - pyiron_dataclasses =0.0.1 - pyiron_snippets =0.1.3 -- executorlib =0.0.7 +- executorlib =0.0.8 - pysqa =0.2.3 -- pytables =3.10.1 -- sqlalchemy =2.0.36 +- pytables =3.10.2 +- sqlalchemy =2.0.37 - tqdm =4.67.1 - traitlets =5.14.3 - setuptools diff --git a/.ci_support/environment.yml b/.ci_support/environment.yml index 357f72825..78f764214 100644 --- a/.ci_support/environment.yml +++ b/.ci_support/environment.yml @@ -5,22 +5,22 @@ dependencies: - codacy-coverage - conda =24.7.1 - conda_subprocess =0.0.5 -- cloudpickle =3.1.0 +- cloudpickle =3.1.1 - gitpython =3.1.44 -- h5io_browser =0.1.5 +- h5io_browser =0.1.6 - h5py =3.12.1 - jinja2 =3.1.5 - monty =2025.1.9 -- numpy =2.2.1 +- numpy =2.2.2 - pandas =2.2.3 - pint =0.24.4 - psutil =6.1.1 -- pyfileindex =0.0.32 +- pyfileindex =0.0.33 - pyiron_dataclasses =0.0.1 - pyiron_snippets =0.1.4 -- executorlib =0.0.7 +- executorlib =0.0.8 - pysqa =0.2.3 -- pytables =3.10.1 -- sqlalchemy =2.0.36 +- pytables =3.10.2 +- sqlalchemy =2.0.37 - tqdm =4.67.1 - traitlets =5.14.3 diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 5a469bc40..e07482f47 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -28,10 +28,6 @@ jobs: - operating-system: ubuntu-latest python-version: '3.11' label: linux-64-py-3-11 - - - operating-system: ubuntu-latest - python-version: '3.10' - label: linux-64-py-3-10 steps: - uses: actions/checkout@v4 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3b10414e2..07657a097 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.6 + rev: v0.9.2 hooks: - id: ruff name: ruff lint diff --git a/.readthedocs.yml b/.readthedocs.yml index 6bb1b86ef..1bbb7ec71 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -6,9 +6,9 @@ version: 2 build: - os: "ubuntu-22.04" + os: "ubuntu-24.04" tools: - python: "mambaforge-22.9" + python: "mambaforge-23.11" jobs: pre_build: # Generate the Sphinx configuration for this Jupyter Book so it builds. @@ -21,10 +21,11 @@ build: # Build documentation in the docs/ directory with Sphinx sphinx: builder: html + configuration: docs/conf.py # Optionally build your docs in additional formats such as PDF and ePub formats: [] # Install pyiron from conda conda: - environment: .ci_support/environment-docs.yml \ No newline at end of file + environment: .ci_support/environment-docs.yml diff --git a/binder/environment.yml b/binder/environment.yml index 1ff33c3e6..50d76a9c7 100644 --- a/binder/environment.yml +++ b/binder/environment.yml @@ -4,22 +4,22 @@ dependencies: - python - conda =24.7.1 - conda_subprocess =0.0.5 -- cloudpickle =3.1.0 +- cloudpickle =3.1.1 - gitpython =3.1.44 -- h5io_browser =0.1.5 +- h5io_browser =0.1.6 - h5py =3.12.1 - jinja2 =3.1.5 - monty =2025.1.9 -- numpy =2.2.1 +- numpy =2.2.2 - pandas =2.2.3 - pint =0.24.4 - psutil =6.1.1 -- pyfileindex =0.0.32 +- pyfileindex =0.0.33 - pyiron_dataclasses =0.0.1 - pyiron_snippets =0.1.4 -- executorlib =0.0.7 +- executorlib =0.0.8 - pysqa =0.2.3 -- pytables =3.10.1 -- sqlalchemy =2.0.36 +- pytables =3.10.2 +- sqlalchemy =2.0.37 - tqdm =4.67.1 - traitlets =5.14.3 diff --git a/pyiron_base/cli/control.py b/pyiron_base/cli/control.py index cafc57c65..4b3ba5416 100644 --- a/pyiron_base/cli/control.py +++ b/pyiron_base/cli/control.py @@ -54,7 +54,7 @@ def main() -> None: mod.register(sub_parser) except AttributeError: warnings.warn( - "module '{}' does not define main or register " "function, ignoring" + "module '{}' does not define main or register function, ignoring" ) args = parser.parse_args() diff --git a/pyiron_base/database/filetable.py b/pyiron_base/database/filetable.py index 7a6d789bc..ad1194a9a 100644 --- a/pyiron_base/database/filetable.py +++ b/pyiron_base/database/filetable.py @@ -280,7 +280,7 @@ def get_items_dict( return_all_columns (bool): return all columns or only the 'id' - still the format stays the same. Returns: - list: the function returns a list of dicts like get_items_sql, but it does not format datetime: + list: the function returns a list of dicts, but it does not format datetime: [{'chemicalformula': u'Ni108', 'computer': u'mapc157', 'hamilton': u'LAMMPS', diff --git a/pyiron_base/database/generic.py b/pyiron_base/database/generic.py index f85b25f44..ace4ea4ad 100644 --- a/pyiron_base/database/generic.py +++ b/pyiron_base/database/generic.py @@ -281,7 +281,7 @@ def _job_dict( element_lst (list): list of elements required in the chemical formular - by default None Returns: - list: the function returns a list of dicts like get_items_sql, but it does not format datetime: + list: the function returns a list of dicts, but it does not format datetime: [{'chemicalformula': u'Ni108', 'computer': u'mapc157', 'hamilton': u'LAMMPS', @@ -312,10 +312,30 @@ def _job_dict( 'username': u'test'},.......] """ + if not self._sql_lite: + + def escape(s, escape_char="\\", special_chars="_%"): + """Insert escape_char in front of special_chars, unless present. + + Handles the cases where s already contains escaped characters, + including the escape character itself. + + Defaults for LIKE in SQL statements.""" + for c in special_chars: + if c in s: + s = s.replace(escape_char + c, c) + s = s.replace(c, escape_char + c) + return s + + else: + + def escape(s, escape_char="\\", special_chars="_%"): + return s + dict_clause = {} # FOR GET_ITEMS_SQL: clause = [] if user is not None: - dict_clause["username"] = str(user) + dict_clause["username"] = escape(str(user)) # FOR GET_ITEMS_SQL: clause.append("username = '" + self.user + "'") if sql_query is not None: # FOR GET_ITEMS_SQL: clause.append(self.sql_query) @@ -329,18 +349,18 @@ def _job_dict( {str(element.split()[0]): element.split()[2] for element in cl_split} ) if job is not None: - dict_clause["job"] = str(job) + dict_clause["job"] = escape(str(job)) if project_path == "./": project_path = "" if recursive: - dict_clause["project"] = str(project_path) + "%" + dict_clause["project"] = escape(str(project_path)) + "%" else: - dict_clause["project"] = str(project_path) + dict_clause["project"] = escape(str(project_path)) if sub_job_name is None: dict_clause["subjob"] = None elif sub_job_name != "%": - dict_clause["subjob"] = str(sub_job_name) + dict_clause["subjob"] = escape(str(sub_job_name)) if element_lst is not None: dict_clause["element_lst"] = element_lst @@ -502,100 +522,6 @@ def change_column_type( else: raise PermissionError("Not avilable in viewer mode.") - def get_items_sql( - self, where_condition: Optional[str] = None, sql_statement: Optional[str] = None - ) -> List[dict]: - """ - Submit an SQL query to the database - - Args: - where_condition (str): SQL where query, query like: "project LIKE 'lammps.phonons.Ni_fcc%'" - sql_statement (str): general SQL query, normal SQL statement - - Returns: - list: get a list of dictionaries, where each dictionary represents one item of the table like: - [{u'chemicalformula': u'BO', - u'computer': u'localhost', - u'hamilton': u'VAMPS', - u'hamversion': u'1.1', - u'id': 1, - u'job': u'testing', - u'masterid': None, - u'parentid': 0, - u'project': u'database.testing', - u'projectpath': u'/TESTING', - u'status': u'KAAAA', - u'subjob': u'testJob', - u'timestart': u'2016-05-02 11:31:04.253377', - u'timestop': u'2016-05-02 11:31:04.371165', - u'totalcputime': 0.117788, - u'username': u'User'}, - {u'chemicalformula': u'BO', - u'computer': u'localhost', - u'hamilton': u'VAMPS', - u'hamversion': u'1.1', - u'id': 2, - u'job': u'testing', - u'masterid': 0, - u'parentid': 0, - u'project': u'database.testing', - u'projectpath': u'/TESTING', - u'status': u'KAAAA', - u'subjob': u'testJob', - u'timestart': u'2016-05-02 11:31:04.253377', - u'timestop': u'2016-05-02 11:31:04.371165', - u'totalcputime': 0.117788, - u'username': u'User'}.....] - """ - - if where_condition: - where_condition = ( - where_condition.replace("like", "similar to") - if self._engine.dialect.name == "postgresql" - else where_condition - ) - try: - query = "select * from " + self.table_name + " where " + where_condition - query.replace("%", "%%") - result = self.conn.execute(text(query)) - except Exception as except_msg: - print("EXCEPTION in get_items_sql: ", except_msg) - raise ValueError("EXCEPTION in get_items_sql: ", except_msg) - elif sql_statement: - sql_statement = ( - sql_statement.replace("like", "similar to") - if self._engine.dialect.name == "postgresql" - else sql_statement - ) - # TODO: make it save against SQL injection - result = self.conn.execute(text(sql_statement)) - else: - result = self.conn.execute(text("select * from " + self.table_name)) - row = result.mappings().all() - if not self._keep_connection: - self.conn.close() - - # change the date of str datatype back into datetime object - output_list = [] - for col in row: - # ensures working with db entries, which are camel case - timestop_index = [item.lower() for item in col.keys()].index("timestop") - timestart_index = [item.lower() for item in col.keys()].index("timestart") - tmp_values = list(col.values()) - if (tmp_values[timestop_index] and tmp_values[timestart_index]) is not None: - # changes values - try: - tmp_values[timestop_index] = datetime.strptime( - str(tmp_values[timestop_index]), "%Y-%m-%d %H:%M:%S.%f" - ) - tmp_values[timestart_index] = datetime.strptime( - str(tmp_values[timestart_index]), "%Y-%m-%d %H:%M:%S.%f" - ) - except ValueError: - print("error in: ", str(col)) - output_list += [dict(zip(col.keys(), tmp_values))] - return output_list - def _check_chem_formula_length(self, par_dict: dict) -> dict: """ performs a check whether the length of chemical formula exceeds the defined limit @@ -900,7 +826,7 @@ def get_items_dict( return_all_columns (bool): return all columns or only the 'id' - still the format stays the same. Returns: - list: the function returns a list of dicts like get_items_sql, but it does not format datetime: + list: the function returns a list of dicts, but it does not format datetime: [{'chemicalformula': u'Ni108', 'computer': u'mapc157', 'hamilton': u'LAMMPS', @@ -974,10 +900,10 @@ def get_items_dict( self.conn.connection.create_function("like", 2, self.regexp) result = self.conn.execute(query) - row = result.fetchall() + results = [row._asdict() for row in result.fetchall()] if not self._keep_connection: self.conn.close() - return [dict(zip(col._mapping.keys(), col._mapping.values())) for col in row] + return results def get_job_status(self, job_id: int) -> Union[str, None]: try: diff --git a/pyiron_base/jobs/job/core.py b/pyiron_base/jobs/job/core.py index a018d1158..1e1c1de28 100644 --- a/pyiron_base/jobs/job/core.py +++ b/pyiron_base/jobs/job/core.py @@ -876,6 +876,7 @@ def _internal_copy_to( _copy_database_entry( new_job_core=new_job_core, job_copied_id=self.job_id, + username=state.settings.login_user, ) else: new_job_core.reset_job_id(job_id=None) diff --git a/pyiron_base/jobs/job/util.py b/pyiron_base/jobs/job/util.py index 5e8dc9116..8a8a962d5 100644 --- a/pyiron_base/jobs/job/util.py +++ b/pyiron_base/jobs/job/util.py @@ -42,7 +42,9 @@ def _copy_database_entry( - new_job_core: "pyiron_base.jobs.job.generic.GenericJob", job_copied_id: int + new_job_core: "pyiron_base.jobs.job.generic.GenericJob", + job_copied_id: int, + username: Optional[str] = None, ) -> None: """ Copy database entry from previous job @@ -50,6 +52,7 @@ def _copy_database_entry( Args: new_job_core (GenericJob): Copy of the job object job_copied_id (int): Job id of the copied job + username (str): Optional name of the user to copy the job to """ db_entry = new_job_core.project.db.get_item_by_id(job_copied_id) if db_entry is not None: @@ -57,6 +60,8 @@ def _copy_database_entry( db_entry["subjob"] = new_job_core.project_hdf5.h5_path db_entry["project"] = new_job_core.project_hdf5.project_path db_entry["projectpath"] = new_job_core.project_hdf5.root_path + if username is not None: + db_entry["username"] = username del db_entry["id"] job_id = new_job_core.project.db.add_item_dict(db_entry) new_job_core.reset_job_id(job_id=job_id) diff --git a/pyiron_base/jobs/master/list.py b/pyiron_base/jobs/master/list.py index 13d233c50..e2c255f3a 100644 --- a/pyiron_base/jobs/master/list.py +++ b/pyiron_base/jobs/master/list.py @@ -218,7 +218,7 @@ def run_if_refresh(self): Internal helper function the run if refresh function is called when the job status is 'refresh'. If the job was suspended previously, the job is going to be started again, to be continued. """ - log_str = "{}, status: {}, finished: {} parallel master " "refresh".format( + log_str = "{}, status: {}, finished: {} parallel master refresh".format( self.job_info_str, self.status, self.is_finished() ) self._logger.info(log_str) diff --git a/pyiron_base/jobs/master/parallel.py b/pyiron_base/jobs/master/parallel.py index c68520851..f22635a95 100644 --- a/pyiron_base/jobs/master/parallel.py +++ b/pyiron_base/jobs/master/parallel.py @@ -308,7 +308,7 @@ def run_if_refresh(self): Internal helper function the run if refresh function is called when the job status is 'refresh'. If the job was suspended previously, the job is going to be started again, to be continued. """ - log_str = "{}, status: {}, finished: {} parallel master " "refresh".format( + log_str = "{}, status: {}, finished: {} parallel master refresh".format( self.job_info_str, self.status, self.is_finished() ) self._logger.info(log_str) diff --git a/pyiron_base/project/delayed.py b/pyiron_base/project/delayed.py index 39fb0ff7b..0355b7595 100644 --- a/pyiron_base/project/delayed.py +++ b/pyiron_base/project/delayed.py @@ -259,10 +259,19 @@ def draw(self): draw(node_dict=node_dict, edge_lst=edge_lst) def get_python_result(self): - if isinstance(self._result, dict): - return self._result[self._output_key] - else: + if isinstance(self._result, dict) and self._output_key is not None: + return self._result[str(self._output_key)] + elif isinstance(self._result, list): + if self._list_index is not None: + return self._result[self._list_index] + elif self._output_key is not None: + return self._result[int(self._output_key)] + else: + return self._result + elif self._output_key is not None: return getattr(self._result.output, self._output_key) + else: + return self._result def get_file_result(self): return getattr(self._result.files, self._output_file) @@ -277,8 +286,10 @@ def pull(self): return self.get_python_result() elif self._output_file is not None: return self.get_file_result() - elif self._list_index is not None: + elif isinstance(self._result, list) and self._list_index is not None: return self._result[self._list_index] + elif isinstance(self._result, dict) and self._list_index is not None: + return self._result[str(self._list_index)] else: return self._result diff --git a/pyiron_base/storage/flattenedstorage.py b/pyiron_base/storage/flattenedstorage.py index f6a0dfe43..693686741 100644 --- a/pyiron_base/storage/flattenedstorage.py +++ b/pyiron_base/storage/flattenedstorage.py @@ -1040,7 +1040,7 @@ def read_array(name, hdf): # itemsize of original a is four bytes per character, so divide by four to get # length of the orignal stored unicode string; np.dtype('U1').itemsize is just a # platform agnostic way of knowing how wide a unicode charater is for numpy - dtype=f"U{a.dtype.itemsize//np.dtype('U1').itemsize}", + dtype=f"U{a.dtype.itemsize // np.dtype('U1').itemsize}", ) return a diff --git a/pyproject.toml b/pyproject.toml index 85b27e7a5..582760c66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,20 +23,20 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ - "cloudpickle==3.1.0", - "executorlib==0.0.7", - "h5io_browser==0.1.5", + "cloudpickle==3.1.1", + "executorlib==0.0.8", + "h5io_browser==0.1.6", "h5py==3.12.1", - "numpy==2.2.1", + "numpy==2.2.2", "monty==2025.1.9", "pandas==2.2.3", "psutil==6.1.1", - "pyfileindex==0.0.32", + "pyfileindex==0.0.33", "pyiron_dataclasses==0.0.1", "pyiron_snippets==0.1.4", "pysqa==0.2.3", - "sqlalchemy==2.0.36", - "tables==3.10.1", + "sqlalchemy==2.0.37", + "tables==3.10.2", "tqdm==4.67.1", "traitlets==5.14.3", ] diff --git a/tests/unit/database/test_database_access.py b/tests/unit/database/test_database_access.py index b08c94284..07d0b42a4 100644 --- a/tests/unit/database/test_database_access.py +++ b/tests/unit/database/test_database_access.py @@ -16,11 +16,108 @@ from datetime import datetime from random import choice from string import ascii_uppercase +from typing import List, Optional from pyiron_base.database.generic import DatabaseAccess from pyiron_base._tests import PyironTestCase from sqlalchemy import text +# legacy method of DatabaseAccess; kept here only to test _get_items_dict +def get_items_sql( + self, where_condition: Optional[str] = None, sql_statement: Optional[str] = None +) -> List[dict]: + """ + Submit an SQL query to the database + + Args: + where_condition (str): SQL where query, query like: "project LIKE 'lammps.phonons.Ni_fcc%'" + sql_statement (str): general SQL query, normal SQL statement + + Returns: + list: get a list of dictionaries, where each dictionary represents one item of the table like: + [{u'chemicalformula': u'BO', + u'computer': u'localhost', + u'hamilton': u'VAMPS', + u'hamversion': u'1.1', + u'id': 1, + u'job': u'testing', + u'masterid': None, + u'parentid': 0, + u'project': u'database.testing', + u'projectpath': u'/TESTING', + u'status': u'KAAAA', + u'subjob': u'testJob', + u'timestart': u'2016-05-02 11:31:04.253377', + u'timestop': u'2016-05-02 11:31:04.371165', + u'totalcputime': 0.117788, + u'username': u'User'}, + {u'chemicalformula': u'BO', + u'computer': u'localhost', + u'hamilton': u'VAMPS', + u'hamversion': u'1.1', + u'id': 2, + u'job': u'testing', + u'masterid': 0, + u'parentid': 0, + u'project': u'database.testing', + u'projectpath': u'/TESTING', + u'status': u'KAAAA', + u'subjob': u'testJob', + u'timestart': u'2016-05-02 11:31:04.253377', + u'timestop': u'2016-05-02 11:31:04.371165', + u'totalcputime': 0.117788, + u'username': u'User'}.....] + """ + + if where_condition: + where_condition = ( + where_condition.replace("like", "similar to") + if self._engine.dialect.name == "postgresql" + else where_condition + ) + try: + query = "select * from " + self.table_name + " where " + where_condition + query.replace("%", "%%") + result = self.conn.execute(text(query)) + except Exception as except_msg: + print("EXCEPTION in get_items_sql: ", except_msg) + raise ValueError("EXCEPTION in get_items_sql: ", except_msg) + elif sql_statement: + sql_statement = ( + sql_statement.replace("like", "similar to") + if self._engine.dialect.name == "postgresql" + else sql_statement + ) + # TODO: make it save against SQL injection + result = self.conn.execute(text(sql_statement)) + else: + result = self.conn.execute(text("select * from " + self.table_name)) + row = result.mappings().all() + if not self._keep_connection: + self.conn.close() + + # change the date of str datatype back into datetime object + output_list = [] + for col in row: + # ensures working with db entries, which are camel case + timestop_index = [item.lower() for item in col.keys()].index("timestop") + timestart_index = [item.lower() for item in col.keys()].index("timestart") + tmp_values = list(col.values()) + if (tmp_values[timestop_index] and tmp_values[timestart_index]) is not None: + # changes values + try: + tmp_values[timestop_index] = datetime.strptime( + str(tmp_values[timestop_index]), "%Y-%m-%d %H:%M:%S.%f" + ) + tmp_values[timestart_index] = datetime.strptime( + str(tmp_values[timestart_index]), "%Y-%m-%d %H:%M:%S.%f" + ) + except ValueError: + print("error in: ", str(col)) + output_list += [dict(zip(col.keys(), tmp_values))] + return output_list + + class TestDatabaseAccess(PyironTestCase): """ Standard Unittest of the DatabaseAccess class @@ -81,70 +178,6 @@ def test_get_table_headings(self): for item in heading_list: self.assertTrue(item in self.database.get_table_headings()) - def test_get_items_sql(self): - """ - Tests get_items_sql function - Returns: - """ - self.add_items("Blub") - self.add_items("Bluk") - # has to return a list - self.assertIsInstance( - self.database.get_items_sql("chemicalformula LIKE 'Blu%'"), list - ) - self.assertRaises( - Exception, self.database.get_items_sql, "A Wrong where Clause" - ) # Where clause must be right - # A valid sqlstatement should also return a valid list - self.assertIsInstance( - self.database.get_items_sql( - where_condition="", sql_statement="select * from simulation" - ), - list, - ) - par_dict = self.add_items("BO") - key = par_dict["id"] - # be sure that get_items_sql returns right result with right statement - result = self.database.get_items_sql( - where_condition="", - sql_statement="select * from simulation " "where id=%s" % key, - )[-1] - self.assertTrue(par_dict.items() <= result.items()) - - def test_get_items_sql_like_regex(self): - """ - Tests the regex functionality of 'like' - Returns: - """ - elem1 = self.add_items("H4Ni2") - elem2 = self.add_items("H2") - elem3 = self.add_items("H6") - elem4 = self.add_items("HeNi2") - elem5 = self.add_items("H12Ni5") - elem6 = self.add_items("H12") - elem7 = self.add_items("He2") - - # H([0-9]*) matches H2, H6 and H12 - self.assertEqual( - [elem2, elem3, elem6], - self.database.get_items_sql(r"chemicalformula like 'H([0-9]*)'"), - ) - # He(\d)*(Ni)?\d* matches HeNi2, He2 - self.assertEqual( - [elem4, elem7], - self.database.get_items_sql(r"chemicalformula like 'He(\d)*(Ni)?\d*'"), - ) - # H\d*Ni\d* matches H4Ni2, H12Ni5 - self.assertEqual( - [elem1, elem5], - self.database.get_items_sql(r"chemicalformula like 'H\d*Ni\d*'"), - ) - # assert that not something random really is in the Database, recommended by Samuel Hartke - # Murat: 'Just ignore the line!' - self.assertEqual( - [], self.database.get_items_sql(r"chemicalformula like 'B\d[a-z]'") - ) - def test_add_item_dict(self): """ Tests add_item_dict function @@ -227,7 +260,7 @@ def test_get_items_dict_and(self): item_dict = {"hamilton": "VAMPE", "hamversion": "1.1"} self.assertEqual( self.database.get_items_dict(item_dict), - self.database.get_items_sql("hamilton='VAMPE' and hamversion='1.1'"), + get_items_sql(self.database, "hamilton='VAMPE' and hamversion='1.1'"), ) def test_get_items_dict_project(self): @@ -272,8 +305,8 @@ def test_get_items_dict_or(self): # tests an example or statement item_dict = {"chemicalformula": ["Blub", "Blab"]} # assert that both the sql and non-sql methods give the same result - sql_db = self.database.get_items_sql( - "chemicalformula='Blub' or chemicalformula='Blab'" + sql_db = get_items_sql( + self.database, "chemicalformula='Blub' or chemicalformula='Blab'" ) dict_db = self.database.get_items_dict(item_dict) for item in sql_db: @@ -288,7 +321,7 @@ def test_get_items_dict_like(self): # tests an example like statement item_dict = {"status": "%AA%"} # assert that both the sql and non-sql methods give the same result - sql_db = self.database.get_items_sql("status like '%AA%'") + sql_db = get_items_sql(self.database, "status like '%AA%'") dict_db = self.database.get_items_dict(item_dict) for item in sql_db: self.assertTrue(item in dict_db) diff --git a/tests/unit/flex/test_decorator.py b/tests/unit/flex/test_decorator.py index 2db56c852..20f8ebd04 100644 --- a/tests/unit/flex/test_decorator.py +++ b/tests/unit/flex/test_decorator.py @@ -41,6 +41,24 @@ def my_function_b(a, b=8): self.assertEqual(len(nodes_dict), 6) self.assertEqual(len(edges_lst), 6) + def test_delayed_return_types(self): + @job + def my_function_a(a, b=8): + return [a + b] + + @job(cores=2, output_key_lst=["0"]) + def my_function_b(a, b=8): + return [a + b] + + c = my_function_a(a=1, b=2, pyiron_project=self.project, list_length=1) + for a in c: + d = my_function_b(a=a, b=3, pyiron_project=self.project) + self.assertEqual(d.pull(), [6]) + self.assertEqual(c.pull(), [3]) + nodes_dict, edges_lst = d.get_graph() + self.assertEqual(len(nodes_dict), 7) + self.assertEqual(len(edges_lst), 6) + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/job/test_jobtypechoice.py b/tests/unit/job/test_jobtypechoice.py index a358c155c..8da717725 100644 --- a/tests/unit/job/test_jobtypechoice.py +++ b/tests/unit/job/test_jobtypechoice.py @@ -33,7 +33,7 @@ def test_attr(self): getattr(self.jobtypechoice, k) except AttributeError: self.fail( - "job class {} in JOB_CLASS_DICT, but not on " "JobTypeChoice".format(k) + "job class {} in JOB_CLASS_DICT, but not on JobTypeChoice".format(k) ) def test_extend_job_class_dict(self): @@ -44,14 +44,13 @@ def test_extend_job_class_dict(self): JOB_CLASS_DICT["TestClass"] = "my.own.test.module" self.assertTrue( "TestClass" in dir(self.jobtypechoice), - "new job class added to JOB_CLASS_DICT, but not " "returned in dir()", + "new job class added to JOB_CLASS_DICT, but not returned in dir()", ) try: getattr(self.jobtypechoice, "TestClass") except AttributeError: self.fail( - "new job class added to JOB_CLASS_DICT, but not defined " - "JobTypeChoice" + "new job class added to JOB_CLASS_DICT, but not defined JobTypeChoice" ) @@ -81,5 +80,5 @@ def test_attr(self): getattr(self.job_factory, k) except AttributeError: self.fail( - "job class {} in JOB_CLASS_DICT, but not on " "JobTypeChoice".format(k) + "job class {} in JOB_CLASS_DICT, but not on JobTypeChoice".format(k) ) diff --git a/tests/unit/project/test_jobloader.py b/tests/unit/project/test_jobloader.py index 7b3573ffd..59b13c139 100644 --- a/tests/unit/project/test_jobloader.py +++ b/tests/unit/project/test_jobloader.py @@ -22,8 +22,7 @@ def test_load(self): self.assertEqual( len(self.project.job_table()), len(self.project.load.__dir__()), - msg="Tab completion (`__dir__`) should see both jobs at this project " - "level", + msg="Tab completion (`__dir__`) should see both jobs at this project level", # Note: When job names are duplicated at different sub-project levels, the # job name occurs in the __dir__ multiple times, even though it will only # show up in the tab-completion menu once (where it accesses the top-most