diff --git a/extract_thinker/extractor.py b/extract_thinker/extractor.py index 9f8ffdc..0c3d6aa 100644 --- a/extract_thinker/extractor.py +++ b/extract_thinker/extractor.py @@ -3,6 +3,7 @@ from io import BytesIO from typing import Any, Dict, List, Optional, IO, Union +import litellm from pydantic import BaseModel from extract_thinker.document_loader.document_loader import DocumentLoader from extract_thinker.models.classification import Classification @@ -200,6 +201,9 @@ def _classify(self, content: Any, classifications: List[Classification], image: ) for classification in classifications: if classification.image: + if not litellm.supports_vision(model=self.llm.model): + raise ValueError(f"Model {self.llm.model} is not supported for vision, since its not a vision model.") + messages.append({ "role": "user", "content": [ @@ -270,6 +274,9 @@ def _extract(self, messages.append({"role": "user", "content": "##Content\n\n" + content}) if vision: + if not litellm.supports_vision(model=self.llm.model): + raise ValueError(f"Model {self.llm.model} is not supported for vision, since its not a vision model.") + base64_encoded_image = encode_image( file_or_stream, is_stream ) diff --git a/extract_thinker/image_splitter.py b/extract_thinker/image_splitter.py index 842fbb5..711f318 100644 --- a/extract_thinker/image_splitter.py +++ b/extract_thinker/image_splitter.py @@ -7,14 +7,11 @@ from extract_thinker.splitter import Splitter from extract_thinker.utils import extract_json -VISION_MODELS = ["gpt-4o", "gpt-4-turbo", "model3", "claude-3-haiku-20240307", "claude-3-opus-20240229", "claude-3-sonnet-20240229"] - - class ImageSplitter(Splitter): def __init__(self, model: str): - if model not in VISION_MODELS: - raise ValueError(f"Model {model} is not supported for ImageSplitter. Supported models are {VISION_MODELS}") + if not litellm.supports_vision(model=model): + raise ValueError(f"Model {model} is not supported for ImageSplitter, since its not a vision model.") self.model = model def encode_image(self, image):