Skip to content

Commit

Permalink
Bug fixes and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyaGusev committed Jul 8, 2024
1 parent 4dc8ff9 commit 7665f7d
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 53 deletions.
12 changes: 12 additions & 0 deletions configs/bot.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"token": "",
"admin_user_name": "",
"admin_user_id": -1,
"temperature_range": [0.0, 0.5, 0.8, 1.0, 1.2],
"top_p_range": [0.8, 0.9, 0.95, 0.98, 1.0],
"timezone": "Europe/Moscow",
"sub_price_rub": 500,
"sub_price_stars": 250,
"sub_duration": 604800,
"output_chunk_size": 3500
}
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash
set -euo pipefail

python3 -m src.bot --bot-token $1 --providers-config-path configs/providers.json --db-path $2 --characters-path configs/characters.json --tools-config-path configs/tools.json --yookassa-config-path configs/yookassa.json --localization-config-path configs/localization.json
python3 -m src.bot --bot-config-path $1 --providers-config-path configs/providers.json --db-path $2 --characters-path configs/characters.json --tools-config-path configs/tools.json --yookassa-config-path configs/yookassa.json --localization-config-path configs/localization.json
125 changes: 73 additions & 52 deletions src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
from email.utils import parseaddr
from typing import cast, List, Dict, Any, Optional, Union, Callable, Coroutine, Tuple, BinaryIO
from dataclasses import dataclass

import fire # type: ignore
import tiktoken
Expand Down Expand Up @@ -43,33 +44,48 @@
from src.document_loader import DocumentLoader


TEMPERATURE_RANGE = (0.0, 0.5, 0.8, 1.0, 1.2)
TOP_P_RANGE = (0.8, 0.9, 0.95, 0.98, 1.0)
DALLE_DAILY_LIMIT = 5
ADMIN_USERNAME = "YallenGusev"
IMAGE_PLACEHOLDER = "<image_placeholder>"
TIMEZONE = "Europe/Moscow"

SUB_PRICE_RUB = 500
SUB_PRICE_STARS = 250
SUB_DURATION = 7 * 86400

ChatMessage = Dict[str, Any]
ChatMessages = List[ChatMessage]


@dataclass
class BotConfig:
token: str
admin_user_name: str
admin_user_id: int
temperature_range: List[float]
top_p_range: List[float]
timezone: str = "Europe/Moscow"
sub_price_rub: int = 500
sub_price_stars: int = 250
sub_duration: int = 7 * 86400
output_chunk_size: int = 3500


def _crop_content(content: str) -> str:
if isinstance(content, str):
return content.replace("\n", " ")[:40]
return IMAGE_PLACEHOLDER


class LlmBot:
def __init__(
self,
bot_token: str,
providers_config_path: str,
db_path: str,
bot_config_path: str,
localization_config_path: str,
output_chunk_size: Optional[int],
characters_path: Optional[str],
tools_config_path: Optional[str],
yookassa_config_path: Optional[str],
):
assert os.path.exists(bot_config_path)
with open(bot_config_path) as r:
self.config = BotConfig(**json.load(r))

