diff --git a/configs/bot.json b/configs/bot.json new file mode 100644 index 0000000..d7bb064 --- /dev/null +++ b/configs/bot.json @@ -0,0 +1,12 @@ +{ + "token": "", + "admin_user_name": "", + "admin_user_id": -1, + "temperature_range": [0.0, 0.5, 0.8, 1.0, 1.2], + "top_p_range": [0.8, 0.9, 0.95, 0.98, 1.0], + "timezone": "Europe/Moscow", + "sub_price_rub": 500, + "sub_price_stars": 250, + "sub_duration": 604800, + "output_chunk_size": 3500 +} diff --git a/run.sh b/run.sh index b274161..7fd2789 100644 --- a/run.sh +++ b/run.sh @@ -1,4 +1,4 @@ #!/bin/bash set -euo pipefail -python3 -m src.bot --bot-token $1 --providers-config-path configs/providers.json --db-path $2 --characters-path configs/characters.json --tools-config-path configs/tools.json --yookassa-config-path configs/yookassa.json --localization-config-path configs/localization.json +python3 -m src.bot --bot-config-path $1 --providers-config-path configs/providers.json --db-path $2 --characters-path configs/characters.json --tools-config-path configs/tools.json --yookassa-config-path configs/yookassa.json --localization-config-path configs/localization.json diff --git a/src/bot.py b/src/bot.py index c4040e6..a79fc18 100644 --- a/src/bot.py +++ b/src/bot.py @@ -8,6 +8,7 @@ import re from email.utils import parseaddr from typing import cast, List, Dict, Any, Optional, Union, Callable, Coroutine, Tuple, BinaryIO +from dataclasses import dataclass import fire # type: ignore import tiktoken @@ -43,33 +44,48 @@ from src.document_loader import DocumentLoader -TEMPERATURE_RANGE = (0.0, 0.5, 0.8, 1.0, 1.2) -TOP_P_RANGE = (0.8, 0.9, 0.95, 0.98, 1.0) DALLE_DAILY_LIMIT = 5 -ADMIN_USERNAME = "YallenGusev" IMAGE_PLACEHOLDER = "" -TIMEZONE = "Europe/Moscow" - -SUB_PRICE_RUB = 500 -SUB_PRICE_STARS = 250 -SUB_DURATION = 7 * 86400 ChatMessage = Dict[str, Any] ChatMessages = List[ChatMessage] +@dataclass +class BotConfig: + token: str + admin_user_name: str + admin_user_id: int + temperature_range: List[float] + top_p_range: List[float] + timezone: str = "Europe/Moscow" + sub_price_rub: int = 500 + sub_price_stars: int = 250 + sub_duration: int = 7 * 86400 + output_chunk_size: int = 3500 + + +def _crop_content(content: str) -> str: + if isinstance(content, str): + return content.replace("\n", " ")[:40] + return IMAGE_PLACEHOLDER + + class LlmBot: def __init__( self, - bot_token: str, providers_config_path: str, db_path: str, + bot_config_path: str, localization_config_path: str, - output_chunk_size: Optional[int], characters_path: Optional[str], tools_config_path: Optional[str], yookassa_config_path: Optional[str], ): + assert os.path.exists(bot_config_path) + with open(bot_config_path) as r: + self.config = BotConfig(**json.load(r)) + self.providers: Dict[str, LLMProvider] = dict() with open(providers_config_path) as r: providers_config = json.load(r) @@ -91,8 +107,6 @@ def __init__( with open(characters_path) as r: self.characters = json.load(r) - self.output_chunk_size = output_chunk_size - self.db = Database(db_path) self.document_loader = DocumentLoader() @@ -112,17 +126,17 @@ def __init__( self.likes_kb.add(InlineKeyboardButton(text="👎", callback_data="feedback:dislike")) self.temperature_kb = InlineKeyboardBuilder() - for value in TEMPERATURE_RANGE: + for value in self.config.temperature_range: self.temperature_kb.add(InlineKeyboardButton(text=str(value), callback_data=f"settemperature:{value}")) self.top_p_kb = InlineKeyboardBuilder() - for value in TOP_P_RANGE: + for value in self.config.top_p_range: self.top_p_kb.add(InlineKeyboardButton(text=str(value), callback_data=f"settopp:{value}")) self.buy_kb = InlineKeyboardBuilder() self.buy_kb.add(InlineKeyboardButton(text=self.localization.BUY_WITH_STARS, callback_data="buy:stars")) - self.bot = Bot(token=bot_token, default=DefaultBotProperties(parse_mode=None)) + self.bot = Bot(token=self.config.token, default=DefaultBotProperties(parse_mode=None)) self.bot_info: Optional[User] = None self.dp = Dispatcher() @@ -178,7 +192,7 @@ def __init__( self.yookassa = YookassaHandler(**config) async def start_polling(self) -> None: - self.scheduler = AsyncIOScheduler(timezone=TIMEZONE) + self.scheduler = AsyncIOScheduler(timezone=self.config.timezone) if self.yookassa is not None: self.scheduler.add_job(self.yookassa_check_payments, trigger="interval", seconds=30) self.scheduler.start() @@ -200,7 +214,10 @@ async def start(self, message: Message) -> None: limits = {name: provider.limits for name, provider in self.providers.items()} sub_limits = self.localization.LIMITS.render(limits=limits, mode=mode).strip() content = self.localization.HELP.render( - model=model, message_count=remaining_count, sub_limits=sub_limits, admin_username=ADMIN_USERNAME + model=model, + message_count=remaining_count, + sub_limits=sub_limits, + admin_username=self.config.admin_user_name, ) await message.reply(content, parse_mode=ParseMode.MARKDOWN) @@ -294,10 +311,23 @@ def _merge_messages(messages: ChatMessages) -> ChatMessages: role = m["role"] if role == prev_role and role != "tool": is_current_str = isinstance(content, str) - is_prev_str = isinstance(new_messages[-1]["content"], str) + is_current_list = isinstance(content, list) + prev_content = new_messages[-1]["content"] + is_prev_str = isinstance(prev_content, str) + is_prev_list = isinstance(prev_content, list) if is_current_str and is_prev_str: new_messages[-1]["content"] += "\n\n" + content continue + elif is_current_str and is_prev_list: + prev_content.append({"type": "text", "text": content}) + continue + elif is_prev_str and is_current_list: + content.insert(0, {"type": "text", "text": prev_content}) + new_messages[-1]["content"] = content + continue + elif is_current_list and is_prev_list: + prev_content.extend(content) + continue prev_role = role new_messages.append(m) return new_messages @@ -465,7 +495,7 @@ async def sub_buy(self, message: Message) -> None: limits = {name: provider.limits for name, provider in self.providers.items()} sub_limits = self.localization.LIMITS.render(limits=limits, mode="subscribed").strip() - description = self.localization.SUB_DESCRIPTION.render(sub_limits=sub_limits, price=SUB_PRICE_RUB) + description = self.localization.SUB_DESCRIPTION.render(sub_limits=sub_limits, price=self.config.sub_price_rub) await message.reply(description, parse_mode=ParseMode.MARKDOWN, reply_markup=self.buy_kb.as_markup()) async def stars_sub_buy_proceed(self, callback: CallbackQuery) -> None: @@ -490,7 +520,7 @@ async def stars_sub_buy_proceed(self, callback: CallbackQuery) -> None: chat_id, title=title, description=description, - prices=[LabeledPrice(label=title, amount=SUB_PRICE_STARS)], + prices=[LabeledPrice(label=title, amount=self.config.sub_price_stars)], provider_token="", currency="XTR", payload=str(user_id), @@ -515,7 +545,7 @@ async def successful_payment_handler(self, message: Message) -> None: charge_id = successful_payment.telegram_payment_charge_id self.db.add_charge(user_id, charge_id) assert user_id == int(payload) - self.db.subscribe_user(user_id, SUB_DURATION) + self.db.subscribe_user(user_id, self.config.sub_duration) await self.bot.send_message(chat_id, self.localization.SUB_SUCCESS) async def yookassa_sub_buy_proceed(self, callback: CallbackQuery) -> None: @@ -545,7 +575,7 @@ async def yookassa_sub_buy_proceed(self, callback: CallbackQuery) -> None: assert self.bot_info assert self.bot_info.username payment_data = self.yookassa.create_payment( - SUB_PRICE_RUB, title, email=email, bot_username=self.bot_info.username + self.config.sub_price_rub, title, email=email, bot_username=self.bot_info.username ) payment_id = payment_data["id"] try: @@ -569,7 +599,7 @@ async def yookassa_check_payments(self) -> None: payment_id=payment.payment_id, status=status, internal_status=payment.internal_status ) if status == YookassaStatus.SUCCEEDED: - self.db.subscribe_user(payment.user_id, SUB_DURATION) + self.db.subscribe_user(payment.user_id, self.config.sub_duration) await self.bot.send_message(chat_id=payment.chat_id, text=self.localization.SUB_SUCCESS) self.db.set_payment_status(payment.payment_id, status=status.value, internal_status="completed") elif status == YookassaStatus.CANCELED: @@ -813,19 +843,17 @@ async def generate(self, message: Message) -> None: await message.reply(self.localization.LIMIT_EXCEEDED.format(model=model)) return + conv_id = self.db.get_current_conv_id(chat_id) + history = self.db.fetch_conversation(conv_id) params = self.db.get_parameters(chat_id) params = provider.params if params is None else params + system_prompt = self.db.get_system_prompt(chat_id) + system_prompt = provider.system_prompt if system_prompt is None else system_prompt + content = await self._build_content(message) + if "claude" in model and params["temperature"] > 1.0: await message.reply(self.localization.CLAUDE_HIGH_TEMPERATURE) return - - conv_id = self.db.get_current_conv_id(chat_id) - history = self.db.fetch_conversation(conv_id) - system_prompt = self.db.get_system_prompt(chat_id) - if system_prompt is None: - system_prompt = provider.system_prompt - - content = await self._build_content(message) if not isinstance(content, str) and not provider.can_handle_images: await message.reply(self.localization.CONTENT_NOT_SUPPORTED_BY_MODEL) return @@ -859,9 +887,9 @@ async def generate(self, message: Message) -> None: history = self._fix_broken_tool_calls(history) if tools and "gpt" not in model: params["tools"] = tools - answer = await self._query_api(model=model, messages=history, system_prompt=system_prompt, **params) + answer = await self._query_api(provider=provider, messages=history, system_prompt=system_prompt, **params) - output_chunk_size = self.output_chunk_size + output_chunk_size = self.config.output_chunk_size if output_chunk_size is not None: answer_parts = [answer[i : i + output_chunk_size] for i in range(0, len(answer), output_chunk_size)] else: @@ -890,23 +918,23 @@ async def generate(self, message: Message) -> None: except Exception: traceback.print_exc() - text = self.localization.ERROR.format(admin_username=ADMIN_USERNAME, chat_id=chat_id) + text = self.localization.ERROR.format(admin_username=self.config.admin_user_name, chat_id=chat_id) await placeholder.edit_text(text) - async def _query_api(self, model: str, messages: ChatMessages, system_prompt: str, **kwargs: Any) -> str: + @staticmethod + async def _query_api(provider: LLMProvider, messages: ChatMessages, system_prompt: str, **kwargs: Any) -> str: assert messages if messages[0]["role"] != "system" and system_prompt.strip(): messages.insert(0, {"role": "system", "content": system_prompt}) print( - model, + provider.model_name, "####", len(messages), "####", - self._crop_content(messages[-1]["content"]), + _crop_content(messages[-1]["content"]), ) casted_messages = [cast(ChatCompletionMessageParam, message) for message in messages] - provider = self.providers[model] chat_completion = await provider.api.chat.completions.create( model=provider.model_name, messages=casted_messages, **kwargs ) @@ -915,13 +943,13 @@ async def _query_api(self, model: str, messages: ChatMessages, system_prompt: st assert isinstance(chat_completion.choices[0].message.content, str), str(chat_completion) answer: str = chat_completion.choices[0].message.content print( - model, + provider.model_name, "####", len(messages), "####", - self._crop_content(messages[-1]["content"]), + _crop_content(messages[-1]["content"]), "####", - self._crop_content(answer), + _crop_content(answer), ) return answer @@ -1052,11 +1080,6 @@ def _fix_broken_tool_calls(messages: ChatMessages) -> ChatMessages: def _get_user_name(self, user: User) -> str: return str(user.full_name) if user.full_name else str(user.username) - def _crop_content(self, content: str) -> str: - if isinstance(content, str): - return content.replace("\n", " ")[:40] - return IMAGE_PLACEHOLDER - def _is_image_content(self, content: Any) -> bool: return isinstance(content, list) and content[-1]["type"] == "image_url" @@ -1077,26 +1100,24 @@ def _replace_images(self, messages: ChatMessages) -> ChatMessages: return messages def _truncate_text(self, text: str) -> str: - if self.output_chunk_size and len(text) > self.output_chunk_size: - text = text[: self.output_chunk_size] + "... truncated" + if self.config.output_chunk_size and len(text) > self.config.output_chunk_size: + text = text[: self.config.output_chunk_size] + "... truncated" return text def main( - bot_token: str, + bot_config_path: str, providers_config_path: str, db_path: str, localization_config_path: str, - output_chunk_size: Optional[int] = 3500, characters_path: Optional[str] = None, tools_config_path: Optional[str] = None, yookassa_config_path: Optional[str] = None, ) -> None: bot = LlmBot( - bot_token=bot_token, + bot_config_path=bot_config_path, providers_config_path=providers_config_path, db_path=db_path, - output_chunk_size=output_chunk_size, characters_path=characters_path, tools_config_path=tools_config_path, yookassa_config_path=yookassa_config_path, diff --git a/src/database.py b/src/database.py index d84b9d4..d329884 100644 --- a/src/database.py +++ b/src/database.py @@ -175,6 +175,17 @@ def create_conv_id(self, user_id: int) -> str: session.commit() return conv_id + def get_user_id_by_conv_id(self, conv_id: str) -> int: + with self.Session() as session: + conv = ( + session.query(Conversation) + .filter(Conversation.conv_id == conv_id) + .order_by(Conversation.timestamp.desc()) + .first() + ) + assert conv + return conv.user_id + def get_current_conv_id(self, user_id: int) -> str: with self.Session() as session: conv = (