Skip to content

Commit

Permalink
Cleanup and nicer examples for classifier scores
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorKraevTransferwise committed Jun 6, 2024
1 parent 63979d6 commit 1474161
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 38 deletions.
133 changes: 133 additions & 0 deletions examples/Getting probability scores out of LLM classification.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "45ebdda2",
"metadata": {},
"source": [
"# Getting probability scores out of LLM classification\n",
"\n",
"When comparing traditional ML classifiers to LLM-based ones, a common problem is that most classifier performance metrics require a vector of confidence/probability scores across the available options, not just the most likely answer. \n",
"\n",
"Fortunately, eg the OpenAI API allows to query token logprobs for up to 20 most likely tokens in each position of its response. \n",
"These still need to be masked (discarding irrelevant options), converted to probabilities, and normalized to sum to one. \n",
"\n",
"To spare you the hassle of doing this, we provide two functions, a binary classifier (which expects a yes/no question), and a multiple-choice classifier that expects a multiple-choice question and a list of valid options. It also has an optional boolean argument `include_other`, which if true makes the classifier also include an \"Other\" option in its output, for when none of the valid options fit. \n",
"\n",
"To keep it simple, the multiple chocice classifier only supports up to 9 choice options, so the LLM output can be a single digit (for speed and parsing simplicity). Feel free to contribute a version that supports a larger number of choices! ;)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f32f114d",
"metadata": {},
"outputs": [],
"source": [
"from pprint import pprint\n",
"\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"try:\n",
" import wise_topic\n",
"except ImportError:\n",
" import os, sys\n",
" sys.path.append(os.path.realpath(\"..\"))\n",
"\n",
"\n",
"from wise_topic import llm_classifier_binary, llm_classifier_multiple\n",
"llm = ChatOpenAI(model=\"gpt-4-turbo\", temperature=0)\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d7288d4b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{False: 0.03559647724243312, True: 0.9644035227575669}"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"question1 = \"Consider a very friendly pet with a fluffy tail. You know it's a cat or a dog. Is it a cat?\"\n",
"llm_classifier_binary(llm, question1)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c3081966",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{False: 0.9999912515146222, True: 8.748485377892584e-06}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"question2 = \"Consider a very friendly pet with a waggy tail. You know it's a cat or a dog. Is it a cat?\"\n",
"llm_classifier_binary(llm, question2)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0689d004",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'cat': 0.9372176977942116,\n",
" 'dog': 0.062782248112413,\n",
" 'dragon': 5.215838794110004e-09,\n",
" 'duck': 4.887753666874768e-08}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"question3 = \"Consider a very friendly pet with a fluffy tail. You know it's a cat, a dog, or a dragon. Which is it?\"\n",
"llm_classifier_multiple(llm, question3, [\"cat\", \"dog\", \"dragon\", \"duck\"], include_other=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:llm3.11]",
"language": "python",
"name": "conda-env-llm3.11-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
30 changes: 15 additions & 15 deletions examples/Topic extraction using LLMs only.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"id": "f32f114d",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -55,17 +55,17 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"id": "d7288d4b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'Winter': 5, 'Summer': 5, 'fish and their habitats and behaviors': 5}"
"{'Summer': 6, 'Winter': 5, 'threats to diverse fish species': 4}"
]
},
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -94,27 +94,27 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"id": "a4757ef7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'Winter': {'messages': ['Winter brings the joy of snowfall and the excitement of skiing.',\n",
"{'Summer': {'messages': ['During summer, the days are long and the nights are warm and inviting.',\n",
" 'Tropical fish add vibrant color and life to coral reefs.',\n",
" 'The summer sun blazed high in the sky, bringing warmth to the sandy beaches.',\n",
" 'Ice cream sales soar as people seek relief from the summer heat.',\n",
" 'Many festivals and outdoor concerts are scheduled in the summer months.',\n",
" 'Families often choose summer for vacations to take advantage of the sunny weather.']},\n",
" 'Winter': {'messages': ['Winter storms can transform the landscape into a snowy wonderland.',\n",
" 'Winter brings the joy of snowfall and the excitement of skiing.',\n",
" 'The cold winter nights are perfect for sipping hot chocolate by the fire.',\n",
" 'Many animals hibernate or migrate to cope with the harsh winter conditions.',\n",
" 'Winter storms can transform the landscape into a snowy wonderland.',\n",
" \"Heating bills tend to rise as winter's chill sets in.\"]},\n",
" 'Summer': {'messages': ['Families often choose summer for vacations to take advantage of the sunny weather.',\n",
" 'Many festivals and outdoor concerts are scheduled in the summer months.',\n",
" 'The summer sun blazed high in the sky, bringing warmth to the sandy beaches.',\n",
" 'During summer, the days are long and the nights are warm and inviting.',\n",
" 'Ice cream sales soar as people seek relief from the summer heat.']},\n",
" 'fish and their habitats and behaviors': {'messages': ['Overfishing threatens many species of fish with extinction.',\n",
" 'Fish swim in schools to protect themselves from predators.',\n",
" 'threats to diverse fish species': {'messages': ['Fish swim in schools to protect themselves from predators.',\n",
" 'Fish have a diverse range of habitats from deep oceans to shallow streams.',\n",
" 'Tropical fish add vibrant color and life to coral reefs.',\n",
" 'Overfishing threatens many species of fish with extinction.',\n",
" 'Salmon migrate upstream during spawning season, a remarkable journey.']}}"
]
},
Expand Down
22 changes: 12 additions & 10 deletions examples/classifier_scores.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from langchain_openai import ChatOpenAI

