diff --git a/run_training.py b/run_training.py index f8be872..d57afaf 100644 --- a/run_training.py +++ b/run_training.py @@ -2,9 +2,10 @@ import os import pathlib import shutil -import tensorflow import sys +import tensorflow + from simpletextgenerator import training, jobs_util from simpletextgenerator.jobs_util import resource_path from simpletextgenerator.logging_setup import setup_logging @@ -19,7 +20,6 @@ except IndexError as ignored: pass - project_dir = "./projects" to_run_dir = "to_run" output_dir = "output" diff --git a/simple-text-generator-ui.py b/simple-text-generator-ui.py index f1126d8..2b43cc0 100644 --- a/simple-text-generator-ui.py +++ b/simple-text-generator-ui.py @@ -1,9 +1,9 @@ import hashlib import logging +import os import pathlib import subprocess import sys -import os from simpletextgenerator.logging_setup import setup_logging from simpletextgenerator.ui.menu import draw_main_menu diff --git a/simpletextgenerator/jobs_util.py b/simpletextgenerator/jobs_util.py index 5def4c5..6a53498 100644 --- a/simpletextgenerator/jobs_util.py +++ b/simpletextgenerator/jobs_util.py @@ -1,7 +1,8 @@ +import logging import os import sys + import yaml -import logging from simpletextgenerator.models import job diff --git a/simpletextgenerator/models/config.py b/simpletextgenerator/models/config.py index cdb0cb8..40682a1 100644 --- a/simpletextgenerator/models/config.py +++ b/simpletextgenerator/models/config.py @@ -1,8 +1,8 @@ +import logging from dataclasses import dataclass from typing import IO -import chevron -import logging +import chevron logger = logging.getLogger("ui") diff --git a/simpletextgenerator/models/job.py b/simpletextgenerator/models/job.py index 2a6fd9d..6501bf7 100644 --- a/simpletextgenerator/models/job.py +++ b/simpletextgenerator/models/job.py @@ -1,7 +1,7 @@ -from simpletextgenerator.models import config, state - import logging +from simpletextgenerator.models import config, state + logger = logging.getLogger("ui") diff --git a/simpletextgenerator/models/state.py b/simpletextgenerator/models/state.py index b2600b1..62cbdef 100644 --- a/simpletextgenerator/models/state.py +++ b/simpletextgenerator/models/state.py @@ -1,7 +1,8 @@ +import logging from dataclasses import dataclass from typing import IO + import chevron -import logging logger = logging.getLogger("ui") diff --git a/simpletextgenerator/models/worklist.py b/simpletextgenerator/models/worklist.py index dc58ce7..7c07b57 100644 --- a/simpletextgenerator/models/worklist.py +++ b/simpletextgenerator/models/worklist.py @@ -6,6 +6,7 @@ class WorkItem: """ Interface for the 2 classes below """ + def __init__(self, items_to_complete: int): self._items_to_complete = items_to_complete diff --git a/simpletextgenerator/training.py b/simpletextgenerator/training.py index 4dbd655..fee3c8a 100644 --- a/simpletextgenerator/training.py +++ b/simpletextgenerator/training.py @@ -1,14 +1,13 @@ -import logging import os from textgenrnn import textgenrnn + from simpletextgenerator.jobs_util import resource_path from simpletextgenerator.logging_setup import setup_logging from simpletextgenerator.models.job import Job from simpletextgenerator.training_status import TrainingStatus - class Train: def __init__(self, job: Job, logger): self.logger = logger @@ -85,7 +84,8 @@ def save_final_model(self): def generate_final_text(self): self.logger.info("Generating final text") for temperature in self.job.config.temperatures_to_generate: - generated = self.textgen.generate(n=self.job.config.items_to_generate_at_end, return_as_list=True, temperature=temperature) + generated = self.textgen.generate(n=self.job.config.items_to_generate_at_end, return_as_list=True, + temperature=temperature) self.save_lines_to_file("last", temperature, generated) def save_model(self, model_name: str) -> None: @@ -109,7 +109,8 @@ def generate_text(self, i) -> None: if (i % self.job.config.generate_every_n_generations) == 0: for temperature in self.job.config.temperatures_to_generate: try: - generated = self.textgen.generate(n=self.job.config.items_to_generate_each_generation, return_as_list=True, temperature=temperature) + generated = self.textgen.generate(n=self.job.config.items_to_generate_each_generation, + return_as_list=True, temperature=temperature) except KeyError: continue self.save_lines_to_file(i * self.job.config.generate_every_n_generations, temperature, generated) diff --git a/simpletextgenerator/ui/archive_delete.py b/simpletextgenerator/ui/archive_delete.py index 6bf3efc..7dbe376 100644 --- a/simpletextgenerator/ui/archive_delete.py +++ b/simpletextgenerator/ui/archive_delete.py @@ -1,8 +1,7 @@ +import logging import os import shutil import tkinter as tk - -import logging from tkinter.messagebox import askyesno logger = logging.getLogger("ui") @@ -40,7 +39,8 @@ def option_selected(self, selected_value): def archive_project(self): if self.selected_project_name.get() == '' or self.selected_project_name.get() == 'Select a project': return - shutil.move(f"projects/{self.selected_project_name.get()}", f"projects/archive/{self.selected_project_name.get()}") + shutil.move(f"projects/{self.selected_project_name.get()}", + f"projects/archive/{self.selected_project_name.get()}") self.selected_project_name.set("Select a project") self.destroy_window() @@ -54,7 +54,6 @@ def delete_project(self): return self.archive_delete_window.lift() - def draw_archive_delete_window(self): if self.archive_delete_window is not None: self.archive_delete_window = None diff --git a/simpletextgenerator/ui/edit_job.py b/simpletextgenerator/ui/edit_job.py index 9c0a468..bf8ae0c 100644 --- a/simpletextgenerator/ui/edit_job.py +++ b/simpletextgenerator/ui/edit_job.py @@ -1,16 +1,19 @@ +import logging import os +import tkinter as tk +from pathlib import Path from shutil import copyfile from tkinter import filedialog -import tkinter as tk + import chevron -from pathlib import Path + from simpletextgenerator.jobs_util import create_job, resource_path from simpletextgenerator.models.config import Config from simpletextgenerator.training_status import TrainingStatus -import logging logger = logging.getLogger("ui") + def draw_edit_existing_job_window(): edit_existing_job_window = EditJobWindow() edit_existing_job_window.draw_edit_existing_job_window() diff --git a/simpletextgenerator/ui/generating_only_job.py b/simpletextgenerator/ui/generating_only_job.py index 8ca9a83..23186e6 100644 --- a/simpletextgenerator/ui/generating_only_job.py +++ b/simpletextgenerator/ui/generating_only_job.py @@ -1,18 +1,18 @@ +import logging import os -import chevron import tkinter as tk -from tkinter import filedialog -from pathlib import Path from os.path import splitext -from tkinter import messagebox +from pathlib import Path from shutil import copyfile +from tkinter import filedialog +from tkinter import messagebox + +import chevron from simpletextgenerator.jobs_util import resource_path from simpletextgenerator.models.config import Config from simpletextgenerator.training_status import TrainingStatus -import logging - logger = logging.getLogger("ui") @@ -193,7 +193,7 @@ def create_state_file(self, path): f.write(self.render_state_file_text()) def copy_training_file(self, project_path): - #todo delete + # todo delete pass # copyfile(self.training_file_origin_path, f"./{project_path}/{self.training_file}") diff --git a/simpletextgenerator/ui/menu.py b/simpletextgenerator/ui/menu.py index d46d524..d76cb92 100644 --- a/simpletextgenerator/ui/menu.py +++ b/simpletextgenerator/ui/menu.py @@ -1,9 +1,8 @@ +import logging import tkinter as tk from simpletextgenerator.ui import new_job, edit_job, training, archive_delete, generating_only_job -import logging - logger = logging.getLogger("ui") diff --git a/simpletextgenerator/ui/new_job.py b/simpletextgenerator/ui/new_job.py index 26e62e5..29c63c5 100644 --- a/simpletextgenerator/ui/new_job.py +++ b/simpletextgenerator/ui/new_job.py @@ -1,18 +1,18 @@ +import logging import os -import chevron import tkinter as tk -from tkinter import filedialog -from pathlib import Path from os.path import splitext -from tkinter import messagebox +from pathlib import Path from shutil import copyfile +from tkinter import filedialog +from tkinter import messagebox + +import chevron from simpletextgenerator.jobs_util import resource_path from simpletextgenerator.models.config import Config from simpletextgenerator.training_status import TrainingStatus -import logging - logger = logging.getLogger("ui") diff --git a/simpletextgenerator/ui/training.py b/simpletextgenerator/ui/training.py index 4763da3..90bb2bf 100644 --- a/simpletextgenerator/ui/training.py +++ b/simpletextgenerator/ui/training.py @@ -1,17 +1,16 @@ import asyncio import contextlib +import logging import os import re import threading import tkinter import tkinter as tk from datetime import datetime -from tkinter import ttk, LEFT, BOTTOM, RIGHT, TOP from queue import Queue +from tkinter import ttk, LEFT from tkinter.scrolledtext import ScrolledText -import logging - from simpletextgenerator.utility.running_mean import RunningMean logger = logging.getLogger("ui") diff --git a/tests/ui/test_training.py b/tests/ui/test_training.py index d5dff3d..271c0b2 100644 --- a/tests/ui/test_training.py +++ b/tests/ui/test_training.py @@ -20,4 +20,3 @@ def test_get_percentage_bar2(): training_window = TrainingWindow(None, None, None) result = training_window.get_percentage_bar2("80%|████████ | 4/5 [00:13<00:03, 3.12it/s]") assert result == "80" - diff --git a/tests/utility/test_running_mean.py b/tests/utility/test_running_mean.py index ccab5c0..5433160 100644 --- a/tests/utility/test_running_mean.py +++ b/tests/utility/test_running_mean.py @@ -10,6 +10,7 @@ def test_3_average(): assert running_mean.mean() == 2 + def test_10_average(): running_mean = RunningMean(10)