-
Notifications
You must be signed in to change notification settings - Fork 498
Description
Is this a new feature, an improvement, or a change to existing functionality?
New Feature
How would you describe the priority of this feature request
Medium
Please provide a clear description of problem this feature solves
The current huggingface LLM provider only supports local model execution via transformers. Extending the provider to support remote inference enables connecting to HuggingFace's Serverless Inference API, dedicated Inference Endpoints, and self-hosted TGI servers. This allows workflows to use HuggingFace models without requiring local GPU resources.
Describe your ideal solution
Create a new huggingface_inference LLM provider type with its own config and client implementation. The config should include model_id, api_key, and endpoint_url fields to support connecting to HuggingFace's Serverless Inference API, dedicated Inference Endpoints, and self-hosted TGI servers. The provider should use huggingface_hub.InferenceClient for HTTP-based inference. A LangChain-compatible wrapper class should be added along with the corresponding client registration in the nvidia_nat_langchain plugin.
Config:
class HuggingFaceInferenceConfig(LLMBaseConfig, name="huggingface_inference"):
model_id: str = Field(description="HuggingFace model ID")
api_key: OptionalSecretStr = Field(default=None, description="HuggingFace API token")
endpoint_url: str | None = Field(default=None, description="TGI or Inference Endpoint URL")
max_new_tokens: int = Field(default=128)
temperature: float = Field(default=0.0)Provider:
@register_llm_provider(config_type=HuggingFaceInferenceConfig)
async def huggingface_inference_provider(config, builder):
yield LLMProviderInfo(config=config, description=f"HuggingFace Inference: {config.model_id}")Client:
class HuggingFaceInferenceModel(BaseChatModel):
"""LangChain-compatible wrapper for HuggingFace Inference API."""
def __init__(self, config: HuggingFaceInferenceConfig):
from huggingface_hub import InferenceClient
self._client = InferenceClient(
model=config.model_id,
token=config.api_key,
base_url=config.endpoint_url
)
self._config = config
async def _agenerate(self, messages, **kwargs):
response = self._client.chat_completion(messages=messages, ...)
# ... return ChatResultAdditional context
No response
Code of Conduct
- I agree to follow this project's Code of Conduct
- I have searched the open feature requests and have found no duplicates for this feature request