diff --git a/Cargo.toml b/Cargo.toml index 00966da..c7bfbd5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,13 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -ureq = {version = "2.9.6", features = ["json"] } +ureq = { version = "2.9.6", features = ["json"] } serde_json = "1.0.114" rand = "0.8.5" + +[dev-dependencies] +criterion = "0.3.4" + +[[bench]] +name = "othello" +harness = false diff --git a/benches/othello.rs b/benches/othello.rs new file mode 100644 index 0000000..dac330c --- /dev/null +++ b/benches/othello.rs @@ -0,0 +1,34 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rusty_othello_ai::{ + mcts::MCTS, + othello::{simulate_game, State}, +}; +use std::time::Duration; + +pub fn bench_simulate_game(c: &mut Criterion) { + let mut group = c.benchmark_group("simulate_game"); + group + .sample_size(1000) + .measurement_time(Duration::from_secs(10)); + let game_state = rusty_othello_ai::othello::State::new(); + group.bench_function("simulate game 1", |b| { + b.iter(|| simulate_game(black_box(&mut game_state.clone()))) + }); + + group.finish() +} +pub fn bench_mcts_search(c: &mut Criterion) { + let mut group = c.benchmark_group("mcts_search"); + group + .sample_size(1000) + .measurement_time(Duration::from_secs(10)); + let mut mcts = MCTS::new("true", 1.0); + group.bench_function("Monte Carlo Tree Search", |b| { + b.iter(|| mcts.search(State::new(), 10, |_, _, _| {})) + }); + + group.finish() +} + +criterion_group!(game, bench_simulate_game, bench_mcts_search); +criterion_main!(game); diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..bd21b18 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,2 @@ +pub mod mcts; +pub mod othello; diff --git a/src/main.rs b/src/main.rs index 4ae081e..9b82776 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,10 @@ -use ureq::Response; +use rusty_othello_ai::mcts::MCTS; +use rusty_othello_ai::othello::{parse_state, Action, State}; use std::thread::current; -use std::usize; -use std::{thread::sleep, borrow::Borrow}; use std::time::Duration; -mod mcts; -mod othello; -use mcts::{MCTS, Node}; -use othello::{State, Action, parse_state}; - - +use std::usize; +use std::{borrow::Borrow, thread::sleep}; +use ureq::Response; const SERVER_URL: &str = "http://localhost:8181"; @@ -18,7 +14,11 @@ fn main() { // If the argument is not recognized, the program will panic let args: Vec = std::env::args().collect(); let ai_color; - match args.get(1).expect("Please specify color to the AI").to_lowercase() { + match args + .get(1) + .expect("Please specify color to the AI") + .to_lowercase() + { x if x == "false" => ai_color = x, x if x == "0" => ai_color = "false".to_string(), x if x == "b" => ai_color = "false".to_string(), @@ -28,7 +28,6 @@ fn main() { x if x == "w" => ai_color = "true".to_string(), x if x == "white" => ai_color = "true".to_string(), _ => panic!("Please pass a proper argument to the AI"), - } // Initialize the game state and the Monte Carlo Tree Search (MCTS) // The MCTS is initialized with a new node that represents the current game state @@ -41,12 +40,12 @@ fn main() { loop { // The AI checks if it's its turn, if so, it gets the current game state and performs a search using MCTS match is_my_turn(ai_color.borrow()) { - Ok(true) => { + Ok(true) => { state = get_game_state(); choice = mcts.search(state, ai_iterations, send_progress); // Gives the ai 2% more iterations every round to balance the game simulations // being shorter - ai_iterations += ai_iterations / 50; + ai_iterations += ai_iterations / 50; // If a valid action is found, it sends the move to the server and updates the game state if choice.is_ok() { @@ -56,16 +55,15 @@ fn main() { // If no valid action is found, it sends a pass move to the server and updates the game state else { let _ = send_move(&ai_color, None); - state.do_action(None); + state.do_action(None); } - - }, + } // If it's not the AI's turn, it performs a search using MCTS and waits Ok(false) => { - let dev_null = |_a: usize, _b: usize, _c: &i8| -> (){}; + let dev_null = |_a: usize, _b: usize, _c: &i8| -> () {}; _ = mcts.search(state, 1000, dev_null); //sleep(Duration::from_secs(1)); - }, + } Err(e) => { eprintln!("Error checking turn: {}", e); sleep(Duration::from_secs(1)); @@ -79,7 +77,7 @@ fn is_my_turn(ai: &String) -> Result> { let mut delay = Duration::from_secs(1); let opponent = match ai { x if x == "true" => "false", - _ => "true" + _ => "true", }; loop { let url = format!("{}/turn", SERVER_URL); @@ -94,10 +92,13 @@ fn is_my_turn(ai: &String) -> Result> { // If the response is anything else, the function returns an error _ => return Err("Unexpected response from server".into()), } - }, + } Err(e) => { // Error occurred, possibly a network issue or server error, wait before trying again - eprintln!("Error checking turn: {}, will retry after {:?} seconds", e, delay); + eprintln!( + "Error checking turn: {}, will retry after {:?} seconds", + e, delay + ); sleep(delay); delay = std::cmp::min(delay.saturating_mul(2), Duration::from_secs(10)); } @@ -112,12 +113,14 @@ fn get_game_state() -> State { let mut delay = Duration::from_secs(3); loop { match get_json() { - Ok(resp) => return parse_state(resp.into_json().expect("Error parsing response to json")), + Ok(resp) => { + return parse_state(resp.into_json().expect("Error parsing response to json")) + } Err(_e) => { sleep(delay); delay *= 2; delay = std::cmp::min(Duration::from_millis(10000), delay); - }, + } } } } @@ -138,7 +141,10 @@ fn send_move(player: &String, ai_move: Option) -> Result) -> Result "false", _ => "true", }; let url = format!("{}/AIStatus/{}/{}/{}", SERVER_URL, current, total, color); _ = ureq::post(&url).call(); - } diff --git a/src/mcts.rs b/src/mcts.rs index ab62006..378888c 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -1,8 +1,6 @@ - -use crate::othello::{State, Action, simulate_game}; +use crate::othello::{simulate_game, Action, State}; use std::collections::HashMap; - #[derive(Debug, Clone)] pub struct Node { state: State, @@ -13,7 +11,7 @@ pub struct Node { } impl Node { - pub fn new (state: State, action: Option, untried_actions: Vec) -> Node { + pub fn new(state: State, action: Option, untried_actions: Vec) -> Node { Node { state, action, @@ -23,7 +21,7 @@ impl Node { } } - pub fn update_node(&mut self, result: (i8, isize)) { + pub fn update_node(&mut self, result: (i8, isize)) { self.visits += 1; if result.0 == self.state.next_turn { self.score += result.1; @@ -33,9 +31,9 @@ impl Node { } // Calculates and returns the Upper Confidence Bound (UCB) for the Node fn calculate_ucb(&self, total_count: usize, explore: f32) -> f32 { - (self.score as f32 / self.visits as f32) + explore * (2.0 * (total_count as f32).ln() / self.visits as f32).sqrt() + (self.score as f32 / self.visits as f32) + + explore * (2.0 * (total_count as f32).ln() / self.visits as f32).sqrt() } - } #[derive()] @@ -71,31 +69,40 @@ impl MCTS { // Performs a Monte Carlo Tree Search from the given state for the given number of iterations // It returns the best action found or an error if no action was found - pub fn search(&mut self, from: State, iterations: usize, send_status: fn(usize, usize, &i8)) -> Result { + pub fn search( + &mut self, + from: State, + iterations: usize, + send_status: fn(usize, usize, &i8), + ) -> Result { if let Some(root) = self.state_map.get(&from).cloned() { for i in 0..iterations { if i % 1000 == 0 { //println!("Progress: {i}/{iterations}"); - _ = send_status(i, iterations, &self.color); + _ = send_status(i, iterations, &self.color); } let node_index = self.select(root.clone()).clone(); let node_index = self.expand(node_index.clone()).clone(); - for index in self.tree.get(node_index).expect("No child nodes to simulate").clone().iter() { + for index in self + .tree + .get(node_index) + .expect("No child nodes to simulate") + .clone() + .iter() + { let result: (i8, isize) = self.simulate(*index); self.backpropagate(*index, result.clone()); } - } Ok(self.get_best_choice(root)?) - } - else { + } else { self.add_node(from.clone(), None, None); - return self.search(from, iterations, send_status) + return self.search(from, iterations, send_status); } } // Adds a new node to the MCTS with the given state, action, and parent - fn add_node(&mut self, state: State, action: Option, parent: Option){ + fn add_node(&mut self, state: State, action: Option, parent: Option) { let new_node = Node::new(state, action, state.get_actions()); self.state_map.insert(state, self.size); self.tree.push(Vec::new()); @@ -110,20 +117,32 @@ impl MCTS { let mut max_index = 0 as usize; let mut node_index = root_index; loop { - if self.tree.get(node_index).expect("Empty child selection").len() == 0 { + if self + .tree + .get(node_index) + .expect("Empty child selection") + .len() + == 0 + { return node_index; - } - else { + } else { for index in self.tree.get(node_index).unwrap().iter() { - let node = self.nodes.get(*index).expect("selected child doesnt exist").clone(); - let node_ucb = node.calculate_ucb(self.nodes.get(node_index).unwrap().visits as usize, self.expl); + let node = self + .nodes + .get(*index) + .expect("selected child doesnt exist") + .clone(); + let node_ucb = node.calculate_ucb( + self.nodes.get(node_index).unwrap().visits as usize, + self.expl, + ); if node_ucb > max_ucb { max_ucb = node_ucb; max_index = index.clone(); } } node_index = max_index; - } + } max_ucb = std::f32::MIN; max_index = 0; } @@ -131,21 +150,32 @@ impl MCTS { // Expands the given node in the MCTS by adding all its untried actions as new nodes fn expand(&mut self, node_index: usize) -> usize { - let mut node = self.nodes.get_mut(node_index).expect("No node to expand").clone(); + let mut node = self + .nodes + .get_mut(node_index) + .expect("No node to expand") + .clone(); if node.untried_actions.len() == 0 { - self.add_node(node.state.clone().do_action(None), - None, - Some(node_index.clone()) + self.add_node( + node.state.clone().do_action(None), + None, + Some(node_index.clone()), ); - self.tree.get_mut(node_index).expect("No node").push(self.size - 1); + self.tree + .get_mut(node_index) + .expect("No node") + .push(self.size - 1); } else { for (_i, action) in node.untried_actions.iter().enumerate() { self.add_node( - node.state.clone().do_action(Some(action.clone())), - Some(action.clone()), - Some(node_index.clone()) + node.state.clone().do_action(Some(action.clone())), + Some(action.clone()), + Some(node_index.clone()), ); - self.tree.get_mut(node_index).expect("No node").push(self.size - 1); + self.tree + .get_mut(node_index) + .expect("No node") + .push(self.size - 1); } while node.untried_actions.len() > 0 { node.untried_actions.pop(); @@ -171,12 +201,18 @@ impl MCTS { // Updates the nodes in the MCTS from the given child node to the root based on the result of a simulated game fn backpropagate(&mut self, child_index: usize, result: (i8, isize)) { let mut current_node: &mut Node; - let mut parent_index: Option = self.parents.get(child_index).unwrap().clone(); + let mut parent_index: Option = self.parents.get(child_index).unwrap().clone(); while parent_index.is_some() { - current_node = self.nodes.get_mut(parent_index.unwrap()).expect("Parent doesn't exist"); + current_node = self + .nodes + .get_mut(parent_index.unwrap()) + .expect("Parent doesn't exist"); current_node.update_node(result); let tmp = parent_index.clone(); - parent_index = *self.parents.get(tmp.unwrap()).expect("Error fetching parent of parent"); + parent_index = *self + .parents + .get(tmp.unwrap()) + .expect("Error fetching parent of parent"); } } @@ -185,8 +221,17 @@ impl MCTS { fn get_best_choice(&self, from_index: usize) -> Result { let mut best_index = 0; let mut max_visits = 0; - for index in self.tree.get(from_index).expect("Empty list of children when getting best choice").iter().clone() { - let node = self.nodes.get(*index).expect("MCST, choice: node index doesnt exists"); + for index in self + .tree + .get(from_index) + .expect("Empty list of children when getting best choice") + .iter() + .clone() + { + let node = self + .nodes + .get(*index) + .expect("MCST, choice: node index doesnt exists"); if node.visits > max_visits { best_index = index.clone(); max_visits = node.visits; @@ -200,10 +245,8 @@ impl MCTS { let from_state = self.nodes.get(from_index).unwrap().clone().state; if from_state.next_turn != best_action.color { return Err(()); - } - else { + } else { Ok(best_action.clone()) } } } - diff --git a/src/othello.rs b/src/othello.rs index f443039..955fd95 100644 --- a/src/othello.rs +++ b/src/othello.rs @@ -1,8 +1,7 @@ -use std::{isize, i16}; +use std::{i16, isize}; use rand::Rng; - const BOARD_SIZE: usize = 8; #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] @@ -12,10 +11,9 @@ pub struct State { pub remaining_moves: i16, } impl State { - pub fn new() -> Self{ + pub fn new() -> Self { let mut new = Self { - board: [ - [-1; BOARD_SIZE]; BOARD_SIZE], + board: [[-1; BOARD_SIZE]; BOARD_SIZE], next_turn: 1, remaining_moves: 60, }; @@ -24,28 +22,35 @@ impl State { new.board[4][4] = 0; new.board[4][3] = 1; new - } pub fn get_actions(&self) -> Vec { let mut actions: Vec = Vec::new(); let mut tmp_action = Action::new(self.next_turn, 0, 0); - for (x, row) in self.board.iter().enumerate(){ - for (y, ch) in row.iter().enumerate(){ + for (x, row) in self.board.iter().enumerate() { + for (y, ch) in row.iter().enumerate() { tmp_action.x = x; tmp_action.y = y; if *ch == -1 { - for dir in vec![(0,1), (1,0), (1,1), (0,-1), (-1,0), (-1,-1), (1,-1), (-1,1)] { + for dir in vec![ + (0, 1), + (1, 0), + (1, 1), + (0, -1), + (-1, 0), + (-1, -1), + (1, -1), + (-1, 1), + ] { let mut tmp_state = self.clone(); - if tmp_state.flip_pieces(tmp_action.clone(), dir.0, dir.1){ + if tmp_state.flip_pieces(tmp_action.clone(), dir.0, dir.1) { actions.push(tmp_action.clone()); - break + break; } - } + } } } } - return actions; } @@ -65,7 +70,16 @@ impl State { if action.is_some() { let act = action.unwrap(); new_state.board[act.x][act.y] = act.color.clone(); - for dir in vec![(0,1), (1,0), (1,1), (0,-1), (-1,0), (-1,-1), (1,-1), (-1,1)] { + for dir in vec![ + (0, 1), + (1, 0), + (1, 1), + (0, -1), + (-1, 0), + (-1, -1), + (1, -1), + (-1, 1), + ] { new_state.flip_pieces(act.clone(), dir.0, dir.1); } } @@ -81,9 +95,9 @@ impl State { 0 => 1, _ => 0, }; - loop{ + loop { //Bounds Check - if x_index > BOARD_SIZE - 1 || y_index > BOARD_SIZE - 1 { + if x_index > BOARD_SIZE - 1 || y_index > BOARD_SIZE - 1 { return false; } match self.board[x_index][y_index] { @@ -92,15 +106,14 @@ impl State { to_flip.push((x_index.clone(), y_index.clone())); x_index = (x_index as isize + x1) as usize; y_index = (y_index as isize + y1) as usize; - }, + } _ => return false, } } if to_flip.len() == 0 { return false; - } - else { - for (x,y) in to_flip.iter() { + } else { + for (x, y) in to_flip.iter() { self.board[x.clone()][y.clone()] = action.color; } true @@ -108,7 +121,6 @@ impl State { } } - #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct Action { pub color: i8, @@ -126,7 +138,7 @@ impl Action { } } - +#[inline] pub fn simulate_game(state: &mut State) -> isize { let mut test_state = state.clone(); let mut test_actions = test_state.get_actions(); @@ -134,8 +146,7 @@ pub fn simulate_game(state: &mut State) -> isize { while test_state.remaining_moves > 0 { if test_actions.len() < 1 { do_act = None; - } - else { + } else { let mut rng = rand::thread_rng(); let index = rng.gen_range(0..test_actions.len()); do_act = test_actions.get(index).cloned(); @@ -158,10 +169,10 @@ fn caculate_win(player: i8, state: State) -> isize { for ch in row { if ch == p1 { p1_score += 1; - }else if ch == p2 { + } else if ch == p2 { p2_score += 1; } - } + } } match p1_score - p2_score { x if x > 0 => 1, @@ -171,31 +182,30 @@ fn caculate_win(player: i8, state: State) -> isize { } pub fn parse_state(json: serde_json::Value) -> State { - let mut new_board = [[-1;BOARD_SIZE]; BOARD_SIZE]; + let mut new_board = [[-1; BOARD_SIZE]; BOARD_SIZE]; let mut moves_left: i16 = 0; let next = match json["turn"] { serde_json::Value::Bool(true) => 1, _ => 0, - }; if let Some(board) = json["board"].as_array() { for (x, row) in board.iter().enumerate() { if let Some(row) = row.as_array() { for (y, cell) in row.iter().enumerate() { - match cell.as_i64() { + match cell.as_i64() { Some(1) => new_board[x][y] = 1, Some(0) => new_board[x][y] = 0, Some(-1) => { new_board[x][y] = -1; moves_left += 1; - }, - _ => {}, + } + _ => {} } } } } } - State{ + State { board: new_board, next_turn: next, remaining_moves: moves_left,