From 7232c3eacee0b1428636607c58507b72698afd19 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Sat, 5 Apr 2025 10:18:09 +0000 Subject: [PATCH 1/6] [deps] Add gradio & slugify --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6de80fc..8ae078d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,12 +19,13 @@ dependencies = [ "sentencepiece==0.2.0", "transformers~=4.49", "wandb~=0.19.8", - "fastapi[standard]~=0.115.11", "fastapi_cli~=0.0.7", "openai~=1.67", "pydantic_settings~=2.8.1", "python-dotenv~=1.0", + "gradio~=5.23", + "python-slugify>=8.0.4", ] [project.optional-dependencies] From a4b1275f8dcfdd30c6b0732dbed41a8752edb8e3 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Mon, 7 Apr 2025 08:14:40 +0000 Subject: [PATCH 2/6] [chore] Add gradio related rules --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 717e529..646fea7 100644 --- a/.gitignore +++ b/.gitignore @@ -263,3 +263,6 @@ tmp/ webdoc/ **/wandb/ + +**/lora_checkpoints +**/.gradio From 87a89fdaab5a213d010779a48da45c14cde930d9 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Mon, 7 Apr 2025 08:15:30 +0000 Subject: [PATCH 3/6] [feat] Add gradio demo for inference & lora fintuning --- gradio/configs/t2i.yaml | 45 +++++ gradio/configs/t2v.yaml | 45 +++++ gradio/gradio_infer_demo.py | 357 ++++++++++++++++++++++++++++++++++ gradio/gradio_lora_demo.py | 372 ++++++++++++++++++++++++++++++++++++ gradio/gradio_ui.py | 10 + gradio/styles/mono.py | 18 ++ gradio/utils/__init__.py | 26 +++ gradio/utils/io.py | 96 ++++++++++ gradio/utils/logging.py | 52 +++++ gradio/utils/misc.py | 53 +++++ gradio/utils/task.py | 185 ++++++++++++++++++ 11 files changed, 1259 insertions(+) create mode 100644 gradio/configs/t2i.yaml create mode 100644 gradio/configs/t2v.yaml create mode 100644 gradio/gradio_infer_demo.py create mode 100644 gradio/gradio_lora_demo.py create mode 100644 gradio/gradio_ui.py create mode 100644 gradio/styles/mono.py create mode 100644 gradio/utils/__init__.py create mode 100644 gradio/utils/io.py create mode 100644 gradio/utils/logging.py create mode 100644 gradio/utils/misc.py create mode 100644 gradio/utils/task.py diff --git a/gradio/configs/t2i.yaml b/gradio/configs/t2i.yaml new file mode 100644 index 0000000..77823e7 --- /dev/null +++ b/gradio/configs/t2i.yaml @@ -0,0 +1,45 @@ +# CogView4 Configuration + +# Model Configuration +model: + model_path: "THUDM/CogView4-6B" # Path to the pre-trained model + model_name: "cogview4-6b" # Model name (options: "cogview4-6b") + model_type: "t2i" # Model type (text-to-image) + training_type: "lora" # Training type + +# Output Configuration +output: + output_dir: "/path/to/output" # Directory to save outputs + report_to: "tensorboard" # Logging framework + +# Data Configuration +data: + data_root: "/path/to/data" # Path to training data + +# Training Configuration +training: + seed: 42 # Random seed for reproducibility + train_epochs: 1 # Number of training epochs + batch_size: 1 # Batch size per GPU + gradient_accumulation_steps: 1 # Number of gradient accumulation steps + mixed_precision: "bf16" # Mixed precision mode (options: "no", "fp16", "bf16") + learning_rate: 2.0e-5 # Learning rate + + # Note: For CogView4 series models, height and width should be **32N** (multiple of 32) + train_resolution: "1024x1024" # Training resolution (height x width) + +# System Configuration +system: + num_workers: 8 # Number of dataloader workers + pin_memory: true # Whether to pin memory in dataloader + nccl_timeout: 1800 # NCCL timeout in seconds + +# Checkpointing Configuration +checkpoint: + checkpointing_steps: 10 # Save checkpoint every x steps + checkpointing_limit: 2 # Maximum number of checkpoints to keep + +# Validation Configuration +validation: + do_validation: true # Whether to perform validation + validation_steps: 10 # Validate every x steps (should be multiple of checkpointing_steps) diff --git a/gradio/configs/t2v.yaml b/gradio/configs/t2v.yaml new file mode 100644 index 0000000..a8e491c --- /dev/null +++ b/gradio/configs/t2v.yaml @@ -0,0 +1,45 @@ +# CogView4 Configuration + +# Model Configuration +model: + model_path: "THUDM/CogVideoX1.5-5B" # Path to the pre-trained model + model_name: "cogvideox1.5-t2v" # Model name (options: "cogview4-6b") + model_type: "t2v" # Model type (text-to-video) + training_type: "lora" # Training type + +# Output Configuration +output: + output_dir: "/path/to/output" # Directory to save outputs + report_to: "tensorboard" # Logging framework + +# Data Configuration +data: + data_root: "/path/to/data" # Path to training data + +# Training Configuration +training: + seed: 42 # Random seed for reproducibility + train_epochs: 1 # Number of training epochs + batch_size: 1 # Batch size per GPU + gradient_accumulation_steps: 1 # Number of gradient accumulation steps + mixed_precision: "bf16" # Mixed precision mode (options: "no", "fp16", "bf16") + learning_rate: 2.0e-5 # Learning rate + + # Note: For CogView4 series models, height and width should be **32N** (multiple of 32) + train_resolution: "81x768x1360" # Training resolution (height x width) + +# System Configuration +system: + num_workers: 8 # Number of dataloader workers + pin_memory: true # Whether to pin memory in dataloader + nccl_timeout: 1800 # NCCL timeout in seconds + +# Checkpointing Configuration +checkpoint: + checkpointing_steps: 10 # Save checkpoint every x steps + checkpointing_limit: 2 # Maximum number of checkpoints to keep + +# Validation Configuration +validation: + do_validation: true # Whether to perform validation + validation_steps: 10 # Validate every x steps (should be multiple of checkpointing_steps) diff --git a/gradio/gradio_infer_demo.py b/gradio/gradio_infer_demo.py new file mode 100644 index 0000000..a4a3c3d --- /dev/null +++ b/gradio/gradio_infer_demo.py @@ -0,0 +1,357 @@ +import os +import tempfile +import uuid +from pathlib import Path +from typing import List, Tuple + +import torch +from utils import ( + get_logger, + get_lora_checkpoint_dirs, + get_lora_checkpoint_rootdir, +) + +import gradio as gr +from cogkit import ( + GenerationMode, + generate_image, + generate_video, + guess_generation_mode, + load_lora_checkpoint, + load_pipeline, + unload_lora_checkpoint, +) +from diffusers.utils import export_to_video + +# ======================= global state ==================== + +logger = get_logger(__name__) + +task: GenerationMode | None = None +checkpoint_rootdir: str = "" +checkpoint_dirs: List[str] = [] +pipeline = None + +prev_model_id: str | None = None +resolution: str | None = None + +# ========================= hooks ========================= + + +def update_task(hf_model_id: str) -> Tuple[gr.Dropdown, gr.Component]: + """Update the task based on model selection and load available LoRA checkpoints.""" + global task, checkpoint_rootdir, checkpoint_dirs, resolution + + task = guess_generation_mode("THUDM/" + hf_model_id) + checkpoint_rootdir = get_lora_checkpoint_rootdir(task) + + # Get all available checkpoints for the selected task + checkpoint_dirs = get_lora_checkpoint_dirs(task) + + # Add a "None" option at the beginning for no LoRA + checkpoint_options = ["None"] + checkpoint_dirs + + logger.info(f"Current task: {task}") + logger.info(f"Checkpoint root dir: {checkpoint_rootdir}") + logger.info(f"Available checkpoints: {checkpoint_options}") + + updated_lora_dropdown = gr.Dropdown( + choices=checkpoint_options, + label="LoRA Checkpoint", + info="Select a LoRA checkpoint to use for inference", + interactive=True, + value="None", + ) + + # Reset the subcheckpoint dropdown + updated_subcheckpoint_dropdown = gr.Dropdown( + choices=[], + label="Checkpoint Version", + info="Select a specific checkpoint version", + interactive=False, + value=None, + visible=False, + ) + + # Configure resolution dropdown based on task + if task == GenerationMode.TextToImage: + resolution_list = [ + "512x512", + "512x768", + "512x1024", + "720x1280", + "768x768", + "1024x1024", + "1080x1920", + ] + default_resolution = resolution_list[0] + resolution_info = "Height x Width" + elif task == GenerationMode.TextToVideo: + resolution_list = [ + "49x480x720", + "81x768x1360", + ] + default_resolution = resolution_list[0] + resolution_info = "Frames x Height x Width" + else: + resolution_list = [] + default_resolution = None + resolution_info = "" + + resolution = default_resolution + + updated_resolution_dropdown = gr.Dropdown( + choices=resolution_list, + label="Resolution", + info=resolution_info, + interactive=True, + value=default_resolution, + ) + + # Return appropriate output component based on task + if task == GenerationMode.TextToImage: + output_component = gr.Image(label="Generated Image", type="pil", visible=True) + video_component = gr.Video(label="Generated Video", visible=False) + else: # TextToVideo + output_component = gr.Image(label="Generated Image", type="pil", visible=False) + video_component = gr.Video(label="Generated Video", visible=True) + + # Return updated UI components + return ( + updated_lora_dropdown, + updated_subcheckpoint_dropdown, + updated_resolution_dropdown, + output_component, + video_component, + ) + + +def update_subcheckpoints(checkpoint_dir): + """Get subdirectories for the selected checkpoint directory.""" + if checkpoint_dir == "None": + return gr.Dropdown(choices=[], interactive=False, visible=False) + + # Get the full path to the checkpoint directory + full_checkpoint_path = os.path.join(checkpoint_rootdir, checkpoint_dir) + + # Get all subdirectories + try: + subdirs = [ + d + for d in os.listdir(full_checkpoint_path) + if os.path.isdir(os.path.join(full_checkpoint_path, d)) and d.startswith("checkpoint-") + ] + subdirs.sort() # Sort to get a consistent order + except Exception as e: + logger.error(f"Error loading subdirectories: {str(e)}") + subdirs = [] + + if not subdirs: + # If there are no subdirectories, hide the dropdown + return gr.Dropdown(choices=[], interactive=False, visible=False) + + # Show dropdown with available subdirectories + return gr.Dropdown( + choices=subdirs, + label="Checkpoint Version", + info="Select a specific checkpoint version", + value=subdirs[-1] if subdirs else None, # Select the last checkpoint by default + interactive=True, + visible=True, + ) + + +def load_model_and_generate( + prompt, + model_type, + lora_checkpoint, + subcheckpoint, + num_inference_steps, + guidance_scale, + resolution, +): + """Load the model with optional LoRA and generate content based on task type.""" + global pipeline, task, prev_model_id + + if not model_type: + raise gr.Error("Please select a model first") + + if not prompt or prompt.strip() == "": + raise gr.Error("Please enter a prompt") + + # Create progress tracking + progress = gr.Progress() + progress(0, desc="Loading model...") + + # Load the base model + model_id = "THUDM/" + model_type + if model_id != prev_model_id: + prev_model_id = model_id + pipeline = load_pipeline( + model_id, + dtype=torch.bfloat16, + ) + + # Load LoRA weights if selected + if lora_checkpoint != "None": + progress(0.3, desc="Loading LoRA weights...") + # Construct the full path to the specific checkpoint + if subcheckpoint and subcheckpoint.strip(): + lora_path = os.path.join(lora_checkpoint, subcheckpoint) + else: + lora_path = lora_checkpoint + logger.info(f"Loading LoRA weights from {lora_path}") + load_lora_checkpoint(pipeline, lora_path) + else: + unload_lora_checkpoint(pipeline) + + # Generate content based on task + progress(0.5, desc="Generating content...") + + try: + if task == GenerationMode.TextToImage: + height, width = map(int, resolution.split("x")) + outputs = generate_image( + prompt=prompt, + pipeline=pipeline, + num_images_per_prompt=1, + output_type="pil", + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + height=height, + width=width, + ) + # For image output, return the PIL image and None for video + return outputs[0], None + + elif task == GenerationMode.TextToVideo: + frames, height, width = map(int, resolution.split("x")) + outputs, fps = generate_video( + prompt=prompt, + pipeline=pipeline, + num_videos_per_prompt=1, + output_type="pil", + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_frames=frames, + height=height, + width=width, + ) + + # Create temporary file to save the video + temp_dir = Path(tempfile.gettempdir()) / "cogkit_videos" + os.makedirs(temp_dir, exist_ok=True) + video_path = str(temp_dir / f"{uuid.uuid4()}.mp4") + + # Export video frames to a video file + export_to_video(outputs[0], video_path, fps=fps) + + # Return None for image and the video path + return None, video_path + + else: + raise gr.Error(f"Unsupported task type: {task}") + + except Exception as e: + logger.error(f"Error during generation: {str(e)}") + raise gr.Error(f"Generation failed: {str(e)}") + finally: + progress(1.0, desc="Generation completed!") + + +# =========================== UI =========================== + +with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(): + with gr.Row(): + model_type = gr.Dropdown( + choices=["CogView4-6B", "CogVideoX1.5-5B"], + label="Model", + info="Select the model to use", + interactive=True, + value=None, + ) + lora_dropdown = gr.Dropdown( + choices=["None"], + label="LoRA Checkpoint", + info="Select a LoRA checkpoint to use for inference", + interactive=False, + value="None", + ) + + subcheckpoint_dropdown = gr.Dropdown( + choices=[], + label="Checkpoint Version", + info="Select a specific checkpoint version", + interactive=False, + visible=False, + ) + + prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=3) + + with gr.Row(): + resolution_dropdown = gr.Dropdown( + choices=[], + label="Resolution", + info="Select resolution for generation", + interactive=False, + ) + + num_inference_steps = gr.Slider( + minimum=1, + maximum=100, + value=50, + step=1, + label="Inference Steps", + info="Higher values give better quality but take longer", + ) + + guidance_scale = gr.Slider( + minimum=1.0, + maximum=15.0, + value=6.0, + step=0.1, + label="Guidance Scale", + info="Higher values increase prompt adherence", + ) + + generate_btn = gr.Button("Generate") + + with gr.Column(scale=1): + image_output = gr.Image(label="Generated Image", type="pil") + video_output = gr.Video(label="Generated Video", visible=False) + + # Set up event handlers + model_type.change( + fn=update_task, + inputs=[model_type], + outputs=[ + lora_dropdown, + subcheckpoint_dropdown, + resolution_dropdown, + image_output, + video_output, + ], + ) + + lora_dropdown.change( + fn=update_subcheckpoints, inputs=[lora_dropdown], outputs=[subcheckpoint_dropdown] + ) + + generate_btn.click( + fn=load_model_and_generate, + inputs=[ + prompt, + model_type, + lora_dropdown, + subcheckpoint_dropdown, + num_inference_steps, + guidance_scale, + resolution_dropdown, + ], + outputs=[image_output, video_output], + ) + +if __name__ == "__main__": + demo.launch(share=True, show_error=True) diff --git a/gradio/gradio_lora_demo.py b/gradio/gradio_lora_demo.py new file mode 100644 index 0000000..609df9e --- /dev/null +++ b/gradio/gradio_lora_demo.py @@ -0,0 +1,372 @@ +import os +import sys +import tempfile +import uuid +from pathlib import Path +from typing import Any, Dict, Iterator, List, Tuple + +import torch +from datasets import Dataset +from slugify import slugify +from torchvision.io import write_video +from utils import ( + BaseTask, + flatten_dict, + get_dataset_dirs, + get_logger, + get_lora_checkpoint_rootdir, + get_resolutions, + get_training_script, + load_config_template, + load_data, + resolve_path, +) + +import gradio as gr +from cogkit import GenerationMode, guess_generation_mode + +# ======================= global state ==================== + +logger = get_logger(__name__) + +data_dirs: List[str] = get_dataset_dirs() +checkpoint_rootdir: str = "" +checkpoint_name: str = "" +checkpoint_dir: str | None = None + +task: GenerationMode | None = None +task_config: Dict[str, Any] = {} + +train_data: Dataset | None = None +test_data: Dataset | None = None + +resolution: str | None = None + +current_training_task = None + +# ========================= hooks ========================= + + +def update_lora_name(name: str) -> Tuple[gr.Textbox]: + """Update the lora_name when the text field changes.""" + global checkpoint_dir, checkpoint_name, lora_name + lora_name.value = name + checkpoint_name = slugify(name) + + checkpoint_dir = checkpoint_rootdir + "/" + checkpoint_name + updated_checkpoint_dir = gr.Textbox( + label="Checkpoint Directory", + info="Path to the checkpoint directory", + interactive=False, + value=checkpoint_dir, + ) + return updated_checkpoint_dir + + +def update_task(hf_model_id: str) -> Tuple[gr.Dropdown]: + global task, task_config, checkpoint_rootdir, checkpoint_dir, model_type, resolution + model_type.value = hf_model_id + task = guess_generation_mode("THUDM/" + hf_model_id) + task_config = load_config_template(task) + + checkpoint_rootdir = get_lora_checkpoint_rootdir(task) + checkpoint_dir = checkpoint_rootdir + "/" + checkpoint_name + + if task == GenerationMode.TextToImage: + resolution_list = get_resolutions(task) + default_value = resolution_list[0] + info = "Height x Width" + + elif task == GenerationMode.TextToVideo: + resolution_list = get_resolutions(task) + default_value = resolution_list[0] + info = "Frames x Height x Width" + + logger.info(f"Current task: {task}") + logger.info(f"lora_checkpoint_rootdir: {checkpoint_rootdir}") + + updated_checkpoint_dir = gr.Textbox( + label="Checkpoint Directory", + info="Path to the checkpoint directory", + interactive=False, + value=checkpoint_dir, + ) + + resolution = default_value + updated_resolution_dropdown = gr.Dropdown( + choices=resolution_list, + label="Training Resolution", + info=info, + interactive=True, + value=default_value, + ) + + # Return the resolution list for updating train_resolution choices + return updated_checkpoint_dir, updated_resolution_dropdown + + +def update_do_validation(user_input_do_validation: bool) -> None: + global do_validation + do_validation.value = user_input_do_validation + + +def update_train_data(user_input_data_dir: str) -> List[Tuple[str, str]]: + global train_data, data_dir + assert task is not None + + data_dir.value = user_input_data_dir + + progress = gr.Progress() + progress(0, desc="Loading dataset...") + train_data = load_data(user_input_data_dir, task) + progress(1, desc="Dataset loaded successfully!") + logger.info(f"Loaded training data from {user_input_data_dir}") + logger.info(f"Train data: {train_data}") + + ###### Prepare data for display in the gallery component + if task == GenerationMode.TextToImage: + # num_samples = min(10, len(train_data)) + num_samples = len(train_data) + sample_images = train_data["image"][:num_samples] + sample_captions = train_data["prompt"][:num_samples] + return [(img, cap) for img, cap in zip(sample_images, sample_captions)] + + elif task == GenerationMode.TextToVideo: + # Create a temporary directory to store video files + temp_dir = Path(tempfile.gettempdir()) / "cogkit_videos" + os.makedirs(temp_dir, exist_ok=True) + + num_samples = min(50, len(train_data)) + sample_videos = [] + sample_captions = train_data["prompt"][:num_samples] + + # Save videos to temporary files + for i, video in enumerate(train_data["video"][:num_samples]): + video_path = str(temp_dir / f"{uuid.uuid4()}.mp4") + + # Get frames from VideoReader and convert to tensor + frames = [] + for frame in video: + frames.append(frame["data"]) + + # Stack frames and save as video + if frames: + video_tensor = torch.stack(frames) + # Change from (T, C, H, W) to (T, H, W, C) format + video_tensor = video_tensor.permute(0, 2, 3, 1) + fps = video.get_metadata().get("fps", 30) # Default to 30 fps if not available + write_video(video_path, video_tensor, fps=fps) + sample_videos.append(video_path) + + return [(video_path, cap) for video_path, cap in zip(sample_videos, sample_captions)] + + return [] + + +def update_training_config() -> None: + assert model_type.value is not None + assert task_config["model"]["model_type"] == task.value + assert task_config["model"]["training_type"] == "lora" + + if lora_name.value is None or lora_name.value == "": + raise gr.Error("Lora name cannot be empty") + + out_dir = Path(checkpoint_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + ###### Rewrite configs + task_config["model"]["model_path"] = "THUDM/" + model_type.value + + task_config["output"]["output_dir"] = resolve_path(out_dir) + + task_config["data"]["data_root"] = data_dir.value + + task_config["training"]["train_epochs"] = epochs.value + task_config["training"]["batch_size"] = batch_size.value + task_config["training"]["learning_rate"] = learning_rate.value + task_config["training"]["train_resolution"] = resolution + + task_config["checkpoint"]["checkpointing_steps"] = checkpointing_step.value + task_config["checkpoint"]["checkpointing_limit"] = checkpointing_limit.value + + task_config["validation"]["do_validation"] = do_validation.value + task_config["validation"]["validation_steps"] = checkpointing_step.value + + logger.info(f"task config: {task_config}") + + +def run_training() -> BaseTask: + """Run the training process using accelerate launch and the configured parameters.""" + logger.info("Starting training process...") + + # Verify task is initialized + if task is None: + raise gr.Error("Error: No model type selected") + + # Verify data directory is set + if not train_data: + raise gr.Error("Error: No training data loaded") + + cmd_args = [sys.executable, get_training_script()] + + # Flatten command line dict + flat_config = flatten_dict(task_config, ignore_none=True) + + for param_name, param_value in flat_config.items(): + cmd_args.extend([f"--{param_name}", str(param_value)]) + + # Create and run the task + training_task = BaseTask(cmd_args) + training_task.run() + + return training_task + + +def start_training_process() -> Iterator[str]: + """Update the training config and start the training process.""" + global current_training_task + + # First update the training configuration + update_training_config() + + # Then run the training + current_training_task = run_training() + + gr.Info(f"Training process started with PID: {current_training_task.get_pid()}") + + # Stream output from the task + output_text = "" + for line in current_training_task.iter_output(): + output_text += line + "\n" + yield output_text + + # Final update after process completes + yield output_text + "\nTraining process completed!" + + gr.Info("Training process completed!") + + +def update_training_resolution(user_input_resolution: str) -> None: + """Update the training resolution in the task config.""" + global resolution + resolution = user_input_resolution + logger.info(f"Updating training resolution: {resolution}") + + +# =========================== UI =========================== + +with gr.Blocks() as demo: + # gr.Markdown("""# LoRA Ease for CogView 🧞‍♂️""") + + with gr.Row(): + lora_name = gr.Textbox( + label="Name of your LoRA checkpoint", + info="This has to be a unique name", + placeholder="e.g.: Persian Miniature Painting style, Cat Toy", + value="", + ) + + model_type = gr.Dropdown( + choices=["CogView4-6B", "CogVideoX1.5-5B"], + label="Model", + info="Select the model to use", + interactive=True, + value=None, + ) + + data_dir = gr.Dropdown( + choices=data_dirs, + label="Dataset Directory", + info="Select the dataset directory to use", + interactive=True, + value=None, + ) + + checkpoint_dir = gr.Textbox( + label="Checkpoint Directory", + info="Path to the checkpoint directory", + interactive=False, + value=None, + ) + + # Add a section to display training data samples + with gr.Row(): + sample_gallery = gr.Gallery( + label="Training Data Preview", + show_label=True, + elem_id="gallery", + columns=5, + object_fit="contain", + height="auto", + ) + + gr.Markdown("### Training Configuration") + with gr.Column(): + with gr.Row(): + train_resolution = gr.Dropdown( + choices=[], + label="Training Resolution", + info="Resolution for training", + interactive=True, + ) + batch_size = gr.Number( + value=1, + label="Batch Size", + info="Number of samples per training batch", + interactive=True, + ) + epochs = gr.Number( + value=1, + label="Epochs", + info="Number of training epochs", + interactive=True, + ) + learning_rate = gr.Number( + value=2e-5, + step=1e-5, + label="Learning Rate", + info="Training learning rate", + interactive=True, + ) + checkpointing_step = gr.Number( + value=10, + label="Checkpointing Step", + info="Number of training steps between checkpoints", + interactive=True, + ) + checkpointing_limit = gr.Number( + value=2, + label="Checkpointing Limit", + info="Maximum number of checkpoints to keep", + interactive=True, + ) + + do_validation = gr.Checkbox( + value=False, + label="Enable Validation", + info="Whether to perform validation during training", + ) + + start_train = gr.Button("Start training") + + training_output = gr.Textbox( + label="Training Output", + placeholder="Training output will appear here...", + interactive=False, + lines=20, + autoscroll=True, + ) + + ###### Binding hooks + lora_name.change(fn=update_lora_name, inputs=[lora_name], outputs=[checkpoint_dir]) + model_type.change( + fn=update_task, inputs=[model_type], outputs=[checkpoint_dir, train_resolution] + ) + train_resolution.change(fn=update_training_resolution, inputs=[train_resolution]) + data_dir.change(fn=update_train_data, inputs=[data_dir], outputs=[sample_gallery]) + do_validation.change(fn=update_do_validation, inputs=[do_validation]) + start_train.click(fn=start_training_process, inputs=None, outputs=[training_output]) + + +if __name__ == "__main__": + demo.launch(share=True, show_error=True) diff --git a/gradio/gradio_ui.py b/gradio/gradio_ui.py new file mode 100644 index 0000000..b264d20 --- /dev/null +++ b/gradio/gradio_ui.py @@ -0,0 +1,10 @@ +from gradio_infer_demo import demo as demo1 +from gradio_lora_demo import demo as demo2 +from styles.mono import CSS, THEME + +import gradio as gr + + +if __name__ == "__main__": + demo = gr.TabbedInterface([demo1, demo2], ["Inference", "Train"], theme=THEME, css=CSS) + demo.launch() diff --git a/gradio/styles/mono.py b/gradio/styles/mono.py new file mode 100644 index 0000000..2f8bce9 --- /dev/null +++ b/gradio/styles/mono.py @@ -0,0 +1,18 @@ +import gradio as gr + + +THEME = gr.themes.Monochrome( + text_size=gr.themes.Size( + lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px" + ), + font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"], +) + +CSS = """ +h1{font-size: 2em} +h3{margin-top: 0} +#component-1{text-align:center} +.main_ui_logged_out{opacity: 0.3; pointer-events: none} +.tabitem{border: 0px} +.group_padding{padding: .55em} +""" diff --git a/gradio/utils/__init__.py b/gradio/utils/__init__.py new file mode 100644 index 0000000..9b92bad --- /dev/null +++ b/gradio/utils/__init__.py @@ -0,0 +1,26 @@ +from .io import ( + get_dataset_dirs, + get_lora_checkpoint_dirs, + get_lora_checkpoint_rootdir, + get_training_script, + load_config_template, + load_data, + resolve_path, +) +from .logging import get_logger +from .misc import flatten_dict, get_resolutions +from .task import BaseTask + +__all__ = [ + "get_dataset_dirs", + "get_training_script", + "get_lora_checkpoint_dirs", + "get_lora_checkpoint_rootdir", + "load_config_template", + "load_data", + "get_logger", + "resolve_path", + "BaseTask", + "get_resolutions", + "flatten_dict", +] diff --git a/gradio/utils/io.py b/gradio/utils/io.py new file mode 100644 index 0000000..5b6ec0b --- /dev/null +++ b/gradio/utils/io.py @@ -0,0 +1,96 @@ +from pathlib import Path +from typing import Any, Dict, List, Literal + +import yaml +from datasets import Dataset, load_dataset + +from cogkit import GenerationMode + +_QUICKSTART_ROOT_DIR = Path(__file__).parent.parent.parent / "quickstart" +_GRADIO_ROOT_DIR = Path(__file__).parent.parent + +_DATASET_ROOT_DIR = _QUICKSTART_ROOT_DIR / "data" +_TRAINING_SCRIPT_FILE = _QUICKSTART_ROOT_DIR / "scripts" / "train.py" +_LORA_CHECKPOINT_ROOT_DIR = _GRADIO_ROOT_DIR / "lora_checkpoints" +_CONFIG_ROOT_DIR = _GRADIO_ROOT_DIR / "configs" + + +def resolve_path(path: str | Path) -> str: + return str(Path(path).expanduser().resolve()) + + +def get_dirs(dir_path: Path) -> List[str]: + dir_path.mkdir(exist_ok=True) + return [resolve_path(d) for d in dir_path.iterdir() if d.is_dir()] + + +def get_dataset_dirs() -> List[str]: + """Get all dataset directories from quickstart/data.""" + return get_dirs(_DATASET_ROOT_DIR) + + +def get_training_script() -> str: + return resolve_path(_TRAINING_SCRIPT_FILE) + + +def get_lora_checkpoint_dirs(task: GenerationMode) -> List[str]: + """Get all lora checkpoint directories from lora_checkpoints.""" + return get_dirs(_LORA_CHECKPOINT_ROOT_DIR / task.value) + + +def get_lora_checkpoint_rootdir(task: GenerationMode) -> str: + return resolve_path(_LORA_CHECKPOINT_ROOT_DIR / task.value) + + +def load_config_template(generation_task: GenerationMode) -> Dict[str, Any]: + """ + Read YAML configuration template based on generation task. + + Args: + generation_task: Task type (e.g., 't2i', 't2v') + + Returns: + Parsed YAML as dictionary + """ + config_file = f"{generation_task.value}.yaml" + config_path = _CONFIG_ROOT_DIR / config_file + + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found at {config_path}") + + with open(config_path, "r") as f: + yaml_dict = yaml.safe_load(f) + + return yaml_dict + + +def load_data( + data_dir: str, task: GenerationMode, split: Literal["train", "test"] = "train" +) -> Dataset: + data_dir = Path(data_dir) + train_dir = data_dir / "train" + test_dir = data_dir / "test" + assert split in ["train", "test"], f"Invalid split: {split}" + if split == "train": + assert train_dir.exists(), f"Train directory {train_dir} does not exist" + else: + assert test_dir.exists(), f"Test directory {test_dir} does not exist" + + match task: + case GenerationMode.TextToImage: + if split == "train": + return load_dataset("imagefolder", data_dir=train_dir, split="train") + else: + return load_dataset("json", data_dir=test_dir, split="test") + + case GenerationMode.TextToVideo: + if split == "train": + return load_dataset("videofolder", data_dir=train_dir, split="train") + else: + return load_dataset("json", data_dir=test_dir, split="test") + + case GenerationMode.ImageToVideo: + raise NotImplementedError("Image to video is not implemented") + + case _: + raise ValueError(f"Unsupported task: {task}") diff --git a/gradio/utils/logging.py b/gradio/utils/logging.py new file mode 100644 index 0000000..262cb81 --- /dev/null +++ b/gradio/utils/logging.py @@ -0,0 +1,52 @@ +import logging +import sys +from typing import Optional + + +def get_logger(name: str, level: Optional[int] = None) -> logging.Logger: + """ + Get a logger instance with proper formatting and configuration. + + Args: + name: The name of the logger, typically __name__ from the calling module + level: Optional logging level. If not provided, defaults to INFO + + Returns: + logging.Logger: Configured logger instance + """ + logger = logging.getLogger(name) + + # If logger already has handlers, return it to avoid duplicate handlers + if logger.handlers: + return logger + + # Set default level to INFO if not specified + if level is None: + level = logging.INFO + + logger.setLevel(level) + + # Create console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(level) + + # Create formatter with color support + class ColoredFormatter(logging.Formatter): + def format(self, record): + if record.levelno == logging.INFO: + record.levelname = f"\033[32m{record.levelname}\033[0m" # Green color for INFO + record.msg = f"\033[32m{record.msg}\033[0m" # Green color for message + elif record.levelno in (logging.WARNING, logging.ERROR): + record.levelname = ( + f"\033[31m{record.levelname}\033[0m" # Red color for WARNING and ERROR + ) + record.msg = f"\033[31m{record.msg}\033[0m" # Red color for message + return super().format(record) + + formatter = ColoredFormatter("\n- %(levelname)s -\n%(message)s\n") + console_handler.setFormatter(formatter) + + # Add handler to logger + logger.addHandler(console_handler) + + return logger diff --git a/gradio/utils/misc.py b/gradio/utils/misc.py new file mode 100644 index 0000000..b548d29 --- /dev/null +++ b/gradio/utils/misc.py @@ -0,0 +1,53 @@ +from typing import Any, Dict, List + +from cogkit import GenerationMode + + +def get_resolutions(task: GenerationMode) -> List[str]: + if task == GenerationMode.TextToImage: + return [ + "512x512", + "512x768", + "512x1024", + "720x1280", + "768x768", + "1024x1024", + "1080x1920", + ] + elif task == GenerationMode.TextToVideo: + return [ + "49x480x720", + "81x768x1360", + ] + + +def flatten_dict(d: Dict[str, Any], ignore_none: bool = False) -> Dict[str, Any]: + """ + Flattens a nested dictionary into a single layer dictionary. + + Args: + d: The dictionary to flatten + ignore_none: If True, keys with None values will be omitted + + Returns: + A flattened dictionary + + Raises: + ValueError: If there are duplicate keys across nested dictionaries + """ + result = {} + + def _flatten(current_dict, result_dict): + for key, value in current_dict.items(): + if value is None and ignore_none: + continue + + if isinstance(value, dict): + _flatten(value, result_dict) + else: + if key in result_dict: + raise ValueError(f"Duplicate key '{key}' found in nested dictionary") + result_dict[key] = value + + _flatten(d, result) + return result diff --git a/gradio/utils/task.py b/gradio/utils/task.py new file mode 100644 index 0000000..13292c2 --- /dev/null +++ b/gradio/utils/task.py @@ -0,0 +1,185 @@ +import logging +import subprocess +import threading +import time +from queue import Empty, Queue +from typing import Iterator, List, Optional +from .logging import get_logger + + +class BaseTask: + """ + A class to manage command execution as a subprocess with output streaming. + + This class handles running a command as a subprocess, capturing its output, + and providing methods to access that output either line by line or as a stream. + Each task can only be started once and can have only one consumer of its output. + """ + + def __init__(self, command: List[str], logger: Optional[logging.Logger] = None): + """ + Initialize a new BaseTask. + + Args: + command: List of command arguments to execute + logger: Optional logger for task status messages + """ + self.command = command + self.logger = logger or get_logger(__name__) + + # Process information + self.pid: Optional[int] = None + self.process: Optional[subprocess.Popen] = None + + # State tracking + self.started = False + self.finished = False + self.error = None + self.return_code: Optional[int] = None + + # Output handling + self.output_queue = Queue() + self.complete_output = [] + self.output_thread = None + self.process_completed = threading.Event() + + def run(self) -> None: + """ + Run the task as a subprocess and start capturing its output. + + Raises: + RuntimeError: If the task has already been started + subprocess.SubprocessError: If the subprocess fails to start + """ + if self.started: + raise RuntimeError("Task has already been started") + + self.started = True + + command_str = " ".join(self.command) + self.logger.info(f"Running command:\n{command_str}") + + try: + # Start the process + self.process = subprocess.Popen( + self.command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True, + ) + + self.pid = self.process.pid + self.logger.info(f"Started process with PID: {self.pid}") + + # Start a thread to read the output + self.output_thread = threading.Thread( + target=self._read_output, + args=(self.process, self.output_queue, self.process_completed), + ) + self.output_thread.daemon = True + self.output_thread.start() + + except Exception as e: + self.error = e + self.finished = True + self.logger.error(f"Failed to start process: {e}") + raise + + def _read_output( + self, process: subprocess.Popen, queue: Queue, completed_event: threading.Event + ) -> None: + """ + Read output from process and put it in the queue. + + Args: + process: The subprocess to read from + queue: Queue to store output lines + completed_event: Event to signal when process is complete + """ + for line in iter(process.stdout.readline, ""): + if line: + line = line.strip() + self.complete_output.append(line) + queue.put(line) + + # Process has completed + self.return_code = process.wait() + self.finished = True + queue.put("") # Empty string signals end of output + self.logger.info(f"Process completed with return code: {self.return_code}") + completed_event.set() + + def join(self) -> None: + """Wait for the task to complete.""" + if not self.started: + raise RuntimeError("Task has not been started") + + self.process_completed.wait() + if self.output_thread: + self.output_thread.join() + + def getline(self) -> str: + """ + Get a single line of output from the task. + + Returns: + A line of output, or empty string if task has finished and queue is empty + """ + if not self.started: + raise RuntimeError("Task has not been started") + + if self.finished and self.output_queue.empty(): + return "" + + try: + return self.output_queue.get(timeout=0.1) + except Empty: + return "" if self.finished else self.getline() + + def iter_output(self) -> Iterator[str]: + """ + Iterate over the output of the task. + + Yields: + Lines of output from the task + """ + if not self.started: + raise RuntimeError("Task has not been started") + + while not self.process_completed.is_set() or not self.output_queue.empty(): + try: + line = self.output_queue.get(timeout=0.1) + if line: # Skip empty string which signals end + yield line + except Empty: + pass + time.sleep(0.01) + + def get_output(self) -> str: + """ + Get the complete output of the task. + + Returns: + The complete output as a string + + Raises: + RuntimeError: If the task has not finished + """ + if not self.finished: + self.join() + + return "\n".join(self.complete_output) + + def is_finished(self) -> bool: + """Check if the task has finished.""" + return self.finished + + def get_pid(self) -> Optional[int]: + """Get the PID of the subprocess.""" + return self.pid + + def get_return_code(self) -> Optional[int]: + """Get the return code of the subprocess.""" + return self.return_code From a60f94dd7714d0c7fe2dc43aba04f09e8e27ee06 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Mon, 7 Apr 2025 08:54:09 +0000 Subject: [PATCH 4/6] [deps] Change slugify version requirement --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8ae078d..12fb154 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "pydantic_settings~=2.8.1", "python-dotenv~=1.0", "gradio~=5.23", - "python-slugify>=8.0.4", + "python-slugify~=8.0", ] [project.optional-dependencies] From 37351e64f0537bca5038d0c3fd553e42b2e80968 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Mon, 7 Apr 2025 09:00:48 +0000 Subject: [PATCH 5/6] [chore] Remove misleading annotations --- gradio/configs/t2i.yaml | 2 -- gradio/configs/t2v.yaml | 2 -- 2 files changed, 4 deletions(-) diff --git a/gradio/configs/t2i.yaml b/gradio/configs/t2i.yaml index 77823e7..e8efe40 100644 --- a/gradio/configs/t2i.yaml +++ b/gradio/configs/t2i.yaml @@ -1,5 +1,3 @@ -# CogView4 Configuration - # Model Configuration model: model_path: "THUDM/CogView4-6B" # Path to the pre-trained model diff --git a/gradio/configs/t2v.yaml b/gradio/configs/t2v.yaml index a8e491c..532e522 100644 --- a/gradio/configs/t2v.yaml +++ b/gradio/configs/t2v.yaml @@ -1,5 +1,3 @@ -# CogView4 Configuration - # Model Configuration model: model_path: "THUDM/CogVideoX1.5-5B" # Path to the pre-trained model From c4121acd4a78a10ed213a89ca2e06e15e97b7344 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Mon, 7 Apr 2025 09:03:42 +0000 Subject: [PATCH 6/6] [chore] Reuse `get_resolutions` --- gradio/gradio_infer_demo.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/gradio/gradio_infer_demo.py b/gradio/gradio_infer_demo.py index a4a3c3d..5b0b225 100644 --- a/gradio/gradio_infer_demo.py +++ b/gradio/gradio_infer_demo.py @@ -9,6 +9,7 @@ get_logger, get_lora_checkpoint_dirs, get_lora_checkpoint_rootdir, + get_resolutions, ) import gradio as gr @@ -75,22 +76,11 @@ def update_task(hf_model_id: str) -> Tuple[gr.Dropdown, gr.Component]: # Configure resolution dropdown based on task if task == GenerationMode.TextToImage: - resolution_list = [ - "512x512", - "512x768", - "512x1024", - "720x1280", - "768x768", - "1024x1024", - "1080x1920", - ] + resolution_list = get_resolutions(task) default_resolution = resolution_list[0] resolution_info = "Height x Width" elif task == GenerationMode.TextToVideo: - resolution_list = [ - "49x480x720", - "81x768x1360", - ] + resolution_list = get_resolutions(task) default_resolution = resolution_list[0] resolution_info = "Frames x Height x Width" else: