From c2bc691a19b2cf7ba946e78f872ad2fe1fa38baa Mon Sep 17 00:00:00 2001 From: chenzihong <522023320011@smail.nju.edu.cn> Date: Fri, 6 Feb 2026 18:51:55 +0800 Subject: [PATCH] feat: add ray_serve as llm provider --- graphgen/common/init_llm.py | 4 + graphgen/models/llm/api/ray_serve_client.py | 88 +++++++++++++++++++ .../models/llm/local/ray_serve_deployment.py | 84 ++++++++++++++++++ 3 files changed, 176 insertions(+) create mode 100644 graphgen/models/llm/api/ray_serve_client.py create mode 100644 graphgen/models/llm/local/ray_serve_deployment.py diff --git a/graphgen/common/init_llm.py b/graphgen/common/init_llm.py index 56bffedf..bb022f76 100644 --- a/graphgen/common/init_llm.py +++ b/graphgen/common/init_llm.py @@ -46,6 +46,10 @@ def __init__(self, backend: str, config: Dict[str, Any]): from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper self.llm_instance = VLLMWrapper(**config) + elif backend == "ray_serve": + from graphgen.models.llm.api.ray_serve_client import RayServeClient + + self.llm_instance = RayServeClient(**config) else: raise NotImplementedError(f"Backend {backend} is not implemented yet.") diff --git a/graphgen/models/llm/api/ray_serve_client.py b/graphgen/models/llm/api/ray_serve_client.py new file mode 100644 index 00000000..1fa38271 --- /dev/null +++ b/graphgen/models/llm/api/ray_serve_client.py @@ -0,0 +1,88 @@ +from typing import Any, List, Optional + +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper +from graphgen.bases.datatypes import Token + + +class RayServeClient(BaseLLMWrapper): + """ + A client to interact with a Ray Serve deployment. + """ + + def __init__( + self, + *, + app_name: Optional[str] = None, + deployment_name: Optional[str] = None, + serve_backend: Optional[str] = None, + **kwargs: Any, + ): + try: + from ray import serve + except ImportError as e: + raise ImportError( + "Ray is not installed. Please install it with `pip install ray[serve]`." + ) from e + + super().__init__(**kwargs) + + # Try to get existing handle first + self.handle = None + if app_name: + try: + self.handle = serve.get_app_handle(app_name) + except Exception: + pass + elif deployment_name: + try: + self.handle = serve.get_deployment(deployment_name).get_handle() + except Exception: + pass + + # If no handle found, try to deploy if serve_backend is provided + if self.handle is None: + if serve_backend: + if not app_name: + import uuid + + app_name = f"llm_app_{serve_backend}_{uuid.uuid4().hex[:8]}" + + print( + f"Deploying Ray Serve app '{app_name}' with backend '{serve_backend}'..." + ) + from graphgen.models.llm.local.ray_serve_deployment import LLMDeployment + + # Filter kwargs to avoid passing unrelated args if necessary, + # but LLMDeployment config accepts everything for now. + # Note: We need to pass kwargs as the config dict. + deployment = LLMDeployment.bind(backend=serve_backend, config=kwargs) + serve.run(deployment, name=app_name, route_prefix=f"/{app_name}") + self.handle = serve.get_app_handle(app_name) + elif app_name or deployment_name: + raise ValueError( + f"Ray Serve app/deployment '{app_name or deployment_name}' " + "not found and 'serve_backend' not provided to deploy it." + ) + else: + raise ValueError( + "Either 'app_name', 'deployment_name' or 'serve_backend' " + "must be provided for RayServeClient." + ) + + async def generate_answer( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> str: + """Generate answer from the model.""" + return await self.handle.generate_answer.remote(text, history, **extra) + + async def generate_topk_per_token( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + """Generate top-k tokens for the next token prediction.""" + return await self.handle.generate_topk_per_token.remote(text, history, **extra) + + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + """Generate probabilities for each token in the input.""" + return await self.handle.generate_inputs_prob.remote(text, history, **extra) diff --git a/graphgen/models/llm/local/ray_serve_deployment.py b/graphgen/models/llm/local/ray_serve_deployment.py new file mode 100644 index 00000000..0607d7eb --- /dev/null +++ b/graphgen/models/llm/local/ray_serve_deployment.py @@ -0,0 +1,84 @@ +import os +from typing import Any, Dict, List, Optional + +from ray import serve +from starlette.requests import Request + +from graphgen.bases.datatypes import Token +from graphgen.models.tokenizer import Tokenizer + + +@serve.deployment +class LLMDeployment: + def __init__(self, backend: str, config: Dict[str, Any]): + self.backend = backend + + # Initialize tokenizer if needed + tokenizer_model = os.environ.get("TOKENIZER_MODEL", "cl100k_base") + if "tokenizer" not in config: + tokenizer = Tokenizer(model_name=tokenizer_model) + config["tokenizer"] = tokenizer + + if backend == "vllm": + from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper + + self.llm_instance = VLLMWrapper(**config) + elif backend == "huggingface": + from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper + + self.llm_instance = HuggingFaceWrapper(**config) + elif backend == "sglang": + from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper + + self.llm_instance = SGLangWrapper(**config) + else: + raise NotImplementedError( + f"Backend {backend} is not implemented for Ray Serve yet." + ) + + async def generate_answer( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> str: + return await self.llm_instance.generate_answer(text, history, **extra) + + async def generate_topk_per_token( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + return await self.llm_instance.generate_topk_per_token(text, history, **extra) + + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + return await self.llm_instance.generate_inputs_prob(text, history, **extra) + + async def __call__(self, request: Request) -> Dict: + try: + data = await request.json() + text = data.get("text") + history = data.get("history") + method = data.get("method", "generate_answer") + kwargs = data.get("kwargs", {}) + + if method == "generate_answer": + result = await self.generate_answer(text, history, **kwargs) + elif method == "generate_topk_per_token": + result = await self.generate_topk_per_token(text, history, **kwargs) + elif method == "generate_inputs_prob": + result = await self.generate_inputs_prob(text, history, **kwargs) + else: + return {"error": f"Method {method} not supported"} + + return {"result": result} + except Exception as e: + return {"error": str(e)} + + +def app_builder(args: Dict[str, str]) -> Any: + """ + Builder function for 'serve run'. + Usage: serve run graphgen.models.llm.local.ray_serve_deployment:app_builder backend=vllm model=... + """ + # args comes from the command line key=value pairs + backend = args.pop("backend", "vllm") + # remaining args are treated as config + return LLMDeployment.bind(backend=backend, config=args)