from langchain_core.prompts import ChatPromptTemplate

from wise_topic import llm_classifier, binary_prompt
from wise_topic import llm_classifier_binary, llm_classifier_multiple


llm = ChatOpenAI(model="gpt-4-turbo", temperature=0)
question = "Consider a very friendly pet with a fluffy tail. You know it's a cat or a dog. Is it a cat?"
prompt_value = (
ChatPromptTemplate.from_template(binary_prompt)
.invoke({"question": question})
.to_messages()
)
question1 = "Consider a very friendly pet with a fluffy tail. You know it's a cat or a dog. Is it a cat?"
question2 = "Consider a very friendly pet with a waggy tail. You know it's a cat or a dog. Is it a cat?"
for question in [question1, question2]:
out = llm_classifier_binary(llm, question)
print(question)
print(out)

out = llm_classifier(llm, [prompt_value], ["0", "1"])

question3 = "Consider a very friendly pet with a long tail. You know it's a cat, a dog, or a dragon. Which is it?"
out = llm_classifier_multiple(llm, question3, ["cat", "dog", "dragon"])
print(question3)
print(out)

print("done!")
File renamed without changes.
6 changes: 4 additions & 2 deletions wise_topic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from wise_topic.topic.greedy import greedy_topic_tree, tree_summary
from wise_topic.classifier.classifier import llm_classifier
from wise_topic.classifier.prompts import binary_prompt
from wise_topic.classifier.classifier import (
llm_classifier_binary,
llm_classifier_multiple,
)
53 changes: 50 additions & 3 deletions wise_topic/classifier/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,55 @@
from langchain_core.prompts import PromptTemplate
from langchain_core.language_models import BaseChatModel

from wise_topic.classifier.prompts import binary_prompt, multi_choice_prompt

def llm_classifier(llm: BaseChatModel, messages, valid_options, top_logprobs=5, max_tokens=1):

def llm_classifier_binary(llm: BaseChatModel, question: str):
prompt_value = (
PromptTemplate.from_template(binary_prompt)
.invoke({"question": question})
.to_messages()
)
out = llm_classifier(llm, [prompt_value], ["0", "1"])
return {False: out[0], True: out[1]}


def llm_classifier_multiple(
llm: BaseChatModel,
question: str,
answer_options: List[str],
include_other: bool = False,
):

assert (
len(answer_options) <= 9
), "Only up to 9 answer options are supported at the moment"
categories = "\n".join([f"{i + 1}. {t}" for i, t in enumerate(answer_options)])
prompt_value = (
PromptTemplate.from_template(multi_choice_prompt(include_other))
.invoke({"question": question, "categories": categories})
.to_messages()
)

valid_outputs = [str(i) for i in range(len(answer_options) + 1)]
scores = llm_classifier(
llm,
[prompt_value],
valid_outputs if include_other else valid_outputs[1:],
top_logprobs=15,
)
if include_other:
used_options = ["Other"] + list(answer_options)
else:
used_options = list(answer_options)

out = {k: v for k, v in zip(used_options, scores)}
return out


def llm_classifier(
llm: BaseChatModel, messages, valid_options, top_logprobs=5, max_tokens=1
) -> np.ndarray:
result = llm.generate(
messages,
logprobs=True,
Expand All @@ -22,12 +69,12 @@ def llm_classifier(llm: BaseChatModel, messages, valid_options, top_logprobs=5,
return scores


def logprobs_to_scores(logprobs, valid_options: List[str]):
def logprobs_to_scores(logprobs, valid_options: List[str]) -> np.ndarray:
scores = np.array(len(valid_options) * [float("-inf")])
matches = False
for i, c in enumerate(valid_options):
for p in logprobs:
if isinstance(p, dict): # Langchain interface
if isinstance(p, dict): # Langchain interface
token = p["token"]
logprob = p["logprob"]
else: # OpenAI interface
Expand Down
33 changes: 25 additions & 8 deletions wise_topic/classifier/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,28 @@
Take a deep breath and think carefully before you make your reply.
"""

multi_choice_prompt = """
{question}
Choose the most suitable option from the following:
{numbered_options}
Return the number of the most suitable option.
Return just the one character digit, nothing else.
Take a deep breath and think carefully before you make your reply.
"""

def multi_choice_prompt(include_other: bool):
out = (
"""I am about to give you a numbered list of options.
Then I will pass to you a message (possibly, but not necessarily, a question),
after the word MESSAGE.
Return an integer that is the number of the option that best fits that message,
or if the message is a question, the number of the option that best answers the question.
"""
+ (
"""
If no option fits the message, return 0.
"""
if include_other
else ""
)
+ """
Return only the number, without additional text.
{categories}
MESSAGE:
{question}
Take a deep breath and think carefully before you make your reply.
BEST MATCH OPTION NUMBER:"""
)
return out

0 comments on commit 1474161

Please sign in to comment.