From 6ce908ca2c2c7ebb0be7993b536ddccb98ba590a Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Wed, 27 Dec 2023 19:47:13 +0100 Subject: [PATCH] Added CLI app framework Close #6 --- labs.yml | 19 +- pyproject.toml | 28 ++- src/databricks/labs/blueprint/__about__.py | 2 +- src/databricks/labs/blueprint/__main__.py | 120 +++++++++++++ src/databricks/labs/blueprint/cli.py | 79 +++++++++ src/databricks/labs/blueprint/entrypoint.py | 7 +- src/databricks/labs/blueprint/wheels.py | 185 +++++++++++--------- tests/unit/test_cli.py | 35 ++++ tests/unit/test_wheels.py | 33 ++-- 9 files changed, 389 insertions(+), 119 deletions(-) create mode 100644 src/databricks/labs/blueprint/__main__.py create mode 100644 src/databricks/labs/blueprint/cli.py create mode 100644 tests/unit/test_cli.py diff --git a/labs.yml b/labs.yml index 1050347..e9bde97 100644 --- a/labs.yml +++ b/labs.yml @@ -2,9 +2,22 @@ name: blueprint description: Common libraries for Databricks Labs install: - script: src/databricks/labs/ucx/install.py -entrypoint: src/databricks/labs/ucx/cli.py + script: src/databricks/labs/blueprint/__init__.py +entrypoint: src/databricks/labs/blueprint/__main__.py min_python: 3.10 commands: - - name: init + - name: me + description: shows current username + flags: + - name: greeting + default: Hello + description: Greeting prefix + - name: workspaces + is_account: true + description: shows current workspaces + - name: init-project + is_unauthenticated: true description: initializes new project + flags: + - name: target + description: target folder diff --git a/pyproject.toml b/pyproject.toml index 8c7f0de..fec7806 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,3 @@ -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build] -sources = ["src"] -include = ["src"] - [project] name = "databricks-labs-blueprint" dynamic = ["version"] @@ -14,9 +6,6 @@ readme = "README.md" license-files = { paths = ["LICENSE", "NOTICE"] } requires-python = ">=3.10.6" # latest available in DBR 13.2 keywords = ["Databricks"] -authors = [ - { name = "Serge Smertin", email = "serge.smertin@databricks.com" }, -] classifiers = [ "Development Status :: 3 - Alpha", "License :: Other/Proprietary License", @@ -25,11 +14,19 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: CPython", ] -dependencies = ["databricks-sdk~=0.16.0"] +dependencies = ["databricks-sdk"] [project.urls] -Issues = "https://github.com/databricks/blueprint/issues" -Source = "https://github.com/databricks/blueprint" +Issues = "https://github.com/databrickslabs/blueprint/issues" +Source = "https://github.com/databrickslabs/blueprint" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +sources = ["src"] +include = ["src"] [tool.hatch.version] path = "src/databricks/labs/blueprint/__about__.py" @@ -68,7 +65,6 @@ verify = ["black --check .", "mypy ."] [tool.isort] -skip_glob = ["notebooks/*.py"] profile = "black" [tool.pytest.ini_options] @@ -93,7 +89,7 @@ branch = true parallel = true [tool.coverage.report] -omit = ["*/working-copy/*", "*/fresh_wheel_file/*"] +omit = ["*/working-copy/*", 'src/databricks/labs/blueprint/__main__.py'] exclude_lines = [ "no cov", "if __name__ == .__main__.:", diff --git a/src/databricks/labs/blueprint/__about__.py b/src/databricks/labs/blueprint/__about__.py index 5664bed..4c8f1c7 100644 --- a/src/databricks/labs/blueprint/__about__.py +++ b/src/databricks/labs/blueprint/__about__.py @@ -1,2 +1,2 @@ -# DO NOT MODIFY THIS FILE +# DO NOT MODIFY THIS FILE BY HAND __version__ = "0.0.1" diff --git a/src/databricks/labs/blueprint/__main__.py b/src/databricks/labs/blueprint/__main__.py new file mode 100644 index 0000000..b43f691 --- /dev/null +++ b/src/databricks/labs/blueprint/__main__.py @@ -0,0 +1,120 @@ +from pathlib import Path + +from databricks.labs.blueprint.cli import App +from databricks.labs.blueprint.entrypoint import ( + find_project_root, + get_logger, + relative_paths, +) +from databricks.labs.blueprint.tui import Prompts + +blueprint = App(__file__) +logger = get_logger(__file__) + +main_py_file = '''from databricks.sdk import AccountClient, WorkspaceClient +from databricks.labs.blueprint.entrypoint import get_logger +from databricks.labs.blueprint.cli import App + +__app__ = App(__file__) +logger = get_logger(__file__) + + +@__app__.command +def me(w: WorkspaceClient, greeting: str): + """Shows current username""" + logger.info(f"{greeting}, {w.current_user.me().user_name}!") + + +@__app__.command(is_account=True) +def workspaces(a: AccountClient): + """Shows workspaces""" + for ws in a.workspaces.list(): + logger.info(f"Workspace: {ws.workspace_name} ({ws.workspace_id})") + + +if "__main__" == __name__: + __app__() + +''' + +labs_yml_file = """--- +name: __app__ +description: Common libraries for Databricks Labs +install: + script: src/databricks/labs/__app__/__init__.py +entrypoint: src/databricks/labs/__app__/__main__.py +min_python: 3.10 + - name: me + description: shows current username + flags: + - name: greeting + default: Hello + description: Greeting prefix + - name: workspaces + is_account: true + description: shows current workspaces +""" + + +@blueprint.command(is_unauthenticated=True) +def init_project(target): + """Creates the required boilerplate structure""" + prompts = Prompts() + + project_root = find_project_root() + target_folder = Path(target) + + project_name = prompts.question("Name of the project", default=target_folder.name) + src_dir, dst_dir = relative_paths(project_root, target_folder.absolute()) + + ignore_names = { + ".git", + ".venv", + ".databricks", + ".mypy_cache", + ".idea", + ".coverage", + "htmlcov", + "__pycache__", + "tests", + ".databricks-login.json", + "coverage.xml", + "dist", + } + queue: list[Path] = [src_dir] # type: ignore[annotation-unchecked] + while queue: + current = queue.pop(0) + if current.name in ignore_names: + continue + if current.is_file(): + relative_file_name = current.as_posix().replace("blueprint", project_name) + dst_file = dst_dir / relative_file_name + dst_file.parent.mkdir(exist_ok=True, parents=True) + with current.open("r") as r, dst_file.open("w") as w: + content = r.read().replace("blueprint", project_name) + content = content.replace("databricks-sdk", "databricks-labs-blueprint") + w.write(content) + continue + virtual_env_marker = current / "pyvenv.cfg" + if virtual_env_marker.exists(): + continue + for file in current.iterdir(): + if file.as_posix() == "src/databricks/labs/blueprint": + continue + queue.append(file) + inner_package_dir = dst_dir / "src" / "databricks" / "labs" / project_name + inner_package_dir.mkdir(parents=True, exist_ok=True) + with (inner_package_dir / "__main__.py").open("w") as f: + f.write(main_py_file.replace("__app__", project_name)) + with (inner_package_dir / "__init__.py").open("w") as f: + f.write(f"from databricks.labs.{project_name}.__about__ import __version__") + with (inner_package_dir / "__about__.py").open("w") as f: + f.write('# DO NOT MODIFY THIS FILE BY HAND\n__version__ = "0.0.1"\n') + with (dst_dir / "labs.yml").open("w") as f: + f.write(labs_yml_file.replace("__app__", project_name)) + with (dst_dir / "CODEOWNERS").open("w") as f: + f.write(f"* @nfx\n/src @databrickslabs/{project_name}-write\n/tests @databrickslabs/{project_name}-write\n") + + +if "__main__" == __name__: + blueprint() diff --git a/src/databricks/labs/blueprint/cli.py b/src/databricks/labs/blueprint/cli.py new file mode 100644 index 0000000..27bffa9 --- /dev/null +++ b/src/databricks/labs/blueprint/cli.py @@ -0,0 +1,79 @@ +import json +import logging +from dataclasses import dataclass +from typing import Callable + +from databricks.sdk import AccountClient, WorkspaceClient + +from databricks.labs.blueprint.entrypoint import get_logger, run_main +from databricks.labs.blueprint.wheels import ProductInfo + + +@dataclass +class Command: + name: str + description: str + fn: Callable[..., None] + is_account: bool = False + is_unauthenticated: bool = False + + def needs_workspace_client(self): + if self.is_unauthenticated: + return False + if self.is_account: + return False + return True + + +class App: + def __init__(self, __file: str): + self._mapping: dict[str, Command] = {} + self._logger = get_logger(__file) + self._product_info = ProductInfo() + + def command(self, is_account: bool = False, is_unauthenticated: bool = False): + def decorator(func): + command_name = func.__name__.replace("_", "-") + if not func.__doc__: + raise SyntaxError(f"{func.__name__} must have some doc comment") + self._mapping[command_name] = Command( + name=command_name, + description=func.__doc__, + fn=func, + is_account=is_account, + is_unauthenticated=is_unauthenticated, + ) + return func + + return decorator + + def _route(self, raw): + payload = json.loads(raw) + command = payload["command"] + if command not in self._mapping: + msg = f"cannot find command: {command}" + raise KeyError(msg) + flags = payload["flags"] + log_level = flags.pop("log_level") + if log_level == "disabled": + log_level = "info" + databricks_logger = logging.getLogger("databricks") + databricks_logger.setLevel(log_level.upper()) + kwargs = {k.replace("-", "_"): v for k, v in flags.items()} + try: + product_name = self._product_info.product_name() + product_version = self._product_info.version() + if self._mapping[command].needs_workspace_client(): + kwargs["w"] = WorkspaceClient(product=product_name, product_version=product_version) + elif self._mapping[command].is_account: + kwargs["a"] = AccountClient(product=product_name, product_version=product_version) + self._mapping[command].fn(**kwargs) + except Exception as err: + logger = self._logger.getChild(command) + if log_level.lower() in ("debug", "trace"): + logger.error(f"Failed to call {command}", exc_info=err) + else: + logger.error(f"{err.__class__.__name__}: {err}") + + def __call__(self): + run_main(self._route) diff --git a/src/databricks/labs/blueprint/entrypoint.py b/src/databricks/labs/blueprint/entrypoint.py index 19208cb..1d4d33a 100644 --- a/src/databricks/labs/blueprint/entrypoint.py +++ b/src/databricks/labs/blueprint/entrypoint.py @@ -16,8 +16,11 @@ def get_logger(file_name: str): entrypoint = Path(file_name).absolute() relative = entrypoint.relative_to(project_root).as_posix() - relative = relative.lstrip("src" + os.sep) - relative = relative.rstrip(".py") + relative = relative.removeprefix("src" + os.sep) + relative = relative.removesuffix("/__main__.py") + relative = relative.removesuffix("/__init__.py") + relative = relative.removesuffix("/cli.py") + relative = relative.removesuffix(".py") module_name = relative.replace(os.sep, ".") logger = logging.getLogger(module_name) diff --git a/src/databricks/labs/blueprint/wheels.py b/src/databricks/labs/blueprint/wheels.py index a29fd64..782f7ed 100644 --- a/src/databricks/labs/blueprint/wheels.py +++ b/src/databricks/labs/blueprint/wheels.py @@ -18,46 +18,34 @@ logger = logging.getLogger(__name__) -_IGNORE_DIR_NAMES = {".git", ".venv", ".databricks", ".mypy_cache", ".github", ".idea", ".coverage", "htmlcov"} - - -class Wheels(AbstractContextManager): - """Wheel builder""" - - __version: str | None = None - +IGNORE_DIR_NAMES = { + ".git", + ".venv", + ".databricks", + ".mypy_cache", + ".github", + ".idea", + ".coverage", + "htmlcov", + "__pycache__", + "tests", +} + + +class ProductInfo: def __init__( self, - ws: WorkspaceClient, - install_state: InstallState, *, - github_org: str = "databrickslabs", - verbose: bool = False, version_file_name: str = "__about__.py", project_root_finder: Callable[[], Path] | None = None, + github_org: str = "databrickslabs", ): if not project_root_finder: project_root_finder = find_project_root - self._ws = ws - self._install_state = install_state self._github_org = github_org - self._verbose = verbose self._version_file_name = version_file_name self._project_root_finder = project_root_finder - def is_git_checkout(self) -> bool: - project_root = self._project_root_finder() - git_config = project_root / ".git" / "config" - return git_config.exists() - - def is_unreleased_version(self) -> bool: - return "+" in self.version() - - def released_version(self) -> str: - project_root = self._project_root_finder() - version_file = self._find_version_file(project_root, [self._version_file_name]) - return self._read_version(version_file) - def version(self): """Returns current version of the project""" if hasattr(self, "__version"): @@ -66,49 +54,53 @@ def version(self): # normal install, downloaded releases won't have the .git folder self.__version = self.released_version() return self.__version + self.__version = self.unreleased_version() + return self.__version + + def product_name(self) -> str: + project_root = self._project_root_finder() + version_file = self.version_file_in(project_root) + version_file_folder = version_file.parent + return version_file_folder.name.replace("_", "-") + + def released_version(self) -> str: + project_root = self._project_root_finder() + version_file = self.version_file_in(project_root) + return self._read_version(version_file) + + def is_git_checkout(self) -> bool: + project_root = self._project_root_finder() + git_config = project_root / ".git" / "config" + return git_config.exists() + + def is_unreleased_version(self) -> bool: + return "+" in self.version() + + def unreleased_version(self) -> str: try: - self.__version = self._pep0440_version_from_git() - return self.__version + out = subprocess.run(["git", "describe", "--tags"], stdout=subprocess.PIPE, check=True) # noqa S607 + git_detached_version = out.stdout.decode("utf8") + dv = SemVer.parse(git_detached_version) + datestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + # new commits on main branch since the last tag + new_commits = dv.pre_release.split("-")[0] if dv.pre_release else None + # show that it's a version different from the released one in stats + bump_patch = dv.patch + 1 + # create something that is both https://semver.org and https://peps.python.org/pep-0440/ + semver_and_pep0440 = f"{dv.major}.{dv.minor}.{bump_patch}+{new_commits}{datestamp}" + # validate the semver + SemVer.parse(semver_and_pep0440) + return semver_and_pep0440 except subprocess.CalledProcessError as err: - logger.error( + logger.warning( "Cannot determine unreleased version. This can be fixed by adding " " `git fetch --prune --unshallow` to your CI configuration.", exc_info=err, ) - self.__version = self.released_version() - return self.__version + return self.released_version() - @staticmethod - def _pep0440_version_from_git(): - out = subprocess.run(["git", "describe", "--tags"], stdout=subprocess.PIPE, check=True) # noqa S607 - git_detached_version = out.stdout.decode("utf8") - dv = SemVer.parse(git_detached_version) - datestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") - # new commits on main branch since the last tag - new_commits = dv.pre_release.split("-")[0] if dv.pre_release else None - # show that it's a version different from the released one in stats - bump_patch = dv.patch + 1 - # create something that is both https://semver.org and https://peps.python.org/pep-0440/ - semver_and_pep0440 = f"{dv.major}.{dv.minor}.{bump_patch}+{new_commits}{datestamp}" - # validate the semver - SemVer.parse(semver_and_pep0440) - return semver_and_pep0440 - - def upload_to_dbfs(self) -> str: - with self._local_wheel.open("rb") as f: - self._ws.dbfs.mkdirs(self._remote_dir_name) - logger.info(f"Uploading wheel to dbfs:{self._remote_wheel}") - self._ws.dbfs.upload(self._remote_wheel, f, overwrite=True) - return self._remote_wheel - - def upload_to_wsfs(self) -> str: - with self._local_wheel.open("rb") as f: - self._ws.workspace.mkdirs(self._remote_dir_name) - logger.info(f"Uploading wheel to /Workspace{self._remote_wheel}") - self._ws.workspace.upload(self._remote_wheel, f, overwrite=True, format=ImportFormat.AUTO) - return self._remote_wheel - - def _find_version_file(self, root: Path, names: list[str]) -> Path: + def version_file_in(self, root: Path) -> Path: + names = [self._version_file_name] queue: list[Path] = [root] while queue: current = queue.pop(0) @@ -132,6 +124,43 @@ def _read_version(version_file: Path) -> str: raise SyntaxError("Cannot find __version__") return version_data["__version__"] + +class Wheels(AbstractContextManager): + """Wheel builder""" + + __version: str | None = None + + def __init__( + self, + ws: WorkspaceClient, + install_state: InstallState, + product_info: ProductInfo, + *, + verbose: bool = False, + project_root_finder: Callable[[], Path] | None = None, + ): + if not project_root_finder: + project_root_finder = find_project_root + self._ws = ws + self._install_state = install_state + self._product_info = product_info + self._verbose = verbose + self._project_root_finder = project_root_finder + + def upload_to_dbfs(self) -> str: + with self._local_wheel.open("rb") as f: + self._ws.dbfs.mkdirs(self._remote_dir_name) + logger.info(f"Uploading wheel to dbfs:{self._remote_wheel}") + self._ws.dbfs.upload(self._remote_wheel, f, overwrite=True) + return self._remote_wheel + + def upload_to_wsfs(self) -> str: + with self._local_wheel.open("rb") as f: + self._ws.workspace.mkdirs(self._remote_dir_name) + logger.info(f"Uploading wheel to /Workspace{self._remote_wheel}") + self._ws.workspace.upload(self._remote_wheel, f, overwrite=True, format=ImportFormat.AUTO) + return self._remote_wheel + def __enter__(self) -> "Wheels": self._tmp_dir = tempfile.TemporaryDirectory() self._local_wheel = self._build_wheel(self._tmp_dir.name, verbose=self._verbose) @@ -156,7 +185,7 @@ def _build_wheel(self, tmp_dir: str, *, verbose: bool = False): stdout = subprocess.DEVNULL stderr = subprocess.DEVNULL project_root = self._project_root_finder() - if self.is_git_checkout() and self.is_unreleased_version(): + if self._product_info.is_git_checkout() and self._product_info.is_unreleased_version(): # working copy becomes project root for building a wheel project_root = self._copy_root_to(tmp_dir) # and override the version file @@ -172,23 +201,23 @@ def _build_wheel(self, tmp_dir: str, *, verbose: bool = False): return next(Path(tmp_dir).glob("*.whl")) def _override_version_to_unreleased(self, tmp_dir_path: Path): - version_file = self._find_version_file(tmp_dir_path, [self._version_file_name]) + version_file = self._product_info.version_file_in(tmp_dir_path) with version_file.open("w") as f: - f.write(f'__version__ = "{self.version()}"') + f.write(f'__version__ = "{self._product_info.version()}"') def _copy_root_to(self, tmp_dir: str | Path): project_root = self._project_root_finder() tmp_dir_path = Path(tmp_dir) / "working-copy" + # copy everything to a temporary directory - shutil.copytree(project_root, tmp_dir_path, ignore=self._copy_ignore) - return tmp_dir_path + def copy_ignore(_, names: list[str]): + # callable(src, names) -> ignored_names + ignored_names = [] + for name in names: + if name not in IGNORE_DIR_NAMES: + continue + ignored_names.append(name) + return ignored_names - @staticmethod - def _copy_ignore(_, names: list[str]): - # callable(src, names) -> ignored_names - ignored_names = [] - for name in names: - if name not in _IGNORE_DIR_NAMES: - continue - ignored_names.append(name) - return ignored_names + shutil.copytree(project_root, tmp_dir_path, ignore=copy_ignore) + return tmp_dir_path diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py new file mode 100644 index 0000000..8fa7e60 --- /dev/null +++ b/tests/unit/test_cli.py @@ -0,0 +1,35 @@ +import json +import sys +from unittest import mock + +from databricks.labs.blueprint.cli import App + + +def test_commands(): + some = mock.Mock() + app = App(__file__) + + @app.command(is_unauthenticated=True) + def foo(name: str): + """Some comment""" + some(name) + + with mock.patch.object( + sys, + "argv", + [ + ..., + json.dumps( + { + "command": "foo", + "flags": { + "name": "y", + "log_level": "disabled", + }, + } + ), + ], + ): + app() + + some.assert_called_with("y") diff --git a/tests/unit/test_wheels.py b/tests/unit/test_wheels.py index 7268b0a..58cad1b 100644 --- a/tests/unit/test_wheels.py +++ b/tests/unit/test_wheels.py @@ -8,7 +8,7 @@ from databricks.labs.blueprint.__about__ import __version__ from databricks.labs.blueprint.entrypoint import is_in_debug from databricks.labs.blueprint.installer import InstallState -from databricks.labs.blueprint.wheels import Wheels +from databricks.labs.blueprint.wheels import ProductInfo, Wheels def test_build_and_upload_wheel(): @@ -16,7 +16,9 @@ def test_build_and_upload_wheel(): state = create_autospec(InstallState) state.product.return_value = "blueprint" state.install_folder.return_value = "~/.blueprint" - wheels = Wheels(ws, state) + product_info = ProductInfo() + + wheels = Wheels(ws, state, product_info) with wheels: assert os.path.exists(wheels._local_wheel) @@ -40,27 +42,20 @@ def test_build_and_upload_wheel(): def test_unreleased_version(tmp_path): if not is_in_debug(): pytest.skip("fails without `git fetch --prune --unshallow` configured") - ws = create_autospec(WorkspaceClient) - state = create_autospec(InstallState) - state.product.return_value = "blueprint" - state.install_folder.return_value = "~/.blueprint" - - wheels = Wheels(ws, state) - assert not __version__ == wheels.version() - assert __version__ == wheels.released_version() - assert wheels.is_unreleased_version() - assert wheels.is_git_checkout() + product_info = ProductInfo() + assert not __version__ == product_info.version() + assert __version__ == product_info.released_version() + assert product_info.is_unreleased_version() + assert product_info.is_git_checkout() def test_released_version(tmp_path): ws = create_autospec(WorkspaceClient) state = create_autospec(InstallState) - state.product.return_value = "blueprint" - state.install_folder.return_value = "~/.blueprint" - working_copy = Wheels(ws, state)._copy_root_to(tmp_path) - wheels = Wheels(ws, state, project_root_finder=lambda: working_copy) + working_copy = Wheels(ws, state, ProductInfo())._copy_root_to(tmp_path) + product_info = ProductInfo(project_root_finder=lambda: working_copy) - assert __version__ == wheels.version() - assert not wheels.is_unreleased_version() - assert not wheels.is_git_checkout() + assert __version__ == product_info.version() + assert not product_info.is_unreleased_version() + assert not product_info.is_git_checkout()