diff --git a/.gitignore b/.gitignore index 68bc17f9..28f6a920 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +/pytag/jars/*.jar diff --git a/README.md b/README.md index 9b461b31..bcff2da6 100644 --- a/README.md +++ b/README.md @@ -6,37 +6,56 @@ [![twitter](https://img.shields.io/twitter/follow/gameai_qmul?style=social)](https://twitter.com/intent/follow?screen_name=gameai_qmul) [![](https://img.shields.io/github/stars/martinballa/PyTAG.svg?label=Stars&style=social)](https://github.com/GAIGResearch/TabletopGames) - -PyTAG allows interaction with the TAG framework from Python. This repository contains all the python code required to run Reinforcement Learning agents. -The aim of PyTAG is to provide a Reinforcement Learning API for the TAG framework, but it is not limited to RL as using the python-java bridge all public functions and variables are accessible from python. +PyTAG allows interaction with the TAG framework from Python. This repository contains all the python code required to +run Reinforcement Learning agents. +The aim of PyTAG is to provide a Reinforcement Learning API for the TAG framework, but it is not limited to RL as using +the python-java bridge all public functions and variables are accessible from python. If you want to learn more about TAG, please visit the [website](http://tabletopgames.ai). -You may try [this](https://colab.research.google.com/drive/1WMVu9bFkxvwK7evD1sIkxcsrlhdRoY9d?usp=sharing) google colab notebook to try out PyTAG before installing it on your own machine. +You may try [this](https://colab.research.google.com/drive/1WMVu9bFkxvwK7evD1sIkxcsrlhdRoY9d?usp=sharing) google colab +notebook to try out PyTAG before installing it on your own machine. ## Setting up -The project requires Java with minimum version 8. We recommend installing pytag in a new virtual environment. To install pytag you may follow the steps below. + +TAG requires Java with minimum version 21. We recommend installing pytag in a new virtual environment. To +install +pytag you may follow the steps below. + - 1, Clone this repository. - 2, Install PyTAG as a python package ```pip install -e .``` -- 3, Run ```jar_setup.py``` to download the latest jar file for TAG or see the section on "Getting the TAG jar files" below for more options. +- 3, Run ```jar_setup.py``` to download the latest jar file for TAG or see the section on "Getting the TAG jar files" + below for more options. - 4, (optional) install pytag with the additional dependencies to run the baselines ```pip install -e .[examples]``` -- 5, (optional) you may test your installation by running the examples in ```examples/``` for instance ```pt-action-masking.py```. +- 5, (optional) you may test your installation by running the examples in ```examples/``` for instance + ```pt-action-masking.py```. ### Getting the TAG jar files -Pytag is looking for the TAG jar files in the ```pytag/jars/``` folder. To get the latest jar files you may run ```jar_setup.py``` which will download the latest jar files and unpack them at the correct location. -Or alternatively you may manually download it from [Google drive](https://drive.google.com/file/d/1uPNoZkdI4rJiFyNyXFVun_VcAlN3QIVQ/view?usp=drive_link) and place the jar files in the ```pytag/jars/``` folder. -In case that you want to make changes to the JAVA framework (i.e.: implementing the RL interfaces for a new game) you need to create new jar files from TAG and place them in the ```pytag/jars/``` folder. +Pytag is looking for the TAG jar file in the ```pytag/jars/``` folder. To get the latest jar file you may run +```jar_setup.py``` which will download the latest jar files and unpack them at the correct location. +Or alternatively you may manually download it +from [Google drive](https://drive.google.com/file/d/1wIM2xPE5tqvVzO931t3xcVYWk7VCr6i8/view?usp=drive_link) and place +the jar files in the ```pytag/jars/``` folder. + +In case that you want to make changes to the JAVA framework (i.e.: implementing the RL interfaces for a new game) you +need to create new jar files from TAG and place them in the ```pytag/jars/``` folder. +For instructions on building the full TAG.jar file, see the [TAG website](https://tabletopgames.ai/wiki/maven); +although this just requires a `mvn install` command, then copy the generated `target/TAG.jar` file to the `pytag/jars/` folder. ## Getting started -The examples folder provides a few python scripts that may serve as a starting point for using the framework. -```pt-action-masking.py``` demonstrates how the action masking may be used to sample random valid actions manually. ```gym-action-masking.py``` extends this to using the action masking in a gym environment. ```gym-random.py``` shows how the built-in action sampler may be used. +The examples folder provides a few python scripts that may serve as a starting point for using the framework. +```pt-action-masking.py``` demonstrates how the action masking may be used to sample random valid actions manually. +```gym-action-masking.py``` extends this to using the action masking in a gym environment. ```gym-random.py``` shows how +the built-in action sampler may be used. ```ma-random.py``` demonstrates how multiple python agents may be controlled. -The remaining scripts are used to run the PPO baselines from the IEEE CoG 23' paper. ```ppo-eval.py``` allows you to load a trained PPO model for evaluation. +The remaining scripts are used to run the PPO baselines from the IEEE CoG 23' paper. ```ppo-eval.py``` allows you to +load a trained PPO model for evaluation. ## Citing Information To cite PyTAG in your work, please cite this paper: + ``` @article{balla2023pytag, title={PyTAG: Challenges and Opportunities for Reinforcement Learning in Tabletop Games}, @@ -47,6 +66,7 @@ To cite PyTAG in your work, please cite this paper: ``` To cite TAG in your work, please cite this paper: + ``` @inproceedings{gaina2020tag, author= {Raluca D. Gaina and Martin Balla and Alexander Dockhorn and Raul Montoliu and Diego Perez-Liebana}, @@ -58,10 +78,16 @@ To cite TAG in your work, please cite this paper: ``` ## Contact and contribute -The main method to contribute to our repository directly with code, or to suggest new features, point out bugs or ask questions about the project is through [creating new Issues on this github repository](https://github.com/GAIGResearch/TabletopGames/issues) or [creating new Pull Requests](https://github.com/GAIGResearch/TabletopGames/pulls). Alternatively, you may contact the authors of the papers listed above. + +The main method to contribute to our repository directly with code, or to suggest new features, point out bugs or ask +questions about the project is +through [creating new Issues on this github repository](https://github.com/GAIGResearch/TabletopGames/issues) +or [creating new Pull Requests](https://github.com/GAIGResearch/TabletopGames/pulls). Alternatively, you may contact the +authors of the papers listed above. You can also find out more about the [QMUL Game AI Group](http://gameai.eecs.qmul.ac.uk/). ## Acknowledgements -This work was partly funded by the EPSRC CDT in Intelligent Games and Game Intelligence (IGGI) EP/L015846/1 and EPSRC research grant EP/T008962/1. +This work was partly funded by the EPSRC CDT in Intelligent Games and Game Intelligence (IGGI) EP/L015846/1 and EPSRC +research grant EP/T008962/1. diff --git a/examples/ppo.py b/examples/ppo.py index 413ac693..dd64618a 100644 --- a/examples/ppo.py +++ b/examples/ppo.py @@ -3,6 +3,7 @@ import os import random import time +import jpype from distutils.util import strtobool import gymnasium as gym @@ -13,7 +14,8 @@ from torch.utils.tensorboard import SummaryWriter -from utils.wrappers import MergeActionMaskWrapper, RecordEpisodeStatistics +import pytag.gym_wrapper +from pytag.utils.wrappers import MergeActionMaskWrapper, RecordEpisodeStatistics from pytag.utils.common import make_env from utils.networks import PPONet @@ -77,7 +79,7 @@ def parse_args(): parser.add_argument("--target-kl", type=float, default=None, help="the target KL divergence threshold") # game related args - parser.add_argument('--opponent', type=str, default='random', choices=["random", "osla", "mcts"]) + parser.add_argument('--opponent', type=str, default='random') parser.add_argument("--n-players", type=int, default=2, help="the number of players in the env (note some games only support certain number of players)") parser.add_argument("--framestack", type=int, default=1) @@ -153,7 +155,12 @@ def parse_args(): # TRY NOT TO MODIFY: start the game global_step = 0 start_time = time.time() - next_obs, next_info = envs.reset() + try: + next_obs, next_info = envs.reset() + except jpype.JException as ex: + print("Java Exception during envs.reset():") + print(ex.stacktrace()) + raise next_obs = torch.tensor(next_obs).to(device) if args.framestack > 1: next_obs = next_obs.view(next_obs.shape[0], -1) @@ -182,14 +189,19 @@ def parse_args(): logprobs[step] = logprob # TRY NOT TO MODIFY: execute the game and log data. - next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy()) + try: + next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy()) + except jpype.JException as ex: + print("Java Exception during envs.step():") + print(ex.stacktrace()) + raise next_masks = torch.from_numpy(info["action_mask"]).to(device) rewards[step] = torch.tensor(reward).to(device).view(-1) next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) if args.framestack > 1: next_obs = next_obs.view(next_obs.shape[0], -1) - if "episode" in info: # todo not sure if it's faster than just iterationg over _episode + if "episode" in info: # todo not sure if it's faster than just iterating over _episode for i in range(args.num_envs): if info["_episode"][i]: # print(f"global_step={global_step}, episodic_return={info['episode']['r'][i]}") diff --git a/jar_setup.py b/jar_setup.py index 2c1a91a1..e4412579 100644 --- a/jar_setup.py +++ b/jar_setup.py @@ -5,15 +5,15 @@ argParser.add_argument("-file-id", "--file-id", default="1uPNoZkdI4rJiFyNyXFVun_VcAlN3QIVQ", help="your name") args = argParser.parse_args() -filename = "pytag/jars/ModernBoardGame.zip" -file_id = args.file_id #"1uPNoZkdI4rJiFyNyXFVun_VcAlN3QIVQ" +filename = "pytag/jars/TAG.jar" +file_id = args.file_id #"1wIM2xPE5tqvVzO931t3xcVYWk7VCr6i8" gdown.download( f"https://drive.google.com/uc?export=download&confirm=pbef&id={file_id}", filename ) -# unzip the files into -jar_path = "pytag/" -with zipfile.ZipFile(filename, 'r') as zip_file: - zip_file.extractall(jar_path) -os.remove(filename) \ No newline at end of file +# unzip the files into (no longer needed since the jar file is already in the correct location) +#jar_path = "pytag/" +#with zipfile.ZipFile(filename, 'r') as zip_file: +# zip_file.extractall(jar_path) +#os.remove(filename) \ No newline at end of file diff --git a/pytag/gym_wrapper/envs.py b/pytag/gym_wrapper/envs.py index 8d6b92a3..b0758d92 100644 --- a/pytag/gym_wrapper/envs.py +++ b/pytag/gym_wrapper/envs.py @@ -28,7 +28,9 @@ def get_action_tree_shape(self): def sample_rnd_action(self): return self._env.sample_rnd_action() - def reset(self): + def reset(self, seed=None, options=None): + if seed is not None: + self._env = PyTAG(agent_ids=self._env.agent_ids, game_id=self._env.game_id, seed=seed, obs_type=self._obs_type) obs, info = self._env.reset() return obs, info diff --git a/pytag/pyTAG.py b/pytag/pyTAG.py index 127a9491..c6b29e67 100644 --- a/pytag/pyTAG.py +++ b/pytag/pyTAG.py @@ -7,32 +7,28 @@ import numpy as np from typing import List def list_supported_games(as_json=False): - tag_jar = os.path.join(os.path.dirname(__file__), 'jars', 'ModernBoardGame.jar') - jpype.addClassPath(tag_jar) + tag_jar = os.path.join(os.path.dirname(__file__), 'jars', 'TAG.jar') + # 2. Hard Stop if the JAR is missing + if not os.path.exists(tag_jar): + raise FileNotFoundError(f"CRITICAL: Could not find JAR at {tag_jar}. " + f"Check your folder structure!") + + # jpype.addClassPath(tag_jar) if not jpype.isJVMStarted(): - jpype.startJVM(convertStrings=False) + print("starting JVM") + print(f"Loading JAR from: {tag_jar}") and print(f"File exists: {os.path.exists(tag_jar)}") + jpype.startJVM(classpath=[tag_jar], convertStrings=False) + # jpype.startJVM("--add-opens=java.base/java.lang=ALL-UNNAMED", classpath=[tag_jar], convertStrings=False) PyTAGEnv = jpype.JClass("core.PyTAG") if as_json: return json.loads(str(PyTAGEnv.getSupportedGamesJSON())) return PyTAGEnv.getSupportedGames() -def get_agent_class(agent_name): - if agent_name == "random": - return jpype.JClass("players.simple.RandomPlayer") - if agent_name == "mcts": - return jpype.JClass("players.mcts.MCTSPlayer") - if agent_name == "osla": - return jpype.JClass("players.simple.OSLAPlayer") - if agent_name == "python": - return jpype.JClass("players.python.PythonAgent") - return None - -def get_mcts_with_params(json_path): - PlayerFactory = jpype.JClass("players.PlayerFactory") - with open(os.path.expanduser(json_path)) as json_file: - json_string = json.load(json_file) - json_string = str(json_string).replace('\'', '\"') # JAVA only uses " for string - return jpype.JClass("players.mcts.MCTSPlayer")(PlayerFactory.fromJSONString(json_string)) +def get_agent(data): + if data == "python": + return jpype.JClass("players.python.PythonAgent")() + player_factory = jpype.JClass("players.PlayerFactory") + return player_factory.createPlayer(data) # create the game registry when PyTAG is loaded _game_registry = list_supported_games(as_json=True) @@ -42,6 +38,9 @@ class PyTAG(): Note that the java jar package is expected to be in the jars folder of the same directory as this file. """ def __init__(self, agent_ids: List[str], game_id: str="Diamant", seed: int=0, obs_type:str="vector"): + self.agent_ids = agent_ids + self.game_id = game_id + self.seed = seed self._last_obs_vector = None self._last_action_mask = None self._rnd = random.Random(seed) @@ -50,10 +49,10 @@ def __init__(self, agent_ids: List[str], game_id: str="Diamant", seed: int=0, ob assert game_id in _game_registry, f"Game {game_id} not supported. Supported games are {_game_registry}" assert _game_registry[game_id][obs_type] == True, f"Game {game_id} does not support observation type {obs_type}" # start up the JVM - tag_jar = os.path.join(os.path.dirname(__file__), 'jars', 'ModernBoardGame.jar') + tag_jar = os.path.join(os.path.dirname(__file__), 'jars', 'TAG.jar') jpype.addClassPath(tag_jar) if not jpype.isJVMStarted(): - jpype.startJVM(convertStrings=False) + jpype.startJVM("--add-opens=java.base/java.lang=ALL-UNNAMED", convertStrings=False) # access to the java classes PyTAGEnv = jpype.JClass("core.PyTAG") @@ -63,10 +62,7 @@ def __init__(self, agent_ids: List[str], game_id: str="Diamant", seed: int=0, ob # Initialize the java environment gameType = GameType.valueOf(Utils.getArg([""], "game", game_id)) - if agent_ids[0] == "mcts": - agents = [get_mcts_with_params(f"~/data/pyTAG/MCTS_for_{game_id}.json")() for agent_id in agent_ids] - else: - agents = [get_agent_class(agent_id)() for agent_id in agent_ids] + agents = [get_agent(agent_id) for agent_id in agent_ids] self._playerID = agent_ids.index("python") # if multiple python agents this is the first one self._java_env = PyTAGEnv(gameType, None, jpype.java.util.ArrayList(agents), seed, True) diff --git a/pytag/utils/common.py b/pytag/utils/common.py index 79bc31dc..da1627ed 100644 --- a/pytag/utils/common.py +++ b/pytag/utils/common.py @@ -1,12 +1,12 @@ # various helper functions import gymnasium as gym -from gymnasium.wrappers.frame_stack import FrameStack +from gymnasium.wrappers import FrameStackObservation as FrameStack import numpy as np import torch import jpype import jpype.imports -from utils.wrappers import StrategoWrapper, SushiGoWrapper +from pytag.utils.wrappers import StrategoWrapper, SushiGoWrapper def make_env(env_id, seed, opponent, n_players, framestack=1, obs_type="vector"): @@ -25,19 +25,6 @@ def thunk(): env = FrameStack(env, framestack) return env return thunk -def get_agent_list(): - return ["random", "mcts", "osla", "python"] - -def get_agent_class(agent_name): - if agent_name == "random": - return jpype.JClass("players.simple.RandomPlayer") - if agent_name == "mcts": - return jpype.JClass("players.mcts.MCTSPlayer") - if agent_name == "osla": - return jpype.JClass("players.simple.OSLAPlayer") - if agent_name == "python": - return jpype.JClass("players.python.PythonAgent") - return None def layer_init(layer, std=np.sqrt(2), bias_const=0.0): torch.nn.init.orthogonal_(layer.weight, std) diff --git a/pytag/utils/wrappers.py b/pytag/utils/wrappers.py index b48c4e6d..2dcf624b 100644 --- a/pytag/utils/wrappers.py +++ b/pytag/utils/wrappers.py @@ -8,10 +8,10 @@ import gymnasium as gym -from gymnasium.vector import VectorEnvWrapper +from gymnasium.vector import VectorWrapper -class MergeActionMaskWrapper(VectorEnvWrapper): +class MergeActionMaskWrapper(VectorWrapper): def reset_wait(self, **kwargs): obs, infos = self.env.reset_wait(**kwargs) return obs, self._merge_action_masks(infos) @@ -93,7 +93,7 @@ def process_json_obs(self, json_obs, normalise=True): obs = np.concatenate([score, round, played_cards, cards_in_hand, opp_played_cards, opp_scores]) return obs -class RecordEpisodeStatistics(gym.Wrapper): +class RecordEpisodeStatistics(VectorWrapper): # Based on RecordEpisodeStatistics from gymnasium, but it checks whether the player has won the game """This wrapper will keep track of cumulative rewards and episode lengths. @@ -111,7 +111,7 @@ def __init__(self, env: gym.Env, deque_size: int = 100): deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue` """ super().__init__(env) - self.num_envs = getattr(env, "num_envs", 1) + # self.num_envs = getattr(env, "num_envs", 1) # This is already a property of VectorWrapper self.episode_count = 0 self.episode_start_times: np.ndarray = None self.episode_returns: Optional[np.ndarray] = None @@ -120,7 +120,7 @@ def __init__(self, env: gym.Env, deque_size: int = 100): self.return_queue = deque(maxlen=deque_size) self.length_queue = deque(maxlen=deque_size) self.win_queue = deque(maxlen=deque_size) - self.is_vector_env = getattr(env, "is_vector_env", False) + # self.is_vector_env = getattr(env, "is_vector_env", False) # VectorWrapper IS a vector env def reset(self, **kwargs): """Resets the environment using kwargs and resets the episode returns and lengths.""" @@ -159,14 +159,14 @@ def step(self, action): infos["episode"] = { "r": np.where(dones, self.episode_returns, 0.0), "l": np.where(dones, self.episode_lengths, 0), - "w": [0 if final_inf is None else final_inf["has_won"] for final_inf in infos["final_info"]], + "w": [0 if final_inf is None else final_inf.get("has_won", 0) for final_inf in infos.get("final_info", [None] * self.num_envs)], "t": np.where( dones, np.round(time.perf_counter() - self.episode_start_times, 6), 0.0, ), } - if self.is_vector_env: + if True: # it is a vector env infos["_episode"] = np.where(dones, True, False) self.return_queue.extend(self.episode_returns[dones]) self.length_queue.extend(self.episode_lengths[dones])