diff --git a/docs/docs/how_to/graph_constructing.ipynb b/docs/docs/how_to/graph_constructing.ipynb index 5ca45d736453f..d36ca54c8cf24 100644 --- a/docs/docs/how_to/graph_constructing.ipynb +++ b/docs/docs/how_to/graph_constructing.ipynb @@ -52,7 +52,7 @@ } ], "source": [ - "%pip install --upgrade --quiet langchain langchain-community langchain-openai langchain-experimental neo4j" + "%pip install --upgrade --quiet langchain langchain-neo4j langchain-openai langchain-experimental neo4j" ] }, { @@ -102,7 +102,7 @@ "source": [ "import os\n", "\n", - "from langchain_community.graphs import Neo4jGraph\n", + "from langchain_neo4j import Neo4jGraph\n", "\n", "os.environ[\"NEO4J_URI\"] = \"bolt://localhost:7687\"\n", "os.environ[\"NEO4J_USERNAME\"] = \"neo4j\"\n", diff --git a/docs/docs/how_to/graph_mapping.ipynb b/docs/docs/how_to/graph_mapping.ipynb index cd98ca00b67a3..146f479e27d32 100644 --- a/docs/docs/how_to/graph_mapping.ipynb +++ b/docs/docs/how_to/graph_mapping.ipynb @@ -33,7 +33,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install --upgrade --quiet langchain langchain-community langchain-openai neo4j" + "%pip install --upgrade --quiet langchain langchain-neo4j langchain-openai neo4j" ] }, { @@ -116,7 +116,7 @@ } ], "source": [ - "from langchain_community.graphs import Neo4jGraph\n", + "from langchain_neo4j import Neo4jGraph\n", "\n", "graph = Neo4jGraph()\n", "\n", @@ -364,11 +364,12 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain_community.chains.graph_qa.cypher_utils import (\n", + "from langchain_neo4j.chains.graph_qa.cypher_utils import (\n", " CypherQueryCorrector,\n", " Schema,\n", ")\n", "\n", + "graph.refresh_schema()\n", "# Cypher validation tool for relationship directions\n", "corrector_schema = [\n", " Schema(el[\"start\"], el[\"type\"], el[\"end\"])\n", diff --git a/docs/docs/how_to/graph_prompting.ipynb b/docs/docs/how_to/graph_prompting.ipynb index 0b83559e7e195..db4922fb3a2da 100644 --- a/docs/docs/how_to/graph_prompting.ipynb +++ b/docs/docs/how_to/graph_prompting.ipynb @@ -36,7 +36,7 @@ } ], "source": [ - "%pip install --upgrade --quiet langchain langchain-community langchain-openai neo4j" + "%pip install --upgrade --quiet langchain langchain-neo4j langchain-openai neo4j" ] }, { @@ -113,7 +113,7 @@ } ], "source": [ - "from langchain_community.graphs import Neo4jGraph\n", + "from langchain_neo4j import Neo4jGraph\n", "\n", "graph = Neo4jGraph()\n", "\n", @@ -188,12 +188,16 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.chains import GraphCypherQAChain\n", + "from langchain_neo4j import GraphCypherQAChain\n", "from langchain_openai import ChatOpenAI\n", "\n", "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n", "chain = GraphCypherQAChain.from_llm(\n", - " graph=graph, llm=llm, exclude_types=[\"Genre\"], verbose=True\n", + " graph=graph,\n", + " llm=llm,\n", + " exclude_types=[\"Genre\"],\n", + " verbose=True,\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -356,8 +360,8 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain_community.vectorstores import Neo4jVector\n", "from langchain_core.example_selectors import SemanticSimilarityExampleSelector\n", + "from langchain_neo4j import Neo4jVector\n", "from langchain_openai import OpenAIEmbeddings\n", "\n", "example_selector = SemanticSimilarityExampleSelector.from_examples(\n", @@ -468,7 +472,11 @@ "source": [ "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n", "chain = GraphCypherQAChain.from_llm(\n", - " graph=graph, llm=llm, cypher_prompt=prompt, verbose=True\n", + " graph=graph,\n", + " llm=llm,\n", + " cypher_prompt=prompt,\n", + " verbose=True,\n", + " allow_dangerous_requests=True,\n", ")" ] }, diff --git a/docs/docs/how_to/graph_semantic.ipynb b/docs/docs/how_to/graph_semantic.ipynb index 94578939d7b48..4cd717d861297 100644 --- a/docs/docs/how_to/graph_semantic.ipynb +++ b/docs/docs/how_to/graph_semantic.ipynb @@ -44,7 +44,7 @@ } ], "source": [ - "%pip install --upgrade --quiet langchain langchain-community langchain-openai neo4j" + "%pip install --upgrade --quiet langchain langchain-neo4j langchain-openai neo4j" ] }, { @@ -127,7 +127,7 @@ } ], "source": [ - "from langchain_community.graphs import Neo4jGraph\n", + "from langchain_neo4j import Neo4jGraph\n", "\n", "graph = Neo4jGraph()\n", "\n", @@ -242,8 +242,8 @@ "\n", "\n", "class InformationTool(BaseTool):\n", - " name = \"Information\"\n", - " description = (\n", + " name: str = \"Information\"\n", + " description: str = (\n", " \"useful for when you need to answer questions about various actors or movies\"\n", " )\n", " args_schema: Type[BaseModel] = InformationInput\n", diff --git a/docs/docs/integrations/chat/writer.ipynb b/docs/docs/integrations/chat/writer.ipynb index a76752ef2f64c..d7a47b9c4767a 100644 --- a/docs/docs/integrations/chat/writer.ipynb +++ b/docs/docs/integrations/chat/writer.ipynb @@ -17,7 +17,7 @@ "source": [ "# ChatWriter\n", "\n", - "This notebook provides a quick overview for getting started with Writer [chat models](/docs/concepts/chat_models).\n", + "This notebook provides a quick overview for getting started with Writer [chat models](/docs/concepts/#chat-models).\n", "\n", "Writer has several chat models. You can find information about their latest models and their costs, context windows, and supported input types in the [Writer docs](https://dev.writer.com/home/models).\n", "\n", @@ -25,21 +25,20 @@ ] }, { - "cell_type": "markdown", - "id": "e49f1e0d", "metadata": {}, + "cell_type": "markdown", "source": [ "## Overview\n", "\n", "### Integration details\n", - "| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/openai) | Package downloads | Package latest |\n", - "| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", - "| ChatWriter | langchain-community | ❌ | ❌ | ❌ | ❌ | ❌ |\n", + "| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n", + "| :--- | :--- | :---: | :---: |:----------:| :---: | :---: |\n", + "| ChatWriter | langchain-community | ❌ | ❌ | ❌ | ❌ | ❌ |\n", "\n", "### Model features\n", - "| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | Image input | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", - "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", - "| ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | \n", + "| [Tool calling](/docs/how_to/tool_calling) | Structured output | JSON mode | Image input | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | Logprobs |\n", + "| :---: |:-----------------:| :---: | :---: | :---: | :---: | :---: | :---: |:--------------------------------:|:--------:|\n", + "| ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ |\n", "\n", "## Setup\n", "\n", @@ -48,15 +47,16 @@ "### Credentials\n", "\n", "Head to [Writer AI Studio](https://app.writer.com/aistudio/signup?utm_campaign=devrel) to sign up to OpenAI and generate an API key. Once you've done this set the WRITER_API_KEY environment variable:" - ] + ], + "id": "617a6e98205ab7c8" }, { "cell_type": "code", "id": "e817fe2e-4f1d-4533-b19e-2400b1cf6ce8", "metadata": { "ExecuteTime": { - "end_time": "2024-10-24T13:51:54.323678Z", - "start_time": "2024-10-24T13:51:42.127404Z" + "end_time": "2024-11-14T09:46:26.800627Z", + "start_time": "2024-11-14T09:27:59.652281Z" } }, "source": [ @@ -64,7 +64,7 @@ "import os\n", "\n", "if not os.environ.get(\"WRITER_API_KEY\"):\n", - " os.environ[\"WRITER_API_KEY\"] = getpass.getpass(\"Enter your Writer API key: \")" + " os.environ[\"WRITER_API_KEY\"] = getpass.getpass(\"Enter your Writer API key:\")" ], "outputs": [], "execution_count": 1 @@ -84,23 +84,24 @@ "id": "2113471c-75d7-45df-b784-d78da4ef7aba", "metadata": { "ExecuteTime": { - "end_time": "2024-10-24T13:52:49.262240Z", - "start_time": "2024-10-24T13:52:47.564879Z" + "end_time": "2024-11-14T09:46:32.415354Z", + "start_time": "2024-11-14T09:46:26.826112Z" } }, - "source": [ - "%pip install -qU langchain-community writer-sdk" - ], + "source": "%pip install -qU langchain-community writer-sdk", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "\r\n", + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m24.2\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m24.3.1\u001B[0m\r\n", + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\r\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], - "execution_count": 4 + "execution_count": 2 }, { "cell_type": "markdown", @@ -118,8 +119,8 @@ "metadata": { "tags": [], "ExecuteTime": { - "end_time": "2024-10-24T13:52:38.822950Z", - "start_time": "2024-10-24T13:52:38.674441Z" + "end_time": "2024-11-14T09:46:33.504711Z", + "start_time": "2024-11-14T09:46:32.574505Z" } }, "source": [ @@ -129,24 +130,10 @@ " model=\"palmyra-x-004\",\n", " temperature=0.7,\n", " max_tokens=1000,\n", - " # api_key=\"...\", # if you prefer to pass api key in directly instaed of using env vars\n", - " # base_url=\"...\",\n", " # other params...\n", ")" ], - "outputs": [ - { - "ename": "ImportError", - "evalue": "cannot import name 'ChatWriter' from 'langchain_community.chat_models' (/home/yanomaly/PycharmProjects/whitesnake/writer/langсhain/libs/community/langchain_community/chat_models/__init__.py)", - "output_type": "error", - "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mImportError\u001B[0m Traceback (most recent call last)", - "Cell \u001B[0;32mIn[3], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mlangchain_community\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mchat_models\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ChatWriter\n\u001B[1;32m 3\u001B[0m llm \u001B[38;5;241m=\u001B[39m ChatWriter(\n\u001B[1;32m 4\u001B[0m model\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mpalmyra-x-004\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[1;32m 5\u001B[0m temperature\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0.7\u001B[39m,\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 9\u001B[0m \u001B[38;5;66;03m# other params...\u001B[39;00m\n\u001B[1;32m 10\u001B[0m )\n", - "\u001B[0;31mImportError\u001B[0m: cannot import name 'ChatWriter' from 'langchain_community.chat_models' (/home/yanomaly/PycharmProjects/whitesnake/writer/langсhain/libs/community/langchain_community/chat_models/__init__.py)" - ] - } - ], + "outputs": [], "execution_count": 3 }, { @@ -159,12 +146,14 @@ }, { "cell_type": "code", - "execution_count": null, "id": "ce16ad78-8e6f-48cd-954e-98be75eb5836", "metadata": { - "tags": [] + "tags": [], + "ExecuteTime": { + "end_time": "2024-11-14T09:46:38.856174Z", + "start_time": "2024-11-14T09:46:33.520062Z" + } }, - "outputs": [], "source": [ "messages = [\n", " (\n", @@ -173,19 +162,127 @@ " ),\n", " (\"human\", \"Write a poem about Python.\"),\n", "]\n", - "ai_msg = llm.invoke(messages)\n", - "ai_msg" - ] + "ai_msg = llm.invoke(messages)" + ], + "outputs": [], + "execution_count": 4 }, { "cell_type": "code", - "execution_count": null, "id": "2cd224b8-4499-41fb-a604-d53a7ff17b2e", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T09:46:38.866651Z", + "start_time": "2024-11-14T09:46:38.863817Z" + } + }, + "source": [ + "print(ai_msg.content)" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In realms of code, where logic weaves and flows,\n", + "A language rises, Python by its name,\n", + "With syntax clear, where elegance it shows,\n", + "A serpent, wise, that time and space can tame.\n", + "\n", + "Born from the mind of Guido, pure and bright,\n", + "Its beauty lies in simplicity and grace,\n", + "A tool of power, yet gentle in its might,\n", + "In every programmer's heart, a cherished place.\n", + "\n", + "It dances through the data, vast and deep,\n", + "With libraries that span the digital realm,\n", + "From machine learning's secrets to keep,\n", + "To web development, it wields the helm.\n", + "\n", + "In the hands of the novice and the sage,\n", + "Python spins the threads of digital dreams,\n", + "A language that can turn the age,\n", + "With a gentle learning curve, its appeal gleams.\n", + "\n", + "It's more than code, a community it builds,\n", + "Where knowledge freely flows, and all are heard,\n", + "In Python's world, the future unfolds,\n", + "A language of the people, for the world.\n", + "\n", + "So here's to Python, in its gentle might,\n", + "A master of the modern coding art,\n", + "May it continue to light our path each night,\n", + "In the vast, evolving world of code, its heart.\n" + ] + } + ], + "execution_count": 5 + }, + { "metadata": {}, + "cell_type": "markdown", + "source": "## Streaming", + "id": "35b3a5b3dabef65" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T09:46:38.914883Z", + "start_time": "2024-11-14T09:46:38.912564Z" + } + }, + "cell_type": "code", + "source": "ai_stream = llm.stream(messages)", + "id": "2725770182bf96dc", "outputs": [], + "execution_count": 6 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T09:46:43.226449Z", + "start_time": "2024-11-14T09:46:38.955512Z" + } + }, + "cell_type": "code", "source": [ - "print(ai_msg.content)" - ] + "for chunk in ai_stream:\n", + " print(chunk.content, end=\"\")" + ], + "id": "a48410d9488162e3", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In realms of code where logic weaves,\n", + "A language rises, Python, it breezes,\n", + "With syntax clear and simple to read,\n", + "Through its elegance, our spirits are fed.\n", + "\n", + "Like rivers flowing, smooth and serene,\n", + "Its structure harmonious, a coder's dream,\n", + "Indentations guide the flow of control,\n", + "In Python's world, confusion takes no toll.\n", + "\n", + "A vast library, a treasure trove so bright,\n", + "For web and data, it offers its might,\n", + "With modules and packages, a rich array,\n", + "Python empowers us to code in play.\n", + "\n", + "From AI to scripts, in flexibility it thrives,\n", + "A language of the future, as many now derive,\n", + "Its community, a beacon of support and cheer,\n", + "With Python, the possibilities are vast, far and near.\n", + "\n", + "So here's to Python, in its gentle grace,\n", + "A tool that enhances, a language that embraces,\n", + "The art of coding, with a fluent, flowing pen,\n", + "In the Python world, we code, and we begin." + ] + } + ], + "execution_count": 7 }, { "cell_type": "markdown", @@ -199,12 +296,14 @@ }, { "cell_type": "code", - "execution_count": null, "id": "fbb043e6", "metadata": { - "tags": [] + "tags": [], + "ExecuteTime": { + "end_time": "2024-11-14T09:46:50.721645Z", + "start_time": "2024-11-14T09:46:43.234590Z" + } }, - "outputs": [], "source": [ "from langchain_core.prompts import ChatPromptTemplate\n", "\n", @@ -225,7 +324,20 @@ " \"input\": \"Write a poem about Java.\",\n", " }\n", ")" - ] + ], + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessageChunk(content='In the realm of code, where logic weaves and flows, \\nA language rises, like a phoenix from the code\\'s throes. \\nJava, the name, a cup of coffee\\'s steam, \\nBrewed in the minds, where digital dreams gleam.\\n\\nWith syntax clear, like morning\\'s misty hue, \\nIn classes and objects, it spins a tale so true. \\nA platform agnostic, with a byte to spare, \\nAcross the devices, it journeys everywhere.\\n\\nInheritance and polymorphism, its power\\'s core, \\nLike ancient runes, in every line they bore. \\nEncapsulation, a shield, with data it does hide, \\nIn the vast jungle of code, it stands as a guide.\\n\\nFrom applets small, to vast, server-side apps, \\nIts threads run swift, through the computing traps. \\nA language of the people, by the people, for the people’s use, \\nBuilt on the principle, \"write once, run anywhere, with no excuse.\"\\n\\nIn the heart of Android, it beats, a steady drum, \\nCrafting experiences, in every smartphone\\'s hum. \\nIn the cloud, in the enterprise, its presence is vast, \\nA cornerstone of computing, built to last.\\n\\nOh Java, thy elegance, thy robust design, \\nA language that stands, in any computing line. \\nWith every update, with every new release, \\nThy community grows, with a vibrant, diverse peace.\\n\\nSo here\\'s to Java, the versatile, the grand, \\nA language that shapes the digital land. \\nMay it continue to evolve, to grow, to inspire, \\nIn the endless quest of turning thoughts into digital fire.', additional_kwargs={}, response_metadata={'token_usage': {'completion_tokens': 345, 'prompt_tokens': 33, 'total_tokens': 378, 'completion_tokens_details': None, 'prompt_token_details': None}, 'model_name': 'palmyra-x-004', 'system_fingerprint': 'v1', 'finish_reason': 'stop'}, id='run-a5b4be59-0eb0-41bd-80f7-72477861b0bd-0')" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 8 }, { "cell_type": "markdown", @@ -251,10 +363,13 @@ }, { "cell_type": "code", - "execution_count": 6, "id": "b7ea7690-ec7a-4337-b392-e87d1f39a6ec", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T09:46:50.891937Z", + "start_time": "2024-11-14T09:46:50.733463Z" + } + }, "source": [ "from pydantic import BaseModel, Field\n", "\n", @@ -266,20 +381,26 @@ "\n", "\n", "llm_with_tools = llm.bind_tools([GetWeather])" - ] + ], + "outputs": [], + "execution_count": 9 }, { "cell_type": "code", - "execution_count": null, "id": "1d1ab955-6a68-42f8-bb5d-86eb1111478a", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T09:46:51.725422Z", + "start_time": "2024-11-14T09:46:50.904699Z" + } + }, "source": [ "ai_msg = llm_with_tools.invoke(\n", " \"what is the weather like in New York City\",\n", - ")\n", - "ai_msg" - ] + ")" + ], + "outputs": [], + "execution_count": 10 }, { "cell_type": "markdown", @@ -292,13 +413,30 @@ }, { "cell_type": "code", - "execution_count": null, "id": "166cb7ce-831d-4a7c-9721-abc107f11084", - "metadata": {}, - "outputs": [], - "source": [ - "ai_msg.tool_calls" - ] + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T09:46:51.744202Z", + "start_time": "2024-11-14T09:46:51.738431Z" + } + }, + "source": "print(ai_msg.tool_calls)", + "outputs": [ + { + "data": { + "text/plain": [ + "[{'name': 'GetWeather',\n", + " 'args': {'location': 'New York City, NY'},\n", + " 'id': 'chatcmpl-tool-fe70912c800d40fc8700d604d4823001',\n", + " 'type': 'tool_call'}]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 11 }, { "cell_type": "markdown", diff --git a/docs/docs/integrations/document_loaders/parsers/azure_openai_whisper_parser.ipynb b/docs/docs/integrations/document_loaders/parsers/azure_openai_whisper_parser.ipynb index b3dadb1f0ad81..6b8894491f30f 100644 --- a/docs/docs/integrations/document_loaders/parsers/azure_openai_whisper_parser.ipynb +++ b/docs/docs/integrations/document_loaders/parsers/azure_openai_whisper_parser.ipynb @@ -115,7 +115,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The `AzureOpenAIWhisperParser` can also be used in conjuction with audio loaders, like the `YoutubeAudioLoader` with a `GenericLoader`." + "The `AzureOpenAIWhisperParser` can also be used in conjunction with audio loaders, like the `YoutubeAudioLoader` with a `GenericLoader`." ] }, { diff --git a/docs/docs/integrations/graphs/apache_age.ipynb b/docs/docs/integrations/graphs/apache_age.ipynb index b3c39e974ab6a..588567c018cdc 100644 --- a/docs/docs/integrations/graphs/apache_age.ipynb +++ b/docs/docs/integrations/graphs/apache_age.ipynb @@ -45,8 +45,8 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.chains import GraphCypherQAChain\n", "from langchain_community.graphs.age_graph import AGEGraph\n", + "from langchain_neo4j import GraphCypherQAChain\n", "from langchain_openai import ChatOpenAI" ] }, @@ -169,7 +169,7 @@ "outputs": [], "source": [ "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0), graph=graph, verbose=True\n", + " ChatOpenAI(temperature=0), graph=graph, verbose=True, allow_dangerous_requests=True\n", ")" ] }, @@ -236,7 +236,11 @@ "outputs": [], "source": [ "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0), graph=graph, verbose=True, top_k=2\n", + " ChatOpenAI(temperature=0),\n", + " graph=graph,\n", + " verbose=True,\n", + " top_k=2,\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -295,7 +299,11 @@ "outputs": [], "source": [ "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0), graph=graph, verbose=True, return_intermediate_steps=True\n", + " ChatOpenAI(temperature=0),\n", + " graph=graph,\n", + " verbose=True,\n", + " return_intermediate_steps=True,\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -348,7 +356,11 @@ "outputs": [], "source": [ "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0), graph=graph, verbose=True, return_direct=True\n", + " ChatOpenAI(temperature=0),\n", + " graph=graph,\n", + " verbose=True,\n", + " return_direct=True,\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -435,6 +447,7 @@ " graph=graph,\n", " verbose=True,\n", " cypher_prompt=CYPHER_GENERATION_PROMPT,\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -503,6 +516,7 @@ " cypher_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\"),\n", " qa_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-16k\"),\n", " verbose=True,\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -574,6 +588,7 @@ " qa_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-16k\"),\n", " verbose=True,\n", " exclude_types=[\"Movie\"],\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -622,6 +637,7 @@ " graph=graph,\n", " verbose=True,\n", " validate_cypher=True,\n", + " allow_dangerous_requests=True,\n", ")" ] }, diff --git a/docs/docs/integrations/graphs/diffbot.ipynb b/docs/docs/integrations/graphs/diffbot.ipynb index 06a8ed4e21ba2..9c4a5e866f497 100644 --- a/docs/docs/integrations/graphs/diffbot.ipynb +++ b/docs/docs/integrations/graphs/diffbot.ipynb @@ -45,7 +45,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install --upgrade --quiet langchain langchain-experimental langchain-openai neo4j wikipedia" + "%pip install --upgrade --quiet langchain langchain-experimental langchain-openai langchain-neo4j neo4j wikipedia" ] }, { @@ -124,7 +124,7 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain_community.graphs import Neo4jGraph\n", + "from langchain_neo4j import Neo4jGraph\n", "\n", "url = \"bolt://localhost:7687\"\n", "username = \"neo4j\"\n", @@ -186,7 +186,7 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.chains import GraphCypherQAChain\n", + "from langchain_neo4j import GraphCypherQAChain\n", "from langchain_openai import ChatOpenAI\n", "\n", "chain = GraphCypherQAChain.from_llm(\n", @@ -194,6 +194,7 @@ " qa_llm=ChatOpenAI(temperature=0, model_name=\"gpt-3.5-turbo\"),\n", " graph=graph,\n", " verbose=True,\n", + " allow_dangerous_requests=True,\n", ")" ] }, diff --git a/docs/docs/integrations/graphs/memgraph.ipynb b/docs/docs/integrations/graphs/memgraph.ipynb index 85eb7497dbf74..4dc8d33be4b86 100644 --- a/docs/docs/integrations/graphs/memgraph.ipynb +++ b/docs/docs/integrations/graphs/memgraph.ipynb @@ -53,7 +53,7 @@ "metadata": {}, "outputs": [], "source": [ - "pip install langchain langchain-openai neo4j gqlalchemy --user" + "pip install langchain langchain-neo4j langchain-openai neo4j gqlalchemy --user" ] }, { @@ -74,9 +74,9 @@ "import os\n", "\n", "from gqlalchemy import Memgraph\n", - "from langchain.chains import GraphCypherQAChain\n", "from langchain_community.graphs import MemgraphGraph\n", "from langchain_core.prompts import PromptTemplate\n", + "from langchain_neo4j import GraphCypherQAChain\n", "from langchain_openai import ChatOpenAI" ] }, @@ -259,7 +259,11 @@ "outputs": [], "source": [ "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0), graph=graph, verbose=True, model_name=\"gpt-3.5-turbo\"\n", + " ChatOpenAI(temperature=0),\n", + " graph=graph,\n", + " verbose=True,\n", + " model_name=\"gpt-3.5-turbo\",\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -363,7 +367,11 @@ "source": [ "# Return the result of querying the graph directly\n", "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0), graph=graph, verbose=True, return_direct=True\n", + " ChatOpenAI(temperature=0),\n", + " graph=graph,\n", + " verbose=True,\n", + " return_direct=True,\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -412,7 +420,11 @@ "source": [ "# Return all the intermediate steps of query execution\n", "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0), graph=graph, verbose=True, return_intermediate_steps=True\n", + " ChatOpenAI(temperature=0),\n", + " graph=graph,\n", + " verbose=True,\n", + " return_intermediate_steps=True,\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -465,7 +477,11 @@ "source": [ "# Limit the maximum number of results returned by query\n", "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0), graph=graph, verbose=True, top_k=2\n", + " ChatOpenAI(temperature=0),\n", + " graph=graph,\n", + " verbose=True,\n", + " top_k=2,\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -530,7 +546,11 @@ "outputs": [], "source": [ "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0), graph=graph, verbose=True, model_name=\"gpt-3.5-turbo\"\n", + " ChatOpenAI(temperature=0),\n", + " graph=graph,\n", + " verbose=True,\n", + " model_name=\"gpt-3.5-turbo\",\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -628,6 +648,7 @@ " graph=graph,\n", " verbose=True,\n", " model_name=\"gpt-3.5-turbo\",\n", + " allow_dangerous_requests=True,\n", ")" ] }, diff --git a/docs/docs/integrations/graphs/neo4j_cypher.ipynb b/docs/docs/integrations/graphs/neo4j_cypher.ipynb index 3348f0b8d26ba..8dd9824d67146 100644 --- a/docs/docs/integrations/graphs/neo4j_cypher.ipynb +++ b/docs/docs/integrations/graphs/neo4j_cypher.ipynb @@ -46,8 +46,7 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.chains import GraphCypherQAChain\n", - "from langchain_community.graphs import Neo4jGraph\n", + "from langchain_neo4j import GraphCypherQAChain, Neo4jGraph\n", "from langchain_openai import ChatOpenAI" ] }, @@ -61,6 +60,28 @@ "graph = Neo4jGraph(url=\"bolt://localhost:7687\", username=\"neo4j\", password=\"password\")" ] }, + { + "cell_type": "markdown", + "id": "8c663e91", + "metadata": {}, + "source": [ + "We default to OpenAI models in this guide." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "51c88001", + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "if \"OPENAI_API_KEY\" not in os.environ:\n", + " os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")" + ] + }, { "cell_type": "markdown", "id": "995ea9b9", @@ -203,7 +224,7 @@ "outputs": [], "source": [ "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0), graph=graph, verbose=True\n", + " ChatOpenAI(temperature=0), graph=graph, verbose=True, allow_dangerous_requests=True\n", ")" ] }, @@ -264,7 +285,11 @@ "outputs": [], "source": [ "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0), graph=graph, verbose=True, top_k=2\n", + " ChatOpenAI(temperature=0),\n", + " graph=graph,\n", + " verbose=True,\n", + " top_k=2,\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -324,7 +349,11 @@ "outputs": [], "source": [ "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0), graph=graph, verbose=True, return_intermediate_steps=True\n", + " ChatOpenAI(temperature=0),\n", + " graph=graph,\n", + " verbose=True,\n", + " return_intermediate_steps=True,\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -377,7 +406,11 @@ "outputs": [], "source": [ "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0), graph=graph, verbose=True, return_direct=True\n", + " ChatOpenAI(temperature=0),\n", + " graph=graph,\n", + " verbose=True,\n", + " return_direct=True,\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -465,6 +498,7 @@ " graph=graph,\n", " verbose=True,\n", " cypher_prompt=CYPHER_GENERATION_PROMPT,\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -527,6 +561,7 @@ " cypher_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\"),\n", " qa_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-16k\"),\n", " verbose=True,\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -592,6 +627,7 @@ " qa_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-16k\"),\n", " verbose=True,\n", " exclude_types=[\"Movie\"],\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -640,6 +676,7 @@ " graph=graph,\n", " verbose=True,\n", " validate_cypher=True,\n", + " allow_dangerous_requests=True,\n", ")" ] }, @@ -734,6 +771,7 @@ " graph=graph,\n", " verbose=True,\n", " use_function_response=True,\n", + " allow_dangerous_requests=True,\n", ")\n", "chain.invoke({\"query\": \"Who played in Top Gun?\"})" ] @@ -790,6 +828,7 @@ " verbose=True,\n", " use_function_response=True,\n", " function_response_system=\"Respond as a pirate!\",\n", + " allow_dangerous_requests=True,\n", ")\n", "chain.invoke({\"query\": \"Who played in Top Gun?\"})" ] diff --git a/docs/docs/integrations/llms/writer.ipynb b/docs/docs/integrations/llms/writer.ipynb index 7488eff3efe16..bc17ba76582dd 100644 --- a/docs/docs/integrations/llms/writer.ipynb +++ b/docs/docs/integrations/llms/writer.ipynb @@ -4,120 +4,161 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Writer\n", + "# Writer LLM\n", "\n", "[Writer](https://writer.com/) is a platform to generate different language content.\n", "\n", "This example goes over how to use LangChain to interact with `Writer` [models](https://dev.writer.com/docs/models).\n", "\n", - "You have to get the WRITER_API_KEY [here](https://dev.writer.com/docs)." + "## Setup\n", + "\n", + "To access Writer models you'll need to create a Writer account, get an API key, and install the `writer-sdk` and `langchain-community` packages.\n", + "\n", + "### Credentials\n", + "\n", + "Head to [Writer AI Studio](https://app.writer.com/aistudio/signup?utm_campaign=devrel) to sign up to OpenAI and generate an API key. Once you've done this set the WRITER_API_KEY environment variable:" ] }, { + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T11:10:46.824961Z", + "start_time": "2024-11-14T11:10:44.864137Z" + } + }, "cell_type": "code", - "execution_count": 4, + "source": [ + "import getpass\n", + "import os\n", + "\n", + "if not os.environ.get(\"WRITER_API_KEY\"):\n", + " os.environ[\"WRITER_API_KEY\"] = getpass.getpass(\"Enter your Writer API key:\")" + ], + "outputs": [], + "execution_count": 1 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Installation\n", + "\n", + "The LangChain Writer integration lives in the `langchain-community` package:" + ] + }, + { "metadata": { - "tags": [] + "ExecuteTime": { + "end_time": "2024-11-14T11:10:48.297429Z", + "start_time": "2024-11-14T11:10:46.843983Z" + } }, + "cell_type": "code", + "source": "%pip install -qU langchain-community writer-sdk", "outputs": [ { - "name": "stdin", + "name": "stdout", "output_type": "stream", "text": [ - " ········\n" + "\r\n", + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m24.2\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m24.3.1\u001B[0m\r\n", + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\r\n", + "Note: you may need to restart the kernel to use updated packages.\n" ] } ], - "source": [ - "from getpass import getpass\n", - "\n", - "WRITER_API_KEY = getpass()" - ] + "execution_count": 2 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Now we can initialize our model object to interact with writer LLMs" }, { - "cell_type": "code", - "execution_count": 5, "metadata": { - "tags": [] + "ExecuteTime": { + "end_time": "2024-11-14T11:10:49.818902Z", + "start_time": "2024-11-14T11:10:48.580516Z" + } }, - "outputs": [], + "cell_type": "code", "source": [ - "import os\n", + "from langchain_community.llms import Writer as WriterLLM\n", "\n", - "os.environ[\"WRITER_API_KEY\"] = WRITER_API_KEY" - ] + "llm = WriterLLM(\n", + " temperature=0.7,\n", + " max_tokens=1000,\n", + " # other params...\n", + ")" + ], + "outputs": [], + "execution_count": 3 }, { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from langchain.chains import LLMChain\n", - "from langchain_community.llms import Writer\n", - "from langchain_core.prompts import PromptTemplate" - ] + "metadata": {}, + "cell_type": "markdown", + "source": "## Invocation" }, { - "cell_type": "code", - "execution_count": 7, "metadata": { - "tags": [] + "jupyter": { + "is_executing": true + }, + "ExecuteTime": { + "start_time": "2024-11-14T11:10:49.832822Z" + } }, + "cell_type": "code", + "source": "response_text = llm.invoke(input=\"Write a poem\")", "outputs": [], - "source": [ - "template = \"\"\"Question: {question}\n", - "\n", - "Answer: Let's think step by step.\"\"\"\n", - "\n", - "prompt = PromptTemplate.from_template(template)" - ] + "execution_count": null }, { + "metadata": {}, "cell_type": "code", - "execution_count": 14, - "metadata": { - "tags": [] - }, + "source": "print(response_text)", "outputs": [], - "source": [ - "# If you get an error, probably, you need to set up the \"base_url\" parameter that can be taken from the error log.\n", - "\n", - "llm = Writer()" - ] + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## Streaming" }, { + "metadata": {}, "cell_type": "code", - "execution_count": 15, - "metadata": { - "tags": [] - }, + "source": "stream_response = llm.stream(input=\"Tell me a fairytale\")", "outputs": [], - "source": [ - "llm_chain = LLMChain(prompt=prompt, llm=llm)" - ] + "execution_count": null }, { + "metadata": {}, "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, + "source": [ + "for chunk in stream_response:\n", + " print(chunk, end=\"\")" + ], "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", "source": [ - "question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n", + "## Async\n", "\n", - "llm_chain.run(question)" + "Writer support asynchronous calls via **ainvoke()** and **astream()** methods" ] }, { - "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], - "source": [] + "cell_type": "markdown", + "source": [ + "## API reference\n", + "\n", + "For detailed documentation of all Writer features, head to our [API reference](https://dev.writer.com/api-guides/api-reference/completion-api/text-generation#text-generation)." + ] } ], "metadata": { diff --git a/docs/docs/integrations/memory/neo4j_chat_message_history.ipynb b/docs/docs/integrations/memory/neo4j_chat_message_history.ipynb index d9dc80020c8ae..20cfd9ce68c10 100644 --- a/docs/docs/integrations/memory/neo4j_chat_message_history.ipynb +++ b/docs/docs/integrations/memory/neo4j_chat_message_history.ipynb @@ -19,7 +19,7 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain_community.chat_message_histories import Neo4jChatMessageHistory\n", + "from langchain_neo4j import Neo4jChatMessageHistory\n", "\n", "history = Neo4jChatMessageHistory(\n", " url=\"bolt://localhost:7687\",\n", diff --git a/docs/docs/integrations/providers/neo4j.mdx b/docs/docs/integrations/providers/neo4j.mdx index 929b622d612ee..2b8d8f683cc9f 100644 --- a/docs/docs/integrations/providers/neo4j.mdx +++ b/docs/docs/integrations/providers/neo4j.mdx @@ -11,7 +11,7 @@ ## Installation and Setup -- Install the Python SDK with `pip install neo4j` +- Install the Python SDK with `pip install neo4j langchain-neo4j` ## VectorStore @@ -20,7 +20,7 @@ The Neo4j vector index is used as a vectorstore, whether for semantic search or example selection. ```python -from langchain_community.vectorstores import Neo4jVector +from langchain_neo4j import Neo4jVector ``` See a [usage example](/docs/integrations/vectorstores/neo4jvector) @@ -31,8 +31,7 @@ There exists a wrapper around Neo4j graph database that allows you to generate C and use them to retrieve relevant information from the database. ```python -from langchain_community.graphs import Neo4jGraph -from langchain.chains import GraphCypherQAChain +from langchain_neo4j import GraphCypherQAChain, Neo4jGraph ``` See a [usage example](/docs/integrations/graphs/neo4j_cypher) @@ -45,7 +44,7 @@ By coupling Diffbot's NLP API with Neo4j, a graph database, you can create power These graph structures are fully queryable and can be integrated into various applications. ```python -from langchain_community.graphs import Neo4jGraph +from langchain_neo4j import Neo4jGraph from langchain_experimental.graph_transformers.diffbot import DiffbotGraphTransformer ``` @@ -56,5 +55,5 @@ See a [usage example](/docs/integrations/graphs/diffbot) See a [usage example](/docs/integrations/memory/neo4j_chat_message_history). ```python -from langchain.memory import Neo4jChatMessageHistory +from langchain_neo4j import Neo4jChatMessageHistory ``` diff --git a/docs/docs/integrations/retrievers/self_query/neo4j_self_query.ipynb b/docs/docs/integrations/retrievers/self_query/neo4j_self_query.ipynb index 6bc4f718dfc84..0d2f14521277b 100644 --- a/docs/docs/integrations/retrievers/self_query/neo4j_self_query.ipynb +++ b/docs/docs/integrations/retrievers/self_query/neo4j_self_query.ipynb @@ -99,8 +99,8 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain_community.vectorstores import Neo4jVector\n", "from langchain_core.documents import Document\n", + "from langchain_neo4j import Neo4jVector\n", "from langchain_openai import OpenAIEmbeddings\n", "\n", "embeddings = OpenAIEmbeddings()" diff --git a/docs/docs/integrations/vectorstores/neo4jvector.ipynb b/docs/docs/integrations/vectorstores/neo4jvector.ipynb index cc489d5c48352..87f521c2fad4c 100644 --- a/docs/docs/integrations/vectorstores/neo4jvector.ipynb +++ b/docs/docs/integrations/vectorstores/neo4jvector.ipynb @@ -34,7 +34,7 @@ "source": [ "# Pip install necessary package\n", "%pip install --upgrade --quiet neo4j\n", - "%pip install --upgrade --quiet langchain-openai langchain-community\n", + "%pip install --upgrade --quiet langchain-openai langchain-neo4j\n", "%pip install --upgrade --quiet tiktoken" ] }, @@ -75,8 +75,8 @@ "outputs": [], "source": [ "from langchain_community.document_loaders import TextLoader\n", - "from langchain_community.vectorstores import Neo4jVector\n", "from langchain_core.documents import Document\n", + "from langchain_neo4j import Neo4jVector\n", "from langchain_openai import OpenAIEmbeddings\n", "from langchain_text_splitters import CharacterTextSplitter" ] diff --git a/docs/docs/tutorials/graph.ipynb b/docs/docs/tutorials/graph.ipynb index d189d1af21fbe..4130bae5a84f3 100644 --- a/docs/docs/tutorials/graph.ipynb +++ b/docs/docs/tutorials/graph.ipynb @@ -45,7 +45,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install --upgrade --quiet langchain langchain-community langchain-openai neo4j" + "%pip install --upgrade --quiet langchain langchain-neo4j langchain-openai neo4j" ] }, { @@ -123,7 +123,7 @@ } ], "source": [ - "from langchain_community.graphs import Neo4jGraph\n", + "from langchain_neo4j import Neo4jGraph\n", "\n", "graph = Neo4jGraph()\n", "\n", @@ -233,11 +233,13 @@ } ], "source": [ - "from langchain.chains import GraphCypherQAChain\n", + "from langchain_neo4j import GraphCypherQAChain\n", "from langchain_openai import ChatOpenAI\n", "\n", "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n", - "chain = GraphCypherQAChain.from_llm(graph=graph, llm=llm, verbose=True)\n", + "chain = GraphCypherQAChain.from_llm(\n", + " graph=graph, llm=llm, verbose=True, allow_dangerous_requests=True\n", + ")\n", "response = chain.invoke({\"query\": \"What was the cast of the Casino?\"})\n", "response" ] @@ -286,7 +288,11 @@ ], "source": [ "chain = GraphCypherQAChain.from_llm(\n", - " graph=graph, llm=llm, verbose=True, validate_cypher=True\n", + " graph=graph,\n", + " llm=llm,\n", + " verbose=True,\n", + " validate_cypher=True,\n", + " allow_dangerous_requests=True,\n", ")\n", "response = chain.invoke({\"query\": \"What was the cast of the Casino?\"})\n", "response" diff --git a/libs/community/langchain_community/chains/graph_qa/cypher.py b/libs/community/langchain_community/chains/graph_qa/cypher.py index 91a5ba606621b..760ce66731206 100644 --- a/libs/community/langchain_community/chains/graph_qa/cypher.py +++ b/libs/community/langchain_community/chains/graph_qa/cypher.py @@ -7,6 +7,7 @@ from langchain.chains.base import Chain from langchain.chains.llm import LLMChain +from langchain_core._api.deprecation import deprecated from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import ( @@ -44,6 +45,11 @@ """ +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.chains.graph_qa.cypher.extract_cypher", +) def extract_cypher(text: str) -> str: """Extract Cypher code from a text. @@ -62,6 +68,11 @@ def extract_cypher(text: str) -> str: return matches[0] if matches else text +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.chains.graph_qa.cypher.construct_schema", +) def construct_schema( structured_schema: Dict[str, Any], include_types: List[str], @@ -124,6 +135,11 @@ def filter_func(x: str) -> bool: ) +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.chains.graph_qa.cypher.get_function_response", +) def get_function_response( question: str, context: List[Dict[str, Any]] ) -> List[BaseMessage]: @@ -149,6 +165,11 @@ def get_function_response( return messages +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.GraphCypherQAChain", +) class GraphCypherQAChain(Chain): """Chain for question-answering against a graph by generating Cypher statements. diff --git a/libs/community/langchain_community/chains/graph_qa/cypher_utils.py b/libs/community/langchain_community/chains/graph_qa/cypher_utils.py index c123cac9b52f3..4d8c7c45572fb 100644 --- a/libs/community/langchain_community/chains/graph_qa/cypher_utils.py +++ b/libs/community/langchain_community/chains/graph_qa/cypher_utils.py @@ -2,9 +2,16 @@ from collections import namedtuple from typing import Any, Dict, List, Optional, Tuple +from langchain_core._api.deprecation import deprecated + Schema = namedtuple("Schema", ["left_node", "relation", "right_node"]) +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.chains.graph_qa.cypher_utils.CypherQueryCorrector", +) class CypherQueryCorrector: """ Used to correct relationship direction in generated Cypher statements. diff --git a/libs/community/langchain_community/chat_message_histories/neo4j.py b/libs/community/langchain_community/chat_message_histories/neo4j.py index aeca69cdad8c0..5a054c706de25 100644 --- a/libs/community/langchain_community/chat_message_histories/neo4j.py +++ b/libs/community/langchain_community/chat_message_histories/neo4j.py @@ -1,5 +1,6 @@ from typing import List, Optional, Union +from langchain_core._api.deprecation import deprecated from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage, messages_from_dict from langchain_core.utils import get_from_dict_or_env @@ -7,6 +8,11 @@ from langchain_community.graphs import Neo4jGraph +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.Neo4jChatMessageHistory", +) class Neo4jChatMessageHistory(BaseChatMessageHistory): """Chat message history stored in a Neo4j database.""" diff --git a/libs/community/langchain_community/chat_models/google_palm.py b/libs/community/langchain_community/chat_models/google_palm.py index 77038256c7d82..e9bd5928a040d 100644 --- a/libs/community/langchain_community/chat_models/google_palm.py +++ b/libs/community/langchain_community/chat_models/google_palm.py @@ -219,7 +219,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel): To use you must have the google.generativeai Python package installed and either: - 1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or + 1. The ``GOOGLE_API_KEY`` environment variable set with your API key, or 2. Pass your API key using the google_api_key kwarg to the ChatGoogle constructor. diff --git a/libs/community/langchain_community/chat_models/writer.py b/libs/community/langchain_community/chat_models/writer.py index 945b9d8b0b6d2..4101b6e23eb35 100644 --- a/libs/community/langchain_community/chat_models/writer.py +++ b/libs/community/langchain_community/chat_models/writer.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import logging from typing import ( Any, @@ -11,7 +12,6 @@ Iterator, List, Literal, - Mapping, Optional, Sequence, Tuple, @@ -26,8 +26,6 @@ from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, - agenerate_from_stream, - generate_from_stream, ) from langchain_core.messages import ( AIMessage, @@ -40,99 +38,49 @@ ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable +from langchain_core.utils import get_from_dict_or_env from langchain_core.utils.function_calling import convert_to_openai_tool -from pydantic import BaseModel, ConfigDict, Field, SecretStr +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator logger = logging.getLogger(__name__) -def _convert_message_to_dict(message: BaseMessage) -> dict: - """Convert a LangChain message to a Writer message dict.""" - message_dict = {"role": "", "content": message.content} - - if isinstance(message, ChatMessage): - message_dict["role"] = message.role - elif isinstance(message, HumanMessage): - message_dict["role"] = "user" - elif isinstance(message, AIMessage): - message_dict["role"] = "assistant" - if message.tool_calls: - message_dict["tool_calls"] = [ - { - "id": tool["id"], - "type": "function", - "function": {"name": tool["name"], "arguments": tool["args"]}, - } - for tool in message.tool_calls - ] - elif isinstance(message, SystemMessage): - message_dict["role"] = "system" - elif isinstance(message, ToolMessage): - message_dict["role"] = "tool" - message_dict["tool_call_id"] = message.tool_call_id - else: - raise ValueError(f"Got unknown message type: {type(message)}") - - if message.name: - message_dict["name"] = message.name - - return message_dict - - -def _convert_dict_to_message(response_dict: Dict[str, Any]) -> BaseMessage: - """Convert a Writer message dict to a LangChain message.""" - role = response_dict["role"] - content = response_dict.get("content", "") - - if role == "user": - return HumanMessage(content=content) - elif role == "assistant": - additional_kwargs = {} - if tool_calls := response_dict.get("tool_calls"): - additional_kwargs["tool_calls"] = tool_calls - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) - elif role == "system": - return SystemMessage(content=content) - elif role == "tool": - return ToolMessage( - content=content, - tool_call_id=response_dict["tool_call_id"], - name=response_dict.get("name"), - ) - else: - return ChatMessage(content=content, role=role) - - class ChatWriter(BaseChatModel): """Writer chat model. To use, you should have the ``writer-sdk`` Python package installed, and the - environment variable ``WRITER_API_KEY`` set with your API key. + environment variable ``WRITER_API_KEY`` set with your API key or pass 'api_key' + init param. Example: .. code-block:: python from langchain_community.chat_models import ChatWriter - chat = ChatWriter(model="palmyra-x-004") + chat = ChatWriter( + api_key="your key" + model="palmyra-x-004" + ) """ client: Any = Field(default=None, exclude=True) #: :meta private: async_client: Any = Field(default=None, exclude=True) #: :meta private: + + api_key: Optional[SecretStr] = Field(default=None) + """Writer API key.""" + model_name: str = Field(default="palmyra-x-004", alias="model") """Model name to use.""" + temperature: float = 0.7 """What sampling temperature to use.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - writer_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") - """Writer API key.""" - writer_api_base: Optional[str] = Field(default=None, alias="base_url") - """Base URL for API requests.""" - streaming: bool = False - """Whether to stream the results or not.""" + n: int = 1 """Number of chat completions to generate for each prompt.""" + max_tokens: Optional[int] = None """Maximum number of tokens to generate.""" @@ -149,37 +97,159 @@ def _identifying_params(self) -> Dict[str, Any]: return { "model_name": self.model_name, "temperature": self.temperature, - "streaming": self.streaming, **self.model_kwargs, } - def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling Writer API.""" + return { + "model": self.model_name, + "temperature": self.temperature, + "n": self.n, + "max_tokens": self.max_tokens, + **self.model_kwargs, + } + + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict) -> Any: + """Validates that api key is passed and creates Writer clients.""" + try: + from writerai import AsyncClient, Client + except ImportError as e: + raise ImportError( + "Could not import writerai python package. " + "Please install it with `pip install writerai`." + ) from e + + if not values.get("client"): + values.update( + { + "client": Client( + api_key=get_from_dict_or_env( + values, "api_key", "WRITER_API_KEY" + ) + ) + } + ) + + if not values.get("async_client"): + values.update( + { + "async_client": AsyncClient( + api_key=get_from_dict_or_env( + values, "api_key", "WRITER_API_KEY" + ) + ) + } + ) + + if not ( + type(values.get("client")) is Client + and type(values.get("async_client")) is AsyncClient + ): + raise ValueError( + "'client' attribute must be with type 'Client' and " + "'async_client' must be with type 'AsyncClient' from 'writerai' package" + ) + + return values + + def _create_chat_result(self, response: Any) -> ChatResult: generations = [] - for choice in response["choices"]: - message = _convert_dict_to_message(choice["message"]) + for choice in response.choices: + message = self._convert_writer_to_langchain(choice.message) gen = ChatGeneration( message=message, - generation_info=dict(finish_reason=choice.get("finish_reason")), + generation_info=dict(finish_reason=choice.finish_reason), ) generations.append(gen) - token_usage = response.get("usage", {}) + token_usage = {} + + if response.usage: + token_usage = response.usage.__dict__ llm_output = { "token_usage": token_usage, "model_name": self.model_name, - "system_fingerprint": response.get("system_fingerprint", ""), + "system_fingerprint": response.system_fingerprint, } return ChatResult(generations=generations, llm_output=llm_output) - def _convert_messages_to_dicts( + @staticmethod + def _convert_langchain_to_writer(message: BaseMessage) -> dict: + """Convert a LangChain message to a Writer message dict.""" + message_dict = {"role": "", "content": message.content} + + if isinstance(message, ChatMessage): + message_dict["role"] = message.role + elif isinstance(message, HumanMessage): + message_dict["role"] = "user" + elif isinstance(message, AIMessage): + message_dict["role"] = "assistant" + if message.tool_calls: + message_dict["tool_calls"] = [ + { + "id": tool["id"], + "type": "function", + "function": {"name": tool["name"], "arguments": tool["args"]}, + } + for tool in message.tool_calls + ] + elif isinstance(message, SystemMessage): + message_dict["role"] = "system" + elif isinstance(message, ToolMessage): + message_dict["role"] = "tool" + message_dict["tool_call_id"] = message.tool_call_id + else: + raise ValueError(f"Got unknown message type: {type(message)}") + + if message.name: + message_dict["name"] = message.name + + return message_dict + + @staticmethod + def _convert_writer_to_langchain(response_message: Any) -> BaseMessage: + """Convert a Writer message to a LangChain message.""" + if not isinstance(response_message, dict): + response_message = json.loads( + json.dumps(response_message, default=lambda o: o.__dict__) + ) + + role = response_message.get("role", "") + content = response_message.get("content") + if not content: + content = "" + + if role == "user": + return HumanMessage(content=content) + elif role == "assistant": + additional_kwargs = {} + if tool_calls := response_message.get("tool_calls", []): + additional_kwargs["tool_calls"] = tool_calls + return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + elif role == "system": + return SystemMessage(content=content) + elif role == "tool": + return ToolMessage( + content=content, + tool_call_id=response_message.get("tool_call_id", ""), + name=response_message.get("name", ""), + ) + else: + return ChatMessage(content=content, role=role) + + def _convert_messages_to_writer( self, messages: List[BaseMessage], stop: Optional[List[str]] = None ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + """Convert a list of LangChain messages to List of Writer dicts.""" params = { "model": self.model_name, "temperature": self.temperature, "n": self.n, - "stream": self.streaming, **self.model_kwargs, } if stop: @@ -187,7 +257,7 @@ def _convert_messages_to_dicts( if self.max_tokens is not None: params["max_tokens"] = self.max_tokens - message_dicts = [_convert_message_to_dict(m) for m in messages] + message_dicts = [self._convert_langchain_to_writer(m) for m in messages] return message_dicts, params def _stream( @@ -197,17 +267,17 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - message_dicts, params = self._convert_messages_to_dicts(messages, stop) + message_dicts, params = self._convert_messages_to_writer(messages, stop) params = {**params, **kwargs, "stream": True} response = self.client.chat.chat(messages=message_dicts, **params) for chunk in response: - delta = chunk["choices"][0].get("delta") - if not delta or not delta.get("content"): + delta = chunk.choices[0].delta + if not delta or not delta.content: continue - chunk = _convert_dict_to_message( - {"role": "assistant", "content": delta["content"]} + chunk = self._convert_writer_to_langchain( + {"role": "assistant", "content": delta.content} ) chunk = ChatGenerationChunk(message=chunk) @@ -223,17 +293,17 @@ async def _astream( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: - message_dicts, params = self._convert_messages_to_dicts(messages, stop) + message_dicts, params = self._convert_messages_to_writer(messages, stop) params = {**params, **kwargs, "stream": True} response = await self.async_client.chat.chat(messages=message_dicts, **params) async for chunk in response: - delta = chunk["choices"][0].get("delta") - if not delta or not delta.get("content"): + delta = chunk.choices[0].delta + if not delta or not delta.content: continue - chunk = _convert_dict_to_message( - {"role": "assistant", "content": delta["content"]} + chunk = self._convert_writer_to_langchain( + {"role": "assistant", "content": delta.content} ) chunk = ChatGenerationChunk(message=chunk) @@ -249,12 +319,7 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - if self.streaming: - return generate_from_stream( - self._stream(messages, stop, run_manager, **kwargs) - ) - - message_dicts, params = self._convert_messages_to_dicts(messages, stop) + message_dicts, params = self._convert_messages_to_writer(messages, stop) params = {**params, **kwargs} response = self.client.chat.chat(messages=message_dicts, **params) return self._create_chat_result(response) @@ -266,28 +331,11 @@ async def _agenerate( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - if self.streaming: - return await agenerate_from_stream( - self._astream(messages, stop, run_manager, **kwargs) - ) - - message_dicts, params = self._convert_messages_to_dicts(messages, stop) + message_dicts, params = self._convert_messages_to_writer(messages, stop) params = {**params, **kwargs} response = await self.async_client.chat.chat(messages=message_dicts, **params) return self._create_chat_result(response) - @property - def _default_params(self) -> Dict[str, Any]: - """Get the default parameters for calling Writer API.""" - return { - "model": self.model_name, - "temperature": self.temperature, - "stream": self.streaming, - "n": self.n, - "max_tokens": self.max_tokens, - **self.model_kwargs, - } - def bind_tools( self, tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], diff --git a/libs/community/langchain_community/graphs/neo4j_graph.py b/libs/community/langchain_community/graphs/neo4j_graph.py index d3e8860c89131..dd2a7937f7f81 100644 --- a/libs/community/langchain_community/graphs/neo4j_graph.py +++ b/libs/community/langchain_community/graphs/neo4j_graph.py @@ -1,6 +1,7 @@ from hashlib import md5 from typing import Any, Dict, List, Optional +from langchain_core._api.deprecation import deprecated from langchain_core.utils import get_from_dict_or_env from langchain_community.graphs.graph_document import GraphDocument @@ -51,6 +52,11 @@ ) +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.graphs.neo4j_graph.clean_string_values", +) def clean_string_values(text: str) -> str: """Clean string values for schema. @@ -65,6 +71,11 @@ def clean_string_values(text: str) -> str: return text.replace("\n", " ").replace("\r", " ") +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.graphs.neo4j_graph.value_sanitize", +) def value_sanitize(d: Any) -> Any: """Sanitize the input dictionary or list. @@ -111,6 +122,11 @@ def value_sanitize(d: Any) -> Any: return d +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.graphs.neo4j_graph._get_node_import_query", +) def _get_node_import_query(baseEntityLabel: bool, include_source: bool) -> str: if baseEntityLabel: return ( @@ -134,6 +150,11 @@ def _get_node_import_query(baseEntityLabel: bool, include_source: bool) -> str: ) +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.graphs.neo4j_graph._get_rel_import_query", +) def _get_rel_import_query(baseEntityLabel: bool) -> str: if baseEntityLabel: return ( @@ -158,6 +179,11 @@ def _get_rel_import_query(baseEntityLabel: bool) -> str: ) +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.graphs.neo4j_graph._format_schema", +) def _format_schema(schema: Dict, is_enhanced: bool) -> str: formatted_node_props = [] formatted_rel_props = [] @@ -287,10 +313,20 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str: ) +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.graphs.neo4j_graph._remove_backticks", +) def _remove_backticks(text: str) -> str: return text.replace("`", "") +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.Neo4jGraph", +) class Neo4jGraph(GraphStore): """Neo4j database wrapper for various graph operations. diff --git a/libs/community/langchain_community/llms/writer.py b/libs/community/langchain_community/llms/writer.py index d82a346c43616..e68909d06e13e 100644 --- a/libs/community/langchain_community/llms/writer.py +++ b/libs/community/langchain_community/llms/writer.py @@ -1,108 +1,89 @@ -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional -import requests -from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain_core.language_models.llms import LLM -from langchain_core.utils import get_from_dict_or_env, pre_init -from pydantic import ConfigDict - -from langchain_community.llms.utils import enforce_stop_tokens +from langchain_core.outputs import GenerationChunk +from langchain_core.utils import get_from_dict_or_env +from pydantic import ConfigDict, Field, SecretStr, model_validator class Writer(LLM): """Writer large language models. - To use, you should have the environment variable ``WRITER_API_KEY`` and - ``WRITER_ORG_ID`` set with your API key and organization ID respectively. + To use, you should have the ``writer-sdk`` Python package installed, and the + environment variable ``WRITER_API_KEY`` set with your API key. Example: .. code-block:: python - from langchain_community.llms import Writer - writer = Writer(model_id="palmyra-base") + from langchain_community.llms import Writer as WriterLLM + from writerai import Writer, AsyncWriter + + client = Writer() + async_client = AsyncWriter() + + chat = WriterLLM( + client=client, + async_client=async_client + ) """ - writer_org_id: Optional[str] = None - """Writer organization ID.""" + client: Any = Field(default=None, exclude=True) #: :meta private: + async_client: Any = Field(default=None, exclude=True) #: :meta private: - model_id: str = "palmyra-instruct" - """Model name to use.""" + api_key: Optional[SecretStr] = Field(default=None) + """Writer API key.""" - min_tokens: Optional[int] = None - """Minimum number of tokens to generate.""" + model_name: str = Field(default="palmyra-x-003-instruct", alias="model") + """Model name to use.""" max_tokens: Optional[int] = None - """Maximum number of tokens to generate.""" + """The maximum number of tokens that the model can generate in the response.""" - temperature: Optional[float] = None - """What sampling temperature to use.""" + temperature: Optional[float] = 0.7 + """Controls the randomness of the model's outputs. Higher values lead to more + random outputs, while lower values make the model more deterministic.""" top_p: Optional[float] = None - """Total probability mass of tokens to consider at each step.""" + """Used to control the nucleus sampling, where only the most probable tokens + with a cumulative probability of top_p are considered for sampling, providing + a way to fine-tune the randomness of predictions.""" stop: Optional[List[str]] = None - """Sequences when completion generation will stop.""" - - presence_penalty: Optional[float] = None - """Penalizes repeated tokens regardless of frequency.""" - - repetition_penalty: Optional[float] = None - """Penalizes repeated tokens according to frequency.""" + """Specifies stopping conditions for the model's output generation. This can + be an array of strings or a single string that the model will look for as a + signal to stop generating further tokens.""" best_of: Optional[int] = None - """Generates this many completions server-side and returns the "best".""" - - logprobs: bool = False - """Whether to return log probabilities.""" - - n: Optional[int] = None - """How many completions to generate.""" + """Specifies the number of completions to generate and return the best one. + Useful for generating multiple outputs and choosing the best based on some + criteria.""" - writer_api_key: Optional[str] = None - """Writer API key.""" - - base_url: Optional[str] = None - """Base url to use, if None decides based on model name.""" - - model_config = ConfigDict( - extra="forbid", - ) + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" - @pre_init - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and organization id exist in environment.""" - - writer_api_key = get_from_dict_or_env( - values, "writer_api_key", "WRITER_API_KEY" - ) - values["writer_api_key"] = writer_api_key - - writer_org_id = get_from_dict_or_env(values, "writer_org_id", "WRITER_ORG_ID") - values["writer_org_id"] = writer_org_id - - return values + model_config = ConfigDict(populate_by_name=True) @property def _default_params(self) -> Mapping[str, Any]: """Get the default parameters for calling Writer API.""" return { - "minTokens": self.min_tokens, - "maxTokens": self.max_tokens, + "max_tokens": self.max_tokens, "temperature": self.temperature, - "topP": self.top_p, + "top_p": self.top_p, "stop": self.stop, - "presencePenalty": self.presence_penalty, - "repetitionPenalty": self.repetition_penalty, - "bestOf": self.best_of, - "logprobs": self.logprobs, - "n": self.n, + "best_of": self.best_of, + **self.model_kwargs, } @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return { - **{"model_id": self.model_id, "writer_org_id": self.writer_org_id}, + "model": self.model_name, **self._default_params, } @@ -111,6 +92,51 @@ def _llm_type(self) -> str: """Return type of llm.""" return "writer" + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict) -> Any: + """Validates that api key is passed and creates Writer clients.""" + try: + from writerai import AsyncClient, Client + except ImportError as e: + raise ImportError( + "Could not import writerai python package. " + "Please install it with `pip install writerai`." + ) from e + + if not values.get("client"): + values.update( + { + "client": Client( + api_key=get_from_dict_or_env( + values, "api_key", "WRITER_API_KEY" + ) + ) + } + ) + + if not values.get("async_client"): + values.update( + { + "async_client": AsyncClient( + api_key=get_from_dict_or_env( + values, "api_key", "WRITER_API_KEY" + ) + ) + } + ) + + if not ( + type(values.get("client")) is Client + and type(values.get("async_client")) is AsyncClient + ): + raise ValueError( + "'client' attribute must be with type 'Client' and " + "'async_client' must be with type 'AsyncClient' from 'writerai' package" + ) + + return values + def _call( self, prompt: str, @@ -118,41 +144,54 @@ def _call( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: - """Call out to Writer's completions endpoint. - - Args: - prompt: The prompt to pass into the model. - stop: Optional list of stop words to use when generating. - - Returns: - The string generated by the model. - - Example: - .. code-block:: python - - response = Writer("Tell me a joke.") - """ - if self.base_url is not None: - base_url = self.base_url - else: - base_url = ( - "https://enterprise-api.writer.com/llm" - f"/organization/{self.writer_org_id}" - f"/model/{self.model_id}/completions" - ) - params = {**self._default_params, **kwargs} - response = requests.post( - url=base_url, - headers={ - "Authorization": f"{self.writer_api_key}", - "Content-Type": "application/json", - "Accept": "application/json", - }, - json={"prompt": prompt, **params}, - ) - text = response.text + params = {**self._identifying_params, **kwargs} + if stop is not None: + params.update({"stop": stop}) + text = self.client.completions.create(prompt=prompt, **params).choices[0].text + return text + + async def _acall( + self, + prompt: str, + stop: Optional[list[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + params = {**self._identifying_params, **kwargs} if stop is not None: - # I believe this is required since the stop tokens - # are not enforced by the model parameters - text = enforce_stop_tokens(text, stop) + params.update({"stop": stop}) + response = await self.async_client.completions.create(prompt=prompt, **params) + text = response.choices[0].text return text + + def _stream( + self, + prompt: str, + stop: Optional[list[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + params = {**self._identifying_params, **kwargs, "stream": True} + if stop is not None: + params.update({"stop": stop}) + response = self.client.completions.create(prompt=prompt, **params) + for chunk in response: + if run_manager: + run_manager.on_llm_new_token(chunk.value) + yield GenerationChunk(text=chunk.value) + + async def _astream( + self, + prompt: str, + stop: Optional[list[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + params = {**self._identifying_params, **kwargs, "stream": True} + if stop is not None: + params.update({"stop": stop}) + response = await self.async_client.completions.create(prompt=prompt, **params) + async for chunk in response: + if run_manager: + await run_manager.on_llm_new_token(chunk.value) + yield GenerationChunk(text=chunk.value) diff --git a/libs/community/langchain_community/query_constructors/neo4j.py b/libs/community/langchain_community/query_constructors/neo4j.py index ecb62452069ac..2ce1de136fcb4 100644 --- a/libs/community/langchain_community/query_constructors/neo4j.py +++ b/libs/community/langchain_community/query_constructors/neo4j.py @@ -1,5 +1,6 @@ from typing import Dict, Tuple, Union +from langchain_core._api.deprecation import deprecated from langchain_core.structured_query import ( Comparator, Comparison, @@ -10,6 +11,11 @@ ) +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.query_constructors.neo4j.Neo4jTranslator", +) class Neo4jTranslator(Visitor): """Translate `Neo4j` internal query language elements to valid filters.""" diff --git a/libs/community/langchain_community/vectorstores/neo4j_vector.py b/libs/community/langchain_community/vectorstores/neo4j_vector.py index 7f7f3f97dd875..03d97a5a9d034 100644 --- a/libs/community/langchain_community/vectorstores/neo4j_vector.py +++ b/libs/community/langchain_community/vectorstores/neo4j_vector.py @@ -16,6 +16,7 @@ ) import numpy as np +from langchain_core._api.deprecation import deprecated from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.utils import get_from_dict_or_env @@ -63,6 +64,11 @@ ) +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.vectorstores.neo4j_vector.SearchType", +) class SearchType(str, enum.Enum): """Enumerator of the Distance strategies.""" @@ -73,6 +79,11 @@ class SearchType(str, enum.Enum): DEFAULT_SEARCH_TYPE = SearchType.VECTOR +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.vectorstores.neo4j_vector.IndexType", +) class IndexType(str, enum.Enum): """Enumerator of the index types.""" @@ -83,6 +94,11 @@ class IndexType(str, enum.Enum): DEFAULT_INDEX_TYPE = IndexType.NODE +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.vectorstores.neo4j_vector._get_search_index_query", +) def _get_search_index_query( search_type: SearchType, index_type: IndexType = DEFAULT_INDEX_TYPE ) -> str: @@ -119,6 +135,11 @@ def _get_search_index_query( ) +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.vectorstores.neo4j_vector.check_if_not_null", +) def check_if_not_null(props: List[str], values: List[Any]) -> None: """Check if the values are not None or empty string""" for prop, value in zip(props, values): @@ -126,6 +147,11 @@ def check_if_not_null(props: List[str], values: List[Any]) -> None: raise ValueError(f"Parameter `{prop}` must not be None or empty string") +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.vectorstores.neo4j_vector.sort_by_index_name", +) def sort_by_index_name( lst: List[Dict[str, Any]], index_name: str ) -> List[Dict[str, Any]]: @@ -133,6 +159,11 @@ def sort_by_index_name( return sorted(lst, key=lambda x: x.get("name") != index_name) +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.vectorstores.neo4j_vector.remove_lucene_chars", +) def remove_lucene_chars(text: str) -> str: """Remove Lucene special characters""" special_chars = [ @@ -161,6 +192,11 @@ def remove_lucene_chars(text: str) -> str: return text.strip() +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.vectorstores.neo4j_vector.dict_to_yaml_str", +) def dict_to_yaml_str(input_dict: Dict, indent: int = 0) -> str: """ Convert a dictionary to a YAML-like string without using external libraries. @@ -186,6 +222,11 @@ def dict_to_yaml_str(input_dict: Dict, indent: int = 0) -> str: return yaml_str +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.vectorstores.neo4j_vector.combine_queries", +) def combine_queries( input_queries: List[Tuple[str, Dict[str, Any]]], operator: str ) -> Tuple[str, Dict[str, Any]]: @@ -220,6 +261,11 @@ def combine_queries( return combined_query, combined_params +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.vectorstores.neo4j_vector.collect_params", +) def collect_params( input_data: List[Tuple[str, Dict[str, str]]], ) -> Tuple[List[str], Dict[str, Any]]: @@ -247,6 +293,11 @@ def collect_params( return (query_parts, params) +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.vectorstores.neo4j_vector._handle_field_filter", +) def _handle_field_filter( field: str, value: Any, param_number: int = 1 ) -> Tuple[str, Dict]: @@ -348,6 +399,11 @@ def _handle_field_filter( raise NotImplementedError() +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.vectorstores.neo4j_vector.construct_metadata_filter", +) def construct_metadata_filter(filter: Dict[str, Any]) -> Tuple[str, Dict]: """Construct a metadata filter. @@ -430,6 +486,11 @@ def construct_metadata_filter(filter: Dict[str, Any]) -> Tuple[str, Dict]: raise ValueError("Got an empty dictionary for filters.") +@deprecated( + since="0.3.8", + removal="1.0", + alternative_import="langchain_neo4j.Neo4jVector", +) class Neo4jVector(VectorStore): """`Neo4j` vector index. diff --git a/libs/community/scripts/check_pydantic.sh b/libs/community/scripts/check_pydantic.sh index 99cb222d2b26e..ca83c483d515a 100755 --- a/libs/community/scripts/check_pydantic.sh +++ b/libs/community/scripts/check_pydantic.sh @@ -20,7 +20,7 @@ count=$(git grep -E '(@root_validator)|(@validator)|(@field_validator)|(@pre_ini # PRs that increase the current count will not be accepted. # PRs that decrease update the code in the repository # and allow decreasing the count of are welcome! -current_count=126 +current_count=125 if [ "$count" -gt "$current_count" ]; then echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator." diff --git a/libs/community/tests/integration_tests/llms/test_writer.py b/libs/community/tests/integration_tests/llms/test_writer.py deleted file mode 100644 index db8ad809144b0..0000000000000 --- a/libs/community/tests/integration_tests/llms/test_writer.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Test Writer API wrapper.""" - -from langchain_community.llms.writer import Writer - - -def test_writer_call() -> None: - """Test valid call to Writer.""" - llm = Writer() - output = llm.invoke("Say foo:") - assert isinstance(output, str) diff --git a/libs/community/tests/unit_tests/chat_models/test_writer.py b/libs/community/tests/unit_tests/chat_models/test_writer.py index 944a9dfeaba1f..2524f62957d1b 100644 --- a/libs/community/tests/unit_tests/chat_models/test_writer.py +++ b/libs/community/tests/unit_tests/chat_models/test_writer.py @@ -1,61 +1,251 @@ -"""Unit tests for Writer chat model integration.""" - import json -from typing import Any, Dict, List -from unittest.mock import AsyncMock, MagicMock, patch +from typing import Any, Dict, List, Literal, Optional, Tuple, Type +from unittest import mock +from unittest.mock import AsyncMock, MagicMock import pytest from langchain_core.callbacks.manager import CallbackManager +from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_tests.unit_tests import ChatModelUnitTests from pydantic import SecretStr -from langchain_community.chat_models.writer import ChatWriter, _convert_dict_to_message +from langchain_community.chat_models.writer import ChatWriter from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +"""Classes for mocking Writer responses.""" + + +class ChoiceDelta: + def __init__(self, content: str): + self.content = content + + +class ChunkChoice: + def __init__(self, index: int, finish_reason: str, delta: ChoiceDelta): + self.index = index + self.finish_reason = finish_reason + self.delta = delta + + +class ChatCompletionChunk: + def __init__( + self, + id: str, + object: str, + created: int, + model: str, + choices: List[ChunkChoice], + ): + self.id = id + self.object = object + self.created = created + self.model = model + self.choices = choices + + +class ToolCallFunction: + def __init__(self, name: str, arguments: str): + self.name = name + self.arguments = arguments + + +class ChoiceMessageToolCall: + def __init__(self, id: str, type: str, function: ToolCallFunction): + self.id = id + self.type = type + self.function = function + + +class Usage: + def __init__( + self, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + ): + self.prompt_tokens = prompt_tokens + self.completion_tokens = completion_tokens + self.total_tokens = total_tokens + + +class ChoiceMessage: + def __init__( + self, + role: str, + content: str, + tool_calls: Optional[List[ChoiceMessageToolCall]] = None, + ): + self.role = role + self.content = content + self.tool_calls = tool_calls + + +class Choice: + def __init__(self, index: int, finish_reason: str, message: ChoiceMessage): + self.index = index + self.finish_reason = finish_reason + self.message = message + + +class Chat: + def __init__( + self, + id: str, + object: str, + created: int, + system_fingerprint: str, + model: str, + usage: Usage, + choices: List[Choice], + ): + self.id = id + self.object = object + self.created = created + self.system_fingerprint = system_fingerprint + self.model = model + self.usage = usage + self.choices = choices + + +@pytest.mark.requires("writerai") +class TestChatWriterCustom: + """Test case for ChatWriter""" + + @pytest.fixture(autouse=True) + def mock_unstreaming_completion(self) -> Chat: + """Fixture providing a mock API response.""" + return Chat( + id="chat-12345", + object="chat.completion", + created=1699000000, + model="palmyra-x-004", + system_fingerprint="v1", + usage=Usage(prompt_tokens=10, completion_tokens=8, total_tokens=18), + choices=[ + Choice( + index=0, + finish_reason="stop", + message=ChoiceMessage( + role="assistant", + content="Hello! How can I help you?", + ), + ) + ], + ) + + @pytest.fixture(autouse=True) + def mock_tool_call_choice_response(self) -> Chat: + return Chat( + id="chat-12345", + object="chat.completion", + created=1699000000, + model="palmyra-x-004", + system_fingerprint="v1", + usage=Usage(prompt_tokens=29, completion_tokens=32, total_tokens=61), + choices=[ + Choice( + index=0, + finish_reason="tool_calls", + message=ChoiceMessage( + role="assistant", + content="", + tool_calls=[ + ChoiceMessageToolCall( + id="call_abc123", + type="function", + function=ToolCallFunction( + name="GetWeather", + arguments='{"location": "London"}', + ), + ) + ], + ), + ) + ], + ) + + @pytest.fixture(autouse=True) + def mock_streaming_chunks(self) -> List[ChatCompletionChunk]: + """Fixture providing mock streaming response chunks.""" + return [ + ChatCompletionChunk( + id="chat-12345", + object="chat.completion", + created=1699000000, + model="palmyra-x-004", + choices=[ + ChunkChoice( + index=0, + finish_reason="stop", + delta=ChoiceDelta(content="Hello! "), + ) + ], + ), + ChatCompletionChunk( + id="chat-12345", + object="chat.completion", + created=1699000000, + model="palmyra-x-004", + choices=[ + ChunkChoice( + index=0, + finish_reason="stop", + delta=ChoiceDelta(content="How can I help you?"), + ) + ], + ), + ] -class TestChatWriter: def test_writer_model_param(self) -> None: """Test different ways to initialize the chat model.""" test_cases: List[dict] = [ - {"model_name": "palmyra-x-004", "writer_api_key": "test-key"}, - {"model": "palmyra-x-004", "writer_api_key": "test-key"}, - {"model_name": "palmyra-x-004", "writer_api_key": "test-key"}, + { + "model_name": "palmyra-x-004", + "api_key": "key", + }, + { + "model": "palmyra-x-004", + "api_key": "key", + }, + { + "model_name": "palmyra-x-004", + "api_key": "key", + }, { "model": "palmyra-x-004", - "writer_api_key": "test-key", "temperature": 0.5, + "api_key": "key", }, ] for case in test_cases: chat = ChatWriter(**case) assert chat.model_name == "palmyra-x-004" - assert chat.writer_api_key - assert chat.writer_api_key.get_secret_value() == "test-key" assert chat.temperature == (0.5 if "temperature" in case else 0.7) - def test_convert_dict_to_message_human(self) -> None: + def test_convert_writer_to_langchain_human(self) -> None: """Test converting a human message dict to a LangChain message.""" message = {"role": "user", "content": "Hello"} - result = _convert_dict_to_message(message) + result = ChatWriter._convert_writer_to_langchain(message) assert isinstance(result, HumanMessage) assert result.content == "Hello" - def test_convert_dict_to_message_ai(self) -> None: + def test_convert_writer_to_langchain_ai(self) -> None: """Test converting an AI message dict to a LangChain message.""" message = {"role": "assistant", "content": "Hello"} - result = _convert_dict_to_message(message) + result = ChatWriter._convert_writer_to_langchain(message) assert isinstance(result, AIMessage) assert result.content == "Hello" - def test_convert_dict_to_message_system(self) -> None: + def test_convert_writer_to_langchain_system(self) -> None: """Test converting a system message dict to a LangChain message.""" message = {"role": "system", "content": "You are a helpful assistant"} - result = _convert_dict_to_message(message) + result = ChatWriter._convert_writer_to_langchain(message) assert isinstance(result, SystemMessage) assert result.content == "You are a helpful assistant" - def test_convert_dict_to_message_tool_call(self) -> None: + def test_convert_writer_to_langchain_tool_call(self) -> None: """Test converting a tool call message dict to a LangChain message.""" content = json.dumps({"result": 42}) message = { @@ -64,12 +254,12 @@ def test_convert_dict_to_message_tool_call(self) -> None: "content": content, "tool_call_id": "call_abc123", } - result = _convert_dict_to_message(message) + result = ChatWriter._convert_writer_to_langchain(message) assert isinstance(result, ToolMessage) assert result.name == "get_number" assert result.content == content - def test_convert_dict_to_message_with_tool_calls(self) -> None: + def test_convert_writer_to_langchain_with_tool_calls(self) -> None: """Test converting an AIMessage with tool calls.""" message = { "role": "assistant", @@ -85,131 +275,55 @@ def test_convert_dict_to_message_with_tool_calls(self) -> None: } ], } - result = _convert_dict_to_message(message) + result = ChatWriter._convert_writer_to_langchain(message) assert isinstance(result, AIMessage) assert result.tool_calls assert len(result.tool_calls) == 1 assert result.tool_calls[0]["name"] == "get_weather" assert result.tool_calls[0]["args"]["location"] == "London" - @pytest.fixture(autouse=True) - def mock_completion(self) -> Dict[str, Any]: - """Fixture providing a mock API response.""" - return { - "id": "chat-12345", - "object": "chat.completion", - "created": 1699000000, - "model": "palmyra-x-004", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "Hello! How can I help you?", - }, - "finish_reason": "stop", - } - ], - "usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}, - } - - @pytest.fixture(autouse=True) - def mock_response(self) -> Dict[str, Any]: - response = { - "id": "chat-12345", - "choices": [ - { - "message": { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_abc123", - "type": "function", - "function": { - "name": "GetWeather", - "arguments": '{"location": "London"}', - }, - } - ], - }, - "finish_reason": "tool_calls", - } - ], - } - return response - - @pytest.fixture(autouse=True) - def mock_streaming_chunks(self) -> List[Dict[str, Any]]: - """Fixture providing mock streaming response chunks.""" - return [ - { - "id": "chat-12345", - "object": "chat.completion.chunk", - "created": 1699000000, - "model": "palmyra-x-004", - "choices": [ - { - "index": 0, - "delta": { - "role": "assistant", - "content": "Hello", - }, - "finish_reason": None, - } - ], - }, - { - "id": "chat-12345", - "object": "chat.completion.chunk", - "created": 1699000000, - "model": "palmyra-x-004", - "choices": [ - { - "index": 0, - "delta": { - "content": "!", - }, - "finish_reason": "stop", - } - ], - }, - ] - - def test_sync_completion(self, mock_completion: Dict[str, Any]) -> None: + def test_sync_completion( + self, mock_unstreaming_completion: List[ChatCompletionChunk] + ) -> None: """Test basic chat completion with mocked response.""" - chat = ChatWriter(api_key=SecretStr("test-key")) + chat = ChatWriter(api_key=SecretStr("key")) + mock_client = MagicMock() - mock_client.chat.chat.return_value = mock_completion + mock_client.chat.chat.return_value = mock_unstreaming_completion - with patch.object(chat, "client", mock_client): + with mock.patch.object(chat, "client", mock_client): message = HumanMessage(content="Hi there!") response = chat.invoke([message]) assert isinstance(response, AIMessage) assert response.content == "Hello! How can I help you?" - async def test_async_completion(self, mock_completion: Dict[str, Any]) -> None: + @pytest.mark.asyncio + async def test_async_completion( + self, mock_unstreaming_completion: List[ChatCompletionChunk] + ) -> None: """Test async chat completion with mocked response.""" - chat = ChatWriter(api_key=SecretStr("test-key")) - mock_client = AsyncMock() - mock_client.chat.chat.return_value = mock_completion + chat = ChatWriter(api_key=SecretStr("key")) + + mock_async_client = AsyncMock() + mock_async_client.chat.chat.return_value = mock_unstreaming_completion - with patch.object(chat, "async_client", mock_client): + with mock.patch.object(chat, "async_client", mock_async_client): message = HumanMessage(content="Hi there!") response = await chat.ainvoke([message]) assert isinstance(response, AIMessage) assert response.content == "Hello! How can I help you?" - def test_sync_streaming(self, mock_streaming_chunks: List[Dict[str, Any]]) -> None: + def test_sync_streaming( + self, mock_streaming_chunks: List[ChatCompletionChunk] + ) -> None: """Test sync streaming with callback handler.""" callback_handler = FakeCallbackHandler() callback_manager = CallbackManager([callback_handler]) chat = ChatWriter( - streaming=True, + api_key=SecretStr("key"), callback_manager=callback_manager, max_tokens=10, - api_key=SecretStr("test-key"), ) mock_client = MagicMock() @@ -217,42 +331,46 @@ def test_sync_streaming(self, mock_streaming_chunks: List[Dict[str, Any]]) -> No mock_response.__iter__.return_value = mock_streaming_chunks mock_client.chat.chat.return_value = mock_response - with patch.object(chat, "client", mock_client): + with mock.patch.object(chat, "client", mock_client): message = HumanMessage(content="Hi") - response = chat.invoke([message]) - - assert isinstance(response, AIMessage) + response = chat.stream([message]) + response_message = "" + for chunk in response: + response_message += str(chunk.content) assert callback_handler.llm_streams > 0 - assert response.content == "Hello!" + assert response_message == "Hello! How can I help you?" + @pytest.mark.asyncio async def test_async_streaming( - self, mock_streaming_chunks: List[Dict[str, Any]] + self, mock_streaming_chunks: List[ChatCompletionChunk] ) -> None: """Test async streaming with callback handler.""" callback_handler = FakeCallbackHandler() callback_manager = CallbackManager([callback_handler]) chat = ChatWriter( - streaming=True, + api_key=SecretStr("key"), callback_manager=callback_manager, max_tokens=10, - api_key=SecretStr("test-key"), ) - mock_client = AsyncMock() + mock_async_client = AsyncMock() mock_response = AsyncMock() mock_response.__aiter__.return_value = mock_streaming_chunks - mock_client.chat.chat.return_value = mock_response + mock_async_client.chat.chat.return_value = mock_response - with patch.object(chat, "async_client", mock_client): + with mock.patch.object(chat, "async_client", mock_async_client): message = HumanMessage(content="Hi") - response = await chat.ainvoke([message]) - - assert isinstance(response, AIMessage) + response = chat.astream([message]) + response_message = "" + async for chunk in response: + response_message += str(chunk.content) assert callback_handler.llm_streams > 0 - assert response.content == "Hello!" + assert response_message == "Hello! How can I help you?" - def test_sync_tool_calling(self, mock_response: Dict[str, Any]) -> None: + def test_sync_tool_calling( + self, mock_tool_call_choice_response: Dict[str, Any] + ) -> None: """Test synchronous tool calling functionality.""" from pydantic import BaseModel, Field @@ -261,23 +379,27 @@ class GetWeather(BaseModel): location: str = Field(..., description="The location to get weather for") - mock_client = MagicMock() - mock_client.chat.chat.return_value = mock_response + chat = ChatWriter(api_key=SecretStr("key")) - chat = ChatWriter(api_key=SecretStr("test-key"), client=mock_client) + mock_client = MagicMock() + mock_client.chat.chat.return_value = mock_tool_call_choice_response chat_with_tools = chat.bind_tools( tools=[GetWeather], tool_choice="GetWeather", ) - response = chat_with_tools.invoke("What's the weather in London?") - assert isinstance(response, AIMessage) - assert response.tool_calls - assert response.tool_calls[0]["name"] == "GetWeather" - assert response.tool_calls[0]["args"]["location"] == "London" + with mock.patch.object(chat, "client", mock_client): + response = chat_with_tools.invoke("What's the weather in London?") + assert isinstance(response, AIMessage) + assert response.tool_calls + assert response.tool_calls[0]["name"] == "GetWeather" + assert response.tool_calls[0]["args"]["location"] == "London" - async def test_async_tool_calling(self, mock_response: Dict[str, Any]) -> None: + @pytest.mark.asyncio + async def test_async_tool_calling( + self, mock_tool_call_choice_response: Dict[str, Any] + ) -> None: """Test asynchronous tool calling functionality.""" from pydantic import BaseModel, Field @@ -286,18 +408,101 @@ class GetWeather(BaseModel): location: str = Field(..., description="The location to get weather for") - mock_client = AsyncMock() - mock_client.chat.chat.return_value = mock_response + mock_async_client = AsyncMock() + mock_async_client.chat.chat.return_value = mock_tool_call_choice_response - chat = ChatWriter(api_key=SecretStr("test-key"), async_client=mock_client) + chat = ChatWriter(api_key=SecretStr("key")) chat_with_tools = chat.bind_tools( tools=[GetWeather], tool_choice="GetWeather", ) - response = await chat_with_tools.ainvoke("What's the weather in London?") - assert isinstance(response, AIMessage) - assert response.tool_calls - assert response.tool_calls[0]["name"] == "GetWeather" - assert response.tool_calls[0]["args"]["location"] == "London" + with mock.patch.object(chat, "async_client", mock_async_client): + response = await chat_with_tools.ainvoke("What's the weather in London?") + assert isinstance(response, AIMessage) + assert response.tool_calls + assert response.tool_calls[0]["name"] == "GetWeather" + assert response.tool_calls[0]["args"]["location"] == "London" + + +@pytest.mark.requires("writerai") +class TestChatWriterStandart(ChatModelUnitTests): + """Test case for ChatWriter that inherits from standard LangChain tests.""" + + @property + def chat_model_class(self) -> Type[BaseChatModel]: + """Return ChatWriter model class.""" + return ChatWriter + + @property + def chat_model_params(self) -> Dict: + """Return any additional parameters needed.""" + return { + "api_key": "fake-api-key", + "model_name": "palmyra-x-004", + } + + @property + def has_tool_calling(self) -> bool: + """Writer supports tool/function calling.""" + return True + + @property + def tool_choice_value(self) -> Optional[str]: + """Value to use for tool choice in tests.""" + return "auto" + + @property + def has_structured_output(self) -> bool: + """Writer does not yet support structured output.""" + return False + + @property + def supports_image_inputs(self) -> bool: + """Writer does not support image inputs.""" + return False + + @property + def supports_video_inputs(self) -> bool: + """Writer does not support video inputs.""" + return False + + @property + def returns_usage_metadata(self) -> bool: + """Writer returns token usage information.""" + return True + + @property + def supports_anthropic_inputs(self) -> bool: + """Writer does not support anthropic inputs.""" + return False + + @property + def supports_image_tool_message(self) -> bool: + """Writer does not support image tool message.""" + return False + + @property + def supported_usage_metadata_details( + self, + ) -> Dict[ + Literal["invoke", "stream"], + List[ + Literal[ + "audio_input", + "audio_output", + "reasoning_output", + "cache_read_input", + "cache_creation_input", + ] + ], + ]: + """Return which types of usage metadata your model supports.""" + return {"invoke": ["cache_creation_input"], "stream": ["reasoning_output"]} + + @property + def init_from_env_params(self) -> Tuple[dict, dict, dict]: + """Return env vars, init args, and expected instance attrs for initializing + from env vars.""" + return {"WRITER_API_KEY": "key"}, {"api_key": "key"}, {"api_key": "key"} diff --git a/libs/community/tests/unit_tests/llms/test_writer.py b/libs/community/tests/unit_tests/llms/test_writer.py new file mode 100644 index 0000000000000..ffdee04db0796 --- /dev/null +++ b/libs/community/tests/unit_tests/llms/test_writer.py @@ -0,0 +1,202 @@ +from typing import List +from unittest import mock +from unittest.mock import AsyncMock, MagicMock + +import pytest +from langchain_core.callbacks import CallbackManager +from pydantic import SecretStr + +from langchain_community.llms.writer import Writer +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler + +"""Classes for mocking Writer responses.""" + + +class Choice: + def __init__(self, text: str): + self.text = text + + +class Completion: + def __init__(self, choices: List[Choice]): + self.choices = choices + + +class StreamingData: + def __init__(self, value: str): + self.value = value + + +@pytest.mark.requires("writerai") +class TestWriterLLM: + """Unit tests for Writer LLM integration.""" + + @pytest.fixture(autouse=True) + def mock_unstreaming_completion(self) -> Completion: + """Fixture providing a mock API response.""" + return Completion(choices=[Choice(text="Hello! How can I help you?")]) + + @pytest.fixture(autouse=True) + def mock_streaming_completion(self) -> List[StreamingData]: + """Fixture providing mock streaming response chunks.""" + return [ + StreamingData(value="Hello! "), + StreamingData(value="How can I"), + StreamingData(value=" help you?"), + ] + + def test_sync_unstream_completion( + self, mock_unstreaming_completion: Completion + ) -> None: + """Test basic llm call with mocked response.""" + mock_client = MagicMock() + mock_client.completions.create.return_value = mock_unstreaming_completion + + llm = Writer(api_key=SecretStr("key")) + + with mock.patch.object(llm, "client", mock_client): + response_text = llm.invoke(input="Hello") + + assert response_text == "Hello! How can I help you?" + + def test_sync_unstream_completion_with_params( + self, mock_unstreaming_completion: Completion + ) -> None: + """Test llm call with passed params with mocked response.""" + mock_client = MagicMock() + mock_client.completions.create.return_value = mock_unstreaming_completion + + llm = Writer(api_key=SecretStr("key"), temperature=1) + + with mock.patch.object(llm, "client", mock_client): + response_text = llm.invoke(input="Hello") + + assert response_text == "Hello! How can I help you?" + + @pytest.mark.asyncio + async def test_async_unstream_completion( + self, mock_unstreaming_completion: Completion + ) -> None: + """Test async chat completion with mocked response.""" + mock_async_client = AsyncMock() + mock_async_client.completions.create.return_value = mock_unstreaming_completion + + llm = Writer(api_key=SecretStr("key")) + + with mock.patch.object(llm, "async_client", mock_async_client): + response_text = await llm.ainvoke(input="Hello") + + assert response_text == "Hello! How can I help you?" + + @pytest.mark.asyncio + async def test_async_unstream_completion_with_params( + self, mock_unstreaming_completion: Completion + ) -> None: + """Test async llm call with passed params with mocked response.""" + mock_async_client = AsyncMock() + mock_async_client.completions.create.return_value = mock_unstreaming_completion + + llm = Writer(api_key=SecretStr("key"), temperature=1) + + with mock.patch.object(llm, "async_client", mock_async_client): + response_text = await llm.ainvoke(input="Hello") + + assert response_text == "Hello! How can I help you?" + + def test_sync_streaming_completion( + self, mock_streaming_completion: List[StreamingData] + ) -> None: + """Test sync streaming.""" + + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.__iter__.return_value = mock_streaming_completion + mock_client.completions.create.return_value = mock_response + + llm = Writer(api_key=SecretStr("key")) + + with mock.patch.object(llm, "client", mock_client): + response = llm.stream(input="Hello") + + response_message = "" + for chunk in response: + response_message += chunk + + assert response_message == "Hello! How can I help you?" + + def test_sync_streaming_completion_with_callback_handler( + self, mock_streaming_completion: List[StreamingData] + ) -> None: + """Test sync streaming with callback handler.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.__iter__.return_value = mock_streaming_completion + mock_client.completions.create.return_value = mock_response + + llm = Writer( + api_key=SecretStr("key"), + callback_manager=callback_manager, + ) + + with mock.patch.object(llm, "client", mock_client): + response = llm.stream(input="Hello") + + response_message = "" + for chunk in response: + response_message += chunk + + assert callback_handler.llm_streams == 3 + assert response_message == "Hello! How can I help you?" + + @pytest.mark.asyncio + async def test_async_streaming_completion( + self, mock_streaming_completion: Completion + ) -> None: + """Test async streaming with callback handler.""" + + mock_async_client = AsyncMock() + mock_response = AsyncMock() + mock_response.__aiter__.return_value = mock_streaming_completion + mock_async_client.completions.create.return_value = mock_response + + llm = Writer(api_key=SecretStr("key")) + + with mock.patch.object(llm, "async_client", mock_async_client): + response = llm.astream(input="Hello") + + response_message = "" + async for chunk in response: + response_message += str(chunk) + + assert response_message == "Hello! How can I help you?" + + @pytest.mark.asyncio + async def test_async_streaming_completion_with_callback_handler( + self, mock_streaming_completion: Completion + ) -> None: + """Test async streaming with callback handler.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + + mock_async_client = AsyncMock() + mock_response = AsyncMock() + mock_response.__aiter__.return_value = mock_streaming_completion + mock_async_client.completions.create.return_value = mock_response + + llm = Writer( + api_key=SecretStr("key"), + callback_manager=callback_manager, + ) + + with mock.patch.object(llm, "async_client", mock_async_client): + response = llm.astream(input="Hello") + + response_message = "" + async for chunk in response: + response_message += str(chunk) + + assert callback_handler.llm_streams == 3 + assert response_message == "Hello! How can I help you?"