From c8f1069eff8ca70aff0c8262233882fc04161aa4 Mon Sep 17 00:00:00 2001 From: efrat99 Date: Mon, 10 Nov 2025 13:33:00 +0200 Subject: [PATCH] Make flink service completely generic --- configs/mapping.yml | 11 + .../models/fruit_defect/__init__.py => docker | 0 docker-compose.yml | 98 +++--- .../fruit_defect/Docs/DATASET_SPEC.md | 0 .../models => }/fruit_defect/Docs/USAGE.md | 0 .../inference => fruit_defect}/__init__.py | 0 .../fruit_defect/configs/fruit_defect.yaml | 0 services/fruit_defect/inference/__init__.py | 0 .../inference/infer_fruit_defect.py | 45 +-- .../models => }/fruit_defect/requirements.txt | 0 .../fruit_defect/scripts/eval_cls.py | 0 .../scripts/export_torchscript.py | 0 .../fruit_defect/scripts/extract_mistakes.py | 0 .../training/train_fruit_defect_cls.py | 0 .../adapters/fruit_defect_runner.py | 51 +-- services/inference_http/app.py | 105 +++--- services/inference_http/model_registry.py | 17 +- services/inference_http/requirements.txt | 2 + .../{ => fruit_defect}/fruit_cls_best.ts | Bin streaming/flink/Dockerfile.flink-py | 2 + streaming/flink/README_Adding_New_Model.md | 231 ++++++++++++ streaming/flink/jobs/http_dispatcher.py | 331 ++++++++---------- 22 files changed, 546 insertions(+), 347 deletions(-) create mode 100644 configs/mapping.yml rename services/inference_http/models/fruit_defect/__init__.py => docker (100%) rename services/{inference_http/models => }/fruit_defect/Docs/DATASET_SPEC.md (100%) rename services/{inference_http/models => }/fruit_defect/Docs/USAGE.md (100%) rename services/{inference_http/models/fruit_defect/inference => fruit_defect}/__init__.py (100%) rename services/{inference_http/models => }/fruit_defect/configs/fruit_defect.yaml (100%) create mode 100644 services/fruit_defect/inference/__init__.py rename services/{inference_http/models => }/fruit_defect/inference/infer_fruit_defect.py (92%) rename services/{inference_http/models => }/fruit_defect/requirements.txt (100%) rename services/{inference_http/models => }/fruit_defect/scripts/eval_cls.py (100%) rename services/{inference_http/models => }/fruit_defect/scripts/export_torchscript.py (100%) rename services/{inference_http/models => }/fruit_defect/scripts/extract_mistakes.py (100%) rename services/{inference_http/models => }/fruit_defect/training/train_fruit_defect_cls.py (100%) rename services/inference_http/weights/{ => fruit_defect}/fruit_cls_best.ts (100%) create mode 100644 streaming/flink/README_Adding_New_Model.md diff --git a/configs/mapping.yml b/configs/mapping.yml new file mode 100644 index 000000000..689be95f1 --- /dev/null +++ b/configs/mapping.yml @@ -0,0 +1,11 @@ +topics: + imagery.new.fruit: + name: fruit_defect + runner: http + +http: + url_template: "http://inference-http:8000/infer_json/{name}" + +kafka: + group_id: "http-dispatcher" + dlq_topic: "dlq.inference.http" diff --git a/services/inference_http/models/fruit_defect/__init__.py b/docker similarity index 100% rename from services/inference_http/models/fruit_defect/__init__.py rename to docker diff --git a/docker-compose.yml b/docker-compose.yml index ea690204e..77bb95bd1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -678,7 +678,6 @@ services: fs.s3a.connection.ssl.enabled: false python.client.executable: /usr/bin/python3 python.executable: /usr/bin/python3 - - HTTP_INFER_URL=http://fruit-inference-http:8000/infer_json volumes: - ./streaming/flink/jobs:/opt/flink/jobs:ro - ./streaming/flink/connectors/flink-json-1.18.1.jar:/opt/flink/lib/flink-json-1.18.1.jar:ro @@ -687,6 +686,7 @@ services: - ./streaming/flink/connectors/kafka-clients-3.2.3.jar:/opt/flink/lib/kafka-clients-3.2.3.jar:ro - ./streaming/flink/connectors/lz4-java-1.8.0.jar:/opt/flink/lib/lz4-java-1.8.0.jar:ro - ./streaming/flink/connectors/snappy-java-1.1.10.5.jar:/opt/flink/lib/snappy-java-1.1.10.5.jar:ro + - ./configs/mapping.yml:/etc/dispatcher/mapping.yml:ro restart: unless-stopped @@ -713,14 +713,15 @@ services: fs.s3a.connection.ssl.enabled: false python.client.executable: /usr/bin/python3 python.executable: /usr/bin/python3 - - HTTP_INFER_URL=http://fruit-inference-http:8000/infer_json volumes: + - ./streaming/flink/jobs:/opt/flink/jobs:ro - ./streaming/flink/connectors/flink-json-1.18.1.jar:/opt/flink/lib/flink-json-1.18.1.jar:ro - ./streaming/flink/connectors/flink-sql-connector-kafka-3.2.0-1.18.jar:/opt/flink/lib/flink-sql-connector-kafka-3.2.0-1.18.jar:ro - ./streaming/flink/connectors/flink-connector-kafka-3.2.0-1.18.jar:/opt/flink/lib/flink-connector-kafka-3.2.0-1.18.jar:ro - ./streaming/flink/connectors/kafka-clients-3.2.3.jar:/opt/flink/lib/kafka-clients-3.2.3.jar:ro - ./streaming/flink/connectors/lz4-java-1.8.0.jar:/opt/flink/lib/lz4-java-1.8.0.jar:ro - ./streaming/flink/connectors/snappy-java-1.1.10.5.jar:/opt/flink/lib/snappy-java-1.1.10.5.jar:ro + - ./configs/mapping.yml:/etc/dispatcher/mapping.yml:ro restart: unless-stopped @@ -728,46 +729,50 @@ services: # -------------------------- # Inference HTTP Service # -------------------------- - fruit-inference-http: + inference-http: build: - context: ./services/inference_http + context: ./services/inference_http dockerfile: Dockerfile environment: - - TEAM=fruit - - WEIGHTS_PATH=/app/weights/fruit_cls_best.ts - - MINIO_ENDPOINT=minio-hot:9000 - - MINIO_ACCESS_KEY=minioadmin - - MINIO_SECRET_KEY=minioadmin123 - - MINIO_SECURE=0 + - WEIGHTS_ROOT=/app/services/inference_http/weights + - MINIO_ENDPOINT=minio-hot:9000 + - MINIO_ACCESS_KEY=minioadmin + - MINIO_SECRET_KEY=minioadmin123 + - MINIO_SECURE=0 volumes: - - ./services/inference_http/weights:/app/weights:ro - container_name: fruit-inference-http + - ./services:/app/services:ro networks: [ag_cloud] ports: - "8011:8000" restart: unless-stopped + # -------------------------- # Flink Jobs # -------------------------- - flink-dispatcher-fruit: + flink-dispatcher: + build: + context: ./streaming/flink + dockerfile: Dockerfile.flink-py image: agcloud-flink-py:1.18 - container_name: flink-dispatcher-fruit + container_name: flink-dispatcher depends_on: + kafka: {condition: service_started} flink-jobmanager: { condition: service_started } flink-taskmanager: { condition: service_started } - fruit-inference-http: { condition: service_started } + inference-http: { condition: service_started } networks: [ag_cloud] environment: - KAFKA_BOOTSTRAP=kafka:9092 - - INPUT_TOPIC=imagery.new.fruit - - TEAM=fruit - - HTTP_URL=http://fruit-inference-http:8000/infer_json - - DLQ_TOPIC=dlq.inference.http - - GROUP_ID=http-dispatcher-fruit - - PARALLELISM=2 + - INPUT_TOPICS=imagery.new.fruit + - CONFIG_PATH=/etc/dispatcher/mapping.yml + - RUNNER=http + - HTTP_URL_TEMPLATE=http://inference-http:8000/infer_json/{name} - PYFLINK_CLIENT_EXECUTABLE=/usr/bin/python3 + - GROUP_ID=http-dispatcher + - DLQ_TOPIC=dlq.inference.http + - OK_TOPIC=inference.dispatched volumes: - ./streaming/flink/jobs:/opt/flink/jobs:ro - ./streaming/flink/connectors/flink-connector-kafka-3.2.0-1.18.jar:/opt/flink/lib/flink-connector-kafka-3.2.0-1.18.jar:ro @@ -776,27 +781,30 @@ services: - ./streaming/flink/connectors/kafka-clients-3.2.3.jar:/opt/flink/lib/kafka-clients-3.2.3.jar:ro - ./streaming/flink/connectors/lz4-java-1.8.0.jar:/opt/flink/lib/lz4-java-1.8.0.jar:ro - ./streaming/flink/connectors/snappy-java-1.1.10.5.jar:/opt/flink/lib/snappy-java-1.1.10.5.jar:ro - command: [ - "bash","-lc", - "set -e; - echo 'Waiting for JobManager to accept commands...'; - until /opt/flink/bin/flink list --jobmanager flink-jobmanager:8081 >/dev/null 2>&1; do - echo 'still waiting...'; sleep 3; - done; - echo 'JobManager is ready!'; - /opt/flink/bin/flink run \ - -Dpython.client.executable=/usr/bin/python3 \ - -Dpython.executable=/usr/bin/python3 \ - -Dpipeline.jars=file:///opt/flink/lib/flink-connector-kafka-3.2.0-1.18.jar,file:///opt/flink/lib/flink-sql-connector-kafka-3.2.0-1.18.jar,file:///opt/flink/lib/flink-json-1.18.1.jar \ - --jobmanager flink-jobmanager:8081 \ - --detached \ - --python /opt/flink/jobs/http_dispatcher.py \ - -- \ - --bootstrap kafka:9092 \ - --input-topic imagery.new.fruit \ - --team fruit \ - --http-url http://fruit-inference-http:8000/infer_json \ - --group-id http-dispatcher-fruit \ - --dlq-topic dlq.inference.http; - tail -f /dev/null" - ] + - ./configs/mapping.yml:/etc/dispatcher/mapping.yml:ro # << חשוב! + command: > + bash -lc ' + set -euo pipefail; + echo "Waiting for JobManager to accept commands..."; + until /opt/flink/bin/flink list --jobmanager flink-jobmanager:8081 >/dev/null 2>&1; do + echo "still waiting..."; sleep 3; + done; + echo "JobManager is ready!"; + /opt/flink/bin/flink run \ + -Dpython.client.executable=/usr/bin/python3 \ + -Dpython.executable=/usr/bin/python3 \ + -Dpipeline.jars=file:///opt/flink/lib/flink-connector-kafka-3.2.0-1.18.jar,file:///opt/flink/lib/flink-sql-connector-kafka-3.2.0-1.18.jar,file:///opt/flink/lib/flink-json-1.18.1.jar \ + --jobmanager flink-jobmanager:8081 \ + --detached \ + --python /opt/flink/jobs/http_dispatcher.py \ + -- \ + --bootstrap "$$KAFKA_BOOTSTRAP" \ + --input-topics "$$INPUT_TOPICS" \ + --config "$$CONFIG_PATH" \ + --runner "$$RUNNER" \ + --http-url-template "$$HTTP_URL_TEMPLATE" \ + --group-id "$$GROUP_ID" \ + --dlq-topic "$$DLQ_TOPIC" \ + --ok-topic "$$OK_TOPIC"; + tail -f /dev/null + ' \ No newline at end of file diff --git a/services/inference_http/models/fruit_defect/Docs/DATASET_SPEC.md b/services/fruit_defect/Docs/DATASET_SPEC.md similarity index 100% rename from services/inference_http/models/fruit_defect/Docs/DATASET_SPEC.md rename to services/fruit_defect/Docs/DATASET_SPEC.md diff --git a/services/inference_http/models/fruit_defect/Docs/USAGE.md b/services/fruit_defect/Docs/USAGE.md similarity index 100% rename from services/inference_http/models/fruit_defect/Docs/USAGE.md rename to services/fruit_defect/Docs/USAGE.md diff --git a/services/inference_http/models/fruit_defect/inference/__init__.py b/services/fruit_defect/__init__.py similarity index 100% rename from services/inference_http/models/fruit_defect/inference/__init__.py rename to services/fruit_defect/__init__.py diff --git a/services/inference_http/models/fruit_defect/configs/fruit_defect.yaml b/services/fruit_defect/configs/fruit_defect.yaml similarity index 100% rename from services/inference_http/models/fruit_defect/configs/fruit_defect.yaml rename to services/fruit_defect/configs/fruit_defect.yaml diff --git a/services/fruit_defect/inference/__init__.py b/services/fruit_defect/inference/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/services/inference_http/models/fruit_defect/inference/infer_fruit_defect.py b/services/fruit_defect/inference/infer_fruit_defect.py similarity index 92% rename from services/inference_http/models/fruit_defect/inference/infer_fruit_defect.py rename to services/fruit_defect/inference/infer_fruit_defect.py index 530146c78..acd4f07ec 100644 --- a/services/inference_http/models/fruit_defect/inference/infer_fruit_defect.py +++ b/services/fruit_defect/inference/infer_fruit_defect.py @@ -24,7 +24,6 @@ except Exception: TQDM_AVAILABLE = False -# --- קונפיג דרך ENV (ברירות מחדל מותאמות לסטאק שלך) --- MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "localhost:9001") MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "minioadmin") MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "minioadmin123") @@ -32,8 +31,8 @@ BUCKET_INPUT = os.getenv("MINIO_BUCKET_INPUT", "imagery") BUCKET_OUTPUT = os.getenv("MINIO_BUCKET_OUTPUT", "telemetry") -INPUT_PREFIX = os.getenv("MINIO_INPUT_PREFIX", "inputs/batch1/") # מאיפה לקרוא תמונות -OUTPUT_PREFIX = os.getenv("MINIO_OUTPUT_PREFIX", "results/batch1/") # לאן להעלות תוצאות +INPUT_PREFIX = os.getenv("MINIO_INPUT_PREFIX", "inputs/batch1/") +OUTPUT_PREFIX = os.getenv("MINIO_OUTPUT_PREFIX", "results/batch1/") WEIGHTS_BUCKET = os.getenv("MINIO_BUCKET_WEIGHTS","imagery") WEIGHTS_PREFIX = os.getenv("MINIO_WEIGHTS_PREFIX","models/") @@ -41,11 +40,11 @@ LOCAL_WEIGHTS_PT = os.getenv("MODEL_PT_LOCAL", "./outputs/fruit_cls_best.pt") IMG_SIZE = int(os.getenv("IMG_SIZE", "192")) -THRESHOLD = float(os.getenv("CLS_THRESHOLD", "0.5")) # סף בינארי defect/ok +THRESHOLD = float(os.getenv("CLS_THRESHOLD", "0.5")) DL_WORKERS = int(os.getenv("DL_WORKERS", "8")) BATCH_SIZE = int(os.getenv("BATCH_SIZE", "8")) -HEARTBEAT_PERIOD = int(os.getenv("HEARTBEAT_PERIOD", "30")) # שניות +HEARTBEAT_PERIOD = int(os.getenv("HEARTBEAT_PERIOD", "30")) # --- logging config --- logging.basicConfig( @@ -152,21 +151,25 @@ def fetch_weights_if_missing() -> Path: log.info(f"Downloaded PT weights: {pt_local}") return pt_local -def load_model(weights_path: Path): - log.info(f"Loading model from: {weights_path}") - if weights_path.suffix == ".ts": - model = torch.jit.load(str(weights_path), map_location="cpu") - else: - obj = torch.load(str(weights_path), map_location="cpu") - if hasattr(obj, "state_dict"): - model = obj - elif isinstance(obj, dict): - raise RuntimeError("Loaded a state_dict dict but no model class is defined here. Please export TorchScript (.ts).") - else: - model = obj - model.eval() - log.info("Model loaded and set to eval()") - return model +def load_model(weights_path, device=None): + p = Path(weights_path) + device = device or ("cuda" if torch.cuda.is_available() else "cpu") + + print(f"Loading model from: {p}") + + if p.suffix == ".ts": + model = torch.jit.load(str(p), map_location=device) + model.eval() + return model + + if p.suffix == ".pt": + model = build_model_architecture() + state = torch.load(str(p), map_location=device) + model.load_state_dict(state) + model.eval() + return model + + raise ValueError(f"Unsupported weights suffix: {p.suffix} for {p}") def get_preprocess(): return transforms.Compose([ @@ -180,7 +183,7 @@ def infer_single(model, img: Image.Image, preprocess, device="cpu") -> Dict: t0 = time.perf_counter() with torch.no_grad(): y = model(x) - dt = (time.perf_counter() - t0) * 1000.0 # ms + dt = (time.perf_counter() - t0) * 1000.0 if isinstance(y, (list, tuple)): y = y[0] diff --git a/services/inference_http/models/fruit_defect/requirements.txt b/services/fruit_defect/requirements.txt similarity index 100% rename from services/inference_http/models/fruit_defect/requirements.txt rename to services/fruit_defect/requirements.txt diff --git a/services/inference_http/models/fruit_defect/scripts/eval_cls.py b/services/fruit_defect/scripts/eval_cls.py similarity index 100% rename from services/inference_http/models/fruit_defect/scripts/eval_cls.py rename to services/fruit_defect/scripts/eval_cls.py diff --git a/services/inference_http/models/fruit_defect/scripts/export_torchscript.py b/services/fruit_defect/scripts/export_torchscript.py similarity index 100% rename from services/inference_http/models/fruit_defect/scripts/export_torchscript.py rename to services/fruit_defect/scripts/export_torchscript.py diff --git a/services/inference_http/models/fruit_defect/scripts/extract_mistakes.py b/services/fruit_defect/scripts/extract_mistakes.py similarity index 100% rename from services/inference_http/models/fruit_defect/scripts/extract_mistakes.py rename to services/fruit_defect/scripts/extract_mistakes.py diff --git a/services/inference_http/models/fruit_defect/training/train_fruit_defect_cls.py b/services/fruit_defect/training/train_fruit_defect_cls.py similarity index 100% rename from services/inference_http/models/fruit_defect/training/train_fruit_defect_cls.py rename to services/fruit_defect/training/train_fruit_defect_cls.py diff --git a/services/inference_http/adapters/fruit_defect_runner.py b/services/inference_http/adapters/fruit_defect_runner.py index 5a1cd00eb..7cd2f3763 100644 --- a/services/inference_http/adapters/fruit_defect_runner.py +++ b/services/inference_http/adapters/fruit_defect_runner.py @@ -1,39 +1,12 @@ -import os, io -from pathlib import Path -from typing import Any, Dict, Optional - -from PIL import Image -import torch - -# Core code imported from your fruit-defect module -from models.fruit_defect.inference.infer_fruit_defect import ( - load_model, get_preprocess, infer_single -) - -# Local weights only -WEIGHTS_PATH = Path(os.getenv("WEIGHTS_PATH", "/app/weights/fruit_cls_best.ts")) - -def _ensure_local_weights(p: Path) -> Path: - if not p.exists(): - raise FileNotFoundError(f"Local weights not found at: {p}") - return p - -class FruitDefectRunner: - def __init__(self, model_tag: Optional[str] = None): - # Allows selecting a different weights file in future via extra/model_tag - weights_path = _ensure_local_weights(WEIGHTS_PATH) - self.model = load_model(weights_path) - self.preprocess = get_preprocess() - self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.model = self.model.to(self.device).eval() - - def run(self, image_bytes: bytes, model_tag=None, extra=None) -> Dict[str, Any]: - img = Image.open(io.BytesIO(image_bytes)).convert("RGB") - result = infer_single(self.model, img, self.preprocess, device=self.device) - # Normalize to standard HTTP response structure - return { - "label": result.get("status"), - "score": result.get("prob_defect"), - "confidence": result.get("confidence"), - "latency_ms_model": result.get("latency_ms_model"), - } +from typing import Dict, Type +from adapters.fruit_defect_runner import FruitDefectRunner + +REGISTRY: Dict[str, Type] = { + "fruit_defect": FruitDefectRunner, +} + +def get_adapter(name: str): + key = (name or "").lower() + if key not in REGISTRY: + raise ValueError(f"Unknown model name: {key}") + return REGISTRY[key]() diff --git a/services/inference_http/app.py b/services/inference_http/app.py index 3a490493d..b3e287051 100644 --- a/services/inference_http/app.py +++ b/services/inference_http/app.py @@ -1,82 +1,87 @@ - -import os, time -from fastapi import FastAPI, Header, HTTPException +import os, io, time, importlib +from typing import Optional, Dict, Any +from fastapi import FastAPI, HTTPException, Header, UploadFile, File from pydantic import BaseModel, ConfigDict +from PIL import Image from minio import Minio -from model_registry import get_model_runner - -TEAM = os.getenv("TEAM") -if not TEAM: - raise RuntimeError("Missing TEAM environment variable – please set TEAM=") +WEIGHTS_ROOT = os.getenv("WEIGHTS_ROOT", "/weights") MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "minio-hot:9000") MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "minioadmin") MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "minioadmin123") MINIO_SECURE = os.getenv("MINIO_SECURE", "0") == "1" -app = FastAPI(title="Fruit Inference HTTP") +app = FastAPI(title="AgCloud Inference Gateway (multi-model)") -class InferRequest(BaseModel): - # Accept only bucket+key; any other fields are rejected (422) +class InferJson(BaseModel): model_config = ConfigDict(extra="forbid") bucket: str key: str +def get_adapter(name: str): + mod_path = f"services.{name}.adapter" + try: + mod = importlib.import_module(mod_path) + except ModuleNotFoundError: + raise HTTPException(status_code=404, detail=f"model '{name}' not found") + cls = getattr(mod, "Adapter", None) + if cls is None: + raise HTTPException(status_code=500, detail=f"adapter for '{name}' missing class Adapter") + return cls(WEIGHTS_ROOT) @app.on_event("startup") def _startup(): - app.state.mc = Minio( - MINIO_ENDPOINT, - access_key=MINIO_ACCESS_KEY, - secret_key=MINIO_SECRET_KEY, - secure=MINIO_SECURE, - ) - # The runner already knows how to read image_uri from S3 - app.state.runner = get_model_runner(TEAM) + app.state.mc = Minio(MINIO_ENDPOINT, access_key=MINIO_ACCESS_KEY, + secret_key=MINIO_SECRET_KEY, secure=MINIO_SECURE) @app.get("/healthz") def healthz(): - return {"ok": True, "team": TEAM} + return {"ok": True, "models": "dynamic"} -@app.post("/infer_json") -def infer_json( - req: InferRequest, - idem_key: str | None = Header(default=None, alias="Idempotency-Key"), - corr_id: str | None = Header(default=None, alias="X-Correlation-ID"), +@app.post("/infer_json/{name}") +def infer_json_by_name( + name: str, + req: InferJson, + idempotency_key: Optional[str] = Header(default=None, alias="Idempotency-Key"), + correlation_id: Optional[str] = Header(default=None, alias="X-Correlation-ID"), ): started = time.perf_counter() - try: - runner = app.state.runner - - # Always build the image URI from bucket and key - s3_uri = f"s3://{req.bucket}/{req.key}" - - # Try to read the image bytes from MinIO + runner = get_adapter(name) obj = app.state.mc.get_object(req.bucket, req.key) try: image_bytes = obj.read() finally: - obj.close() - obj.release_conn() - - # Attempt to run the model with bytes input first - try: - result = runner.run(image_bytes) - except TypeError: - # If the function does not accept bytes, try with URI instead - result = runner.run(s3_uri) - - latency_ms = int((time.perf_counter() - started) * 1000) + obj.close(); obj.release_conn() + result = runner.predict(image_bytes) return { - "ok": True, - "team": TEAM, - "result": result, - "image_uri": s3_uri, - "latency_ms": latency_ms, - "idempotency_key": idem_key, - "correlation_id": corr_id, + "ok": True, "name": name, "image": {"bucket": req.bucket, "key": req.key}, + "result": result, "latency_ms": int((time.perf_counter()-started)*1000), + "idempotency_key": idempotency_key, "correlation_id": correlation_id } + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"inference failed: {e}") +@app.post("/infer_upload/{name}") +async def infer_upload_by_name( + name: str, + file: UploadFile = File(...), + idempotency_key: Optional[str] = Header(default=None, alias="Idempotency-Key"), + correlation_id: Optional[str] = Header(default=None, alias="X-Correlation-ID"), +): + started = time.perf_counter() + try: + runner = get_adapter(name) + image_bytes = await file.read() + result = runner.predict(image_bytes) + return { + "ok": True, "name": name, "file": file.filename, + "result": result, "latency_ms": int((time.perf_counter()-started)*1000), + "idempotency_key": idempotency_key, "correlation_id": correlation_id + } + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=f"inference failed: {e}") diff --git a/services/inference_http/model_registry.py b/services/inference_http/model_registry.py index 3c5c0d6df..f7e8eee4d 100644 --- a/services/inference_http/model_registry.py +++ b/services/inference_http/model_registry.py @@ -1,6 +1,11 @@ -from typing import Any, Dict +from typing import Any, Dict, Type from adapters.fruit_defect_runner import FruitDefectRunner + +REGISTRY: Dict[str, Type] = { + "fruit_defect": FruitDefectRunner, +} + class FruitRunner: def __init__(self): self.impl = FruitDefectRunner() @@ -8,8 +13,8 @@ def __init__(self): def run(self, image_bytes: bytes, model_tag=None, extra=None) -> Dict[str, Any]: return self.impl.run(image_bytes, model_tag=model_tag, extra=extra) -def get_model_runner(team: str): - t = (team or "").lower() - if t == "fruit": - return FruitRunner() - raise ValueError(f"unknown TEAM {t}") +def get_adapter(name: str): + key = (name or "").lower() + if key not in REGISTRY: + raise ValueError(f"Unknown model name: {key}") + return REGISTRY[key]() \ No newline at end of file diff --git a/services/inference_http/requirements.txt b/services/inference_http/requirements.txt index 2e65f4b67..e599e7042 100644 --- a/services/inference_http/requirements.txt +++ b/services/inference_http/requirements.txt @@ -4,3 +4,5 @@ minio pillow numpy==1.26.4 pydantic +python-multipart + diff --git a/services/inference_http/weights/fruit_cls_best.ts b/services/inference_http/weights/fruit_defect/fruit_cls_best.ts similarity index 100% rename from services/inference_http/weights/fruit_cls_best.ts rename to services/inference_http/weights/fruit_defect/fruit_cls_best.ts diff --git a/streaming/flink/Dockerfile.flink-py b/streaming/flink/Dockerfile.flink-py index 60664e730..c6d27af36 100644 --- a/streaming/flink/Dockerfile.flink-py +++ b/streaming/flink/Dockerfile.flink-py @@ -13,6 +13,8 @@ RUN python3 -m pip install --no-cache-dir \ protobuf==3.20.3 \ cloudpickle==2.2.1 +RUN pip install --no-cache-dir pyyaml aiohttp + ENV PYFLINK_CLIENT_EXECUTABLE=/usr/bin/python3 ENV FLINK_PYTHON=/usr/bin/python3 diff --git a/streaming/flink/README_Adding_New_Model.md b/streaming/flink/README_Adding_New_Model.md new file mode 100644 index 000000000..f9cb5db73 --- /dev/null +++ b/streaming/flink/README_Adding_New_Model.md @@ -0,0 +1,231 @@ +# Adding a New Model to the Generic Kafka → Flink → HTTP Inference Pipeline + +This README explains how to plug a **new ML model** into the existing **Kafka → Flink → HTTP** dispatcher used in this repo. The design lets multiple teams add models in a uniform way with minimal code: you only add a small *Runner* (adapter) file, register it, map a Kafka input topic, and you’re done. + +--- + +## Architecture + +``` +Kafka (input topics) + │ + ▼ +Flink Dispatcher (http_dispatcher.py) + │ calls HTTP: /infer_json/{name} + ▼ +HTTP Inference Service (FastAPI) + │ loads runner by name (model_registry) + ▼ +Kafka (OK topic) Kafka (DLQ) +``` + +- **Input topics** carry image events +- **Dispatcher** reads from input topics, calls the HTTP service with `name=`, and publishes results to **OK_TOPIC** or errors to **DLQ_TOPIC**. +- **HTTP service** looks up the runner class from a central `REGISTRY` and executes `predict(image_bytes)`. + +--- + +## Prerequisites + +- Docker & docker compose +- Running services: `kafka`, `minio-hot`, `inference-http`, `flink-jobmanager`, `flink-taskmanager`, `flink-dispatcher` +- Topics created in Kafka (auto-create is disabled in this stack) +- Access to MinIO object(s) you want to test + +--- + +## Directory Layout (relevant parts) + +``` +services/ + inference_http/ + adapters/ + fruit_defect_runner.py + _runner.py # ← add yours here + model_registry.py # ← register your runner class here + weights/ + fruit_defect/fruit_cls_best.ts + /.ts # ← add weights here +configs/ + mapping.yml # ← topic→model mapping + http/kafka config +streaming/ + flink/ + jobs/http_dispatcher.py # ← Flink job +``` + +--- + +## Quick Start (existing fruit_defect model) + +1. Ensure topics exist (example): + ```bash + /opt/bitnami/kafka/bin/kafka-topics.sh --bootstrap-server kafka:9092 --create --topic imagery.new.fruit --partitions 3 --replication-factor 1 || true + ``` +2. Bring up services (snippet): + ```bash + docker compose up -d kafka minio-hot inference-http flink-jobmanager flink-taskmanager flink-dispatcher + ``` +3. Send a test event: + ```bash + echo '{"bucket":"imagery","key":"Apple__Healthy/FreshApple (10).jpg"}' | docker exec -i kafka bash -lc "/opt/bitnami/kafka/bin/kafka-console-producer.sh --bootstrap-server kafka:9092 --topic imagery.new.fruit" + ``` +4. Read success output (OK topic): + ```bash + docker exec -it kafka bash -lc "/opt/bitnami/kafka/bin/kafka-console-consumer.sh --bootstrap-server kafka:9092 --topic inference.dispatched --from-beginning --timeout-ms 10000" + ``` +5. Check DLQ if needed: + ```bash + docker exec -it kafka bash -lc "/opt/bitnami/kafka/bin/kafka-console-consumer.sh --bootstrap-server kafka:9092 --topic dlq.inference.http --from-beginning --timeout-ms 10000" + ``` + +--- + +## Add a New Model (Step-by-Step) + +### 1) Add model weights +Place your weights under: +``` +services/inference_http/weights//.ts +``` + +### 2) Create a runner (adapter) +Create: `services/inference_http/adapters/_runner.py` + +Minimal template: +```python +from typing import Dict, Any +from PIL import Image +import io, torch +# replace the following import with your own model code +from services.fruit_defect.inference.infer_fruit_defect import load_model, get_preprocess, infer_single + +class MyModelRunner: + def __init__(self): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = load_model("/app/services/inference_http/weights/my_model/model.ts").to(self.device).eval() + self.pp = get_preprocess() + + def predict(self, image_bytes: bytes) -> Dict[str, Any]: + img = Image.open(io.BytesIO(image_bytes)).convert("RGB") + r = infer_single(self.model, img, self.pp, device=self.device) + # return structure is flexible; at minimum return a label + return { + "label": r.get("status", "unknown"), + "score": r.get("prob_defect"), # optional + "confidence": r.get("confidence"), # optional + "latency_ms_model": r.get("latency_ms_model") # optional + } +``` + +> **Note**: The runner is a thin bridge. You **do not need to modify the model’s own code**. Keep business logic inside the runner. + +### 3) Register your runner +Edit `services/inference_http/model_registry.py`: +```python +from services.inference_http.adapters.my_model_runner import MyModelRunner +REGISTRY["my_model"] = MyModelRunner +``` + +### 4) Map an input topic to your model +Edit `configs/mapping.yml` and add a topic entry: +```yaml +topics: + imagery.new.fruit: + name: fruit_defect + runner: http + imagery.new.leaf: + name: my_model + runner: http + +http: + url_template: "http://inference-http:8000/infer_json/{name}" + +kafka: + group_id: "http-dispatcher" + dlq_topic: "dlq.inference.http" +``` + +### 5) Include the new topic in the dispatcher environment +In `docker-compose.yml` under `flink-dispatcher`: +```yaml +environment: + - KAFKA_BOOTSTRAP=kafka:9092 + - INPUT_TOPICS=imagery.new.fruit,imagery.new.leaf + - CONFIG_PATH=/etc/dispatcher/mapping.yml + - RUNNER=http + - HTTP_URL_TEMPLATE=http://inference-http:8000/infer_json/{name} + - GROUP_ID=http-dispatcher + - DLQ_TOPIC=dlq.inference.http + - OK_TOPIC=inference.dispatched +``` + +### 6) Create the new Kafka topic (if auto-create is disabled) +```bash +/opt/bitnami/kafka/bin/kafka-topics.sh --bootstrap-server kafka:9092 --create --topic imagery.new.leaf --partitions 3 --replication-factor 1 +``` + +### 7) Rebuild & restart the HTTP service (to load the new runner) +```bash +docker compose build inference-http +docker compose up -d inference-http +``` + +### 8) Restart the dispatcher (to re-read mapping if you changed it) +```bash +docker compose up -d flink-dispatcher +``` + +### 9) Send test messages and verify output +```bash +echo '{"bucket":"imagery","key":"Apple__Healthy/FreshApple (10).jpg"}' | docker exec -i kafka bash -lc "/opt/bitnami/kafka/bin/kafka-console-producer.sh --bootstrap-server kafka:9092 --topic imagery.new.leaf" + +docker exec -it kafka bash -lc "/opt/bitnami/kafka/bin/kafka-console-consumer.sh --bootstrap-server kafka:9092 --topic inference.dispatched --timeout-ms 10000" +``` + +--- + +## Notes & Conventions + +- **OK_TOPIC is shared** across models (recommended): filter by the `name` field in the message. +- **DLQ_TOPIC is shared** as well: you’ll see the HTTP status and detailed error body. +- `predict()` output is **flexible**: at minimum, return a `label`. If your model has scores/latency, add those too. +- Keep the **single source of truth** for model registration in `services/inference_http/model_registry.py`. Do not keep duplicate registries in model packages. +- Be precise with Kafka topic names (no stray spaces) and MinIO keys (object must exist). + +--- + +## Troubleshooting + +| Symptom | Likely Cause | Fix | +|---|---|---| +| No messages in `inference.dispatched` | Dispatcher env not set or wrong input topic | Check `docker logs flink-dispatcher` and ensure `INPUT_TOPICS` includes your topic; verify `mapping.yml` path and content | +| DLQ shows `NoSuchKey` | MinIO object key missing | Verify `bucket/key` is correct and the object exists | +| HTTP returns 500 | Weights path wrong or model failed to load | Check `inference-http` logs and runner paths | +| Flink job not running | Dispatcher didn’t submit or failed | `docker exec -it flink-jobmanager /opt/flink/bin/flink list`, then re-up dispatcher | +| Producer writes but consumer sees nothing | Wrong topic or wrong bootstrap | Verify topic name, bootstrap `kafka:9092`, and create topic if needed | + +--- + +## Example Success Message (from OK topic) + +```json +{ + "event_id": "0831f05c-c514-40d6-9995-c94c4a068422", + "name": "fruit_defect", + "http_url": "http://inference-http:8000/infer_json/fruit_defect", + "ok": true, + "status": 200, + "body": { + "ok": true, + "name": "fruit_defect", + "image": {"bucket": "imagery", "key": "Apple__Healthy/FreshApple (10).jpg"}, + "result": {"label": "ok", "confidence": 0.9999, "latency_ms_model": 233.6} + }, + "topic": "imagery.new.fruit" +} +``` + +--- + +## Credits +AgCloud / VectorDB pipeline (Kafka → Flink → HTTP Inference). diff --git a/streaming/flink/jobs/http_dispatcher.py b/streaming/flink/jobs/http_dispatcher.py index bef7adf4e..f749fb7fe 100644 --- a/streaming/flink/jobs/http_dispatcher.py +++ b/streaming/flink/jobs/http_dispatcher.py @@ -1,226 +1,187 @@ -# -*- coding: utf-8 -*- -""" -Flink 1.18 PyFlink — HTTP dispatcher (Kafka -> HTTP /infer_json) -- DataStream API: KafkaSource / KafkaSink -- WatermarkStrategy from pyflink.common (compatible with 1.18) -- Splits successful requests to inference.dispatched. and errors to DLQ -""" - -import os -import json -import uuid -import asyncio -import argparse -from typing import Any, Dict - -import aiohttp - -# PyFlink core +# Flink 1.18 — Kafka -> HTTP (multi-model via mapping.yml + {name} template) +import os, json, uuid, asyncio, argparse +from typing import Any, Dict, List, Optional +import aiohttp, yaml + from pyflink.datastream import StreamExecutionEnvironment, RuntimeExecutionMode from pyflink.datastream.functions import MapFunction, RuntimeContext - -# WatermarkStrategy/Types – in Flink 1.18 imported from pyflink.common from pyflink.common import WatermarkStrategy, Types from pyflink.common.serialization import SimpleStringSchema - -# Kafka connectors from pyflink.datastream.connectors.kafka import ( - KafkaSource, - KafkaSink, - KafkaOffsetsInitializer, - KafkaRecordSerializationSchema, + KafkaSource, KafkaSink, KafkaOffsetsInitializer, + KafkaRecordSerializationSchema ) -# ----------------------------- -# Args & Config -# ----------------------------- +# ---------- CLI ---------- def parse_args(): - p = argparse.ArgumentParser(description="HTTP dispatcher for imagery events") + p = argparse.ArgumentParser("multi-model HTTP dispatcher (topics list)") p.add_argument("--bootstrap", default=os.getenv("KAFKA_BOOTSTRAP", "kafka:9092")) - p.add_argument("--input-topic", default=os.getenv("INPUT_TOPIC", "imagery.new.fruit")) - p.add_argument("--team", default=os.getenv("TEAM", "fruit")) - p.add_argument("--http-url", - default=os.getenv("HTTP_URL", "http://fruit-inference-http:8000/infer_json")) + p.add_argument("--input-topics", default=os.getenv("INPUT_TOPICS", "")) + p.add_argument("--config", dest="config_path", + default=os.getenv("CONFIG_PATH", "/etc/dispatcher/mapping.yml")) + p.add_argument("--runner", default=os.getenv("RUNNER", "http")) + p.add_argument("--http-url-template", + default=os.getenv("HTTP_URL_TEMPLATE", "http://inference-http:8000/infer_json/{name}")) + p.add_argument("--group-id", default=os.getenv("GROUP_ID", "http-dispatcher")) p.add_argument("--dlq-topic", default=os.getenv("DLQ_TOPIC", "dlq.inference.http")) - p.add_argument("--group-id", default=os.getenv("GROUP_ID", "http-dispatcher-fruit")) - # tuning + p.add_argument("--ok-topic", default=os.getenv("OK_TOPIC", "inference.dispatched")) + p.add_argument("--parallelism", type=int, default=int(os.getenv("PARALLELISM", "2"))) - p.add_argument("--http-connect-timeout", type=float, - default=float(os.getenv("HTTP_CONNECT_TIMEOUT", "2.0"))) - p.add_argument("--http-read-timeout", type=float, - default=float(os.getenv("HTTP_READ_TIMEOUT", "5.0"))) - p.add_argument("--http-max-retries", type=int, - default=int(os.getenv("HTTP_MAX_RETRIES", "5"))) - p.add_argument("--http-retry-backoff-s", type=float, - default=float(os.getenv("HTTP_RETRY_BACKOFF_S", "0.5"))) + p.add_argument("--http-connect-timeout", type=float, default=float(os.getenv("HTTP_CONNECT_TIMEOUT","2.0"))) + p.add_argument("--http-read-timeout", type=float, default=float(os.getenv("HTTP_READ_TIMEOUT","5.0"))) + p.add_argument("--http-max-retries", type=int, default=int(os.getenv("HTTP_MAX_RETRIES","5"))) + p.add_argument("--http-retry-backoff-s", type=float, default=float(os.getenv("HTTP_RETRY_BACKOFF_S","0.5"))) return p.parse_args() - -# ----------------------------- -# MapFunction – HTTP POST with retry -# ----------------------------- +# ---------- Mapping ---------- +class Mapping: + def __init__(self, cfg_path: str, default_template: str): + with open(cfg_path, "r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) or {} + self.topics: Dict[str, Dict[str, Any]] = (cfg.get("topics") or {}) + http_cfg = cfg.get("http") or {} + self.template = http_cfg.get("url_template") or default_template + + def resolve(self, topic: str) -> Optional[Dict[str, str]]: + m = self.topics.get(topic) + if not m: + return None + name = m.get("name") + if not name: + return None + url = self.template.format(name=name) + return {"name": name, "url": url} + +# ---------- HTTP Map ---------- class HttpMap(MapFunction): - def __init__(self, http_url: str, connect_timeout: float, read_timeout: float, - max_retries: int, retry_backoff_s: float, team: str): - self.http_url = http_url + def __init__(self, mapping: Mapping, connect_timeout: float, read_timeout: float, + max_retries: int, retry_backoff_s: float): + self.mapping = mapping self.connect_timeout = connect_timeout self.read_timeout = read_timeout self.max_retries = max_retries self.retry_backoff_s = retry_backoff_s - self.team = team self.loop = None self.session = None - def open(self, runtime_context: RuntimeContext): + def open(self, ctx: RuntimeContext): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - timeout = aiohttp.ClientTimeout(total=None, - connect=self.connect_timeout, - sock_read=self.read_timeout) - self.session = self.loop.run_until_complete(self._make_session(timeout)) - - async def _make_session(self, timeout: aiohttp.ClientTimeout) -> aiohttp.ClientSession: - return aiohttp.ClientSession(timeout=timeout) - - async def _post_once(self, url: str, payload: Dict[str, Any], headers: Dict[str, str]): - async with self.session.post(url, json=payload, headers=headers) as resp: - return resp.status, await resp.text() - - async def _post_with_retry(self, url: str, payload: Dict[str, Any], headers: Dict[str, str]): - attempt = 0 - while True: - try: - status, text = await self._post_once(url, payload, headers) - if 200 <= status < 300: - return {"ok": True, "status": status, "body": text} - # Do not retry most 4xx except 408/429 - if 400 <= status < 500 and status not in (408, 429): - return {"ok": False, "status": status, "body": text, "retry": False} - attempt += 1 - if attempt > self.max_retries: - return {"ok": False, "status": status, "body": text, "retry": False} - await asyncio.sleep(self.retry_backoff_s * attempt) - except Exception as e: - attempt += 1 - if attempt > self.max_retries: - return {"ok": False, "status": 599, "body": str(e), "retry": False} - await asyncio.sleep(self.retry_backoff_s * attempt) + timeout = aiohttp.ClientTimeout(total=None, connect=self.connect_timeout, sock_read=self.read_timeout) + self.session = self.loop.run_until_complete(aiohttp.ClientSession(timeout=timeout).__aenter__()) def map(self, s: str) -> str: - # 1) Parse JSON + topic = None try: - event = json.loads(s) - except Exception as e: - return json.dumps( - {"ok": False, "status": 422, "body": f"bad json: {e}", "raw": s, "stage": "parse"}, - ensure_ascii=False - ) + env = json.loads(s) + topic = env.get("__topic") + value = env.get("value") + except Exception: + value = s # fallback - # 2) Generate idempotency key - event_id = event.get("event_id") or str(uuid.uuid4()) + m = self.mapping.resolve(topic or "") + if not m: + return json.dumps({"ok": False, "stage": "route", "reason": "no_mapping", "topic": topic, "raw": value}, ensure_ascii=False) + + try: + event = json.loads(value) + except Exception as e: + return json.dumps({"ok": False, "stage": "parse", "status": 422, "body": f"bad json: {e}", "raw": value}, ensure_ascii=False) - # 3) Validate fields: must have bucket+key; image_uri not allowed - if "image_uri" in event: - return json.dumps( - {"ok": False, "status": 422, - "body": "image_uri not supported; use {bucket,key} only", - "event": event, "stage": "validate"}, - ensure_ascii=False - ) - - bucket = event.get("bucket") - key = event.get("key") + bucket, key = event.get("bucket"), event.get("key") if not bucket or not key: - return json.dumps( - {"ok": False, "status": 422, - "body": "missing required fields: bucket and key", - "event": event, "stage": "validate"}, - ensure_ascii=False - ) - - # 4) Prepare headers and payload: send only bucket/key to the inference service - headers = { - "Content-Type": "application/json", - "Idempotency-Key": event_id, - "X-Correlation-ID": event_id, - } + return json.dumps({"ok": False, "stage": "validate", "status": 422, "body": "missing bucket/key", "event": event}, ensure_ascii=False) + + event_id = event.get("event_id") or str(uuid.uuid4()) + headers = {"Content-Type": "application/json", "Idempotency-Key": event_id, "X-Correlation-ID": event_id} payload = {"bucket": bucket, "key": key} - # 5) Execute HTTP POST with retry logic - res = self.loop.run_until_complete(self._post_with_retry(self.http_url, payload, headers)) + async def _once(): + async with self.session.post(m["url"], json=payload, headers=headers) as resp: + return resp.status, await resp.text() + + async def _with_retry(): + attempt = 0 + while True: + try: + st, body = await _once() + if 200 <= st < 300: + return {"ok": True, "status": st, "body": body} + if 400 <= st < 500 and st not in (408, 429): + return {"ok": False, "status": st, "body": body, "retry": False} + except Exception as e: + st, body = 599, str(e) + attempt += 1 + if attempt > self.max_retries: + return {"ok": False, "status": st, "body": body, "retry": False} + await asyncio.sleep(self.retry_backoff_s * attempt) - # 6) Wrap output into a consistent JSON response - out = { - "event_id": event_id, - "team": self.team, - "http_url": self.http_url, - **res, - "event": event - } - return json.dumps(out, ensure_ascii=False) + res = self.loop.run_until_complete(_with_retry()) + return json.dumps( + {"event_id": event_id, "name": m["name"], "http_url": m["url"], **res, "event": event, "topic": topic}, + ensure_ascii=False + ) def close(self): - try: - if self.session is not None: - self.loop.run_until_complete(self.session.close()) - finally: - if self.loop is not None: - self.loop.close() + if self.session: + self.loop.run_until_complete(self.session.__aexit__(None, None, None)) + if self.loop: + self.loop.close() - -# ----------------------------- -# Flink Topology -# ----------------------------- -def build_env(parallelism: int) -> StreamExecutionEnvironment: +# ---------- Source/Sink ---------- +def build_env(par: int): env = StreamExecutionEnvironment.get_execution_environment() env.set_runtime_mode(RuntimeExecutionMode.STREAMING) - env.set_parallelism(parallelism) + env.set_parallelism(par) return env +def build_topic_source(env: StreamExecutionEnvironment, bootstrap: str, group_id: str, topic: str): + src = (KafkaSource.builder() + .set_bootstrap_servers(bootstrap) + .set_group_id(group_id) + .set_topics(topic) + .set_starting_offsets(KafkaOffsetsInitializer.latest()) + .set_value_only_deserializer(SimpleStringSchema()) + .build()) + ds = env.from_source(src, WatermarkStrategy.no_watermarks(), f"kafka-source-{topic}") + return ds.map(lambda v: json.dumps({"__topic": topic, "value": v}, ensure_ascii=False), output_type=Types.STRING()) + +def build_union_source(env: StreamExecutionEnvironment, bootstrap: str, group_id: str, topics: List[str]): + streams = [build_topic_source(env, bootstrap, group_id, t) for t in topics] + if not streams: + raise ValueError("No input topics provided") + out = streams[0] + for s in streams[1:]: + out = out.union(s) + return out + +def sink_to_topic(bootstrap: str, topic: str): + return (KafkaSink.builder() + .set_bootstrap_servers(bootstrap) + .set_record_serializer(KafkaRecordSerializationSchema.builder() + .set_topic(topic) + .set_value_serialization_schema(SimpleStringSchema()) + .build()) + .build()) + +# ---------- main ---------- +def main(): + a = parse_args() -def build_source(env: StreamExecutionEnvironment, bootstrap: str, group_id: str, topic: str): - src = ( - KafkaSource.builder() - .set_bootstrap_servers(bootstrap) - .set_group_id(group_id) - .set_topics(topic) - .set_starting_offsets(KafkaOffsetsInitializer.earliest()) - .set_value_only_deserializer(SimpleStringSchema()) - .build() - ) - return env.from_source(src, WatermarkStrategy.no_watermarks(), "kafka-source") - - -def build_sink(bootstrap: str, topic: str): - return ( - KafkaSink.builder() - .set_bootstrap_servers(bootstrap) - .set_record_serializer( - KafkaRecordSerializationSchema.builder() - .set_topic(topic) - .set_value_serialization_schema(SimpleStringSchema()) - .build() - ) - .build() - ) + topics = [t.strip() for t in (a.input_topics or "").split(",") if t.strip()] + if not topics: + raise SystemExit("must provide --input-topics (comma separated)") + mapping = Mapping(a.config_path, a.http_url_template) + env = build_env(a.parallelism) -def main(): - args = parse_args() - - env = build_env(args.parallelism) - ds = build_source(env, args.bootstrap, args.group_id, args.input_topic) - - mapper = HttpMap( - http_url=args.http_url, - connect_timeout=args.http_connect_timeout, - read_timeout=args.http_read_timeout, - max_retries=args.http_max_retries, - retry_backoff_s=args.http_retry_backoff_s, - team=args.team, + ds = build_union_source(env, a.bootstrap, a.group_id, topics) + + mapped = ds.map( + HttpMap(mapping, a.http_connect_timeout, a.http_read_timeout, a.http_max_retries, a.http_retry_backoff_s), + output_type=Types.STRING() ) - dispatched = ds.map(mapper, output_type=Types.STRING()) def _is_ok(s: str) -> bool: try: @@ -228,15 +189,13 @@ def _is_ok(s: str) -> bool: except Exception: return False - ok_stream = dispatched.filter(_is_ok) - bad_stream = dispatched.filter(lambda s: not _is_ok(s)) - - ok_topic = f"inference.dispatched.{args.team}" - ok_stream.sink_to(build_sink(args.bootstrap, ok_topic)) - bad_stream.sink_to(build_sink(args.bootstrap, args.dlq_topic)) + ok = mapped.filter(_is_ok) + bad = mapped.filter(lambda s: not _is_ok(s)) - env.execute(f"http-dispatcher-{args.team}") + ok.sink_to(sink_to_topic(a.bootstrap, a.ok_topic)) + bad.sink_to(sink_to_topic(a.bootstrap, a.dlq_topic)) + env.execute("http-dispatcher-multimodel") if __name__ == "__main__": main()