Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions lightllm/server/multimodal_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import librosa
import base64
import numpy as np
from typing import List
from typing import List, Tuple
from io import BytesIO
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
from PIL import Image, ImageFile
from fastapi import Request
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from lightllm.utils.error_utils import ClientDisconnected
Expand Down Expand Up @@ -146,9 +147,11 @@ async def preload(self, request: Request):
else:
raise ValueError(f"cannot read image which type is {self._type}!")

with Image.open(BytesIO(img_data)) as image:
self.image_w, self.image_h = image.size
image.verify() # verify后会失效
# Do pixel-level decoding verification in a thread pool to avoid blocking the event loop;
# Decoding is mainly done in the C libraries (libjpeg/libpng/libwebp), which releases the GIL,
# and multiple threads can achieve true parallelism.
loop = asyncio.get_running_loop()
self.image_w, self.image_h = await loop.run_in_executor(_IMAGE_VERIFY_POOL, _verify_image_bytes, img_data)

self._preload_data = img_data
return
Expand Down Expand Up @@ -220,3 +223,25 @@ def to_origin_dict(self):
ret["images"] = [i.to_origin_dict() for i in self.images]
ret["audios"] = [a.to_origin_dict() for a in self.audios]
return ret


_IMAGE_VERIFY_POOL = ThreadPoolExecutor(
max_workers=int(os.getenv("LIGHTLLM_IMAGE_VERIFY_WORKERS", 4)),
thread_name_prefix="img-verify",
)


def _verify_image_bytes(img_data: bytes) -> Tuple[int, int]:
"""
Verify image bytes in a thread pool to find truncated/corrupted images.
image.verify() only does header-level verification and cannot find truncated images;
image.load() reads the entire pixel data and truncated images will raise OSError.
"""
# Disable PIL's truncated image loading tolerance to make truncated images raise OSError in load()
# so that the frontend can intercept it and avoid crashing in the subsequent encode/preprocess stage.
ImageFile.LOAD_TRUNCATED_IMAGES = False

with Image.open(BytesIO(img_data)) as image:
w, h = image.size
image.load()
return w, h
128 changes: 128 additions & 0 deletions test/performance/bench_image_verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Benchmark: find the right LIGHTLLM_IMAGE_VERIFY_WORKERS value.

