From 80214a4d2c66799acc0a79fb72f8316e294557cb Mon Sep 17 00:00:00 2001 From: Nikita Mashchenko Date: Wed, 26 Mar 2025 10:12:00 -0500 Subject: [PATCH 1/5] feat: mocked initial version for server --- requirements.txt | 1 + server/main.py | 90 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 server/main.py diff --git a/requirements.txt b/requirements.txt index e773848aa..b65d80ddf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ shapely klampt==0.9.2 pyyaml dacite +fastapi[standard] # Perception ultralytics diff --git a/server/main.py b/server/main.py new file mode 100644 index 000000000..4e1062324 --- /dev/null +++ b/server/main.py @@ -0,0 +1,90 @@ +from typing import List +from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel, Field +import time +import uuid +# import jwt + +SECRET_KEY = "CHANGE_ME_TO_SOMETHING_SECURE" + +app = FastAPI(title="GemStack Car‑Summon API (Mock)") + +### MODELS ### + +class LoginRequest(BaseModel): + username: str + password: str + +class CoordinatesRequest(BaseModel): + lat: float = Field(..., ge=-90, le=90) + lon: float = Field(..., ge=-180, le=180) + +class CoordinatesResponse(BaseModel): + current_position: CoordinatesRequest + optimized_route: List[CoordinatesRequest] + eta: str + +class SummonResponse(BaseModel): + launch_status: str + launch_id: str + +class CancelRequest(BaseModel): + launch_id: str + +class CancelResponse(BaseModel): + launch_id: str + status: str + +class StreamPosition(BaseModel): + current_position: CoordinatesRequest + launch_status: str + eta: str + +### HELPERS ### + +# def create_jwt(username: str) -> str: +# payload = {"sub": username, "jti": str(uuid.uuid4())} +# return jwt.encode(payload, SECRET_KEY, algorithm="HS256") + +### ENDPOINTS ### + +# @app.post("/api/login") +# def login(req: LoginRequest): +# if req.username == "admin" and req.password == "password": +# return {"token": create_jwt(req.username)} +# raise HTTPException(status_code=401, detail="Invalid credentials") + +@app.post("/api/coordinates", response_model=CoordinatesResponse) +def get_coordinates(req: CoordinatesRequest): + # Mock “optimized route” as a straight line of 3 waypoints + route = [ + CoordinatesRequest(lat=req.lat + 0.001 * i, lon=req.lon + 0.001 * i) + for i in range(1, 4) + ] + return CoordinatesResponse( + current_position=CoordinatesRequest(lat=req.lat, lon=req.lon), + optimized_route=route, + eta="5 min", + ) + +@app.post("/api/summon", response_model=SummonResponse) +def summon(req: CoordinatesRequest): + launch_id = str(uuid.uuid4()) + return SummonResponse(launch_status="launched", launch_id=launch_id) + +@app.get("/api/stream_position/{launch_id}") +def stream_position(launch_id: str): + def event_generator(): + lat, lon = 40.0930, -88.2350 + for i in range(5): + time.sleep(1) + lat += 0.0005 + lon += 0.0005 + yield f"data: {StreamPosition(current_position=CoordinatesRequest(lat=lat, lon=lon), launch_status='navigating', eta=f'{5-i} min').json()}\n\n" + yield "data: {\"launch_status\":\"arrived\"}\n\n" + return StreamingResponse(event_generator(), media_type="text/event-stream") + +@app.post("/api/cancel", response_model=CancelResponse) +def cancel(req: CancelRequest): + return CancelResponse(launch_id=req.launch_id, status="cancelled") From 2c1f076d13c97ccd509d3ab5fbd4019f8b26f504 Mon Sep 17 00:00:00 2001 From: injustli Date: Sun, 6 Apr 2025 21:06:08 -0500 Subject: [PATCH 2/5] Add get and post endpoints for inspection --- server/main.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/server/main.py b/server/main.py index 4e1062324..be8661d71 100644 --- a/server/main.py +++ b/server/main.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field import time import uuid + # import jwt SECRET_KEY = "CHANGE_ME_TO_SOMETHING_SECURE" @@ -12,35 +13,48 @@ ### MODELS ### + class LoginRequest(BaseModel): username: str password: str + class CoordinatesRequest(BaseModel): lat: float = Field(..., ge=-90, le=90) lon: float = Field(..., ge=-180, le=180) + class CoordinatesResponse(BaseModel): current_position: CoordinatesRequest optimized_route: List[CoordinatesRequest] eta: str + class SummonResponse(BaseModel): launch_status: str launch_id: str + class CancelRequest(BaseModel): launch_id: str + class CancelResponse(BaseModel): launch_id: str status: str + class StreamPosition(BaseModel): current_position: CoordinatesRequest launch_status: str eta: str + +class Coordinates(BaseModel): + lat: float + lon: float + + ### HELPERS ### # def create_jwt(username: str) -> str: @@ -55,6 +69,7 @@ class StreamPosition(BaseModel): # return {"token": create_jwt(req.username)} # raise HTTPException(status_code=401, detail="Invalid credentials") + @app.post("/api/coordinates", response_model=CoordinatesResponse) def get_coordinates(req: CoordinatesRequest): # Mock “optimized route” as a straight line of 3 waypoints @@ -68,11 +83,13 @@ def get_coordinates(req: CoordinatesRequest): eta="5 min", ) + @app.post("/api/summon", response_model=SummonResponse) def summon(req: CoordinatesRequest): launch_id = str(uuid.uuid4()) return SummonResponse(launch_status="launched", launch_id=launch_id) + @app.get("/api/stream_position/{launch_id}") def stream_position(launch_id: str): def event_generator(): @@ -82,9 +99,25 @@ def event_generator(): lat += 0.0005 lon += 0.0005 yield f"data: {StreamPosition(current_position=CoordinatesRequest(lat=lat, lon=lon), launch_status='navigating', eta=f'{5-i} min').json()}\n\n" - yield "data: {\"launch_status\":\"arrived\"}\n\n" + yield 'data: {"launch_status":"arrived"}\n\n' + return StreamingResponse(event_generator(), media_type="text/event-stream") + @app.post("/api/cancel", response_model=CancelResponse) def cancel(req: CancelRequest): return CancelResponse(launch_id=req.launch_id, status="cancelled") + + +bounding_box = None + + +@app.post("/api/inspect", status_code=201) +def get_bounding_box(coords: list[Coordinates]): + bounding_box = coords + return "Successfully retrieved bounding box coords!" + + +@app.get("/api/inspect", response_model=list[Coordinates], status_code=200) +def send_bounding_box(): + return bounding_box From 94bbcc6b987c69319aa7f95df599b74af0179baa Mon Sep 17 00:00:00 2001 From: injustli Date: Mon, 7 Apr 2025 00:01:31 -0500 Subject: [PATCH 3/5] Create endpoint for inspection --- server/main.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/server/main.py b/server/main.py index be8661d71..804c2cb34 100644 --- a/server/main.py +++ b/server/main.py @@ -1,6 +1,7 @@ from typing import List from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import time import uuid @@ -11,6 +12,14 @@ app = FastAPI(title="GemStack Car‑Summon API (Mock)") +app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost", "http://localhost:3000"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + ### MODELS ### @@ -52,7 +61,7 @@ class StreamPosition(BaseModel): class Coordinates(BaseModel): lat: float - lon: float + lng: float ### HELPERS ### @@ -114,10 +123,11 @@ def cancel(req: CancelRequest): @app.post("/api/inspect", status_code=201) def get_bounding_box(coords: list[Coordinates]): + global bounding_box bounding_box = coords return "Successfully retrieved bounding box coords!" -@app.get("/api/inspect", response_model=list[Coordinates], status_code=200) +@app.get("/api/inspect", response_model=list[Coordinates] | None, status_code=200) def send_bounding_box(): return bounding_box From c4bb905ad052a70776deabd13449d7f0678651f8 Mon Sep 17 00:00:00 2001 From: injustli Date: Mon, 7 Apr 2025 00:02:03 -0500 Subject: [PATCH 4/5] Add polling to server and print responses --- GEMstack/onboard/execution/execution.py | 745 +++++++++++++++--------- 1 file changed, 477 insertions(+), 268 deletions(-) diff --git a/GEMstack/onboard/execution/execution.py b/GEMstack/onboard/execution/execution.py index 9af6ee5f3..87baa2913 100644 --- a/GEMstack/onboard/execution/execution.py +++ b/GEMstack/onboard/execution/execution.py @@ -12,85 +12,96 @@ import io import contextlib import sys -from typing import Dict,Tuple,Set,List,Optional +from typing import Dict, Tuple, Set, List, Optional +import requests EXECUTION_PREFIX = "Execution:" EXECUTION_VERBOSITY = 1 # Define the computation graph -COMPONENTS = None +COMPONENTS = None COMPONENT_ORDER = None COMPONENT_SETTINGS = None LOGGING_MANAGER = None # type: LoggingManager -def executor_debug_print(verbosity : int, format : str, *args): + +def executor_debug_print(verbosity: int, format: str, *args): """Top level prints. Will be printed to stdout and logged.""" if EXECUTION_VERBOSITY >= verbosity: s = format.format(*args) - print(EXECUTION_PREFIX,s) + print(EXECUTION_PREFIX, s) if LOGGING_MANAGER is not None: - LOGGING_MANAGER.log_component_stdout('Executor',s.split('\n')) + LOGGING_MANAGER.log_component_stdout("Executor", s.split("\n")) + -def executor_debug_stderr(format : str, *args): +def executor_debug_stderr(format: str, *args): """Top level stderr prints. Will be printed to stderr and logged.""" s = format.format(*args) - print(EXECUTION_PREFIX,s,file=sys.stderr) + print(EXECUTION_PREFIX, s, file=sys.stderr) if LOGGING_MANAGER is not None: - LOGGING_MANAGER.log_component_stderr('Executor',s.split('\n')) + LOGGING_MANAGER.log_component_stderr("Executor", s.split("\n")) + -def executor_debug_exception(e : Exception, format: str, *args): +def executor_debug_exception(e: Exception, format: str, *args): """Top level exceptions. Will be printed to stderr and logged.""" - executor_debug_stderr(format,*args) + executor_debug_stderr(format, *args) import traceback + executor_debug_stderr(traceback.format_exc()) - executor_debug_print(0,format,*args) - executor_debug_print(0,traceback.format_exc()) + executor_debug_print(0, format, *args) + executor_debug_print(0, traceback.format_exc()) -def normalize_computation_graph(components : list) -> List[Dict]: +def normalize_computation_graph(components: list) -> List[Dict]: normalized_components = [] for c in components: - if isinstance(c,str): - normalized_components.append({c:{'inputs':[],'outputs':[]}}) + if isinstance(c, str): + normalized_components.append({c: {"inputs": [], "outputs": []}}) else: - assert isinstance(c,dict), "Component {} is not a string or dict".format(c) + assert isinstance(c, dict), "Component {} is not a string or dict".format(c) assert len(c) == 1, "Component {} dict has more than one key".format(c) k = list(c.keys())[0] v = c[k] - assert isinstance(v,dict), "Component {} value is not a string or dict".format(c) - if 'inputs' not in v: - v['inputs'] = [] - elif isinstance(v['inputs'],str): - v['inputs'] = [v['inputs']] - elif v['inputs'] is None: - v['inputs'] = [] - if 'outputs' not in v: - v['outputs'] = [] - elif isinstance(v['outputs'],str): - v['outputs'] = [v['outputs']] - elif v['outputs'] is None: - v['outputs'] = [] - normalized_components.append({k:v}) + assert isinstance( + v, dict + ), "Component {} value is not a string or dict".format(c) + if "inputs" not in v: + v["inputs"] = [] + elif isinstance(v["inputs"], str): + v["inputs"] = [v["inputs"]] + elif v["inputs"] is None: + v["inputs"] = [] + if "outputs" not in v: + v["outputs"] = [] + elif isinstance(v["outputs"], str): + v["outputs"] = [v["outputs"]] + elif v["outputs"] is None: + v["outputs"] = [] + normalized_components.append({k: v}) return normalized_components + def load_computation_graph(): """Loads the computation graph from settings[run.computation_graph.components] and sets global variables COMPONENTS, COMPONENT_ORDER, and COMPONENT_SETTINGS.""" global COMPONENTS, COMPONENT_ORDER, COMPONENT_SETTINGS - COMPONENTS = normalize_computation_graph(settings.get('run.computation_graph.components')) + COMPONENTS = normalize_computation_graph( + settings.get("run.computation_graph.components") + ) COMPONENT_ORDER = [list(c.keys())[0] for c in COMPONENTS] COMPONENT_SETTINGS = dict(list(c.items())[0] for c in COMPONENTS) + def import_module_dynamic(module_name, parent_module=None): if parent_module is not None: - full_path = parent_module + '.' + module_name + full_path = parent_module + "." + module_name else: full_path = module_name return importlib.import_module(full_path) -def make_class(config_info, component_module, parent_module=None, extra_args = None): +def make_class(config_info, component_module, parent_module=None, extra_args=None): """Creates an object from a config_info dictionary or string. Args: @@ -101,7 +112,7 @@ def make_class(config_info, component_module, parent_module=None, extra_args = N parent_module: if not None, the parent module to import from. extra_args: if provided, a dict of arguments to send to the component's constructor. - + Returns: Component: instance of named class """ @@ -109,39 +120,54 @@ def make_class(config_info, component_module, parent_module=None, extra_args = N extra_args = {} args = () kwargs = {} - if isinstance(config_info,str): - if '.' in config_info: - component_module,class_name = config_info.rsplit('.',1) + if isinstance(config_info, str): + if "." in config_info: + component_module, class_name = config_info.rsplit(".", 1) else: class_name = config_info else: - class_name = config_info['type'] - if '.' in class_name: - component_module,class_name = class_name.rsplit('.',1) - if 'module' in config_info: - component_module = config_info['module'] - if 'args' in config_info: - args = config_info['args'] - if isinstance(args,dict): + class_name = config_info["type"] + if "." in class_name: + component_module, class_name = class_name.rsplit(".", 1) + if "module" in config_info: + component_module = config_info["module"] + if "args" in config_info: + args = config_info["args"] + if isinstance(args, dict): kwargs = args args = () if parent_module is not None: - executor_debug_print(0,"Importing {} from {} to get {}",component_module,parent_module,class_name) + executor_debug_print( + 0, + "Importing {} from {} to get {}", + component_module, + parent_module, + class_name, + ) else: - executor_debug_print(0,"Importing {} to get {}",component_module,class_name) - module = import_module_dynamic(component_module,parent_module) - klass = getattr(module,class_name) + executor_debug_print(0, "Importing {} to get {}", component_module, class_name) + module = import_module_dynamic(component_module, parent_module) + klass = getattr(module, class_name) try: - return klass(*args,**kwargs,**extra_args) + return klass(*args, **kwargs, **extra_args) except TypeError: try: - return klass(*args,**kwargs) + return klass(*args, **kwargs) except TypeError: - executor_debug_print(0,"Unable to launch module {} with class {} and args {} kwargs {}",component_module,class_name,args,kwargs) + executor_debug_print( + 0, + "Unable to launch module {} with class {} and args {} kwargs {}", + component_module, + class_name, + args, + kwargs, + ) raise -def validate_components(components : Dict[str,ComponentExecutor], provided : List = None): +def validate_components( + components: Dict[str, ComponentExecutor], provided: List = None +): """Checks whether the defined components match the known computation graph""" state = asdict(AllState.zero()) if provided is None: @@ -152,35 +178,57 @@ def validate_components(components : Dict[str,ComponentExecutor], provided : Lis for k in COMPONENT_ORDER: if k not in components: continue - possible_inputs = COMPONENT_SETTINGS[k]['inputs'] - required_outputs = COMPONENT_SETTINGS[k]['outputs'] + possible_inputs = COMPONENT_SETTINGS[k]["inputs"] + required_outputs = COMPONENT_SETTINGS[k]["outputs"] c = components[k] inputs = c.c.state_inputs() for i in inputs: - if i == 'all': - assert possible_inputs == ['all'], "Component {} inputs are not provided by previous components".format(k) + if i == "all": + assert possible_inputs == [ + "all" + ], "Component {} inputs are not provided by previous components".format( + k + ) else: - assert provided_all or i in provided, "Component {} input {} is not provided by previous components".format(k,i) + assert ( + provided_all or i in provided + ), "Component {} input {} is not provided by previous components".format( + k, i + ) if i not in state: - executor_debug_print("Component {} input {} does not exist in AllState object",k,i) - if possible_inputs != ['all']: - assert i in possible_inputs, "Component {} is not supposed to receive input {}".format(k,i) + executor_debug_print( + "Component {} input {} does not exist in AllState object", k, i + ) + if possible_inputs != ["all"]: + assert ( + i in possible_inputs + ), "Component {} is not supposed to receive input {}".format(k, i) outputs = c.c.state_outputs() for o in required_outputs: - if o == 'all': - assert outputs == ['all'], "Component {} outputs are not provided by previous components".format(k) + if o == "all": + assert outputs == [ + "all" + ], "Component {} outputs are not provided by previous components".format( + k + ) else: - assert o in outputs, "Component {} doesn't output required output {}".format(k,o) + assert ( + o in outputs + ), "Component {} doesn't output required output {}".format(k, o) for o in outputs: - if 'all' != o: + if "all" != o: provided.add(o) if o not in state: - executor_debug_print("Component {} output {} does not exist in AllState object",k,o) + executor_debug_print( + "Component {} output {} does not exist in AllState object", k, o + ) else: provided_all = True - for k,c in components.items(): - executor_debug_print(0,"Component {} uses implementation {}",k,c.c.__class__.__name__) + for k, c in components.items(): + executor_debug_print( + 0, "Component {} uses implementation {}", k, c.c.__class__.__name__ + ) assert k in COMPONENT_SETTINGS, "Component {} is not known".format(k) return list(provided) @@ -188,38 +236,39 @@ def validate_components(components : Dict[str,ComponentExecutor], provided : Lis class Debugger: """A simple debugging interface that allows components to send debug messages to visualizations and loggers.""" + def __init__(self): self.handlers = [] # type: List[Debugger] - - def add_handler(self, handler : Debugger): + + def add_handler(self, handler: Debugger): self.handlers.append(handler) - - def debug(self, source : str, item : str, value): + + def debug(self, source: str, item: str, value): for h in self.handlers: h.debug(source, item, value) - - def debug_event(self, source : str, label : str): + + def debug_event(self, source: str, label: str): for h in self.handlers: h.debug_event(source, label) class ChildDebugger: - def __init__(self, parent : Debugger, source : str): + def __init__(self, parent: Debugger, source: str): self.parent = parent self.source = source - - def debug(self, item : str, value): + + def debug(self, item: str, value): self.parent.debug(self.source, item, value) - - def debug_event(self, label : str): - self.parent.debug_event(self.source, label) + def debug_event(self, label: str): + self.parent.debug_event(self.source, label) class ComponentExecutor: """Polls for whether a component should be updated, and reads/writes inputs / outputs to the AllState object.""" - def __init__(self, c : Component, essential : bool = True): + + def __init__(self, c: Component, essential: bool = True): self.c = c self.essential = essential self.do_debug = True @@ -231,15 +280,15 @@ def __init__(self, c : Component, essential : bool = True): self.next_update_time = None rate = c.rate() self.had_exception = False - self.dt = 1.0/rate if rate is not None else 0.0 + self.dt = 1.0 / rate if rate is not None else 0.0 self.num_overruns = 0 self.overrun_amount = 0.0 self.do_update = None - + def set_debugger(self, debugger): if self.do_debug: self.c.debugger = ChildDebugger(debugger, self.c.__class__.__name__) - + def healthy(self): return self.c.healthy() and not self.had_exception @@ -249,10 +298,10 @@ def start(self): def stop(self): self.c.cleanup() - def update(self, t : float, state : AllState): + def update(self, t: float, state: AllState): if self.next_update_time is None or t >= self.next_update_time: t0 = time.time() - self.update_now(t,state) + self.update_now(t, state) t1 = time.time() self.last_update_time = t if self.next_update_time is None: @@ -261,17 +310,38 @@ def update(self, t : float, state : AllState): self.next_update_time += self.dt if self.next_update_time < t and self.dt > 0: if t1 - t0 > self.dt: - executor_debug_print(1,"Component {} is running behind, time {} overran dt {} by {} s",self.c.__class__.__name__,t1-t0,self.dt,t-self.next_update_time) + executor_debug_print( + 1, + "Component {} is running behind, time {} overran dt {} by {} s", + self.c.__class__.__name__, + t1 - t0, + self.dt, + t - self.next_update_time, + ) else: - executor_debug_print(1,"Component {} is running behind (pushed back) overran dt {} by {} s",self.c.__class__.__name__,t1-t0,self.dt,t-self.next_update_time) + executor_debug_print( + 1, + "Component {} is running behind (pushed back) overran dt {} by {} s", + self.c.__class__.__name__, + t1 - t0, + self.dt, + t - self.next_update_time, + ) self.num_overruns += 1 self.overrun_amount += t - self.next_update_time self.next_update_time = t + self.dt return True - executor_debug_print(3,"Component {}","not updating at time {}, next update time is {}",self.c.__class__.__name__,t,self.next_update_time) + executor_debug_print( + 3, + "Component {}", + "not updating at time {}, next update time is {}", + self.c.__class__.__name__, + t, + self.next_update_time, + ) return False - def _do_update(self, t:float, *args): + def _do_update(self, t: float, *args): f = io.StringIO() g = io.StringIO() with contextlib.redirect_stdout(f): @@ -282,70 +352,77 @@ def _do_update(self, t:float, *args): else: res = self.c.update(*args) except Exception as e: - executor_debug_exception(e,"Exception in component {}: {}",self.c.__class__.__name__,e) + executor_debug_exception( + e, "Exception in component {}: {}", self.c.__class__.__name__, e + ) self.had_exception = True res = None - self.log_output(f.getvalue(),g.getvalue()) + self.log_output(f.getvalue(), g.getvalue()) return res - def update_now(self, t:float, state : AllState): + def update_now(self, t: float, state: AllState): """Performs the updates for this component, without fussing with the polling scheduling""" - if self.inputs == ['all']: + if self.inputs == ["all"]: args = (state,) else: - args = tuple([getattr(state,i) for i in self.inputs]) - executor_debug_print(2,"Updating {}",self.c.__class__.__name__) - #capture stdout/stderr + args = tuple([getattr(state, i) for i in self.inputs]) + executor_debug_print(2, "Updating {}", self.c.__class__.__name__) + # capture stdout/stderr res = self._do_update(t, *args) - #write result to state + # write result to state if res is not None: if len(self.output) > 1: - assert len(res) == len(self.output), "Component {} output {} does not match expected length {}".format(self.c.__class__.__name__,self.output,len(self.output)) - for (k,v) in zip(self.output,res): - setattr(state,k, v) - setattr(state,k+'_update_time', t) + assert len(res) == len( + self.output + ), "Component {} output {} does not match expected length {}".format( + self.c.__class__.__name__, self.output, len(self.output) + ) + for k, v in zip(self.output, res): + setattr(state, k, v) + setattr(state, k + "_update_time", t) else: - setattr(state,self.output[0], res) - setattr(state,self.output[0]+'_update_time', t) + setattr(state, self.output[0], res) + setattr(state, self.output[0] + "_update_time", t) - def log_output(self,stdout,stderr): + def log_output(self, stdout, stderr): if stdout: - lines = stdout.split('\n') - if len(lines) > 0 and len(lines[-1])==0: + lines = stdout.split("\n") + if len(lines) > 0 and len(lines[-1]) == 0: lines = lines[:-1] if self.print_stdout: - print("------ Component",self.c.__class__.__name__,"stdout ---------") + print("------ Component", self.c.__class__.__name__, "stdout ---------") for l in lines: - print(" ",l) + print(" ", l) print("-------------------------------------------") if LOGGING_MANAGER is not None: LOGGING_MANAGER.log_component_stdout(self.c.__class__.__name__, lines) if stderr: - lines = stderr.split('\n') - if len(lines) > 0 and len(lines[-1])==0: + lines = stderr.split("\n") + if len(lines) > 0 and len(lines[-1]) == 0: lines = lines[:-1] if self.print_stderr: - print("------ Component",self.c.__class__.__name__,"stderr ---------") + print("------ Component", self.c.__class__.__name__, "stderr ---------") for l in lines: - print(" ",l) + print(" ", l) print("-------------------------------------------") if LOGGING_MANAGER is not None: LOGGING_MANAGER.log_component_stderr(self.c.__class__.__name__, lines) - - class ExecutorBase: """Base class for a mission executor. Handles the computation graph setup. Subclasses should implement begin(), update(), done(), and end() methods.""" + def __init__(self, vehicle_interface): self.vehicle_interface = vehicle_interface self.all_components = dict() # type: Dict[str,ComponentExecutor] - self.always_run_components = dict() # type: Dict[str,ComponentExecutor] - self.pipelines = dict() # type: Dict[str,Tuple[Dict[str,ComponentExecutor],Dict[str,ComponentExecutor],Dict[str,ComponentExecutor]]] - self.current_pipeline = 'drive' # type: str - self.state = None # type: Optional[AllState] + self.always_run_components = dict() # type: Dict[str,ComponentExecutor] + self.pipelines = ( + dict() + ) # type: Dict[str,Tuple[Dict[str,ComponentExecutor],Dict[str,ComponentExecutor],Dict[str,ComponentExecutor]]] + self.current_pipeline = "drive" # type: str + self.state = None # type: Optional[AllState] self.logging_manager = LoggingManager() self.debugger = Debugger() self.debugger.add_handler(self.logging_manager) @@ -357,10 +434,11 @@ def begin(self): already been started and sensors will have been validated.""" pass - def update(self, state : AllState) -> Optional[str]: + def update(self, state: AllState) -> Optional[str]: """Override me to implement mission and pipeline switching logic. - - Returns the name of the next pipeline to run, or None to continue the current pipeline""" + + Returns the name of the next pipeline to run, or None to continue the current pipeline + """ return None def done(self): @@ -372,100 +450,136 @@ def end(self): the vehicle is stopped.""" pass - def make_component(self, config_info, component_name, parent_module=None, extra_args = None) -> ComponentExecutor: + def make_component( + self, config_info, component_name, parent_module=None, extra_args=None + ) -> ComponentExecutor: """Creates a component, caching the result. See arguments of :func:`make_class`. If the component was marked as being a replayed component, will return an executor of a LogReplay object. """ - identifier = str((component_name,config_info)) + identifier = str((component_name, config_info)) if identifier in self.all_components: return self.all_components[identifier] else: try: - component = make_class(config_info,component_name,parent_module,extra_args) + component = make_class( + config_info, component_name, parent_module, extra_args + ) except Exception as e: - executor_debug_exception(e,"Exception raised while trying to make component {} from config info:\n {}",component_name,config_info) + executor_debug_exception( + e, + "Exception raised while trying to make component {} from config info:\n {}", + component_name, + config_info, + ) raise - if not isinstance(component,Component): - raise RuntimeError("Component {} is not a subclass of Component".format(component_name)) - replacement = self.logging_manager.component_replayer(component_name, component) + if not isinstance(component, Component): + raise RuntimeError( + "Component {} is not a subclass of Component".format(component_name) + ) + replacement = self.logging_manager.component_replayer( + component_name, component + ) if replacement is not None: - executor_debug_print(1,"Replaying component {} from long {} with outputs {}",component_name,replacement.logfn,component.state_outputs()) + executor_debug_print( + 1, + "Replaying component {} from long {} with outputs {}", + component_name, + replacement.logfn, + component.state_outputs(), + ) component = replacement - if isinstance(config_info,dict) and config_info.get('multiprocess',False): - #wrap component in a multiprocess executor. TODO: not tested yet + if isinstance(config_info, dict) and config_info.get("multiprocess", False): + # wrap component in a multiprocess executor. TODO: not tested yet from .multiprocess_execution import MPComponentExecutor + executor = MPComponentExecutor(component) else: executor = ComponentExecutor(component) - if isinstance(config_info,dict): - executor.essential = config_info.get('essential',True) - if 'rate' in config_info: - executor.dt = 1.0/config_info['rate'] - executor.print_stderr = executor.print_stdout = config_info.get('print',True) - executor.do_debug = config_info.get('debug',True) + if isinstance(config_info, dict): + executor.essential = config_info.get("essential", True) + if "rate" in config_info: + executor.dt = 1.0 / config_info["rate"] + executor.print_stderr = executor.print_stdout = config_info.get( + "print", True + ) + executor.do_debug = config_info.get("debug", True) executor.set_debugger(self.debugger) self.all_components[identifier] = executor return executor - + def always_run(self, component_name, component: ComponentExecutor): """Adds a component the always-run set.""" self.always_run_components[component_name] = component - def add_pipeline(self,name : str, perception : Dict[str,ComponentExecutor], planning : Dict[str,ComponentExecutor], other : Dict[str,ComponentExecutor]): + def add_pipeline( + self, + name: str, + perception: Dict[str, ComponentExecutor], + planning: Dict[str, ComponentExecutor], + other: Dict[str, ComponentExecutor], + ): """Creates a new pipeline with the given components. The pipeline will be executed in the order perception, planning, other. """ output = validate_components(perception) output = validate_components(planning, output) validate_components(other, output) - self.pipelines[name] = (perception,planning,other) + self.pipelines[name] = (perception, planning, other) - def set_log_folder(self,folder : str): + def set_log_folder(self, folder: str): self.logging_manager.set_log_folder(folder) - def set_auto_plot(self,enabled : bool): + def set_auto_plot(self, enabled: bool): self.logging_manager.set_auto_plot(enabled) - - def log_vehicle_interface(self,enabled=True): + + def log_vehicle_interface(self, enabled=True): """Indicates that the vehicle interface should be logged""" if enabled: logger = self.logging_manager.log_vehicle_behavior(self.vehicle_interface) - self.always_run('vehicle_behavior_logger',ComponentExecutor(logger)) + self.always_run("vehicle_behavior_logger", ComponentExecutor(logger)) else: - raise NotImplementedError("Disabling vehicle interface logging not supported yet") - - def log_components(self,components : List[str]): + raise NotImplementedError( + "Disabling vehicle interface logging not supported yet" + ) + + def log_components(self, components: List[str]): """Indicates that the designated component outputs should be logged.""" self.logging_manager.log_components(components) - - def log_state(self,state_attributes : List[str], rate : Optional[float]=None): + + def log_state(self, state_attributes: List[str], rate: Optional[float] = None): """Indicates that the designated state attributes should be logged at the given rate.""" - logger = self.logging_manager.log_state(state_attributes,rate) - self.always_run('state_logger',ComponentExecutor(logger)) + logger = self.logging_manager.log_state(state_attributes, rate) + self.always_run("state_logger", ComponentExecutor(logger)) - def log_ros_topics(self, topics : List[str], rosbag_options : str = '') -> Optional[str]: + def log_ros_topics( + self, topics: List[str], rosbag_options: str = "" + ) -> Optional[str]: """Indicates that the designated ros topics should be logged with the given options.""" - command = self.logging_manager.log_ros_topics(topics,rosbag_options) + command = self.logging_manager.log_ros_topics(topics, rosbag_options) if command: - executor_debug_print(0,"Recording ROS topics with command {}",command) + executor_debug_print(0, "Recording ROS topics with command {}", command) - def replay_components(self, replayed_components : list, replay_folder : str): + def replay_components(self, replayed_components: list, replay_folder: str): """Declare that the given components should be replayed from a log folder. Further make_component calls to this component will be replaced with LogReplay objects. """ - self.logging_manager.replay_components(replayed_components,replay_folder) + self.logging_manager.replay_components(replayed_components, replay_folder) - def event(self,event_description : str, event_print_string : str = None): + def event(self, event_description: str, event_print_string: str = None): """Logs an event to the metadata and prints a message to the console.""" self.logging_manager.event(event_description) if EXECUTION_VERBOSITY >= 1: if event_print_string is None: - event_print_string = event_description if event_description.endswith('.') else event_description + '.' - executor_debug_print(1,event_print_string) + event_print_string = ( + event_description + if event_description.endswith(".") + else event_description + "." + ) + executor_debug_print(1, event_print_string) def set_exit_reason(self, description): """Sets a main loop exit reason""" @@ -474,97 +588,143 @@ def set_exit_reason(self, description): def run(self): """Main entry point. Runs the mission execution loop.""" global LOGGING_MANAGER - LOGGING_MANAGER = self.logging_manager #kludge! should refactor to avoid global variables + LOGGING_MANAGER = ( + self.logging_manager + ) # kludge! should refactor to avoid global variables - #sanity checking + # sanity checking if self.current_pipeline not in self.pipelines: - executor_debug_print(0,"Initial pipeline {} not found",self.current_pipeline) + executor_debug_print( + 0, "Initial pipeline {} not found", self.current_pipeline + ) return - #must have recovery pipeline - if 'recovery' not in self.pipelines: - executor_debug_print(0,"'recovery' pipeline not found") + # must have recovery pipeline + if "recovery" not in self.pipelines: + executor_debug_print(0, "'recovery' pipeline not found") return - #did we ask to replay any components that don't exist in any pipelines? + # did we ask to replay any components that don't exist in any pipelines? for c in self.logging_manager.replayed_components.keys(): found = False - for (name,(perception_components,planning_components,other_components)) in self.pipelines.items(): - if c in perception_components or c in planning_components or c in other_components: + for name, ( + perception_components, + planning_components, + other_components, + ) in self.pipelines.items(): + if ( + c in perception_components + or c in planning_components + or c in other_components + ): found = True break if not found: - raise ValueError("Replay component",c,"not found in any pipeline") + raise ValueError("Replay component", c, "not found in any pipeline") - #start running components - for k,c in self.all_components.items(): + # start running components + for k, c in self.all_components.items(): c.start() - #start running mission + # start running mission self.state = AllState.zero() self.state.mission.type = MissionEnum.IDLE - + validated = False try: validated = self.validate_sensors() if not validated: - self.event("Sensor validation failed","Could not validate sensors, stopping components and exiting") + self.event( + "Sensor validation failed", + "Could not validate sensors, stopping components and exiting", + ) self.set_exit_reason("Sensor validation failed") except KeyboardInterrupt: - self.event("Ctrl+C interrupt during sensor validation","Could not validate sensors, stopping components and exiting") + self.event( + "Ctrl+C interrupt during sensor validation", + "Could not validate sensors, stopping components and exiting", + ) self.set_exit_reason("Sensor validation failed") if time.time() - self.last_loop_time > 0.5: import traceback - executor_debug_print(1,"A component may have hung. Traceback:\n{}",traceback.format_exc()) + executor_debug_print( + 1, + "A component may have hung. Traceback:\n{}", + traceback.format_exc(), + ) + print("validated: ", validated) if validated: self.begin() while True: + self.state.t = self.vehicle_interface.time() self.logging_manager.pipeline_start_event(self.current_pipeline) try: - executor_debug_print(1,"Executing pipeline {}",self.current_pipeline) + executor_debug_print( + 1, "Executing pipeline {}", self.current_pipeline + ) next = self.run_until_switch() if next is None: - #done + # done self.set_exit_reason("normal exit") break if next not in self.pipelines: - executor_debug_print(1,"Pipeline {} not found, switching to recovery",next) - next = 'recovery' - if self.current_pipeline == 'recovery' and next == 'recovery': - executor_debug_print(1,"\ + executor_debug_print( + 1, "Pipeline {} not found, switching to recovery", next + ) + next = "recovery" + if self.current_pipeline == "recovery" and next == "recovery": + executor_debug_print( + 1, + "\ ************************************************\ Recovery pipeline is not working, exiting! \ - ************************************************") + ************************************************", + ) self.set_exit_reason("recovery pipeline not working") break self.current_pipeline = next if not self.validate_sensors(1): - self.event("Sensors in desired pipeline {} are not working, switching to recovery".format(self.current_pipeline)) - self.current_pipeline = 'recovery' + self.event( + "Sensors in desired pipeline {} are not working, switching to recovery".format( + self.current_pipeline + ) + ) + self.current_pipeline = "recovery" except KeyboardInterrupt: - if self.current_pipeline == 'recovery': - executor_debug_print(1,"\ + if self.current_pipeline == "recovery": + executor_debug_print( + 1, + "\ ************************************************\ Ctrl+C interrupt during recovery, exiting! \ - ************************************************") + ************************************************", + ) self.set_exit_reason("Ctrl+C interrupt during recovery") break - self.current_pipeline = 'recovery' + self.current_pipeline = "recovery" self.event("Ctrl+C pressed, switching to recovery mode") if time.time() - self.last_loop_time > 0.5: import traceback - executor_debug_print(1,"A component may have hung. Traceback:\n{}",traceback.format_exc()) - self.end() - #done with mission - self.event("Mission execution ended","Done with mission execution, stopping components and exiting") - #cleanup, whether validated or not - for k,c in self.all_components.items(): - executor_debug_print(2,"Stopping",k) + executor_debug_print( + 1, + "A component may have hung. Traceback:\n{}", + traceback.format_exc(), + ) + self.end() + # done with mission + self.event( + "Mission execution ended", + "Done with mission execution, stopping components and exiting", + ) + # cleanup, whether validated or not + + for k, c in self.all_components.items(): + executor_debug_print(2, "Stopping", k) c.stop() - + self.logging_manager.close() - executor_debug_print(0,"Done with execution loop") + executor_debug_print(0, "Done with execution loop") def check_for_hardware_faults(self): """Handles vehicle fault checking / logging""" @@ -572,8 +732,8 @@ def check_for_hardware_faults(self): new_faults = [] printed_faults = [] for f in faults: - if f == 'disengaged': - if not settings.get('run.require_engaged',False): + if f == "disengaged": + if not settings.get("run.require_engaged", False): continue if not f in self.last_hardware_faults: self.event("Vehicle disengaged") @@ -586,21 +746,27 @@ def check_for_hardware_faults(self): printed_faults.append(f) if printed_faults: if EXECUTION_VERBOSITY >= 1: - fault_strings = [(f + " (new)" if f in new_faults else f) for f in printed_faults] - executor_debug_print(1,"Hardware faults:",'\n '.join(fault_strings)) + fault_strings = [ + (f + " (new)" if f in new_faults else f) for f in printed_faults + ] + executor_debug_print(1, "Hardware faults:", "\n ".join(fault_strings)) elif new_faults: - executor_debug_print(0,"Hardware fault:",", ".join(new_faults)) + executor_debug_print(0, "Hardware fault:", ", ".join(new_faults)) self.last_hardware_faults = set(faults) - def validate_sensors(self,numsteps=None): + def validate_sensors(self, numsteps=None): """Verifies sensors are working""" - (perception_components,planning_components,other_components) = self.pipelines[self.current_pipeline] + (perception_components, planning_components, other_components) = self.pipelines[ + self.current_pipeline + ] if len(perception_components) == 0: return True - components = list(perception_components.values()) + list(self.always_run_components.values()) + components = list(perception_components.values()) + list( + self.always_run_components.values() + ) dt_min = min([c.dt for c in components if c.dt != 0.0]) - looper = TimedLooper(dt_min,name="main executor") + looper = TimedLooper(dt_min, name="main executor") sensors_working = False num_attempts = 0 t0 = time.time() @@ -610,96 +776,133 @@ def validate_sensors(self,numsteps=None): self.logging_manager.set_vehicle_time(self.state.t) self.last_loop_time = time.time() - #check for vehicle faults + # check for vehicle faults self.check_for_hardware_faults() - self.update_components(perception_components,self.state) + self.update_components(perception_components, self.state) sensors_working = all([c.healthy() for c in perception_components.values()]) - self.update_components(self.always_run_components,self.state,force=True) - always_run_working = all([c.healthy() for c in self.always_run_components.values()]) + self.update_components(self.always_run_components, self.state, force=True) + always_run_working = all( + [c.healthy() for c in self.always_run_components.values()] + ) if not always_run_working: - executor_debug_print(1,"Always-run components not working, ignoring") + executor_debug_print(1, "Always-run components not working, ignoring") num_attempts += 1 if numsteps is not None and num_attempts >= numsteps: return False if time.time() > next_print_time: - executor_debug_print(1,"Waiting for sensors to be healthy...") + executor_debug_print(1, "Waiting for sensors to be healthy...") next_print_time += 1.0 return True def run_until_switch(self): """Runs a pipeline until a switch is requested.""" - if self.current_pipeline == 'recovery': + if self.current_pipeline == "recovery": self.state.mission.type = MissionEnum.RECOVERY_STOP - - (perception_components,planning_components,other_components) = self.pipelines[self.current_pipeline] - components = list(perception_components.values()) + list(planning_components.values()) + list(other_components.values()) + list(self.always_run_components.values()) + + (perception_components, planning_components, other_components) = self.pipelines[ + self.current_pipeline + ] + components = ( + list(perception_components.values()) + + list(planning_components.values()) + + list(other_components.values()) + + list(self.always_run_components.values()) + ) dt_min = min([c.dt for c in components if c.dt != 0.0]) - looper = TimedLooper(dt_min,name="main executor") + looper = TimedLooper(dt_min, name="main executor") while looper and not self.done(): + response = requests.get("http://localhost:8000/api/inspect") + print("data: ", response.json()) self.state.t = self.vehicle_interface.time() self.logging_manager.set_vehicle_time(self.state.t) self.last_loop_time = time.time() - #check for vehicle faults + # check for vehicle faults self.check_for_hardware_faults() - - self.update_components(perception_components,self.state) - #check for faults - for name,c in perception_components.items(): + + self.update_components(perception_components, self.state) + # check for faults + for name, c in perception_components.items(): if not c.healthy(): - if c.essential and self.current_pipeline != 'recovery': - executor_debug_print(1,"Sensor {} not working, entering recovery mode",name) - return 'recovery' + if c.essential and self.current_pipeline != "recovery": + executor_debug_print( + 1, "Sensor {} not working, entering recovery mode", name + ) + return "recovery" else: - executor_debug_print(1,"Warning, sensor {} not working, ignoring",name) - + executor_debug_print( + 1, "Warning, sensor {} not working, ignoring", name + ) + next_pipeline = self.update(self.state) if next_pipeline is not None and next_pipeline != self.current_pipeline: - executor_debug_print(0,"update() requests to switch to pipeline {}",next_pipeline) + executor_debug_print( + 0, "update() requests to switch to pipeline {}", next_pipeline + ) return next_pipeline - self.update_components(planning_components,self.state) - #check for faults - for name,c in planning_components.items(): + self.update_components(planning_components, self.state) + # check for faults + for name, c in planning_components.items(): if not c.healthy(): - if c.essential and self.current_pipeline != 'recovery': - executor_debug_print(1,"Planner {} not working, entering recovery mode",name) - return 'recovery' + if c.essential and self.current_pipeline != "recovery": + executor_debug_print( + 1, "Planner {} not working, entering recovery mode", name + ) + return "recovery" else: - executor_debug_print(1,"Warning, planner {} not working, ignoring",name) + executor_debug_print( + 1, "Warning, planner {} not working, ignoring", name + ) - self.update_components(other_components,self.state) - for name,c in other_components.items(): + self.update_components(other_components, self.state) + for name, c in other_components.items(): if not c.healthy(): - if c.essential and self.current_pipeline != 'recovery': - executor_debug_print(1,"Other component {} not working, entering recovery mode",name) - return 'recovery' + if c.essential and self.current_pipeline != "recovery": + executor_debug_print( + 1, + "Other component {} not working, entering recovery mode", + name, + ) + return "recovery" else: - executor_debug_print(1,"Warning, other component {} not working",name) + executor_debug_print( + 1, "Warning, other component {} not working", name + ) - self.update_components(self.always_run_components,self.state,force=True) - for name,c in self.always_run_components.items(): + self.update_components(self.always_run_components, self.state, force=True) + for name, c in self.always_run_components.items(): if not c.healthy(): - if c.essential and self.current_pipeline != 'recovery': - executor_debug_print(1,"Always-run component {} not working, entering recovery mode",name) - return 'recovery' + if c.essential and self.current_pipeline != "recovery": + executor_debug_print( + 1, + "Always-run component {} not working, entering recovery mode", + name, + ) + return "recovery" else: - executor_debug_print(1,"Warning, always-run component {} not working",name) + executor_debug_print( + 1, "Warning, always-run component {} not working", name + ) - - #self.done() returned True + # self.done() returned True return None - - def update_components(self, components : Dict[str,ComponentExecutor], state : AllState, now = False, force = False): + def update_components( + self, + components: Dict[str, ComponentExecutor], + state: AllState, + now=False, + force=False, + ): """Updates the components and performs necessary logging. - + If now = True, all components are run regardless of polling state. - If force = False, only components listed in COMPONENT_ORDER are run. + If force = False, only components listed in COMPONENT_ORDER are run. Otherwise, all components in `components` are run in arbitrary order. """ t = state.t @@ -713,26 +916,32 @@ def update_components(self, components : Dict[str,ComponentExecutor], state : Al for k in order: updated = False if now: - components[k].update_now(t,state) + components[k].update_now(t, state) updated = True else: - updated = components[k].update(t,state) - #log component output if necessary + updated = components[k].update(t, state) + # log component output if necessary if updated: - self.logging_manager.log_component_update(k, state, components[k].output) + self.logging_manager.log_component_update( + k, state, components[k].output + ) class StandardExecutor(ExecutorBase): def __init__(self, vehicle_interface): - ExecutorBase.__init__(self,vehicle_interface) - + ExecutorBase.__init__(self, vehicle_interface) + def done(self): - if self.current_pipeline == 'recovery': - if self.vehicle_interface.last_reading is not None and \ - abs(self.vehicle_interface.last_reading.speed) < 1e-3: - executor_debug_print(1,"Vehicle has stopped, exiting execution loop.") + if self.current_pipeline == "recovery": + if ( + self.vehicle_interface.last_reading is not None + and abs(self.vehicle_interface.last_reading.speed) < 1e-3 + ): + executor_debug_print(1, "Vehicle has stopped, exiting execution loop.") return True - if 'disengaged' in self.vehicle_interface.hardware_faults(): - executor_debug_print(1,"Vehicle has disengaged, exiting execution loop.") + if "disengaged" in self.vehicle_interface.hardware_faults(): + executor_debug_print( + 1, "Vehicle has disengaged, exiting execution loop." + ) return True return False From 3be489dfec7d3ef33c6e1f199d820408bf51fa72 Mon Sep 17 00:00:00 2001 From: Nikita Mashchenko Date: Mon, 7 Apr 2025 01:50:55 -0500 Subject: [PATCH 5/5] wip: updated server --- server/main.py | 310 +++++++++++++++++++++--------------- server/message_constants.py | 33 ++++ 2 files changed, 213 insertions(+), 130 deletions(-) create mode 100644 server/message_constants.py diff --git a/server/main.py b/server/main.py index 804c2cb34..b71f91ad4 100644 --- a/server/main.py +++ b/server/main.py @@ -1,133 +1,183 @@ -from typing import List -from fastapi import FastAPI, HTTPException -from fastapi.responses import JSONResponse, StreamingResponse -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, Field -import time +import asyncio +import json +import logging +import websockets import uuid - -# import jwt - -SECRET_KEY = "CHANGE_ME_TO_SOMETHING_SECURE" - -app = FastAPI(title="GemStack Car‑Summon API (Mock)") - -app.add_middleware( - CORSMiddleware, - allow_origins=["http://localhost", "http://localhost:3000"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], +from datetime import datetime +from message_constants import ClientRole, MessageType, MissionEnum + +# configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', + handlers=[logging.StreamHandler()] ) -### MODELS ### - - -class LoginRequest(BaseModel): - username: str - password: str - - -class CoordinatesRequest(BaseModel): - lat: float = Field(..., ge=-90, le=90) - lon: float = Field(..., ge=-180, le=180) - - -class CoordinatesResponse(BaseModel): - current_position: CoordinatesRequest - optimized_route: List[CoordinatesRequest] - eta: str - - -class SummonResponse(BaseModel): - launch_status: str - launch_id: str - - -class CancelRequest(BaseModel): - launch_id: str - - -class CancelResponse(BaseModel): - launch_id: str - status: str - - -class StreamPosition(BaseModel): - current_position: CoordinatesRequest - launch_status: str - eta: str - - -class Coordinates(BaseModel): - lat: float - lng: float - - -### HELPERS ### - -# def create_jwt(username: str) -> str: -# payload = {"sub": username, "jti": str(uuid.uuid4())} -# return jwt.encode(payload, SECRET_KEY, algorithm="HS256") - -### ENDPOINTS ### - -# @app.post("/api/login") -# def login(req: LoginRequest): -# if req.username == "admin" and req.password == "password": -# return {"token": create_jwt(req.username)} -# raise HTTPException(status_code=401, detail="Invalid credentials") - - -@app.post("/api/coordinates", response_model=CoordinatesResponse) -def get_coordinates(req: CoordinatesRequest): - # Mock “optimized route” as a straight line of 3 waypoints - route = [ - CoordinatesRequest(lat=req.lat + 0.001 * i, lon=req.lon + 0.001 * i) - for i in range(1, 4) - ] - return CoordinatesResponse( - current_position=CoordinatesRequest(lat=req.lat, lon=req.lon), - optimized_route=route, - eta="5 min", - ) - - -@app.post("/api/summon", response_model=SummonResponse) -def summon(req: CoordinatesRequest): - launch_id = str(uuid.uuid4()) - return SummonResponse(launch_status="launched", launch_id=launch_id) - - -@app.get("/api/stream_position/{launch_id}") -def stream_position(launch_id: str): - def event_generator(): - lat, lon = 40.0930, -88.2350 - for i in range(5): - time.sleep(1) - lat += 0.0005 - lon += 0.0005 - yield f"data: {StreamPosition(current_position=CoordinatesRequest(lat=lat, lon=lon), launch_status='navigating', eta=f'{5-i} min').json()}\n\n" - yield 'data: {"launch_status":"arrived"}\n\n' - - return StreamingResponse(event_generator(), media_type="text/event-stream") - - -@app.post("/api/cancel", response_model=CancelResponse) -def cancel(req: CancelRequest): - return CancelResponse(launch_id=req.launch_id, status="cancelled") - - -bounding_box = None - - -@app.post("/api/inspect", status_code=201) -def get_bounding_box(coords: list[Coordinates]): - global bounding_box - bounding_box = coords - return "Successfully retrieved bounding box coords!" - - -@app.get("/api/inspect", response_model=list[Coordinates] | None, status_code=200) -def send_bounding_box(): - return bounding_box +# store connected clients with their roles +connected_clients = {} # {websocket: {"id": client_id, "role": role}} +# store active summoning missions +active_missions = {} +# track executed summons to prevent duplicates +executed_summons = set() # set of coordinate tuples (x, y) + +async def handle_client(websocket): + """handle a client connection.""" + client_id = str(uuid.uuid4()) + client_role = ClientRole.UNKNOWN + + # wait for initial registration message to determine role + try: + # set a timeout for registration + registration_message = await asyncio.wait_for(websocket.recv(), timeout=10.0) + data = json.loads(registration_message) + + if "role" in data: + role_str = data["role"].lower() + if role_str == "webapp": + client_role = ClientRole.WEBAPP + elif role_str == "server": + client_role = ClientRole.SERVER + elif role_str == "gemstack": + client_role = ClientRole.GEMSTACK + + logging.info(f"new client connected: {client_id} with role {client_role}") + + # send acknowledgment + await websocket.send(json.dumps({ + "type": MessageType.REGISTRATION_RESPONSE, + "client_id": client_id, + "role": client_role, + "status": "connected" + })) + + # store client with role + connected_clients[websocket] = {"id": client_id, "role": client_role} + + # process messages + async for message in websocket: + logging.info(f"received message from {client_id} ({client_role}): {message}") + + try: + data = json.loads(message) + + # add source role to the message for processing + data["source_role"] = client_role + data["source_id"] = client_id + + # handle different message types + if "type" in data: + msg_type = data["type"] + if msg_type == MessageType.SUMMON: + await handle_summon_request(websocket, data, client_id, client_role) + else: + logging.warning(f"unknown or unimplemented message type: {msg_type}") + else: + logging.warning("message missing 'type' field") + + except json.JSONDecodeError: + logging.error(f"invalid JSON: {message}") + await websocket.send(json.dumps({ + "type": MessageType.ERROR, + "message": "invalid JSON format" + })) + + except asyncio.TimeoutError: + logging.warning(f"client {client_id} registration timed out") + await websocket.send(json.dumps({ + "type": MessageType.ERROR, + "message": "registration timed out. please identify your role." + })) + return + except websockets.exceptions.ConnectionClosed: + logging.info(f"client disconnected during registration: {client_id}") + except Exception as e: + logging.error(f"error during client handling: {str(e)}") + finally: + if websocket in connected_clients: + del connected_clients[websocket] + logging.info(f"client removed: {client_id} ({client_role})") + +async def handle_summon_request(websocket, data, client_id, client_role): + """process a summoning request with coordinates.""" + # note: this is a temporary check to ensure the request is coming from a trusted source + # TODO: remove this check when we have a proper authentication mechanism + if client_role != ClientRole.WEBAPP and client_role != ClientRole.SERVER: + logging.error(f"unauthorized role ({client_role}) for summon request") + await websocket.send(json.dumps({ + "type": MessageType.ERROR, + "message": f"unauthorized role ({client_role}) for summon request" + })) + return + + if "coordinates" not in data: + await websocket.send(json.dumps({ + "type": MessageType.ERROR, + "message": "missing coordinates in summon request" + })) + return + + coords = data["coordinates"] + # validate coordinates format + if not all(k in coords for k in ["x", "y"]): + logging.error(f"invalid coordinates: {coords}") + await websocket.send(json.dumps({ + "type": MessageType.ERROR, + "message": "coordinates must include x and y values" + })) + return + + # check if these coordinates have already been executed + coord_tuple = (coords["x"], coords["y"]) + if coord_tuple in executed_summons: + logging.error(f"summon to coordinates {coord_tuple} was already executed") + await websocket.send(json.dumps({ + "type": MessageType.ERROR, + "message": f"summon to coordinates {coord_tuple} was already executed", + "source_role": ClientRole.SERVER + })) + return + + # create a new mission + mission_id = str(uuid.uuid4()) + timestamp = datetime.now().isoformat() + + mission = { + "id": mission_id, + "client_id": client_id, + "client_role": client_role, + "coordinates": coords, + "mission_enum": MissionEnum.DRIVE.value, + "timestamp": timestamp + } + + active_missions[mission_id] = mission + executed_summons.add(coord_tuple) + + # send response to client + await websocket.send(json.dumps({ + "type": MessageType.SUMMON_RESPONSE, + "mission_id": mission_id, + "mission_enum": MissionEnum.DRIVE.value, + "timestamp": timestamp, + "source_role": ClientRole.SERVER + })) + + # Note: Not implementing broadcast to GEMstack clients yet + logging.info(f"summon request processed for mission {mission_id} to coordinates ({coords['x']}, {coords['y']})") + +async def main(): + """start the websocket server.""" + host = "localhost" + port = 8765 + + # updated serve call for newer websockets versions + async with websockets.serve(handle_client, host, port): + logging.info(f"websocket server started at ws://{host}:{port}") + # keep the server running indefinitely + await asyncio.Future() # run forever + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logging.info("server shutdown initiated by user") diff --git a/server/message_constants.py b/server/message_constants.py new file mode 100644 index 000000000..a94b0dff5 --- /dev/null +++ b/server/message_constants.py @@ -0,0 +1,33 @@ +from enum import Enum + +# note: intentionally redefining the enum here to separately deploy server +class MissionEnum(Enum): + IDLE = 0 # not driving, no mission + DRIVE = 1 # normal driving with routing + DRIVE_ROUTE = 2 # normal driving with a fixed route + TELEOP = 3 # manual teleop control + RECOVERY_STOP = 4 # abnormal condition detected, must stop now + ESTOP = 5 # estop pressed, must stop now + +# note: this is primitive config, should be replaced with a more robust auth system +class ClientRole(str, Enum): + WEBAPP = "webapp" + SERVER = "server" + GEMSTACK = "gemstack" + UNKNOWN = "unknown" + +# define message types +class MessageType(str, Enum): + # client registration + REGISTER = "register" + REGISTRATION_RESPONSE = "registration_response" + + # requests from clients + SUMMON = "summon" + + # responses from server + SUMMON_RESPONSE = "summon_response" + ERROR = "error" + + # events + LAUNCH_EVENT = "launch_event"