Skip to content

Commit

Permalink
more robust validation of settings
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Oct 25, 2024
1 parent be3afeb commit e18706c
Showing 1 changed file with 259 additions and 50 deletions.
309 changes: 259 additions & 50 deletions src/gui_tabs_settings_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,69 @@
from pathlib import Path

from PySide6.QtGui import QIntValidator, QDoubleValidator
from PySide6.QtWidgets import (QWidget, QLabel, QLineEdit, QGridLayout, QMessageBox, QSizePolicy, QCheckBox, QComboBox)
from PySide6.QtWidgets import (QWidget, QLabel, QLineEdit, QGridLayout, QMessageBox, QSizePolicy, QCheckBox, QComboBox, QMessageBox)

from constants import PROMPT_FORMATS, TOOLTIPS

class ServerSettingsTab(QWidget):
def __init__(self):
super(ServerSettingsTab, self).__init__()

with open('config.yaml', 'r') as file:
self.config_data = yaml.safe_load(file)
self.connection_str = self.config_data.get('server', {}).get('connection_str', '')
self.current_port = self.connection_str.split(":")[-1].split("/")[0]
self.current_max_tokens = self.config_data.get('server', {}).get('model_max_tokens', '')
self.current_temperature = self.config_data.get('server', {}).get('model_temperature', '')
self.current_prefix = self.config_data.get('server', {}).get('prefix', '')
self.current_suffix = self.config_data.get('server', {}).get('suffix', '')
self.prompt_format_disabled = self.config_data.get('server', {}).get('prompt_format_disabled', False)
try:
with open('config.yaml', 'r', encoding='utf-8') as file:
self.config_data = yaml.safe_load(file)
self.server_config = self.config_data.get('server', {})
self.connection_str = self.server_config.get('connection_str', '')
if ':' in self.connection_str and '/' in self.connection_str:
self.current_port = self.connection_str.split(":")[-1].split("/")[0]
else:
self.current_port = ''
self.current_max_tokens = self.server_config.get('model_max_tokens', '')
self.current_temperature = self.server_config.get('model_temperature', '')
self.current_prefix = self.server_config.get('prefix', '')
self.current_suffix = self.server_config.get('suffix', '')
self.prompt_format_disabled = self.server_config.get('prompt_format_disabled', False)
except Exception as e:
QMessageBox.critical(
self,
"Error Loading Configuration",
f"An error occurred while loading the configuration: {e}"
)
self.server_config = {}
self.connection_str = ''
self.current_port = ''
self.current_max_tokens = ''
self.current_temperature = ''
self.current_prefix = ''
self.current_suffix = ''
self.prompt_format_disabled = False

settings_dict = {
'port': {"placeholder": "Enter new port...", "validator": QIntValidator(), "current": self.current_port},
'max_tokens': {"placeholder": "Enter new max tokens...", "validator": QIntValidator(), "current": self.current_max_tokens},
'temperature': {"placeholder": "Enter new model temperature...", "validator": QDoubleValidator(), "current": self.current_temperature},
'prefix': {"placeholder": "Enter new prefix...", "validator": None, "current": self.current_prefix},
'suffix': {"placeholder": "Enter new suffix...", "validator": None, "current": self.current_suffix}
'port': {
"placeholder": "Port...",
"validator": QIntValidator(1, 65535),
"current": self.current_port
},
'max_tokens': {
"placeholder": "Max tokens (-1 or ≥25)...",
"validator": QIntValidator(-1, 1000000),
"current": self.current_max_tokens
},
'temperature': {
"placeholder": "Temperature (0.0 - 2.0)...",
"validator": QDoubleValidator(0.0, 2.0, 4),
"current": self.current_temperature
},
'prefix': {
"placeholder": "Prefix...",
"validator": None,
"current": self.current_prefix
},
'suffix': {
"placeholder": "Suffix...",
"validator": None,
"current": self.current_suffix
}
}

self.widgets = {}
Expand Down Expand Up @@ -59,15 +98,17 @@ def __init__(self):
prompt_format_label = QLabel("Prompt Format:")
prompt_format_label.setToolTip(TOOLTIPS["PREFIX_SUFFIX"])
layout.addWidget(prompt_format_label, 2, 0)

self.prompt_format_combobox = QComboBox()
self.prompt_format_combobox.addItems(["", "ChatML", "Llama2/Mistral", "Neural Chat/SOLAR", "Orca2", "StableLM-Zephyr"])
self.prompt_format_combobox.addItems([
"", "ChatML", "Llama2/Mistral", "Neural Chat/SOLAR", "Orca2", "StableLM-Zephyr"
])
self.prompt_format_combobox.setToolTip(TOOLTIPS["PREFIX_SUFFIX"])
layout.addWidget(self.prompt_format_combobox, 2, 1)
self.prompt_format_combobox.currentIndexChanged.connect(self.update_prefix_suffix)

