From e6c5dc1b6b080c8668524e199f3673cf506e4b51 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Thu, 8 Feb 2024 23:11:07 -0600 Subject: [PATCH] WIP, add pre/post copy and pre/post update tasks --- copier/main.py | 20 ++++++++ copier/template.py | 44 +++++++++++++++++ tests/test_pre_copy.py | 109 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 173 insertions(+) create mode 100644 tests/test_pre_copy.py diff --git a/copier/main.py b/copier/main.py index 3d90e853f..f8897e800 100644 --- a/copier/main.py +++ b/copier/main.py @@ -220,11 +220,27 @@ def _check_unsafe(self, mode: Literal["copy", "update"]) -> None: features.add("jinja_extensions") if self.template.tasks: features.add("tasks") + if self.template.pre_copy: + features.add("pre_copy") + if self.template.pre_update: + features.add("pre_update") + if self.template.post_copy: + features.add("post_copy") + if self.template.post_update: + features.add("post_update") if mode == "update" and self.subproject.template: if self.subproject.template.jinja_extensions: features.add("jinja_extensions") if self.subproject.template.tasks: features.add("tasks") + if self.subproject.template.pre_copy: + features.add("pre_copy") + if self.subproject.template.pre_update: + features.add("pre_update") + if self.subproject.template.post_copy: + features.add("post_copy") + if self.subproject.template.post_update: + features.add("post_update") for stage in get_args(Literal["before", "after"]): if self.template.migration_tasks(stage, self.subproject.template): features.add("migrations") @@ -746,6 +762,7 @@ def run_copy(self) -> None: was_existing = self.subproject.local_abspath.exists() src_abspath = self.template_copy_root try: + self._execute_tasks(self.template.pre_copy) if not self.quiet: # TODO Unify printing tools print( @@ -757,6 +774,7 @@ def run_copy(self) -> None: # TODO Unify printing tools print("") # padding space self._execute_tasks(self.template.tasks) + self._execute_tasks(self.template.post_copy) except Exception: if not was_existing and self.cleanup_on_error: rmtree(self.subproject.local_abspath) @@ -818,12 +836,14 @@ def run_update(self) -> None: # asking for confirmation raise UserMessageError("Enable overwrite to update a subproject.") self._print_message(self.template.message_before_update) + self._execute_tasks(self.template.pre_update) if not self.quiet: # TODO Unify printing tools print( f"Updating to template version {self.template.version}", file=sys.stderr ) self._apply_update() + self._execute_tasks(self.template.post_update) self._print_message(self.template.message_after_update) def _apply_update(self): diff --git a/copier/template.py b/copier/template.py index dacd7a318..dabbc9fdb 100644 --- a/copier/template.py +++ b/copier/template.py @@ -453,6 +453,50 @@ def tasks(self) -> Sequence[Task]: for cmd in self.config_data.get("tasks", []) ] + @cached_property + def pre_copy(self) -> Sequence[Task]: + """Get pre-copy tasks defined in the template. + + See [pre_copy][]. + """ + return [ + Task(cmd=cmd, extra_env={"STAGE": "pre_copy"}) + for cmd in self.config_data.get("pre_copy", []) + ] + + @cached_property + def post_copy(self) -> Sequence[Task]: + """Get post-copy tasks defined in the template. + + See [post_copy][]. + """ + return [ + Task(cmd=cmd, extra_env={"STAGE": "post_copy"}) + for cmd in self.config_data.get("post_copy", []) + ] + + @cached_property + def pre_update(self) -> Sequence[Task]: + """Get pre-update tasks defined in the template. + + See [pre_update][]. + """ + return [ + Task(cmd=cmd, extra_env={"STAGE": "pre_update"}) + for cmd in self.config_data.get("pre_update", []) + ] + + @cached_property + def post_update(self) -> Sequence[Task]: + """Get post-update tasks defined in the template. + + See [post_update][]. + """ + return [ + Task(cmd=cmd, extra_env={"STAGE": "post_update"}) + for cmd in self.config_data.get("post_update", []) + ] + @cached_property def templates_suffix(self) -> str: """Get the suffix defined for templates. diff --git a/tests/test_pre_copy.py b/tests/test_pre_copy.py new file mode 100644 index 000000000..a68226073 --- /dev/null +++ b/tests/test_pre_copy.py @@ -0,0 +1,109 @@ +from pathlib import Path +from typing import Literal, Optional + +import pytest + +import copier + +from .helpers import BRACKET_ENVOPS_JSON, SUFFIX_TMPL, build_file_tree + + +@pytest.fixture(scope="module") +def template_path(tmp_path_factory: pytest.TempPathFactory) -> str: + root = tmp_path_factory.mktemp("demo_pre_copy") + build_file_tree( + { + (root / "copier.yaml"): ( + f"""\ + _templates_suffix: {SUFFIX_TMPL} + _envops: {BRACKET_ENVOPS_JSON} + + other_file: bye + + # This tests two things: + # 1. That the tasks are being executed in the destination folder; and + # 2. That the tasks are being executed in order, one after another + _pre_copy: + - mkdir hello + - cd hello && touch world + - touch [[ other_file ]] + - ["[[ _copier_python ]]", "-c", "open('pyfile', 'w').close()"] + """ + ) + } + ) + return str(root) + + +def test_render_tasks(template_path: str, tmp_path: Path) -> None: + copier.run_copy(template_path, tmp_path, data={"other_file": "custom"}, unsafe=True) + assert (tmp_path / "custom").is_file() + + +def test_copy_tasks(template_path: str, tmp_path: Path) -> None: + copier.run_copy( + template_path, tmp_path, quiet=True, defaults=True, overwrite=True, unsafe=True + ) + assert (tmp_path / "hello").exists() + assert (tmp_path / "hello").is_dir() + assert (tmp_path / "hello" / "world").exists() + assert (tmp_path / "bye").is_file() + assert (tmp_path / "pyfile").is_file() + + +def test_pretend_mode(tmp_path_factory: pytest.TempPathFactory) -> None: + src, dst = map(tmp_path_factory.mktemp, ("src", "dst")) + build_file_tree( + { + (src / "copier.yml"): ( + """ + _pre_copy: + - touch created-by-pre-copy.txt + """ + ) + } + ) + copier.run_copy(str(src), dst, pretend=True, unsafe=True) + assert not (dst / "created-by-pre-copy.txt").exists() + + +@pytest.mark.parametrize( + "os, filename", + [ + ("linux", "linux.txt"), + ("macos", "macos.txt"), + ("windows", "windows.txt"), + (None, "unsupported.txt"), + ], +) +def test_os_specific_tasks( + tmp_path_factory: pytest.TempPathFactory, + monkeypatch: pytest.MonkeyPatch, + os: Optional[Literal["linux", "macos", "windows"]], + filename: str, +) -> None: + src, dst = map(tmp_path_factory.mktemp, ("src", "dst")) + build_file_tree( + { + (src / "copier.yml"): ( + """\ + _pre_copy: + - >- + {% if _copier_conf.os == 'linux' %} + touch linux.txt + {% elif _copier_conf.os == 'macos' %} + touch macos.txt + {% elif _copier_conf.os == 'windows' %} + touch windows.txt + {% elif _copier_conf.os is none %} + touch unsupported.txt + {% else %} + touch never.txt + {% endif %} + """ + ) + } + ) + monkeypatch.setattr("copier.main.OS", os) + copier.run_copy(str(src), dst, unsafe=True) + assert (dst / filename).exists()