Skip to content

Commit

Permalink
refactor: avoid duplicate token passing to hf api
Browse files Browse the repository at this point in the history
  • Loading branch information
davidberenstein1957 committed Oct 11, 2024
1 parent 28e197a commit f456d42
Showing 1 changed file with 32 additions and 36 deletions.
68 changes: 32 additions & 36 deletions argilla/src/argilla/_helpers/_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,33 @@ def deploy_on_spaces(
overwrite: Optional[Union[bool, None]] = False,
) -> "Argilla":
"""
Deploys Argilla on Hugging Face Spaces.
Args:
api_key (str): The Argilla API key to be defined for the owner user and creator of the Space.
repo_name (Optional[str]): The ID of the repository where Argilla will be deployed. Defaults to "argilla".
org_name (Optional[str]): The name of the organization where Argilla will be deployed. Defaults to None.
hf_token (Optional[Union[str, None]]): The Hugging Face authentication token. Defaults to None.
space_storage (Optional[Union[str, SpaceStorage]]): The persistant storage size for the space. Defaults to None without persistant storage.
space_hardware (Optional[Union[str, SpaceHardware]]): The hardware configuration for the space. Defaults to "cpu-basic" with downtime after 48 hours of inactivity.
private (Optional[Union[bool, None]]): Whether the space should be private. Defaults to False.
overwrite (Optional[Union[bool, None]]): Whether to overwrite the config of an existing space. Defaults to False.
Returns:
Argilla: The Argilla client.
Example:
```Python
import argilla as rg
client = rg.Argilla.deploy_on_spaces(api_key="12345678")
```
Deploys Argilla on Hugging Face Spaces.
Args:
api_key (str): The Argilla API key to be defined for the owner user and creator of the Space.
repo_name (Optional[str]): The ID of the repository where Argilla will be deployed. Defaults to "argilla".
org_name (Optional[str]): The name of the organization where Argilla will be deployed. Defaults to None.
hf_token (Optional[Union[str, None]]): The Hugging Face authentication token. Defaults to None.
space_storage (Optional[Union[str, SpaceStorage]]): The persistant storage size for the space. Defaults to None without persistant storage.
space_hardware (Optional[Union[str, SpaceHardware]]): The hardware configuration for the space. Defaults to "cpu-basic" with downtime after 48 hours of inactivity.
private (Optional[Union[bool, None]]): Whether the space should be private. Defaults to False.
overwrite (Optional[Union[bool, None]]): Whether to overwrite the config of an existing space. Defaults to False.
Returns:
Argilla: The Argilla client.
Example:
```Python
import argilla as rg
api
client = rg.Argilla.deploy_on_spaces(api_key="12345678")
```
"""
hf_token = cls._acquire_hf_token(ht_token=hf_token)
api = HfApi(token=hf_token)
hf_api = HfApi(token=hf_token)

# Get the org name from the repo name or default to the current user
token_username = api.whoami(token=hf_token)["name"]
token_username = hf_api.whoami()["name"]
org_name = org_name or token_username
repo_id = f"{org_name}/{repo_name}"

Expand All @@ -81,50 +81,46 @@ def deploy_on_spaces(
]

# Check if the space already exists
if api.repo_exists(repo_id=repo_id, repo_type="space", token=hf_token):
if cls._check_if_stage_can_be_build(api.get_space_runtime(repo_id=repo_id, token=hf_token).stage):
api.restart_space(repo_id=repo_id, token=hf_token)
if hf_api.repo_exists(repo_id=repo_id, repo_type="space"):
if cls._check_if_stage_can_be_build(hf_api.get_space_runtime(repo_id=repo_id).stage):
hf_api.restart_space(repo_id=repo_id)

if overwrite:
for secret in secrets:
api.add_space_secret(
hf_api.add_space_secret(
repo_id=repo_id,
key=secret["key"],
value=secret["value"],
description=secret["description"],
token=hf_token,
)

if space_hardware:
api.request_space_hardware(repo_id=repo_id, hardware=space_hardware, token=hf_token)
hf_api.request_space_hardware(repo_id=repo_id, hardware=space_hardware)

if space_storage:
api.request_space_storage(repo_id=repo_id, storage=space_storage, token=hf_token)
hf_api.request_space_storage(repo_id=repo_id, storage=space_storage)
else:
cls._space_storage_warning()
else:
if space_storage is None:
cls._space_storage_warning()

api.duplicate_space(
hf_api.duplicate_space(
from_id=_ARGILLA_SPACE_TEMPLATE_REPO,
to_id=repo_id,
private=private,
token=hf_token,
exist_ok=True,
hardware=space_hardware,
storage=space_storage,
secrets=secrets,
)

repo_url: RepoUrl = api.create_repo(
repo_id=repo_id, repo_type="space", token=hf_token, exist_ok=True, space_sdk="docker"
)
repo_url: RepoUrl = hf_api.create_repo(repo_id=repo_id, repo_type="space", exist_ok=True, space_sdk="docker")
api_url: str = (
f"https://{cls._sanitize_url_component(org_name)}-{cls._sanitize_url_component(repo_name)}.hf.space/"
)
cls._log_message(cls, message=f"Argilla is being deployed at: {repo_url}")
while cls._check_if_running(api.get_space_runtime(repo_id=repo_id, token=hf_token).stage):
while cls._check_if_running(hf_api.get_space_runtime(repo_id=repo_id).stage):
time.sleep(_SLEEP_TIME)
cls._log_message(cls, message=f"Deployment in progress. Waiting {_SLEEP_TIME} seconds.")

Expand Down

0 comments on commit f456d42

Please sign in to comment.