There is a hello_world PyWorker implementation under workers/hello_world. This PyWorker is
created for an LLM model server that runs on port 5001 has two API endpoints:
/generate: generates an full response to the prompt and sends a JSON response/generate_stream: streams a response one token at a time
Both of these endpoints take the same API JSON payload:
{
"prompt": String,
"max_response_tokens": Number | null
}
We want the PyWorker to also expose two endpoints that correspond to the above endpoints.
All PyWorkers have four files:
.
└── workers
└── hello_world
├── __init__.py
├── data_types.py # contains data types representing model API endpoints
├── server.py # contains endpoint handlers
└── test_load.py # script for load testing
All of the classes follow strict type hinting. It is recommended that you type hint all of your function.
This will allow your IDE or VSCode with pyright plugin to find any type errors in your implementation.
You can also install pyright with sudo npm install -g pyright and run pyright in the root of the project to find
any type errors.
This file defines the structure of the data your model server expects (its API contract) and, critically, how PyWorker interprets that data for autoscaling purposes. You define Python data classes that mirror the JSON payloads your model's API uses.
These classes must inherit from lib.data_types.ApiPayload. This inheritance is not just for structure; it's how PyWorker knows how to:
- Parse Incoming Requests: Convert JSON from clients into usable Python objects.
- Calculate Workload: Determine the computational cost of a request.
- Generate Test Data: Create realistic inputs for benchmarking.
- Format Requests for the Model Server: Prepare data for the underlying model.
import dataclasses
import random
from typing import Dict, Any
from transformers import OpenAIGPTTokenizer # used to count tokens in a prompt
import nltk # used to download a list of all words to generate a random prompt and benchmark the LLM model
from lib.data_types import ApiPayload
nltk.download("words")
WORD_LIST = nltk.corpus.words.words()
# you can use any tokenizer that fits your LLM. `openai-gpt` is free to use and is a good fit for most LLMs
tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt")
@dataclasses.dataclass
class InputData(ApiPayload):
prompt: str
max_response_tokens: int
@classmethod
def for_test(cls) -> "ApiPayload":
"""defines how create a payload for load testing"""
prompt = " ".join(random.choices(WORD_LIST, k=int(250)))
return cls(prompt=prompt, max_response_tokens=300)
def generate_payload_json(self) -> Dict[str, Any]:
"""defines how to convert an ApiPayload to JSON that will be sent to model API"""
return dataclasses.asdict(self)
def count_workload(self) -> float:
"""defines how to calculate workload for a payload"""
return len(tokenizer.tokenize(self.prompt))
@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "InputData":
"""
defines how to transform JSON data to AuthData and payload type,
in this case `InputData` defined above represents the data sent to the model API.
AuthData is data generated by autoscaler in order to authenticate payloads.
In this case, the transformation is simple and 1:1. That is not always the case. See comfyui's PyWorker
for more complicated examples
"""
errors = {}
for param in inspect.signature(cls).parameters:
if param not in json_msg:
errors[param] = "missing parameter"
if errors:
raise JsonDataException(errors)
return cls(
**{
k: v
for k, v in json_msg.items()
if k in inspect.signature(cls).parameters
}
)This section guides you through creating the core of your custom model API: the EndpointHandler. Think of EndpointHandler as the bridge between incoming requests from users and your underlying model. It's the key to making your model accessible and scalable.
Why use an EndpointHandler?
- Organized Request Handling: It provides a structured way to handle different types of requests (like generating text, generating images, or performing other model-specific tasks).
- Scalability: By separating request handling from the model itself, you can easily scale your API to handle many concurrent users.
- Flexibility: You can customize how requests are processed, validated, and transformed before being sent to your model.
- Standard Interface: It provides a consistent interface for interacting with your model, regardless of the underlying implementation.
For every model API endpoint you want to expose (e.g., /generate, /generate_stream), you'll implement an EndpointHandler. This class is responsible for:
The EndpointHandler achieves this through several key methods:
- Receiving and validating incoming requests (
get_data_from_request): This method ensures the request contains the necessary data (authentication and payload) and is in the correct format. It's the entry point for all requests. - Defining the endpoint (
endpoint): This method specifies the URL endpoint on the model API server where requests will be sent (e.g.,/generate). - Specifying the payload type (
payload_cls): This method indicates the specificApiPayloadclass used for this endpoint, defining the structure of the request data. - Creating benchmark payloads (
make_benchmark_payload): This method creates payloads specifically for benchmarking the model's performance. - Handling the model's response (
generate_client_response): This method takes the response from the model API server and transforms it into the format expected by the client making the request to your PyWorker. This allows you to customize the output as needed.
The EndpointHandler class has several abstract functions that you must implement to define the behavior of your specific endpoints. Here, we'll implement two common endpoints: /generate (for synchronous requests) and /generate_stream (for streaming responses):
"""
AuthData is a dataclass that represents Authentication data sent from Autoscaler to client requesting a route.
When a user requests a route from autoscaler, see Vast's Serverless documentation for how routing and AuthData
work.
When a user receives a route for this PyWorker, they'll call PyWorkers API with the following JSON:
{
auth_data: AuthData,
payload : InputData # defined above
}
"""
from aiohttp import web
from lib.data_types import EndpointHandler, JsonDataException
from lib.server import start_server
from .data_types import InputData
# This class is the implementer for the '/generate' endpoint of model API
@dataclasses.dataclass
class GenerateHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
# the API endpoint
return "/generate"
@classmethod
def payload_cls(cls) -> Type[InputData]:
"""this function should just return ApiPayload subclass used by this handler"""
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
"""
defines how to convert `InputData` defined above, to
JSON data to be sent to the model API. This function too is a simple dataclass -> JSON, but
can be more complicated, See comfyui for an example
"""
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
"""
defines how to generate an InputData for benchmarking. This needs to be defined in only
one EndpointHandler, the one passed to the backend as the benchmark handler. Here we use the .for_test()
method on InputData. However, in some cases you might need to fine tune your InputData used for
benchmarking to closely resemble the average request users call the endpoint with in order to get best
autoscaling performance
"""
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
"""
defines how to convert a model API response to a response to PyWorker client
"""
_ = client_request
match model_response.status:
case 200:
log.debug("SUCCESS")
data = await model_response.json()
return web.json_response(data=data)
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
We also handle GenerateStreamHandler for streaming responses. It is identical to GenerateHandler, except for
the endpoint name and how we create a web response, as it is a streaming response:
class GenerateStreamHandler(EndpointHandler[InputData]):
@property
def endpoint(self) -> str:
return "/generate_stream"
@classmethod
def payload_cls(cls) -> Type[InputData]:
return InputData
def generate_payload_json(self, payload: InputData) -> Dict[str, Any]:
return dataclasses.asdict(payload)
def make_benchmark_payload(self) -> InputData:
return InputData.for_test()
async def generate_client_response(
self, client_request: web.Request, model_response: ClientResponse
) -> Union[web.Response, web.StreamResponse]:
match model_response.status:
case 200:
log.debug("Streaming response...")
res = web.StreamResponse()
res.content_type = "text/event-stream"
await res.prepare(client_request)
async for chunk in model_response.content:
await res.write(chunk)
await res.write_eof()
log.debug("Done streaming response")
return res
case code:
log.debug("SENDING RESPONSE: ERROR: unknown code")
return web.Response(status=code)
You can now instantiate a Backend and use it to handle requests.
from lib.backend import Backend, LogAction
# the url and port of model API
MODEL_SERVER_URL = "http://0.0.0.0:5001"
# This is the log line that is emitted once the server has started
MODEL_SERVER_START_LOG_MSG = "server has started"
MODEL_SERVER_ERROR_LOG_MSGS = [
"Exception: corrupted model file" # message in the logs indicating the unrecoverable error
]
backend = Backend(
model_server_url=MODEL_SERVER_URL,
# location of model log file
model_log_file=os.environ["MODEL_LOG"],
# for some model backends that can only handle one request at a time, be sure to set this to False to
# let PyWorker handling queueing requests.
allow_parallel_requests=True,
# give the backend an EndpointHandler instance that is used for benchmarking
# number of benchmark run and number of words for a random benchmark run are given
benchmark_handler=GenerateHandler(benchmark_runs=3, benchmark_words=256),
# defines how to handle specific log messages. See docstring of LogAction for details
log_actions=[
(LogAction.ModelLoaded, MODEL_SERVER_START_LOG_MSG),
(LogAction.Info, '"message":"Download'),
*[
(LogAction.ModelError, error_msg)
for error_msg in MODEL_SERVER_ERROR_LOG_MSGS
],
],
)
# this is a simple ping handler for PyWorker
async def handle_ping(_: web.Request):
return web.Response(body="pong")
# this is a handler for forwarding a health check to model API
async def handle_healthcheck(_: web.Request):
healthcheck_res = await backend.session.get("/healthcheck")
return web.Response(body=healthcheck_res.content, status=healthcheck_res.status)
routes = [
web.post("/generate", backend.create_handler(GenerateHandler())),
web.post("/generate_stream", backend.create_handler(GenerateStreamHandler())),
web.get("/ping", handle_ping),
web.get("/healthcheck", handle_healthcheck),
]
if __name__ == "__main__":
# start server, called from start_server.sh
start_server(backend, routes)Here you can create a script that allows you test an endpoint group running instances with this PyWorker
from lib.test_harness import run
from .data_types import InputData
WORKER_ENDPOINT = "/generate"
if __name__ == "__main__":
run(InputData.for_test(), WORKER_ENDPOINT)You can then run the following command from the root of this repo to load test endpoint group:
# sends 1000 requests at the rate of 0.5 requests per second
python3 workers.hello_world.test_load -n 1000 -rps 0.5 -k "$API_KEY" -e "$ENDPOINT_GROUP_NAME"