This State Looks Like That: Self-Interpretable Reinforcement Learning Agents using Prototype Soft Actor-Critic
This is the implementation of ProtoSAC, a novel deep RL architecture that integrates a prototype-based actor into the Soft Actor-Critic (SAC) algorithm, enabling intrinsic interpretability in continuous action spaces.
Our method learns a set of prototypes that represent interpretable state clusters, each associated with a Gaussian action distribution. Actions are generated as a similarity-weighted mixture over these prototypes, providing more inspectable and decomposable decision-making, without sacrificing performance compared to standard SAC.
The images below show examples of prototypes learned by the model during training.
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
- Installation
- Usage
- Available Environments
- Model Configuration
- Training
- Saving and Loading Models
- Prototype Visualization
To set up the project environment, you will need to install the required dependencies. You can create a Conda environment by following these steps:
-
Clone the repository:
git clone https://github.com/KRLGroup/PrototypeSAC.git cd PrototypeSAC -
Create the environment from the
env.ymlfile:conda env create -f env.yml
-
Activate the environment:
conda activate myenv
After activating the environment, you can run the training script. The script allows you to choose the environment, set the number of training episodes, and decide whether to use the baseline model or the custom ProtoSAC model.
You can execute the training process with the following command:
python main.py --environment 0 --episodes 30000 --baseline FalseThis will train the model for 30,000 episodes in the Pendulum-v1 environment using the ProtoSAC model.
-
--environment: Choose the environment for training:0: Pendulum-v1 (default)1: LunarLanderContinuous-v32: MountainCarContinuous-v03: HalfCheetah-v54: Humanoid-v55: Hopper-v56: CarRacing-v3
-
--episodes: The number of episodes to run in the selected environment (default:30000). -
--baseline: Use the baseline SAC model if set toTrue, or use ProtoSAC if set toFalse(default:False).
-
Train on Pendulum-v1 with the baseline model:
python train.py --environment 0 --episodes 50000 --baseline True
-
Train on LunarLanderContinuous-v3 with ProtoSAC:
python train.py --environment 1 --episodes 100000 --baseline False
The project currently supports the following environments:
- Pendulum-v1: A classic continuous control task where the goal is to balance a pendulum in an upright position.
- LunarLanderContinuous-v3: A continuous control task where the agent must land a lunar module safely on the moon’s surface.
- MountainCarContinuous-v0: A task where the agent must drive a car up a mountain to reach a goal position.
- HalfCheetah-v5: Control a simulated 2D cheetah robot to run forward efficiently.
- Humanoid-v5: Control a full-body humanoid robot to walk or run; highly complex.
- Hopper-v5: Control a 2D one-legged robot to hop forward without falling.
- CarRacing-v3: Vision-based racing task where the agent drives a car on random tracks.
You can select any of these environments by specifying the --environment argument.
The script uses the ProtoSAC algorithm by default, and the SAC (Soft Actor-Critic) algorithm as an alternative.
The training is handled by the model.learn() method. By default, the training runs for 30000 episodes, but this can be adjusted with the --episodes argument. During training, the model's performance is monitored, and video recordings of the agent’s behavior are saved in the videos/ directory.
At the end of the training process, the model is saved to the models/ directory. The model can be loaded for further use as follows:
model.save(f"models/{name_env}")model = SAC.load(f"models/{name_env}")You can load the model later for evaluation or continued training.
In this project, you can visualize prototypes during the evaluation phase. The evaluation script uses the ProtoSAC model to visualize the prototypes learned by the model during training.
The evaluation script evaluates the model on a specified environment and visualizes the learned prototypes, action distributions, and saves these visualizations as images and videos. You can use this method to see how the model behaves and understand the representation of the learned prototypes.
To evaluate the model and visualize the prototypes, run:
python protosac_test.py --environment "Pendulum-v1" --model_path "path_to_trained_model" --episodes 50 --save_dir "evaluation_results"This will:
- Evaluate the model stored at
"path_to_trained_model"on the"Pendulum-v1"environment. - Run for 50 episodes.
- Save the evaluation results, prototype images, and video in the
"evaluation_results"directory.















