diff --git a/flash/create-endpoints.mdx b/flash/create-endpoints.mdx index 04b2b10b..b935c461 100644 --- a/flash/create-endpoints.mdx +++ b/flash/create-endpoints.mdx @@ -65,12 +65,13 @@ from runpod_flash import Endpoint, GpuType api = Endpoint( name="inference-api", gpu=GpuType.NVIDIA_GEFORCE_RTX_4090, - workers=(1, 5) + workers=(1, 5), + dependencies=["torch"] ) @api.post("/predict") async def predict(data: dict) -> dict: - import torch + import torch # Import inside the function body # Run inference return {"prediction": "result"}