-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cleanup and nicer examples for classifier scores
- Loading branch information
1 parent
63979d6
commit 1474161
Showing
7 changed files
with
239 additions
and
38 deletions.
There are no files selected for viewing
133 changes: 133 additions & 0 deletions
133
examples/Getting probability scores out of LLM classification.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "45ebdda2", | ||
"metadata": {}, | ||
"source": [ | ||
"# Getting probability scores out of LLM classification\n", | ||
"\n", | ||
"When comparing traditional ML classifiers to LLM-based ones, a common problem is that most classifier performance metrics require a vector of confidence/probability scores across the available options, not just the most likely answer. \n", | ||
"\n", | ||
"Fortunately, eg the OpenAI API allows to query token logprobs for up to 20 most likely tokens in each position of its response. \n", | ||
"These still need to be masked (discarding irrelevant options), converted to probabilities, and normalized to sum to one. \n", | ||
"\n", | ||
"To spare you the hassle of doing this, we provide two functions, a binary classifier (which expects a yes/no question), and a multiple-choice classifier that expects a multiple-choice question and a list of valid options. It also has an optional boolean argument `include_other`, which if true makes the classifier also include an \"Other\" option in its output, for when none of the valid options fit. \n", | ||
"\n", | ||
"To keep it simple, the multiple chocice classifier only supports up to 9 choice options, so the LLM output can be a single digit (for speed and parsing simplicity). Feel free to contribute a version that supports a larger number of choices! ;)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "f32f114d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pprint import pprint\n", | ||
"\n", | ||
"from langchain_openai import ChatOpenAI\n", | ||
"\n", | ||
"try:\n", | ||
" import wise_topic\n", | ||
"except ImportError:\n", | ||
" import os, sys\n", | ||
" sys.path.append(os.path.realpath(\"..\"))\n", | ||
"\n", | ||
"\n", | ||
"from wise_topic import llm_classifier_binary, llm_classifier_multiple\n", | ||
"llm = ChatOpenAI(model=\"gpt-4-turbo\", temperature=0)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "d7288d4b", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{False: 0.03559647724243312, True: 0.9644035227575669}" | ||
] | ||
}, | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"question1 = \"Consider a very friendly pet with a fluffy tail. You know it's a cat or a dog. Is it a cat?\"\n", | ||
"llm_classifier_binary(llm, question1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "c3081966", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{False: 0.9999912515146222, True: 8.748485377892584e-06}" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"question2 = \"Consider a very friendly pet with a waggy tail. You know it's a cat or a dog. Is it a cat?\"\n", | ||
"llm_classifier_binary(llm, question2)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "0689d004", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'cat': 0.9372176977942116,\n", | ||
" 'dog': 0.062782248112413,\n", | ||
" 'dragon': 5.215838794110004e-09,\n", | ||
" 'duck': 4.887753666874768e-08}" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"question3 = \"Consider a very friendly pet with a fluffy tail. You know it's a cat, a dog, or a dragon. Which is it?\"\n", | ||
"llm_classifier_multiple(llm, question3, [\"cat\", \"dog\", \"dragon\", \"duck\"], include_other=False)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python [conda env:llm3.11]", | ||
"language": "python", | ||
"name": "conda-env-llm3.11-py" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,20 @@ | ||
from langchain_openai import ChatOpenAI | ||
|
||
from langchain_core.prompts import ChatPromptTemplate | ||
|
||
from wise_topic import llm_classifier, binary_prompt | ||
from wise_topic import llm_classifier_binary, llm_classifier_multiple | ||
|
||
|
||
llm = ChatOpenAI(model="gpt-4-turbo", temperature=0) | ||
question = "Consider a very friendly pet with a fluffy tail. You know it's a cat or a dog. Is it a cat?" | ||
prompt_value = ( | ||
ChatPromptTemplate.from_template(binary_prompt) | ||
.invoke({"question": question}) | ||
.to_messages() | ||
) | ||
question1 = "Consider a very friendly pet with a fluffy tail. You know it's a cat or a dog. Is it a cat?" | ||
question2 = "Consider a very friendly pet with a waggy tail. You know it's a cat or a dog. Is it a cat?" | ||
for question in [question1, question2]: | ||
out = llm_classifier_binary(llm, question) | ||
print(question) | ||
print(out) | ||
|
||
out = llm_classifier(llm, [prompt_value], ["0", "1"]) | ||
|
||
question3 = "Consider a very friendly pet with a long tail. You know it's a cat, a dog, or a dragon. Which is it?" | ||
out = llm_classifier_multiple(llm, question3, ["cat", "dog", "dragon"]) | ||
print(question3) | ||
print(out) | ||
|
||
print("done!") |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from wise_topic.topic.greedy import greedy_topic_tree, tree_summary | ||
from wise_topic.classifier.classifier import llm_classifier | ||
from wise_topic.classifier.prompts import binary_prompt | ||
from wise_topic.classifier.classifier import ( | ||
llm_classifier_binary, | ||
llm_classifier_multiple, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters