diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index a297f12c6c455..fd64b824a8df7 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -307,7 +307,7 @@ class ChatAnthropic(BaseChatModel): Key init args — client params: timeout: Optional[float] Timeout for requests. - max_retries: Optional[int] + max_retries: int Max number of retries if a request fails. api_key: Optional[str] Anthropic API key. If not passed in will be read from env var ANTHROPIC_API_KEY. @@ -558,7 +558,8 @@ class Joke(BaseModel): default_request_timeout: Optional[float] = Field(None, alias="timeout") """Timeout for requests to Anthropic Completion API.""" - max_retries: Optional[int] = None + # sdk default = 2: https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#retries + max_retries: int = 2 """Number of retries allowed for requests sent to the Anthropic Completion API.""" stop_sequences: Optional[List[str]] = Field(None, alias="stop") @@ -661,10 +662,9 @@ def _client_params(self) -> Dict[str, Any]: client_params: Dict[str, Any] = { "api_key": self.anthropic_api_key.get_secret_value(), "base_url": self.anthropic_api_url, + "max_retries": self.max_retries, "default_headers": (self.default_headers or None), } - if self.max_retries is not None: - client_params["max_retries"] = self.max_retries # value <= 0 indicates the param should be ignored. None is a meaningful value # for Anthropic client and treated differently than not specifying the param at # all. diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 02ba31b47e723..9297710470612 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -316,7 +316,7 @@ def is_lc_serializable(cls) -> bool: default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model" ) """Model name to use.""" - temperature: Optional[float] = None + temperature: float = 0.0 """What sampling temperature to use.""" stop: Optional[Union[str, List[str]]] = Field(default=None, alias="stop_sequences") """Default stop sequences.""" diff --git a/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr index da33d819cd30c..4375bf55ff02a 100644 --- a/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr @@ -22,7 +22,6 @@ 'request_timeout': 60.0, 'stop': list([ ]), - 'temperature': 0.0, }), 'lc': 1, 'name': 'ChatFireworks', diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 838867dc37bf4..5868e9cc6a3a4 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -119,7 +119,7 @@ class ChatGroq(BaseChatModel): Key init args — client params: timeout: Union[float, Tuple[float, float], Any, None] Timeout for requests. - max_retries: Optional[int] + max_retries: int Max number of retries. api_key: Optional[str] Groq API key. If not passed in will be read from env var GROQ_API_KEY. @@ -303,7 +303,7 @@ class Joke(BaseModel): async_client: Any = Field(default=None, exclude=True) #: :meta private: model_name: str = Field(default="mixtral-8x7b-32768", alias="model") """Model name to use.""" - temperature: Optional[float] = None + temperature: float = 0.7 """What sampling temperature to use.""" stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences") """Default stop sequences.""" @@ -327,11 +327,11 @@ class Joke(BaseModel): ) """Timeout for requests to Groq completion API. Can be float, httpx.Timeout or None.""" - max_retries: Optional[int] = None + max_retries: int = 2 """Maximum number of retries to make when generating.""" streaming: bool = False """Whether to stream the results or not.""" - n: Optional[int] = None + n: int = 1 """Number of chat completions to generate for each prompt.""" max_tokens: Optional[int] = None """Maximum number of tokens to generate.""" @@ -379,11 +379,10 @@ def build_extra(cls, values: Dict[str, Any]) -> Any: @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" - if self.n is not None and self.n < 1: + if self.n < 1: raise ValueError("n must be at least 1.") - elif self.n is not None and self.n > 1 and self.streaming: + if self.n > 1 and self.streaming: raise ValueError("n must be 1 when streaming.") - if self.temperature == 0: self.temperature = 1e-8 @@ -393,11 +392,10 @@ def validate_environment(self) -> Self: ), "base_url": self.groq_api_base, "timeout": self.request_timeout, + "max_retries": self.max_retries, "default_headers": self.default_headers, "default_query": self.default_query, } - if self.max_retries is not None: - client_params["max_retries"] = self.max_retries try: import groq diff --git a/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr index 919d2a5c3d3c0..741d2c847455d 100644 --- a/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr @@ -17,6 +17,7 @@ 'max_retries': 2, 'max_tokens': 100, 'model_name': 'mixtral-8x7b-32768', + 'n': 1, 'request_timeout': 60.0, 'stop': list([ ]), diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 63edab1f29a6b..686e4a7e6a8dd 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -95,11 +95,8 @@ def _create_retry_decorator( """Returns a tenacity retry decorator, preconfigured to handle exceptions""" errors = [httpx.RequestError, httpx.StreamError] - kwargs: dict = dict( - error_types=errors, max_retries=llm.max_retries, run_manager=run_manager - ) return create_base_retry_decorator( - **{k: v for k, v in kwargs.items() if v is not None} + error_types=errors, max_retries=llm.max_retries, run_manager=run_manager ) @@ -383,13 +380,13 @@ class ChatMistralAI(BaseChatModel): default_factory=secret_from_env("MISTRAL_API_KEY", default=None), ) endpoint: Optional[str] = Field(default=None, alias="base_url") - max_retries: Optional[int] = None - timeout: Optional[int] = None - max_concurrent_requests: Optional[int] = None + max_retries: int = 5 + timeout: int = 120 + max_concurrent_requests: int = 64 model: str = Field(default="mistral-small", alias="model_name") - temperature: Optional[float] = None + temperature: float = 0.7 max_tokens: Optional[int] = None - top_p: Optional[float] = None + top_p: float = 1 """Decode using nucleus sampling: consider the smallest set of tokens whose probability sum is at least top_p. Must be in the closed interval [0.0, 1.0].""" random_seed: Optional[int] = None diff --git a/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr index 07e4f33f3ce04..f7986097c47e3 100644 --- a/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr @@ -9,6 +9,7 @@ ]), 'kwargs': dict({ 'endpoint': 'boo', + 'max_concurrent_requests': 64, 'max_retries': 2, 'max_tokens': 100, 'mistral_api_key': dict({ @@ -21,6 +22,7 @@ 'model': 'mistral-small', 'temperature': 0.0, 'timeout': 60, + 'top_p': 1, }), 'lc': 1, 'name': 'ChatMistralAI', diff --git a/libs/partners/openai/langchain_openai/chat_models/azure.py b/libs/partners/openai/langchain_openai/chat_models/azure.py index c2de17988cb5e..2e1e5f8abfe03 100644 --- a/libs/partners/openai/langchain_openai/chat_models/azure.py +++ b/libs/partners/openai/langchain_openai/chat_models/azure.py @@ -79,7 +79,7 @@ class AzureChatOpenAI(BaseChatOpenAI): https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning timeout: Union[float, Tuple[float, float], Any, None] Timeout for requests. - max_retries: Optional[int] + max_retries: int Max number of retries. organization: Optional[str] OpenAI organization ID. If not passed in will be read from env @@ -586,9 +586,9 @@ def is_lc_serializable(cls) -> bool: @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" - if self.n is not None and self.n < 1: + if self.n < 1: raise ValueError("n must be at least 1.") - elif self.n is not None and self.n > 1 and self.streaming: + if self.n > 1 and self.streaming: raise ValueError("n must be 1 when streaming.") if self.disabled_params is None: @@ -641,11 +641,10 @@ def validate_environment(self) -> Self: "organization": self.openai_organization, "base_url": self.openai_api_base, "timeout": self.request_timeout, + "max_retries": self.max_retries, "default_headers": self.default_headers, "default_query": self.default_query, } - if self.max_retries is not None: - client_params["max_retries"] = self.max_retries if not self.client: sync_specific = {"http_client": self.http_client} self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type] diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 546a33c720e8b..142e7eca1a84b 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -409,7 +409,7 @@ class BaseChatOpenAI(BaseChatModel): root_async_client: Any = Field(default=None, exclude=True) #: :meta private: model_name: str = Field(default="gpt-3.5-turbo", alias="model") """Model name to use.""" - temperature: Optional[float] = None + temperature: float = 0.7 """What sampling temperature to use.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" @@ -430,7 +430,7 @@ class BaseChatOpenAI(BaseChatModel): ) """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or None.""" - max_retries: Optional[int] = None + max_retries: int = 2 """Maximum number of retries to make when generating.""" presence_penalty: Optional[float] = None """Penalizes repeated tokens.""" @@ -448,7 +448,7 @@ class BaseChatOpenAI(BaseChatModel): """Modify the likelihood of specified tokens appearing in the completion.""" streaming: bool = False """Whether to stream the results or not.""" - n: Optional[int] = None + n: int = 1 """Number of chat completions to generate for each prompt.""" top_p: Optional[float] = None """Total probability mass of tokens to consider at each step.""" @@ -532,9 +532,9 @@ def validate_temperature(cls, values: Dict[str, Any]) -> Any: @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" - if self.n is not None and self.n < 1: + if self.n < 1: raise ValueError("n must be at least 1.") - elif self.n is not None and self.n > 1 and self.streaming: + if self.n > 1 and self.streaming: raise ValueError("n must be 1 when streaming.") # Check OPENAI_ORGANIZATION for backwards compatibility. @@ -551,12 +551,10 @@ def validate_environment(self) -> Self: "organization": self.openai_organization, "base_url": self.openai_api_base, "timeout": self.request_timeout, + "max_retries": self.max_retries, "default_headers": self.default_headers, "default_query": self.default_query, } - if self.max_retries is not None: - client_params["max_retries"] = self.max_retries - if self.openai_proxy and (self.http_client or self.http_async_client): openai_proxy = self.openai_proxy http_client = self.http_client @@ -611,14 +609,14 @@ def _default_params(self) -> Dict[str, Any]: "stop": self.stop or None, # also exclude empty list for this "max_tokens": self.max_tokens, "extra_body": self.extra_body, - "n": self.n, - "temperature": self.temperature, "reasoning_effort": self.reasoning_effort, } params = { "model": self.model_name, "stream": self.streaming, + "n": self.n, + "temperature": self.temperature, **{k: v for k, v in exclude_if_none.items() if v is not None}, **self.model_kwargs, } @@ -1567,7 +1565,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] timeout: Union[float, Tuple[float, float], Any, None] Timeout for requests. - max_retries: Optional[int] + max_retries: int Max number of retries. api_key: Optional[str] OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY. diff --git a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr index 2060512958a9f..2b8c3563b9443 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr +++ b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr @@ -15,6 +15,7 @@ }), 'max_retries': 2, 'max_tokens': 100, + 'n': 1, 'openai_api_key': dict({ 'id': list([ 'AZURE_OPENAI_API_KEY', diff --git a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr index e7307c6158fbc..b7ab1ce9c072c 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr +++ b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr @@ -11,6 +11,7 @@ 'max_retries': 2, 'max_tokens': 100, 'model_name': 'gpt-3.5-turbo', + 'n': 1, 'openai_api_key': dict({ 'id': list([ 'OPENAI_API_KEY', diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 5eac32c0447dd..2e6cca0cd2d96 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -877,6 +877,8 @@ def test__get_request_payload() -> None: ], "model": "gpt-4o-2024-08-06", "stream": False, + "n": 1, + "temperature": 0.7, } payload = llm._get_request_payload(messages) assert payload == expected diff --git a/libs/partners/xai/Makefile b/libs/partners/xai/Makefile index 6859cc789a179..1626a01bc493c 100644 --- a/libs/partners/xai/Makefile +++ b/libs/partners/xai/Makefile @@ -8,13 +8,7 @@ TEST_FILE ?= tests/unit_tests/ integration_test integration_tests: TEST_FILE=tests/integration_tests/ -test tests: - poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE) - -test_watch: - poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE) - -integration_test integration_tests: +test tests integration_test integration_tests: poetry run pytest $(TEST_FILE) ###################### diff --git a/libs/partners/xai/langchain_xai/chat_models.py b/libs/partners/xai/langchain_xai/chat_models.py index a854be5487d4c..775d22740cd4e 100644 --- a/libs/partners/xai/langchain_xai/chat_models.py +++ b/libs/partners/xai/langchain_xai/chat_models.py @@ -320,9 +320,9 @@ def _get_ls_params( @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" - if self.n is not None and self.n < 1: + if self.n < 1: raise ValueError("n must be at least 1.") - if self.n is not None and self.n > 1 and self.streaming: + if self.n > 1 and self.streaming: raise ValueError("n must be 1 when streaming.") client_params: dict = { @@ -331,11 +331,10 @@ def validate_environment(self) -> Self: ), "base_url": self.xai_api_base, "timeout": self.request_timeout, + "max_retries": self.max_retries, "default_headers": self.default_headers, "default_query": self.default_query, } - if self.max_retries is not None: - client_params["max_retries"] = self.max_retries if client_params["api_key"] is None: raise ValueError( diff --git a/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr b/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr index 4cd1261555c90..5c6f113f2174a 100644 --- a/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr +++ b/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr @@ -10,6 +10,7 @@ 'max_retries': 2, 'max_tokens': 100, 'model_name': 'grok-beta', + 'n': 1, 'request_timeout': 60.0, 'stop': list([ ]),