Skip to content

Commit

Permalink
feat(gui): add user model management and flux models
Browse files Browse the repository at this point in the history
- Implement user-added model functionality with add/delete options
- Replace recent models with Flux fine-tune models
- Refactor ImageGenerator to ReplicateAPI for clarity
- Remove token counter and update UI layout
  • Loading branch information
rtuszik committed Aug 25, 2024
1 parent 9f2aa5c commit 235bcd8
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 45 deletions.
96 changes: 54 additions & 42 deletions src/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@ def __init__(self, image_generator):
self.image_generator = image_generator
self.settings_file = "settings.json"
self.load_settings()
self.recent_replicate_models = self.load_recent_replicate_models()
self.flux_fine_tune_models = self.image_generator.get_flux_fine_tune_models()
self.user_added_models = self.settings.get("user_added_models", [])
self.setup_ui()
logger.info("ImageGeneratorGUI initialized")

def setup_ui(self):
ui.dark_mode().enable()

with ui.column().classes("w-full max-w-7xl mx-auto p-4 space-y-4"):
with ui.column().classes("w-full max-w-full mx-auto p-4 space-y-4"):
with ui.card().classes("w-full"):
ui.label("Image Generator").classes("text-2xl font-bold mb-4")
with ui.row().classes("w-full"):
Expand All @@ -77,13 +78,28 @@ def setup_left_panel(self):
).classes("w-full")
self.replicate_model_input.on("change", self.update_replicate_model)

self.recent_models_select = ui.select(
options=self.recent_replicate_models,
label="Recent Models",
self.flux_models_select = ui.select(
options=self.flux_fine_tune_models,
label="Flux Fine-Tune Models",
value=None,
on_change=self.select_recent_model,
on_change=self.select_flux_model,
).classes("w-full")

with ui.row().classes("w-full"):
self.new_model_input = ui.input(label="New Model").classes("w-3/4")
ui.button("Add Model", on_click=self.add_user_model).classes("w-1/4")

self.user_models_select = ui.select(
options=self.user_added_models,
label="User Added Models",
value=None,
on_change=self.select_user_model,
).classes("w-full")

ui.button("Delete Selected Model", on_click=self.delete_user_model).classes(
"w-full"
)

self.folder_path = self.settings.get(
"output_folder", str(Path.home() / "Downloads")
)
Expand Down Expand Up @@ -221,6 +237,34 @@ def setup_left_panel(self):
.bind_value(self, "disable_safety_checker")
)

def add_user_model(self):
new_model = self.new_model_input.value
if new_model and new_model not in self.user_added_models:
self.user_added_models.append(new_model)
self.user_models_select.options = self.user_added_models
self.new_model_input.value = ""
self.save_settings()
ui.notify(f"Model '{new_model}' added successfully", type="success")
else:
ui.notify("Invalid model name or model already exists", type="error")

def delete_user_model(self):
selected_model = self.user_models_select.value
if selected_model in self.user_added_models:
self.user_added_models.remove(selected_model)
self.user_models_select.options = self.user_added_models
self.user_models_select.value = None
self.save_settings()
ui.notify(f"Model '{selected_model}' deleted successfully", type="success")
else:
ui.notify("No model selected for deletion", type="error")

def select_user_model(self, e):
if e.value:
self.replicate_model_input.value = e.value
self.update_replicate_model(e)
self.user_models_select.value = None

def update_folder_path(self, e):
new_path = e.value
if os.path.isdir(new_path):
Expand Down Expand Up @@ -248,53 +292,28 @@ def setup_bottom_panel(self):
.classes("w-full")
.bind_value(self, "prompt")
)
self.token_counter = ui.label("Tokens: 0").classes("text-sm text-gray-500")
self.prompt_input.on("input", self.update_token_count)
self.generate_button = ui.button(
"Generate Images", on_click=self.start_generation
).classes(
"w-full bg-blue-500 hover:bg-blue-600 text-white font-bold py-2 px-4 rounded"
)

