Skip to content

Commit

Permalink
v5.0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Apr 29, 2024
1 parent 361c8f0 commit e75816a
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 96 deletions.
80 changes: 69 additions & 11 deletions src/gui_tabs_database_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from utilities import check_preconditions_for_submit_question, my_cprint
from constants import CHAT_MODELS
import module_chat
from module_chat import cleanup_resources

class RefreshingComboBox(QComboBox):
def __init__(self, parent=None):
Expand Down Expand Up @@ -57,6 +58,7 @@ class LocalModelThread(QThread):
citationsSignal = Signal(str)
errorSignal = Signal(str)
finishedSignal = Signal()
modelLoadedSignal = Signal()

def __init__(self, user_question, chunks_only, selected_model, parent=None):
super(LocalModelThread, self).__init__(parent)
Expand Down Expand Up @@ -84,7 +86,18 @@ def run(self):
augmented_query = f"{prepend_string}\n\n---\n\n" + "\n\n---\n\n".join(contexts) + "\n\n-----\n\n" + self.user_question

model_function = getattr(module_chat, CHAT_MODELS[self.selected_model]['function'])
response = model_function(augmented_query, None, None)

# Check if model needs to be loaded or switched
if not hasattr(self.parent(), 'model') or not hasattr(self.parent(), 'tokenizer') or self.selected_model != getattr(self.parent(), 'current_model', None):
if hasattr(self.parent(), 'model') and hasattr(self.parent(), 'tokenizer'):
cleanup_resources(self.parent().model, self.parent().tokenizer)
self.parent().model, self.parent().tokenizer = None, None # Explicitly set to None

self.parent().model, self.parent().tokenizer = model_function(augmented_query)
self.parent().current_model = self.selected_model
self.modelLoadedSignal.emit()

response = model_function(augmented_query, self.parent().model, self.parent().tokenizer)
self.responseSignal.emit(response)

with open('chat_history.txt', 'w', encoding='utf-8') as f:
Expand All @@ -96,6 +109,10 @@ def run(self):
self.finishedSignal.emit()
except Exception as err:
self.errorSignal.emit(str(err))
if hasattr(self.parent(), 'model') and hasattr(self.parent(), 'tokenizer'):
cleanup_resources(self.parent().model, self.parent().tokenizer)
self.parent().model = None
self.parent().tokenizer = None

def format_contexts_and_metadata(self, contexts, metadata_list):
formatted_contexts = []
Expand Down Expand Up @@ -140,27 +157,35 @@ def initWidgets(self):
self.chunks_only_checkbox.setToolTip(CHUNKS_ONLY_TOOLTIP)
hbox1_layout.addWidget(self.chunks_only_checkbox)

disabled_style = "QPushButton:disabled { color: #707070; }"
self.eject_button = QPushButton("Eject Model")
self.eject_button.clicked.connect(self.on_eject_button_clicked)
self.eject_button.setDisabled(True)
self.eject_button.setStyleSheet(disabled_style)
hbox1_layout.addWidget(self.eject_button)

self.model_combo_box = QComboBox()
self.model_combo_box.addItems(model_info['model'] for model_info in CHAT_MODELS.values())
self.model_combo_box.setCurrentText("Neural-Chat - 7b") # Set the default model
self.model_combo_box.setCurrentText("Neural-Chat - 7b") # default model
hbox1_layout.addWidget(self.model_combo_box)

self.local_model_radio = QRadioButton("Local Model")
self.lm_studio_radio = QRadioButton("LM Studio")
self.local_model_radio.toggled.connect(self.handle_model_switch)
self.lm_studio_radio.toggled.connect(self.handle_model_switch)
self.lm_studio_radio.setChecked(True)
self.lm_studio_radio.toggled.connect(lambda checked: not checked and self.local_model_radio.setChecked(True))
self.local_model_radio.toggled.connect(lambda checked: not checked and self.lm_studio_radio.setChecked(True))
self.lm_studio_radio.toggled.connect(lambda checked: self.lm_studio_radio_unchecked(checked))
self.local_model_radio.toggled.connect(lambda checked: self.local_model_radio_unchecked(checked))
hbox1_layout.addWidget(self.local_model_radio)
hbox1_layout.addWidget(self.lm_studio_radio)