self.providers: Dict[str, LLMProvider] = dict()
with open(providers_config_path) as r:
providers_config = json.load(r)
Expand All @@ -91,8 +107,6 @@ def __init__(
with open(characters_path) as r:
self.characters = json.load(r)

self.output_chunk_size = output_chunk_size

self.db = Database(db_path)

self.document_loader = DocumentLoader()
Expand All @@ -112,17 +126,17 @@ def __init__(
self.likes_kb.add(InlineKeyboardButton(text="👎", callback_data="feedback:dislike"))

self.temperature_kb = InlineKeyboardBuilder()
for value in TEMPERATURE_RANGE:
for value in self.config.temperature_range:
self.temperature_kb.add(InlineKeyboardButton(text=str(value), callback_data=f"settemperature:{value}"))

self.top_p_kb = InlineKeyboardBuilder()
for value in TOP_P_RANGE:
for value in self.config.top_p_range:
self.top_p_kb.add(InlineKeyboardButton(text=str(value), callback_data=f"settopp:{value}"))

self.buy_kb = InlineKeyboardBuilder()
self.buy_kb.add(InlineKeyboardButton(text=self.localization.BUY_WITH_STARS, callback_data="buy:stars"))

self.bot = Bot(token=bot_token, default=DefaultBotProperties(parse_mode=None))
self.bot = Bot(token=self.config.token, default=DefaultBotProperties(parse_mode=None))
self.bot_info: Optional[User] = None

self.dp = Dispatcher()
Expand Down Expand Up @@ -178,7 +192,7 @@ def __init__(
self.yookassa = YookassaHandler(**config)

async def start_polling(self) -> None:
self.scheduler = AsyncIOScheduler(timezone=TIMEZONE)
self.scheduler = AsyncIOScheduler(timezone=self.config.timezone)
if self.yookassa is not None:
self.scheduler.add_job(self.yookassa_check_payments, trigger="interval", seconds=30)
self.scheduler.start()
Expand All @@ -200,7 +214,10 @@ async def start(self, message: Message) -> None:
limits = {name: provider.limits for name, provider in self.providers.items()}
sub_limits = self.localization.LIMITS.render(limits=limits, mode=mode).strip()
content = self.localization.HELP.render(
model=model, message_count=remaining_count, sub_limits=sub_limits, admin_username=ADMIN_USERNAME
model=model,
message_count=remaining_count,
sub_limits=sub_limits,
admin_username=self.config.admin_user_name,
)
await message.reply(content, parse_mode=ParseMode.MARKDOWN)

Expand Down Expand Up @@ -294,10 +311,23 @@ def _merge_messages(messages: ChatMessages) -> ChatMessages:
role = m["role"]
if role == prev_role and role != "tool":
is_current_str = isinstance(content, str)
is_prev_str = isinstance(new_messages[-1]["content"], str)
is_current_list = isinstance(content, list)
prev_content = new_messages[-1]["content"]
is_prev_str = isinstance(prev_content, str)
is_prev_list = isinstance(prev_content, list)
if is_current_str and is_prev_str:
new_messages[-1]["content"] += "\n\n" + content
continue
elif is_current_str and is_prev_list:
prev_content.append({"type": "text", "text": content})
continue
elif is_prev_str and is_current_list:
content.insert(0, {"type": "text", "text": prev_content})
new_messages[-1]["content"] = content
continue
elif is_current_list and is_prev_list:
prev_content.extend(content)
continue
prev_role = role
new_messages.append(m)
return new_messages
Expand Down Expand Up @@ -465,7 +495,7 @@ async def sub_buy(self, message: Message) -> None:

limits = {name: provider.limits for name, provider in self.providers.items()}
sub_limits = self.localization.LIMITS.render(limits=limits, mode="subscribed").strip()
description = self.localization.SUB_DESCRIPTION.render(sub_limits=sub_limits, price=SUB_PRICE_RUB)
description = self.localization.SUB_DESCRIPTION.render(sub_limits=sub_limits, price=self.config.sub_price_rub)
await message.reply(description, parse_mode=ParseMode.MARKDOWN, reply_markup=self.buy_kb.as_markup())

async def stars_sub_buy_proceed(self, callback: CallbackQuery) -> None:
Expand All @@ -490,7 +520,7 @@ async def stars_sub_buy_proceed(self, callback: CallbackQuery) -> None:
chat_id,
title=title,
description=description,
prices=[LabeledPrice(label=title, amount=SUB_PRICE_STARS)],
prices=[LabeledPrice(label=title, amount=self.config.sub_price_stars)],
provider_token="",
currency="XTR",
payload=str(user_id),
Expand All @@ -515,7 +545,7 @@ async def successful_payment_handler(self, message: Message) -> None:
charge_id = successful_payment.telegram_payment_charge_id
self.db.add_charge(user_id, charge_id)
assert user_id == int(payload)
self.db.subscribe_user(user_id, SUB_DURATION)
self.db.subscribe_user(user_id, self.config.sub_duration)
await self.bot.send_message(chat_id, self.localization.SUB_SUCCESS)

async def yookassa_sub_buy_proceed(self, callback: CallbackQuery) -> None:
Expand Down Expand Up @@ -545,7 +575,7 @@ async def yookassa_sub_buy_proceed(self, callback: CallbackQuery) -> None:
assert self.bot_info
assert self.bot_info.username
payment_data = self.yookassa.create_payment(
SUB_PRICE_RUB, title, email=email, bot_username=self.bot_info.username
self.config.sub_price_rub, title, email=email, bot_username=self.bot_info.username
)
payment_id = payment_data["id"]
try:
Expand All @@ -569,7 +599,7 @@ async def yookassa_check_payments(self) -> None:
payment_id=payment.payment_id, status=status, internal_status=payment.internal_status
)
if status == YookassaStatus.SUCCEEDED:
self.db.subscribe_user(payment.user_id, SUB_DURATION)
self.db.subscribe_user(payment.user_id, self.config.sub_duration)
await self.bot.send_message(chat_id=payment.chat_id, text=self.localization.SUB_SUCCESS)
self.db.set_payment_status(payment.payment_id, status=status.value, internal_status="completed")
elif status == YookassaStatus.CANCELED:
Expand Down Expand Up @@ -813,19 +843,17 @@ async def generate(self, message: Message) -> None:
await message.reply(self.localization.LIMIT_EXCEEDED.format(model=model))
return

conv_id = self.db.get_current_conv_id(chat_id)
history = self.db.fetch_conversation(conv_id)
params = self.db.get_parameters(chat_id)
params = provider.params if params is None else params
system_prompt = self.db.get_system_prompt(chat_id)
system_prompt = provider.system_prompt if system_prompt is None else system_prompt
content = await self._build_content(message)

if "claude" in model and params["temperature"] > 1.0:
await message.reply(self.localization.CLAUDE_HIGH_TEMPERATURE)
return

conv_id = self.db.get_current_conv_id(chat_id)
history = self.db.fetch_conversation(conv_id)
system_prompt = self.db.get_system_prompt(chat_id)
if system_prompt is None:
system_prompt = provider.system_prompt

content = await self._build_content(message)
if not isinstance(content, str) and not provider.can_handle_images:
await message.reply(self.localization.CONTENT_NOT_SUPPORTED_BY_MODEL)
return
Expand Down Expand Up @@ -859,9 +887,9 @@ async def generate(self, message: Message) -> None:
history = self._fix_broken_tool_calls(history)
if tools and "gpt" not in model:
params["tools"] = tools
answer = await self._query_api(model=model, messages=history, system_prompt=system_prompt, **params)
answer = await self._query_api(provider=provider, messages=history, system_prompt=system_prompt, **params)

output_chunk_size = self.output_chunk_size
output_chunk_size = self.config.output_chunk_size
if output_chunk_size is not None:
answer_parts = [answer[i : i + output_chunk_size] for i in range(0, len(answer), output_chunk_size)]
else:
Expand Down Expand Up @@ -890,23 +918,23 @@ async def generate(self, message: Message) -> None:

except Exception:
traceback.print_exc()
text = self.localization.ERROR.format(admin_username=ADMIN_USERNAME, chat_id=chat_id)
text = self.localization.ERROR.format(admin_username=self.config.admin_user_name, chat_id=chat_id)
await placeholder.edit_text(text)

async def _query_api(self, model: str, messages: ChatMessages, system_prompt: str, **kwargs: Any) -> str:
@staticmethod
async def _query_api(provider: LLMProvider, messages: ChatMessages, system_prompt: str, **kwargs: Any) -> str:
assert messages
if messages[0]["role"] != "system" and system_prompt.strip():
messages.insert(0, {"role": "system", "content": system_prompt})

print(
model,
provider.model_name,
"####",
len(messages),
"####",
self._crop_content(messages[-1]["content"]),
_crop_content(messages[-1]["content"]),
)
casted_messages = [cast(ChatCompletionMessageParam, message) for message in messages]
provider = self.providers[model]
chat_completion = await provider.api.chat.completions.create(
model=provider.model_name, messages=casted_messages, **kwargs
)
Expand All @@ -915,13 +943,13 @@ async def _query_api(self, model: str, messages: ChatMessages, system_prompt: st
assert isinstance(chat_completion.choices[0].message.content, str), str(chat_completion)
answer: str = chat_completion.choices[0].message.content
print(
model,
provider.model_name,
"####",
len(messages),
"####",
self._crop_content(messages[-1]["content"]),
_crop_content(messages[-1]["content"]),
"####",
self._crop_content(answer),
_crop_content(answer),
)
return answer

Expand Down Expand Up @@ -1052,11 +1080,6 @@ def _fix_broken_tool_calls(messages: ChatMessages) -> ChatMessages:
def _get_user_name(self, user: User) -> str:
return str(user.full_name) if user.full_name else str(user.username)

def _crop_content(self, content: str) -> str:
if isinstance(content, str):
return content.replace("\n", " ")[:40]
return IMAGE_PLACEHOLDER

def _is_image_content(self, content: Any) -> bool:
return isinstance(content, list) and content[-1]["type"] == "image_url"

Expand All @@ -1077,26 +1100,24 @@ def _replace_images(self, messages: ChatMessages) -> ChatMessages:
return messages

def _truncate_text(self, text: str) -> str:
if self.output_chunk_size and len(text) > self.output_chunk_size:
text = text[: self.output_chunk_size] + "... truncated"
if self.config.output_chunk_size and len(text) > self.config.output_chunk_size:
text = text[: self.config.output_chunk_size] + "... truncated"
return text


def main(
bot_token: str,
bot_config_path: str,
providers_config_path: str,
db_path: str,
localization_config_path: str,
output_chunk_size: Optional[int] = 3500,
characters_path: Optional[str] = None,
tools_config_path: Optional[str] = None,
yookassa_config_path: Optional[str] = None,
) -> None:
bot = LlmBot(
bot_token=bot_token,
bot_config_path=bot_config_path,
providers_config_path=providers_config_path,
db_path=db_path,
output_chunk_size=output_chunk_size,
characters_path=characters_path,
tools_config_path=tools_config_path,
yookassa_config_path=yookassa_config_path,
Expand Down
11 changes: 11 additions & 0 deletions src/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,17 @@ def create_conv_id(self, user_id: int) -> str:
session.commit()
return conv_id

def get_user_id_by_conv_id(self, conv_id: str) -> int:
with self.Session() as session:
conv = (
session.query(Conversation)
.filter(Conversation.conv_id == conv_id)
.order_by(Conversation.timestamp.desc())
.first()
)
assert conv
return conv.user_id

def get_current_conv_id(self, user_id: int) -> str:
with self.Session() as session:
conv = (
Expand Down

0 comments on commit 7665f7d

Please sign in to comment.