Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify migrations #1510

Merged
merged 3 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 58 additions & 9 deletions copier/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Literal,
Mapping,
Sequence,
TypeVar,
get_args,
overload,
)
Expand All @@ -45,11 +46,19 @@
)
from .subproject import Subproject
from .template import Task, Template
from .tools import OS, Style, normalize_git_path, printf, readlink
from .types import MISSING, AnyByStrDict, JSONSerializable, RelativePath, StrOrPath
from .tools import OS, Style, cast_to_bool, normalize_git_path, printf, readlink
from .types import (
MISSING,
AnyByStrDict,
JSONSerializable,
RelativePath,
StrOrPath,
)
from .user_data import DEFAULT_DATA, AnswersMap, Question
from .vcs import get_git

_T = TypeVar("_T")


@dataclass(config=ConfigDict(extra="forbid"))
class Worker:
Expand Down Expand Up @@ -195,12 +204,14 @@
return self

@overload
def __exit__(self, type: None, value: None, traceback: None) -> None: ...
def __exit__(self, type: None, value: None, traceback: None) -> None:
...

Check warning on line 208 in copier/main.py

View check run for this annotation

Codecov / codecov/patch

copier/main.py#L208

Added line #L208 was not covered by tests

@overload
def __exit__(
self, type: type[BaseException], value: BaseException, traceback: TracebackType
) -> None: ...
) -> None:
...

Check warning on line 214 in copier/main.py

View check run for this annotation

Codecov / codecov/patch

copier/main.py#L214

Added line #L214 was not covered by tests

