diff --git a/.ci_support/environment-docs.yml b/.ci_support/environment-docs.yml index 325bc83be..801490e24 100644 --- a/.ci_support/environment-docs.yml +++ b/.ci_support/environment-docs.yml @@ -13,9 +13,9 @@ dependencies: - pandas - phonopy - pyiron_base -- pymatgen - pyscal - scipy - seekpath - scikit-learn - spglib +- structuretoolkit \ No newline at end of file diff --git a/.ci_support/environment.yml b/.ci_support/environment.yml index 34b26611e..a2e6f2d73 100644 --- a/.ci_support/environment.yml +++ b/.ci_support/environment.yml @@ -1,25 +1,27 @@ channels: - conda-forge dependencies: +- aimsgb =1.1.0 - ase =3.22.1 +- atomistics =0.1.2 - coveralls - coverage - codacy-coverage - defusedxml =0.7.1 -- h5py =3.9.0 -- matplotlib-base =3.7.2 +- h5py =3.10.0 +- matplotlib-base =3.8.1 - mendeleev =0.14.0 -- mp-api =0.33.3 -- numpy =1.24.3 -- pandas =2.0.3 +- mp-api =0.37.5 +- numpy =1.26.0 +- pandas =2.1.3 - phonopy =2.20.0 - pint =0.22 -- pyiron_base =0.6.3 +- pyiron_base =0.6.9 - pylammpsmpi =0.2.6 - pymatgen =2023.8.10 - pyscal =2.10.18 -- scikit-learn =1.3.0 -- scipy =1.11.1 +- scikit-learn =1.3.2 +- scipy =1.11.3 - seekpath =2.1.0 -- spglib =2.0.2 -- structuretoolkit =0.0.6 +- spglib =2.1.0 +- structuretoolkit =0.0.12 diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 56001c060..2abc85bb0 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -1,3 +1,6 @@ +# This workflow is used to upload and deploy a new release to PyPi +# Based on https://github.com/pypa/gh-action-pypi-publish + name: PyPi Release on: @@ -5,16 +8,20 @@ on: pull_request: workflow_dispatch: -# based on https://github.com/pypa/gh-action-pypi-publish jobs: build: + if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' runs-on: ubuntu-latest - + environment: + name: pypi + url: https://pypi.org/p/pyiron_atomistics + permissions: + id-token: write steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: - python-version: "3.10" + python-version: "3.11" - name: Install dependencies run: >- @@ -26,8 +33,4 @@ jobs: run: >- python setup.py sdist bdist_wheel - name: Publish distribution 📦 to PyPI - if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' - uses: pypa/gh-action-pypi-publish@master - with: - user: __token__ - password: ${{ secrets.pypi_password }} + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/binder/environment.yml b/binder/environment.yml index 9215c5e8f..f71426727 100644 --- a/binder/environment.yml +++ b/binder/environment.yml @@ -16,4 +16,4 @@ dependencies: - jupyterlab - pyiron-data >=0.0.22 - sqsgenerator -- pymatgen +- structuretoolkit diff --git a/pyiron_atomistics/_version.py b/pyiron_atomistics/_version.py index 9b0c840b6..f305839a7 100644 --- a/pyiron_atomistics/_version.py +++ b/pyiron_atomistics/_version.py @@ -4,8 +4,9 @@ # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) +# This file is released into the public domain. +# Generated by versioneer-0.29 +# https://github.com/python-versioneer/python-versioneer """Git implementation of _version.py.""" @@ -14,9 +15,11 @@ import re import subprocess import sys +from typing import Any, Callable, Dict, List, Optional, Tuple +import functools -def get_keywords(): +def get_keywords() -> Dict[str, str]: """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must @@ -32,8 +35,15 @@ def get_keywords(): class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + parentdir_prefix: str + versionfile_source: str + verbose: bool -def get_config(): + +def get_config() -> VersioneerConfig: """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py @@ -51,14 +61,14 @@ class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" - def decorate(f): + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} @@ -68,24 +78,39 @@ def decorate(f): return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen( - [c] + args, + process = subprocess.Popen( + [command] + args, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, ) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -96,18 +121,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= if verbose: print("unable to find command, tried %s" % (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -116,7 +143,7 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return { @@ -126,9 +153,8 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): "error": None, "date": None, } - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: print( @@ -139,41 +165,48 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -186,11 +219,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -199,7 +232,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r"\d", r)]) + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -208,6 +241,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): r = ref[len(tag_prefix) :] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r"\d", r): + continue if verbose: print("picking %s" % r) return { @@ -230,7 +268,9 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, root: str, verbose: bool, runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -241,7 +281,14 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -249,7 +296,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command( + describe_out, rc = runner( GITS, [ "describe", @@ -258,7 +305,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): "--always", "--long", "--match", - "%s*" % tag_prefix, + f"{tag_prefix}[[:digit:]]*", ], cwd=root, ) @@ -266,16 +313,48 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -292,7 +371,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # TAG-NUM-gHEX mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? + # unparsable. Maybe git-describe is misbehaving? pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces @@ -318,26 +397,27 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ - 0 - ].strip() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -361,23 +441,70 @@ def render_pep440(pieces): return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] + rendered = "0.post0.dev%d" % pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -404,12 +531,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -426,7 +582,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -446,7 +602,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -466,7 +622,7 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: return { @@ -482,10 +638,14 @@ def render(pieces, style): if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -504,7 +664,7 @@ def render(pieces, style): } -def get_versions(): +def get_versions() -> Dict[str, Any]: """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some @@ -524,7 +684,7 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split("/"): + for _ in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: return { diff --git a/pyiron_atomistics/atomistics/master/murnaghan.py b/pyiron_atomistics/atomistics/master/murnaghan.py index 0d5aa4271..86a89b3f9 100644 --- a/pyiron_atomistics/atomistics/master/murnaghan.py +++ b/pyiron_atomistics/atomistics/master/murnaghan.py @@ -3,13 +3,18 @@ # Distributed under the terms of "New BSD License", see the LICENSE file. from __future__ import print_function -from typing import Optional, Literal - +from typing import Optional + +from atomistics.shared.thermo.debye import DebyeModel +from atomistics.workflows.evcurve.fit import ( + EnergyVolumeFit, + fitfunction, + get_error, + fit_leastsq_eos, +) +from atomistics.workflows.evcurve.workflow import _strain_axes import matplotlib.pyplot as plt import numpy as np -import scipy.constants -import scipy.integrate -import scipy.optimize as spy from pyiron_atomistics.atomistics.master.parallel import AtomisticParallelMaster from pyiron_atomistics.atomistics.structure.atoms import Atoms, ase_to_pyiron @@ -27,199 +32,21 @@ __date__ = "Sep 1, 2017" -eV_div_A3_to_GPa = ( - 1e21 / scipy.constants.physical_constants["joule-electron volt relationship"][0] -) - - -def _debye_kernel(xi): - return xi**3 / (np.exp(xi) - 1) - - -def debye_integral(x): - return scipy.integrate.quad(_debye_kernel, 0, x)[0] - - -def debye_function(x): - if hasattr(x, "__len__"): - return np.array([3 / xx**3 * debye_integral(xx) for xx in x]) - return 3 / x**3 * debye_integral(x) - - -# https://gitlab.com/ase/ase/blob/master/ase/eos.py -def birchmurnaghan_energy(V, E0, B0, BP, V0): - "BirchMurnaghan equation from PRB 70, 224107" - eta = (V0 / V) ** (1 / 3) - return E0 + 9 * B0 * V0 / 16 * (eta**2 - 1) ** 2 * ( - 6 + BP * (eta**2 - 1) - 4 * eta**2 - ) - - -def vinet_energy(V, E0, B0, BP, V0): - "Vinet equation from PRB 70, 224107" - eta = (V / V0) ** (1 / 3) - return E0 + 2 * B0 * V0 / (BP - 1) ** 2 * ( - 2 - (5 + 3 * BP * (eta - 1) - 3 * eta) * np.exp(-3 * (BP - 1) * (eta - 1) / 2) - ) - - -def murnaghan(V, E0, B0, BP, V0): - "From PRB 28,5480 (1983" - E = E0 + B0 * V / BP * (((V0 / V) ** BP) / (BP - 1) + 1) - V0 * B0 / (BP - 1) - return E - - -def birch(V, E0, B0, BP, V0): - """ - From Intermetallic compounds: Principles and Practice, Vol. I: Principles - Chapter 9 pages 195-210 by M. Mehl. B. Klein, D. Papaconstantopoulos - paper downloaded from Web - - case where n=0 - """ - E = ( - E0 - + 9 / 8 * B0 * V0 * ((V0 / V) ** (2 / 3) - 1) ** 2 - + 9 / 16 * B0 * V0 * (BP - 4) * ((V0 / V) ** (2 / 3) - 1) ** 3 - ) - return E - - -def pouriertarantola(V, E0, B0, BP, V0): - "Pourier-Tarantola equation from PRB 70, 224107" - eta = (V / V0) ** (1 / 3) - squiggle = -3 * np.log(eta) - - E = E0 + B0 * V0 * squiggle**2 / 6 * (3 + squiggle * (BP - 2)) - return E - - -def fitfunction(parameters, vol, fittype="vinet"): - """ - Fit the energy volume curve - - Args: - parameters (list): [E0, B0, BP, V0] list of fit parameters - vol (float/numpy.dnarray): single volume or a vector of volumes as numpy array - fittype (str): on of the following ['birch', 'birchmurnaghan', 'murnaghan', 'pouriertarantola', 'vinet'] - - Returns: - (float/numpy.dnarray): single energy as float or a vector of energies as numpy array - """ - [E0, b0, bp, V0] = parameters - # Unit correction - B0 = b0 / eV_div_A3_to_GPa - BP = bp - V = vol - if fittype.lower() == "birchmurnaghan": - return birchmurnaghan_energy(V, E0, B0, BP, V0) - elif fittype.lower() == "vinet": - return vinet_energy(V, E0, B0, BP, V0) - elif fittype.lower() == "murnaghan": - return murnaghan(V, E0, B0, BP, V0) - elif fittype.lower() == "pouriertarantola": - return pouriertarantola(V, E0, B0, BP, V0) - elif fittype.lower() == "birch": - return birch(V, E0, B0, BP, V0) - else: - raise ValueError - - -def fit_leastsq(p0, datax, datay, fittype="vinet"): - """ - Least square fit - - Args: - p0 (list): [E0, B0, BP, V0] list of fit parameters - datax (float/numpy.dnarray): volumes to fit - datay (float/numpy.dnarray): energies corresponding to the volumes - fittype (str): on of the following ['birch', 'birchmurnaghan', 'murnaghan', 'pouriertarantola', 'vinet'] - - Returns: - list: [E0, B0, BP, V0], [E0_err, B0_err, BP_err, V0_err] - """ - # http://stackoverflow.com/questions/14581358/getting-standard-errors-on-fitted-parameters-using-the-optimize-leastsq-method-i - - errfunc = lambda p, x, y, fittype: fitfunction(p, x, fittype) - y - - pfit, pcov, infodict, errmsg, success = spy.leastsq( - errfunc, p0, args=(datax, datay, fittype), full_output=1, epsfcn=0.0001 - ) - - if (len(datay) > len(p0)) and pcov is not None: - s_sq = (errfunc(pfit, datax, datay, fittype) ** 2).sum() / ( - len(datay) - len(p0) - ) - pcov = pcov * s_sq - else: - pcov = np.inf - - error = [] - for i in range(len(pfit)): - try: - error.append(np.absolute(pcov[i][i]) ** 0.5) - except: - error.append(0.00) - pfit_leastsq = pfit - perr_leastsq = np.array(error) - return pfit_leastsq, perr_leastsq - - -class DebyeModel(object): +class MurnaghanDebyeModel(DebyeModel): """ Calculate Thermodynamic Properties based on the Murnaghan output """ def __init__(self, murnaghan, num_steps=50): self._murnaghan = murnaghan - - # self._atoms_per_cell = len(murnaghan.structure) - self._v_min = None - self._v_max = None - self._num_steps = None - - self._volume = None - self._init_volume() - - self.num_steps = num_steps - self._fit_volume = None - self._debye_T = None - - def _init_volume(self): - vol = self._murnaghan["output/volume"] - self._v_min, self._v_max = np.min(vol), np.max(vol) - - def _set_volume(self): - if self._v_min and self._v_max and self._num_steps: - self._volume = np.linspace(self._v_min, self._v_max, self._num_steps) - self._reset() - # print ('set_volume: ', self._num_steps) - - @property - def num_steps(self): - return self._num_steps - - @num_steps.setter - def num_steps(self, val): - self._num_steps = val - self._set_volume() - - @property - def volume(self): - if self._volume is None: - self._init_volume() - self._set_volume() - return self._volume - - @volume.setter - def volume(self, volume_lst): - self._volume = volume_lst - self._v_min = np.min(volume_lst) - self._v_max = np.max(volume_lst) - self._reset() - - def _reset(self): - self._debye_T = None + fit_dict = self._murnaghan.fit_dict.copy() + fit_dict["volume"] = self._murnaghan["output/volume"] + fit_dict["energy"] = self._murnaghan["output/energy"] + super().__init__( + fit_dict=fit_dict, + masses=self._murnaghan.structure.get_masses(), + num_steps=num_steps, + ) def polynomial(self, poly_fit=None, volumes=None): if poly_fit is None: @@ -230,69 +57,6 @@ def polynomial(self, poly_fit=None, volumes=None): return p_fit(self.volume) return p_fit(volumes) - @property - def debye_temperature(self): - if self._debye_T is not None: - return self._debye_T - - GPaTokBar = 10 - Ang3_to_Bohr3 = ( - scipy.constants.angstrom**3 - / scipy.constants.physical_constants["Bohr radius"][0] ** 3 - ) - convert = 67.48 # conversion factor, Moruzzi Eq. (4) - empirical = 0.617 # empirical factor, Moruzzi Eq. (6) - gamma_low, gamma_high = 1, 2 / 3 # low/high T gamma - - out = self._murnaghan["output"] - V0 = out["equilibrium_volume"] - B0 = out["equilibrium_bulk_modulus"] - Bp = out["equilibrium_b_prime"] - - vol = self.volume - - mass = set(self._murnaghan.structure.get_masses()) - if len(mass) > 1: - raise NotImplementedError( - "Debye temperature only for single species systems!" - ) - mass = list(mass)[0] - - r0 = (3 * V0 * Ang3_to_Bohr3 / (4 * np.pi)) ** (1.0 / 3.0) - debye_zero = empirical * convert * np.sqrt(r0 * B0 * GPaTokBar / mass) - # print('r0, B0, Bp, mass, V0', r0, B0, Bp, mass, V0) - # print('gamma_low, gamma_high: ', gamma_low, gamma_high) - # print('debye_zero, V0: ', debye_zero, V0) - if vol is None: - print("WARNING: vol: ", vol) - - debye_low = debye_zero * (V0 / vol) ** (-gamma_low + 0.5 * (1 + Bp)) - debye_high = debye_zero * (V0 / vol) ** (-gamma_high + 0.5 * (1 + Bp)) - - self._debye_T = (debye_low, debye_high) - return self._debye_T - - def energy_vib(self, T, debye_T=None, low_T_limit=True): - kB = 0.086173422 / 1000 # eV/K - if debye_T is None: - if low_T_limit: - debye_T = self.debye_temperature[0] # low - else: - debye_T = self.debye_temperature[1] # high - if hasattr(debye_T, "__len__"): - val = [ - 9.0 / 8.0 * kB * d_T - + T * kB * (3 * np.log(1 - np.exp(-d_T / T)) - debye_function(d_T / T)) - for d_T in debye_T - ] - val = np.array(val) - else: - val = 9.0 / 8.0 * kB * debye_T + T * kB * ( - 3 * np.log(1 - np.exp(-debye_T / T)) - debye_function(debye_T / T) - ) - atoms_per_cell = len(self._murnaghan.structure) - return atoms_per_cell * val - @property def publication(self): return { @@ -314,22 +78,6 @@ def publication(self): } -def _strain_axes( - structure: Atoms, axes: Literal["x", "y", "z"], volume_strain: float -) -> Atoms: - """ - Strain box along given axes to achieve given *volumetric* strain. - - Returns a copy. - """ - axes = np.array([a in axes for a in ("x", "y", "z")]) - num_axes = sum(axes) - # formula calculates the strain along each axis to achieve the overall volumetric strain - # beware that: (1+e)**x - 1 != e**x - strains = axes * ((1 + volume_strain) ** (1.0 / num_axes) - 1) - return structure.apply_strain(strains, return_box=True) - - class MurnaghanJobGenerator(JobGenerator): @property def parameter_list(self): @@ -361,275 +109,6 @@ def modify_job(self, job, parameter): return job -class EnergyVolumeFit(object): - """ - Fit energy volume curves - - Args: - volume_lst (list/numpy.dnarray): vector of volumes - energy_lst (list/numpy.dnarray): vector of energies - - Attributes: - - .. attribute:: volume_lst - - vector of volumes - - .. attribute:: energy_lst - - vector of energies - - .. attribute:: fit_dict - - dictionary of fit parameters - """ - - def __init__(self, volume_lst=None, energy_lst=None): - self._volume_lst = volume_lst - self._energy_lst = energy_lst - self._fit_dict = None - - @property - def volume_lst(self): - return self._volume_lst - - @volume_lst.setter - def volume_lst(self, vol_lst): - self._volume_lst = vol_lst - - @property - def energy_lst(self): - return self._energy_lst - - @energy_lst.setter - def energy_lst(self, eng_lst): - self._energy_lst = eng_lst - - @property - def fit_dict(self): - return self._fit_dict - - def _get_volume_and_energy_lst(self, volume_lst=None, energy_lst=None): - """ - Internal function to get the vector of volumes and the vector of energies - - Args: - volume_lst (list/numpy.dnarray/None): vector of volumes - energy_lst (list/numpy.dnarray/None): vector of energies - - Returns: - list: vector of volumes and vector of energies - """ - if volume_lst is None: - if self._volume_lst is None: - raise ValueError("Volume list not set.") - volume_lst = self._volume_lst - if energy_lst is None: - if self._energy_lst is None: - raise ValueError("Volume list not set.") - energy_lst = self._energy_lst - return volume_lst, energy_lst - - def fit_eos_general_intern(self, fittype="birchmurnaghan"): - self._fit_dict = self.fit_eos_general( - volume_lst=self._volume_lst, energy_lst=self._energy_lst, fittype=fittype - ) - - def fit_eos_general( - self, volume_lst=None, energy_lst=None, fittype="birchmurnaghan" - ): - """ - Fit on of the equations of state - - Args: - volume_lst (list/numpy.dnarray/None): vector of volumes - energy_lst (list/numpy.dnarray/None): vector of energies - fittype (str): on of the following ['birch', 'birchmurnaghan', 'murnaghan', 'pouriertarantola', 'vinet'] - - Returns: - dict: dictionary with fit results - """ - volume_lst, energy_lst = self._get_volume_and_energy_lst( - volume_lst=volume_lst, energy_lst=energy_lst - ) - fit_dict = {} - pfit_leastsq, perr_leastsq = self._fit_leastsq( - volume_lst=volume_lst, energy_lst=energy_lst, fittype=fittype - ) - fit_dict["fit_type"] = fittype - fit_dict["volume_eq"] = pfit_leastsq[3] - fit_dict["energy_eq"] = pfit_leastsq[0] - fit_dict["bulkmodul_eq"] = pfit_leastsq[1] - fit_dict["b_prime_eq"] = pfit_leastsq[2] - fit_dict["least_square_error"] = perr_leastsq # [e0, b0, bP, v0] - - return fit_dict - - def fit_polynomial(self, volume_lst=None, energy_lst=None, fit_order=3): - """ - Fit a polynomial - - Args: - volume_lst (list/numpy.dnarray/None): vector of volumes - energy_lst (list/numpy.dnarray/None): vector of energies - fit_order (int): Degree of the polynomial - - Returns: - dict: dictionary with fit results - """ - volume_lst, energy_lst = self._get_volume_and_energy_lst( - volume_lst=volume_lst, energy_lst=energy_lst - ) - fit_dict = {} - - # compute a polynomial fit - z = np.polyfit(volume_lst, energy_lst, fit_order) - p_fit = np.poly1d(z) - fit_dict["poly_fit"] = z - - # get equilibrium lattice constant - # search for the local minimum with the lowest energy - p_deriv_1 = np.polyder(p_fit, 1) - roots = np.roots(p_deriv_1) - - # volume_eq_lst = np.array([np.real(r) for r in roots if np.abs(np.imag(r)) < 1e-10]) - volume_eq_lst = np.array( - [ - np.real(r) - for r in roots - if ( - abs(np.imag(r)) < 1e-10 - and r >= min(volume_lst) - and r <= max(volume_lst) - ) - ] - ) - - e_eq_lst = p_fit(volume_eq_lst) - arg = np.argsort(e_eq_lst) - # print ("v_eq:", arg, e_eq_lst) - if len(e_eq_lst) == 0: - return None - e_eq = e_eq_lst[arg][0] - volume_eq = volume_eq_lst[arg][0] - - # get bulk modulus at equ. lattice const. - p_2deriv = np.polyder(p_fit, 2) - p_3deriv = np.polyder(p_fit, 3) - a2 = p_2deriv(volume_eq) - a3 = p_3deriv(volume_eq) - - b_prime = -(volume_eq * a3 / a2 + 1) - - fit_dict["fit_type"] = "polynomial" - fit_dict["fit_order"] = fit_order - fit_dict["volume_eq"] = volume_eq - fit_dict["energy_eq"] = e_eq - fit_dict["bulkmodul_eq"] = eV_div_A3_to_GPa * volume_eq * a2 - fit_dict["b_prime_eq"] = b_prime - fit_dict["least_square_error"] = self.get_error(volume_lst, energy_lst, p_fit) - return fit_dict - - def _fit_leastsq(self, volume_lst, energy_lst, fittype="birchmurnaghan"): - """ - Internal helper function for the least square fit - - Args: - volume_lst (list/numpy.dnarray/None): vector of volumes - energy_lst (list/numpy.dnarray/None): vector of energies - fittype (str): on of the following ['birch', 'birchmurnaghan', 'murnaghan', 'pouriertarantola', 'vinet'] - - Returns: - list: [E0, B0, BP, V0], [E0_err, B0_err, BP_err, V0_err] - """ - vol_lst = np.array(volume_lst).flatten() - eng_lst = np.array(energy_lst).flatten() - a, b, c = np.polyfit(vol_lst, eng_lst, 2) - v0 = -b / (2 * a) - pfit_leastsq, perr_leastsq = fit_leastsq( - [a * v0**2 + b * v0 + c, 2 * a * v0 * eV_div_A3_to_GPa, 4, v0], - vol_lst, - eng_lst, - fittype, - ) - return pfit_leastsq, perr_leastsq # [e0, b0, bP, v0] - - @staticmethod - def get_error(x_lst, y_lst, p_fit): - """ - - Args: - x_lst: - y_lst: - p_fit: - - Returns: - numpy.dnarray - """ - y_fit_lst = np.array(p_fit(x_lst)) - error_lst = (y_lst - y_fit_lst) ** 2 - return np.mean(error_lst) - - def fit_energy(self, volume_lst): - """ - Gives the energy value for the corresponding energy volume fit defined in the fit dictionary. - - Args: - volume_lst: list of volumes - - Returns: - list of energies - - """ - if not self._fit_dict: - return ValueError("parameter 'fit_dict' has to be defined!") - v = volume_lst - e0 = self._fit_dict["energy_eq"] - b0 = self._fit_dict["bulkmodul_eq"] / eV_div_A3_to_GPa - b_p = self._fit_dict["b_prime_eq"] - v0 = self._fit_dict["volume_eq"] - if self._fit_dict["fit_type"] == "birch": - return self.birch(v, e0, b0, b_p, v0) - elif self._fit_dict["fit_type"] == "birchmurnaghan": - return self.birchmurnaghan_energy(v, e0, b0, b_p, v0) - elif self._fit_dict["fit_type"] == "murnaghan": - return self.murnaghan(v, e0, b0, b_p, v0) - elif self._fit_dict["fit_type"] == "pouriertarantola": - return self.pouriertarantola(v, e0, b0, b_p, v0) - else: - return self.vinet_energy(v, e0, b0, b_p, v0) - - @staticmethod - def birchmurnaghan_energy(V, E0, B0, BP, V0): - "BirchMurnaghan equation from PRB 70, 224107" - return birchmurnaghan_energy(V, E0, B0, BP, V0) - - @staticmethod - def vinet_energy(V, E0, B0, BP, V0): - "Vinet equation from PRB 70, 224107" - return vinet_energy(V, E0, B0, BP, V0) - - @staticmethod - def murnaghan(V, E0, B0, BP, V0): - "From PRB 28,5480 (1983" - return murnaghan(V, E0, B0, BP, V0) - - @staticmethod - def birch(V, E0, B0, BP, V0): - """ - From Intermetallic compounds: Principles and Practice, Vol. I: Principles - Chapter 9 pages 195-210 by M. Mehl. B. Klein, D. Papaconstantopoulos - paper downloaded from Web - - case where n=0 - """ - return birch(V, E0, B0, BP, V0) - - @staticmethod - def pouriertarantola(V, E0, B0, BP, V0): - return pouriertarantola(V, E0, B0, BP, V0) - - # ToDo: not all abstract methods implemented class Murnaghan(AtomisticParallelMaster): """ @@ -684,7 +163,7 @@ def __init__(self, project, job_name): "The number of child jobs that are allowed to abort, before the whole job is considered aborted.", ) - self.debye_model = DebyeModel(self) + self.debye_model = None self.fit_module = EnergyVolumeFit() self.fit_dict = None @@ -708,6 +187,12 @@ def convergence_check(self) -> bool: @property def fit(self): + if self.debye_model is None and self.fit_dict is None: + raise ValueError( + "The fit object is only available after fitting the energy volume curve." + ) + elif self.debye_model is None: + self.debye_model = MurnaghanDebyeModel(self) return self.debye_model @property @@ -757,7 +242,7 @@ def _fit_eos_general(self, vol_erg_dic=None, fittype="birchmurnaghan"): return fit_dict def _fit_leastsq(self, volume_lst, energy_lst, fittype="birchmurnaghan"): - return self.fit_module._fit_leastsq( + return fit_leastsq_eos( volume_lst=volume_lst, energy_lst=energy_lst, fittype=fittype ) @@ -922,7 +407,7 @@ def plot( if self.fit_dict is not None: if self.input["fit_type"] == "polynomial": p_fit = np.poly1d(self.fit_dict["poly_fit"]) - least_square_error = self.fit_module.get_error(vol_lst, erg_lst, p_fit) + least_square_error = get_error(vol_lst, erg_lst, p_fit) ax.set_title("Murnaghan: error: " + str(least_square_error)) ax.plot( x_i / normalization, diff --git a/pyiron_atomistics/atomistics/master/phonopy.py b/pyiron_atomistics/atomistics/master/phonopy.py index 20cb16791..3724dc50d 100644 --- a/pyiron_atomistics/atomistics/master/phonopy.py +++ b/pyiron_atomistics/atomistics/master/phonopy.py @@ -10,16 +10,16 @@ import posixpath import scipy.constants from phonopy import Phonopy -from phonopy.structure.atoms import PhonopyAtoms from phonopy.units import VaspToTHz from phonopy.file_IO import write_FORCE_CONSTANTS -from pyiron_atomistics.atomistics.structure.atoms import Atoms +from pyiron_atomistics.atomistics.structure.atoms import ase_to_pyiron from pyiron_atomistics.atomistics.master.parallel import AtomisticParallelMaster from pyiron_atomistics.atomistics.structure.phonopy import ( publication as phonopy_publication, ) from pyiron_base import state, JobGenerator, ImportAlarm, deprecate +import structuretoolkit __author__ = "Jan Janssen, Yury Lysogorskiy" __copyright__ = ( @@ -52,40 +52,6 @@ def __init__(self, temperatures, free_energies, entropy, cv): self.cv = cv -def phonopy_to_atoms(ph_atoms): - """ - Convert Phonopy Atoms to ASE-like Atoms - Args: - ph_atoms: Phonopy Atoms object - - Returns: ASE-like Atoms object - - """ - return Atoms( - symbols=list(ph_atoms.get_chemical_symbols()), - positions=list(ph_atoms.get_positions()), - cell=list(ph_atoms.get_cell()), - pbc=True, - ) - - -def atoms_to_phonopy(atom): - """ - Convert ASE-like Atoms to Phonopy Atoms - Args: - atom: ASE-like Atoms - - Returns: - Phonopy Atoms - - """ - return PhonopyAtoms( - symbols=list(atom.get_chemical_symbols()), - scaled_positions=list(atom.get_scaled_positions()), - cell=list(atom.get_cell()), - ) - - class PhonopyJobGenerator(JobGenerator): @property def parameter_list(self): @@ -98,7 +64,9 @@ def parameter_list(self): return [ [ "{}_{}".format(self._master.ref_job.job_name, ind), - self._restore_magmoms(phonopy_to_atoms(sc)), + self._restore_magmoms( + ase_to_pyiron(structuretoolkit.common.phonopy_to_atoms(sc)) + ), ] for ind, sc in enumerate(supercells) ] @@ -193,7 +161,7 @@ def phonopy_pickling_disabled(self, disable): @property def _phonopy_unit_cell(self): if self.structure is not None: - return atoms_to_phonopy(self.structure) + return structuretoolkit.common.atoms_to_phonopy(self.structure) else: return None diff --git a/pyiron_atomistics/atomistics/structure/atoms.py b/pyiron_atomistics/atomistics/structure/atoms.py index 0fc49d6cf..9b24b481a 100644 --- a/pyiron_atomistics/atomistics/structure/atoms.py +++ b/pyiron_atomistics/atomistics/structure/atoms.py @@ -34,6 +34,7 @@ from pyiron_atomistics.atomistics.structure.periodic_table import ( PeriodicTable, ChemicalElement, + chemical_element_dict_to_hdf, ) from pyiron_base import state, deprecate from collections.abc import Sequence @@ -447,6 +448,54 @@ def copy(self): """ return self.__copy__() + def to_dict(self): + hdf_structure = { + "TYPE": str(type(self)), + "units": self.units, + "dimension": self.dimension, + "positions": self.positions, + "info": self.info, + } + for el in self.species: + if isinstance(el.tags, dict): + if "new_species" not in hdf_structure.keys(): + hdf_structure["new_species"] = {} + hdf_structure["new_species"][el.Abbreviation] = el.to_dict() + hdf_structure["species"] = [el.Abbreviation for el in self.species] + hdf_structure["indices"] = self.indices + + for tag, value in self.arrays.items(): + if tag in ["positions", "numbers", "indices"]: + continue + if "tags" not in hdf_structure.keys(): + hdf_structure["tags"] = {} + hdf_structure["tags"][tag] = value.tolist() + + if self.cell is not None: + # Convert ASE cell object to numpy array before storing + hdf_structure["cell"] = {"cell": np.array(self.cell), "pbc": self.pbc} + + if self.has("initial_magmoms"): + hdf_structure["spins"] = self.spins + # potentials with explicit bonds (TIP3P, harmonic, etc.) + if self.bonds is not None: + hdf_structure["explicit_bonds"] = self.bonds + + if self._high_symmetry_points is not None: + hdf_structure["high_symmetry_points"] = self._high_symmetry_points + + if self._high_symmetry_path is not None: + hdf_structure["high_symmetry_path"] = self._high_symmetry_path + + if self.calc is not None: + calc_dict = self.calc.todict() + calc_dict["label"] = self.calc.label + calc_dict["class"] = ( + self.calc.__class__.__module__ + "." + self.calc.__class__.__name__ + ) + hdf_structure["calculator"] = calc_dict + return hdf_structure + def to_hdf(self, hdf, group_name="structure"): """ Save the object in a HDF5 file @@ -457,53 +506,7 @@ def to_hdf(self, hdf, group_name="structure"): Group name with which the object should be stored. This same name should be used to retrieve the object """ - # import time - with hdf.open(group_name) as hdf_structure: - hdf_structure["TYPE"] = str(type(self)) - for el in self.species: - if isinstance(el.tags, dict): - with hdf_structure.open("new_species") as hdf_species: - el.to_hdf(hdf_species) - hdf_structure["species"] = [el.Abbreviation for el in self.species] - hdf_structure["indices"] = self.indices - - with hdf_structure.open("tags") as hdf_tags: - for tag, value in self.arrays.items(): - if tag in ["positions", "numbers", "indices"]: - continue - hdf_tags[tag] = value.tolist() - hdf_structure["units"] = self.units - hdf_structure["dimension"] = self.dimension - - if self.cell is not None: - with hdf_structure.open("cell") as hdf_cell: - # Convert ASE cell object to numpy array before storing - hdf_cell["cell"] = np.array(self.cell) - hdf_cell["pbc"] = self.pbc - - # hdf_structure["coordinates"] = self.positions # "Atomic coordinates" - hdf_structure["positions"] = self.positions # "Atomic coordinates" - if self.has("initial_magmoms"): - hdf_structure["spins"] = self.spins - # potentials with explicit bonds (TIP3P, harmonic, etc.) - if self.bonds is not None: - hdf_structure["explicit_bonds"] = self.bonds - - if self._high_symmetry_points is not None: - hdf_structure["high_symmetry_points"] = self._high_symmetry_points - - if self._high_symmetry_path is not None: - hdf_structure["high_symmetry_path"] = self._high_symmetry_path - - hdf_structure["info"] = self.info - - if self.calc is not None: - calc_dict = self.calc.todict() - calc_dict["label"] = self.calc.label - calc_dict["class"] = ( - self.calc.__class__.__module__ + "." + self.calc.__class__.__name__ - ) - hdf_structure["calculator"] = calc_dict + structure_dict_to_hdf(data_dict=self.to_dict(), hdf=hdf, group_name=group_name) def from_hdf(self, hdf, group_name="structure"): """ @@ -3439,3 +3442,26 @@ def __setitem__(self, key, value): ) for i, el in enumerate(replace_elements): self._structure[index_array[i]] = el + + +def structure_dict_to_hdf(data_dict, hdf, group_name="structure"): + with hdf.open(group_name) as hdf_structure: + for k, v in data_dict.items(): + if k not in ["new_species", "cell", "tags"]: + hdf_structure[k] = v + + if "new_species" in data_dict.keys(): + for el, el_dict in data_dict["new_species"].items(): + chemical_element_dict_to_hdf( + data_dict=el_dict, hdf=hdf_structure, group_name="new_species/" + el + ) + + dict_group_to_hdf(data_dict=data_dict, hdf=hdf_structure, group="tags") + dict_group_to_hdf(data_dict=data_dict, hdf=hdf_structure, group="cell") + + +def dict_group_to_hdf(data_dict, hdf, group): + if group in data_dict.keys(): + with hdf.open(group) as hdf_tags: + for k, v in data_dict[group].items(): + hdf_tags[k] = v diff --git a/pyiron_atomistics/atomistics/structure/factory.py b/pyiron_atomistics/atomistics/structure/factory.py index 085d97932..1850580d4 100644 --- a/pyiron_atomistics/atomistics/structure/factory.py +++ b/pyiron_atomistics/atomistics/structure/factory.py @@ -29,6 +29,7 @@ high_index_surface, get_high_index_surface_info, ) +from structuretoolkit.common import pymatgen_read_from_file from pyiron_atomistics.atomistics.structure.factories.ase import AseFactory from pyiron_atomistics.atomistics.structure.factories.atomsk import ( AtomskFactory, @@ -49,7 +50,6 @@ pymatgen_to_pyiron, ovito_to_pyiron, ) -from pymatgen.core import Structure from pyiron_atomistics.atomistics.structure.periodic_table import PeriodicTable from pyiron_base import state, PyironFactory, deprecate import types @@ -119,7 +119,7 @@ def read(self, *args, **kwargs): read.__doc__ = AseFactory.read.__doc__ def read_using_pymatgen(self, *args, **kwargs): - return pymatgen_to_pyiron(Structure.from_file(*args, **kwargs)) + return ase_to_pyiron(pymatgen_read_from_file(*args, **kwargs)) def read_using_ase(self, *args, **kwargs): return self.ase.read(*args, **kwargs) diff --git a/pyiron_atomistics/atomistics/structure/periodic_table.py b/pyiron_atomistics/atomistics/structure/periodic_table.py index e11a9b0fa..7c2529e23 100644 --- a/pyiron_atomistics/atomistics/structure/periodic_table.py +++ b/pyiron_atomistics/atomistics/structure/periodic_table.py @@ -3,9 +3,9 @@ # Distributed under the terms of "New BSD License", see the LICENSE file. from __future__ import print_function, unicode_literals +import pkgutil +import io import numpy as np -import os -from pyiron_base import state import mendeleev import pandas from functools import lru_cache @@ -154,22 +154,25 @@ def add_tags(self, tag_dic): """ (self.sub["tags"]).update(tag_dic) + def to_dict(self): + hdf_el = {} + # TODO: save all parameters that are different from the parent (e.g. modified mass) + if self.Parent is not None: + self._dataset = {"Parameter": ["Parent"], "Value": [self.Parent]} + hdf_el["elementData"] = self._dataset + # "Dictionary of element tag static" + hdf_el["tagData"] = {key: self.tags[key] for key in self.tags.keys()} + return hdf_el + def to_hdf(self, hdf): """ saves the element with his parameters into his hdf5 job file Args: hdf (Hdfio): Hdfio object which will be used """ - with hdf.open(self.Abbreviation) as hdf_el: # "Symbol of the chemical element" - # TODO: save all parameters that are different from the parent (e.g. modified mass) - if self.Parent is not None: - self._dataset = {"Parameter": ["Parent"], "Value": [self.Parent]} - hdf_el["elementData"] = self._dataset - with hdf_el.open( - "tagData" - ) as hdf_tag: # "Dictionary of element tag static" - for key in self.tags.keys(): - hdf_tag[key] = self.tags[key] + chemical_element_dict_to_hdf( + data_dict=self.to_dict(), hdf=hdf, group_name=self.Abbreviation + ) def from_hdf(self, hdf): """ @@ -201,7 +204,7 @@ def from_hdf(self, hdf): self.sub["tags"] = tag_dic -class PeriodicTable(object): +class PeriodicTable: """ An Object which stores an elementary table which can be modified for the current session """ @@ -393,27 +396,12 @@ def _get_periodic_table_df(file_name): """ if not file_name: - for resource_path in state.settings.resource_paths: - if os.path.exists(os.path.join(resource_path, "atomistics")): - resource_path = os.path.join(resource_path, "atomistics") - for path, folder_lst, file_lst in os.walk(resource_path): - for periodic_table_file_name in {"periodic_table.csv"}: - if ( - periodic_table_file_name in file_lst - and periodic_table_file_name.endswith(".csv") - ): - return pandas.read_csv( - os.path.join(path, periodic_table_file_name), - index_col=0, - ) - elif ( - periodic_table_file_name in file_lst - and periodic_table_file_name.endswith(".h5") - ): - return pandas.read_hdf( - os.path.join(path, periodic_table_file_name), mode="r" - ) - raise ValueError("Was not able to locate a periodic table. ") + return pandas.read_csv( + io.BytesIO( + pkgutil.get_data("pyiron_atomistics", "data/periodic_table.csv") + ), + index_col=0, + ) else: if file_name.endswith(".h5"): return pandas.read_hdf(file_name, mode="r") @@ -424,3 +412,13 @@ def _get_periodic_table_df(file_name): + file_name + " supported file formats are csv, h5." ) + + +def chemical_element_dict_to_hdf(data_dict, hdf, group_name): + with hdf.open(group_name) as hdf_el: + if "elementData" in data_dict.keys(): + hdf_el["elementData"] = data_dict["elementData"] + with hdf_el.open("tagData") as hdf_tag: + if "tagData" in data_dict.keys(): + for k, v in data_dict["tagData"].items(): + hdf_tag[k] = v diff --git a/pyiron_atomistics/atomistics/thermodynamics/thermo_bulk.py b/pyiron_atomistics/atomistics/thermodynamics/thermo_bulk.py index ce116cb0b..5c59b23d9 100644 --- a/pyiron_atomistics/atomistics/thermodynamics/thermo_bulk.py +++ b/pyiron_atomistics/atomistics/thermodynamics/thermo_bulk.py @@ -2,10 +2,7 @@ # Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department # Distributed under the terms of "New BSD License", see the LICENSE file. -from __future__ import print_function - -from copy import copy -import numpy as np +from atomistics.shared.thermo.thermo import ThermoBulk as AtomisticsThermoBulk __author__ = "Joerg Neugebauer, Jan Janssen" __copyright__ = ( @@ -19,7 +16,7 @@ __date__ = "Sep 1, 2017" -class ThermoBulk(object): +class ThermoBulk(AtomisticsThermoBulk): """ Class should provide all tools to compute bulk thermodynamic quantities. Central quantity is the Free Energy F(V,T). ToDo: Make it a (light weight) pyiron object (introduce a new tool rather than job object). @@ -30,485 +27,9 @@ class ThermoBulk(object): """ - eV_to_J_per_mol = 1.60217662e-19 * 6.022e23 - kB = 1 / 8.6173303e-5 - def __init__(self, project=None, name=None): # only for compatibility with pyiron objects self._project = project self._name = name - self._volumes = None - self._temperatures = None - self._energies = None - self._entropy = None - self._pressure = None - self._num_atoms = None - - self._fit_order = 3 - - def copy(self): - """ - - Returns: - - """ - cls = self.__class__ - result = cls.__new__(cls) - result.__init__() - result.__dict__["_volumes"] = copy(self._volumes) - result.__dict__["_temperatures"] = copy(self._temperatures) - result.__dict__["_energies"] = copy(self._energies) - result.__dict__["_fit_order"] = self._fit_order - return result - - def _reset_energy(self): - """ - - Returns: - - """ - if self._volumes is not None: - if self._temperatures is not None: - self._energies = np.zeros((len(self._temperatures), len(self._volumes))) - # self.energies = 0 - - @property - def num_atoms(self): - """ - - Returns: - - """ - if self._num_atoms is None: - return 1 # normalize per cell if number of atoms unknown - return self._num_atoms - - @num_atoms.setter - def num_atoms(self, num): - """ - - Args: - num: - - Returns: - - """ - self._num_atoms = num - - @property - def _coeff(self): - """ - - Returns: - - """ - return np.polyfit(self._volumes, self._energies.T, deg=self._fit_order) - - @property - def temperatures(self): - """ - - Returns: - - """ - return self._temperatures - - @property - def _d_temp(self): - """ - - Returns: - - """ - return self.temperatures[1] - self.temperatures[0] - - @property - def _d_vol(self): - """ - - Returns: - - """ - return self.volumes[1] - self.volumes[0] - - @temperatures.setter - def temperatures(self, temp_lst): - """ - - Args: - temp_lst: - - Returns: - - """ - if not hasattr(temp_lst, "__len__"): - raise ValueError("Requires list as input parameter") - len_temp = -1 - if self._temperatures is not None: - len_temp = len(self._temperatures) - self._temperatures = np.array(temp_lst) - if len(temp_lst) != len_temp: - self._reset_energy() - - @property - def volumes(self): - """ - - Returns: - - """ - return self._volumes - - @volumes.setter - def volumes(self, volume_lst): - """ - - Args: - volume_lst: - - Returns: - - """ - if not hasattr(volume_lst, "__len__"): - raise ValueError("Requires list as input parameter") - len_vol = -1 - if self._volumes is not None: - len_vol = len(self._volumes) - self._volumes = np.array(volume_lst) - if len(volume_lst) != len_vol: - self._reset_energy() - - @property - def entropy(self): - """ - - Returns: - - """ - if self._entropy is None: - self._compute_thermo() - return self._entropy - - @property - def pressure(self): - """ - - Returns: - - """ - if self._pressure is None: - self._compute_thermo() - return self._pressure - - @property - def energies(self): - """ - - Returns: - - """ - return self._energies - - @energies.setter - def energies(self, erg_lst): - """ - - Args: - erg_lst: - - Returns: - - """ - if np.ndim(erg_lst) == 2: - self._energies = erg_lst - elif np.ndim(erg_lst) == 1: - if len(erg_lst) == len(self.volumes): - self._energies = np.tile(erg_lst, (len(self.temperatures), 1)) - else: - raise ValueError() - else: - self._energies = ( - np.ones((len(self.volumes), len(self.temperatures))) * erg_lst - ) - - def set_temperatures( - self, temperature_min=0, temperature_max=1500, temperature_steps=50 - ): - """ - - Args: - temperature_min: - temperature_max: - temperature_steps: - - Returns: - - """ - self.temperatures = np.linspace( - temperature_min, temperature_max, temperature_steps - ) - - def set_volumes(self, volume_min, volume_max=None, volume_steps=10): - """ - - Args: - volume_min: - volume_max: - volume_steps: - - Returns: - - """ - if volume_max is None: - volume_max = 1.1 * volume_min - self.volumes = np.linspace(volume_min, volume_max, volume_steps) - - def meshgrid(self): - """ - - Returns: - - """ - return np.meshgrid(self.volumes, self.temperatures) - - def get_minimum_energy_path(self, pressure=None): - """ - - Args: - pressure: - - Returns: - - """ - if pressure is not None: - raise NotImplemented() - v_min_lst = [] - for c in self._coeff.T: - v_min = np.roots(np.polyder(c, 1)) - p_der2 = np.polyder(c, 2) - p_val2 = np.polyval(p_der2, v_min) - v_m_lst = v_min[p_val2 > 0] - if len(v_m_lst) > 0: - v_min_lst.append(v_m_lst[0]) - else: - v_min_lst.append(np.nan) - return np.array(v_min_lst) - - def get_free_energy(self, vol, pressure=None): - """ - - Args: - vol: - pressure: - - Returns: - - """ - if not pressure: - return np.polyval(self._coeff, vol) - else: - raise NotImplementedError() - - def interpolate_volume(self, volumes, fit_order=None): - """ - - Args: - volumes: - fit_order: - - Returns: - - """ - if fit_order is not None: - self._fit_order = fit_order - new = self.copy() - new.volumes = volumes - new.energies = np.array([np.polyval(self._coeff, v) for v in volumes]).T - return new - - def _compute_thermo(self): - """ - - Returns: - - """ - self._entropy, self._pressure = np.gradient( - -self.energies, self._d_temp, self._d_vol - ) - - def get_free_energy_p(self): - """ - - Returns: - - """ - coeff = np.polyfit(self._volumes, self.energies.T, deg=self._fit_order) - return np.polyval(coeff, self.get_minimum_energy_path()) - - def get_entropy_p(self): - """ - - Returns: - - """ - s_coeff = np.polyfit(self._volumes, self.entropy.T, deg=self._fit_order) - return np.polyval(s_coeff, self.get_minimum_energy_path()) - - def get_entropy_v(self): - """ - - Returns: - - """ - eq_volume = self.volumes[0] - s_coeff = np.polyfit(self.volumes, self.entropy.T, deg=self._fit_order) - const_v = eq_volume * np.ones(len(s_coeff.T)) - return np.polyval(s_coeff, const_v) - - def plot_free_energy(self): - """ - - Returns: - - """ - try: - import pylab as plt - except ImportError: - import matplotlib.pyplot as plt - plt.plot(self.temperatures, self.get_free_energy_p() / self.num_atoms) - plt.xlabel("Temperature [K]") - plt.ylabel("Free energy [eV]") - - def plot_entropy(self): - """ - - Returns: - - """ - try: - import pylab as plt - except ImportError: - import matplotlib.pyplot as plt - plt.plot( - self.temperatures, - self.eV_to_J_per_mol / self.num_atoms * self.get_entropy_p(), - label="S$_p$", - ) - plt.plot( - self.temperatures, - self.eV_to_J_per_mol / self.num_atoms * self.get_entropy_v(), - label="S$_V$", - ) - plt.legend() - plt.xlabel("Temperature [K]") - plt.ylabel("Entropy [J K$^{-1}$ mol-atoms$^{-1}$]") - - def plot_heat_capacity(self, to_kB=True): - """ - - Args: - to_kB: - - Returns: - - """ - try: - import pylab as plt - except ImportError: - import matplotlib.pyplot as plt - if to_kB: - units = self.kB / self.num_atoms - plt.ylabel("Heat capacity [kB]") - else: - units = self.eV_to_J_per_mol - plt.ylabel("Heat capacity [J K$^{-1}$ mol-atoms$^{-1}$]") - temps = self.temperatures[:-2] - c_p = temps * np.gradient(self.get_entropy_p(), self._d_temp)[:-2] - c_v = temps * np.gradient(self.get_entropy_v(), self._d_temp)[:-2] - plt.plot(temps, units * c_p, label="c$_p$") - plt.plot(temps, units * c_v, label="c$_v$") - plt.legend(loc="lower right") - plt.xlabel("Temperature [K]") - - def contour_pressure(self): - """ - - Returns: - - """ - try: - import pylab as plt - except ImportError: - import matplotlib.pyplot as plt - x, y = self.meshgrid() - p_coeff = np.polyfit(self.volumes, self.pressure.T, deg=self._fit_order) - p_grid = np.array([np.polyval(p_coeff, v) for v in self._volumes]).T - plt.contourf(x, y, p_grid) - plt.plot(self.get_minimum_energy_path(), self.temperatures) - plt.xlabel("Volume [$\AA^3$]") - plt.ylabel("Temperature [K]") - - def contour_entropy(self): - """ - - Returns: - - """ - try: - import pylab as plt - except ImportError: - import matplotlib.pyplot as plt - s_coeff = np.polyfit(self.volumes, self.entropy.T, deg=self._fit_order) - s_grid = np.array([np.polyval(s_coeff, v) for v in self.volumes]).T - x, y = self.meshgrid() - plt.contourf(x, y, s_grid) - plt.plot(self.get_minimum_energy_path(), self.temperatures) - plt.xlabel("Volume [$\AA^3$]") - plt.ylabel("Temperature [K]") - - def plot_contourf(self, ax=None, show_min_erg_path=False): - """ - - Args: - ax: - show_min_erg_path: - - Returns: - - """ - try: - import pylab as plt - except ImportError: - import matplotlib.pyplot as plt - x, y = self.meshgrid() - if ax is None: - fig, ax = plt.subplots(1, 1) - ax.contourf(x, y, self.energies) - if show_min_erg_path: - plt.plot(self.get_minimum_energy_path(), self.temperatures, "w--") - plt.xlabel("Volume [$\AA^3$]") - plt.ylabel("Temperature [K]") - return ax - - def plot_min_energy_path(self, *args, ax=None, **qwargs): - """ - - Args: - *args: - ax: - **qwargs: - - Returns: - - """ - try: - import pylab as plt - except ImportError: - import matplotlib.pyplot as plt - if ax is None: - fig, ax = plt.subplots(1, 1) - ax.xlabel("Volume [$\AA^3$]") - ax.ylabel("Temperature [K]") - ax.plot(self.get_minimum_energy_path(), self.temperatures, *args, **qwargs) - return ax + super().__init__() diff --git a/pyiron_atomistics/calphy/job.py b/pyiron_atomistics/calphy/job.py index 75b2b6c65..b6eb93355 100644 --- a/pyiron_atomistics/calphy/job.py +++ b/pyiron_atomistics/calphy/job.py @@ -533,6 +533,7 @@ def calc_mode_fe( self.input.temperature = temperature self.input.pressure = pressure + self.input.npt = pressure is not None self.input.reference_phase = reference_phase self.input.n_equilibration_steps = n_equilibration_steps self.input.n_switching_steps = n_switching_steps @@ -566,6 +567,7 @@ def calc_mode_ts( self.input.temperature = temperature self.input.pressure = pressure + self.input.npt = pressure is not None self.input.reference_phase = reference_phase self.input.n_equilibration_steps = n_equilibration_steps self.input.n_switching_steps = n_switching_steps @@ -596,6 +598,7 @@ def calc_mode_alchemy( raise ValueError("provide a temperature") self.input.temperature = temperature self.input.pressure = pressure + self.input.npt = pressure is not None self.input.reference_phase = reference_phase self.input.n_equilibration_steps = n_equilibration_steps self.input.n_switching_steps = n_switching_steps @@ -629,6 +632,7 @@ def calc_mode_pscale( self.input.temperature = temperature self.input.pressure = pressure + self.input.npt = True self.input.reference_phase = reference_phase self.input.n_equilibration_steps = n_equilibration_steps self.input.n_switching_steps = n_switching_steps @@ -659,6 +663,7 @@ def calc_free_energy( raise ValueError("provide a temperature") self.input.temperature = temperature self.input.pressure = pressure + self.input.npt = pressure is not None self.input.reference_phase = reference_phase self.input.n_equilibration_steps = n_equilibration_steps self.input.n_switching_steps = n_switching_steps diff --git a/tests/static/atomistics/periodic_table.csv b/pyiron_atomistics/data/periodic_table.csv similarity index 100% rename from tests/static/atomistics/periodic_table.csv rename to pyiron_atomistics/data/periodic_table.csv diff --git a/pyiron_atomistics/dft/master/murnaghan_dft.py b/pyiron_atomistics/dft/master/murnaghan_dft.py index 3316d030f..4c7bc8df3 100644 --- a/pyiron_atomistics/dft/master/murnaghan_dft.py +++ b/pyiron_atomistics/dft/master/murnaghan_dft.py @@ -3,7 +3,7 @@ # Distributed under the terms of "New BSD License", see the LICENSE file. from __future__ import print_function -from pyiron_atomistics.atomistics.master.murnaghan import Murnaghan, DebyeModel +from pyiron_atomistics.atomistics.master.murnaghan import Murnaghan, MurnaghanDebyeModel from pyiron_base import deprecate __author__ = "Joerg Neugebauer, Jan Janssen" diff --git a/pyiron_atomistics/dft/waves/electronic.py b/pyiron_atomistics/dft/waves/electronic.py index 97de8747d..2f9ac59ba 100644 --- a/pyiron_atomistics/dft/waves/electronic.py +++ b/pyiron_atomistics/dft/waves/electronic.py @@ -6,7 +6,11 @@ import numpy as np -from pyiron_atomistics.atomistics.structure.atoms import Atoms +from pyiron_atomistics.atomistics.structure.atoms import ( + Atoms, + structure_dict_to_hdf, + dict_group_to_hdf, +) from pyiron_atomistics.dft.waves.dos import Dos __author__ = "Sudarsan Surendralal" @@ -481,24 +485,33 @@ def to_hdf(self, hdf, group_name="electronic_structure"): hdf: Path to the hdf5 file/group in the file group_name: Name of the group under which the attributes are o be stored """ - with hdf.open(group_name) as h_es: - h_es["TYPE"] = str(type(self)) - if self.structure is not None: - self.structure.to_hdf(h_es) - h_es["k_points"] = self.kpoint_list - h_es["k_weights"] = self.kpoint_weights - h_es["eig_matrix"] = self.eigenvalue_matrix - h_es["occ_matrix"] = self.occupancy_matrix - if self.efermi is not None: - h_es["efermi"] = self.efermi - with h_es.open("dos") as h_dos: - h_dos["energies"] = self.dos_energies - h_dos["tot_densities"] = self.dos_densities - h_dos["int_densities"] = self.dos_idensities - if self.grand_dos_matrix is not None: - h_dos["grand_dos_matrix"] = self.grand_dos_matrix - if self.resolved_densities is not None: - h_dos["resolved_densities"] = self.resolved_densities + electronic_structure_dict_to_hdf( + data_dict=self.to_dict(), hdf=hdf, group_name=group_name + ) + + def to_dict(self): + h_es = { + "TYPE": str(type(self)), + "k_points": self.kpoint_list, + "k_weights": self.kpoint_weights, + "eig_matrix": self.eigenvalue_matrix, + "occ_matrix": self.occupancy_matrix, + } + if self.structure is not None: + h_es["structure"] = self.structure.to_dict() + if self.efermi is not None: + h_es["efermi"] = self.efermi + + h_es["dos"] = { + "energies": self.dos_energies, + "tot_densities": self.dos_densities, + "int_densities": self.dos_idensities, + } + if self.grand_dos_matrix is not None: + h_es["dos"]["grand_dos_matrix"] = self.grand_dos_matrix + if self.resolved_densities is not None: + h_es["dos"]["resolved_densities"] = self.resolved_densities + return h_es def from_hdf(self, hdf, group_name="electronic_structure"): """ @@ -841,3 +854,15 @@ def resolved_dos_matrix(self): @resolved_dos_matrix.setter def resolved_dos_matrix(self, val): self._resolved_dos_matrix = val + + +def electronic_structure_dict_to_hdf(data_dict, hdf, group_name): + with hdf.open(group_name) as h_es: + for k, v in data_dict.items(): + if k not in ["structure", "dos"]: + h_es[k] = v + + if "structure" in data_dict.keys(): + structure_dict_to_hdf(data_dict=data_dict["structure"], hdf=h_es) + + dict_group_to_hdf(data_dict=data_dict, hdf=h_es, group="dos") diff --git a/pyiron_atomistics/lammps/base.py b/pyiron_atomistics/lammps/base.py index 56c9150d9..9b3738eaa 100644 --- a/pyiron_atomistics/lammps/base.py +++ b/pyiron_atomistics/lammps/base.py @@ -25,13 +25,10 @@ UnfoldingPrism, structure_to_lammps, ) -from pyiron_atomistics.lammps.units import UnitConverter, LAMMPS_UNIT_CONVERSIONS +from pyiron_atomistics.lammps.units import LAMMPS_UNIT_CONVERSIONS from pyiron_atomistics.lammps.output import ( - collect_output_log, - collect_h5md_file, - collect_dump_file, - collect_errors, remap_indices, + parse_lammps_output, ) __author__ = "Joerg Neugebauer, Sudarsan Surendralal, Jan Janssen" @@ -386,20 +383,7 @@ def write_input(self): structure=self.structure, cutoff_radius=self.cutoff_radius ) lmp_structure.write_file(file_name="structure.inp", cwd=self.working_directory) - version_int_lst = self._get_executable_version_number() update_input_hdf5 = False - if ( - version_int_lst is not None - and "dump_modify" in self.input.control._dataset["Parameter"] - and ( - version_int_lst[0] < 2016 - or (version_int_lst[0] == 2016 and version_int_lst[1] < 11) - ) - ): - self.input.control["dump_modify"] = self.input.control[ - "dump_modify" - ].replace(" line ", " ") - update_input_hdf5 = True if not all(self.structure.pbc): self.input.control["boundary"] = " ".join( ["p" if coord else "f" for coord in self.structure.pbc] @@ -416,25 +400,6 @@ def write_input(self): ) self.input.potential.copy_pot_files(self.working_directory) - def _get_executable_version_number(self): - """ - Get the version of the executable - - Returns: - list: List of integers defining the version number - """ - if self.executable.version: - return [ - l - for l in [ - [int(i) for i in sv.split(".") if i.isdigit()] - for sv in self.executable.version.split("/")[-1].split("_") - ] - if len(l) > 0 - ][0] - else: - return None - @property def publication(self): return { @@ -454,6 +419,30 @@ def publication(self): } } + def collect_output_parser( + self, + cwd, + dump_h5_file_name="dump.h5", + dump_out_file_name="dump.out", + log_lammps_file_name="log.lammps", + ): + # Parse output files + return parse_lammps_output( + dump_h5_full_file_name=self.job_file_name( + file_name=dump_h5_file_name, cwd=cwd + ), + dump_out_full_file_name=self.job_file_name( + file_name=dump_out_file_name, cwd=cwd + ), + log_lammps_full_file_name=self.job_file_name( + file_name=log_lammps_file_name, cwd=cwd + ), + prism=self._prism, + structure=self.structure, + potential_elements=self.input.potential.get_element_lst(), + units=self.units, + ) + def collect_output(self): """ @@ -461,13 +450,22 @@ def collect_output(self): """ self.input.from_hdf(self._hdf5) - if os.path.isfile( - self.job_file_name(file_name="dump.h5", cwd=self.working_directory) - ): - self.collect_h5md_file(file_name="dump.h5", cwd=self.working_directory) - else: - self.collect_dump_file(file_name="dump.out", cwd=self.working_directory) - self.collect_output_log(file_name="log.lammps", cwd=self.working_directory) + hdf_dict = self.collect_output_parser( + cwd=self.working_directory, + dump_h5_file_name="dump.h5", + dump_out_file_name="dump.out", + log_lammps_file_name="log.lammps", + ) + + # Write to hdf + with self.project_hdf5.open("output/generic") as hdf_output: + for k, v in hdf_dict["generic"].items(): + hdf_output[k] = v + + with self.project_hdf5.open("output/lammps") as hdf_output: + for k, v in hdf_dict["lammps"].items(): + hdf_output[k] = v + if len(self.output.cells) > 0: final_structure = self.get_structure(iteration_step=-1) if final_structure is not None: @@ -498,34 +496,6 @@ def collect_logfiles(self): """ return - # TODO: make rotation of all vectors back to the original as in self.collect_dump_file - def collect_h5md_file(self, file_name="dump.h5", cwd=None): - """ - - Args: - file_name: - cwd: - - Returns: - - """ - uc = UnitConverter(self.units) - forces, positions, steps, cell = collect_h5md_file( - file_name=self.job_file_name(file_name=file_name, cwd=cwd), - prism=self._prism, - ) - with self.project_hdf5.open("output/generic") as h5_file: - h5_file["forces"] = uc.convert_array_to_pyiron_units( - np.array(forces), "forces" - ) - h5_file["positions"] = uc.convert_array_to_pyiron_units( - np.array(positions), "positions" - ) - h5_file["steps"] = uc.convert_array_to_pyiron_units( - np.array(steps), "steps" - ) - h5_file["cells"] = uc.convert_array_to_pyiron_units(np.array(cell), "cells") - def remap_indices(self, lammps_indices): """ Give the Lammps-dumped indices, re-maps these back onto the structure's indices to preserve the species. @@ -547,46 +517,6 @@ def remap_indices(self, lammps_indices): structure=self.structure, ) - def collect_output_log(self, file_name="log.lammps", cwd=None): - """ - general purpose routine to extract static from a lammps log file - - Args: - file_name: - cwd: - - Returns: - - """ - file_name = self.job_file_name(file_name=file_name, cwd=cwd) - collect_errors(file_name=file_name) - if os.path.exists(file_name): - generic_keys_lst, pressure_dict, df = collect_output_log( - file_name=file_name, - prism=self._prism, - ) - uc = UnitConverter(self.units) - with self.project_hdf5.open("output/generic") as hdf_output: - # This is a hack for backward comparability - for k, v in df.items(): - if k in generic_keys_lst: - hdf_output[k] = uc.convert_array_to_pyiron_units( - np.array(v), label=k - ) - # Store pressures as numpy arrays - for key, val in pressure_dict.items(): - hdf_output[key] = uc.convert_array_to_pyiron_units(val, label=key) - - with self.project_hdf5.open("output/lammps") as hdf_output: - # This is a hack for backward comparability - for k, v in df.items(): - if k not in generic_keys_lst: - hdf_output[k] = uc.convert_array_to_pyiron_units( - np.array(v), label=k - ) - else: - warnings.warn("LAMMPS warning: No log.lammps output file found.") - def calc_minimize( self, ionic_energy_tolerance=0.0, @@ -861,47 +791,6 @@ def read_restart_file(self, filename="restart.out"): ["dimension", "read_data", "boundary", "atom_style", "velocity"] ) - def collect_dump_file(self, file_name="dump.out", cwd=None): - """ - general purpose routine to extract static from a lammps dump file - - Args: - file_name: - cwd: - - Returns: - - """ - uc = UnitConverter(self.units) - file_name = self.job_file_name(file_name=file_name, cwd=cwd) - - if os.path.exists(file_name): - dump_dict = collect_dump_file( - file_name=file_name, - prism=self._prism, - structure=self.structure, - potential_elements=self.input.potential.get_element_lst(), - ) - # Write to hdf - with self.project_hdf5.open("output/generic") as hdf_output: - for k, v in dump_dict.pop("computes").items(): - hdf_output[k] = uc.convert_array_to_pyiron_units( - np.array(v), label=k - ) - - hdf_output["steps"] = uc.convert_array_to_pyiron_units( - np.array(dump_dict.pop("steps"), dtype=int), label="steps" - ) - - for k, v in dump_dict.items(): - if len(v) > 0: - hdf_output[k] = uc.convert_array_to_pyiron_units( - np.array(v), label=k - ) - - else: - warnings.warn("LAMMPS warning: No dump.out output file found.") - # Outdated functions: def set_potential(self, file_name): """ @@ -983,7 +872,7 @@ def _get_lammps_structure(self, structure=None, cutoff_radius=None): def _set_selective_dynamics(self): if "selective_dynamics" in self.structure.arrays.keys(): - sel_dyn = np.logical_not(self.structure.selective_dynamics) + sel_dyn = np.logical_not(np.stack(self.structure.selective_dynamics)) # Enter loop only if constraints present if len(np.argwhere(np.any(sel_dyn, axis=1)).flatten()) != 0: all_indices = np.arange(len(self.structure), dtype=int) diff --git a/pyiron_atomistics/lammps/output.py b/pyiron_atomistics/lammps/output.py index 5bac69f6c..02b3e6ffb 100644 --- a/pyiron_atomistics/lammps/output.py +++ b/pyiron_atomistics/lammps/output.py @@ -1,11 +1,20 @@ +from __future__ import annotations + from dataclasses import dataclass, field, asdict -from typing import List, Dict +from typing import Dict, List, Tuple, TYPE_CHECKING, Union import h5py from io import StringIO import numpy as np +import os import pandas as pd -from pyiron_base import extract_data_from_file import warnings +from pyiron_base import extract_data_from_file +from pyiron_atomistics.lammps.units import UnitConverter + + +if TYPE_CHECKING: + from pyiron_atomistics.atomistics.structure.atoms import Atoms + from pyiron_atomistics.lammps.structure import UnfoldingPrism @dataclass @@ -24,154 +33,84 @@ class DumpData: computes: Dict = field(default_factory=lambda: {}) -def collect_output_log(file_name, prism): - """ - general purpose routine to extract static from a lammps log file +def parse_lammps_output( + dump_h5_full_file_name: str, + dump_out_full_file_name: str, + log_lammps_full_file_name: str, + prism: UnfoldingPrism, + structure: Atoms, + potential_elements: Union[np.ndarray, List], + units: str, +) -> Dict: + dump_dict = _parse_dump( + dump_h5_full_file_name, + dump_out_full_file_name, + prism, + structure, + potential_elements, + ) - Args: - file_name: - prism: + generic_keys_lst, pressure_dict, df = _parse_log(log_lammps_full_file_name, prism) - Returns: + convert_units = UnitConverter(units).convert_array_to_pyiron_units - """ - with open(file_name, "r") as f: - read_thermo = False - thermo_lines = "" - dfs = [] - for l in f: - l = l.lstrip() + hdf_output = {"generic": {}, "lammps": {}} + hdf_generic = hdf_output["generic"] + hdf_lammps = hdf_output["lammps"] - if read_thermo: - if l.startswith("Loop"): - read_thermo = False - continue - thermo_lines += l + if "computes" in dump_dict.keys(): + for k, v in dump_dict.pop("computes").items(): + hdf_generic[k] = convert_units(np.array(v), label=k) - if l.startswith("Step"): - read_thermo = True - thermo_lines += l - - dfs.append( - pd.read_csv( - StringIO(thermo_lines), - sep="\s+", - engine="c", - ) - ) + hdf_generic["steps"] = convert_units( + np.array(dump_dict.pop("steps"), dtype=int), label="steps" + ) - if len(dfs) == 1: - df = dfs[0] + for k, v in dump_dict.items(): + if len(v) > 0: + hdf_generic[k] = convert_units(np.array(v), label=k) + + if df is not None and pressure_dict is not None and generic_keys_lst is not None: + for k, v in df.items(): + v = convert_units(np.array(v), label=k) + if k in generic_keys_lst: + hdf_generic[k] = v + else: # This is a hack for backward comparability + hdf_lammps[k] = v + + # Store pressures as numpy arrays + for key, val in pressure_dict.items(): + hdf_generic[key] = convert_units(val, label=key) else: - df = pd.concat[dfs] - - h5_dict = { - "Step": "steps", - "Temp": "temperature", - "PotEng": "energy_pot", - "TotEng": "energy_tot", - "Volume": "volume", - } - - for key in df.columns[df.columns.str.startswith("f_mean")]: - h5_dict[key] = key.replace("f_", "") - - df = df.rename(index=str, columns=h5_dict) - pressure_dict = dict() - if all( - [ - x in df.columns.values - for x in [ - "Pxx", - "Pxy", - "Pxz", - "Pxy", - "Pyy", - "Pyz", - "Pxz", - "Pyz", - "Pzz", - ] - ] - ): - pressures = ( - np.stack( - ( - df.Pxx, - df.Pxy, - df.Pxz, - df.Pxy, - df.Pyy, - df.Pyz, - df.Pxz, - df.Pyz, - df.Pzz, - ), - axis=-1, - ) - .reshape(-1, 3, 3) - .astype("float64") + warnings.warn("LAMMPS warning: No log.lammps output file found.") + + return hdf_output + + +def _parse_dump( + dump_h5_full_file_name: str, + dump_out_full_file_name: str, + prism: UnfoldingPrism, + structure: Atoms, + potential_elements: Union[np.ndarray, List], +) -> Dict: + if os.path.isfile(dump_h5_full_file_name): + return _collect_dump_from_h5md( + file_name=dump_h5_full_file_name, + prism=prism, ) - # Rotate pressures from Lammps frame to pyiron frame if necessary - if _check_ortho_prism(prism=prism): - rotation_matrix = prism.R.T - pressures = rotation_matrix.T @ pressures @ rotation_matrix - - df = df.drop( - columns=df.columns[ - ((df.columns.str.len() == 3) & df.columns.str.startswith("P")) - ] + elif os.path.exists(dump_out_full_file_name): + return _collect_dump_from_text( + file_name=dump_out_full_file_name, + prism=prism, + structure=structure, + potential_elements=potential_elements, ) - pressure_dict["pressures"] = pressures else: - warnings.warn( - "LAMMPS warning: log.lammps does not contain the required pressure values." - ) - if "mean_pressure[1]" in df.columns: - pressures = ( - np.stack( - ( - df["mean_pressure[1]"], - df["mean_pressure[4]"], - df["mean_pressure[5]"], - df["mean_pressure[4]"], - df["mean_pressure[2]"], - df["mean_pressure[6]"], - df["mean_pressure[5]"], - df["mean_pressure[6]"], - df["mean_pressure[3]"], - ), - axis=-1, - ) - .reshape(-1, 3, 3) - .astype("float64") - ) - if _check_ortho_prism(prism=prism): - rotation_matrix = prism.R.T - pressures = rotation_matrix.T @ pressures @ rotation_matrix - df = df.drop( - columns=df.columns[ - ( - df.columns.str.startswith("mean_pressure") - & df.columns.str.endswith("]") - ) - ] - ) - pressure_dict["mean_pressures"] = pressures - generic_keys_lst = list(h5_dict.values()) - return generic_keys_lst, pressure_dict, df + return {} -def collect_h5md_file(file_name, prism): - """ - - Args: - file_name: - cwd: - - Returns: - - """ +def _collect_dump_from_h5md(file_name: str, prism: UnfoldingPrism) -> Dict: if _check_ortho_prism(prism=prism): raise RuntimeError( "The Lammps output will not be mapped back to pyiron correctly." @@ -185,38 +124,22 @@ def collect_h5md_file(file_name, prism): np.eye(3) * np.array(cell_i.tolist()) for cell_i in h5md["/particles/all/box/edges/value"] ] - return forces, positions, steps, cell - - -def collect_errors(file_name): - """ - - Args: - file_name: - - Returns: - - """ - error = extract_data_from_file(file_name, tag="ERROR", num_args=1000) - if len(error) > 0: - error = " ".join(error[0]) - raise RuntimeError("Run time error occurred: " + str(error)) - else: - return True + return { + "forces": forces, + "positions": positions, + "steps": steps, + "cells": cell, + } -def collect_dump_file(file_name, prism, structure, potential_elements): +def _collect_dump_from_text( + file_name: str, + prism: UnfoldingPrism, + structure: Atoms, + potential_elements: Union[np.ndarray, List], +) -> Dict: """ general purpose routine to extract static from a lammps dump file - - Args: - file_name: - prism: - structure: - potential_elements: - - Returns: - """ rotation_lammps2orig = prism.R.T with open(file_name, "r") as f: @@ -351,30 +274,182 @@ def collect_dump_file(file_name, prism, structure, potential_elements): return asdict(dump) -def _check_ortho_prism(prism, rtol=0.0, atol=1e-08): +def _parse_log( + log_lammps_full_file_name: str, prism: UnfoldingPrism +) -> Union[Tuple[List[str], Dict, pd.DataFrame], Tuple[None, None, None]]: """ - Check if the rotation matrix of the UnfoldingPrism object is sufficiently close to a unit matrix + If it exists, parses the lammps log file and either raises an exception if errors + occurred or returns data. Just returns a tuple of Nones if there is no file at the + given location. Args: - prism (pyiron_atomistics.lammps.structure.UnfoldingPrism): UnfoldingPrism object to check - rtol (float): relative precision for numpy.isclose() - atol (float): absolute precision for numpy.isclose() + log_lammps_full_file_name (str): The path to the lammps log file. + prism (pyiron_atomistics.lammps.structure.UnfoldingPrism): For mapping between + lammps and pyiron structures Returns: - boolean: True or False + (list | None): Generic keys + (dict | None): Pressures + (pandas.DataFrame | None): A dataframe with the rest of the information + + Raises: + (RuntimeError): If there are "ERROR" tags in the log. """ - return np.isclose(prism.R, np.eye(3), rtol=rtol, atol=atol).all() + if os.path.exists(log_lammps_full_file_name): + _raise_exception_if_errors_found(file_name=log_lammps_full_file_name) + return _collect_output_log( + file_name=log_lammps_full_file_name, + prism=prism, + ) + else: + return None, None, None -def to_amat(l_list): +def _collect_output_log( + file_name: str, prism: UnfoldingPrism +) -> Tuple[List[str], Dict, pd.DataFrame]: """ + general purpose routine to extract static from a lammps log file + """ + with open(file_name, "r") as f: + read_thermo = False + thermo_lines = "" + for l in f: + l = l.lstrip() + + if read_thermo: + if l.startswith("Loop"): + read_thermo = False + continue + thermo_lines += l + + if l.startswith("Step"): + read_thermo = True + thermo_lines += l + + df = pd.read_csv(StringIO(thermo_lines), sep="\s+", engine="c") + + h5_dict = { + "Step": "steps", + "Temp": "temperature", + "PotEng": "energy_pot", + "TotEng": "energy_tot", + "Volume": "volume", + } + + for key in df.columns[df.columns.str.startswith("f_mean")]: + h5_dict[key] = key.replace("f_", "") + + df = df.rename(index=str, columns=h5_dict) + pressure_dict = dict() + if all( + [ + x in df.columns.values + for x in [ + "Pxx", + "Pxy", + "Pxz", + "Pxy", + "Pyy", + "Pyz", + "Pxz", + "Pyz", + "Pzz", + ] + ] + ): + pressures = ( + np.stack( + ( + df.Pxx, + df.Pxy, + df.Pxz, + df.Pxy, + df.Pyy, + df.Pyz, + df.Pxz, + df.Pyz, + df.Pzz, + ), + axis=-1, + ) + .reshape(-1, 3, 3) + .astype("float64") + ) + # Rotate pressures from Lammps frame to pyiron frame if necessary + if _check_ortho_prism(prism=prism): + rotation_matrix = prism.R.T + pressures = rotation_matrix.T @ pressures @ rotation_matrix + + df = df.drop( + columns=df.columns[ + ((df.columns.str.len() == 3) & df.columns.str.startswith("P")) + ] + ) + pressure_dict["pressures"] = pressures + else: + warnings.warn( + "LAMMPS warning: log.lammps does not contain the required pressure values." + ) + if "mean_pressure[1]" in df.columns: + pressures = ( + np.stack( + tuple(df[f"mean_pressure[{i}]"] for i in [1, 4, 5, 4, 2, 6, 5, 6, 3]), + axis=-1, + ) + .reshape(-1, 3, 3) + .astype("float64") + ) + if _check_ortho_prism(prism=prism): + rotation_matrix = prism.R.T + pressures = rotation_matrix.T @ pressures @ rotation_matrix + df = df.drop( + columns=df.columns[ + ( + df.columns.str.startswith("mean_pressure") + & df.columns.str.endswith("]") + ) + ] + ) + pressure_dict["mean_pressures"] = pressures + generic_keys_lst = list(h5_dict.values()) + return generic_keys_lst, pressure_dict, df + + +def _raise_exception_if_errors_found(file_name: str) -> None: + """ + Raises a `RuntimeError` if the `"ERROR"` tag is found in the file. Args: - l_list: + file_name (str): The file holding the LAMMPS log - Returns: + Raises: + (RuntimeError): if at least one "ERROR" tag is found + """ + error = extract_data_from_file(file_name, tag="ERROR", num_args=1000) + if len(error) > 0: + error = " ".join(error[0]) + raise RuntimeError("Run time error occurred: " + str(error)) + +def _check_ortho_prism( + prism: UnfoldingPrism, rtol: float = 0.0, atol: float = 1e-08 +) -> bool: """ + Check if the rotation matrix of the UnfoldingPrism object is sufficiently close to a unit matrix + + Args: + prism (pyiron_atomistics.lammps.structure.UnfoldingPrism): UnfoldingPrism object to check + rtol (float): relative precision for numpy.isclose() + atol (float): absolute precision for numpy.isclose() + + Returns: + boolean: True or False + """ + return np.isclose(prism.R, np.eye(3), rtol=rtol, atol=atol).all() + + +def to_amat(l_list: Union[np.ndarray, List]) -> List: lst = np.reshape(l_list, -1) if len(lst) == 9: ( @@ -413,7 +488,11 @@ def to_amat(l_list): return cell -def remap_indices(lammps_indices, potential_elements, structure): +def remap_indices( + lammps_indices: Union[np.ndarray, List], + potential_elements: Union[np.ndarray, List], + structure: Atoms, +) -> np.ndarray: """ Give the Lammps-dumped indices, re-maps these back onto the structure's indices to preserve the species. diff --git a/pyiron_atomistics/vasp/base.py b/pyiron_atomistics/vasp/base.py index 62521f21c..89cccb371 100644 --- a/pyiron_atomistics/vasp/base.py +++ b/pyiron_atomistics/vasp/base.py @@ -16,17 +16,28 @@ Potcar, strip_xc_from_potential_name, ) -from pyiron_atomistics.atomistics.structure.atoms import Atoms, CrystalStructure +from pyiron_atomistics.atomistics.structure.atoms import ( + Atoms, + CrystalStructure, + structure_dict_to_hdf, + dict_group_to_hdf, +) from pyiron_base import state, GenericParameters, deprecate -from pyiron_atomistics.vasp.outcar import Outcar -from pyiron_atomistics.vasp.oszicar import Oszicar +from pyiron_atomistics.vasp.parser.outcar import Outcar +from pyiron_atomistics.vasp.parser.oszicar import Oszicar from pyiron_atomistics.vasp.procar import Procar from pyiron_atomistics.vasp.structure import read_atoms, write_poscar, vasp_sorter from pyiron_atomistics.vasp.vasprun import Vasprun as Vr from pyiron_atomistics.vasp.vasprun import VasprunError, VasprunWarning -from pyiron_atomistics.vasp.volumetric_data import VaspVolumetricData +from pyiron_atomistics.vasp.volumetric_data import ( + VaspVolumetricData, + volumetric_data_dict_to_hdf, +) from pyiron_atomistics.vasp.potential import get_enmax_among_potentials -from pyiron_atomistics.dft.waves.electronic import ElectronicStructure +from pyiron_atomistics.dft.waves.electronic import ( + ElectronicStructure, + electronic_structure_dict_to_hdf, +) from pyiron_atomistics.dft.waves.bandstructure import Bandstructure from pyiron_atomistics.dft.bader import Bader import warnings @@ -376,21 +387,24 @@ def write_input(self): modified_elements=modified_elements, ) - # define routines that collect all output files - def collect_output(self): + def collect_output_parser(self, cwd): """ Collects the outputs and stores them to the hdf file """ if self.structure is None or len(self.structure) == 0: try: - self.structure = self.get_final_structure_from_file(filename="CONTCAR") + self.structure = self.get_final_structure_from_file( + cwd=cwd, filename="CONTCAR" + ) except IOError: - self.structure = self.get_final_structure_from_file(filename="POSCAR") + self.structure = self.get_final_structure_from_file( + cwd=cwd, filename="POSCAR" + ) self._sorted_indices = np.array(range(len(self.structure))) self._output_parser.structure = self.structure.copy() try: self._output_parser.collect( - directory=self.working_directory, sorted_indices=self.sorted_indices + directory=cwd, sorted_indices=self.sorted_indices ) except VaspCollectError: self.status.aborted = True @@ -398,15 +412,16 @@ def collect_output(self): # Try getting high precision positions from CONTCAR try: self._output_parser.structure = self.get_final_structure_from_file( - filename="CONTCAR" + cwd=cwd, + filename="CONTCAR", ) except (IOError, ValueError, FileNotFoundError): pass # Bader analysis - if os.path.isfile( - os.path.join(self.working_directory, "AECCAR0") - ) and os.path.isfile(os.path.join(self.working_directory, "AECCAR2")): + if os.path.isfile(os.path.join(cwd, "AECCAR0")) and os.path.isfile( + os.path.join(cwd, "AECCAR2") + ): bader = Bader(self) try: charges_orig, volumes_orig = bader.compute_bader_charges() @@ -431,7 +446,18 @@ def collect_output(self): self._output_parser.generic_output.dft_log_dict[ "bader_volumes" ] = volumes - self._output_parser.to_hdf(self._hdf5) + return self._output_parser.to_dict() + + # define routines that collect all output files + def collect_output(self): + """ + Collects the outputs and stores them to the hdf file + """ + output_dict_to_hdf( + data_dict=self.collect_output_parser(cwd=self.working_directory), + hdf=self._hdf5, + group_name="output", + ) if len(self._exclude_groups_hdf) > 0 or len(self._exclude_nodes_hdf) > 0: self.project_hdf5.rewrite_hdf5() @@ -766,7 +792,7 @@ def reset_output(self): """ self._output_parser = Output() - def get_final_structure_from_file(self, filename="CONTCAR"): + def get_final_structure_from_file(self, cwd, filename="CONTCAR"): """ Get the final structure of the simulation usually from the CONTCAR file @@ -776,7 +802,7 @@ def get_final_structure_from_file(self, filename="CONTCAR"): Returns: pyiron.atomistics.structure.atoms.Atoms: The final structure """ - filename = posixpath.join(self.working_directory, filename) + filename = posixpath.join(cwd, filename) if self.structure is None: try: output_structure = read_atoms(filename=filename) @@ -1993,8 +2019,12 @@ def collect(self, directory=os.getcwd(), sorted_indices=None): "n_elect" ] if len(self.outcar.parse_dict["magnetization"]) > 0: - magnetization = np.array(self.outcar.parse_dict["magnetization"]).copy() - final_magmoms = np.array(self.outcar.parse_dict["final_magmoms"]).copy() + magnetization = np.array( + self.outcar.parse_dict["magnetization"], dtype=object + ) + final_magmoms = np.array( + self.outcar.parse_dict["final_magmoms"], dtype=object + ) # magnetization[sorted_indices] = magnetization.copy() if len(final_magmoms) != 0: if len(final_magmoms.shape) == 3: @@ -2199,6 +2229,30 @@ def collect(self, directory=os.getcwd(), sorted_indices=None): ) self.generic_output.bands = self.electronic_structure + def to_dict(self): + hdf5_output = { + "description": self.description, + "generic": self.generic_output.to_dict(), + } + + if self._structure is not None: + hdf5_output["structure"] = self.structure.to_dict() + + if self.electrostatic_potential.total_data is not None: + hdf5_output[ + "electrostatic_potential" + ] = self.electrostatic_potential.to_dict() + + if self.charge_density.total_data is not None: + hdf5_output["charge_density"] = self.charge_density.to_dict() + + if len(self.electronic_structure.kpoint_list) > 0: + hdf5_output["electronic_structure"] = self.electronic_structure.to_dict() + + if len(self.outcar.parse_dict.keys()) > 0: + hdf5_output["outcar"] = self.outcar.to_dict_minimal() + return hdf5_output + def to_hdf(self, hdf): """ Save the object in a HDF5 file @@ -2207,34 +2261,7 @@ def to_hdf(self, hdf): hdf (pyiron_base.generic.hdfio.ProjectHDFio): HDF path to which the object is to be saved """ - with hdf.open("output") as hdf5_output: - hdf5_output["description"] = self.description - self.generic_output.to_hdf(hdf5_output) - try: - self.structure.to_hdf(hdf5_output) - except AttributeError: - pass - - # with hdf5_output.open("vasprun") as hvr: - # if self.vasprun.dict_vasprun is not None: - # for key, val in self.vasprun.dict_vasprun.items(): - # hvr[key] = val - - if self.electrostatic_potential.total_data is not None: - self.electrostatic_potential.to_hdf( - hdf5_output, group_name="electrostatic_potential" - ) - - if self.charge_density.total_data is not None: - self.charge_density.to_hdf(hdf5_output, group_name="charge_density") - - if len(self.electronic_structure.kpoint_list) > 0: - self.electronic_structure.to_hdf( - hdf=hdf5_output, group_name="electronic_structure" - ) - - if len(self.outcar.parse_dict.keys()) > 0: - self.outcar.to_hdf_minimal(hdf=hdf5_output, group_name="outcar") + output_dict_to_hdf(data_dict=self.to_dict(), hdf=hdf, group_name="output") def from_hdf(self, hdf): """ @@ -2297,15 +2324,20 @@ def to_hdf(self, hdf): hdf (pyiron_base.generic.hdfio.ProjectHDFio): HDF path to which the object is to be saved """ - with hdf.open("generic") as hdf_go: - # hdf_go["description"] = self.description - for key, val in self.log_dict.items(): - hdf_go[key] = val - with hdf_go.open("dft") as hdf_dft: - for key, val in self.dft_log_dict.items(): - hdf_dft[key] = val - if self.bands.eigenvalue_matrix is not None: - self.bands.to_hdf(hdf_dft, "bands") + generic_output_dict_to_hdf( + data_dict=self.to_dict(), hdf=hdf, group_name="generic" + ) + + def to_dict(self): + hdf_go, hdf_dft = {}, {} + for key, val in self.log_dict.items(): + hdf_go[key] = val + for key, val in self.dft_log_dict.items(): + hdf_dft[key] = val + hdf_go["dft"] = hdf_dft + if self.bands.eigenvalue_matrix is not None: + hdf_go["dft"]["bands"] = self.bands.to_dict() + return hdf_go def from_hdf(self, hdf): """ @@ -2527,3 +2559,73 @@ def get_k_mesh_by_cell(cell, kspace_per_in_ang=0.10): class VaspCollectError(ValueError): pass + + +def generic_output_dict_to_hdf(data_dict, hdf, group_name="generic"): + with hdf.open(group_name) as hdf_go: + for k, v in data_dict.items(): + if k not in ["dft"]: + hdf_go[k] = v + + with hdf_go.open("dft") as hdf_dft: + for k, v in data_dict["dft"].items(): + if k not in ["bands"]: + hdf_dft[k] = v + + if "bands" in data_dict["dft"].keys(): + electronic_structure_dict_to_hdf( + data_dict=data_dict["dft"]["bands"], + hdf=hdf_dft, + group_name="bands", + ) + + +def output_dict_to_hdf(data_dict, hdf, group_name="output"): + with hdf.open(group_name) as hdf5_output: + for k, v in data_dict.items(): + if k not in [ + "generic", + "structure", + "electrostatic_potential", + "charge_density", + "electronic_structure", + "outcar", + ]: + hdf5_output[k] = v + + if "generic" in data_dict.keys(): + generic_output_dict_to_hdf( + data_dict=data_dict["generic"], + hdf=hdf5_output, + group_name="generic", + ) + + if "structure" in data_dict.keys(): + structure_dict_to_hdf( + data_dict=data_dict["structure"], + hdf=hdf5_output, + group_name="structure", + ) + + if "electrostatic_potential" in data_dict.keys(): + volumetric_data_dict_to_hdf( + data_dict=data_dict["electrostatic_potential"], + hdf=hdf5_output, + group_name="electrostatic_potential", + ) + + if "charge_density" in data_dict.keys(): + volumetric_data_dict_to_hdf( + data_dict=data_dict["charge_density"], + hdf=hdf5_output, + group_name="charge_density", + ) + + if "electronic_structure" in data_dict.keys(): + electronic_structure_dict_to_hdf( + data_dict=data_dict["electronic_structure"], + hdf=hdf5_output, + group_name="electronic_structure", + ) + + dict_group_to_hdf(data_dict=data_dict, hdf=hdf5_output, group="outcar") diff --git a/pyiron_atomistics/vasp/interactive.py b/pyiron_atomistics/vasp/interactive.py index 4766263b7..cf68fdf96 100644 --- a/pyiron_atomistics/vasp/interactive.py +++ b/pyiron_atomistics/vasp/interactive.py @@ -6,7 +6,7 @@ import os from subprocess import Popen, PIPE -from pyiron_atomistics.vasp.outcar import Outcar +from pyiron_atomistics.vasp.parser.outcar import Outcar from pyiron_atomistics.vasp.base import VaspBase from pyiron_atomistics.vasp.structure import vasp_sorter from pyiron_atomistics.vasp.potential import VaspPotentialSetter diff --git a/pyiron_atomistics/vasp/metadyn.py b/pyiron_atomistics/vasp/metadyn.py index 07a03d1c1..39b4044aa 100644 --- a/pyiron_atomistics/vasp/metadyn.py +++ b/pyiron_atomistics/vasp/metadyn.py @@ -6,7 +6,7 @@ from pyiron_base import GenericParameters from pyiron_atomistics.vasp.vasp import Vasp from pyiron_atomistics.vasp.base import Input, Output -from pyiron_atomistics.vasp.report import Report +from pyiron_atomistics.vasp.parser.report import Report import os import posixpath diff --git a/tests/atomistics/__init__.py b/pyiron_atomistics/vasp/parser/__init__.py similarity index 100% rename from tests/atomistics/__init__.py rename to pyiron_atomistics/vasp/parser/__init__.py diff --git a/pyiron_atomistics/vasp/oszicar.py b/pyiron_atomistics/vasp/parser/oszicar.py similarity index 100% rename from pyiron_atomistics/vasp/oszicar.py rename to pyiron_atomistics/vasp/parser/oszicar.py diff --git a/pyiron_atomistics/vasp/outcar.py b/pyiron_atomistics/vasp/parser/outcar.py similarity index 99% rename from pyiron_atomistics/vasp/outcar.py rename to pyiron_atomistics/vasp/parser/outcar.py index be1844016..cd4619428 100644 --- a/pyiron_atomistics/vasp/outcar.py +++ b/pyiron_atomistics/vasp/parser/outcar.py @@ -152,6 +152,12 @@ def to_hdf_minimal(self, hdf, group_name="outcar"): hdf (pyiron_base.generic.hdfio.FileHDFio): HDF5 group or file group_name (str): Name of the HDF5 group """ + with hdf.open(group_name) as hdf5_output: + for k, v in self.to_dict_minimal().items(): + hdf5_output[k] = v + + def to_dict_minimal(self): + hdf5_output = {} unique_quantities = [ "kin_energy_error", "broyden_mixing", @@ -162,10 +168,10 @@ def to_hdf_minimal(self, hdf, group_name="outcar"): "energy_components", "resources", ] - with hdf.open(group_name) as hdf5_output: - for key in self.parse_dict.keys(): - if key in unique_quantities: - hdf5_output[key] = self.parse_dict[key] + for key in self.parse_dict.keys(): + if key in unique_quantities: + hdf5_output[key] = self.parse_dict[key] + return hdf5_output def from_hdf(self, hdf, group_name="outcar"): """ diff --git a/pyiron_atomistics/vasp/report.py b/pyiron_atomistics/vasp/parser/report.py similarity index 100% rename from pyiron_atomistics/vasp/report.py rename to pyiron_atomistics/vasp/parser/report.py diff --git a/pyiron_atomistics/vasp/volumetric_data.py b/pyiron_atomistics/vasp/volumetric_data.py index 86b1baa33..8f52898ac 100644 --- a/pyiron_atomistics/vasp/volumetric_data.py +++ b/pyiron_atomistics/vasp/volumetric_data.py @@ -277,6 +277,15 @@ def diff_data(self): def diff_data(self, val): self._diff_data = val + def to_dict(self): + volumetric_data_dict = { + "TYPE": str(type(self)), + "total": self.total_data, + } + if self.diff_data is not None: + volumetric_data_dict["diff"] = self.diff_data + return volumetric_data_dict + def to_hdf(self, hdf, group_name="volumetric_data"): """ Writes the data as a group to a HDF5 file @@ -286,11 +295,11 @@ def to_hdf(self, hdf, group_name="volumetric_data"): group_name (str): The name of the group under which the data must be stored as """ - with hdf.open(group_name) as hdf_vd: - hdf_vd["TYPE"] = str(type(self)) - hdf_vd["total"] = self.total_data - if self.diff_data is not None: - hdf_vd["diff"] = self.diff_data + volumetric_data_dict_to_hdf( + data_dict=self.to_dict(), + hdf=hdf, + group_name=group_name, + ) def from_hdf(self, hdf, group_name="volumetric_data"): """ @@ -308,3 +317,9 @@ def from_hdf(self, hdf, group_name="volumetric_data"): self._total_data = hdf_vd["total"] if "diff" in hdf_vd.list_nodes(): self._diff_data = hdf_vd["diff"] + + +def volumetric_data_dict_to_hdf(data_dict, hdf, group_name="volumetric_data"): + with hdf.open(group_name) as hdf_vd: + for k, v in data_dict.items(): + hdf_vd[k] = v diff --git a/setup.py b/setup.py index 5a7dcd12a..cc57b1a4a 100644 --- a/setup.py +++ b/setup.py @@ -43,24 +43,25 @@ ]), install_requires=[ 'ase==3.22.1', + 'atomistics==0.1.2', 'defusedxml==0.7.1', - 'h5py==3.9.0', - 'matplotlib==3.7.2', + 'h5py==3.10.0', + 'matplotlib==3.8.1', 'mendeleev==0.14.0', - 'mp-api==0.33.3', - 'numpy==1.24.3', - 'pandas==2.0.3', + 'mp-api==0.37.5', + 'numpy==1.26.0', + 'pandas==2.1.3', 'phonopy==2.20.0', 'pint==0.22', - 'pyiron_base==0.6.3', + 'pyiron_base==0.6.9', 'pylammpsmpi==0.2.6', 'pymatgen==2023.8.10', - 'scipy==1.11.1', + 'scipy==1.11.3', 'seekpath==2.1.0', - 'scikit-learn==1.3.0', - 'spglib==2.0.2', - 'structuretoolkit==0.0.6' + 'scikit-learn==1.3.2', + 'spglib==2.1.0', + 'structuretoolkit==0.0.12' ], cmdclass=versioneer.get_cmdclass(), - + package_data={'': ['data/*.csv']}, ) diff --git a/tests/atomistics/job/__init__.py b/tests/atomic/__init__.py similarity index 100% rename from tests/atomistics/job/__init__.py rename to tests/atomic/__init__.py diff --git a/tests/atomistics/master/__init__.py b/tests/atomic/job/__init__.py similarity index 100% rename from tests/atomistics/master/__init__.py rename to tests/atomic/job/__init__.py diff --git a/tests/atomistics/job/test_StructureContainer.py b/tests/atomic/job/test_StructureContainer.py similarity index 97% rename from tests/atomistics/job/test_StructureContainer.py rename to tests/atomic/job/test_StructureContainer.py index f6cdbf497..a5b6396d4 100644 --- a/tests/atomistics/job/test_StructureContainer.py +++ b/tests/atomic/job/test_StructureContainer.py @@ -35,7 +35,7 @@ def test_container(self): self.assertEqual(structure_container.job_id, self.project.get_job_ids()[0]) self.assertEqual(structure_container.job_name, "structure_container") self.assertTrue( - "atomistics/job/structure_testing/" + "atomic/job/structure_testing/" in structure_container.project_hdf5.project_path ) self.assertTrue(structure_container.status.finished) diff --git a/tests/atomistics/job/test_atomistic.py b/tests/atomic/job/test_atomistic.py similarity index 98% rename from tests/atomistics/job/test_atomistic.py rename to tests/atomic/job/test_atomistic.py index a5f561b40..98818f075 100644 --- a/tests/atomistics/job/test_atomistic.py +++ b/tests/atomic/job/test_atomistic.py @@ -87,7 +87,8 @@ def test_get_displacements(self): disp_ref.append(np.dot(diff, cell)) self.assertTrue(np.allclose(disp, disp_ref)) - @unittest.skipIf(os.name == 'nt', "Runs forever on Windows") + #@unittest.skipIf(os.name == 'nt', "Runs forever on Windows") + @unittest.skip def test_get_structure(self): """get_structure() should return structures with the exact values from the HDF files even if the size of structures changes.""" diff --git a/tests/atomistics/job/test_sqs.py b/tests/atomic/job/test_sqs.py similarity index 100% rename from tests/atomistics/job/test_sqs.py rename to tests/atomic/job/test_sqs.py diff --git a/tests/atomistics/job/test_transform_trajectory.py b/tests/atomic/job/test_transform_trajectory.py similarity index 100% rename from tests/atomistics/job/test_transform_trajectory.py rename to tests/atomic/job/test_transform_trajectory.py diff --git a/tests/atomistics/structure/__init__.py b/tests/atomic/master/__init__.py similarity index 100% rename from tests/atomistics/structure/__init__.py rename to tests/atomic/master/__init__.py diff --git a/tests/atomistics/master/test_elastic.py b/tests/atomic/master/test_elastic.py similarity index 100% rename from tests/atomistics/master/test_elastic.py rename to tests/atomic/master/test_elastic.py diff --git a/tests/atomistics/master/test_murnaghan.py b/tests/atomic/master/test_murnaghan.py similarity index 100% rename from tests/atomistics/master/test_murnaghan.py rename to tests/atomic/master/test_murnaghan.py diff --git a/tests/atomistics/master/test_murnaghan_master_modal.py b/tests/atomic/master/test_murnaghan_master_modal.py similarity index 100% rename from tests/atomistics/master/test_murnaghan_master_modal.py rename to tests/atomic/master/test_murnaghan_master_modal.py diff --git a/tests/atomistics/master/test_murnaghan_non_modal.py b/tests/atomic/master/test_murnaghan_non_modal.py similarity index 100% rename from tests/atomistics/master/test_murnaghan_non_modal.py rename to tests/atomic/master/test_murnaghan_non_modal.py diff --git a/tests/atomistics/master/test_phonopy.py b/tests/atomic/master/test_phonopy.py similarity index 100% rename from tests/atomistics/master/test_phonopy.py rename to tests/atomic/master/test_phonopy.py diff --git a/tests/atomistics/master/test_quasi.py b/tests/atomic/master/test_quasi.py similarity index 100% rename from tests/atomistics/master/test_quasi.py rename to tests/atomic/master/test_quasi.py diff --git a/tests/atomistics/structure/factories/__init__.py b/tests/atomic/structure/__init__.py similarity index 100% rename from tests/atomistics/structure/factories/__init__.py rename to tests/atomic/structure/__init__.py diff --git a/tests/atomistics/volumetric/__init__.py b/tests/atomic/structure/factories/__init__.py similarity index 100% rename from tests/atomistics/volumetric/__init__.py rename to tests/atomic/structure/factories/__init__.py diff --git a/tests/atomistics/structure/factories/test_aimsgb.py b/tests/atomic/structure/factories/test_aimsgb.py similarity index 100% rename from tests/atomistics/structure/factories/test_aimsgb.py rename to tests/atomic/structure/factories/test_aimsgb.py diff --git a/tests/atomistics/structure/factories/test_atomsk.py b/tests/atomic/structure/factories/test_atomsk.py similarity index 100% rename from tests/atomistics/structure/factories/test_atomsk.py rename to tests/atomic/structure/factories/test_atomsk.py diff --git a/tests/atomistics/structure/factories/test_compound.py b/tests/atomic/structure/factories/test_compound.py similarity index 100% rename from tests/atomistics/structure/factories/test_compound.py rename to tests/atomic/structure/factories/test_compound.py diff --git a/tests/atomistics/structure/test_analyse.py b/tests/atomic/structure/test_analyse.py similarity index 100% rename from tests/atomistics/structure/test_analyse.py rename to tests/atomic/structure/test_analyse.py diff --git a/tests/atomistics/structure/test_atom.py b/tests/atomic/structure/test_atom.py similarity index 100% rename from tests/atomistics/structure/test_atom.py rename to tests/atomic/structure/test_atom.py diff --git a/tests/atomistics/structure/test_atoms.py b/tests/atomic/structure/test_atoms.py similarity index 91% rename from tests/atomistics/structure/test_atoms.py rename to tests/atomic/structure/test_atoms.py index e9e6628fb..e7b725650 100644 --- a/tests/atomistics/structure/test_atoms.py +++ b/tests/atomic/structure/test_atoms.py @@ -13,8 +13,6 @@ Atoms, CrystalStructure, ase_to_pyiron, - pymatgen_to_pyiron, - pyiron_to_pymatgen, ) from pyiron_atomistics.atomistics.structure.factory import StructureFactory from pyiron_atomistics.atomistics.structure.periodic_table import ( @@ -26,7 +24,6 @@ from ase.cell import Cell as ASECell from ase.atoms import Atoms as ASEAtoms from ase.build import molecule -from pymatgen.core import Structure, Lattice from ase.calculators.morse import MorsePotential @@ -1774,161 +1771,6 @@ def test_cached_speed(self): "Atom creation not speed up to the required level by caches!", ) - def test_pymatgen_to_pyiron_conversion(self): - """ - Tests pymatgen_to_pyiron conversion functionality (implemented conversion path is pymatgen->ASE->pyiron) - Tests: - 1. If conversion works with no site-specific properties - 2. Equivalence in selective dynamics tags after conversion if only sel dyn is present - 3. Checks if other tags are affected when sel dyn is present (magmom is checked) - 4. Checks if other tags are affected when sel dyn is not present (magmom is checked) - """ - - coords = [[0, 0, 0], [0.75, 0.5, 0.75]] - lattice = Lattice.from_parameters( - a=4.2, b=4.2, c=4.2, alpha=120, beta=90, gamma=60 - ) - struct = Structure(lattice, ["Fe", "Fe"], coords) - - # First test make sure it actually works for structures without sel-dyn - pyiron_atoms_no_sd = pymatgen_to_pyiron(struct) - - # Check that it doesn't have any selective dynamics tags attached when it shouldn't - self.assertFalse( - hasattr(pyiron_atoms_no_sd, "selective_dynamics"), - "It's adding selective dynamics after conversion even when original object doesn't have it", - ) - - # Second test for equivalence in selective dynamics tags in pyiron Atoms vs pymatgen Structure - - struct_with_sd = struct.copy() - new_site_properties = struct.site_properties - new_site_properties["selective_dynamics"] = [ - [True, False, False] for site in struct - ] - struct_with_sd = struct.copy(site_properties=new_site_properties) - - pyiron_atoms_sd = pymatgen_to_pyiron(struct_with_sd) - - sd_equivalent = struct_with_sd.site_properties["selective_dynamics"] == [ - x.selective_dynamics.tolist() for x in pyiron_atoms_sd - ] - self.assertTrue( - sd_equivalent, - "Failed equivalence test of selective dynamics tags after conversion", - ) - - # Third test make sure no tags are erased (e.g. magmom) if selective dynamics are present - new_site_properties = struct.site_properties - new_site_properties["selective_dynamics"] = [ - [True, False, False] for site in struct - ] - new_site_properties["magmom"] = [0.61 for site in struct] - new_site_properties["magmom"][1] = 3.0 - struct_with_sd_magmom = struct.copy(site_properties=new_site_properties) - pyiron_atoms_sd_magmom = pymatgen_to_pyiron(struct_with_sd_magmom) - magmom_equivalent = struct_with_sd_magmom.site_properties["magmom"] == [ - x.spin for x in pyiron_atoms_sd_magmom - ] - sd_equivalent = struct_with_sd_magmom.site_properties["selective_dynamics"] == [ - x.selective_dynamics.tolist() for x in pyiron_atoms_sd_magmom - ] - self.assertTrue( - magmom_equivalent, - "Failed equivalence test of magnetic moment tags if selective dynamics present after conversion (it's messing with other site-specific properties)", - ) - self.assertTrue( - sd_equivalent, - "Failed equivalence test of selective dynamics tags if magmom site property is also present", - ) - - # Fourth test, make sure if other traits are present (e.g. magmom) but no sel dyn, the conversion works properly (check if magmom is transferred) - new_site_properties = struct.site_properties - new_site_properties["magmom"] = [0.61 for site in struct] - new_site_properties["magmom"][1] = 3.0 - struct_with_magmom = struct.copy(site_properties=new_site_properties) - pyiron_atoms_magmom = pymatgen_to_pyiron(struct_with_magmom) - magmom_equivalent = struct_with_magmom.site_properties["magmom"] == [ - x.spin for x in pyiron_atoms_magmom - ] - self.assertTrue( - magmom_equivalent, - "Failed to convert site-specific properties (checked magmom spin) when no selective dynamics was present)", - ) - # Make sure no sel dyn tags are added unnecessarily - self.assertFalse( - hasattr(pyiron_atoms_magmom, "selective_dynamics"), - "selective dynamics are added when there was none in original pymatgen Structure", - ) - - def test_pyiron_to_pymatgen_conversion(self): - """ - Tests pyiron_to_pymatgen conversion functionality (implemented conversion path is pyiron->ASE->pymatgen) - - Tests: - 1. If conversion works with no site-specific properties - 2. Equivalence in selective dynamics tags after conversion if only sel dyn is present - 3. Checks if other tags are affected when sel dyn is present (magmom is checked) - 4. Checks if other tags are affected when sel dyn is not present (magmom is checked) - """ - pyiron_atoms = StructureFactory().bulk( - name="Fe", crystalstructure="bcc", a=4.182 - ) * [1, 2, 1] - - # First, check conversion actually works - struct = pyiron_to_pymatgen(pyiron_atoms) - # Ensure no random selective dynamics are added - self.assertFalse("selective_dynamics" in struct.site_properties) - - # Second, ensure that when only sel_dyn is present (no other site-props present), conversion works - pyiron_atoms_sd = pyiron_atoms.copy() - pyiron_atoms_sd.add_tag(selective_dynamics=[False, True, True]) - pyiron_atoms_sd.selective_dynamics[1] = [True, False, False] - - struct_sd = pyiron_to_pymatgen(pyiron_atoms_sd) - self.assertTrue( - np.array_equal( - struct_sd.site_properties["selective_dynamics"], - pyiron_atoms_sd.selective_dynamics - ), - "Failed to produce equivalent selective dynamics after conversion!", - ) - - # Third, ensure when magnetic moment is present without selective dynamics, conversion works and magmom is transferred - pyiron_atoms_magmom = pyiron_atoms.copy() - pyiron_atoms_magmom.add_tag(spin=0.61) - pyiron_atoms_magmom.spin[1] = 3 - - struct_magmom = pyiron_to_pymatgen(pyiron_atoms_magmom) - self.assertTrue( - struct_magmom.site_properties["magmom"] - == [x.spin for x in pyiron_atoms_magmom], - "Failed to produce equivalent magmom when only magmom and no sel_dyn are present!", - ) - self.assertFalse( - "selective_dynamics" in struct_magmom.site_properties, - "Failed because selective dynamics was randomly added after conversion!", - ) - - # Fourth, ensure when both magmom and sd are present, conversion works and magmom+selective dynamics are transferred properly - pyiron_atoms_sd_magmom = pyiron_atoms_sd.copy() - pyiron_atoms_sd_magmom.add_tag(spin=0.61) - pyiron_atoms_sd_magmom.spin[1] = 3 - - struct_sd_magmom = pyiron_to_pymatgen(pyiron_atoms_sd_magmom) - self.assertTrue( - struct_sd_magmom.site_properties["magmom"] - == [x.spin for x in pyiron_atoms_sd_magmom], - "Failed to produce equivalent magmom when both magmom + sel_dyn are present!", - ) - self.assertTrue( - np.array_equal( - struct_sd_magmom.site_properties["selective_dynamics"], - pyiron_atoms_sd_magmom.selective_dynamics - ), - "Failed to produce equivalent sel_dyn when both magmom + sel_dyn are present!", - ) - def test_calc_to_hdf(self): """Calculators set on the structure should be properly reloaded after reading from HDF.""" structure = self.CO2.copy() diff --git a/tests/atomistics/structure/test_high_index_surface.py b/tests/atomic/structure/test_high_index_surface.py similarity index 100% rename from tests/atomistics/structure/test_high_index_surface.py rename to tests/atomic/structure/test_high_index_surface.py diff --git a/tests/atomistics/structure/test_neighbors.py b/tests/atomic/structure/test_neighbors.py similarity index 100% rename from tests/atomistics/structure/test_neighbors.py rename to tests/atomic/structure/test_neighbors.py diff --git a/tests/atomistics/structure/test_neighbors_trajectory.py b/tests/atomic/structure/test_neighbors_trajectory.py similarity index 100% rename from tests/atomistics/structure/test_neighbors_trajectory.py rename to tests/atomic/structure/test_neighbors_trajectory.py diff --git a/tests/atomistics/structure/test_periodic_table.py b/tests/atomic/structure/test_periodic_table.py similarity index 100% rename from tests/atomistics/structure/test_periodic_table.py rename to tests/atomic/structure/test_periodic_table.py diff --git a/tests/atomistics/structure/test_potentials.py b/tests/atomic/structure/test_potentials.py similarity index 100% rename from tests/atomistics/structure/test_potentials.py rename to tests/atomic/structure/test_potentials.py diff --git a/tests/atomistics/structure/test_pyscal.py b/tests/atomic/structure/test_pyscal.py similarity index 100% rename from tests/atomistics/structure/test_pyscal.py rename to tests/atomic/structure/test_pyscal.py diff --git a/tests/atomistics/structure/test_strain.py b/tests/atomic/structure/test_strain.py similarity index 100% rename from tests/atomistics/structure/test_strain.py rename to tests/atomic/structure/test_strain.py diff --git a/tests/atomistics/structure/test_structurestorage.py b/tests/atomic/structure/test_structurestorage.py similarity index 100% rename from tests/atomistics/structure/test_structurestorage.py rename to tests/atomic/structure/test_structurestorage.py diff --git a/tests/atomistics/structure/test_symmetry.py b/tests/atomic/structure/test_symmetry.py similarity index 100% rename from tests/atomistics/structure/test_symmetry.py rename to tests/atomic/structure/test_symmetry.py diff --git a/tests/atomistics/structure/test_visualize.py b/tests/atomic/structure/test_visualize.py similarity index 100% rename from tests/atomistics/structure/test_visualize.py rename to tests/atomic/structure/test_visualize.py diff --git a/tests/atomic/volumetric/__init__.py b/tests/atomic/volumetric/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/atomistics/volumetric/test_volumetric_data.py b/tests/atomic/volumetric/test_volumetric_data.py similarity index 100% rename from tests/atomistics/volumetric/test_volumetric_data.py rename to tests/atomic/volumetric/test_volumetric_data.py diff --git a/tests/lammps/test_base.py b/tests/lammps/test_base.py index 28387b0c9..e1fb353d8 100644 --- a/tests/lammps/test_base.py +++ b/tests/lammps/test_base.py @@ -133,30 +133,6 @@ def test_structure_charge(self): ], ) - def test_avilable_versions(self): - self.job.executable = os.path.abspath( - os.path.join( - self.execution_path, - "..", - "static", - "lammps", - "bin", - "run_lammps_2018.03.16.sh", - ) - ) - self.assertTrue([2018, 3, 16] == self.job._get_executable_version_number()) - self.job.executable = os.path.abspath( - os.path.join( - self.execution_path, - "..", - "static", - "lammps", - "bin", - "run_lammps_2018.03.16_mpi.sh", - ) - ) - self.assertTrue([2018, 3, 16] == self.job._get_executable_version_number()) - def _build_water(self, y0_shift=0.0): density = 1.0e-24 # g/A^3 n_mols = 27 @@ -448,25 +424,31 @@ def test_dump_parser(self): file_directory = os.path.join( self.execution_path, "..", "static", "lammps_test_files" ) - self.job.collect_dump_file(cwd=file_directory, file_name="dump_static.out") + output_dict = self.job.collect_output_parser( + cwd=file_directory, + dump_out_file_name="dump_static.out", + log_lammps_file_name="log_not_available" + ) self.assertTrue( - np.array_equal(self.job["output/generic/forces"].shape, (1, 2, 3)) + np.array_equal(output_dict["generic"]["forces"].shape, (1, 2, 3)) ) self.assertTrue( - np.array_equal(self.job["output/generic/positions"].shape, (1, 2, 3)) + np.array_equal(output_dict["generic"]["positions"].shape, (1, 2, 3)) ) self.assertTrue( - np.array_equal(self.job["output/generic/cells"].shape, (1, 3, 3)) + np.array_equal(output_dict["generic"]["cells"].shape, (1, 3, 3)) ) self.assertTrue( - np.array_equal(self.job["output/generic/indices"].shape, (1, 2)) + np.array_equal(output_dict["generic"]["indices"].shape, (1, 2)) ) # compare to old dump parser - old_output = collect_dump_file_old(job=self.job, cwd=file_directory, file_name="dump_static.out") - with self.job.project_hdf5.open("output/generic") as hdf_out: - for k, v in old_output.items(): - self.assertTrue(np.all(v == hdf_out[k])) - + old_output = collect_dump_file_old( + job=self.job, + cwd=file_directory, + file_name="dump_static.out" + ) + for k, v in old_output.items(): + self.assertTrue(np.all(v == output_dict["generic"][k])) def test_vcsgc_input(self): unit_cell = Atoms( @@ -749,8 +731,11 @@ def test_average(self): file_directory = os.path.join( self.execution_path, "..", "static", "lammps_test_files" ) - self.job.collect_dump_file(cwd=file_directory, file_name="dump_average.out") - self.job.collect_output_log(cwd=file_directory, file_name="log_average.lammps") + _ = self.job.collect_output_parser( + cwd=file_directory, + dump_out_file_name="dump_average.out", + log_lammps_file_name="log_average.lammps" + ) def test_validate(self): with self.assertRaises(ValueError): diff --git a/tests/vasp/test_oszicar.py b/tests/vasp/test_oszicar.py index 672a98bfe..b2cb9ccd3 100644 --- a/tests/vasp/test_oszicar.py +++ b/tests/vasp/test_oszicar.py @@ -5,7 +5,7 @@ import numpy as np import os import posixpath -from pyiron_atomistics.vasp.oszicar import Oszicar +from pyiron_atomistics.vasp.parser.oszicar import Oszicar import unittest diff --git a/tests/vasp/test_outcar.py b/tests/vasp/test_outcar.py index f275d0118..0da288151 100644 --- a/tests/vasp/test_outcar.py +++ b/tests/vasp/test_outcar.py @@ -6,7 +6,7 @@ import os import posixpath import numpy as np -from pyiron_atomistics.vasp.outcar import Outcar +from pyiron_atomistics.vasp.parser.outcar import Outcar class TestOutcar(unittest.TestCase): diff --git a/versioneer.py b/versioneer.py index 64fea1c89..1e3753e63 100644 --- a/versioneer.py +++ b/versioneer.py @@ -1,5 +1,5 @@ -# Version: 0.18 +# Version: 0.29 """The Versioneer - like a rocketeer, but for versions. @@ -7,18 +7,14 @@ ============== * like a rocketeer, but for versions! -* https://github.com/warner/python-versioneer +* https://github.com/python-versioneer/python-versioneer * Brian Warner -* License: Public Domain -* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy -* [![Latest Version] -(https://pypip.in/version/versioneer/badge.svg?style=flat) -](https://pypi.python.org/pypi/versioneer/) -* [![Build Status] -(https://travis-ci.org/warner/python-versioneer.png?branch=master) -](https://travis-ci.org/warner/python-versioneer) - -This is a tool for managing a recorded version number in distutils-based +* License: Public Domain (Unlicense) +* Compatible with: Python 3.7, 3.8, 3.9, 3.10, 3.11 and pypy3 +* [![Latest Version][pypi-image]][pypi-url] +* [![Build Status][travis-image]][travis-url] + +This is a tool for managing a recorded version number in setuptools-based python projects. The goal is to remove the tedious and error-prone "update the embedded version string" step from your release process. Making a new release should be as easy as recording a new tag in your version-control @@ -27,9 +23,38 @@ ## Quick Install -* `pip install versioneer` to somewhere to your $PATH -* add a `[versioneer]` section to your setup.cfg (see below) -* run `versioneer install` in your source tree, commit the results +Versioneer provides two installation modes. The "classic" vendored mode installs +a copy of versioneer into your repository. The experimental build-time dependency mode +is intended to allow you to skip this step and simplify the process of upgrading. + +### Vendored mode + +* `pip install versioneer` to somewhere in your $PATH + * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is + available, so you can also use `conda install -c conda-forge versioneer` +* add a `[tool.versioneer]` section to your `pyproject.toml` or a + `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) + * Note that you will need to add `tomli; python_version < "3.11"` to your + build-time dependencies if you use `pyproject.toml` +* run `versioneer install --vendor` in your source tree, commit the results +* verify version information with `python setup.py version` + +### Build-time dependency mode + +* `pip install versioneer` to somewhere in your $PATH + * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is + available, so you can also use `conda install -c conda-forge versioneer` +* add a `[tool.versioneer]` section to your `pyproject.toml` or a + `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) +* add `versioneer` (with `[toml]` extra, if configuring in `pyproject.toml`) + to the `requires` key of the `build-system` table in `pyproject.toml`: + ```toml + [build-system] + requires = ["setuptools", "versioneer[toml]"] + build-backend = "setuptools.build_meta" + ``` +* run `versioneer install --no-vendor` in your source tree, commit the results +* verify version information with `python setup.py version` ## Version Identifiers @@ -61,7 +86,7 @@ for example `git describe --tags --dirty --always` reports things like "0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the 0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes. +uncommitted changes). The version identifier is used for multiple purposes: @@ -166,7 +191,7 @@ Some situations are known to cause problems for Versioneer. This details the most significant ones. More can be found on Github -[issues page](https://github.com/warner/python-versioneer/issues). +[issues page](https://github.com/python-versioneer/python-versioneer/issues). ### Subprojects @@ -180,7 +205,7 @@ `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI distributions (and upload multiple independently-installable tarballs). * Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other langauges) in subdirectories. + provide bindings to Python (and perhaps other languages) in subdirectories. Versioneer will look for `.git` in parent directories, and most operations should get the right version string. However `pip` and `setuptools` have bugs @@ -194,9 +219,9 @@ Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in some later version. -[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking +[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking this issue. The discussion in -[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the +[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the issue from the Versioneer side in more detail. [pip PR#3176](https://github.com/pypa/pip/pull/3176) and [pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve @@ -224,31 +249,20 @@ cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into a different virtualenv), so this can be surprising. -[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes +[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes this one, but upgrading to a newer version of setuptools should probably resolve it. -### Unicode version strings - -While Versioneer works (and is continually tested) with both Python 2 and -Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. -Newer releases probably generate unicode version strings on py2. It's not -clear that this is wrong, but it may be surprising for applications when then -write these strings to a network connection or include them in bytes-oriented -APIs like cryptographic checksums. - -[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates -this question. - ## Updating Versioneer To upgrade your project to a new release of Versioneer, do the following: * install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace +* edit `setup.cfg` and `pyproject.toml`, if necessary, + to include any new configuration settings indicated by the release notes. + See [UPGRADING](./UPGRADING.md) for details. +* re-run `versioneer install --[no-]vendor` in your source tree, to replace `SRC/_version.py` * commit any changed files @@ -265,35 +279,70 @@ direction and include code from all supported VCS systems, reducing the number of intermediate scripts. +## Similar projects + +* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time + dependency +* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of + versioneer +* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools + plugin ## License To make Versioneer easier to embed, all its code is dedicated to the public domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . +Specifically, both are released under the "Unlicense", as described in +https://unlicense.org/. + +[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg +[pypi-url]: https://pypi.python.org/pypi/versioneer/ +[travis-image]: +https://img.shields.io/travis/com/python-versioneer/python-versioneer.svg +[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer """ +# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring +# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements +# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error +# pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with +# pylint:disable=attribute-defined-outside-init,too-many-arguments -from __future__ import print_function -try: - import configparser -except ImportError: - import ConfigParser as configparser +import configparser import errno import json import os import re import subprocess import sys +from pathlib import Path +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import NoReturn +import functools + +have_tomllib = True +if sys.version_info >= (3, 11): + import tomllib +else: + try: + import tomli as tomllib + except ImportError: + have_tomllib = False class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + versionfile_source: str + versionfile_build: Optional[str] + parentdir_prefix: Optional[str] + verbose: Optional[bool] + -def get_root(): +def get_root() -> str: """Get the project root directory. We require that all commands are run from the project root, i.e. the @@ -301,13 +350,23 @@ def get_root(): """ root = os.path.realpath(os.path.abspath(os.getcwd())) setup_py = os.path.join(root, "setup.py") + pyproject_toml = os.path.join(root, "pyproject.toml") versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + if not ( + os.path.exists(setup_py) + or os.path.exists(pyproject_toml) + or os.path.exists(versioneer_py) + ): # allow 'python path/to/setup.py COMMAND' root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) setup_py = os.path.join(root, "setup.py") + pyproject_toml = os.path.join(root, "pyproject.toml") versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + if not ( + os.path.exists(setup_py) + or os.path.exists(pyproject_toml) + or os.path.exists(versioneer_py) + ): err = ("Versioneer was unable to run the project root directory. " "Versioneer requires setup.py to be executed from " "its immediate directory (like 'python setup.py COMMAND'), " @@ -321,43 +380,62 @@ def get_root(): # module-import table will cache the first one. So we can't use # os.path.dirname(__file__), as that will find whichever # versioneer.py was first imported, even in later projects. - me = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(me)[0]) + my_path = os.path.realpath(os.path.abspath(__file__)) + me_dir = os.path.normcase(os.path.splitext(my_path)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: + if me_dir != vsr_dir and "VERSIONEER_PEP518" not in globals(): print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py)) + % (os.path.dirname(my_path), versioneer_py)) except NameError: pass return root -def get_config_from_root(root): +def get_config_from_root(root: str) -> VersioneerConfig: """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise EnvironmentError (if setup.cfg is missing), or + # This might raise OSError (if setup.cfg is missing), or # configparser.NoSectionError (if it lacks a [versioneer] section), or # configparser.NoOptionError (if it lacks "VCS="). See the docstring at # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.SafeConfigParser() - with open(setup_cfg, "r") as f: - parser.readfp(f) - VCS = parser.get("versioneer", "VCS") # mandatory - - def get(parser, name): - if parser.has_option("versioneer", name): - return parser.get("versioneer", name) - return None + root_pth = Path(root) + pyproject_toml = root_pth / "pyproject.toml" + setup_cfg = root_pth / "setup.cfg" + section: Union[Dict[str, Any], configparser.SectionProxy, None] = None + if pyproject_toml.exists() and have_tomllib: + try: + with open(pyproject_toml, 'rb') as fobj: + pp = tomllib.load(fobj) + section = pp['tool']['versioneer'] + except (tomllib.TOMLDecodeError, KeyError) as e: + print(f"Failed to load config from {pyproject_toml}: {e}") + print("Try to load it from setup.cfg") + if not section: + parser = configparser.ConfigParser() + with open(setup_cfg) as cfg_file: + parser.read_file(cfg_file) + parser.get("versioneer", "VCS") # raise error if missing + + section = parser["versioneer"] + + # `cast`` really shouldn't be used, but its simplest for the + # common VersioneerConfig users at the moment. We verify against + # `None` values elsewhere where it matters + cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = get(parser, "style") or "" - cfg.versionfile_source = get(parser, "versionfile_source") - cfg.versionfile_build = get(parser, "versionfile_build") - cfg.tag_prefix = get(parser, "tag_prefix") - if cfg.tag_prefix in ("''", '""'): + cfg.VCS = section['VCS'] + cfg.style = section.get("style", "") + cfg.versionfile_source = cast(str, section.get("versionfile_source")) + cfg.versionfile_build = section.get("versionfile_build") + cfg.tag_prefix = cast(str, section.get("tag_prefix")) + if cfg.tag_prefix in ("''", '""', None): cfg.tag_prefix = "" - cfg.parentdir_prefix = get(parser, "parentdir_prefix") - cfg.verbose = get(parser, "verbose") + cfg.parentdir_prefix = section.get("parentdir_prefix") + if isinstance(section, configparser.SectionProxy): + # Make sure configparser translates to bool + cfg.verbose = section.getboolean("verbose") + else: + cfg.verbose = section.get("verbose") + return cfg @@ -366,37 +444,48 @@ class NotThisMethod(Exception): # these dictionaries contain VCS-specific tools -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f + HANDLERS.setdefault(vcs, {})[method] = f return f return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -407,26 +496,25 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, if verbose: print("unable to find command, tried %s" % (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -LONG_VERSION_PY['git'] = ''' +LONG_VERSION_PY['git'] = r''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) +# This file is released into the public domain. +# Generated by versioneer-0.29 +# https://github.com/python-versioneer/python-versioneer """Git implementation of _version.py.""" @@ -435,9 +523,11 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, import re import subprocess import sys +from typing import Any, Callable, Dict, List, Optional, Tuple +import functools -def get_keywords(): +def get_keywords() -> Dict[str, str]: """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must @@ -453,8 +543,15 @@ def get_keywords(): class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + parentdir_prefix: str + versionfile_source: str + verbose: bool + -def get_config(): +def get_config() -> VersioneerConfig: """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py @@ -472,13 +569,13 @@ class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} @@ -487,22 +584,35 @@ def decorate(f): return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -513,18 +623,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, if verbose: print("unable to find command, tried %%s" %% (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %%s (error)" %% dispcmd) print("stdout was %%s" %% stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -533,15 +645,14 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return {"version": dirname[len(parentdir_prefix):], "full-revisionid": None, "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: print("Tried directories %%s but none started with prefix %%s" %% @@ -550,41 +661,48 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -597,11 +715,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %%d @@ -610,7 +728,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%%s', no digits" %% ",".join(refs - tags)) if verbose: @@ -619,6 +737,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %%s" %% r) return {"version": r, @@ -634,7 +757,12 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -645,8 +773,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %%s not under git control" %% root) @@ -654,24 +789,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%%s*" %% tag_prefix], - cwd=root) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -688,7 +856,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # TAG-NUM-gHEX mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? + # unparsable. Maybe git-describe is misbehaving? pieces["error"] = ("unable to parse git-describe output: '%%s'" %% describe_out) return pieces @@ -713,26 +881,27 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], - cwd=root)[0].strip() + date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -757,23 +926,71 @@ def render_pep440(pieces): return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%%d.g%%s" %% (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%%d" %% pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%%d.dev%%d" %% (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%%d" %% (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%%d" %% pieces["distance"] + rendered = "0.post0.dev%%d" %% pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -800,12 +1017,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -822,7 +1068,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -842,7 +1088,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -862,7 +1108,7 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: return {"version": "unknown", @@ -876,10 +1122,14 @@ def render(pieces, style): if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -894,7 +1144,7 @@ def render(pieces, style): "date": pieces.get("date")} -def get_versions(): +def get_versions() -> Dict[str, Any]: """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some @@ -915,7 +1165,7 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split('/'): root = os.path.dirname(root) except NameError: return {"version": "0+unknown", "full-revisionid": None, @@ -942,41 +1192,48 @@ def get_versions(): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -989,11 +1246,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1002,7 +1259,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -1011,6 +1268,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %s" % r) return {"version": r, @@ -1026,7 +1288,12 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -1037,8 +1304,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -1046,24 +1320,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -1080,7 +1387,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # TAG-NUM-gHEX mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? + # unparsable. Maybe git-describe is misbehaving? pieces["error"] = ("unable to parse git-describe output: '%s'" % describe_out) return pieces @@ -1105,19 +1412,20 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def do_vcs_install(manifest_in, versionfile_source, ipy): +def do_vcs_install(versionfile_source: str, ipy: Optional[str]) -> None: """Git-specific installation logic for Versioneer. For Git, this means creating/changing .gitattributes to mark _version.py @@ -1126,36 +1434,40 @@ def do_vcs_install(manifest_in, versionfile_source, ipy): GITS = ["git"] if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - files = [manifest_in, versionfile_source] + files = [versionfile_source] if ipy: files.append(ipy) - try: - me = __file__ - if me.endswith(".pyc") or me.endswith(".pyo"): - me = os.path.splitext(me)[0] + ".py" - versioneer_file = os.path.relpath(me) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) + if "VERSIONEER_PEP518" not in globals(): + try: + my_path = __file__ + if my_path.endswith((".pyc", ".pyo")): + my_path = os.path.splitext(my_path)[0] + ".py" + versioneer_file = os.path.relpath(my_path) + except NameError: + versioneer_file = "versioneer.py" + files.append(versioneer_file) present = False try: - f = open(".gitattributes", "r") - for line in f.readlines(): - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - f.close() - except EnvironmentError: + with open(".gitattributes", "r") as fobj: + for line in fobj: + if line.strip().startswith(versionfile_source): + if "export-subst" in line.strip().split()[1:]: + present = True + break + except OSError: pass if not present: - f = open(".gitattributes", "a+") - f.write("%s export-subst\n" % versionfile_source) - f.close() + with open(".gitattributes", "a+") as fobj: + fobj.write(f"{versionfile_source} export-subst\n") files.append(".gitattributes") run_command(GITS, ["add", "--"] + files) -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -1164,15 +1476,14 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return {"version": dirname[len(parentdir_prefix):], "full-revisionid": None, "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: print("Tried directories %s but none started with prefix %s" % @@ -1181,7 +1492,7 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.18) from +# This file was generated by 'versioneer.py' (0.29) from # revision-control system data, or from the parent directory name of an # unpacked source archive. Distribution tarballs contain a pre-generated copy # of this file. @@ -1198,12 +1509,12 @@ def get_versions(): """ -def versions_from_file(filename): +def versions_from_file(filename: str) -> Dict[str, Any]: """Try to determine the version from _version.py if present.""" try: with open(filename) as f: contents = f.read() - except EnvironmentError: + except OSError: raise NotThisMethod("unable to read _version.py") mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S) @@ -1215,9 +1526,8 @@ def versions_from_file(filename): return json.loads(mo.group(1)) -def write_to_version_file(filename, versions): +def write_to_version_file(filename: str, versions: Dict[str, Any]) -> None: """Write the given version number to the given _version.py file.""" - os.unlink(filename) contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) with open(filename, "w") as f: @@ -1226,14 +1536,14 @@ def write_to_version_file(filename, versions): print("set %s to '%s'" % (filename, versions["version"])) -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -1258,23 +1568,71 @@ def render_pep440(pieces): return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] + rendered = "0.post0.dev%d" % pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -1301,12 +1659,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -1323,7 +1710,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -1343,7 +1730,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -1363,7 +1750,7 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: return {"version": "unknown", @@ -1377,10 +1764,14 @@ def render(pieces, style): if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -1399,7 +1790,7 @@ class VersioneerBadRootError(Exception): """The project root directory is unknown or missing key files.""" -def get_versions(verbose=False): +def get_versions(verbose: bool = False) -> Dict[str, Any]: """Get the project version from whatever source is available. Returns dict with two keys: 'version' and 'full'. @@ -1414,7 +1805,7 @@ def get_versions(verbose=False): assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose + verbose = verbose or bool(cfg.verbose) # `bool()` used to avoid `None` assert cfg.versionfile_source is not None, \ "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" @@ -1475,13 +1866,17 @@ def get_versions(verbose=False): "date": None} -def get_version(): +def get_version() -> str: """Get the short version string for this project.""" return get_versions()["version"] -def get_cmdclass(): - """Get the custom setuptools/distutils subclasses used by Versioneer.""" +def get_cmdclass(cmdclass: Optional[Dict[str, Any]] = None): + """Get the custom setuptools subclasses used by Versioneer. + + If the package uses a different cmdclass (e.g. one from numpy), it + should be provide as an argument. + """ if "versioneer" in sys.modules: del sys.modules["versioneer"] # this fixes the "python setup.py develop" case (also 'install' and @@ -1495,25 +1890,25 @@ def get_cmdclass(): # parent is protected against the child's "import versioneer". By # removing ourselves from sys.modules here, before the child build # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/warner/python-versioneer/issues/52 + # Also see https://github.com/python-versioneer/python-versioneer/issues/52 - cmds = {} + cmds = {} if cmdclass is None else cmdclass.copy() - # we add "version" to both distutils and setuptools - from distutils.core import Command + # we add "version" to setuptools + from setuptools import Command class cmd_version(Command): description = "report generated version string" - user_options = [] - boolean_options = [] + user_options: List[Tuple[str, str, str]] = [] + boolean_options: List[str] = [] - def initialize_options(self): + def initialize_options(self) -> None: pass - def finalize_options(self): + def finalize_options(self) -> None: pass - def run(self): + def run(self) -> None: vers = get_versions(verbose=True) print("Version: %s" % vers["version"]) print(" full-revisionid: %s" % vers.get("full-revisionid")) @@ -1523,7 +1918,7 @@ def run(self): print(" error: %s" % vers["error"]) cmds["version"] = cmd_version - # we override "build_py" in both distutils and setuptools + # we override "build_py" in setuptools # # most invocation pathways end up running build_py: # distutils/build -> build_py @@ -1538,18 +1933,25 @@ def run(self): # then does setup.py bdist_wheel, or sometimes setup.py install # setup.py egg_info -> ? + # pip install -e . and setuptool/editable_wheel will invoke build_py + # but the build_py command is not expected to copy any files. + # we override different "build_py" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.build_py import build_py as _build_py + if 'build_py' in cmds: + _build_py: Any = cmds['build_py'] else: - from distutils.command.build_py import build_py as _build_py + from setuptools.command.build_py import build_py as _build_py class cmd_build_py(_build_py): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() _build_py.run(self) + if getattr(self, "editable_mode", False): + # During editable installs `.py` and data files are + # not copied to build_lib + return # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: @@ -1559,8 +1961,40 @@ def run(self): write_to_version_file(target_versionfile, versions) cmds["build_py"] = cmd_build_py + if 'build_ext' in cmds: + _build_ext: Any = cmds['build_ext'] + else: + from setuptools.command.build_ext import build_ext as _build_ext + + class cmd_build_ext(_build_ext): + def run(self) -> None: + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + _build_ext.run(self) + if self.inplace: + # build_ext --inplace will only build extensions in + # build/lib<..> dir with no _version.py to write to. + # As in place builds will already have a _version.py + # in the module dir, we do not need to write one. + return + # now locate _version.py in the new build/ directory and replace + # it with an updated value + if not cfg.versionfile_build: + return + target_versionfile = os.path.join(self.build_lib, + cfg.versionfile_build) + if not os.path.exists(target_versionfile): + print(f"Warning: {target_versionfile} does not exist, skipping " + "version update. This can happen if you are running build_ext " + "without first running build_py.") + return + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + cmds["build_ext"] = cmd_build_ext + if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe + from cx_Freeze.dist import build_exe as _build_exe # type: ignore # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ @@ -1569,7 +2003,7 @@ def run(self): # ... class cmd_build_exe(_build_exe): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() @@ -1593,12 +2027,12 @@ def run(self): if 'py2exe' in sys.modules: # py2exe enabled? try: - from py2exe.distutils_buildexe import py2exe as _py2exe # py3 + from py2exe.setuptools_buildexe import py2exe as _py2exe # type: ignore except ImportError: - from py2exe.build_exe import py2exe as _py2exe # py2 + from py2exe.distutils_buildexe import py2exe as _py2exe # type: ignore class cmd_py2exe(_py2exe): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() @@ -1619,14 +2053,51 @@ def run(self): }) cmds["py2exe"] = cmd_py2exe + # sdist farms its file list building out to egg_info + if 'egg_info' in cmds: + _egg_info: Any = cmds['egg_info'] + else: + from setuptools.command.egg_info import egg_info as _egg_info + + class cmd_egg_info(_egg_info): + def find_sources(self) -> None: + # egg_info.find_sources builds the manifest list and writes it + # in one shot + super().find_sources() + + # Modify the filelist and normalize it + root = get_root() + cfg = get_config_from_root(root) + self.filelist.append('versioneer.py') + if cfg.versionfile_source: + # There are rare cases where versionfile_source might not be + # included by default, so we must be explicit + self.filelist.append(cfg.versionfile_source) + self.filelist.sort() + self.filelist.remove_duplicates() + + # The write method is hidden in the manifest_maker instance that + # generated the filelist and was thrown away + # We will instead replicate their final normalization (to unicode, + # and POSIX-style paths) + from setuptools import unicode_utils + normalized = [unicode_utils.filesys_decode(f).replace(os.sep, '/') + for f in self.filelist.files] + + manifest_filename = os.path.join(self.egg_info, 'SOURCES.txt') + with open(manifest_filename, 'w') as fobj: + fobj.write('\n'.join(normalized)) + + cmds['egg_info'] = cmd_egg_info + # we override different "sdist" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.sdist import sdist as _sdist + if 'sdist' in cmds: + _sdist: Any = cmds['sdist'] else: - from distutils.command.sdist import sdist as _sdist + from setuptools.command.sdist import sdist as _sdist class cmd_sdist(_sdist): - def run(self): + def run(self) -> None: versions = get_versions() self._versioneer_generated_versions = versions # unless we update this, the command will keep using the old @@ -1634,7 +2105,7 @@ def run(self): self.distribution.metadata.version = versions["version"] return _sdist.run(self) - def make_release_tree(self, base_dir, files): + def make_release_tree(self, base_dir: str, files: List[str]) -> None: root = get_root() cfg = get_config_from_root(root) _sdist.make_release_tree(self, base_dir, files) @@ -1687,21 +2158,26 @@ def make_release_tree(self, base_dir, files): """ -INIT_PY_SNIPPET = """ +OLD_SNIPPET = """ from ._version import get_versions __version__ = get_versions()['version'] del get_versions """ +INIT_PY_SNIPPET = """ +from . import {0} +__version__ = {0}.get_versions()['version'] +""" -def do_setup(): - """Main VCS-independent setup function for installing Versioneer.""" + +def do_setup() -> int: + """Do main VCS-independent setup function for installing Versioneer.""" root = get_root() try: cfg = get_config_from_root(root) - except (EnvironmentError, configparser.NoSectionError, + except (OSError, configparser.NoSectionError, configparser.NoOptionError) as e: - if isinstance(e, (EnvironmentError, configparser.NoSectionError)): + if isinstance(e, (OSError, configparser.NoSectionError)): print("Adding sample versioneer config to setup.cfg", file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: @@ -1721,62 +2197,37 @@ def do_setup(): ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") + maybe_ipy: Optional[str] = ipy if os.path.exists(ipy): try: with open(ipy, "r") as f: old = f.read() - except EnvironmentError: + except OSError: old = "" - if INIT_PY_SNIPPET not in old: + module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0] + snippet = INIT_PY_SNIPPET.format(module) + if OLD_SNIPPET in old: + print(" replacing boilerplate in %s" % ipy) + with open(ipy, "w") as f: + f.write(old.replace(OLD_SNIPPET, snippet)) + elif snippet not in old: print(" appending to %s" % ipy) with open(ipy, "a") as f: - f.write(INIT_PY_SNIPPET) + f.write(snippet) else: print(" %s unmodified" % ipy) else: print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make sure both the top-level "versioneer.py" and versionfile_source - # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so - # they'll be copied into source distributions. Pip won't be able to - # install the package without this. - manifest_in = os.path.join(root, "MANIFEST.in") - simple_includes = set() - try: - with open(manifest_in, "r") as f: - for line in f: - if line.startswith("include "): - for include in line.split()[1:]: - simple_includes.add(include) - except EnvironmentError: - pass - # That doesn't cover everything MANIFEST.in can do - # (http://docs.python.org/2/distutils/sourcedist.html#commands), so - # it might give some false negatives. Appending redundant 'include' - # lines is safe, though. - if "versioneer.py" not in simple_includes: - print(" appending 'versioneer.py' to MANIFEST.in") - with open(manifest_in, "a") as f: - f.write("include versioneer.py\n") - else: - print(" 'versioneer.py' already in MANIFEST.in") - if cfg.versionfile_source not in simple_includes: - print(" appending versionfile_source ('%s') to MANIFEST.in" % - cfg.versionfile_source) - with open(manifest_in, "a") as f: - f.write("include %s\n" % cfg.versionfile_source) - else: - print(" versionfile_source already in MANIFEST.in") + maybe_ipy = None # Make VCS-specific changes. For git, this means creating/changing # .gitattributes to mark _version.py for export-subst keyword # substitution. - do_vcs_install(manifest_in, cfg.versionfile_source, ipy) + do_vcs_install(cfg.versionfile_source, maybe_ipy) return 0 -def scan_setup_py(): +def scan_setup_py() -> int: """Validate the contents of setup.py against Versioneer's expectations.""" found = set() setters = False @@ -1813,10 +2264,14 @@ def scan_setup_py(): return errors +def setup_command() -> NoReturn: + """Set up Versioneer and exit with appropriate error code.""" + errors = do_setup() + errors += scan_setup_py() + sys.exit(1 if errors else 0) + + if __name__ == "__main__": cmd = sys.argv[1] if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1) + setup_command()