Skip to content

Commit 02f922d

Browse files
Merge pull request #24 from Association-INTech/HL
HL IS FINISHED
2 parents eb6500d + 40423fd commit 02f922d

18 files changed

Lines changed: 449 additions & 393 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ Debug_Wayfinding
1010
.venv
1111
*.onnx
1212
checkpoints
13+
*.wbproj

scripts/lanch_one_simu.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import os
2+
import sys
3+
4+
from typing import *
5+
import numpy as np
6+
import onnxruntime as ort
7+
8+
simu_path = __file__.rsplit('/', 2)[0] + '/src/Simulateur'
9+
if simu_path not in sys.path:
10+
sys.path.insert(0, simu_path)
11+
12+
from onnx_utils import run_onnx_model
13+
from config import *
14+
from WebotsSimulationGymEnvironment import WebotsSimulationGymEnvironment
15+
from TemporalResNetExtractor import TemporalResNetExtractor
16+
from CNN1DResNetExtractor import CNN1DResNetExtractor
17+
# -------------------------------------------------------------------------
18+
19+
20+
21+
# --- Chemin vers le fichier ONNX ---
22+
23+
ONNX_MODEL_PATH = "model.onnx"
24+
25+
# --- Initialisation du moteur d'inférence ONNX Runtime (ORT) ---
26+
def init_onnx_runtime_session(onnx_path: str) -> ort.InferenceSession:
27+
if not os.path.exists(onnx_path):
28+
raise FileNotFoundError(f"Le fichier ONNX est introuvable à : {onnx_path}. Veuillez l'exporter d'abord.")
29+
30+
# Crée la session d'inférence
31+
return ort.InferenceSession(onnx_path) #On peut modifier le providers afin de mettre une CUDA
32+
33+
34+
if __name__ == "__main__":
35+
if not os.path.exists("/tmp/autotech/"):
36+
os.mkdir("/tmp/autotech/")
37+
38+
os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi')
39+
40+
41+
# 2. Initialisation de la session ONNX Runtime
42+
try:
43+
ort_session = init_onnx_runtime_session(ONNX_MODEL_PATH)
44+
input_name = ort_session.get_inputs()[0].name
45+
output_name = ort_session.get_outputs()[0].name
46+
print(f"Modèle ONNX chargé depuis {ONNX_MODEL_PATH}")
47+
print(f"Input Name: {input_name}, Output Name: {output_name}")
48+
except FileNotFoundError as e:
49+
print(f"ERREUR : {e}")
50+
print(
51+
"Veuillez vous assurer que vous avez exécuté une fois le script d'entraînement pour exporter 'model.onnx'.")
52+
sys.exit(1)
53+
54+
# 3. Boucle d'inférence (Test)
55+
env = WebotsSimulationGymEnvironment(0,0)
56+
obs = env.reset()
57+
print("Début de la simulation en mode inférence...")
58+
59+
max_steps = 5000
60+
step_count = 0
61+
62+
while True:
63+
64+
action = run_onnx_model(ort_session,obs)
65+
66+
# 4. Exécuter l'action dans l'environnement
67+
obs, reward, done, info = env.step(action)
68+
69+
# Note: L'environnement Webots gère généralement son propre affichage
70+
# env.render() # Décommenter si votre env supporte le rendu externe
71+
72+
# Gestion des fins d'épisodes
73+
if done:
74+
print(f"Épisode(s) terminé(s) après {step_count} étapes.")
75+
obs = env.reset()
76+
77+
78+
79+
# Fermeture propre (très important pour les processus parallèles SubprocVecEnv)
80+
envs.close()
81+
print("Simulation terminée. Environnements fermés.")
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import logging
2+
import os
3+
import sys
4+
5+
from typing import *
6+
7+
import torch.nn as nn
8+
9+
from stable_baselines3 import PPO
10+
from stable_baselines3.common.vec_env import SubprocVecEnv
11+
12+
simu_path = __file__.rsplit('/', 2)[0] + '/src/Simulateur'
13+
if simu_path not in sys.path:
14+
sys.path.insert(0, simu_path)
15+
16+
from Simulateur.config import LOG_LEVEL
17+
from config import *
18+
from TemporalResNetExtractor import TemporalResNetExtractor
19+
from CNN1DResNetExtractor import CNN1DResNetExtractor
20+
from onnx_utils import *
21+
22+
from WebotsSimulationGymEnvironment import WebotsSimulationGymEnvironment
23+
if LOG_LEVEL == logging.DEBUG: from DynamicActionPlotCallback import DynamicActionPlotDistributionCallback
24+
25+
26+
if __name__ == "__main__":
27+
28+
if not os.path.exists("/tmp/autotech/"):
29+
os.mkdir("/tmp/autotech/")
30+
31+
os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi')
32+
33+
34+
def make_env(simulation_rank: int, vehicle_rank: int):
35+
if LOG_LEVEL == logging.DEBUG:
36+
print("CAREFUL !!! created an SERVER env with {simulation_rank}_{vehicle_rank}")
37+
return WebotsSimulationGymEnvironment(simulation_rank, vehicle_rank)
38+
39+
envs = SubprocVecEnv([lambda simulation_rank=simulation_rank, vehicle_rank=vehicle_rank : make_env(simulation_rank, vehicle_rank) for vehicle_rank in range(n_vehicles) for simulation_rank in range(n_simulations)])
40+
41+
ExtractorClass = CNN1DResNetExtractor
42+
43+
policy_kwargs = dict(
44+
features_extractor_class=ExtractorClass,
45+
features_extractor_kwargs=dict(
46+
context_size=context_size,
47+
lidar_horizontal_resolution=lidar_horizontal_resolution,
48+
camera_horizontal_resolution=camera_horizontal_resolution,
49+
device=device
50+
),
51+
activation_fn=nn.ReLU,
52+
net_arch=[512, 512, 512],
53+
)
54+
55+
56+
ppo_args = dict(
57+
n_steps=4096,
58+
n_epochs=10,
59+
batch_size=256,
60+
learning_rate=3e-4,
61+
gamma=0.99,
62+
verbose=1,
63+
normalize_advantage=True,
64+
device=device
65+
)
66+
67+
68+
save_path = __file__.rsplit("/", 1)[0] + "/checkpoints/" + ExtractorClass.__name__ + "/"
69+
os.makedirs(save_path, exist_ok=True)
70+
71+
72+
print(save_path)
73+
print(os.listdir(save_path))
74+
75+
valid_files = [x for x in os.listdir(save_path) if x.rstrip(".zip").isnumeric()]
76+
77+
if valid_files:
78+
model_name = max(
79+
valid_files,
80+
key=lambda x : int(x.rstrip(".zip"))
81+
)
82+
print(f"Loading model {save_path + model_name}")
83+
model = PPO.load(
84+
save_path + model_name,
85+
envs,
86+
**ppo_args,
87+
policy_kwargs=policy_kwargs
88+
)
89+
i = int(model_name.rstrip(".zip")) + 1
90+
print(f"----- Model found, loading {model_name} -----")
91+
92+
else:
93+
model = PPO(
94+
"MlpPolicy",
95+
envs,
96+
**ppo_args,
97+
policy_kwargs=policy_kwargs
98+
)
99+
100+
i = 0
101+
print("----- Model not found, creating a new one -----")
102+
103+
print("MODEL HAS HYPER PARAMETERS:")
104+
print(f"{model.learning_rate=}")
105+
print(f"{model.gamma=}")
106+
print(f"{model.verbose=}")
107+
print(f"{model.n_steps=}")
108+
print(f"{model.n_epochs=}")
109+
print(f"{model.batch_size=}")
110+
print(f"{model.device=}")
111+
112+
print("SERVER : finished executing")
113+
114+
# obs = envs.reset()
115+
# while True:
116+
# action, _states = model.predict(obs, deterministic=True) # Use deterministic=True for evaluation
117+
# obs, reward, done, info = envs.step(action)
118+
# envs.render() # Optional: visualize the environment
119+
120+
121+
while True:
122+
export_onnx(model)
123+
test_onnx(model)
124+
125+
if LOG_LEVEL <= logging.DEBUG:
126+
model.learn(total_timesteps=500_000, callback=DynamicActionPlotDistributionCallback())
127+
else:
128+
model.learn(total_timesteps=500_000)
129+
130+
print("iteration over")
131+
132+
model.save(save_path + str(i))
133+
134+
i += 1
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import os
2+
from typing import *
3+
import numpy as np
4+
import gymnasium as gym
5+
6+
from config import *
7+
8+
9+
class WebotsSimulationGymEnvironment(gym.Env):
10+
"""
11+
One environment for each vehicle
12+
13+
n: index of the vehicle
14+
supervisor: the supervisor of the simulation
15+
"""
16+
17+
def __init__(self, simulation_rank: int, vehicle_rank: int):
18+
19+
20+
super().__init__()
21+
self.simulation_rank = simulation_rank
22+
self.vehicle_rank = vehicle_rank
23+
24+
self.handler = logging.FileHandler(f"/tmp/autotech/Voiture_{self.simulation_rank}_{self.vehicle_rank}.log")
25+
self.handler.setFormatter(FORMATTER)
26+
self.log = logging.getLogger("SERVER")
27+
self.log.setLevel(level=LOG_LEVEL)
28+
self.log.addHandler(self.handler)
29+
30+
31+
self.log.info("Initialisation started")
32+
33+
# this is only true if lidar_horizontal_resolution = camera_horizontal_resolution
34+
box_min = np.zeros([2, context_size, lidar_horizontal_resolution], dtype=np.float32)
35+
box_max = np.ones([2, context_size, lidar_horizontal_resolution], dtype=np.float32) * 30
36+
37+
self.observation_space = gym.spaces.Box(box_min, box_max, dtype=np.float32)
38+
self.action_space = gym.spaces.MultiDiscrete([n_actions_steering, n_actions_speed])
39+
40+
if not os.path.exists("/tmp/autotech"):
41+
os.mkdir("/tmp/autotech")
42+
43+
self.log.debug(f"Creation of the pipes")
44+
45+
os.mkfifo(f"/tmp/autotech/{simulation_rank}_{vehicle_rank}toserver.pipe")
46+
os.mkfifo(f"/tmp/autotech/serverto{simulation_rank}_{vehicle_rank}.pipe")
47+
os.mkfifo(f"/tmp/autotech/{simulation_rank}_{vehicle_rank}tosupervisor.pipe")
48+
49+
# --mode=fast --minimize --no-rendering --batch --stdout
50+
if vehicle_rank == 0 :
51+
os.system(f"""
52+
webots {__file__.rsplit('/', 1)[0]}/worlds/piste{simulation_rank % n_map}.wbt --mode=fast --minimize --batch --stdout &
53+
echo $! {simulation_rank}_{vehicle_rank} >>/tmp/autotech/simulationranks
54+
""")
55+
56+
self.log.debug("Connection to the vehicle")
57+
self.fifo_w = open(f"/tmp/autotech/serverto{simulation_rank}_{vehicle_rank}.pipe", "wb")
58+
self.log.debug("Connection to the supervisor")
59+
self.fifo_r = open(f"/tmp/autotech/{simulation_rank}_{vehicle_rank}toserver.pipe", "rb")
60+
61+
self.log.info("Initialisation finished\n")
62+
63+
def reset(self, seed=0):
64+
# basically useless function
65+
66+
# lidar data
67+
# this is true for lidar_horizontal_resolution = camera_horizontal_resolution
68+
self.context = obs = np.zeros([2, context_size, lidar_horizontal_resolution], dtype=np.float32)
69+
info = {}
70+
self.log.info(f"reset finished\n")
71+
return obs, info
72+
73+
def step(self, action):
74+
75+
self.log.info("Starting step")
76+
self.log.info(f"sending {action=}")
77+
self.fifo_w.write(action.tobytes())
78+
self.fifo_w.flush()
79+
80+
# communication with the supervisor
81+
self.log.debug("trying to get info from supervisor")
82+
cur_state = np.frombuffer(self.fifo_r.read(np.dtype(np.float32).itemsize * (n_sensors + lidar_horizontal_resolution + camera_horizontal_resolution)), dtype=np.float32)
83+
self.log.info(f"received {cur_state=}")
84+
reward = np.frombuffer(self.fifo_r.read(np.dtype(np.float32).itemsize), dtype=np.float32)[0] # scalar
85+
self.log.info(f"received {reward=}")
86+
done = np.frombuffer(self.fifo_r.read(np.dtype(np.bool).itemsize), dtype=np.bool)[0] # scalar
87+
self.log.info(f"received {done=}")
88+
truncated = np.frombuffer(self.fifo_r.read(np.dtype(np.bool).itemsize), dtype=np.bool)[0] # scalar
89+
self.log.info(f"received {truncated=}")
90+
info = {}
91+
92+
cur_state = np.nan_to_num(cur_state[n_sensors:], nan=0., posinf=30.)
93+
94+
lidar_obs = cur_state[:lidar_horizontal_resolution]
95+
camera_obs = cur_state[lidar_horizontal_resolution:]
96+
97+
self.context = obs = np.concatenate([
98+
self.context[:, 1:],
99+
[lidar_obs[None], camera_obs[None]]
100+
], axis=1)
101+
102+
self.log.info("step over")
103+
104+
return obs, reward, done, truncated, info
105+
106+

src/Simulateur/__init__.py

Whitespace-only changes.

src/Simulateur/config.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
# just a file that lets us define some constants that are used in multiple files the simulation
22
from torch.cuda import is_available
3+
import logging
4+
35

46
n_map = 2
5-
n_simulations = 8
6-
n_vehicles = 1
7+
n_simulations = 2
8+
n_vehicles = 2
79
n_stupid_vehicles = 0
810
n_actions_steering = 16
911
n_actions_speed = 16
1012
n_sensors = 1
1113
lidar_max_range = 12.0
1214
device = "cuda" if is_available() else "cpu"
1315

14-
context_size = 128
15-
lidar_horizontal_resolution = 128 # DON'T CHANGE THIS VALUE PLS
16-
camera_horizontal_resolution = 128 # DON'T CHANGE THIS VALUE PLS
16+
context_size = 1
17+
lidar_horizontal_resolution = 1024 # DON'T CHANGE THIS VALUE PLS
18+
camera_horizontal_resolution = 1024 # DON'T CHANGE THIS VALUE PLS
1719

18-
B_DEBUG = False
20+
LOG_LEVEL = logging.DEBUG
21+
FORMATTER = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

0 commit comments

Comments
 (0)