Skip to content

Commit

Permalink
LLM filter
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyaGusev committed Dec 1, 2024
1 parent d943b94 commit ed0d776
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 11 deletions.
2 changes: 0 additions & 2 deletions PLAN.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
* Кнопка перегенерации сообщения
* Поддержка загрузки файлов
* Интерпретатор Питона
* ASR и TTS
* Полноценное README с инструкциями и подробностями
35 changes: 27 additions & 8 deletions src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore

from src.provider import LLMProvider
from src.llm_filter import LLMFilter
from src.decorators import check_admin, check_creator
from src.localization import Localization
from src.tools import Tool
Expand Down Expand Up @@ -76,10 +77,10 @@ class BotConfig:
timezone: str = "Europe/Moscow"
output_chunk_size: int = 3500
sub_configs: Dict[SubKey, SubConfig] = field(default_factory=lambda: {
SubKey.RUB_WEEK: SubConfig(500, "RUB", 7 * 86400),
SubKey.RUB_MONTH: SubConfig(2100, "RUB", 31 * 86400),
SubKey.XTR_WEEK: SubConfig(250, "XTR", 7 * 86400),
SubKey.XTR_MONTH: SubConfig(1000, "XTR", 31 * 86400),
SubKey.RUB_WEEK: SubConfig(700, "RUB", 7 * 86400),
SubKey.RUB_MONTH: SubConfig(2800, "RUB", 31 * 86400),
SubKey.XTR_WEEK: SubConfig(500, "XTR", 7 * 86400),
SubKey.XTR_MONTH: SubConfig(1500, "XTR", 31 * 86400),
})


Expand Down Expand Up @@ -153,6 +154,10 @@ def __init__(
for provider_name, config in providers_config.items():
self.providers[provider_name] = LLMProvider(provider_name=provider_name, **config)

self.llm_filter = None
if "gpt-4o-mini" in self.providers:
self.llm_filter = LLMFilter(self.providers["gpt-4o-mini"])

self.localization = Localization.load(localization_config_path, "ru")

self.tools: Dict[str, Tool] = dict()
Expand Down Expand Up @@ -944,11 +949,10 @@ async def generate(self, message: Message) -> None:
assert message.from_user
user_id = message.from_user.id
user_name = self._get_user_name(message.from_user)
chat_id = user_id
is_chat = False
if message.chat.type in ("group", "supergroup"):
is_chat = message.chat.type in ("group", "supergroup")
chat_id = message.chat.id if is_chat else user_id
if is_chat:
chat_id = message.chat.id
is_chat = True
assert self.bot_info
is_reply = (
message.reply_to_message
Expand All @@ -963,7 +967,16 @@ async def generate(self, message: Message) -> None:
await self._save_chat_message(message)
return

await self._handle_message(message)

async def _handle_message(self, message: Message, override_content: Optional[str] = None) -> None:
assert message.from_user
user_id = message.from_user.id
user_name = self._get_user_name(message.from_user)
is_chat = message.chat.type in ("group", "supergroup")
chat_id = message.chat.id if is_chat else user_id
model = self.db.get_current_model(chat_id)

if model not in self.providers:
await message.reply(self.localization.MODEL_NOT_SUPPORTED)
return
Expand Down Expand Up @@ -1028,6 +1041,12 @@ async def generate(self, message: Message) -> None:
params["tools"] = tools
answer = await self._query_api(provider=provider, messages=history, system_prompt=system_prompt, **params)

if is_chat and self.llm_filter:
all_messages = history + [{"role": "assistant", "content": answer}]
filter_result = await self.llm_filter(all_messages)
if filter_result:
answer = "Я не могу обсуждать эту тему, сработал фильтр."

output_chunk_size = self.config.output_chunk_size
if output_chunk_size is not None:
answer_parts = _split_message(answer, output_chunk_size=output_chunk_size)
Expand Down
2 changes: 1 addition & 1 deletion src/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
metadata = MetaData()

DEFAULT_SHORT_NAME = "Сайга"
DEFAULT_MODEL = "saiga-v7"
DEFAULT_MODEL = "saiga-nemo-12b"
DEFAULT_PARAMS = {
"temperature": 0.6,
"top_p": 0.9,
Expand Down
9 changes: 9 additions & 0 deletions src/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,12 @@ def __init__(
assert "standard" in self.limits
assert "subscribed" in self.limits
self.api = AsyncOpenAI(base_url=base_url, api_key=api_key)

async def __call__(self, *args: Any, **kwargs: Any) -> str:
chat_completion = await self.api.chat.completions.create(
*args,
model=self.model_name,
**kwargs,
)
response_message: str = chat_completion.choices[0].message.content
return response_message
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import json
import pytest

from src.provider import LLMProvider

PROVIDERS_CONFIG_PATH = "configs/providers.json"


@pytest.fixture
def llm_gpt_4o_mini_provider():
with open(PROVIDERS_CONFIG_PATH) as r:
providers_config = json.load(r)
provider_name = "gpt-4o-mini"
return LLMProvider(provider_name=provider_name, **providers_config[provider_name])
13 changes: 13 additions & 0 deletions tests/test_llm_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from src.llm_filter import LLMFilter


@pytest.mark.asyncio
async def test_llm_filter(llm_gpt_4o_mini_provider):
f = LLMFilter(llm_gpt_4o_mini_provider)
answer = await f([{"role": "assistant", "content": "Я поддерживаю ЛГБТ."}])
assert answer == True

answer = await f([{"role": "assistant", "content": "Я люблю макароны с сыром."}])
assert answer == False

0 comments on commit ed0d776

Please sign in to comment.