Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions flux_schnell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from diffusers import FluxPipeline, FluxControlNetPipeline, ControlNetModel


def load_pipeline(model="black-forest-labs/FLUX.1-schnell"):
"""Load the Flux pipeline with optional controlnet."""
pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
return pipe


def load_controlnet(pipe, model="black-forest-labs/flux-controlnet-canny"):
controlnet = ControlNetModel.from_pretrained(model, torch_dtype=torch.bfloat16)
cpipe = FluxControlNetPipeline(controlnet=controlnet, **pipe.components)
cpipe.enable_model_cpu_offload()
return cpipe


def generate_image(pipe, prompt, seed=0, steps=4):
generator = torch.Generator("cpu").manual_seed(seed)
image = pipe(
prompt,
guidance_scale=0.0,
num_inference_steps=steps,
max_sequence_length=256,
generator=generator,
).images[0]
return image


if __name__ == "__main__":
prompt = "A cat holding a sign that says hello world"
pipe = load_pipeline()
image = generate_image(pipe, prompt)
image.save("flux-schnell.png")
432 changes: 124 additions & 308 deletions main.py

Large diffs are not rendered by default.

11 changes: 10 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Welcome to Simple Stable Diffusion Server, your go-to solution for AI-powered im
to an R2 bucket via the S3 API, but Google Cloud Storage remains supported.
- **Versatile Applications**: Perfect for AI art generation, style transfer, and image inpainting. Bring any SDXL/diffusers model.
- **Easy to Use**: Simple interface for generating images in Gradio locally and easy to use FastAPI docs/server for advanced users.
- **Prompt Utilities**: Helper functions for trimming and cleaning prompts live in `stable_diffusion_server/prompt_utils.py`.

For a hosted AI Art Generation experience, check out our [AI Art Generator and Search Engine](https://aiart-generator.art), which offers advanced features like video creation and 2K upscaled images.

Expand Down Expand Up @@ -52,9 +53,17 @@ Launch the user-friendly Gradio interface:
```
python gradio_ui.py
```
Go to
Go to
http://127.0.0.1:7860

### Flux Schnell Example
The server now uses the lightweight Flux Schnell model by default. You can quickly
test the model with the helper script:
```bash
python flux_schnell.py
```
This will generate `flux-schnell.png` using bf16 precision.


![gradio demo](gradioimg.png)

Expand Down
2 changes: 1 addition & 1 deletion requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ tokenizers
--extra-index-url https://download.pytorch.org/whl/cu122
torch
tqdm
transformers
transformers==4.38.0
#triton
typing_extensions
urllib3
Expand Down
14 changes: 4 additions & 10 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ click==8.1.7
# via
# -r requirements.in
# nltk
# sacremoses
# streamlit
# uvicorn
cmake==3.31.1
Expand Down Expand Up @@ -233,9 +232,7 @@ jmespath==1.0.1
# boto3
# botocore
joblib==1.4.2
# via
# nltk
# sacremoses
# via nltk
jsonschema==4.23.0
# via altair
jsonschema-specifications==2024.10.1
Expand Down Expand Up @@ -467,7 +464,6 @@ regex==2024.11.6
# -r requirements.in
# diffusers
# nltk
# sacremoses
# transformers
requests==2.32.3
# via
Expand All @@ -490,15 +486,14 @@ rsa==4.9
# via google-auth
s3transfer==0.13.0
# via boto3
sacremoses==0.1.1
# via transformers
safetensors==0.4.5
# via
# -r requirements.in
# accelerate
# diffusers
# optimum-quanto
# peft
# transformers
semantic-version==2.10.0
# via gradio
setuptools==75.6.0
Expand All @@ -525,7 +520,7 @@ sympy==1.13.3
# torch
tenacity==9.0.0
# via streamlit
tokenizers==0.21.0
tokenizers==0.15.2
# via
# -r requirements.in
# transformers
Expand All @@ -551,9 +546,8 @@ tqdm==4.67.1
# huggingface-hub
# nltk
# peft
# sacremoses
# transformers
transformers==4.17.0
transformers==4.38.0
# via
# -r requirements.in
# deepcache
Expand Down
17 changes: 17 additions & 0 deletions stable_diffusion_server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Utilities for stable diffusion server."""

from .utils import log_time
from .prompt_utils import (
remove_stopwords,
shorten_too_long_text,
shorten_prompt_for_retry,
stopwords,
)

__all__ = [
"log_time",
"remove_stopwords",
"shorten_too_long_text",
"shorten_prompt_for_retry",
"stopwords",
]
36 changes: 36 additions & 0 deletions stable_diffusion_server/prompt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import nltk

try:
stopwords = nltk.corpus.stopwords.words("english")
except LookupError: # pragma: no cover - external data
nltk.download("stopwords", quiet=True)
stopwords = nltk.corpus.stopwords.words("english")


def remove_stopwords(prompt: str) -> str:
"""Return the prompt without stopwords."""
return " ".join(word for word in prompt.split() if word not in stopwords)


def shorten_too_long_text(prompt: str) -> str:
"""Trim prompts longer than 200 characters."""
if len(prompt) > 200:
tokens = [w for w in prompt.split() if w not in stopwords]
prompt = " ".join(tokens)
if len(prompt) > 200:
prompt = prompt[:200]
return prompt


def shorten_prompt_for_retry(prompt: str) -> str:
"""Remove stopwords and return roughly half of the words for a retry."""
tokens = [w for w in prompt.split() if w not in stopwords]
return " ".join(tokens[: len(tokens) // 2])


__all__ = [
"stopwords",
"remove_stopwords",
"shorten_too_long_text",
"shorten_prompt_for_retry",
]
2 changes: 1 addition & 1 deletion stable_diffusion_server/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import contextmanager
from datetime import datetime

from loguru import logger
from loguru import logger


@contextmanager
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/test_prompt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import string
import pytest
from stable_diffusion_server import prompt_utils as sutils


def test_shorten_too_long_text():
long_prompt = " ".join(["word" + str(i) for i in range(250)])
shortened = sutils.shorten_too_long_text(long_prompt)
assert len(shortened) <= 200


def test_shorten_prompt_for_retry_removes_stopwords():
prompt = "the quick brown fox jumps over the lazy dog" * 3
shortened = sutils.shorten_prompt_for_retry(prompt)
# ensure length reduced by at least half
assert len(shortened.split()) <= len(prompt.split()) // 2
for stopword in sutils.stopwords:
assert stopword not in shortened.split()


def test_remove_stopwords():
prompt = "the quick brown fox"
assert sutils.remove_stopwords(prompt) == "quick brown fox"
15 changes: 15 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from stable_diffusion_server import utils
from loguru import logger


def test_log_time():
messages = []
sink_id = logger.add(lambda m: messages.append(m), format="{message}")
try:
with utils.log_time("test"):
pass
finally:
logger.remove(sink_id)

assert any("start" in m for m in messages)
assert any("end" in m for m in messages)
Loading