diff --git a/examples/Getting probability scores out of LLM classification.ipynb b/examples/Getting probability scores out of LLM classification.ipynb new file mode 100644 index 0000000..cf5f17a --- /dev/null +++ b/examples/Getting probability scores out of LLM classification.ipynb @@ -0,0 +1,133 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "45ebdda2", + "metadata": {}, + "source": [ + "# Getting probability scores out of LLM classification\n", + "\n", + "When comparing traditional ML classifiers to LLM-based ones, a common problem is that most classifier performance metrics require a vector of confidence/probability scores across the available options, not just the most likely answer. \n", + "\n", + "Fortunately, eg the OpenAI API allows to query token logprobs for up to 20 most likely tokens in each position of its response. \n", + "These still need to be masked (discarding irrelevant options), converted to probabilities, and normalized to sum to one. \n", + "\n", + "To spare you the hassle of doing this, we provide two functions, a binary classifier (which expects a yes/no question), and a multiple-choice classifier that expects a multiple-choice question and a list of valid options. It also has an optional boolean argument `include_other`, which if true makes the classifier also include an \"Other\" option in its output, for when none of the valid options fit. \n", + "\n", + "To keep it simple, the multiple chocice classifier only supports up to 9 choice options, so the LLM output can be a single digit (for speed and parsing simplicity). Feel free to contribute a version that supports a larger number of choices! ;)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f32f114d", + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "\n", + "from langchain_openai import ChatOpenAI\n", + "\n", + "try:\n", + " import wise_topic\n", + "except ImportError:\n", + " import os, sys\n", + " sys.path.append(os.path.realpath(\"..\"))\n", + "\n", + "\n", + "from wise_topic import llm_classifier_binary, llm_classifier_multiple\n", + "llm = ChatOpenAI(model=\"gpt-4-turbo\", temperature=0)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d7288d4b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{False: 0.03559647724243312, True: 0.9644035227575669}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "question1 = \"Consider a very friendly pet with a fluffy tail. You know it's a cat or a dog. Is it a cat?\"\n", + "llm_classifier_binary(llm, question1)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c3081966", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{False: 0.9999912515146222, True: 8.748485377892584e-06}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "question2 = \"Consider a very friendly pet with a waggy tail. You know it's a cat or a dog. Is it a cat?\"\n", + "llm_classifier_binary(llm, question2)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0689d004", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'cat': 0.9372176977942116,\n", + " 'dog': 0.062782248112413,\n", + " 'dragon': 5.215838794110004e-09,\n", + " 'duck': 4.887753666874768e-08}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "question3 = \"Consider a very friendly pet with a fluffy tail. You know it's a cat, a dog, or a dragon. Which is it?\"\n", + "llm_classifier_multiple(llm, question3, [\"cat\", \"dog\", \"dragon\", \"duck\"], include_other=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:llm3.11]", + "language": "python", + "name": "conda-env-llm3.11-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/Topic extraction using LLMs only.ipynb b/examples/Topic extraction using LLMs only.ipynb index 035bcf5..61e44c6 100644 --- a/examples/Topic extraction using LLMs only.ipynb +++ b/examples/Topic extraction using LLMs only.ipynb @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "id": "f32f114d", "metadata": {}, "outputs": [], @@ -55,17 +55,17 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "d7288d4b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'Winter': 5, 'Summer': 5, 'fish and their habitats and behaviors': 5}" + "{'Summer': 6, 'Winter': 5, 'threats to diverse fish species': 4}" ] }, - "execution_count": 4, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -94,27 +94,27 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "id": "a4757ef7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'Winter': {'messages': ['Winter brings the joy of snowfall and the excitement of skiing.',\n", + "{'Summer': {'messages': ['During summer, the days are long and the nights are warm and inviting.',\n", + " 'Tropical fish add vibrant color and life to coral reefs.',\n", + " 'The summer sun blazed high in the sky, bringing warmth to the sandy beaches.',\n", + " 'Ice cream sales soar as people seek relief from the summer heat.',\n", + " 'Many festivals and outdoor concerts are scheduled in the summer months.',\n", + " 'Families often choose summer for vacations to take advantage of the sunny weather.']},\n", + " 'Winter': {'messages': ['Winter storms can transform the landscape into a snowy wonderland.',\n", + " 'Winter brings the joy of snowfall and the excitement of skiing.',\n", " 'The cold winter nights are perfect for sipping hot chocolate by the fire.',\n", " 'Many animals hibernate or migrate to cope with the harsh winter conditions.',\n", - " 'Winter storms can transform the landscape into a snowy wonderland.',\n", " \"Heating bills tend to rise as winter's chill sets in.\"]},\n", - " 'Summer': {'messages': ['Families often choose summer for vacations to take advantage of the sunny weather.',\n", - " 'Many festivals and outdoor concerts are scheduled in the summer months.',\n", - " 'The summer sun blazed high in the sky, bringing warmth to the sandy beaches.',\n", - " 'During summer, the days are long and the nights are warm and inviting.',\n", - " 'Ice cream sales soar as people seek relief from the summer heat.']},\n", - " 'fish and their habitats and behaviors': {'messages': ['Overfishing threatens many species of fish with extinction.',\n", - " 'Fish swim in schools to protect themselves from predators.',\n", + " 'threats to diverse fish species': {'messages': ['Fish swim in schools to protect themselves from predators.',\n", " 'Fish have a diverse range of habitats from deep oceans to shallow streams.',\n", - " 'Tropical fish add vibrant color and life to coral reefs.',\n", + " 'Overfishing threatens many species of fish with extinction.',\n", " 'Salmon migrate upstream during spawning season, a remarkable journey.']}}" ] }, diff --git a/examples/classifier_scores.py b/examples/classifier_scores.py index 0d69e67..e11e74b 100644 --- a/examples/classifier_scores.py +++ b/examples/classifier_scores.py @@ -1,18 +1,20 @@ from langchain_openai import ChatOpenAI -from langchain_core.prompts import ChatPromptTemplate - -from wise_topic import llm_classifier, binary_prompt +from wise_topic import llm_classifier_binary, llm_classifier_multiple llm = ChatOpenAI(model="gpt-4-turbo", temperature=0) -question = "Consider a very friendly pet with a fluffy tail. You know it's a cat or a dog. Is it a cat?" -prompt_value = ( - ChatPromptTemplate.from_template(binary_prompt) - .invoke({"question": question}) - .to_messages() -) +question1 = "Consider a very friendly pet with a fluffy tail. You know it's a cat or a dog. Is it a cat?" +question2 = "Consider a very friendly pet with a waggy tail. You know it's a cat or a dog. Is it a cat?" +for question in [question1, question2]: + out = llm_classifier_binary(llm, question) + print(question) + print(out) -out = llm_classifier(llm, [prompt_value], ["0", "1"]) +question3 = "Consider a very friendly pet with a long tail. You know it's a cat, a dog, or a dragon. Which is it?" +out = llm_classifier_multiple(llm, question3, ["cat", "dog", "dragon"]) +print(question3) print(out) + +print("done!") diff --git a/examples/simple.py b/examples/topic_extraction.py similarity index 100% rename from examples/simple.py rename to examples/topic_extraction.py diff --git a/wise_topic/__init__.py b/wise_topic/__init__.py index 05e7d2e..7a4ebb4 100644 --- a/wise_topic/__init__.py +++ b/wise_topic/__init__.py @@ -1,3 +1,5 @@ from wise_topic.topic.greedy import greedy_topic_tree, tree_summary -from wise_topic.classifier.classifier import llm_classifier -from wise_topic.classifier.prompts import binary_prompt +from wise_topic.classifier.classifier import ( + llm_classifier_binary, + llm_classifier_multiple, +) diff --git a/wise_topic/classifier/classifier.py b/wise_topic/classifier/classifier.py index 3d06a02..d157f9f 100644 --- a/wise_topic/classifier/classifier.py +++ b/wise_topic/classifier/classifier.py @@ -6,8 +6,55 @@ from langchain_core.prompts import PromptTemplate from langchain_core.language_models import BaseChatModel +from wise_topic.classifier.prompts import binary_prompt, multi_choice_prompt -def llm_classifier(llm: BaseChatModel, messages, valid_options, top_logprobs=5, max_tokens=1): + +def llm_classifier_binary(llm: BaseChatModel, question: str): + prompt_value = ( + PromptTemplate.from_template(binary_prompt) + .invoke({"question": question}) + .to_messages() + ) + out = llm_classifier(llm, [prompt_value], ["0", "1"]) + return {False: out[0], True: out[1]} + + +def llm_classifier_multiple( + llm: BaseChatModel, + question: str, + answer_options: List[str], + include_other: bool = False, +): + + assert ( + len(answer_options) <= 9 + ), "Only up to 9 answer options are supported at the moment" + categories = "\n".join([f"{i + 1}. {t}" for i, t in enumerate(answer_options)]) + prompt_value = ( + PromptTemplate.from_template(multi_choice_prompt(include_other)) + .invoke({"question": question, "categories": categories}) + .to_messages() + ) + + valid_outputs = [str(i) for i in range(len(answer_options) + 1)] + scores = llm_classifier( + llm, + [prompt_value], + valid_outputs if include_other else valid_outputs[1:], + top_logprobs=15, + ) + if include_other: + used_options = ["Other"] + list(answer_options) + else: + used_options = list(answer_options) + + out = {k: v for k, v in zip(used_options, scores)} + return out + + +def llm_classifier( + llm: BaseChatModel, messages, valid_options, top_logprobs=5, max_tokens=1 +) -> np.ndarray: result = llm.generate( messages, logprobs=True, @@ -22,12 +69,12 @@ def llm_classifier(llm: BaseChatModel, messages, valid_options, top_logprobs=5, return scores -def logprobs_to_scores(logprobs, valid_options: List[str]): +def logprobs_to_scores(logprobs, valid_options: List[str]) -> np.ndarray: scores = np.array(len(valid_options) * [float("-inf")]) matches = False for i, c in enumerate(valid_options): for p in logprobs: - if isinstance(p, dict): # Langchain interface + if isinstance(p, dict): # Langchain interface token = p["token"] logprob = p["logprob"] else: # OpenAI interface diff --git a/wise_topic/classifier/prompts.py b/wise_topic/classifier/prompts.py index ab7ee88..65b70c6 100644 --- a/wise_topic/classifier/prompts.py +++ b/wise_topic/classifier/prompts.py @@ -5,11 +5,28 @@ Take a deep breath and think carefully before you make your reply. """ -multi_choice_prompt = """ -{question} -Choose the most suitable option from the following: -{numbered_options} -Return the number of the most suitable option. -Return just the one character digit, nothing else. -Take a deep breath and think carefully before you make your reply. -""" + +def multi_choice_prompt(include_other: bool): + out = ( + """I am about to give you a numbered list of options. + Then I will pass to you a message (possibly, but not necessarily, a question), + after the word MESSAGE. + Return an integer that is the number of the option that best fits that message, + or if the message is a question, the number of the option that best answers the question. + """ + + ( + """ + If no option fits the message, return 0. + """ + if include_other + else "" + ) + + """ + Return only the number, without additional text. + {categories} + MESSAGE: + {question} + Take a deep breath and think carefully before you make your reply. + BEST MATCH OPTION NUMBER:""" + ) + return out