From 2bae26c0d4c33a1d81c077e92cb1002b22410ac7 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 3 Dec 2024 13:02:41 +0100 Subject: [PATCH 01/11] refactor unused imports --- src/distilabel_dataset_generator/__init__.py | 5 ++--- src/distilabel_dataset_generator/apps/eval.py | 21 ++++++++++--------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/distilabel_dataset_generator/__init__.py b/src/distilabel_dataset_generator/__init__.py index 1c9126c..582b6d5 100644 --- a/src/distilabel_dataset_generator/__init__.py +++ b/src/distilabel_dataset_generator/__init__.py @@ -1,7 +1,6 @@ import os import warnings -from pathlib import Path -from typing import Optional, Union +from typing import Optional import argilla as rg import distilabel @@ -10,7 +9,7 @@ DistilabelDatasetCard, size_categories_parser, ) -from huggingface_hub import DatasetCardData, HfApi, upload_file +from huggingface_hub import DatasetCardData, HfApi HF_TOKENS = [os.getenv("HF_TOKEN")] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)] HF_TOKENS = [token for token in HF_TOKENS if token] diff --git a/src/distilabel_dataset_generator/apps/eval.py b/src/distilabel_dataset_generator/apps/eval.py index f415760..a20cade 100644 --- a/src/distilabel_dataset_generator/apps/eval.py +++ b/src/distilabel_dataset_generator/apps/eval.py @@ -39,9 +39,9 @@ extract_column_names, get_argilla_client, get_org_dropdown, + pad_or_truncate_list, process_columns, swap_visibility, - pad_or_truncate_list, ) @@ -580,6 +580,7 @@ def push_dataset( def show_pipeline_code_visibility(): return {pipeline_code_ui: gr.Accordion(visible=True)} + def hide_pipeline_code_visibility(): return {pipeline_code_ui: gr.Accordion(visible=False)} @@ -708,15 +709,15 @@ def hide_pipeline_code_visibility(): visible=False, ) as pipeline_code_ui: code = generate_pipeline_code( - repo_id=search_in.value, - aspects=aspects_instruction_response.value, - instruction_column=instruction_instruction_response, - response_columns=response_instruction_response, - prompt_template=prompt_template.value, - structured_output=structured_output.value, - num_rows=num_rows.value, - eval_type=eval_type.value, - ) + repo_id=search_in.value, + aspects=aspects_instruction_response.value, + instruction_column=instruction_instruction_response, + response_columns=response_instruction_response, + prompt_template=prompt_template.value, + structured_output=structured_output.value, + num_rows=num_rows.value, + eval_type=eval_type.value, + ) pipeline_code = gr.Code( value=code, language="python", From cd47483ac70f13aa2fc3d2583bc0d8ae50f37763 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 3 Dec 2024 14:00:18 +0100 Subject: [PATCH 02/11] add support for custom BASE_URL, MODEL, APIKEY --- README.md | 8 +- app.py | 10 +- pyproject.toml | 16 +- src/distilabel_dataset_generator/__init__.py | 26 -- .../apps/__init__.py | 0 src/distilabel_dataset_generator/apps/base.py | 2 +- src/distilabel_dataset_generator/apps/eval.py | 12 +- src/distilabel_dataset_generator/apps/sft.py | 333 +++++++++--------- .../apps/textcat.py | 4 +- src/distilabel_dataset_generator/constants.py | 55 +++ .../pipelines/__init__.py | 0 .../pipelines/base.py | 6 +- .../pipelines/embeddings.py | 2 +- .../pipelines/eval.py | 29 +- .../pipelines/sft.py | 21 +- .../pipelines/textcat.py | 27 +- src/distilabel_dataset_generator/utils.py | 2 +- 17 files changed, 302 insertions(+), 251 deletions(-) create mode 100644 src/distilabel_dataset_generator/apps/__init__.py create mode 100644 src/distilabel_dataset_generator/constants.py create mode 100644 src/distilabel_dataset_generator/pipelines/__init__.py diff --git a/README.md b/README.md index 16c61ad..0637301 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,13 @@ pip install synthetic-dataset-generator ### Environment Variables -- `HF_TOKEN`: Your Hugging Face token to push your datasets to the Hugging Face Hub and run *Free* Inference Endpoints Requests. You can get one [here](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained). +- `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. + +Optionally, you can set the following environment variables to customize the generation process. + +- `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api-inference.huggingface.co/v1/`. +- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`. +- `API_KEY`: The API key to use for the corresponding API, e.g. `hf_...`. Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables: diff --git a/app.py b/app.py index 04b9409..53ec94f 100644 --- a/app.py +++ b/app.py @@ -1,8 +1,8 @@ -from src.distilabel_dataset_generator._tabbedinterface import TabbedInterface -from src.distilabel_dataset_generator.apps.eval import app as eval_app -from src.distilabel_dataset_generator.apps.faq import app as faq_app -from src.distilabel_dataset_generator.apps.sft import app as sft_app -from src.distilabel_dataset_generator.apps.textcat import app as textcat_app +from distilabel_dataset_generator._tabbedinterface import TabbedInterface +from distilabel_dataset_generator.apps.eval import app as eval_app +from distilabel_dataset_generator.apps.faq import app as faq_app +from distilabel_dataset_generator.apps.sft import app as sft_app +from distilabel_dataset_generator.apps.textcat import app as textcat_app theme = "argilla/argilla-theme" diff --git a/pyproject.toml b/pyproject.toml index 47c3a51..ddf19ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,18 @@ description = "Build datasets using natural language" authors = [ {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"}, ] +tags = [ + "gradio", + "synthetic-data", + "huggingface", + "argilla", + "generative-ai", + "ai", +] +requires-python = "<3.13,>=3.10" +readme = "README.md" +license = {text = "Apache 2"} + dependencies = [ "distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1", "gradio[oauth]<5.0.0", @@ -14,14 +26,10 @@ dependencies = [ "gradio-huggingfacehub-search>=0.0.7", "argilla>=2.4.0", ] -requires-python = "<3.13,>=3.10" -readme = "README.md" -license = {text = "apache 2"} [build-system] requires = ["pdm-backend"] build-backend = "pdm.backend" - [tool.pdm] distribution = true diff --git a/src/distilabel_dataset_generator/__init__.py b/src/distilabel_dataset_generator/__init__.py index 582b6d5..a0ec0dc 100644 --- a/src/distilabel_dataset_generator/__init__.py +++ b/src/distilabel_dataset_generator/__init__.py @@ -1,8 +1,5 @@ -import os -import warnings from typing import Optional -import argilla as rg import distilabel import distilabel.distiset from distilabel.utils.card.dataset_card import ( @@ -11,29 +8,6 @@ ) from huggingface_hub import DatasetCardData, HfApi -HF_TOKENS = [os.getenv("HF_TOKEN")] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)] -HF_TOKENS = [token for token in HF_TOKENS if token] - -if len(HF_TOKENS) == 0: - raise ValueError( - "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints." - ) - -ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") -ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") -if ARGILLA_API_URL is None or ARGILLA_API_KEY is None: - ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER") - ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER") - -if ARGILLA_API_URL is None or ARGILLA_API_KEY is None: - warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set") - argilla_client = None -else: - argilla_client = rg.Argilla( - api_url=ARGILLA_API_URL, - api_key=ARGILLA_API_KEY, - ) - class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset): def _generate_card( diff --git a/src/distilabel_dataset_generator/apps/__init__.py b/src/distilabel_dataset_generator/apps/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/distilabel_dataset_generator/apps/base.py b/src/distilabel_dataset_generator/apps/base.py index 0b6cc4f..895afbc 100644 --- a/src/distilabel_dataset_generator/apps/base.py +++ b/src/distilabel_dataset_generator/apps/base.py @@ -10,7 +10,7 @@ from gradio import OAuthToken from huggingface_hub import HfApi, upload_file -from src.distilabel_dataset_generator.utils import ( +from distilabel_dataset_generator.utils import ( _LOGGED_OUT_CSS, get_argilla_client, get_login_button, diff --git a/src/distilabel_dataset_generator/apps/eval.py b/src/distilabel_dataset_generator/apps/eval.py index 6e4a60a..1136fe1 100644 --- a/src/distilabel_dataset_generator/apps/eval.py +++ b/src/distilabel_dataset_generator/apps/eval.py @@ -16,25 +16,23 @@ from gradio_huggingfacehub_search import HuggingfaceHubSearch from huggingface_hub import HfApi -from src.distilabel_dataset_generator.apps.base import ( +from distilabel_dataset_generator.apps.base import ( hide_success_message, show_success_message, validate_argilla_user_workspace_dataset, validate_push_to_hub, ) -from src.distilabel_dataset_generator.pipelines.base import ( - DEFAULT_BATCH_SIZE, -) -from src.distilabel_dataset_generator.pipelines.embeddings import ( +from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE +from distilabel_dataset_generator.pipelines.embeddings import ( get_embeddings, get_sentence_embedding_dimensions, ) -from src.distilabel_dataset_generator.pipelines.eval import ( +from distilabel_dataset_generator.pipelines.eval import ( generate_pipeline_code, get_custom_evaluator, get_ultrafeedback_evaluator, ) -from src.distilabel_dataset_generator.utils import ( +from distilabel_dataset_generator.utils import ( column_to_list, extract_column_names, get_argilla_client, diff --git a/src/distilabel_dataset_generator/apps/sft.py b/src/distilabel_dataset_generator/apps/sft.py index fad57d1..d071c64 100644 --- a/src/distilabel_dataset_generator/apps/sft.py +++ b/src/distilabel_dataset_generator/apps/sft.py @@ -9,27 +9,25 @@ from distilabel.distiset import Distiset from huggingface_hub import HfApi -from src.distilabel_dataset_generator.apps.base import ( +from distilabel_dataset_generator.apps.base import ( hide_success_message, show_success_message, validate_argilla_user_workspace_dataset, validate_push_to_hub, ) -from src.distilabel_dataset_generator.pipelines.base import ( - DEFAULT_BATCH_SIZE, -) -from src.distilabel_dataset_generator.pipelines.embeddings import ( +from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE, SFT_AVAILABLE +from distilabel_dataset_generator.pipelines.embeddings import ( get_embeddings, get_sentence_embedding_dimensions, ) -from src.distilabel_dataset_generator.pipelines.sft import ( +from distilabel_dataset_generator.pipelines.sft import ( DEFAULT_DATASET_DESCRIPTIONS, generate_pipeline_code, get_magpie_generator, get_prompt_generator, get_response_generator, ) -from src.distilabel_dataset_generator.utils import ( +from distilabel_dataset_generator.utils import ( _LOGGED_OUT_CSS, get_argilla_client, get_org_dropdown, @@ -354,168 +352,175 @@ def hide_pipeline_code_visibility(): with gr.Blocks(css=_LOGGED_OUT_CSS) as app: with gr.Column() as main_ui: - gr.Markdown(value="## 1. Describe the dataset you want") - with gr.Row(): - with gr.Column(scale=2): - dataset_description = gr.Textbox( - label="Dataset description", - placeholder="Give a precise description of your desired dataset.", - ) - with gr.Accordion("Temperature", open=False): - temperature = gr.Slider( - minimum=0.1, - maximum=1, - value=0.8, - step=0.1, + if not SFT_AVAILABLE: + gr.Markdown( + value=f"## Supervised Fine-Tuning is not available for the {MODEL} model. Use Hugging Face Llama3 or Qwen2 models." + ) + else: + gr.Markdown(value="## 1. Describe the dataset you want") + with gr.Row(): + with gr.Column(scale=2): + dataset_description = gr.Textbox( + label="Dataset description", + placeholder="Give a precise description of your desired dataset.", + ) + with gr.Accordion("Temperature", open=False): + temperature = gr.Slider( + minimum=0.1, + maximum=1, + value=0.8, + step=0.1, + interactive=True, + show_label=False, + ) + load_btn = gr.Button( + "Create dataset", + variant="primary", + ) + with gr.Column(scale=2): + examples = gr.Examples( + examples=DEFAULT_DATASET_DESCRIPTIONS, + inputs=[dataset_description], + cache_examples=False, + label="Examples", + ) + with gr.Column(scale=1): + pass + + gr.HTML(value="
") + gr.Markdown(value="## 2. Configure your dataset") + with gr.Row(equal_height=False): + with gr.Column(scale=2): + system_prompt = gr.Textbox( + label="System prompt", + placeholder="You are a helpful assistant.", + ) + num_turns = gr.Number( + value=1, + label="Number of turns in the conversation", + minimum=1, + maximum=4, + step=1, interactive=True, - show_label=False, + info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).", ) - load_btn = gr.Button( - "Create dataset", - variant="primary", - ) - with gr.Column(scale=2): - examples = gr.Examples( - examples=DEFAULT_DATASET_DESCRIPTIONS, - inputs=[dataset_description], - cache_examples=False, - label="Examples", - ) - with gr.Column(scale=1): - pass - - gr.HTML(value="
") - gr.Markdown(value="## 2. Configure your dataset") - with gr.Row(equal_height=False): - with gr.Column(scale=2): - system_prompt = gr.Textbox( - label="System prompt", - placeholder="You are a helpful assistant.", - ) - num_turns = gr.Number( - value=1, - label="Number of turns in the conversation", - minimum=1, - maximum=4, - step=1, - interactive=True, - info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).", - ) - btn_apply_to_sample_dataset = gr.Button( - "Refresh dataset", variant="secondary" - ) - with gr.Column(scale=3): - dataframe = gr.Dataframe( - headers=["prompt", "completion"], - wrap=True, - height=500, - interactive=False, - ) - - gr.HTML(value="
") - gr.Markdown(value="## 3. Generate your dataset") - with gr.Row(equal_height=False): - with gr.Column(scale=2): - org_name = get_org_dropdown() - repo_name = gr.Textbox( - label="Repo name", - placeholder="dataset_name", - value=f"my-distiset-{str(uuid.uuid4())[:8]}", - interactive=True, - ) - num_rows = gr.Number( - label="Number of rows", - value=10, - interactive=True, - scale=1, - ) - private = gr.Checkbox( - label="Private dataset", - value=False, - interactive=True, - scale=1, - ) - btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2) - with gr.Column(scale=3): - success_message = gr.Markdown(visible=True) - with gr.Accordion( - "Do you want to go further? Customize and run with Distilabel", - open=False, - visible=False, - ) as pipeline_code_ui: - code = generate_pipeline_code( - system_prompt=system_prompt.value, - num_turns=num_turns.value, - num_rows=num_rows.value, + btn_apply_to_sample_dataset = gr.Button( + "Refresh dataset", variant="secondary" ) - pipeline_code = gr.Code( - value=code, - language="python", - label="Distilabel Pipeline Code", + with gr.Column(scale=3): + dataframe = gr.Dataframe( + headers=["prompt", "completion"], + wrap=True, + height=500, + interactive=False, ) - load_btn.click( - fn=generate_system_prompt, - inputs=[dataset_description, temperature], - outputs=[system_prompt], - show_progress=True, - ).then( - fn=generate_sample_dataset, - inputs=[system_prompt, num_turns], - outputs=[dataframe], - show_progress=True, - ) + gr.HTML(value="
") + gr.Markdown(value="## 3. Generate your dataset") + with gr.Row(equal_height=False): + with gr.Column(scale=2): + org_name = get_org_dropdown() + repo_name = gr.Textbox( + label="Repo name", + placeholder="dataset_name", + value=f"my-distiset-{str(uuid.uuid4())[:8]}", + interactive=True, + ) + num_rows = gr.Number( + label="Number of rows", + value=10, + interactive=True, + scale=1, + ) + private = gr.Checkbox( + label="Private dataset", + value=False, + interactive=True, + scale=1, + ) + btn_push_to_hub = gr.Button( + "Push to Hub", variant="primary", scale=2 + ) + with gr.Column(scale=3): + success_message = gr.Markdown(visible=True) + with gr.Accordion( + "Do you want to go further? Customize and run with Distilabel", + open=False, + visible=False, + ) as pipeline_code_ui: + code = generate_pipeline_code( + system_prompt=system_prompt.value, + num_turns=num_turns.value, + num_rows=num_rows.value, + ) + pipeline_code = gr.Code( + value=code, + language="python", + label="Distilabel Pipeline Code", + ) + + load_btn.click( + fn=generate_system_prompt, + inputs=[dataset_description, temperature], + outputs=[system_prompt], + show_progress=True, + ).then( + fn=generate_sample_dataset, + inputs=[system_prompt, num_turns], + outputs=[dataframe], + show_progress=True, + ) - btn_apply_to_sample_dataset.click( - fn=generate_sample_dataset, - inputs=[system_prompt, num_turns], - outputs=[dataframe], - show_progress=True, - ) + btn_apply_to_sample_dataset.click( + fn=generate_sample_dataset, + inputs=[system_prompt, num_turns], + outputs=[dataframe], + show_progress=True, + ) - btn_push_to_hub.click( - fn=validate_argilla_user_workspace_dataset, - inputs=[repo_name], - outputs=[success_message], - show_progress=True, - ).then( - fn=validate_push_to_hub, - inputs=[org_name, repo_name], - outputs=[success_message], - show_progress=True, - ).success( - fn=hide_success_message, - outputs=[success_message], - show_progress=True, - ).success( - fn=hide_pipeline_code_visibility, - inputs=[], - outputs=[pipeline_code_ui], - ).success( - fn=push_dataset, - inputs=[ - org_name, - repo_name, - system_prompt, - num_turns, - num_rows, - private, - ], - outputs=[success_message], - show_progress=True, - ).success( - fn=show_success_message, - inputs=[org_name, repo_name], - outputs=[success_message], - ).success( - fn=generate_pipeline_code, - inputs=[system_prompt, num_turns, num_rows], - outputs=[pipeline_code], - ).success( - fn=show_pipeline_code_visibility, - inputs=[], - outputs=[pipeline_code_ui], - ) + btn_push_to_hub.click( + fn=validate_argilla_user_workspace_dataset, + inputs=[repo_name], + outputs=[success_message], + show_progress=True, + ).then( + fn=validate_push_to_hub, + inputs=[org_name, repo_name], + outputs=[success_message], + show_progress=True, + ).success( + fn=hide_success_message, + outputs=[success_message], + show_progress=True, + ).success( + fn=hide_pipeline_code_visibility, + inputs=[], + outputs=[pipeline_code_ui], + ).success( + fn=push_dataset, + inputs=[ + org_name, + repo_name, + system_prompt, + num_turns, + num_rows, + private, + ], + outputs=[success_message], + show_progress=True, + ).success( + fn=show_success_message, + inputs=[org_name, repo_name], + outputs=[success_message], + ).success( + fn=generate_pipeline_code, + inputs=[system_prompt, num_turns, num_rows], + outputs=[pipeline_code], + ).success( + fn=show_pipeline_code_visibility, + inputs=[], + outputs=[pipeline_code_ui], + ) - app.load(fn=swap_visibility, outputs=main_ui) - app.load(fn=get_org_dropdown, outputs=[org_name]) + app.load(fn=swap_visibility, outputs=main_ui) + app.load(fn=get_org_dropdown, outputs=[org_name]) diff --git a/src/distilabel_dataset_generator/apps/textcat.py b/src/distilabel_dataset_generator/apps/textcat.py index 2666d0a..ff4775c 100644 --- a/src/distilabel_dataset_generator/apps/textcat.py +++ b/src/distilabel_dataset_generator/apps/textcat.py @@ -9,15 +9,13 @@ from distilabel.distiset import Distiset from huggingface_hub import HfApi +from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE from src.distilabel_dataset_generator.apps.base import ( hide_success_message, show_success_message, validate_argilla_user_workspace_dataset, validate_push_to_hub, ) -from src.distilabel_dataset_generator.pipelines.base import ( - DEFAULT_BATCH_SIZE, -) from src.distilabel_dataset_generator.pipelines.embeddings import ( get_embeddings, get_sentence_embedding_dimensions, diff --git a/src/distilabel_dataset_generator/constants.py b/src/distilabel_dataset_generator/constants.py new file mode 100644 index 0000000..c9aaaa4 --- /dev/null +++ b/src/distilabel_dataset_generator/constants.py @@ -0,0 +1,55 @@ +import os +import warnings + +import argilla as rg + +# Hugging Face +HF_TOKEN = os.getenv("HF_TOKEN") +if HF_TOKEN is None: + raise ValueError( + "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints." + ) + +# Inference +DEFAULT_BATCH_SIZE = 5 +MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct") +API_KEYS = ( + [os.getenv("HF_TOKEN")] + + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)] + + [os.getenv("API_KEY")] +) +API_KEYS = [token for token in API_KEYS if token] +BASE_URL = os.getenv("BASE_URL", "https://api-inference.huggingface.co/v1/") + +if BASE_URL != "https://api-inference.huggingface.co/v1/" and len(API_KEYS) == 0: + raise ValueError( + "API_KEY is not set. Ensure you have set the API_KEY environment variable that has access to the Hugging Face Inference Endpoints." + ) +if "Qwen2" not in MODEL and "Llama-3" not in MODEL: + SFT_AVAILABLE = False + warnings.warn( + "SFT_AVAILABLE is set to False because the model is not a Qwen or Llama model." + ) + MAGPIE_PRE_QUERY_TEMPLATE = None +else: + SFT_AVAILABLE = True + if "Qwen2" in MODEL: + MAGPIE_PRE_QUERY_TEMPLATE = "qwen2" + else: + MAGPIE_PRE_QUERY_TEMPLATE = "llama3" + +# Argilla +ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") +ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") +if ARGILLA_API_URL is None or ARGILLA_API_KEY is None: + ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER") + ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER") + +if ARGILLA_API_URL is None or ARGILLA_API_KEY is None: + warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set") + argilla_client = None +else: + argilla_client = rg.Argilla( + api_url=ARGILLA_API_URL, + api_key=ARGILLA_API_KEY, + ) diff --git a/src/distilabel_dataset_generator/pipelines/__init__.py b/src/distilabel_dataset_generator/pipelines/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/distilabel_dataset_generator/pipelines/base.py b/src/distilabel_dataset_generator/pipelines/base.py index ec54f95..22510c2 100644 --- a/src/distilabel_dataset_generator/pipelines/base.py +++ b/src/distilabel_dataset_generator/pipelines/base.py @@ -1,12 +1,10 @@ -from src.distilabel_dataset_generator import HF_TOKENS +from distilabel_dataset_generator.constants import API_KEYS -DEFAULT_BATCH_SIZE = 5 TOKEN_INDEX = 0 -MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct" def _get_next_api_key(): global TOKEN_INDEX - api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)] + api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)] TOKEN_INDEX += 1 return api_key diff --git a/src/distilabel_dataset_generator/pipelines/embeddings.py b/src/distilabel_dataset_generator/pipelines/embeddings.py index bcd99ef..1ab3873 100644 --- a/src/distilabel_dataset_generator/pipelines/embeddings.py +++ b/src/distilabel_dataset_generator/pipelines/embeddings.py @@ -4,7 +4,7 @@ from sentence_transformers.models import StaticEmbedding # Initialize a StaticEmbedding module -static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output") +static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-8M") model = SentenceTransformer(modules=[static_embedding]) diff --git a/src/distilabel_dataset_generator/pipelines/eval.py b/src/distilabel_dataset_generator/pipelines/eval.py index cf1d25b..ee2959a 100644 --- a/src/distilabel_dataset_generator/pipelines/eval.py +++ b/src/distilabel_dataset_generator/pipelines/eval.py @@ -5,18 +5,16 @@ UltraFeedback, ) -from src.distilabel_dataset_generator.pipelines.base import ( - MODEL, - _get_next_api_key, -) -from src.distilabel_dataset_generator.utils import extract_column_names +from distilabel_dataset_generator.constants import BASE_URL, MODEL +from distilabel_dataset_generator.pipelines.base import _get_next_api_key +from distilabel_dataset_generator.utils import extract_column_names def get_ultrafeedback_evaluator(aspect, is_sample): ultrafeedback_evaluator = UltraFeedback( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), generation_kwargs={ "temperature": 0, @@ -33,7 +31,7 @@ def get_custom_evaluator(prompt_template, structured_output, columns, is_sample) custom_evaluator = TextGeneration( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), structured_output={"format": "json", "schema": structured_output}, generation_kwargs={ @@ -62,7 +60,8 @@ def generate_ultrafeedback_pipeline_code( from distilabel.llms import InferenceEndpointsLLM MODEL = "{MODEL}" -os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained +BASE_URL = "{BASE_URL}" +os.environ["API_KEY"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}[:{num_rows}]") data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries @@ -76,8 +75,8 @@ def generate_ultrafeedback_pipeline_code( ultrafeedback_evaluator = UltraFeedback( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, - api_key=os.environ["HF_TOKEN"], + base_url=BASE_URL, + api_key=os.environ["API_KEY"], generation_kwargs={{ "temperature": 0, "max_new_tokens": 2048, @@ -101,7 +100,8 @@ def generate_ultrafeedback_pipeline_code( from distilabel.llms import InferenceEndpointsLLM MODEL = "{MODEL}" -os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained +BASE_URL = "{BASE_URL}" +os.environ["BASE_URL"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}") data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries @@ -119,8 +119,8 @@ def generate_ultrafeedback_pipeline_code( aspect=aspect, llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, - api_key=os.environ["HF_TOKEN"], + base_url=BASE_URL, + api_key=os.environ["BASE_URL"], generation_kwargs={{ "temperature": 0, "max_new_tokens": 2048, @@ -157,6 +157,7 @@ def generate_custom_pipeline_code( from distilabel.llms import InferenceEndpointsLLM MODEL = "{MODEL}" +BASE_URL = "{BASE_URL}" CUSTOM_TEMPLATE = "{prompt_template}" os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained @@ -171,7 +172,7 @@ def generate_custom_pipeline_code( custom_evaluator = TextGeneration( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, + base_url=BASE_URL, api_key=os.environ["HF_TOKEN"], structured_output={{"format": "json", "schema": {structured_output}}}, generation_kwargs={{ diff --git a/src/distilabel_dataset_generator/pipelines/sft.py b/src/distilabel_dataset_generator/pipelines/sft.py index 240e973..920f40d 100644 --- a/src/distilabel_dataset_generator/pipelines/sft.py +++ b/src/distilabel_dataset_generator/pipelines/sft.py @@ -1,10 +1,12 @@ from distilabel.llms import InferenceEndpointsLLM from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration -from src.distilabel_dataset_generator.pipelines.base import ( +from distilabel_dataset_generator.constants import ( + BASE_URL, + MAGPIE_PRE_QUERY_TEMPLATE, MODEL, - _get_next_api_key, ) +from distilabel_dataset_generator.pipelines.base import _get_next_api_key INFORMATION_SEEKING_PROMPT = ( "You are an AI assistant designed to provide accurate and concise information on a wide" @@ -144,6 +146,7 @@ def get_prompt_generator(temperature): api_key=_get_next_api_key(), model_id=MODEL, tokenizer_id=MODEL, + base_url=BASE_URL, generation_kwargs={ "temperature": temperature, "max_new_tokens": 2048, @@ -165,8 +168,9 @@ def get_magpie_generator(system_prompt, num_turns, is_sample): llm=InferenceEndpointsLLM( model_id=MODEL, tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), - magpie_pre_query_template="llama3", + magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE, generation_kwargs={ "temperature": 0.9, "do_sample": True, @@ -184,8 +188,9 @@ def get_magpie_generator(system_prompt, num_turns, is_sample): llm=InferenceEndpointsLLM( model_id=MODEL, tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), - magpie_pre_query_template="llama3", + magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE, generation_kwargs={ "temperature": 0.9, "do_sample": True, @@ -208,6 +213,7 @@ def get_response_generator(system_prompt, num_turns, is_sample): llm=InferenceEndpointsLLM( model_id=MODEL, tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), generation_kwargs={ "temperature": 0.8, @@ -223,6 +229,7 @@ def get_response_generator(system_prompt, num_turns, is_sample): llm=InferenceEndpointsLLM( model_id=MODEL, tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), generation_kwargs={ "temperature": 0.8, @@ -247,14 +254,16 @@ def generate_pipeline_code(system_prompt, num_turns, num_rows): from distilabel.llms import InferenceEndpointsLLM MODEL = "{MODEL}" +BASE_URL = "{BASE_URL}" SYSTEM_PROMPT = "{system_prompt}" -os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained +os.environ["API_KEY"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained with Pipeline(name="sft") as pipeline: magpie = MagpieGenerator( llm=InferenceEndpointsLLM( model_id=MODEL, tokenizer_id=MODEL, + base_url=BASE_URL, magpie_pre_query_template="llama3", generation_kwargs={{ "temperature": 0.9, @@ -262,7 +271,7 @@ def generate_pipeline_code(system_prompt, num_turns, num_rows): "max_new_tokens": 2048, "stop_sequences": {_STOP_SEQUENCES} }}, - api_key=os.environ["HF_TOKEN"], + api_key=os.environ["BASE_URL"], ), n_turns={num_turns}, num_rows={num_rows}, diff --git a/src/distilabel_dataset_generator/pipelines/textcat.py b/src/distilabel_dataset_generator/pipelines/textcat.py index e17f594..1c88e86 100644 --- a/src/distilabel_dataset_generator/pipelines/textcat.py +++ b/src/distilabel_dataset_generator/pipelines/textcat.py @@ -1,5 +1,4 @@ import random -from pydantic import BaseModel, Field from typing import List from distilabel.llms import InferenceEndpointsLLM @@ -8,12 +7,11 @@ TextClassification, TextGeneration, ) +from pydantic import BaseModel, Field -from src.distilabel_dataset_generator.pipelines.base import ( - MODEL, - _get_next_api_key, -) -from src.distilabel_dataset_generator.utils import get_preprocess_labels +from distilabel_dataset_generator.constants import BASE_URL, MODEL +from distilabel_dataset_generator.pipelines.base import _get_next_api_key +from distilabel_dataset_generator.utils import get_preprocess_labels PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation. @@ -73,7 +71,7 @@ def get_prompt_generator(temperature): llm=InferenceEndpointsLLM( api_key=_get_next_api_key(), model_id=MODEL, - tokenizer_id=MODEL, + base_url=BASE_URL, structured_output={"format": "json", "schema": TextClassificationTask}, generation_kwargs={ "temperature": temperature, @@ -92,7 +90,7 @@ def get_textcat_generator(difficulty, clarity, is_sample): textcat_generator = GenerateTextClassificationData( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), generation_kwargs={ "temperature": 0.9, @@ -114,7 +112,7 @@ def get_labeller_generator(system_prompt, labels, num_labels): labeller_generator = TextClassification( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), generation_kwargs={ "temperature": 0.7, @@ -149,8 +147,9 @@ def generate_pipeline_code( from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"} MODEL = "{MODEL}" +BASE_URL = "{BASE_URL}" TEXT_CLASSIFICATION_TASK = "{system_prompt}" -os.environ["HF_TOKEN"] = ( +os.environ["API_KEY"] = ( "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained ) @@ -161,8 +160,8 @@ def generate_pipeline_code( textcat_generation = GenerateTextClassificationData( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, - api_key=os.environ["HF_TOKEN"], + base_url=BASE_URL, + api_key=os.environ["API_KEY"], generation_kwargs={{ "temperature": 0.8, "max_new_tokens": 2048, @@ -205,8 +204,8 @@ def generate_pipeline_code( textcat_labeller = TextClassification( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, - api_key=os.environ["HF_TOKEN"], + base_url=BASE_URL, + api_key=os.environ["API_KEY"], generation_kwargs={{ "temperature": 0.8, "max_new_tokens": 2048, diff --git a/src/distilabel_dataset_generator/utils.py b/src/distilabel_dataset_generator/utils.py index 68b0b77..26def99 100644 --- a/src/distilabel_dataset_generator/utils.py +++ b/src/distilabel_dataset_generator/utils.py @@ -15,7 +15,7 @@ from huggingface_hub import whoami from jinja2 import Environment, meta -from src.distilabel_dataset_generator import argilla_client +from distilabel_dataset_generator.constants import argilla_client _LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}" From cc48701dea989098b2e8e044132d8c18d44d14ba Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 3 Dec 2024 15:07:44 +0100 Subject: [PATCH 03/11] add info about deployment and usage --- README.md | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 0637301..821117e 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,14 @@ You can simply install the package with: pip install synthetic-dataset-generator ``` +### Quickstart + +```python +from synthetic_dataset_generator.app import demo + +demo.launch() +``` + ### Environment Variables - `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. @@ -93,12 +101,6 @@ Optionally, you can also push your datasets to Argilla for further curation by s - `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla. - `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla. -## Quickstart - -```bash -python app.py -``` - ### Argilla integration Argilla is a open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/). @@ -110,3 +112,19 @@ Argilla is a open source tool for data curation. It allows you to annotate and r Each pipeline is based on distilabel, so you can easily change the LLM or the pipeline steps. Check out the [distilabel library](https://github.com/argilla-io/distilabel) for more information. + +## Development + +Install the dependencies: + +```bash +python -m venv .venv +source .venv/bin/activate +pip install -e . +``` + +Run the app: + +```bash +python app.py +``` From 44fc64a29166116728473296c492e367b6e277d2 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 3 Dec 2024 15:07:57 +0100 Subject: [PATCH 04/11] update imports for application usage --- app.py | 36 +---------------------- src/distilabel_dataset_generator/app.py | 38 +++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 35 deletions(-) create mode 100644 src/distilabel_dataset_generator/app.py diff --git a/app.py b/app.py index 53ec94f..a952cb7 100644 --- a/app.py +++ b/app.py @@ -1,38 +1,4 @@ -from distilabel_dataset_generator._tabbedinterface import TabbedInterface -from distilabel_dataset_generator.apps.eval import app as eval_app -from distilabel_dataset_generator.apps.faq import app as faq_app -from distilabel_dataset_generator.apps.sft import app as sft_app -from distilabel_dataset_generator.apps.textcat import app as textcat_app - -theme = "argilla/argilla-theme" - -css = """ -button[role="tab"][aria-selected="true"] { border: 0; background: var(--neutral-800); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)} -button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill)} -button.hf-login {background: var(--neutral-800); color: white} -button.hf-login:hover {background: var(--neutral-700); color: white} -.tabitem { border: 0; padding-inline: 0} -.main_ui_logged_out{opacity: 0.3; pointer-events: none} -.group_padding{padding: .55em} -.gallery-item {background: var(--background-fill-secondary); text-align: left} -.gallery {white-space: wrap} -#space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none} -#system_prompt_examples { - color: var(--body-text-color) !important; - background-color: var(--block-background-fill) !important; -} -.container {padding-inline: 0 !important} -""" - -demo = TabbedInterface( - [textcat_app, sft_app, eval_app, faq_app], - ["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"], - css=css, - title="Synthetic Data Generator", - head="Synthetic Data Generator", - theme=theme, -) - +from distilabel_dataset_generator.app import demo if __name__ == "__main__": demo.launch() diff --git a/src/distilabel_dataset_generator/app.py b/src/distilabel_dataset_generator/app.py new file mode 100644 index 0000000..53ec94f --- /dev/null +++ b/src/distilabel_dataset_generator/app.py @@ -0,0 +1,38 @@ +from distilabel_dataset_generator._tabbedinterface import TabbedInterface +from distilabel_dataset_generator.apps.eval import app as eval_app +from distilabel_dataset_generator.apps.faq import app as faq_app +from distilabel_dataset_generator.apps.sft import app as sft_app +from distilabel_dataset_generator.apps.textcat import app as textcat_app + +theme = "argilla/argilla-theme" + +css = """ +button[role="tab"][aria-selected="true"] { border: 0; background: var(--neutral-800); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)} +button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill)} +button.hf-login {background: var(--neutral-800); color: white} +button.hf-login:hover {background: var(--neutral-700); color: white} +.tabitem { border: 0; padding-inline: 0} +.main_ui_logged_out{opacity: 0.3; pointer-events: none} +.group_padding{padding: .55em} +.gallery-item {background: var(--background-fill-secondary); text-align: left} +.gallery {white-space: wrap} +#space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none} +#system_prompt_examples { + color: var(--body-text-color) !important; + background-color: var(--block-background-fill) !important; +} +.container {padding-inline: 0 !important} +""" + +demo = TabbedInterface( + [textcat_app, sft_app, eval_app, faq_app], + ["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"], + css=css, + title="Synthetic Data Generator", + head="Synthetic Data Generator", + theme=theme, +) + + +if __name__ == "__main__": + demo.launch() From 65217751a9331dc89ab50ec322b10802050374df Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 3 Dec 2024 15:18:01 +0100 Subject: [PATCH 05/11] remove obsolete code --- .python-version | 1 - app.py | 5 + src/distilabel_dataset_generator/apps/base.py | 337 +----------------- src/distilabel_dataset_generator/apps/sft.py | 3 +- .../apps/textcat.py | 3 +- src/distilabel_dataset_generator/utils.py | 58 +-- 6 files changed, 9 insertions(+), 398 deletions(-) delete mode 100644 .python-version diff --git a/.python-version b/.python-version deleted file mode 100644 index 9f675fa..0000000 --- a/.python-version +++ /dev/null @@ -1 +0,0 @@ -synthetic-data-generator diff --git a/app.py b/app.py index a952cb7..a6e9bae 100644 --- a/app.py +++ b/app.py @@ -1,4 +1,9 @@ +import os + from distilabel_dataset_generator.app import demo +os.environ["API_KEY"] = "hf_..." +os.environ["API_KEY"] = "hf_..." + if __name__ == "__main__": demo.launch() diff --git a/src/distilabel_dataset_generator/apps/base.py b/src/distilabel_dataset_generator/apps/base.py index 895afbc..da19daf 100644 --- a/src/distilabel_dataset_generator/apps/base.py +++ b/src/distilabel_dataset_generator/apps/base.py @@ -1,6 +1,6 @@ import io import uuid -from typing import Any, Callable, List, Tuple, Union +from typing import List, Union import argilla as rg import gradio as gr @@ -11,161 +11,14 @@ from huggingface_hub import HfApi, upload_file from distilabel_dataset_generator.utils import ( - _LOGGED_OUT_CSS, get_argilla_client, - get_login_button, list_orgs, - swap_visibility, ) TEXTCAT_TASK = "text_classification" SFT_TASK = "supervised_fine_tuning" -def get_main_ui( - default_dataset_descriptions: List[str], - default_system_prompts: List[str], - default_datasets: List[pd.DataFrame], - fn_generate_system_prompt: Callable, - fn_generate_dataset: Callable, - task: str, -): - def fn_generate_sample_dataset(system_prompt, progress=gr.Progress()): - if system_prompt in default_system_prompts: - index = default_system_prompts.index(system_prompt) - if index < len(default_datasets): - return default_datasets[index] - if task == TEXTCAT_TASK: - result = fn_generate_dataset( - system_prompt=system_prompt, - difficulty="high school", - clarity="clear", - labels=[], - num_labels=1, - num_rows=1, - progress=progress, - is_sample=True, - ) - else: - result = fn_generate_dataset( - system_prompt=system_prompt, - num_turns=1, - num_rows=1, - progress=progress, - is_sample=True, - ) - return result - - with gr.Blocks( - title="🧬 Synthetic Data Generator", - head="🧬 Synthetic Data Generator", - css=_LOGGED_OUT_CSS, - ) as app: - with gr.Row(): - gr.HTML( - """

How does it work?

""" - ) - with gr.Row(): - gr.Markdown( - "Want to run this locally or with other LLMs? Take a look at the FAQ tab. distilabel Synthetic Data Generator is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation." - ) - with gr.Row(): - gr.Column() - get_login_button() - gr.Column() - - gr.Markdown("## Iterate on a sample dataset") - with gr.Column() as main_ui: - ( - dataset_description, - examples, - btn_generate_system_prompt, - system_prompt, - sample_dataset, - btn_generate_sample_dataset, - ) = get_iterate_on_sample_dataset_ui( - default_dataset_descriptions=default_dataset_descriptions, - default_system_prompts=default_system_prompts, - default_datasets=default_datasets, - task=task, - ) - gr.Markdown("## Generate full dataset") - gr.Markdown( - "Once you're satisfied with the sample, generate a larger dataset and push it to Argilla or the Hugging Face Hub." - ) - with gr.Row(variant="panel") as custom_input_ui: - pass - - ( - dataset_name, - add_to_existing_dataset, - btn_generate_full_dataset_argilla, - btn_generate_and_push_to_argilla, - btn_push_to_argilla, - org_name, - repo_name, - private, - btn_generate_full_dataset, - btn_generate_and_push_to_hub, - btn_push_to_hub, - final_dataset, - success_message, - ) = get_push_to_ui(default_datasets) - - sample_dataset.change( - fn=lambda x: x, - inputs=[sample_dataset], - outputs=[final_dataset], - ) - - btn_generate_system_prompt.click( - fn=fn_generate_system_prompt, - inputs=[dataset_description], - outputs=[system_prompt], - show_progress=True, - ).then( - fn=fn_generate_sample_dataset, - inputs=[system_prompt], - outputs=[sample_dataset], - show_progress=True, - ) - - btn_generate_sample_dataset.click( - fn=fn_generate_sample_dataset, - inputs=[system_prompt], - outputs=[sample_dataset], - show_progress=True, - ) - - app.load(fn=swap_visibility, outputs=main_ui) - app.load(get_org_dropdown, outputs=[org_name]) - - return ( - app, - main_ui, - custom_input_ui, - dataset_description, - examples, - btn_generate_system_prompt, - system_prompt, - sample_dataset, - btn_generate_sample_dataset, - dataset_name, - add_to_existing_dataset, - btn_generate_full_dataset_argilla, - btn_generate_and_push_to_argilla, - btn_push_to_argilla, - org_name, - repo_name, - private, - btn_generate_full_dataset, - btn_generate_and_push_to_hub, - btn_push_to_hub, - final_dataset, - success_message, - ) - - def validate_argilla_user_workspace_dataset( dataset_name: str, add_to_existing_dataset: bool = True, @@ -205,176 +58,6 @@ def get_org_dropdown(oauth_token: Union[OAuthToken, None]): ) -def get_push_to_ui(default_datasets): - with gr.Column() as push_to_ui: - ( - dataset_name, - add_to_existing_dataset, - btn_generate_full_dataset_argilla, - btn_generate_and_push_to_argilla, - btn_push_to_argilla, - ) = get_argilla_tab() - ( - org_name, - repo_name, - private, - btn_generate_full_dataset, - btn_generate_and_push_to_hub, - btn_push_to_hub, - ) = get_hf_tab() - final_dataset = get_final_dataset_row(default_datasets) - success_message = get_success_message_row() - return ( - dataset_name, - add_to_existing_dataset, - btn_generate_full_dataset_argilla, - btn_generate_and_push_to_argilla, - btn_push_to_argilla, - org_name, - repo_name, - private, - btn_generate_full_dataset, - btn_generate_and_push_to_hub, - btn_push_to_hub, - final_dataset, - success_message, - ) - - -def get_iterate_on_sample_dataset_ui( - default_dataset_descriptions: List[str], - default_system_prompts: List[str], - default_datasets: List[pd.DataFrame], - task: str, -): - with gr.Column(): - dataset_description = gr.TextArea( - label="Give a precise description of your desired application. Check the examples for inspiration.", - value=default_dataset_descriptions[0], - lines=2, - ) - examples = gr.Examples( - elem_id="system_prompt_examples", - examples=[[example] for example in default_dataset_descriptions], - inputs=[dataset_description], - ) - with gr.Row(): - gr.Column(scale=1) - btn_generate_system_prompt = gr.Button( - value="Generate system prompt and sample dataset", variant="primary" - ) - gr.Column(scale=1) - - system_prompt = gr.TextArea( - label="System prompt for dataset generation. You can tune it and regenerate the sample.", - value=default_system_prompts[0], - lines=2 if task == TEXTCAT_TASK else 5, - ) - - with gr.Row(): - sample_dataset = gr.Dataframe( - value=default_datasets[0], - label=( - "Sample dataset. Text truncated to 256 tokens." - if task == TEXTCAT_TASK - else "Sample dataset. Prompts and completions truncated to 256 tokens." - ), - interactive=False, - wrap=True, - ) - - with gr.Row(): - gr.Column(scale=1) - btn_generate_sample_dataset = gr.Button( - value="Generate sample dataset", variant="primary" - ) - gr.Column(scale=1) - - return ( - dataset_description, - examples, - btn_generate_system_prompt, - system_prompt, - sample_dataset, - btn_generate_sample_dataset, - ) - - -def get_argilla_tab() -> Tuple[Any]: - with gr.Tab(label="Argilla"): - if get_argilla_client() is not None: - with gr.Row(variant="panel"): - dataset_name = gr.Textbox( - label="Dataset name", - placeholder="dataset_name", - value="my-distiset", - ) - add_to_existing_dataset = gr.Checkbox( - label="Allow adding records to existing dataset", - info="When selected, you do need to ensure the dataset options are the same as in the existing dataset.", - value=False, - interactive=True, - scale=1, - ) - - with gr.Row(variant="panel"): - btn_generate_full_dataset_argilla = gr.Button( - value="Generate", variant="primary", scale=2 - ) - btn_generate_and_push_to_argilla = gr.Button( - value="Generate and Push to Argilla", - variant="primary", - scale=2, - ) - btn_push_to_argilla = gr.Button( - value="Push to Argilla", variant="primary", scale=2 - ) - else: - gr.Markdown( - "Please add `ARGILLA_API_URL` and `ARGILLA_API_KEY` to use Argilla or export the dataset to the Hugging Face Hub." - ) - return ( - dataset_name, - add_to_existing_dataset, - btn_generate_full_dataset_argilla, - btn_generate_and_push_to_argilla, - btn_push_to_argilla, - ) - - -def get_hf_tab() -> Tuple[Any]: - with gr.Tab("Hugging Face Hub"): - with gr.Row(variant="panel"): - org_name = get_org_dropdown() - repo_name = gr.Textbox( - label="Repo name", - placeholder="dataset_name", - value="my-distiset", - ) - private = gr.Checkbox( - label="Private dataset", - value=True, - interactive=True, - scale=1, - ) - with gr.Row(variant="panel"): - btn_generate_full_dataset = gr.Button( - value="Generate", variant="primary", scale=2 - ) - btn_generate_and_push_to_hub = gr.Button( - value="Generate and Push to Hub", variant="primary", scale=2 - ) - btn_push_to_hub = gr.Button(value="Push to Hub", variant="primary", scale=2) - return ( - org_name, - repo_name, - private, - btn_generate_full_dataset, - btn_generate_and_push_to_hub, - btn_push_to_hub, - ) - - def push_pipeline_code_to_hub( pipeline_code: str, org_name: str, @@ -455,24 +138,6 @@ def validate_push_to_hub(org_name, repo_name): return repo_id -def get_final_dataset_row(default_datasets) -> gr.Dataframe: - with gr.Row(): - final_dataset = gr.Dataframe( - value=default_datasets[0], - label="Generated dataset", - interactive=False, - wrap=True, - min_width=300, - ) - return final_dataset - - -def get_success_message_row() -> gr.Markdown: - with gr.Row(): - success_message = gr.Markdown(visible=False) - return success_message - - def show_success_message(org_name, repo_name) -> gr.Markdown: client = get_argilla_client() if client is None: diff --git a/src/distilabel_dataset_generator/apps/sft.py b/src/distilabel_dataset_generator/apps/sft.py index d071c64..f9655d3 100644 --- a/src/distilabel_dataset_generator/apps/sft.py +++ b/src/distilabel_dataset_generator/apps/sft.py @@ -28,7 +28,6 @@ get_response_generator, ) from distilabel_dataset_generator.utils import ( - _LOGGED_OUT_CSS, get_argilla_client, get_org_dropdown, swap_visibility, @@ -350,7 +349,7 @@ def hide_pipeline_code_visibility(): ###################### -with gr.Blocks(css=_LOGGED_OUT_CSS) as app: +with gr.Blocks() as app: with gr.Column() as main_ui: if not SFT_AVAILABLE: gr.Markdown( diff --git a/src/distilabel_dataset_generator/apps/textcat.py b/src/distilabel_dataset_generator/apps/textcat.py index ff4775c..43988ef 100644 --- a/src/distilabel_dataset_generator/apps/textcat.py +++ b/src/distilabel_dataset_generator/apps/textcat.py @@ -28,7 +28,6 @@ get_textcat_generator, ) from src.distilabel_dataset_generator.utils import ( - _LOGGED_OUT_CSS, get_argilla_client, get_org_dropdown, get_preprocess_labels, @@ -332,7 +331,7 @@ def hide_pipeline_code_visibility(): ###################### -with gr.Blocks(css=_LOGGED_OUT_CSS) as app: +with gr.Blocks() as app: with gr.Column() as main_ui: gr.Markdown("## 1. Describe the dataset you want") with gr.Row(): diff --git a/src/distilabel_dataset_generator/utils.py b/src/distilabel_dataset_generator/utils.py index 26def99..b894a87 100644 --- a/src/distilabel_dataset_generator/utils.py +++ b/src/distilabel_dataset_generator/utils.py @@ -6,10 +6,7 @@ import numpy as np import pandas as pd from gradio.oauth import ( - OAUTH_CLIENT_ID, - OAUTH_CLIENT_SECRET, - OAUTH_SCOPES, - OPENID_PROVIDER_URL, + OAuthToken, get_space, ) from huggingface_hub import whoami @@ -17,30 +14,6 @@ from distilabel_dataset_generator.constants import argilla_client -_LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}" - - -_CHECK_IF_SPACE_IS_SET = ( - all( - [ - OAUTH_CLIENT_ID, - OAUTH_CLIENT_SECRET, - OAUTH_SCOPES, - OPENID_PROVIDER_URL, - ] - ) - or get_space() is None -) - -if _CHECK_IF_SPACE_IS_SET: - from gradio.oauth import OAuthToken -else: - OAuthToken = str - - -def get_login_button(): - return gr.LoginButton(value="Sign in!", size="sm", scale=2).activate() - def get_duplicate_button(): if get_space() is not None: @@ -85,13 +58,6 @@ def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None): ) -def get_token(oauth_token: Union[OAuthToken, None]): - if oauth_token: - return oauth_token.token - else: - return "" - - def swap_visibility(oauth_token: Union[OAuthToken, None]): if oauth_token: return gr.update(elem_classes=["main_ui_logged_in"]) @@ -99,28 +65,6 @@ def swap_visibility(oauth_token: Union[OAuthToken, None]): return gr.update(elem_classes=["main_ui_logged_out"]) -def get_base_app(): - with gr.Blocks( - title="🧬 Synthetic Data Generator", - head="🧬 Synthetic Data Generator", - css=_LOGGED_OUT_CSS, - ) as app: - with gr.Row(): - gr.Markdown( - "Want to run this locally or with other LLMs? Take a look at the FAQ tab. distilabel Synthetic Data Generator is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation." - ) - with gr.Row(): - gr.Column() - get_login_button() - gr.Column() - - gr.Markdown("## Iterate on a sample dataset") - with gr.Column() as main_ui: - pass - - return app - - def get_argilla_client() -> Union[rg.Argilla, None]: return argilla_client From 5ac0c97111618087bd8ad32e692765dae42d425f Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 3 Dec 2024 15:20:58 +0100 Subject: [PATCH 06/11] remove obsolete code --- src/distilabel_dataset_generator/apps/base.py | 15 +-------------- src/distilabel_dataset_generator/constants.py | 7 +++++++ .../pipelines/embeddings.py | 5 +++-- 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/distilabel_dataset_generator/apps/base.py b/src/distilabel_dataset_generator/apps/base.py index da19daf..aead4df 100644 --- a/src/distilabel_dataset_generator/apps/base.py +++ b/src/distilabel_dataset_generator/apps/base.py @@ -10,14 +10,11 @@ from gradio import OAuthToken from huggingface_hub import HfApi, upload_file +from distilabel_dataset_generator.constants import TEXTCAT_TASK from distilabel_dataset_generator.utils import ( get_argilla_client, - list_orgs, ) -TEXTCAT_TASK = "text_classification" -SFT_TASK = "supervised_fine_tuning" - def validate_argilla_user_workspace_dataset( dataset_name: str, @@ -48,16 +45,6 @@ def validate_argilla_user_workspace_dataset( return "" -def get_org_dropdown(oauth_token: Union[OAuthToken, None]): - orgs = list_orgs(oauth_token) - return gr.Dropdown( - label="Organization", - choices=orgs, - value=orgs[0] if orgs else None, - allow_custom_value=True, - ) - - def push_pipeline_code_to_hub( pipeline_code: str, org_name: str, diff --git a/src/distilabel_dataset_generator/constants.py b/src/distilabel_dataset_generator/constants.py index c9aaaa4..4732a09 100644 --- a/src/distilabel_dataset_generator/constants.py +++ b/src/distilabel_dataset_generator/constants.py @@ -3,6 +3,10 @@ import argilla as rg +# Tasks +TEXTCAT_TASK = "text_classification" +SFT_TASK = "supervised_fine_tuning" + # Hugging Face HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN is None: @@ -38,6 +42,9 @@ else: MAGPIE_PRE_QUERY_TEMPLATE = "llama3" +# Embeddings +STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M" + # Argilla ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") diff --git a/src/distilabel_dataset_generator/pipelines/embeddings.py b/src/distilabel_dataset_generator/pipelines/embeddings.py index 1ab3873..3275713 100644 --- a/src/distilabel_dataset_generator/pipelines/embeddings.py +++ b/src/distilabel_dataset_generator/pipelines/embeddings.py @@ -3,8 +3,9 @@ from sentence_transformers import SentenceTransformer from sentence_transformers.models import StaticEmbedding -# Initialize a StaticEmbedding module -static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-8M") +from distilabel_dataset_generator.constants import STATIC_EMBEDDING_MODEL + +static_embedding = StaticEmbedding.from_model2vec(STATIC_EMBEDDING_MODEL) model = SentenceTransformer(modules=[static_embedding]) From ad7d65a34695f11994fe48b1b5624e997d82f846 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 3 Dec 2024 16:33:11 +0100 Subject: [PATCH 07/11] add examples for openai to readme --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 821117e..fad2e59 100644 --- a/README.md +++ b/README.md @@ -92,9 +92,9 @@ demo.launch() Optionally, you can set the following environment variables to customize the generation process. -- `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api-inference.huggingface.co/v1/`. -- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`. -- `API_KEY`: The API key to use for the corresponding API, e.g. `hf_...`. +- `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api-inference.huggingface.co/v1/`, `https://api.openai.com/v1/`. +- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`. +- `API_KEY`: The API key to use for the corresponding API, e.g. `hf_...`, `sk-...`. Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables: From 8395748720a63a7cef540b3faae3ea2f76700670 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 3 Dec 2024 16:34:05 +0100 Subject: [PATCH 08/11] fix InferenceEndpointsLLM --- src/distilabel_dataset_generator/__init__.py | 53 ++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/distilabel_dataset_generator/__init__.py b/src/distilabel_dataset_generator/__init__.py index a0ec0dc..9b8c50d 100644 --- a/src/distilabel_dataset_generator/__init__.py +++ b/src/distilabel_dataset_generator/__init__.py @@ -1,12 +1,64 @@ +import warnings from typing import Optional import distilabel import distilabel.distiset +from distilabel.llms import InferenceEndpointsLLM from distilabel.utils.card.dataset_card import ( DistilabelDatasetCard, size_categories_parser, ) from huggingface_hub import DatasetCardData, HfApi +from pydantic import ( + ValidationError, + model_validator, +) + + +class CustomInferenceEndpointsLLM(InferenceEndpointsLLM): + @model_validator(mode="after") # type: ignore + def only_one_of_model_id_endpoint_name_or_base_url_provided( + self, + ) -> "InferenceEndpointsLLM": + """Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also + provided, a warning will be shown informing the user that the provided `base_url` will be ignored in + favour of the dynamically calculated one..""" + + if self.base_url and (self.model_id or self.endpoint_name): + warnings.warn( # type: ignore + f"Since the `base_url={self.base_url}` is available and either one of `model_id`" + " or `endpoint_name` is also provided, the `base_url` will either be ignored" + " or overwritten with the one generated from either of those args, for serverless" + " or dedicated inference endpoints, respectively." + ) + + if self.use_magpie_template and self.tokenizer_id is None: + raise ValueError( + "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please," + " set a `tokenizer_id` and try again." + ) + + if ( + self.model_id + and self.tokenizer_id is None + and self.structured_output is not None + ): + self.tokenizer_id = self.model_id + + if self.base_url and not (self.model_id or self.endpoint_name): + return self + + if self.model_id and not self.endpoint_name: + return self + + if self.endpoint_name and not self.model_id: + return self + + raise ValidationError( + f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is" + f" provided too, it will be overwritten instead. Found `model_id`={self.model_id}," + f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}." + ) class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset): @@ -111,3 +163,4 @@ def _get_card( distilabel.distiset.Distiset = CustomDistisetWithAdditionalTag +distilabel.llms.InferenceEndpointsLLM = CustomInferenceEndpointsLLM From 356864a343bf61ef3b5f99939b958721509f28b1 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 3 Dec 2024 16:34:16 +0100 Subject: [PATCH 09/11] remove environment variables --- app.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/app.py b/app.py index a6e9bae..a952cb7 100644 --- a/app.py +++ b/app.py @@ -1,9 +1,4 @@ -import os - from distilabel_dataset_generator.app import demo -os.environ["API_KEY"] = "hf_..." -os.environ["API_KEY"] = "hf_..." - if __name__ == "__main__": demo.launch() From 6476dbffe9e97604716f067796d65d61ff235b0b Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 3 Dec 2024 16:34:43 +0100 Subject: [PATCH 10/11] remove shields from readem --- README.md | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/README.md b/README.md index fad2e59..b5f272e 100644 --- a/README.md +++ b/README.md @@ -27,18 +27,6 @@ hf_oauth_scopes: ![Synthetic Data Generator](https://huggingface.co/spaces/argilla/synthetic-data-generator/resolve/main/assets/ui-full.png) -

- -CI - - -CI - - - - -

-

From 9feda8cc5967a748ff1c74d7e7db1efc41b2c250 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 3 Dec 2024 16:41:23 +0100 Subject: [PATCH 11/11] update formatting --- README.md | 2 +- src/distilabel_dataset_generator/_tabbedinterface.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b5f272e..ca77bd8 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ hf_oauth_scopes:


- Synthetic Data Generator + 🧬 Synthetic Data Generator

Build datasets using natural language

diff --git a/src/distilabel_dataset_generator/_tabbedinterface.py b/src/distilabel_dataset_generator/_tabbedinterface.py index 277004f..4263c06 100644 --- a/src/distilabel_dataset_generator/_tabbedinterface.py +++ b/src/distilabel_dataset_generator/_tabbedinterface.py @@ -68,7 +68,9 @@ def __init__( with gr.Column(scale=3): pass with gr.Column(scale=2): - gr.LoginButton(value="Sign in!", variant="hf-login", size="sm", scale=2) + gr.LoginButton( + value="Sign in", variant="hf-login", size="sm", scale=2 + ) with Tabs(): for interface, tab_name in zip(interface_list, tab_names, strict=False): with Tab(label=tab_name):