# Disable the Local Model radio button and model combo box if CUDA is not available
# Disable widgets if CUDA is not available
if not torch.cuda.is_available():
self.local_model_radio.setEnabled(False)
self.model_combo_box.setEnabled(False)
tooltip_text = "The Local Model option is only supported with GPU acceleration."
self.local_model_radio.setToolTip(tooltip_text)
self.model_combo_box.setToolTip(tooltip_text)
# Applying style sheet to grey out disabled widgets
disabled_style = "QRadioButton:disabled, QComboBox:disabled { color: #707070; }"
self.local_model_radio.setStyleSheet(disabled_style)
self.model_combo_box.setStyleSheet(disabled_style)
Expand Down Expand Up @@ -194,7 +219,27 @@ def initWidgets(self):
self.is_recording = False
self.voice_recorder = VoiceRecorder(self)


def lm_studio_radio_unchecked(self, checked):
if not checked:
self.local_model_radio.setChecked(True)

def local_model_radio_unchecked(self, checked):
if not checked:
self.lm_studio_radio.setChecked(True)

def handle_model_switch(self, checked):
if checked:
if self.sender() == self.lm_studio_radio:
# If switching to LM Studio, check if Local Model was previously active and cleanup
self.attempt_cleanup()
# If switching to Local Model, setup or other necessary actions can be handled here

def attempt_cleanup(self):
if hasattr(self, 'submit_thread') and hasattr(self.submit_thread, 'parent'):
cleanup_resources(self.submit_thread.parent().model, self.submit_thread.parent().tokenizer)
self.submit_thread.parent().model, self.submit_thread.parent().tokenizer = None, None
self.eject_button.setDisabled(True)

def on_database_selected(self, index):
selected_database = self.database_pulldown.itemText(index)
self.update_config_selected_database(selected_database)
Expand Down Expand Up @@ -241,8 +286,19 @@ def on_submit_button_clicked(self):
self.submit_thread.citationsSignal.connect(self.display_citations_in_widget)
self.submit_thread.errorSignal.connect(self.show_error_message)
self.submit_thread.finishedSignal.connect(self.on_submission_finished)
self.submit_thread.modelLoadedSignal.connect(self.on_model_loaded)
self.submit_thread.start()

def on_eject_button_clicked(self):
if hasattr(self, 'submit_thread') and hasattr(self.submit_thread, 'parent'):
cleanup_resources(self.submit_thread.parent().model, self.submit_thread.parent().tokenizer)
self.submit_thread.parent().model = None
self.submit_thread.parent().tokenizer = None
self.eject_button.setDisabled(True)

def on_model_loaded(self):
self.eject_button.setEnabled(True)

def display_citations_in_widget(self, citations):
if citations:
self.read_only_text.append("\n\nCitations:\n" + citations)
Expand All @@ -259,7 +315,6 @@ def on_copy_response_clicked(self):
QMessageBox.warning(self, "Warning", "No response to copy.")

def on_bark_button_clicked(self):
# Check if PyTorch with CUDA is available
if not torch.cuda.is_available():
QMessageBox.warning(self, "Error", "Text to Speech is currently only supported with GPU acceleration.")
return
Expand Down Expand Up @@ -314,14 +369,17 @@ def display_citations(self, citations):
if citations:
QMessageBox.information(self, "Citations", f"The following sources were used:\n\n{citations}")
else:
QMessageBox.information(self, "Citations", "No citations found.")
QMessageBox.information(this, "Citations", "No citations found.")

def show_error_message(self, error_message):
QMessageBox.warning(self, "Error", error_message)
self.submit_button.setDisabled(False)
this.submit_button.setDisabled(False)

def on_submission_finished(self):
self.submit_button.setDisabled(False)
if self.lm_studio_radio.isChecked() or self.model_combo_box.currentText() != self.submit_thread.selected_model:
self.submit_thread.request_model_cleanup()
self.eject_button.setDisabled(True)

def update_transcription(self, transcription_text):
self.text_input.setPlainText(transcription_text)
self.text_input.setPlainText(transcription_text)
Loading

0 comments on commit e75816a

Please sign in to comment.