Methodology:
- Generate N independent JPEGs once (random pixels so libjpeg can't cheat).
- For each candidate pool size, create a FRESH ThreadPoolExecutor of that size,
submit all N decodes concurrently (no semaphore), measure wall time.
- This faithfully simulates production: at peak, many requests pile into
run_in_executor at once and the pool size is the real bottleneck.

This lets us compare different LIGHTLLM_IMAGE_VERIFY_WORKERS settings in one run.

Usage:
python test/performance/bench_image_verify.py
python test/performance/bench_image_verify.py --size 4096 --num 128 --pool_sizes 1,2,4,8,16,32,64
"""
import argparse
import asyncio
import os
import time
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import List

import numpy as np
from PIL import Image

from lightllm.server.multimodal_params import _verify_image_bytes


def make_big_jpeg(size: int, seed: int) -> bytes:
"""Random-noise JPEG so decode time is real (flat images decode too fast)."""
rng = np.random.default_rng(seed)
arr = rng.integers(0, 256, (size, size, 3), dtype=np.uint8)
buf = BytesIO()
Image.fromarray(arr).save(buf, format="JPEG", quality=85)
return buf.getvalue()


def bench_serial(images: List[bytes]) -> float:
t0 = time.perf_counter()
for img in images:
_verify_image_bytes(img)
return time.perf_counter() - t0


def bench_pool(images: List[bytes], pool_size: int) -> float:
"""Fresh pool of `pool_size`, submit all images concurrently, wait, time it."""
pool = ThreadPoolExecutor(max_workers=pool_size, thread_name_prefix=f"bench-{pool_size}")
try:
# Pre-warm threads so we don't time thread spawn-up
list(pool.map(lambda _: None, range(pool_size)))

async def run():
loop = asyncio.get_running_loop()
futs = [loop.run_in_executor(pool, _verify_image_bytes, img) for img in images]
await asyncio.gather(*futs)

t0 = time.perf_counter()
asyncio.run(run())
return time.perf_counter() - t0
finally:
pool.shutdown(wait=True)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--size", type=int, default=2048, help="image edge length, e.g. 2048/4096")
parser.add_argument("--num", type=int, default=64, help="total images to decode per run")
parser.add_argument(
"--pool_sizes",
default="1,2,4,8,16,32,64",
help="comma-separated pool sizes (LIGHTLLM_IMAGE_VERIFY_WORKERS candidates)",
)
parser.add_argument("--warmup", type=int, default=4)
parser.add_argument("--repeat", type=int, default=2, help="repeats per pool size, takes the best")
args = parser.parse_args()

print(f"CPU count : {os.cpu_count()}")
print(f"Image size : {args.size}x{args.size}")
print(f"Images per run : {args.num}")
print(f"Pool sizes to test : {args.pool_sizes}")
print(f"Repeats per pool : {args.repeat} (best time wins)\n")

print("Generating distinct test images ...")
images = [make_big_jpeg(args.size, seed=i) for i in range(args.num)]
avg_kb = sum(len(b) for b in images) / len(images) / 1024
print(f" per-image encoded size ~ {avg_kb:.1f} KB\n")

# Warmup libjpeg / page faults
for _ in range(args.warmup):
_verify_image_bytes(images[0])

# Baseline
serial_times = [bench_serial(images) for _ in range(args.repeat)]
serial_t = min(serial_times)
print(
f"[serial] {args.num} images in {serial_t * 1000:.1f} ms "
f"=> {args.num / serial_t:.1f} img/s, {serial_t / args.num * 1000:.2f} ms/img\n"
)

# Sweep pool size
print("[threaded] — vary LIGHTLLM_IMAGE_VERIFY_WORKERS")
print(f" {'pool':>6} | {'time(ms)':>10} | {'img/s':>8} | {'speedup':>8} | {'efficiency':>10}")
print(f" {'-' * 6}-+-{'-' * 10}-+-{'-' * 8}-+-{'-' * 8}-+-{'-' * 10}")
rows = []
for p in [int(x) for x in args.pool_sizes.split(",")]:
times = [bench_pool(images, p) for _ in range(args.repeat)]
t = min(times)
ips = args.num / t
speedup = serial_t / t
eff = speedup / p
rows.append((p, t, ips, speedup, eff))
print(f" {p:>6} | {t * 1000:>10.1f} | {ips:>8.1f} | {speedup:>7.2f}x | {eff * 100:>9.1f}%")

# Pick the sweet spot: largest speedup before efficiency drops below 50%
best = max(rows, key=lambda r: r[3])
knee = next((r for r in rows if r[4] < 0.5), rows[-1])
print(f"\nBest absolute throughput : pool={best[0]} ({best[2]:.1f} img/s, {best[3]:.2f}x)")
print(f"Diminishing-returns knee : pool={knee[0]} (efficiency drops <50% beyond here)")
print("\nHints:")
print(" - efficiency = speedup / pool_size. ~100% means perfect linear scaling.")
print(" - You usually want the smallest pool size that still gets >80% of peak throughput,")
print(" since extra threads only add scheduling + memory pressure.")
print(f" - Recommended: export LIGHTLLM_IMAGE_VERIFY_WORKERS={knee[0]}")


if __name__ == "__main__":
main()
72 changes: 72 additions & 0 deletions test/test_api/test_image_verify_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""验证残缺图片在 OpenAI /v1/chat/completions 接口被前端拦截为 4xx。

启动 server:
python -m lightllm.server.api_server --port 8000 --model_dir <your_vlm> --tp 1

运行:
python test/test_api/test_image_verify_api.py
"""
Comment on lines +1 to +8
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the rest of the codebase and to make it more accessible to non-Chinese speakers, please consider translating the docstring and comments into English.

Suggested change
"""验证残缺图片在 OpenAI /v1/chat/completions 接口被前端拦截为 4xx
启动 server
python -m lightllm.server.api_server --port 8000 --model_dir <your_vlm> --tp 1
运行
python test/test_api/test_image_verify_api.py
"""
"""Verify that truncated images are rejected with a 4xx error by the frontend
at the OpenAI /v1/chat/completions endpoint.
To run this test:
1. Start the server:
python -m lightllm.server.api_server --port 8000 --model_dir <your_vlm> --tp 1
2. Run the test script:
python test/test_api/test_image_verify_api.py
"""

import argparse
import base64
import os
from io import BytesIO

import requests
from PIL import Image


def make_jpeg(w=512, h=512) -> bytes:
buf = BytesIO()
Image.new("RGB", (w, h), color=(255, 0, 0)).save(buf, format="JPEG", quality=85)
return buf.getvalue()


def truncate(data: bytes, ratio: float = 0.3) -> bytes:
return data[: int(len(data) * (1 - ratio))]


def data_url(img_bytes: bytes) -> str:
return "data:image/jpeg;base64," + base64.b64encode(img_bytes).decode("ascii")


def call(url: str, model: str, img_bytes: bytes):
payload = {
"model": model,
"messages": [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url(img_bytes)}},
{"type": "text", "text": "Describe this image."},
],
}
],
"max_tokens": 16,
"temperature": 0.0,
}
return requests.post(f"{url}/v1/chat/completions", json=payload, timeout=30)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--url", default="http://127.0.0.1:8000")
parser.add_argument("--model", default="your_model_name")
args = parser.parse_args()

cases = [
("intact JPEG", make_jpeg(), 200),
("truncated JPEG", truncate(make_jpeg(1024, 1024), 0.3), 400),
("garbage bytes", os.urandom(4096), 400),
]

for name, img, expected in cases:
resp = call(args.url, args.model, img)
ok = resp.status_code == expected
print(f"[{'OK' if ok else 'FAIL'}] {name:18s} -> {resp.status_code} (expected {expected})")
if not ok:
print(f" body: {resp.text[:200]}")
raise SystemExit(1)


if __name__ == "__main__":
main()
Loading