From ee8fcfbc1b38d1ef5f4a023bffeef50043566918 Mon Sep 17 00:00:00 2001 From: Lee Penkman Date: Mon, 7 Jul 2025 20:00:23 +1200 Subject: [PATCH] Add package init and log_time test --- flux_schnell.py | 35 ++ main.py | 432 +++++++----------------- readme.md | 11 +- requirements.in | 2 +- requirements.txt | 14 +- stable_diffusion_server/__init__.py | 17 + stable_diffusion_server/prompt_utils.py | 36 ++ stable_diffusion_server/utils.py | 2 +- tests/unit/test_prompt_utils.py | 23 ++ tests/unit/test_utils.py | 15 + 10 files changed, 266 insertions(+), 321 deletions(-) create mode 100644 flux_schnell.py create mode 100644 stable_diffusion_server/__init__.py create mode 100644 stable_diffusion_server/prompt_utils.py create mode 100644 tests/unit/test_prompt_utils.py create mode 100644 tests/unit/test_utils.py diff --git a/flux_schnell.py b/flux_schnell.py new file mode 100644 index 0000000..2442ef9 --- /dev/null +++ b/flux_schnell.py @@ -0,0 +1,35 @@ +import torch +from diffusers import FluxPipeline, FluxControlNetPipeline, ControlNetModel + + +def load_pipeline(model="black-forest-labs/FLUX.1-schnell"): + """Load the Flux pipeline with optional controlnet.""" + pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16) + pipe.enable_model_cpu_offload() + return pipe + + +def load_controlnet(pipe, model="black-forest-labs/flux-controlnet-canny"): + controlnet = ControlNetModel.from_pretrained(model, torch_dtype=torch.bfloat16) + cpipe = FluxControlNetPipeline(controlnet=controlnet, **pipe.components) + cpipe.enable_model_cpu_offload() + return cpipe + + +def generate_image(pipe, prompt, seed=0, steps=4): + generator = torch.Generator("cpu").manual_seed(seed) + image = pipe( + prompt, + guidance_scale=0.0, + num_inference_steps=steps, + max_sequence_length=256, + generator=generator, + ).images[0] + return image + + +if __name__ == "__main__": + prompt = "A cat holding a sign that says hello world" + pipe = load_pipeline() + image = generate_image(pipe, prompt) + image.save("flux-schnell.png") diff --git a/main.py b/main.py index ad41056..f072743 100644 --- a/main.py +++ b/main.py @@ -30,6 +30,8 @@ ControlNetModel, StableDiffusionXLControlNetPipeline, AutoPipelineForImage2Image, + FluxPipeline, + FluxControlNetPipeline, ) from diffusers.utils import load_image from fastapi import FastAPI @@ -45,6 +47,12 @@ from stable_diffusion_server.bumpy_detection import detect_too_bumpy from stable_diffusion_server.image_processing import process_image_for_stable_diffusion from stable_diffusion_server.utils import log_time +from stable_diffusion_server.prompt_utils import ( + shorten_too_long_text, + shorten_prompt_for_retry, + remove_stopwords, + stopwords, +) try: import pillow_avif @@ -89,6 +97,24 @@ pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl", adapter_name="lcm") pipe.set_adapters(["lcm"], adapter_weights=[1.0]) +# Load Flux Schnell pipeline for efficient text-to-image +flux_pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 +) +flux_pipe.enable_model_cpu_offload() + +try: + flux_controlnet = ControlNetModel.from_pretrained( + "black-forest-labs/flux-controlnet-canny", torch_dtype=torch.bfloat16 + ) + flux_controlnetpipe = FluxControlNetPipeline( + controlnet=flux_controlnet, **flux_pipe.components + ) + flux_controlnetpipe.enable_model_cpu_offload() +except Exception as e: + logger.error(f"Failed to load Flux ControlNet: {e}") + flux_controlnetpipe = None + # quantizing from optimum.quanto import freeze, qfloat8, quantize @@ -426,7 +452,6 @@ async def healthz(): return {"status": "ok"} -stopwords = nltk.corpus.stopwords.words("english") negative = "3 or 4 ears, never BUT ONE EAR, blurry, unclear, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers, mangled teeth, weird teeth, poorly drawn eyes, blurry eyes, tan skin, oversaturated, teeth, poorly drawn, ugly, closed eyes, 3D, weird neck, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, extra limbs, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, text, logo, wordmark, writing, signature, blurry, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers, Removed From Image Removed From Image flowers, Deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, long body, ((((mutated hands and fingers)))), cartoon, 3d ((disfigured)), ((bad art)), ((deformed)), ((extra limbs)), ((dose up)), ((b&w)), Wierd colors, blurry, (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), (poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), (extra limbs)), cloned face, (((disfigured))), out of frame ugly, extra limbs (bad anatomy), gross proportions (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck))), Photoshop, videogame, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured deformed cross-eye, ((body out of )), blurry, bad art, bad anatomy, 3d render, two faces, duplicate, coppy, multi, two, disfigured, kitsch, ugly, oversaturated, grain, low-res, Deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, out of focus, long neck, long body, ugly, disgusting, poorly drawn, childish, mutilated, mangled, old ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, signature, cut off, draf, blurry, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers" negative2 = "ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers" extra_pipe_args = { @@ -632,11 +657,9 @@ def style_transfer_image_from_prompt( retries=3, ): prompt = shorten_too_long_text(prompt) - # image = pipe(guidance_scale=7,prompt=prompt).images[0] if not is_defined(input_pil): input_pil = load_image(image_url).convert("RGB") - # resize to nice size input_pil = process_image_for_stable_diffusion(input_pil) canny_image = None if canny: @@ -646,99 +669,40 @@ def style_transfer_image_from_prompt( in_image = in_image[:, :, None] in_image = np.concatenate([in_image, in_image, in_image], axis=2) canny_image = Image.fromarray(in_image) - # reset seed to be more deterministic? set_seed(42) - try: - if canny: - # generate image - image = controlnetpipe( - prompt, - controlnet_conditioning_scale=controlnet_conditioning_scale, - image=canny_image, - num_inference_steps=n_steps, - **extra_pipe_args, - ).images[0] - else: - image = img2img( - prompt=prompt, - image=input_pil, - num_inference_steps=n_steps, - strength=strength, - **extra_pipe_args, - ).images[0] - except Exception as err: - # try rm stopwords + half the prompt - # todo try prompt permutations - logger.error(err) - logger.info(f"trying to shorten prompt of length {len(prompt)}") - - prompt = " ".join((word for word in prompt if word not in stopwords)) - prompts = prompt.split() - - prompt = " ".join(prompts[: len(prompts) // 2]) - logger.info(f"shortened prompt to: {len(prompt)}") - image = None - if prompt: - try: - if canny: - # generate image - image = controlnetpipe( - prompt, - controlnet_conditioning_scale=controlnet_conditioning_scale, - image=canny_image, - num_inference_steps=n_steps, - **extra_pipe_args, - ).images[0] - else: - image = img2img( - prompt=prompt, - image=input_pil, - num_inference_steps=n_steps, - strength=strength, - **extra_pipe_args, - ).images[0] - except Exception as err: - # logger.info("trying to permute prompt") - # # try two swaps of the prompt/permutations - # prompt = prompt.split() - # prompt = ' '.join(permutations(prompt, 2).__next__()) - logger.info(f"trying to shorten prompt of length {len(prompt)}") - - prompt = " ".join((word for word in prompt if word not in stopwords)) - prompts = prompt.split() - - prompt = " ".join(prompts[: len(prompts) // 2]) - logger.info(f"shortened prompt to: {len(prompt)}") - - try: - if canny: - # generate image - image = controlnetpipe( - prompt, - controlnet_conditioning_scale=controlnet_conditioning_scale, - image=canny_image, - num_inference_steps=n_steps, - **extra_pipe_args, - ).images[0] - else: - image = img2img( - prompt=prompt, - image=input_pil, - num_inference_steps=n_steps, - strength=strength, - **extra_pipe_args, - ).images[0] - except Exception as inner_error: - # just error out - traceback.print_exc() - raise inner_error - # logger.info("restarting server to fix cuda issues (device side asserts)") - # todo fix device side asserts instead of restart to fix - # todo only restart the correct gunicorn - # this could be really annoying if your running other gunicorns on your machine which also get restarted - # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`") - # os.system("kill -1 `pgrep gunicorn`") + generator = torch.Generator("cpu").manual_seed(0) + for attempt in range(retries + 1): + try: + if canny and flux_controlnetpipe: + image = flux_controlnetpipe( + prompt=prompt, + image=canny_image, + num_inference_steps=n_steps, + guidance_scale=0.0, + generator=generator, + max_sequence_length=256, + ).images[0] + else: + image = flux_pipe( + prompt=prompt, + width=input_pil.width, + height=input_pil.height, + guidance_scale=0.0, + num_inference_steps=n_steps, + generator=generator, + max_sequence_length=256, + ).images[0] + break + except Exception as err: + if attempt >= retries: + raise + logger.warning( + f"Flux style transfer failed on attempt {attempt + 1}/{retries}: {err}" + ) + prompt = remove_stopwords(prompt) if attempt == 0 else shorten_prompt_for_retry(prompt) + if not prompt: + raise err # todo refine # if image != None and use_refiner: # image = refiner( @@ -791,155 +755,68 @@ def style_transfer_image_from_prompt( return image_to_bytes(image) -# multiprocessing.set_start_method('spawn', True) -# processes_pool = Pool(1) # cant do too much at once or OOM errors happen -# def create_image_from_prompt_sync(prompt): -# """have to call this sync to avoid OOM errors""" -# return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait() - - def create_image_from_prompt( - prompt, width, height, n_steps=5, extra_args={}, retries=3 + prompt, width, height, n_steps=5, extra_args=None, retries=3 ): - # round width and height down to multiple of 64 + """Generate an image using the Flux Schnell pipeline with retries.""" + if extra_args is None: + extra_args = {} + block_width = width - (width % 64) block_height = height - (height % 64) prompt = shorten_too_long_text(prompt) - extra_total_args = {**extra_pipe_args, **extra_args} - # image = pipe(guidance_scale=7,prompt=prompt).images[0] - try: - image = pipe( - prompt=prompt, - # guidance_scale=7, - width=block_width, - height=block_height, - # denoising_end=high_noise_frac, - output_type="latent" if use_refiner else "pil", - # height=512, - # width=512, - num_inference_steps=n_steps, - **extra_total_args, - ).images[0] - except Exception as e: - # try rm stopwords + half the prompt - # todo try prompt permutations - logger.info(f"trying to shorten prompt of length {len(prompt)}") - - prompt = " ".join((word for word in prompt if word not in stopwords)) - prompts = prompt.split() - - prompt = " ".join(prompts[: len(prompts) // 2]) - logger.info(f"shortened prompt to: {len(prompt)}") - image = None - if prompt: - try: - image = pipe( - prompt=prompt, - # guidance_scale=7, - negative_prompt=negative, - width=block_width, - height=block_height, - # denoising_end=high_noise_frac, - output_type="latent" if use_refiner else "pil", - # height=512, - # width=512, - num_inference_steps=n_steps, - **extra_total_args, - ).images[0] - except Exception as e: - # logger.info("trying to permute prompt") - # # try two swaps of the prompt/permutations - # prompt = prompt.split() - # prompt = ' '.join(permutations(prompt, 2).__next__()) - logger.info(f"trying to shorten prompt of length {len(prompt)}") - - prompt = " ".join((word for word in prompt if word not in stopwords)) - prompts = prompt.split() - - prompt = " ".join(prompts[: len(prompts) // 2]) - logger.info(f"shortened prompt to: {len(prompt)}") - - try: - image = pipe( - prompt=prompt, - # guidance_scale=7, - negative_prompt=negative, - width=block_width, - height=block_height, - # denoising_end=high_noise_frac, - output_type=( - "latent" if use_refiner else "pil" - ), # dont need latent yet - we refine the image at full res - # height=512, - # width=512, - num_inference_steps=n_steps, - ).images[0] - except Exception as e: - # just error out - traceback.print_exc() - raise e - # logger.info("restarting server to fix cuda issues (device side asserts)") - # todo fix device side asserts instead of restart to fix - # todo only restart the correct gunicorn - # this could be really annoying if your running other gunicorns on your machine which also get restarted - # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`") - # os.system("kill -1 `pgrep gunicorn`") - # todo refine - if image != None and use_refiner: - # todo depend on q length? - # refiner.set_adapters(["lcm"], adapter_weights=[0]) # turn lcm off temporarily - image = refiner( - prompt=prompt, - num_inference_steps=8, - # guidance_scale=7, - # width=block_width, - # height=block_height, - # num_inference_steps=n_steps, # default - # denoising_start=high_noise_frac, - image=image, - ).images[0] - # pipe.set_adapters(["lcm"], adapter_weights=[1.0]) # turn lcm back on + generator = torch.Generator("cpu").manual_seed(extra_args.get("seed", 0)) + + for attempt in range(retries + 1): + try: + image = flux_pipe( + prompt=prompt, + width=block_width, + height=block_height, + guidance_scale=0.0, + num_inference_steps=n_steps, + generator=generator, + max_sequence_length=256, + ).images[0] + break + except Exception as err: # pragma: no cover - hardware/oom errors + if attempt >= retries: + raise + logger.warning( + f"Flux generation failed on attempt {attempt + 1}/{retries}: {err}" + ) + if attempt == 0: + prompt = remove_stopwords(prompt) + else: + prompt = shorten_prompt_for_retry(prompt) + if not prompt: + raise err + if width != block_width or height != block_height: - # resize to original size width/height - # find aspect ratio to scale up to that covers the original img input width/height scale_up_ratio = max(width / block_width, height / block_height) image = image.resize( - ( - math.ceil(block_width * scale_up_ratio), - math.ceil(height * scale_up_ratio), - ) + (math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)) ) - # crop image to original size image = image.crop((0, 0, width, height)) - # try: - # # gc.collect() - # torch.cuda.empty_cache() - # except Exception as e: - # traceback.print_exc() - # logger.info("restarting server to fix cuda issues (device side asserts)") - # # todo fix device side asserts instead of restart to fix - # # todo only restart the correct gunicorn - # # this could be really annoying if your running other gunicorns on your machine which also get restarted - # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`") - # os.system("kill -1 `pgrep gunicorn`") - # save as bytesio - - # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability - with open("progress.txt", "w") as f: - current_time = datetime.now().strftime("%H:%M:%S") - f.write(f"{current_time}") if detect_too_bumpy(image): if retries <= 0: - raise Exception( - "image too bumpy, retrying failed" - ) # todo fix and just accept it? + raise Exception("image too bumpy, retrying failed") logger.info("image too bumpy, retrying once w different prompt detailed") return create_image_from_prompt( prompt + " detail", width, height, n_steps + 1, extra_args, retries - 1 ) + return image_to_bytes(image) +# multiprocessing.set_start_method('spawn', True) +# processes_pool = Pool(1) # cant do too much at once or OOM errors happen +# def create_image_from_prompt_sync(prompt): +# """have to call this sync to avoid OOM errors""" +# return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait() + + + def image_to_bytes(image): bs = BytesIO() @@ -963,7 +840,7 @@ def image_to_bytes(image): return bio -def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str): +def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str, retries=3): prompt = shorten_too_long_text(prompt) # image = pipe(guidance_scale=7,prompt=prompt).images[0] @@ -973,73 +850,28 @@ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str): num_inference_steps = 40 high_noise_frac = 0.7 - try: - image = inpaintpipe( - prompt=prompt, - # guidance_scale=7, - image=init_image, - mask_image=mask_image, - num_inference_steps=num_inference_steps, - denoising_start=high_noise_frac, - output_type="latent", - ).images[ - 0 - ] # normally uses 50 steps - except Exception as e: - # try rm stopwords + half the prompt - # todo try prompt permutations - logger.info(f"trying to shorten prompt of length {len(prompt)}") - - prompt = " ".join((word for word in prompt if word not in stopwords)) - prompts = prompt.split() - - prompt = " ".join(prompts[: len(prompts) // 2]) - logger.info(f"shortened prompt to: {len(prompt)}") - image = None - if prompt: - try: - image = pipe( - prompt=prompt, - image=init_image, - # guidance_scale=7, - mask_image=mask_image, - num_inference_steps=num_inference_steps, - denoising_start=high_noise_frac, - output_type="latent", - ).images[0] - except Exception as e: - # logger.info("trying to permute prompt") - # # try two swaps of the prompt/permutations - # prompt = prompt.split() - # prompt = ' '.join(permutations(prompt, 2).__next__()) - logger.info(f"trying to shorten prompt of length {len(prompt)}") - - prompt = " ".join((word for word in prompt if word not in stopwords)) - prompts = prompt.split() - - prompt = " ".join(prompts[: len(prompts) // 2]) - logger.info(f"shortened prompt to: {len(prompt)}") - - try: - image = inpaintpipe( - prompt=prompt, - # guidance_scale=7, - image=init_image, - mask_image=mask_image, - num_inference_steps=num_inference_steps, - denoising_start=high_noise_frac, - output_type="latent", - ).images[0] - except Exception as e: - # just error out - traceback.print_exc() - raise e - # logger.info("restarting server to fix cuda issues (device side asserts)") - # todo fix device side asserts instead of restart to fix - # todo only restart the correct gunicorn - # this could be really annoying if your running other gunicorns on your machine which also get restarted - # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`") - # os.system("kill -1 `pgrep gunicorn`") + generator = torch.Generator("cpu").manual_seed(0) + for attempt in range(retries + 1): + try: + image = inpaintpipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + num_inference_steps=num_inference_steps, + denoising_start=high_noise_frac, + output_type="latent", + ).images[0] + break + except Exception as e: + if attempt >= retries: + traceback.print_exc() + raise + logger.warning( + f"Inpainting failed on attempt {attempt + 1}/{retries}: {e}" + ) + prompt = remove_stopwords(prompt) if attempt == 0 else shorten_prompt_for_retry(prompt) + if not prompt: + raise e if image != None: image = inpaint_refiner( prompt=prompt, @@ -1067,19 +899,3 @@ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str): return image_to_bytes(image) -def shorten_too_long_text(prompt): - if len(prompt) > 200: - # remove stopwords - prompt = prompt.split() # todo also split hyphens - prompt = " ".join((word for word in prompt if word not in stopwords)) - if len(prompt) > 200: - prompt = prompt[:200] - return prompt - - -# image = pipe(guidance_scale=7,prompt=prompt).images[0] -# -# image.save("test.png") -# save all images -# for i, image in enumerate(images): -# image.save(f"{i}.png") diff --git a/readme.md b/readme.md index cb8a9b2..150c61c 100644 --- a/readme.md +++ b/readme.md @@ -11,6 +11,7 @@ Welcome to Simple Stable Diffusion Server, your go-to solution for AI-powered im to an R2 bucket via the S3 API, but Google Cloud Storage remains supported. - **Versatile Applications**: Perfect for AI art generation, style transfer, and image inpainting. Bring any SDXL/diffusers model. - **Easy to Use**: Simple interface for generating images in Gradio locally and easy to use FastAPI docs/server for advanced users. +- **Prompt Utilities**: Helper functions for trimming and cleaning prompts live in `stable_diffusion_server/prompt_utils.py`. For a hosted AI Art Generation experience, check out our [AI Art Generator and Search Engine](https://aiart-generator.art), which offers advanced features like video creation and 2K upscaled images. @@ -52,9 +53,17 @@ Launch the user-friendly Gradio interface: ``` python gradio_ui.py ``` -Go to +Go to http://127.0.0.1:7860 +### Flux Schnell Example +The server now uses the lightweight Flux Schnell model by default. You can quickly +test the model with the helper script: +```bash +python flux_schnell.py +``` +This will generate `flux-schnell.png` using bf16 precision. + ![gradio demo](gradioimg.png) diff --git a/requirements.in b/requirements.in index def82e9..6437b66 100644 --- a/requirements.in +++ b/requirements.in @@ -40,7 +40,7 @@ tokenizers --extra-index-url https://download.pytorch.org/whl/cu122 torch tqdm -transformers +transformers==4.38.0 #triton typing_extensions urllib3 diff --git a/requirements.txt b/requirements.txt index c1da90e..1ff99b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -57,7 +57,6 @@ click==8.1.7 # via # -r requirements.in # nltk - # sacremoses # streamlit # uvicorn cmake==3.31.1 @@ -233,9 +232,7 @@ jmespath==1.0.1 # boto3 # botocore joblib==1.4.2 - # via - # nltk - # sacremoses + # via nltk jsonschema==4.23.0 # via altair jsonschema-specifications==2024.10.1 @@ -467,7 +464,6 @@ regex==2024.11.6 # -r requirements.in # diffusers # nltk - # sacremoses # transformers requests==2.32.3 # via @@ -490,8 +486,6 @@ rsa==4.9 # via google-auth s3transfer==0.13.0 # via boto3 -sacremoses==0.1.1 - # via transformers safetensors==0.4.5 # via # -r requirements.in @@ -499,6 +493,7 @@ safetensors==0.4.5 # diffusers # optimum-quanto # peft + # transformers semantic-version==2.10.0 # via gradio setuptools==75.6.0 @@ -525,7 +520,7 @@ sympy==1.13.3 # torch tenacity==9.0.0 # via streamlit -tokenizers==0.21.0 +tokenizers==0.15.2 # via # -r requirements.in # transformers @@ -551,9 +546,8 @@ tqdm==4.67.1 # huggingface-hub # nltk # peft - # sacremoses # transformers -transformers==4.17.0 +transformers==4.38.0 # via # -r requirements.in # deepcache diff --git a/stable_diffusion_server/__init__.py b/stable_diffusion_server/__init__.py new file mode 100644 index 0000000..0aaac71 --- /dev/null +++ b/stable_diffusion_server/__init__.py @@ -0,0 +1,17 @@ +"""Utilities for stable diffusion server.""" + +from .utils import log_time +from .prompt_utils import ( + remove_stopwords, + shorten_too_long_text, + shorten_prompt_for_retry, + stopwords, +) + +__all__ = [ + "log_time", + "remove_stopwords", + "shorten_too_long_text", + "shorten_prompt_for_retry", + "stopwords", +] diff --git a/stable_diffusion_server/prompt_utils.py b/stable_diffusion_server/prompt_utils.py new file mode 100644 index 0000000..26992ca --- /dev/null +++ b/stable_diffusion_server/prompt_utils.py @@ -0,0 +1,36 @@ +import nltk + +try: + stopwords = nltk.corpus.stopwords.words("english") +except LookupError: # pragma: no cover - external data + nltk.download("stopwords", quiet=True) + stopwords = nltk.corpus.stopwords.words("english") + + +def remove_stopwords(prompt: str) -> str: + """Return the prompt without stopwords.""" + return " ".join(word for word in prompt.split() if word not in stopwords) + + +def shorten_too_long_text(prompt: str) -> str: + """Trim prompts longer than 200 characters.""" + if len(prompt) > 200: + tokens = [w for w in prompt.split() if w not in stopwords] + prompt = " ".join(tokens) + if len(prompt) > 200: + prompt = prompt[:200] + return prompt + + +def shorten_prompt_for_retry(prompt: str) -> str: + """Remove stopwords and return roughly half of the words for a retry.""" + tokens = [w for w in prompt.split() if w not in stopwords] + return " ".join(tokens[: len(tokens) // 2]) + + +__all__ = [ + "stopwords", + "remove_stopwords", + "shorten_too_long_text", + "shorten_prompt_for_retry", +] diff --git a/stable_diffusion_server/utils.py b/stable_diffusion_server/utils.py index 40e0430..376705b 100644 --- a/stable_diffusion_server/utils.py +++ b/stable_diffusion_server/utils.py @@ -1,7 +1,7 @@ from contextlib import contextmanager from datetime import datetime -from loguru import logger +from loguru import logger @contextmanager diff --git a/tests/unit/test_prompt_utils.py b/tests/unit/test_prompt_utils.py new file mode 100644 index 0000000..3625b23 --- /dev/null +++ b/tests/unit/test_prompt_utils.py @@ -0,0 +1,23 @@ +import string +import pytest +from stable_diffusion_server import prompt_utils as sutils + + +def test_shorten_too_long_text(): + long_prompt = " ".join(["word" + str(i) for i in range(250)]) + shortened = sutils.shorten_too_long_text(long_prompt) + assert len(shortened) <= 200 + + +def test_shorten_prompt_for_retry_removes_stopwords(): + prompt = "the quick brown fox jumps over the lazy dog" * 3 + shortened = sutils.shorten_prompt_for_retry(prompt) + # ensure length reduced by at least half + assert len(shortened.split()) <= len(prompt.split()) // 2 + for stopword in sutils.stopwords: + assert stopword not in shortened.split() + + +def test_remove_stopwords(): + prompt = "the quick brown fox" + assert sutils.remove_stopwords(prompt) == "quick brown fox" diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py new file mode 100644 index 0000000..914f3d2 --- /dev/null +++ b/tests/unit/test_utils.py @@ -0,0 +1,15 @@ +from stable_diffusion_server import utils +from loguru import logger + + +def test_log_time(): + messages = [] + sink_id = logger.add(lambda m: messages.append(m), format="{message}") + try: + with utils.log_time("test"): + pass + finally: + logger.remove(sink_id) + + assert any("start" in m for m in messages) + assert any("end" in m for m in messages)