# Disable
disable_label = QLabel("Disable:")
# Disable Prompt Formatting
disable_label = QLabel("Disable Prompt Formatting:")
disable_label.setToolTip(TOOLTIPS["DISABLE_PROMPT_FORMATTING"])
layout.addWidget(disable_label, 2, 2)

Expand Down Expand Up @@ -95,7 +136,8 @@ def __init__(self):
self.setLayout(layout)

def create_label(self, setting, settings_dict):
label = QLabel(f"{setting.capitalize()}: {settings_dict[setting]['current']}")
label_text = f"{setting.replace('_', ' ').capitalize()}: {settings_dict[setting]['current']}"
label = QLabel(label_text)
self.widgets[setting] = {"label": label}
return label

Expand All @@ -113,9 +155,9 @@ def create_edit(self, setting, settings_dict):
return edit

def refresh_labels(self):
self.widgets['prefix']['label'].setText(f"Prefix: {self.config_data.get('server', {}).get('prefix', '')}")
self.widgets['suffix']['label'].setText(f"Suffix: {self.config_data.get('server', {}).get('suffix', '')}")
self.widgets['prefix']['label'].setText(f"Prefix: {self.server_config.get('prefix', '')}")
self.widgets['suffix']['label'].setText(f"Suffix: {self.server_config.get('suffix', '')}")

def update_prefix_suffix(self, index):
option = self.prompt_format_combobox.currentText()
if option in PROMPT_FORMATS:
Expand All @@ -131,38 +173,205 @@ def update_config(self):
config_file_path = Path('config.yaml')
if config_file_path.exists():
try:
with config_file_path.open('r') as file:
with config_file_path.open('r', encoding='utf-8') as file:
config_data = yaml.safe_load(file)
self.server_config = config_data.get('server', {})
except Exception as e:
config_data = {}

updated = False
for setting, widget in self.widgets.items():
new_value = widget['edit'].text()
if new_value:
updated = True
if setting == 'port':
config_data['server']['connection_str'] = self.connection_str.replace(self.current_port, new_value)
self.widgets['port']['label'].setText(f"Port: {new_value}")
elif setting == 'max_tokens':
config_data['server']['model_max_tokens'] = int(new_value)
self.widgets['max_tokens']['label'].setText(f"Max Tokens: {new_value}")
elif setting == 'temperature':
config_data['server']['model_temperature'] = float(new_value)
self.widgets['temperature']['label'].setText(f"Temperature: {new_value}")
QMessageBox.critical(
self,
"Error Loading Configuration",
f"An error occurred while loading the configuration: {e}"
)
return False
else:
QMessageBox.critical(
self,
"Configuration File Missing",
"The configuration file 'config.yaml' does not exist."
)
return False

settings_changed = False
errors = []

new_query_device = self.prompt_format_combobox.currentText()
device_changed = new_query_device != self.server_config.get('database_query', '')

new_port_text = self.widgets['port']['edit'].text().strip()
if new_port_text:
try:
new_port = int(new_port_text)
if not (1 <= new_port <= 65535):
raise ValueError("Port must be between 1 and 65535.")
except ValueError:
errors.append("Port must be an integer between 1 and 65535.")
else:
new_port = self.current_port

new_max_tokens_text = self.widgets['max_tokens']['edit'].text().strip()
if new_max_tokens_text:
try:
if new_max_tokens_text == "-1":
new_max_tokens = -1
else:
config_data['server'][setting] = new_value
self.widgets[setting]['label'].setText(f"{setting.capitalize()}: {new_value}")
new_max_tokens = int(new_max_tokens_text)
if new_max_tokens < 25:
raise ValueError("Max tokens must be -1 or an integer ≥25.")
except ValueError:
errors.append("Max tokens must be -1 or an integer ≥25.")
else:
new_max_tokens = self.current_max_tokens

new_temperature_text = self.widgets['temperature']['edit'].text().strip()
if new_temperature_text:
try:
new_temperature = float(new_temperature_text)
if not (0.0 <= new_temperature <= 2.0):
raise ValueError("Temperature must be between 0.0 and 2.0.")
except ValueError:
errors.append("Temperature must be a number between 0.0 and 2.0.")
else:
new_temperature = self.current_temperature

new_prefix = self.widgets['prefix']['edit'].text().strip()
new_suffix = self.widgets['suffix']['edit'].text().strip()

new_prompt_format_disabled = self.disable_checkbox.isChecked()

if errors:
error_message = "\n".join(errors)
QMessageBox.warning(
self,
"Invalid Input",
f"The following errors occurred:\n{error_message}"
)
return False