def __exit__(
self,
Expand Down Expand Up @@ -277,13 +288,21 @@
tasks: The list of tasks to run.
"""
for i, task in enumerate(tasks):
extra_context = {f"_{k}": v for k, v in task.extra_vars.items()}

if not cast_to_bool(self._render_value(task.condition, extra_context)):
continue

task_cmd = task.cmd
if isinstance(task_cmd, str):
task_cmd = self._render_string(task_cmd)
task_cmd = self._render_string(task_cmd, extra_context)
use_shell = True
else:
task_cmd = [self._render_string(str(part)) for part in task_cmd]
task_cmd = [
self._render_string(str(part), extra_context) for part in task_cmd
]
use_shell = False

if not self.quiet:
print(
colors.info
Expand All @@ -292,7 +311,15 @@
)
if self.pretend:
continue
with local.cwd(self.subproject.local_abspath), local.env(**task.extra_env):

working_directory = (
# We can't use _render_path here, as that function has special handling for files in the template
self.subproject.local_abspath
/ Path(self._render_string(str(task.working_directory), extra_context))
).absolute()

extra_env = {k.upper(): str(v) for k, v in task.extra_vars.items()}
with local.cwd(working_directory), local.env(**extra_env):
subprocess.run(task_cmd, shell=use_shell, check=True, env=local.env)

def _render_context(self) -> Mapping[str, Any]:
Expand Down Expand Up @@ -709,15 +736,37 @@
return None
return result

def _render_string(self, string: str) -> str:
def _render_string(
self, string: str, extra_context: AnyByStrDict | None = None
) -> str:
"""Render one templated string.

Args:
string:
The template source string.

extra_context:
Additional variables to use for rendering the template.
"""
tpl = self.jinja_env.from_string(string)
return tpl.render(**self._render_context())
return tpl.render(**self._render_context(), **(extra_context or {}))

def _render_value(
self, value: _T, extra_context: AnyByStrDict | None = None
) -> str | _T:
"""Render a value, which may or may not be a templated string.

Args:
value:
The value to render.

extra_context:
Additional variables to use for rendering the template.
"""
try:
return self._render_string(value, extra_context=extra_context) # type: ignore[arg-type]
except TypeError:
return value

@cached_property
def subproject(self) -> Subproject:
Expand Down
112 changes: 87 additions & 25 deletions copier/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
UnsupportedVersionError,
)
from .tools import copier_version, handle_remove_readonly
from .types import AnyByStrDict, Env, VCSTypes
from .types import AnyByStrDict, VCSTypes
from .vcs import checkout_latest_tag, clone, get_git, get_repo

# Default list of files in the template to exclude from the rendered project
Expand Down Expand Up @@ -153,12 +153,26 @@
cmd:
Command to execute.

extra_env:
Additional environment variables to set while executing the command.
extra_vars:
Additional variables for the task.
Will be available as Jinja variables for rendering of `cmd`, `condition`
and `working_directory` and as environment variables while the task is
running.
As Jinja variables they will be prefixed by an underscore, while as
environment variables they will be upper cased.

condition:
The condition when a conditional task runs.

working_directory:
The directory from inside where to execute the task.
If `None`, the project directory will be used.
"""

cmd: str | Sequence[str]
extra_env: Env = field(default_factory=dict)
extra_vars: dict[str, Any] = field(default_factory=dict)
condition: str | bool = True
working_directory: Path = Path(".")


@dataclass
Expand Down Expand Up @@ -370,27 +384,64 @@
"""
result: list[Task] = []
if not (self.version and from_template.version):
return result
extra_env: Env = {
"STAGE": stage,
"VERSION_FROM": str(from_template.commit),
"VERSION_TO": str(self.commit),
"VERSION_PEP440_FROM": str(from_template.version),
"VERSION_PEP440_TO": str(self.version),
return []

Check warning on line 387 in copier/template.py

View check run for this annotation

Codecov / codecov/patch

copier/template.py#L387

Added line #L387 was not covered by tests
extra_vars: dict[str, Any] = {
"stage": stage,
"version_from": from_template.commit,
"version_to": self.commit,
"version_pep440_from": from_template.version,
"version_pep440_to": self.version,
}
migration: dict[str, Any]
for migration in self._raw_config.get("_migrations", []):
current = parse(migration["version"])
if self.version >= current > from_template.version:
extra_env = {
**extra_env,
"VERSION_CURRENT": migration["version"],
"VERSION_PEP440_CURRENT": str(current),
}
result.extend(
Task(cmd=cmd, extra_env=extra_env)
for cmd in migration.get(stage, [])
if any(key in migration for key in ("before", "after")):
# Legacy configuration format
warn(
"This migration configuration is deprecated. Please switch to the new format.",
category=DeprecationWarning,
)
current = parse(migration["version"])
if self.version >= current > from_template.version:
extra_vars = {
**extra_vars,
"version_current": migration["version"],
"version_pep440_current": current,
}
result.extend(
Task(cmd=cmd, extra_vars=extra_vars)
for cmd in migration.get(stage, [])
)
else:
# New configuration format
if isinstance(migration, (str, list)):
result.append(
Task(
cmd=migration,
extra_vars=extra_vars,
condition='{{ _stage == "after" }}',
)
)
else:
condition = migration.get("when", '{{ _stage == "after" }}')
working_directory = Path(migration.get("working_directory", "."))
if "version" in migration:
current = parse(migration["version"])
if not (self.version >= current > from_template.version):
continue
extra_vars = {
**extra_vars,
"version_current": migration["version"],
"version_pep440_current": current,
}
result.append(
Task(
cmd=migration["command"],
extra_vars=extra_vars,
condition=condition,
working_directory=working_directory,
)
)

return result

@cached_property
Expand Down Expand Up @@ -456,10 +507,21 @@

See [tasks][].
"""
return [
Task(cmd=cmd, extra_env={"STAGE": "task"})
for cmd in self.config_data.get("tasks", [])
]
extra_vars = {"stage": "task"}
tasks = []
for task in self.config_data.get("tasks", []):
if isinstance(task, dict):
tasks.append(
Task(
cmd=task["command"],
extra_vars=extra_vars,
condition=task.get("when", "true"),
working_directory=Path(task.get("working_directory", ".")),
)
)
else:
tasks.append(Task(cmd=task, extra_vars=extra_vars))
return tasks

@cached_property
def templates_suffix(self) -> str:
Expand Down
Loading
Loading