diff --git a/argilla/docs/how_to_guides/query_export.md b/argilla/docs/how_to_guides/query_export.md
index c0f4ae1b13..aa252b793b 100644
--- a/argilla/docs/how_to_guides/query_export.md
+++ b/argilla/docs/how_to_guides/query_export.md
@@ -48,7 +48,7 @@ To search for records with terms, you can use the `Dataset.records` attribute wi
query = rg.Query(query="my_term")
- queried_records = list(dataset.records(query=query))
+ queried_records = dataset.records(query=query).to_list(flatten=True)
```
=== "Multiple search term"
@@ -64,19 +64,19 @@ To search for records with terms, you can use the `Dataset.records` attribute wi
query = rg.Query(query="my_term1 my_term2")
- queried_records = list(dataset.records(query=query))
+ queried_records = dataset.records(query=query).to_list(flatten=True)
```
## Filter by conditions
You can use the `Filter` class to define the conditions and pass them to the `Dataset.records` attribute to fetch records based on the conditions. Conditions include "==", ">=", "<=", or "in". Conditions can be combined with dot notation to filter records based on metadata, suggestions, or responses. You can use a single condition or multiple conditions to filter records.
-| operator | description |
-|----------|-------------|
-| `==` | The `field` value is equal to the `value` |
+| operator | description |
+| -------- | --------------------------------------------------------- |
+| `==` | The `field` value is equal to the `value` |
| `>=` | The `field` value is greater than or equal to the `value` |
-| `<=` | The `field` value is less than or equal to the `value` |
-| `in` | TThe `field` value is included in a list of values |
+| `<=` | The `field` value is less than or equal to the `value` |
+| `in` | TThe `field` value is included in a list of values |
=== "Single condition"
@@ -91,7 +91,9 @@ You can use the `Filter` class to define the conditions and pass them to the `Da
filter_label = rg.Filter(("label", "==", "positive"))
- filtered_records = list(dataset.records(query=rg.Query(filter=filter_label)))
+ filtered_records = dataset.records(query=rg.Query(filter=filter_label)).to_list(
+ flatten=True
+ )
```
=== "Multiple conditions"
@@ -114,10 +116,10 @@ You can use the `Filter` class to define the conditions and pass them to the `Da
]
)
- filtered_records = list(dataset.records(
- query=rg.Query(filter=filters)),
- with_suggestions=True
- )
+ filtered_records = dataset.records(
+ query=rg.Query(filter=filters),
+ with_suggestions=True,
+ ).to_list(flatten=True)
```
## Filter by status
@@ -137,7 +139,7 @@ status_filter = rg.Query(
filter=rg.Filter(("response.status", "==", "submitted"))
)
-filtered_records = list(dataset.records(status_filter))
+filtered_records = dataset.records(status_filter).to_list(flatten=True)
```
## Query and filter a dataset
@@ -163,12 +165,11 @@ query_filter = rg.Query(
)
)
-queried_filtered_records = list(dataset.records(
+queried_filtered_records = dataset.records(
query=query_filter,
with_metadata=True,
with_suggestions=True
-)
-)
+).to_list(flatten=True)
```
## Export records to a dictionary
diff --git a/argilla/docs/tutorials/index.md b/argilla/docs/tutorials/index.md
index 94fa165a56..af9ec0f56b 100644
--- a/argilla/docs/tutorials/index.md
+++ b/argilla/docs/tutorials/index.md
@@ -10,11 +10,18 @@ These are the tutorials for *the Argilla SDK*. They provide step-by-step instruc
-- __Text classification task__
+- __Text classification__
---
Learn about a standard workflow to improve data quality for a text classification task.
[:octicons-arrow-right-24: Tutorial](text_classification.ipynb)
-
\ No newline at end of file
+- __Token classification__
+
+ ---
+
+ Learn about a standard workflow to improve data quality for a token classification task.
+ [:octicons-arrow-right-24: Tutorial](token_classification.ipynb)
+
+
diff --git a/argilla/docs/tutorials/text_classification.ipynb b/argilla/docs/tutorials/text_classification.ipynb
index 9d3b6dffb7..f22ccf7295 100644
--- a/argilla/docs/tutorials/text_classification.ipynb
+++ b/argilla/docs/tutorials/text_classification.ipynb
@@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Text classification task"
+ "# Text classification"
]
},
{
@@ -91,7 +91,7 @@
},
{
"cell_type": "code",
- "execution_count": 42,
+ "execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -119,7 +119,7 @@
"# Uncomment the last line and set your HF_TOKEN if your space is private\n",
"client = rg.Argilla(\n",
" api_url=\"https://[your-owner-name]-[your_space_name].hf.space\",\n",
- " api_key=\"owner.apikey\"\n",
+ " api_key=\"owner.apikey\",\n",
" # headers={\"Authorization\": f\"Bearer {HF_TOKEN}\"}\n",
")"
]
@@ -143,7 +143,7 @@
},
{
"cell_type": "code",
- "execution_count": 37,
+ "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@@ -182,7 +182,7 @@
"outputs": [],
"source": [
"dataset = rg.Dataset(\n",
- " name=\"text_classification_dataset1\",\n",
+ " name=\"text_classification_dataset\",\n",
" settings=settings,\n",
")\n",
"dataset.create()"
@@ -204,7 +204,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -238,7 +238,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "The next step is to add suggestions to the dataset. In our case, we will generate them using a zero-shot SetFit model. However, you can use a framework or technique of your choice."
+ "The next step is to add suggestions to the dataset. This will make things easier and faster for the annotation team. Suggestions will appear as preselected options, so annotators will only need to correct them. In our case, we will generate them using a zero-shot SetFit model. However, you can use a framework or technique of your choice."
]
},
{
@@ -250,7 +250,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -272,12 +272,11 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def train_model(model_name, dataset):\n",
- " \n",
" model = SetFitModel.from_pretrained(model_name)\n",
"\n",
" trainer = Trainer(\n",
@@ -286,7 +285,7 @@
" )\n",
"\n",
" trainer.train()\n",
- " \n",
+ "\n",
" return model"
]
},
@@ -337,16 +336,15 @@
},
{
"cell_type": "code",
- "execution_count": 35,
+ "execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def predict(model, input, labels):\n",
- " \n",
" model.labels = labels\n",
- " \n",
+ "\n",
" prediction = model.predict([input])\n",
- " \n",
+ "\n",
" return prediction[0]"
]
},
@@ -422,7 +420,7 @@
},
{
"cell_type": "code",
- "execution_count": 43,
+ "execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
@@ -431,12 +429,13 @@
},
{
"cell_type": "code",
- "execution_count": 44,
+ "execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
- "status_filter = rg.Query(filter = rg.Filter((\"status\", \"==\", \"submitted\")))\n",
- "submitted = list(dataset.records(status_filter))"
+ "status_filter = rg.Query(filter=rg.Filter((\"response.status\", \"==\", \"submitted\")))\n",
+ "\n",
+ "submitted = dataset.records(status_filter).to_list(flatten=True)"
]
},
{
@@ -452,10 +451,12 @@
"metadata": {},
"outputs": [],
"source": [
- "train_records = [{\n",
- " \"text\" : r.fields[\"review\"],\n",
- " \"label\" : r.responses.sentiment_label[0].value,\n",
- " } for r in submitted\n",
+ "train_records = [\n",
+ " {\n",
+ " \"text\": r[\"review\"],\n",
+ " \"label\": r[\"sentiment_label.responses\"][0],\n",
+ " }\n",
+ " for r in submitted\n",
"]\n",
"train_dataset = Dataset.from_list(train_records)\n",
"train_dataset = sample_dataset(train_dataset, label_column=\"label\", num_samples=8)"
diff --git a/argilla/docs/tutorials/token_classification.ipynb b/argilla/docs/tutorials/token_classification.ipynb
new file mode 100644
index 0000000000..4c328f15ac
--- /dev/null
+++ b/argilla/docs/tutorials/token_classification.ipynb
@@ -0,0 +1,736 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Token classification\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this tutorial, we will show a standard workflow for a token classification task, in this case, using GLiNER, SpanMarker and Argilla.\n",
+ "\n",
+ "We will follow these steps:\n",
+ "\n",
+ "- Configure the Argilla dataset\n",
+ "- Add initial model suggestions\n",
+ "- Evaluate with Argilla\n",
+ "- Train your model\n",
+ "- Update the suggestions with the new model\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Getting started\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "If you have already deployed Argilla Server, you can skip this step. Otherwise, you can quickly deploy it in two different ways:\n",
+ "\n",
+ "- Remotely using a [HF Space](https://huggingface.co/new-space?template=argilla/argilla-template-space). ⚠️ If persistent storage is not enabled, you will lose your data when the server is stopped.\n",
+ "\n",
+ "!!! note\n",
+ "As this is a release candidate version, you'll need to manually change the version in the HF Space Files > Dockerfile to `argilla/argilla-quickstart:v2.0.0rc2`.\n",
+ "\n",
+ "- Locally using Docker: `docker run -d --name quickstart -p 6900:6900 argilla/argilla-quickstart:v2.0.0rc2`\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Set up the environment\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To complete this tutorial, you need to install the Argilla SDK and a few third-party libraries via `pip`.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install argilla --pre"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install gliner==0.2.6 transformers==4.40.2 span_marker=1.5.0"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's make the needed imports:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import re\n",
+ "\n",
+ "import argilla as rg\n",
+ "\n",
+ "from datasets import load_dataset, Dataset, DatasetDict\n",
+ "from gliner import GLiNER\n",
+ "from span_marker import SpanMarkerModel, Trainer\n",
+ "from transformers import TrainingArguments"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You also need to connect to the Argilla server with the `api_url` and `api_key`.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Replace api_url with your url if using Docker\n",
+ "# Replace api_key if you configured a custom API key\n",
+ "# Uncomment the last line and set your HF_TOKEN if your space is private\n",
+ "client = rg.Argilla(\n",
+ " api_url=\"https://[your-owner-name]-[your_space_name].hf.space\",\n",
+ " api_key=\"owner.apikey\",\n",
+ " # headers={\"Authorization\": f\"Bearer {HF_TOKEN}\"}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Configure and create the Argilla dataset\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we will need to configure the dataset. In the settings, we can specify the guidelines, fields, and questions. If needed, you can also add metadata and vectors. However, for our use case, we just need a text field and a span question. We will focus on Name Entity Recognition, but this workflow can also be applied to Span Classification, which differs in that the spans are less clearly defined and often overlap."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 128,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "labels = [\n",
+ " \"CARDINAL\",\n",
+ " \"DATE\",\n",
+ " \"PERSON\",\n",
+ " \"NORP\",\n",
+ " \"GPE\",\n",
+ " \"LAW\",\n",
+ " \"PERCENT\",\n",
+ " \"ORDINAL\",\n",
+ " \"MONEY\",\n",
+ " \"WORK_OF_ART\",\n",
+ " \"FAC\",\n",
+ " \"TIME\",\n",
+ " \"QUANTITY\",\n",
+ " \"PRODUCT\",\n",
+ " \"LANGUAGE\",\n",
+ " \"ORG\",\n",
+ " \"LOC\",\n",
+ " \"EVENT\",\n",
+ "]\n",
+ "\n",
+ "settings = rg.Settings(\n",
+ " guidelines=\"Classify individual tokens according to the specified categories, ensuring that any overlapping or nested entities are accurately captured.\",\n",
+ " fields=[\n",
+ " rg.TextField(\n",
+ " name=\"text\",\n",
+ " title=\"Text\",\n",
+ " use_markdown=False,\n",
+ " ),\n",
+ " ],\n",
+ " questions=[\n",
+ " rg.SpanQuestion(\n",
+ " name=\"span_label\",\n",
+ " field=\"text\",\n",
+ " labels=labels,\n",
+ " title=\"Classify the tokens according to the specified categories.\",\n",
+ " allow_overlapping=False,\n",
+ " )\n",
+ " ],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's create the dataset with the name and the defined settings:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = rg.Dataset(\n",
+ " name=\"token_classification_dataset\",\n",
+ " settings=settings,\n",
+ ")\n",
+ "dataset.create()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Add records\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We have created the dataset (you can check it in the UI), but we still need to add the data for annotation. In this case, we will use the `ontonote5` dataset from the [Hugging Face Hub](https://huggingface.co/datasets/tner/ontonotes5?row=0). Specifically, we will use 2100 samples from the `test` split."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 130,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "hf_dataset = load_dataset(\"tner/ontonotes5\", split=\"test[:2100]\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will iterate over the Hugging Face dataset, adding data to the corresponding field in the `Record` object for the Argilla dataset. Then, we will easily add them to the dataset using `log`.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "records = [rg.Record(fields={\"text\": \" \".join(row[\"tokens\"])}) for row in hf_dataset]\n",
+ "\n",
+ "dataset.records.log(records)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Add initial model suggestions\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The next step is to add suggestions to the dataset. This will make things easier and faster for the annotation team. Suggestions will appear as preselected options, so annotators will only need to correct them. In our case, we will generate them using a GLiNER model. However, you can use a framework or technique of your choice.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "!!! note\n",
+ "For further information, you can check the [GLiNER repository](https://github.com/urchade/GLiNER) and the [original paper](https://arxiv.org/abs/2311.08526).\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will start by loading the pre-trained GLiNER model. Specifically, we will use `gliner_mediumv2`, available in [Hugging Face Hub](https://huggingface.co/urchade/gliner_medium-v1).\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "gliner_model = GLiNER.from_pretrained(\"urchade/gliner_mediumv2.1\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we will create a function to generate predictions using this general model, which can identify the specified labels without being pre-trained on them. The function will return a dictionary formatted with the necessary schema to add entities to our Argilla dataset. This schema includes the keys 'start’ and ‘end’ to indicate the indices where the span begins and ends, as well as ‘label’ for the entity label.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 113,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def predict_gliner(model, text, labels, threshold):\n",
+ " entities = model.predict_entities(text, labels, threshold)\n",
+ " return [\n",
+ " {k: v for k, v in ent.items() if k not in {\"score\", \"text\"}} for ent in entities\n",
+ " ]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To update the records, we will need to retrieve them from the server and update them with the new suggestions. The `id` will always need to be provided as it is the records' identifier to update a record and avoid creating a new one.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data = dataset.records.to_list(flatten=True)\n",
+ "updated_data = [\n",
+ " {\n",
+ " \"span_label\": predict_gliner(\n",
+ " model=gliner_model, text=sample[\"text\"], labels=labels, threshold=0.70\n",
+ " ),\n",
+ " \"id\": sample[\"id\"],\n",
+ " }\n",
+ " for sample in data\n",
+ "]\n",
+ "dataset.records.log(records=updated_data)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Voilà! We have added the suggestions to the dataset and they will appear in the UI marked with ✨.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Evaluate with Argilla\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we can start the annotation process. Just open the dataset in the Argilla UI and start annotating the records. If the suggestions are correct, you can just click on `Submit`. Otherwise, you can select the correct label.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "!!! note\n",
+ "Check this [how-to guide](../how_to_guides/annotate.md) to know more about annotating in the UI.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Train your model\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "After the annotation, we will have a robust dataset to train our model for entity recognition. For our case, we will train a SpanMarker model, but you can select any model of your choice. So, let's start by retrieving the annotated records.\n",
+ "\n",
+ "> Check this [how-to guide](../how_to_guides/query_export.md) to learn more about filtering and querying in Argilla.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 132,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = client.datasets(\"token_classification_dataset\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In our case, we submitted 2000 annotations using the bulk view.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "status_filter = rg.Query(filter=rg.Filter((\"response.status\", \"==\", \"submitted\")))\n",
+ "\n",
+ "submitted = dataset.records(status_filter).to_list(flatten=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "SpanMarker accepts any dataset as long as it has the `tokens` and `ner_tags` columns. The `ner_tags` can be annotated using the IOB, IOB2, BIOES or BILOU labeling scheme, as well as regular unschemed labels. In our case, we have chosen to use the IOB format. Thus, we will define a function to extract the annotated NER tags according to this schema.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "!!! note\n",
+ "For further information, you can check the [SpanMarker documentation](https://tomaarsen.github.io/SpanMarkerNER/).\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 61,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_iob_tag_for_token(token_start, token_end, ner_spans):\n",
+ " for span in ner_spans:\n",
+ " if token_start >= span[\"start\"] and token_end <= span[\"end\"]:\n",
+ " if token_start == span[\"start\"]:\n",
+ " return f\"B-{span['label']}\"\n",
+ " else:\n",
+ " return f\"I-{span['label']}\"\n",
+ " return \"O\"\n",
+ "\n",
+ "\n",
+ "def extract_ner_tags(text, responses):\n",
+ " tokens = re.split(r\"(\\s+)\", text)\n",
+ " ner_tags = []\n",
+ "\n",
+ " current_position = 0\n",
+ " for token in tokens:\n",
+ " if token.strip():\n",
+ " token_start = current_position\n",
+ " token_end = current_position + len(token)\n",
+ " tag = get_iob_tag_for_token(token_start, token_end, responses)\n",
+ " ner_tags.append(tag)\n",
+ " current_position += len(token)\n",
+ "\n",
+ " return ner_tags"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's now extract them and save two lists with the tokens and NER tags, which will help us build our dataset to train the SpanMarker model.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 62,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokens = []\n",
+ "ner_tags = []\n",
+ "for r in submitted:\n",
+ " tags = extract_ner_tags(r[\"text\"], r[\"span_label.responses\"][0])\n",
+ " tks = r[\"text\"].split()\n",
+ " tokens.append(tks)\n",
+ " ner_tags.append(tags)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In addition, we will have to indicate the labels and they should be formatted as integers. So, we will retrieve them and map them.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "labels = list(set([item for sublist in ner_tags for item in sublist]))\n",
+ "\n",
+ "id2label = {i: label for i, label in enumerate(labels)}\n",
+ "label2id = {label: id_ for id_, label in id2label.items()}\n",
+ "\n",
+ "mapped_ner_tags = [[label2id[label] for label in ner_tag] for ner_tag in ner_tags]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Finally, we will create a dataset with the train and validation sets.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 66,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "records = [\n",
+ " {\n",
+ " \"tokens\": token,\n",
+ " \"ner_tags\": ner_tag,\n",
+ " }\n",
+ " for token, ner_tag in zip(tokens, mapped_ner_tags)\n",
+ "]\n",
+ "span_dataset = DatasetDict(\n",
+ " {\n",
+ " \"train\": Dataset.from_list(records[:1500]),\n",
+ " \"validation\": Dataset.from_list(records[1501:2000]),\n",
+ " }\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, let's prepare to train our model. For this, it is recommended to use GPU. You can check if it is available as shown below.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "\n",
+ "if torch.cuda.is_available():\n",
+ " device = torch.device(\"cuda\")\n",
+ " print(f\"Using {torch.cuda.get_device_name(0)}\")\n",
+ "elif torch.backends.mps.is_available():\n",
+ " device = torch.device(\"mps\")\n",
+ " print(\"Using MPS device\")\n",
+ "else:\n",
+ " device = torch.device(\"cpu\")\n",
+ " print(\"No GPU available, using CPU instead.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will define our model and arguments. In this case, we will use the `bert-base-cased`, available in the [Hugging Face Hub](https://huggingface.co/google-bert/bert-base-cased), but others can be applied.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "!!! note\n",
+ "The training arguments are inherited from the Transformers library. You can check more information [here](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments).\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "encoder_id = \"bert-base-cased\"\n",
+ "model = SpanMarkerModel.from_pretrained(\n",
+ " encoder_id,\n",
+ " labels=labels,\n",
+ " model_max_length=256,\n",
+ " entity_max_length=8,\n",
+ ")\n",
+ "\n",
+ "args = TrainingArguments(\n",
+ " output_dir=\"models/span-marker\",\n",
+ " learning_rate=5e-5,\n",
+ " per_device_train_batch_size=8,\n",
+ " per_device_eval_batch_size=8,\n",
+ " num_train_epochs=1,\n",
+ " weight_decay=0.01,\n",
+ " warmup_ratio=0.1,\n",
+ " fp16=False, # Set to True if available\n",
+ " logging_first_step=True,\n",
+ " logging_steps=50,\n",
+ " evaluation_strategy=\"steps\",\n",
+ " save_strategy=\"steps\",\n",
+ " eval_steps=500,\n",
+ " save_total_limit=2,\n",
+ " dataloader_num_workers=2,\n",
+ ")\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " args=args,\n",
+ " train_dataset=span_dataset[\"train\"],\n",
+ " eval_dataset=span_dataset[\"validation\"],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's train it! This time, we use a high-quality human-annotated training set, so the results are expected to have improved.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainer.evaluate()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You can save it locally or push it to the Hub. And then load it from there.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Save and load locally\n",
+ "# model.save_pretrained(\"token_classification_model\")\n",
+ "# model = SpanMarkerModel.from_pretrained(\"token_classification_model\")\n",
+ "\n",
+ "# Push and load in HF\n",
+ "# model.push_to_hub(\"[username]/token_classification_model\")\n",
+ "# model = SpanMarkerModel.from_pretrained(\"[username]/token_classification_model\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "It's time to make the predictions! We will set a function that uses the `predict` method to get the suggested label. The model will infer the label based on the text. The function will return the spans in the corresponding structure for the Argilla dataset.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 117,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def predict_spanmarker(model, text):\n",
+ " entities = model.predict(text)\n",
+ " return [\n",
+ " {\n",
+ " \"start\": ent[\"char_start_index\"],\n",
+ " \"end\": ent[\"char_end_index\"],\n",
+ " \"label\": ent[\"label\"],\n",
+ " }\n",
+ " for ent in entities\n",
+ " ]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "As the training data was of better quality, we can expect a better model. So we can update the remaining non-annotated records with the new model's suggestions.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data = dataset.records.to_list(flatten=True)\n",
+ "updated_data = [\n",
+ " {\n",
+ " \"span_label\": predict_spanmarker(model=model, text=sample[\"text\"]),\n",
+ " \"id\": sample[\"id\"],\n",
+ " }\n",
+ " for sample in data\n",
+ "]\n",
+ "dataset.records.log(records=updated_data)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Conclusions\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this tutorial, we present an end-to-end example of a token classification task. This serves as the base, but it can be performed iteratively and seamlessly integrated into your workflow to ensure high-quality curation of your data and improved results.\n",
+ "\n",
+ "We started by configuring the dataset, adding records, and adding suggestions based on the GLiNer predictions. After the annotation process, we trained a SpanMarker model with the annotated data and updated the remaining records with the new suggestions.\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "argilla",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/argilla/mkdocs.yml b/argilla/mkdocs.yml
index 1f6041ee7e..0a29e6461f 100644
--- a/argilla/mkdocs.yml
+++ b/argilla/mkdocs.yml
@@ -144,7 +144,8 @@ nav:
- Migrate your legacy datasets to Argilla V2: how_to_guides/migrate_from_legacy_datasets.md
- Tutorials:
- tutorials/index.md
- - Text classification task: tutorials/text_classification.ipynb
+ - Text classification: tutorials/text_classification.ipynb
+ - Token classification: tutorials/token_classification.ipynb
- API Reference:
- Python SDK: reference/argilla/
- Community: