Skip to content

Commit

Permalink
v3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Jan 10, 2024
1 parent 4ae20f8 commit b46bb5e
Show file tree
Hide file tree
Showing 12 changed files with 351 additions and 119 deletions.
16 changes: 0 additions & 16 deletions src/choose_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,6 @@ def choose_documents_directory():
if user_choice == QDialog.Rejected:
return

def see_documents_directory():
current_dir = Path(__file__).parent.resolve()
docs_folder = current_dir / "Docs_for_DB"
images_folder = current_dir / "Images_for_DB"

os_name = platform.system()
if os_name == 'Windows':
subprocess.Popen(['explorer', str(docs_folder)])
subprocess.Popen(['explorer', str(images_folder)])
elif os_name == 'Darwin':
subprocess.Popen(['open', str(docs_folder)])
subprocess.Popen(['open', str(images_folder)])
elif os_name == 'Linux':
subprocess.Popen(['xdg-open', str(docs_folder)])
subprocess.Popen(['xdg-open', str(images_folder)])

if __name__ == '__main__':
app = QApplication([])
choose_documents_directory()
Expand Down
14 changes: 7 additions & 7 deletions src/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Compute_Device:
database_creation: cpu
database_query: cpu
gpu_brand:
EMBEDDING_MODEL_NAME:
EMBEDDING_MODEL_NAME: null
Platform_Info:
os:
Supported_CTranslate2_Quantizations:
Expand All @@ -23,12 +23,12 @@ Supported_CTranslate2_Quantizations:
bark:
enable_cpu_offload: false
model_precision: float32
size: normal
speaker: v2/en_speaker_7
size: small
speaker: v2/en_speaker_6
use_better_transformer: true
database:
chunk_overlap: 250
chunk_size: 800
chunk_size: 750
contexts: 6
similarity: 0.9
embedding-models:
Expand All @@ -50,7 +50,7 @@ server:
prefix_obsidian_3B: <|im_start|>user\n
prefix_orca2: <|im_start|>user
prefix_phi2: 'Instruct:'
prompt_format_disabled: true
prompt_format_disabled: false
suffix: '### Assistant:'
suffix_chat_ml: <|im_end|>
suffix_llama2_and_mistral: '[/INST]'
Expand All @@ -70,7 +70,7 @@ transcribe_file:
device: cpu
file: null
model: medium.en
quant: float16
quant: float32
timestamps: true
transcriber:
device: cpu
Expand Down Expand Up @@ -108,4 +108,4 @@ vision:
- float32
available_sizes:
- 470m
test_image:
test_image: null
8 changes: 7 additions & 1 deletion src/download_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from PySide6.QtCore import Qt
from PySide6.QtCore import Qt, QObject, Signal
from PySide6.QtWidgets import QDialog, QVBoxLayout, QHBoxLayout, QRadioButton, QPushButton, QButtonGroup, QLabel, QGridLayout
import subprocess
import threading
from pathlib import Path
from constants import AVAILABLE_MODELS

class ModelDownloadedSignal(QObject):
downloaded = Signal(str)

model_downloaded_signal = ModelDownloadedSignal()

class DownloadModelDialog(QDialog):
def __init__(self, parent=None):
super(DownloadModelDialog, self).__init__(parent)
Expand Down Expand Up @@ -105,6 +110,7 @@ def download_embedding_model(parent):
def download_model():
subprocess.run(["git", "clone", model_url, str(target_directory)])
print(f"{selected_model['model']} has been downloaded and is ready to use!")
model_downloaded_signal.downloaded.emit(selected_model['model'])

download_thread = threading.Thread(target=download_model)
download_thread.start()
81 changes: 9 additions & 72 deletions src/gui.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from PySide6.QtWidgets import (
QApplication, QWidget, QPushButton, QVBoxLayout, QTabWidget,
QTextEdit, QSplitter, QFrame, QStyleFactory, QLabel, QGridLayout, QMenuBar, QCheckBox, QHBoxLayout, QMessageBox
QApplication, QWidget, QVBoxLayout, QTabWidget, QTextEdit, QSplitter, QFrame,
QStyleFactory, QLabel, QGridLayout, QMenuBar, QCheckBox, QHBoxLayout, QMessageBox, QPushButton
)
from PySide6.QtCore import Qt, QTimer
import os
Expand All @@ -12,12 +12,8 @@
import threading
from initialize import main as initialize_system
from metrics_bar import MetricsBar
from download_model import download_embedding_model
from select_model import select_embedding_model_directory
from choose_documents import choose_documents_directory, see_documents_directory
import create_database
from gui_tabs import create_tabs
from gui_threads import CreateDatabaseThread, SubmitButtonThread
from gui_threads import SubmitButtonThread
import voice_recorder_module
from utilities import list_theme_files, make_theme_changer, load_stylesheet
from bark_module import BarkAudio
Expand All @@ -31,7 +27,6 @@ def __init__(self):
self.cumulative_response = ""
self.metrics_bar = MetricsBar()
self.compute_device = self.metrics_bar.determine_compute_device()
os_name = self.metrics_bar.get_os_name()
self.submit_button = None
self.init_ui()
self.load_config()
Expand All @@ -41,8 +36,7 @@ def is_nvidia_gpu(self):
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
return "nvidia" in gpu_name.lower()
return False


def load_config(self):
script_dir = Path(__file__).resolve().parent
config_path = os.path.join(script_dir, 'config.yaml')
Expand All @@ -59,25 +53,9 @@ def init_ui(self):
# LEFT FRAME
self.left_frame = QFrame()
grid_layout = QGridLayout()

