diff --git a/README.md b/README.md index eefe374..16c61ad 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ hf_oauth_scopes: CI -CI +CI @@ -80,16 +80,25 @@ pip install synthetic-dataset-generator ### Environment Variables -- `HF_TOKEN`: Your Hugging Face token to push your datasets to the Hugging Face Hub and run Inference Endpoints Requests. You can get one [here](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained). +- `HF_TOKEN`: Your Hugging Face token to push your datasets to the Hugging Face Hub and run *Free* Inference Endpoints Requests. You can get one [here](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained). + +Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables: + - `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla. - `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla. -## Quick Start +## Quickstart ```bash python app.py ``` +### Argilla integration + +Argilla is a open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/). + +![Argilla integration](https://huggingface.co/spaces/argilla/synthetic-data-generator/resolve/main/assets/argilla.png) + ## Custom synthetic data generation? Each pipeline is based on distilabel, so you can easily change the LLM or the pipeline steps. diff --git a/assets/argilla.png b/assets/argilla.png new file mode 100644 index 0000000..0acf8d2 Binary files /dev/null and b/assets/argilla.png differ diff --git a/src/distilabel_dataset_generator/apps/base.py b/src/distilabel_dataset_generator/apps/base.py index 4027954..0b6cc4f 100644 --- a/src/distilabel_dataset_generator/apps/base.py +++ b/src/distilabel_dataset_generator/apps/base.py @@ -475,6 +475,27 @@ def get_success_message_row() -> gr.Markdown: def show_success_message(org_name, repo_name) -> gr.Markdown: client = get_argilla_client() + if client is None: + return gr.Markdown( + value=""" +
+

Dataset Published Successfully!

+

+ The generated dataset is in the right format for fine-tuning with TRL, AutoTrain, or other frameworks. Your dataset is now available at: + + https://huggingface.co/datasets/{org_name}/{repo_name} + +

+

+ By configuring an `ARGILLA_API_URL` and `ARGILLA_API_KEY` you can curate the dataset in Argilla. + Unfamiliar with Argilla? Here are some docs to help you get started: +
How to get started with Argilla +
How to curate data in Argilla +
How to export data once you have reviewed the dataset +

+
+ """ + ) argilla_api_url = client.api_url return gr.Markdown( value=f""" diff --git a/src/distilabel_dataset_generator/apps/eval.py b/src/distilabel_dataset_generator/apps/eval.py index f415760..6e4a60a 100644 --- a/src/distilabel_dataset_generator/apps/eval.py +++ b/src/distilabel_dataset_generator/apps/eval.py @@ -39,9 +39,9 @@ extract_column_names, get_argilla_client, get_org_dropdown, + pad_or_truncate_list, process_columns, swap_visibility, - pad_or_truncate_list, ) @@ -334,8 +334,10 @@ def push_dataset( push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private) try: progress(0.1, desc="Setting up user and workspace") - client = get_argilla_client() hf_user = HfApi().whoami(token=oauth_token.token)["name"] + client = get_argilla_client() + if client is None: + return "" if eval_type == "ultrafeedback": num_generations = len((dataframe["generations"][0])) fields = [ @@ -580,6 +582,7 @@ def push_dataset( def show_pipeline_code_visibility(): return {pipeline_code_ui: gr.Accordion(visible=True)} + def hide_pipeline_code_visibility(): return {pipeline_code_ui: gr.Accordion(visible=False)} @@ -708,15 +711,15 @@ def hide_pipeline_code_visibility(): visible=False, ) as pipeline_code_ui: code = generate_pipeline_code( - repo_id=search_in.value, - aspects=aspects_instruction_response.value, - instruction_column=instruction_instruction_response, - response_columns=response_instruction_response, - prompt_template=prompt_template.value, - structured_output=structured_output.value, - num_rows=num_rows.value, - eval_type=eval_type.value, - ) + repo_id=search_in.value, + aspects=aspects_instruction_response.value, + instruction_column=instruction_instruction_response, + response_columns=response_instruction_response, + prompt_template=prompt_template.value, + structured_output=structured_output.value, + num_rows=num_rows.value, + eval_type=eval_type.value, + ) pipeline_code = gr.Code( value=code, language="python", diff --git a/src/distilabel_dataset_generator/apps/sft.py b/src/distilabel_dataset_generator/apps/sft.py index 0e9986b..fad57d1 100644 --- a/src/distilabel_dataset_generator/apps/sft.py +++ b/src/distilabel_dataset_generator/apps/sft.py @@ -220,8 +220,10 @@ def push_dataset( push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private) try: progress(0.1, desc="Setting up user and workspace") - client = get_argilla_client() hf_user = HfApi().whoami(token=oauth_token.token)["name"] + client = get_argilla_client() + if client is None: + return "" if "messages" in dataframe.columns: settings = rg.Settings( fields=[ diff --git a/src/distilabel_dataset_generator/apps/textcat.py b/src/distilabel_dataset_generator/apps/textcat.py index e40464e..2666d0a 100644 --- a/src/distilabel_dataset_generator/apps/textcat.py +++ b/src/distilabel_dataset_generator/apps/textcat.py @@ -58,7 +58,10 @@ def generate_system_prompt(dataset_description, temperature, progress=gr.Progres labels = data["labels"] return system_prompt, labels -def generate_sample_dataset(system_prompt, difficulty, clarity, labels, num_labels, progress=gr.Progress()): + +def generate_sample_dataset( + system_prompt, difficulty, clarity, labels, num_labels, progress=gr.Progress() +): dataframe = generate_dataset( system_prompt=system_prompt, difficulty=difficulty, @@ -138,11 +141,7 @@ def generate_dataset( # create final dataset distiset_results = [] for result in labeller_results: - record = { - key: result[key] - for key in ["labels", "text"] - if key in result - } + record = {key: result[key] for key in ["labels", "text"] if key in result} distiset_results.append(record) dataframe = pd.DataFrame(distiset_results) @@ -212,13 +211,16 @@ def push_dataset( push_dataset_to_hub( dataframe, org_name, repo_name, num_labels, labels, oauth_token, private ) + dataframe = dataframe[ (dataframe["text"].str.strip() != "") & (dataframe["text"].notna()) ] try: progress(0.1, desc="Setting up user and workspace") - client = get_argilla_client() hf_user = HfApi().whoami(token=oauth_token.token)["name"] + client = get_argilla_client() + if client is None: + return "" labels = get_preprocess_labels(labels) settings = rg.Settings( fields=[