|
| 1 | +import time |
| 2 | +from typing import List, Tuple |
| 3 | +from datetime import datetime |
| 4 | +from llama_cpp import Llama |
| 5 | +from pydantic import BaseModel, Field, field_validator |
| 6 | +from rich.console import Console |
| 7 | +from rich.panel import Panel |
| 8 | +from rich.progress import track |
1 | 9 | import typer |
2 | | -import subprocess |
| 10 | +import requests |
3 | 11 | import json |
4 | 12 |
|
5 | | -def benchmark(model: str, output_format: str = "md"): |
6 | | - """ |
7 | | - Runs Llama-Bench with specified parameters. |
8 | | - """ |
9 | | - typer.echo(f"⏳ Running Llama-Bench for model: {model}") |
| 13 | +console = Console() |
10 | 14 |
|
11 | | - command = [ |
12 | | - "./llama-bench", |
13 | | - "-m", model, |
14 | | - "-o", output_format |
15 | | - ] |
| 15 | +class Message(BaseModel): |
| 16 | + role: str |
| 17 | + content: str |
16 | 18 |
|
| 19 | +class LlamaResponse(BaseModel): |
| 20 | + model: str |
| 21 | + created_at: datetime |
| 22 | + message: Message |
| 23 | + done: bool |
| 24 | + total_duration: float |
| 25 | + load_duration: float = 0.0 |
| 26 | + eval_count: int |
| 27 | + eval_duration: float |
| 28 | + |
| 29 | +def load_model(model_path: str) -> Tuple[Llama, float]: |
| 30 | + console.print(Panel.fit(f"[cyan]Loading model: {model_path}[/]", title="[bold magenta]Solo Server[/]")) |
| 31 | + start_time = time.time() |
| 32 | + model = Llama(model_path=model_path) |
| 33 | + load_duration = time.time() - start_time |
| 34 | + return model, load_duration |
| 35 | + |
| 36 | +def api_response(model: str, prompt: str, url: str, server_type:str = None) -> dict: |
| 37 | + payload = { |
| 38 | + "model": model, |
| 39 | + "prompt": prompt, |
| 40 | + } |
| 41 | + |
| 42 | + if server_type == "ollama": |
| 43 | + payload["model"] = model.lower() |
| 44 | + payload["stream"] = False |
| 45 | + headers = {"Content-Type": "application/json"} |
| 46 | + start_time = time.time() |
17 | 47 | try: |
18 | | - result = subprocess.run(command, check=True, capture_output=True, text=True) |
19 | | - typer.echo(result.stdout) |
20 | | - except subprocess.CalledProcessError as e: |
21 | | - typer.echo(f"❌ Benchmark failed: {e.stderr}", err=True) |
| 48 | + response = requests.post(url, data=json.dumps(payload), headers=headers) |
| 49 | + response.raise_for_status() |
| 50 | + data = response.json() |
| 51 | + # Add eval_duration if not present |
| 52 | + if "eval_duration" not in data: |
| 53 | + data["eval_duration"] = time.time() - start_time |
| 54 | + return data |
| 55 | + except requests.exceptions.RequestException as e: |
| 56 | + return {"error": str(e)} |
| 57 | + |
| 58 | +def run_benchmark(server_type: str, model: object, model_name: str, prompt: str, load_duration: float) -> LlamaResponse: |
| 59 | + content = "" |
| 60 | + if server_type == "llama.cpp": |
| 61 | + start_time = time.time() |
| 62 | + response = model(prompt, stop=["\n"], echo=False) |
| 63 | + eval_duration = time.time() - start_time |
| 64 | + content = response["choices"][0]["text"] |
| 65 | + else: |
| 66 | + url = "http://localhost:11434/api/generate" if server_type == "ollama" else "http://localhost:8000/v1/completions" |
| 67 | + response = api_response(model_name, prompt, url, server_type) |
| 68 | + |
| 69 | + if server_type == "vllm": |
| 70 | + if "choices" in response and "message" in response["choices"][0]: |
| 71 | + content = response["choices"][0]["message"]["content"] |
| 72 | + else: |
| 73 | + content = response["choices"][0]["text"] |
| 74 | + eval_duration = response.get("eval_duration", 0.0) |
| 75 | + else: |
| 76 | + content = response.get("response", "") |
| 77 | + load_duration = response.get("load_duration", 0.0) * 1e-9 # Convert nanoseconds to seconds |
| 78 | + eval_duration = response.get("eval_duration", 0.0) * 1e-9 # Convert nanoseconds to seconds |
| 79 | + |
| 80 | + message = Message(role="assistant", content=content) |
| 81 | + |
| 82 | + return LlamaResponse( |
| 83 | + model=model_name, |
| 84 | + created_at=datetime.now(), |
| 85 | + message=message, |
| 86 | + done=True, |
| 87 | + load_duration=load_duration, |
| 88 | + total_duration=load_duration + eval_duration, |
| 89 | + eval_count=len(content.split()), |
| 90 | + eval_duration=eval_duration, |
| 91 | + ) |
| 92 | + |
| 93 | +def inference_stats(model_response: LlamaResponse): |
| 94 | + # Add checks for zero duration |
| 95 | + response_ts = 0.0 if model_response.eval_duration == 0 else model_response.eval_count / model_response.eval_duration |
| 96 | + total_ts = 0.0 if model_response.total_duration == 0 else model_response.eval_count / model_response.total_duration |
| 97 | + |
| 98 | + console.print( |
| 99 | + Panel.fit( |
| 100 | + f"[bold magenta]{model_response.model}[/]\n" |
| 101 | + f"[green]Response:[/] {response_ts:.2f} tokens/s\n" |
| 102 | + f"[blue]Total:[/] {total_ts:.2f} tokens/s\n\n" |
| 103 | + f"[yellow]Stats:[/]\n" |
| 104 | + f" - Response tokens: {model_response.eval_count}\n" |
| 105 | + f" - Model load time: {model_response.load_duration:.2f}s\n" |
| 106 | + f" - Response time: {model_response.eval_duration:.2f}s\n" |
| 107 | + f" - Total time: {model_response.total_duration:.2f}s", |
| 108 | + title="[bold cyan]Benchmark Results[/]", |
| 109 | + ) |
| 110 | + ) |
| 111 | + |
| 112 | +def average_stats(responses: List[LlamaResponse]): |
| 113 | + if not responses: |
| 114 | + console.print("[red]No stats to average.[/]") |
| 115 | + return |
| 116 | + |
| 117 | + avg_response = LlamaResponse( |
| 118 | + model=responses[0].model, |
| 119 | + created_at=datetime.now(), |
| 120 | + message=Message(role="system", content=f"Average stats across {len(responses)} runs"), |
| 121 | + done=True, |
| 122 | + total_duration=sum(r.total_duration for r in responses) / len(responses), |
| 123 | + load_duration=sum(r.load_duration for r in responses) / len(responses), |
| 124 | + eval_count=sum(r.eval_count for r in responses) // len(responses), |
| 125 | + eval_duration=sum(r.eval_duration for r in responses) / len(responses), |
| 126 | + ) |
| 127 | + inference_stats(avg_response) |
| 128 | + |
| 129 | +def benchmark( |
| 130 | + server_type: str = typer.Option(None, "-s", help="Type of server (e.g., ollama, vllm, llama.cpp)."), |
| 131 | + model_name: str = typer.Option(None, "-m", help="Name of the model."), |
| 132 | + prompts: List[str] = typer.Option(["Why is the sky blue?", "Write a report on the financials of Apple Inc.", |
| 133 | + "Tell me about San Francisco"], "-p", help="List of prompts to use for benchmarking."), |
| 134 | +): |
| 135 | + if not server_type: |
| 136 | + server_type = typer.prompt("Enter server type (ollama, vllm, llama.cpp)") |
| 137 | + if not model_name: |
| 138 | + model_name = typer.prompt("Enter model name") |
| 139 | + |
| 140 | + console.print(f"\n[bold cyan]Starting Solo Server Benchmark for {server_type} with model {model_name}...[/]") |
| 141 | + |
| 142 | + model = None |
| 143 | + load_duration = 0.0 |
| 144 | + if server_type == "llama.cpp": |
| 145 | + model, load_duration = load_model(model_name) |
| 146 | + responses: List[LlamaResponse] = [] |
| 147 | + for prompt in track(prompts, description="[cyan]Running benchmarks..."): |
| 148 | + response = run_benchmark(server_type, model, model_name, prompt, load_duration) |
| 149 | + responses.append(response) |
| 150 | + inference_stats(response) |
| 151 | + |
| 152 | + average_stats(responses) |
0 commit comments