def select_folder(self):
def on_folder_selected(e):
if e.value:
self.folder_path = e.value
self.folder_input.value = self.folder_path
self.save_settings()
logger.info(f"Output folder set to: {self.folder_path}")

ui.open_directory_dialog(on_folder_selected)

def update_replicate_model(self, e):
new_model = e.value
new_model = self.replicate_model_input.value
if new_model:
self.image_generator.set_model(new_model)
self.save_settings()
self.add_recent_replicate_model(new_model)
logger.info(f"Replicate model updated to: {new_model}")
self.generate_button.enable()
else:
logger.warning("Empty Replicate model provided")
self.generate_button.disable()

def select_recent_model(self, e):
def select_flux_model(self, e):
if e.value:
self.replicate_model_input.value = e.value
self.update_replicate_model(e)
self.recent_models_select.value = None

def add_recent_replicate_model(self, model):
if model not in self.recent_replicate_models:
self.recent_replicate_models.insert(0, model)
self.recent_replicate_models = self.recent_replicate_models[
:5
] # Keep only the last 5
self.save_settings()
self.recent_models_select.options = self.recent_replicate_models

def load_recent_replicate_models(self):
return self.settings.get("recent_replicate_models", [])
self.flux_models_select.value = None

def toggle_custom_dimensions(self, e):
if e.value == "custom":
Expand All @@ -306,13 +325,6 @@ def toggle_custom_dimensions(self, e):
self.save_settings()
logger.info(f"Custom dimensions toggled: {e.value}")

def update_token_count(self, e):
token_count = len(e.value.split())
self.token_counter.text = f"Tokens: {token_count}"
if token_count > 77:
ui.notify("Warning: Tokens beyond 77 will be ignored", type="warning")
self.save_settings()

async def start_generation(self):
if not self.replicate_model_input.value:
ui.notify(
Expand Down Expand Up @@ -407,6 +419,7 @@ def load_settings(self):

def save_settings(self):
settings_to_save = {
"user_added_models": self.user_added_models,
"replicate_model": self.replicate_model_input.value,
"output_folder": self.folder_path,
"flux_model": self.flux_model,
Expand All @@ -422,7 +435,6 @@ def save_settings(self):
"output_quality": self.output_quality,
"disable_safety_checker": self.disable_safety_checker,
"prompt": self.prompt,
"recent_replicate_models": self.recent_replicate_models,
}
with open(self.settings_file, "w") as f:
json.dump(settings_to_save, f)
Expand Down
2 changes: 1 addition & 1 deletion src/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import sys

from gui import create_gui
from image_generator import ImageGenerator
from loguru import logger
from nicegui import ui
from replicate_api import ImageGenerator

# Configure Loguru
logger.remove() # Remove the default handler
Expand Down
13 changes: 11 additions & 2 deletions src/image_generator.py → src/replicate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
load_dotenv()

# Configure Loguru
logger.remove() # Remove the default handler
logger.remove()
logger.add(
sys.stderr, format="{time} {level} {message}", filter="my_module", level="INFO"
)
logger.add(
"image_generator.log",
"replicate.log",
rotation="10 MB",
format="{time} {level} {message}",
level="INFO",
Expand Down Expand Up @@ -58,6 +58,15 @@ def generate_images(self, params):
logger.exception(error_message)
raise ImageGenerationError(error_message)

def get_flux_fine_tune_models(self):
try:
collection = replicate.collections.get("flux-fine-tunes")
models = collection.models
return [model.name for model in models]
except Exception as e:
logger.error(f"Error fetching flux-fine-tunes models: {str(e)}")
return []


class ImageGenerationError(Exception):
pass
12 changes: 12 additions & 0 deletions src/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import replicate
from dotenv import load_dotenv

load_dotenv()
print(replicate.paginate(replicate.collections.list))
collections = [
collection
for page in replicate.paginate(replicate.collections.list)
for collection in page
]

print(replicate.collections.get("flux-fine-tunes").models)

0 comments on commit 235bcd8

Please sign in to comment.