Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: token classification tutorial #5183

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions argilla/docs/how_to_guides/query_export.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ To search for records with terms, you can use the `Dataset.records` attribute wi

query = rg.Query(query="my_term")

queried_records = list(dataset.records(query=query))
queried_records = dataset.records(query=query).to_list(flatten=True)
```

=== "Multiple search term"
Expand All @@ -64,19 +64,19 @@ To search for records with terms, you can use the `Dataset.records` attribute wi

query = rg.Query(query="my_term1 my_term2")

queried_records = list(dataset.records(query=query))
queried_records = dataset.records(query=query).to_list(flatten=True)
```

## Filter by conditions

You can use the `Filter` class to define the conditions and pass them to the `Dataset.records` attribute to fetch records based on the conditions. Conditions include "==", ">=", "<=", or "in". Conditions can be combined with dot notation to filter records based on metadata, suggestions, or responses. You can use a single condition or multiple conditions to filter records.

| operator | description |
|----------|-------------|
| `==` | The `field` value is equal to the `value` |
| operator | description |
| -------- | --------------------------------------------------------- |
| `==` | The `field` value is equal to the `value` |
| `>=` | The `field` value is greater than or equal to the `value` |
| `<=` | The `field` value is less than or equal to the `value` |
| `in` | TThe `field` value is included in a list of values |
| `<=` | The `field` value is less than or equal to the `value` |
| `in` | TThe `field` value is included in a list of values |

=== "Single condition"

Expand All @@ -91,7 +91,9 @@ You can use the `Filter` class to define the conditions and pass them to the `Da

filter_label = rg.Filter(("label", "==", "positive"))

filtered_records = list(dataset.records(query=rg.Query(filter=filter_label)))
filtered_records = dataset.records(query=rg.Query(filter=filter_label)).to_list(
flatten=True
)
```

=== "Multiple conditions"
Expand All @@ -114,10 +116,10 @@ You can use the `Filter` class to define the conditions and pass them to the `Da
]
)

filtered_records = list(dataset.records(
query=rg.Query(filter=filters)),
with_suggestions=True
)
filtered_records = dataset.records(
query=rg.Query(filter=filters),
with_suggestions=True,
).to_list(flatten=True)
```

## Filter by status
Expand All @@ -137,7 +139,7 @@ status_filter = rg.Query(
filter=rg.Filter(("response.status", "==", "submitted"))
)

filtered_records = list(dataset.records(status_filter))
filtered_records = dataset.records(status_filter).to_list(flatten=True)
```

## Query and filter a dataset
Expand All @@ -163,12 +165,11 @@ query_filter = rg.Query(
)
)

queried_filtered_records = list(dataset.records(
queried_filtered_records = dataset.records(
query=query_filter,
with_metadata=True,
with_suggestions=True
)
)
).to_list(flatten=True)
```

## Export records to a dictionary
Expand Down
11 changes: 9 additions & 2 deletions argilla/docs/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,18 @@ These are the tutorials for *the Argilla SDK*. They provide step-by-step instruc

<div class="grid cards" markdown>

- __Text classification task__
- __Text classification__

---

Learn about a standard workflow to improve data quality for a text classification task.
[:octicons-arrow-right-24: Tutorial](text_classification.ipynb)

</div>
- __Token classification__

---

Learn about a standard workflow to improve data quality for a token classification task.
[:octicons-arrow-right-24: Tutorial](token_classification.ipynb)

</div>
47 changes: 24 additions & 23 deletions argilla/docs/tutorials/text_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Text classification task"
"# Text classification"
]
},
{
Expand Down Expand Up @@ -91,7 +91,7 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -119,7 +119,7 @@
"# Uncomment the last line and set your HF_TOKEN if your space is private\n",
"client = rg.Argilla(\n",
" api_url=\"https://[your-owner-name]-[your_space_name].hf.space\",\n",
" api_key=\"owner.apikey\"\n",
" api_key=\"owner.apikey\",\n",
" # headers={\"Authorization\": f\"Bearer {HF_TOKEN}\"}\n",
")"
]
Expand All @@ -143,7 +143,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -182,7 +182,7 @@
"outputs": [],
"source": [
"dataset = rg.Dataset(\n",
" name=\"text_classification_dataset1\",\n",
" name=\"text_classification_dataset\",\n",
" settings=settings,\n",
")\n",
"dataset.create()"
Expand All @@ -204,7 +204,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -238,7 +238,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The next step is to add suggestions to the dataset. In our case, we will generate them using a zero-shot SetFit model. However, you can use a framework or technique of your choice."
"The next step is to add suggestions to the dataset. This will make things easier and faster for the annotation team. Suggestions will appear as preselected options, so annotators will only need to correct them. In our case, we will generate them using a zero-shot SetFit model. However, you can use a framework or technique of your choice."
]
},
{
Expand All @@ -250,7 +250,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -272,12 +272,11 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def train_model(model_name, dataset):\n",
" \n",
" model = SetFitModel.from_pretrained(model_name)\n",
"\n",
" trainer = Trainer(\n",
Expand All @@ -286,7 +285,7 @@
" )\n",
"\n",
" trainer.train()\n",
" \n",
"\n",
" return model"
]
},
Expand Down Expand Up @@ -337,16 +336,15 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def predict(model, input, labels):\n",
" \n",
" model.labels = labels\n",
" \n",
"\n",
" prediction = model.predict([input])\n",
" \n",
"\n",
" return prediction[0]"
]
},
Expand Down Expand Up @@ -422,7 +420,7 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -431,12 +429,13 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"status_filter = rg.Query(filter = rg.Filter((\"status\", \"==\", \"submitted\")))\n",
"submitted = list(dataset.records(status_filter))"
"status_filter = rg.Query(filter=rg.Filter((\"response.status\", \"==\", \"submitted\")))\n",
"\n",
"submitted = dataset.records(status_filter).to_list(flatten=True)"
]
},
{
Expand All @@ -452,10 +451,12 @@
"metadata": {},
"outputs": [],
"source": [
"train_records = [{\n",
" \"text\" : r.fields[\"review\"],\n",
" \"label\" : r.responses.sentiment_label[0].value,\n",
" } for r in submitted\n",
"train_records = [\n",
" {\n",
" \"text\": r[\"review\"],\n",
" \"label\": r[\"sentiment_label.responses\"][0],\n",
" }\n",
" for r in submitted\n",
"]\n",
"train_dataset = Dataset.from_list(train_records)\n",
"train_dataset = sample_dataset(train_dataset, label_column=\"label\", num_samples=8)"
Expand Down
Loading
Loading