if device_changed:
config_data['Compute_Device']['database_query'] = new_query_device
settings_changed = True

if new_port_text and new_port != self.current_port:
if ':' in self.connection_str and '/' in self.connection_str:
new_connection_str = self.connection_str.replace(self.current_port, str(new_port))
config_data['server']['connection_str'] = new_connection_str
settings_changed = True
else:
QMessageBox.warning(
self,
"Invalid Connection String",
"The existing connection string format is invalid. Unable to update port."
)
return False

if new_max_tokens_text and new_max_tokens != self.server_config.get('model_max_tokens', 0):
config_data['server']['model_max_tokens'] = new_max_tokens
settings_changed = True

if new_temperature_text and new_temperature != self.server_config.get('model_temperature', 0.0):
config_data['server']['model_temperature'] = new_temperature
settings_changed = True

widget['edit'].clear()
if new_prefix and new_prefix != self.server_config.get('prefix', ''):
config_data['server']['prefix'] = new_prefix
settings_changed = True

checkbox_state = self.disable_checkbox.isChecked()
if checkbox_state != config_data.get('server', {}).get('prompt_format_disabled', False):
config_data['server']['prompt_format_disabled'] = checkbox_state
updated = True
if new_suffix and new_suffix != self.server_config.get('suffix', ''):
config_data['server']['suffix'] = new_suffix
settings_changed = True

if updated:
with config_file_path.open('w') as file:
if new_prompt_format_disabled != self.server_config.get('prompt_format_disabled', False):
config_data['server']['prompt_format_disabled'] = new_prompt_format_disabled
settings_changed = True

if settings_changed:
try:
with config_file_path.open('w', encoding='utf-8') as file:
yaml.safe_dump(config_data, file)
except Exception as e:
QMessageBox.critical(
self,
"Error Saving Configuration",
f"An error occurred while saving the configuration: {e}"
)
return False

if device_changed:
self.server_config['database_query'] = new_query_device
self.widgets['port']['label'].setText(f"Device: {new_query_device}")

if new_port_text:
self.connection_str = config_data['server']['connection_str']
self.current_port = str(new_port)
self.widgets['port']['label'].setText(f"Port: {new_port}")

if new_max_tokens_text:
self.server_config['model_max_tokens'] = new_max_tokens
self.widgets['max_tokens']['label'].setText(f"Max Tokens: {new_max_tokens}")

if new_temperature_text:
self.server_config['model_temperature'] = new_temperature
self.widgets['temperature']['label'].setText(f"Temperature: {new_temperature}")

if new_prefix:
self.server_config['prefix'] = new_prefix
self.widgets['prefix']['label'].setText(f"Prefix: {new_prefix}")

if new_suffix:
self.server_config['suffix'] = new_suffix
self.widgets['suffix']['label'].setText(f"Suffix: {new_suffix}")

if new_prompt_format_disabled != self.prompt_format_disabled:
self.prompt_format_disabled = new_prompt_format_disabled

self.widgets['port']['edit'].clear()
self.widgets['max_tokens']['edit'].clear()
self.widgets['temperature']['edit'].clear()
self.widgets['prefix']['edit'].clear()
self.widgets['suffix']['edit'].clear()

return settings_changed

def reset_search_term(self):
config_file_path = Path('config.yaml')
if config_file_path.exists():
try:
with config_file_path.open('r', encoding='utf-8') as file:
config_data = yaml.safe_load(file)
self.server_config = config_data.get('server', {})
except Exception as e:
QMessageBox.critical(
self,
"Error Loading Configuration",
f"An error occurred while loading the configuration: {e}"
)
return
else:
QMessageBox.critical(
self,
"Configuration File Missing",
"The configuration file 'config.yaml' does not exist."
)
return

config_data['server']['search_term'] = ''

try:
with config_file_path.open('w', encoding='utf-8') as file:
yaml.safe_dump(config_data, file)
except Exception as e:
QMessageBox.critical(
self,
"Error Saving Configuration",
f"An error occurred while saving the configuration: {e}"
)
return

return updated
self.server_config['search_term'] = ''
self.widgets['prefix']['label'].setText(f"Prefix: {self.server_config.get('prefix', '')}")
self.widgets['suffix']['label'].setText(f"Suffix: {self.server_config.get('suffix', '')}")
self.widgets['port']['label'].setText(f"Port: {self.current_port}")
self.widgets['max_tokens']['label'].setText(f"Max Tokens: {self.current_max_tokens}")
self.widgets['temperature']['label'].setText(f"Temperature: {self.current_temperature}")
self.widgets['prefix']['edit'].clear()
self.widgets['suffix']['edit'].clear()

0 comments on commit e18706c

Please sign in to comment.