tab_widget = create_tabs()
grid_layout.addWidget(tab_widget, 0, 0, 1, 2)

# Buttons data
button_data = [
("Download Embedding Model", lambda: download_embedding_model(self)),
("Choose Embedding Model Directory", select_embedding_model_directory),
("Choose Documents or Images", choose_documents_directory),
("See Currently Chosen Documents", see_documents_directory),
("Create Vector Database", self.on_create_button_clicked)
]
button_positions = [(1, 0), (1, 1), (2, 0), (2, 1), (3, 0)]

# Create and add buttons
for position, (text, handler) in zip(button_positions, button_data):
button = QPushButton(text)
button.clicked.connect(handler)
grid_layout.addWidget(button, *position)

self.left_frame.setLayout(grid_layout)
main_splitter.addWidget(self.left_frame)
Expand Down Expand Up @@ -111,15 +89,15 @@ def init_ui(self):
right_vbox.addLayout(checkbox_button_hbox)

# Create and add button row
button_row_widget = self.create_button_row(self.on_submit_button_clicked)
button_row_widget = self.create_button_row()
right_vbox.addWidget(button_row_widget)

right_frame.setLayout(right_vbox)
main_splitter.addWidget(right_frame)

main_layout = QVBoxLayout(self)
main_layout.addWidget(main_splitter)

# Metrics bar
main_layout.addWidget(self.metrics_bar)
self.metrics_bar.setMaximumHeight(75 if self.is_nvidia_gpu() else 30)
Expand All @@ -139,44 +117,6 @@ def resizeEvent(self, event):
self.left_frame.setMinimumWidth(self.width() * 0.3)
super().resizeEvent(event)

def on_create_button_clicked(self):
script_dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(script_dir, 'config.yaml')
with open(config_path, 'r') as file:
config = yaml.safe_load(file)

if platform.system() == "Darwin" and any(images_dir.iterdir()):
QMessageBox.warning(self, "Error",
"Image processing has been disabled for MacOS for the time being until a fix can be implemented. Please remove all files from the 'Images_for_DB' folder and try again.")
return

embedding_model_name = config.get('EMBEDDING_MODEL_NAME')
if not embedding_model_name:
QMessageBox.warning(self, "Error",
"You must first download an embedding model, select it, and choose documents first before proceeding.")
return

documents_dir = Path(script_dir) / "Docs_for_DB"
images_dir = Path(script_dir) / "Images_for_DB"
if not any(documents_dir.iterdir()) and not any(images_dir.iterdir()):
QMessageBox.warning(self, "Error",
"No documents found to process. Please select files to add to the vector database and try again.")
return

# New check for compute device availability
compute_device = config.get('Compute_Device', {}).get('available', [])
database_creation = config.get('Compute_Device', {}).get('database_creation')

if ("cuda" in compute_device or "mps" in compute_device) and database_creation == "cpu":
reply = QMessageBox.question(self, 'Warning',
"GPU-acceleration is available and highly recommended for creating a vector database. Click OK to proceed or Cancel to go back and change the device.",
QMessageBox.Ok | QMessageBox.Cancel)
if reply == QMessageBox.Cancel:
return

self.create_database_thread = CreateDatabaseThread(self)
self.create_database_thread.start()

def on_submit_button_clicked(self):
script_dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(script_dir, 'config.yaml')
Expand Down Expand Up @@ -238,14 +178,11 @@ def update_response(self, response):
self.read_only_text.setPlainText(self.cumulative_response)
self.submit_button.setDisabled(False)

def update_transcription(self, text):
self.text_input.setPlainText(text)

def closeEvent(self, event):
self.metrics_bar.stop_metrics_collector()
event.accept()

def create_button_row(self, submit_handler):
def create_button_row(self):
voice_recorder = voice_recorder_module.VoiceRecorder(self)

def start_recording():
Expand Down
16 changes: 16 additions & 0 deletions src/gui_tabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pathlib import Path
from gui_tabs_settings import GuiSettingsTab
from gui_tabs_tools import GuiSettingsTab as ToolsSettingsTab
from gui_tabs_databases import DatabasesTab
from gui_tabs_vector_models import VectorModelsTab

def load_url(view, url):
view.setUrl(QUrl.fromLocalFile(url))
Expand Down Expand Up @@ -41,8 +43,14 @@ def create_tabs():
user_guide_view = QWebEngineView()
user_guide_view.setHtml('<body style="background-color: #161b22;"></body>')

default_url = user_manual_folder / 'tips.html'
load_url(user_guide_view, str(default_url))

light_blue_style = "QPushButton { background-color: #3498db; color: white; }"

for button_name, html_file in buttons_dict.items():
button = QPushButton(button_name)
button.setStyleSheet(light_blue_style)
button_url = user_manual_folder / html_file
button.clicked.connect(partial(load_url, user_guide_view, str(button_url)))
menu_layout.addWidget(button)
Expand All @@ -60,6 +68,14 @@ def create_tabs():
tools_tab = ToolsSettingsTab()
tab_widget.addTab(tools_tab, 'Tools')

# DATABASES TAB
databases_tab = DatabasesTab()
tab_widget.addTab(databases_tab, 'Databases')

# VECTOR MODELS TAB
vector_models_tab = VectorModelsTab()
tab_widget.addTab(vector_models_tab, 'Vector Models')

return tab_widget

if __name__ == '__main__':
Expand Down
Loading

0 comments on commit b46bb5e

Please sign in to comment.