Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,4 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
/pytag/jars/*.jar
56 changes: 41 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,56 @@
[![twitter](https://img.shields.io/twitter/follow/gameai_qmul?style=social)](https://twitter.com/intent/follow?screen_name=gameai_qmul)
[![](https://img.shields.io/github/stars/martinballa/PyTAG.svg?label=Stars&style=social)](https://github.com/GAIGResearch/TabletopGames)


PyTAG allows interaction with the TAG framework from Python. This repository contains all the python code required to run Reinforcement Learning agents.
The aim of PyTAG is to provide a Reinforcement Learning API for the TAG framework, but it is not limited to RL as using the python-java bridge all public functions and variables are accessible from python.
PyTAG allows interaction with the TAG framework from Python. This repository contains all the python code required to
run Reinforcement Learning agents.
The aim of PyTAG is to provide a Reinforcement Learning API for the TAG framework, but it is not limited to RL as using
the python-java bridge all public functions and variables are accessible from python.
If you want to learn more about TAG, please visit the [website](http://tabletopgames.ai).

You may try [this](https://colab.research.google.com/drive/1WMVu9bFkxvwK7evD1sIkxcsrlhdRoY9d?usp=sharing) google colab notebook to try out PyTAG before installing it on your own machine.
You may try [this](https://colab.research.google.com/drive/1WMVu9bFkxvwK7evD1sIkxcsrlhdRoY9d?usp=sharing) google colab
notebook to try out PyTAG before installing it on your own machine.

## Setting up
The project requires Java with minimum version 8. We recommend installing pytag in a new virtual environment. To install pytag you may follow the steps below.

TAG requires Java with minimum version 21. We recommend installing pytag in a new virtual environment. To
install
pytag you may follow the steps below.

- 1, Clone this repository.
- 2, Install PyTAG as a python package ```pip install -e .```
- 3, Run ```jar_setup.py``` to download the latest jar file for TAG or see the section on "Getting the TAG jar files" below for more options.
- 3, Run ```jar_setup.py``` to download the latest jar file for TAG or see the section on "Getting the TAG jar files"
below for more options.
- 4, (optional) install pytag with the additional dependencies to run the baselines ```pip install -e .[examples]```
- 5, (optional) you may test your installation by running the examples in ```examples/``` for instance ```pt-action-masking.py```.
- 5, (optional) you may test your installation by running the examples in ```examples/``` for instance
```pt-action-masking.py```.

### Getting the TAG jar files
Pytag is looking for the TAG jar files in the ```pytag/jars/``` folder. To get the latest jar files you may run ```jar_setup.py``` which will download the latest jar files and unpack them at the correct location.
Or alternatively you may manually download it from [Google drive](https://drive.google.com/file/d/1uPNoZkdI4rJiFyNyXFVun_VcAlN3QIVQ/view?usp=drive_link) and place the jar files in the ```pytag/jars/``` folder.

In case that you want to make changes to the JAVA framework (i.e.: implementing the RL interfaces for a new game) you need to create new jar files from TAG and place them in the ```pytag/jars/``` folder.
Pytag is looking for the TAG jar file in the ```pytag/jars/``` folder. To get the latest jar file you may run
```jar_setup.py``` which will download the latest jar files and unpack them at the correct location.
Or alternatively you may manually download it
from [Google drive](https://drive.google.com/file/d/1wIM2xPE5tqvVzO931t3xcVYWk7VCr6i8/view?usp=drive_link) and place
the jar files in the ```pytag/jars/``` folder.

In case that you want to make changes to the JAVA framework (i.e.: implementing the RL interfaces for a new game) you
need to create new jar files from TAG and place them in the ```pytag/jars/``` folder.
For instructions on building the full TAG.jar file, see the [TAG website](https://tabletopgames.ai/wiki/maven);
although this just requires a `mvn install` command, then copy the generated `target/TAG.jar` file to the `pytag/jars/` folder.

## Getting started

The examples folder provides a few python scripts that may serve as a starting point for using the framework.
```pt-action-masking.py``` demonstrates how the action masking may be used to sample random valid actions manually. ```gym-action-masking.py``` extends this to using the action masking in a gym environment. ```gym-random.py``` shows how the built-in action sampler may be used.
The examples folder provides a few python scripts that may serve as a starting point for using the framework.
```pt-action-masking.py``` demonstrates how the action masking may be used to sample random valid actions manually.
```gym-action-masking.py``` extends this to using the action masking in a gym environment. ```gym-random.py``` shows how
the built-in action sampler may be used.
```ma-random.py``` demonstrates how multiple python agents may be controlled.
The remaining scripts are used to run the PPO baselines from the IEEE CoG 23' paper. ```ppo-eval.py``` allows you to load a trained PPO model for evaluation.
The remaining scripts are used to run the PPO baselines from the IEEE CoG 23' paper. ```ppo-eval.py``` allows you to
load a trained PPO model for evaluation.

## Citing Information

To cite PyTAG in your work, please cite this paper:

```
@article{balla2023pytag,
title={PyTAG: Challenges and Opportunities for Reinforcement Learning in Tabletop Games},
Expand All @@ -47,6 +66,7 @@ To cite PyTAG in your work, please cite this paper:
```

To cite TAG in your work, please cite this paper:

```
@inproceedings{gaina2020tag,
author= {Raluca D. Gaina and Martin Balla and Alexander Dockhorn and Raul Montoliu and Diego Perez-Liebana},
Expand All @@ -58,10 +78,16 @@ To cite TAG in your work, please cite this paper:
```

## Contact and contribute
The main method to contribute to our repository directly with code, or to suggest new features, point out bugs or ask questions about the project is through [creating new Issues on this github repository](https://github.com/GAIGResearch/TabletopGames/issues) or [creating new Pull Requests](https://github.com/GAIGResearch/TabletopGames/pulls). Alternatively, you may contact the authors of the papers listed above.

The main method to contribute to our repository directly with code, or to suggest new features, point out bugs or ask
questions about the project is
through [creating new Issues on this github repository](https://github.com/GAIGResearch/TabletopGames/issues)
or [creating new Pull Requests](https://github.com/GAIGResearch/TabletopGames/pulls). Alternatively, you may contact the
authors of the papers listed above.

You can also find out more about the [QMUL Game AI Group](http://gameai.eecs.qmul.ac.uk/).

## Acknowledgements

This work was partly funded by the EPSRC CDT in Intelligent Games and Game Intelligence (IGGI) EP/L015846/1 and EPSRC research grant EP/T008962/1.
This work was partly funded by the EPSRC CDT in Intelligent Games and Game Intelligence (IGGI) EP/L015846/1 and EPSRC
research grant EP/T008962/1.
22 changes: 17 additions & 5 deletions examples/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import random
import time
import jpype
from distutils.util import strtobool

import gymnasium as gym
Expand All @@ -13,7 +14,8 @@

from torch.utils.tensorboard import SummaryWriter

from utils.wrappers import MergeActionMaskWrapper, RecordEpisodeStatistics
import pytag.gym_wrapper
from pytag.utils.wrappers import MergeActionMaskWrapper, RecordEpisodeStatistics
from pytag.utils.common import make_env
from utils.networks import PPONet

Expand Down Expand Up @@ -77,7 +79,7 @@ def parse_args():
parser.add_argument("--target-kl", type=float, default=None,
help="the target KL divergence threshold")
# game related args
parser.add_argument('--opponent', type=str, default='random', choices=["random", "osla", "mcts"])
parser.add_argument('--opponent', type=str, default='random')
parser.add_argument("--n-players", type=int, default=2,
help="the number of players in the env (note some games only support certain number of players)")
parser.add_argument("--framestack", type=int, default=1)
Expand Down Expand Up @@ -153,7 +155,12 @@ def parse_args():
# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
next_obs, next_info = envs.reset()
try:
next_obs, next_info = envs.reset()
except jpype.JException as ex:
print("Java Exception during envs.reset():")
print(ex.stacktrace())
raise
next_obs = torch.tensor(next_obs).to(device)
if args.framestack > 1:
next_obs = next_obs.view(next_obs.shape[0], -1)
Expand Down Expand Up @@ -182,14 +189,19 @@ def parse_args():
logprobs[step] = logprob

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
try:
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
except jpype.JException as ex:
print("Java Exception during envs.step():")
print(ex.stacktrace())
raise
next_masks = torch.from_numpy(info["action_mask"]).to(device)
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
if args.framestack > 1:
next_obs = next_obs.view(next_obs.shape[0], -1)

if "episode" in info: # todo not sure if it's faster than just iterationg over _episode
if "episode" in info: # todo not sure if it's faster than just iterating over _episode
for i in range(args.num_envs):
if info["_episode"][i]:
# print(f"global_step={global_step}, episodic_return={info['episode']['r'][i]}")
Expand Down
14 changes: 7 additions & 7 deletions jar_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
argParser.add_argument("-file-id", "--file-id", default="1uPNoZkdI4rJiFyNyXFVun_VcAlN3QIVQ", help="your name")
args = argParser.parse_args()

filename = "pytag/jars/ModernBoardGame.zip"
file_id = args.file_id #"1uPNoZkdI4rJiFyNyXFVun_VcAlN3QIVQ"
filename = "pytag/jars/TAG.jar"
file_id = args.file_id #"1wIM2xPE5tqvVzO931t3xcVYWk7VCr6i8"
gdown.download(
f"https://drive.google.com/uc?export=download&confirm=pbef&id={file_id}",
filename
)

# unzip the files into
jar_path = "pytag/"
with zipfile.ZipFile(filename, 'r') as zip_file:
zip_file.extractall(jar_path)
os.remove(filename)
# unzip the files into (no longer needed since the jar file is already in the correct location)
#jar_path = "pytag/"
#with zipfile.ZipFile(filename, 'r') as zip_file:
# zip_file.extractall(jar_path)
#os.remove(filename)
4 changes: 3 additions & 1 deletion pytag/gym_wrapper/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def get_action_tree_shape(self):
def sample_rnd_action(self):
return self._env.sample_rnd_action()

def reset(self):
def reset(self, seed=None, options=None):
if seed is not None:
self._env = PyTAG(agent_ids=self._env.agent_ids, game_id=self._env.game_id, seed=seed, obs_type=self._obs_type)
obs, info = self._env.reset()
return obs, info

Expand Down
48 changes: 22 additions & 26 deletions pytag/pyTAG.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,28 @@
import numpy as np
from typing import List
def list_supported_games(as_json=False):
tag_jar = os.path.join(os.path.dirname(__file__), 'jars', 'ModernBoardGame.jar')
jpype.addClassPath(tag_jar)
tag_jar = os.path.join(os.path.dirname(__file__), 'jars', 'TAG.jar')
# 2. Hard Stop if the JAR is missing
if not os.path.exists(tag_jar):
raise FileNotFoundError(f"CRITICAL: Could not find JAR at {tag_jar}. "
f"Check your folder structure!")

# jpype.addClassPath(tag_jar)
if not jpype.isJVMStarted():
jpype.startJVM(convertStrings=False)
print("starting JVM")
print(f"Loading JAR from: {tag_jar}") and print(f"File exists: {os.path.exists(tag_jar)}")
jpype.startJVM(classpath=[tag_jar], convertStrings=False)
# jpype.startJVM("--add-opens=java.base/java.lang=ALL-UNNAMED", classpath=[tag_jar], convertStrings=False)
PyTAGEnv = jpype.JClass("core.PyTAG")
if as_json:
return json.loads(str(PyTAGEnv.getSupportedGamesJSON()))
return PyTAGEnv.getSupportedGames()

def get_agent_class(agent_name):
if agent_name == "random":
return jpype.JClass("players.simple.RandomPlayer")
if agent_name == "mcts":
return jpype.JClass("players.mcts.MCTSPlayer")
if agent_name == "osla":
return jpype.JClass("players.simple.OSLAPlayer")
if agent_name == "python":
return jpype.JClass("players.python.PythonAgent")
return None

def get_mcts_with_params(json_path):
PlayerFactory = jpype.JClass("players.PlayerFactory")
with open(os.path.expanduser(json_path)) as json_file:
json_string = json.load(json_file)
json_string = str(json_string).replace('\'', '\"') # JAVA only uses " for string
return jpype.JClass("players.mcts.MCTSPlayer")(PlayerFactory.fromJSONString(json_string))
def get_agent(data):
if data == "python":
return jpype.JClass("players.python.PythonAgent")()
player_factory = jpype.JClass("players.PlayerFactory")
return player_factory.createPlayer(data)

# create the game registry when PyTAG is loaded
_game_registry = list_supported_games(as_json=True)
Expand All @@ -42,6 +38,9 @@ class PyTAG():
Note that the java jar package is expected to be in the jars folder of the same directory as this file.
"""
def __init__(self, agent_ids: List[str], game_id: str="Diamant", seed: int=0, obs_type:str="vector"):
self.agent_ids = agent_ids
self.game_id = game_id
self.seed = seed
self._last_obs_vector = None
self._last_action_mask = None
self._rnd = random.Random(seed)
Expand All @@ -50,10 +49,10 @@ def __init__(self, agent_ids: List[str], game_id: str="Diamant", seed: int=0, ob
assert game_id in _game_registry, f"Game {game_id} not supported. Supported games are {_game_registry}"
assert _game_registry[game_id][obs_type] == True, f"Game {game_id} does not support observation type {obs_type}"
# start up the JVM
tag_jar = os.path.join(os.path.dirname(__file__), 'jars', 'ModernBoardGame.jar')
tag_jar = os.path.join(os.path.dirname(__file__), 'jars', 'TAG.jar')
jpype.addClassPath(tag_jar)
if not jpype.isJVMStarted():
jpype.startJVM(convertStrings=False)
jpype.startJVM("--add-opens=java.base/java.lang=ALL-UNNAMED", convertStrings=False)

# access to the java classes
PyTAGEnv = jpype.JClass("core.PyTAG")
Expand All @@ -63,10 +62,7 @@ def __init__(self, agent_ids: List[str], game_id: str="Diamant", seed: int=0, ob
# Initialize the java environment
gameType = GameType.valueOf(Utils.getArg([""], "game", game_id))

if agent_ids[0] == "mcts":
agents = [get_mcts_with_params(f"~/data/pyTAG/MCTS_for_{game_id}.json")() for agent_id in agent_ids]
else:
agents = [get_agent_class(agent_id)() for agent_id in agent_ids]
agents = [get_agent(agent_id) for agent_id in agent_ids]
self._playerID = agent_ids.index("python") # if multiple python agents this is the first one
self._java_env = PyTAGEnv(gameType, None, jpype.java.util.ArrayList(agents), seed, True)

Expand Down
17 changes: 2 additions & 15 deletions pytag/utils/common.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# various helper functions
import gymnasium as gym
from gymnasium.wrappers.frame_stack import FrameStack
from gymnasium.wrappers import FrameStackObservation as FrameStack
import numpy as np
import torch

import jpype
import jpype.imports
from utils.wrappers import StrategoWrapper, SushiGoWrapper
from pytag.utils.wrappers import StrategoWrapper, SushiGoWrapper


def make_env(env_id, seed, opponent, n_players, framestack=1, obs_type="vector"):
Expand All @@ -25,19 +25,6 @@ def thunk():
env = FrameStack(env, framestack)
return env
return thunk
def get_agent_list():
return ["random", "mcts", "osla", "python"]

def get_agent_class(agent_name):
if agent_name == "random":
return jpype.JClass("players.simple.RandomPlayer")
if agent_name == "mcts":
return jpype.JClass("players.mcts.MCTSPlayer")
if agent_name == "osla":
return jpype.JClass("players.simple.OSLAPlayer")
if agent_name == "python":
return jpype.JClass("players.python.PythonAgent")
return None

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
Expand Down
14 changes: 7 additions & 7 deletions pytag/utils/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@


import gymnasium as gym
from gymnasium.vector import VectorEnvWrapper
from gymnasium.vector import VectorWrapper


class MergeActionMaskWrapper(VectorEnvWrapper):
class MergeActionMaskWrapper(VectorWrapper):
def reset_wait(self, **kwargs):
obs, infos = self.env.reset_wait(**kwargs)
return obs, self._merge_action_masks(infos)
Expand Down Expand Up @@ -93,7 +93,7 @@ def process_json_obs(self, json_obs, normalise=True):
obs = np.concatenate([score, round, played_cards, cards_in_hand, opp_played_cards, opp_scores])
return obs

class RecordEpisodeStatistics(gym.Wrapper):
class RecordEpisodeStatistics(VectorWrapper):
# Based on RecordEpisodeStatistics from gymnasium, but it checks whether the player has won the game
"""This wrapper will keep track of cumulative rewards and episode lengths.

Expand All @@ -111,7 +111,7 @@ def __init__(self, env: gym.Env, deque_size: int = 100):
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
"""
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
# self.num_envs = getattr(env, "num_envs", 1) # This is already a property of VectorWrapper
self.episode_count = 0
self.episode_start_times: np.ndarray = None
self.episode_returns: Optional[np.ndarray] = None
Expand All @@ -120,7 +120,7 @@ def __init__(self, env: gym.Env, deque_size: int = 100):
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)
self.win_queue = deque(maxlen=deque_size)
self.is_vector_env = getattr(env, "is_vector_env", False)
# self.is_vector_env = getattr(env, "is_vector_env", False) # VectorWrapper IS a vector env

def reset(self, **kwargs):
"""Resets the environment using kwargs and resets the episode returns and lengths."""
Expand Down Expand Up @@ -159,14 +159,14 @@ def step(self, action):
infos["episode"] = {
"r": np.where(dones, self.episode_returns, 0.0),
"l": np.where(dones, self.episode_lengths, 0),
"w": [0 if final_inf is None else final_inf["has_won"] for final_inf in infos["final_info"]],
"w": [0 if final_inf is None else final_inf.get("has_won", 0) for final_inf in infos.get("final_info", [None] * self.num_envs)],
"t": np.where(
dones,
np.round(time.perf_counter() - self.episode_start_times, 6),
0.0,
),
}
if self.is_vector_env:
if True: # it is a vector env
infos["_episode"] = np.where(dones, True, False)
self.return_queue.extend(self.episode_returns[dones])
self.length_queue.extend(self.episode_lengths[dones])
Expand Down