Skip to content

Commit

Permalink
docs: feedback tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
sdiazlor committed Jul 11, 2024
1 parent 164838d commit 5681e40
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 89 deletions.
28 changes: 13 additions & 15 deletions argilla/docs/tutorials/text_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Text classification task"
"# Text classification"
]
},
{
Expand Down Expand Up @@ -119,7 +119,7 @@
"# Uncomment the last line and set your HF_TOKEN if your space is private\n",
"client = rg.Argilla(\n",
" api_url=\"https://[your-owner-name]-[your_space_name].hf.space\",\n",
" api_key=\"owner.apikey\"\n",
" api_key=\"owner.apikey\",\n",
" # headers={\"Authorization\": f\"Bearer {HF_TOKEN}\"}\n",
")"
]
Expand Down Expand Up @@ -238,7 +238,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The next step is to add suggestions to the dataset. In our case, we will generate them using a zero-shot SetFit model. However, you can use a framework or technique of your choice."
"The next step is to add suggestions to the dataset. This will make things easier and faster for the annotation team. Suggestions will appear as preselected options, so annotators will only need to correct them. In our case, we will generate them using a zero-shot SetFit model. However, you can use a framework or technique of your choice."
]
},
{
Expand Down Expand Up @@ -277,7 +277,6 @@
"outputs": [],
"source": [
"def train_model(model_name, dataset):\n",
" \n",
" model = SetFitModel.from_pretrained(model_name)\n",
"\n",
" trainer = Trainer(\n",
Expand All @@ -286,7 +285,7 @@
" )\n",
"\n",
" trainer.train()\n",
" \n",
"\n",
" return model"
]
},
Expand Down Expand Up @@ -342,11 +341,10 @@
"outputs": [],
"source": [
"def predict(model, input, labels):\n",
" \n",
" model.labels = labels\n",
" \n",
"\n",
" prediction = model.predict([input])\n",
" \n",
"\n",
" return prediction[0]"
]
},
Expand Down Expand Up @@ -435,9 +433,7 @@
"metadata": {},
"outputs": [],
"source": [
"status_filter = rg.Query(\n",
" filter=rg.Filter((\"response.status\", \"==\", \"submitted\"))\n",
")\n",
"status_filter = rg.Query(filter=rg.Filter((\"response.status\", \"==\", \"submitted\")))\n",
"\n",
"submitted = dataset.records(status_filter).to_list(flatten=True)"
]
Expand All @@ -455,10 +451,12 @@
"metadata": {},
"outputs": [],
"source": [
"train_records = [{\n",
" \"text\" : r[\"review\"],\n",
" \"label\" : r[\"sentiment_label.responses\"][0],\n",
" } for r in submitted\n",
"train_records = [\n",
" {\n",
" \"text\": r[\"review\"],\n",
" \"label\": r[\"sentiment_label.responses\"][0],\n",
" }\n",
" for r in submitted\n",
"]\n",
"train_dataset = Dataset.from_list(train_records)\n",
"train_dataset = sample_dataset(train_dataset, label_column=\"label\", num_samples=8)"
Expand Down
Loading

0 comments on commit 5681e40

Please sign in to comment.