diff --git a/src/bin/ai-test.rs b/src/bin/ai-test.rs new file mode 100644 index 0000000..346f86a --- /dev/null +++ b/src/bin/ai-test.rs @@ -0,0 +1,50 @@ +use rusty_othello_ai::mcts::MCTS; +use rusty_othello_ai::othello::{caculate_win, Color, State}; +use std::isize; + +pub fn main() { + let args: Vec = std::env::args().collect(); + let mut win_balance: isize = 0; + let a: f32 = args + .get(1) + .expect("Missing value for A") + .parse() + .expect("Not a valid floatingpoint number"); + let b: f32 = args + .get(2) + .expect("Missing value for A") + .parse() + .expect("Not a valid floatingpoint number"); + + let mut state = State::new(); + let mut mcts = MCTS::new("true", a); + let mut mcts2 = MCTS::new("false", b); + let mut ai_iterations = 500; + loop { + state = ai_turn(&mut mcts, state.clone(), ai_iterations); + if state.remaining_moves == 0 { + break; + } + state = ai_turn(&mut mcts2, state.clone(), ai_iterations); + if state.remaining_moves == 0 { + break; + } + ai_iterations += ai_iterations / 100; + } + win_balance += match caculate_win(state) { + Some(Color::WHITE) => 1, + Some(Color::BLACK) => -1, + None => 0, + }; + println!("{win_balance}") +} + +fn ai_turn(mcts: &mut MCTS, state: State, iterations: usize) -> State { + let dev_null = |_a: usize, _b: usize, _c: &Color| -> () {}; + let action = mcts.search(state.clone(), iterations, dev_null); + if action.is_ok() { + state.clone().do_action(Some(action.unwrap().clone())) + } else { + state.clone().do_action(None) + } +} diff --git a/src/console_game.rs b/src/console_game.rs new file mode 100644 index 0000000..774bb07 --- /dev/null +++ b/src/console_game.rs @@ -0,0 +1,130 @@ +use std::io::Write; +use std::isize; +use std::process::exit; + +use crate::mcts::MCTS; +use crate::othello::{caculate_win, print_state, Action, Color, Position, State}; + +enum GameCommand { + SKIP, + QUIT, + INVALID, + MOVE(usize, usize), +} + +pub fn console_game() { + let mut win_balance: isize = 0; + let a = 1.0; + println!("Game mode: player vs AI\n"); + let mut state = State::new(); + let mut mcts = MCTS::new("true", a); + _ = std::io::stdout().flush(); + let mut ai_iterations = 20000; + loop { + print_state(state); + state = player_turn(state.clone()); + if state.remaining_moves == 0 { + break; + } + print_state(state); + state = ai_turn(&mut mcts, state.clone(), ai_iterations); + ai_iterations += ai_iterations / 100; + + if state.remaining_moves == 0 { + break; + } + } + //print_state(state); + win_balance += match caculate_win(state) { + Some(Color::WHITE) => { + println!("White wins!"); + 1 + } + Some(Color::BLACK) => { + println!("Black wins!"); + -1 + } + None => { + println!("Draw."); + 0 + } + }; + //println!("\nGAME OVER\n"); + println!("\nResult: {win_balance}") +} + +fn ai_turn(mcts: &mut MCTS, state: State, iterations: usize) -> State { + let dev_null = |_a: usize, _b: usize, _c: &Color| -> () { /*println!("Progress: {a}/{b}")*/ }; + let action = mcts.search(state.clone(), iterations, dev_null); + if action.is_ok() { + println!("{:?}", action.clone().unwrap().position); + state.clone().do_action(Some(action.unwrap().clone())) + } else { + state.clone().do_action(None) + } +} + +fn player_turn(state: State) -> State { + let mut player_choice; + loop { + print!("Enter coordinates for desired move: "); + let _ = std::io::stdout().flush(); + let cmd = read_command(); + match cmd { + GameCommand::QUIT => exit(0), + GameCommand::INVALID => { + println!("Please provide a valid command 'quit' 'skip' or 'x,y'") + } + GameCommand::SKIP => { + player_choice = None; + break; + } + GameCommand::MOVE(x_index, y_index) => { + player_choice = Some(Action { + color: Color::BLACK, + position: Position { + x: x_index, + y: y_index, + }, + }); + if state + .get_actions() + .contains(&player_choice.clone().unwrap()) + { + break; + } else { + println!("Invalid move."); + let pos: Vec<(usize, usize)> = state + .get_actions() + .iter() + .map(|a| (a.position.y, a.position.x)) + .collect(); + println!("Valid moves: {:?}", pos); + print_state(state); + } + } + } + } + state.clone().do_action(player_choice) +} + +fn read_command() -> GameCommand { + let mut buf = String::new(); + let _ = std::io::stdin().read_line(&mut buf); + match buf.to_lowercase().as_str().trim() { + "quit" => GameCommand::QUIT, + "skip" => GameCommand::SKIP, + line => { + let cmd: Vec<&str> = line.trim().split(",").clone().collect(); + match (cmd.get(0), cmd.get(1)) { + (Some(cmd_1), Some(cmd_2)) => { + match (cmd_1.parse::(), cmd_2.parse::()) { + (Ok(y_index), Ok(x_index)) => GameCommand::MOVE(x_index, y_index), + _ => GameCommand::INVALID, + } + } + _ => GameCommand::INVALID, + } + } + } +} diff --git a/src/main.rs b/src/main.rs index 9b82776..45f4f9e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,14 @@ -use rusty_othello_ai::mcts::MCTS; -use rusty_othello_ai::othello::{parse_state, Action, State}; -use std::thread::current; +use std::process::exit; use std::time::Duration; use std::usize; use std::{borrow::Borrow, thread::sleep}; use ureq::Response; +mod console_game; +mod mcts; +mod othello; +use console_game::console_game; +use mcts::MCTS; +use othello::{parse_state, Action, Color, State}; const SERVER_URL: &str = "http://localhost:8181"; @@ -27,6 +31,10 @@ fn main() { x if x == "1" => ai_color = "true".to_string(), x if x == "w" => ai_color = "true".to_string(), x if x == "white" => ai_color = "true".to_string(), + x if x == "console" => { + console_game(); + exit(0); + } _ => panic!("Please pass a proper argument to the AI"), } // Initialize the game state and the Monte Carlo Tree Search (MCTS) @@ -60,7 +68,7 @@ fn main() { } // If it's not the AI's turn, it performs a search using MCTS and waits Ok(false) => { - let dev_null = |_a: usize, _b: usize, _c: &i8| -> () {}; + let dev_null = |_a: usize, _b: usize, _c: &Color| -> () {}; _ = mcts.search(state, 1000, dev_null); //sleep(Duration::from_secs(1)); } @@ -143,7 +151,7 @@ fn send_move(player: &String, ai_move: Option) -> Result) -> Result "false", - _ => "true", + Color::BLACK => "false", + Color::WHITE => "true", }; let url = format!("{}/AIStatus/{}/{}/{}", SERVER_URL, current, total, color); _ = ureq::post(&url).call(); diff --git a/src/mcts.rs b/src/mcts.rs index 378888c..3de11d2 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -1,4 +1,5 @@ -use crate::othello::{simulate_game, Action, State}; +use crate::othello::{simulate_game, Action, Color, State}; +use rand::Rng; use std::collections::HashMap; #[derive(Debug, Clone)] @@ -21,7 +22,7 @@ impl Node { } } - pub fn update_node(&mut self, result: (i8, isize)) { + pub fn update_node(&mut self, result: (Color, isize)) { self.visits += 1; if result.0 == self.state.next_turn { self.score += result.1; @@ -39,7 +40,7 @@ impl Node { #[derive()] pub struct MCTS { pub size: usize, - color: i8, + color: Color, expl: f32, nodes: Vec, tree: Vec>, @@ -49,13 +50,11 @@ pub struct MCTS { impl MCTS { pub fn new(col: &str, explore: f32) -> Self { - let ai_color: i8; + let ai_color: Color; match col { - b if b == "false".to_string() => ai_color = 0, - _ => ai_color = 1, + b if b == "false".to_string() => ai_color = Color::BLACK, + _ => ai_color = Color::WHITE, }; - //let mut map = HashMap::new(); - //map.insert(node.state, 0 as usize); Self { tree: Vec::new(), color: ai_color, @@ -73,28 +72,19 @@ impl MCTS { &mut self, from: State, iterations: usize, - send_status: fn(usize, usize, &i8), + send_status: fn(usize, usize, &Color), ) -> Result { if let Some(root) = self.state_map.get(&from).cloned() { for i in 0..iterations { if i % 1000 == 0 { - //println!("Progress: {i}/{iterations}"); _ = send_status(i, iterations, &self.color); } - let node_index = self.select(root.clone()).clone(); - let node_index = self.expand(node_index.clone()).clone(); - for index in self - .tree - .get(node_index) - .expect("No child nodes to simulate") - .clone() - .iter() - { - let result: (i8, isize) = self.simulate(*index); - self.backpropagate(*index, result.clone()); - } + let selected_node = self.select(root); + let expanded_node = self.expand(selected_node); + let result: (Color, isize) = self.simulate(expanded_node); + self.backpropagate(expanded_node, result); } - Ok(self.get_best_choice(root)?) + return self.get_best_choice(root); } else { self.add_node(from.clone(), None, None); return self.search(from, iterations, send_status); @@ -116,76 +106,75 @@ impl MCTS { let mut max_ucb = std::f32::MIN; let mut max_index = 0 as usize; let mut node_index = root_index; + let mut depth = 0; loop { - if self - .tree - .get(node_index) - .expect("Empty child selection") - .len() - == 0 - { + // Failsafe to avoid tree becoming too deep + if depth > 100 { + return node_index; + } + let children = &self.tree[node_index]; + if children.is_empty() { + return node_index; + } + if !self.nodes[node_index].untried_actions.is_empty() { return node_index; - } else { - for index in self.tree.get(node_index).unwrap().iter() { - let node = self - .nodes - .get(*index) - .expect("selected child doesnt exist") - .clone(); - let node_ucb = node.calculate_ucb( - self.nodes.get(node_index).unwrap().visits as usize, - self.expl, - ); - if node_ucb > max_ucb { - max_ucb = node_ucb; - max_index = index.clone(); - } + } + let parent_visits = self.nodes[node_index].visits; + for &child_index in children { + let child = &self.nodes[child_index]; + let ucb = child.calculate_ucb(parent_visits, self.expl); + + if ucb > max_ucb { + max_ucb = ucb; + max_index = child_index; } - node_index = max_index; } + if max_index == node_index { + return node_index; + } + node_index = max_index; max_ucb = std::f32::MIN; - max_index = 0; + depth += 1; } } // Expands the given node in the MCTS by adding all its untried actions as new nodes fn expand(&mut self, node_index: usize) -> usize { - let mut node = self - .nodes - .get_mut(node_index) - .expect("No node to expand") - .clone(); - if node.untried_actions.len() == 0 { - self.add_node( - node.state.clone().do_action(None), - None, - Some(node_index.clone()), - ); - self.tree - .get_mut(node_index) - .expect("No node") - .push(self.size - 1); + // Get the node not a clone of it + let untried_actions = self.nodes[node_index].untried_actions.clone(); + + if untried_actions.is_empty() { + // No actions to try add skip node + let new_state = self.nodes[node_index].state.clone().do_action(None); + self.add_node(new_state, None, Some(node_index)); + self.tree[node_index].push(self.size - 1); + + // Return the new node's index + return self.size - 1; } else { - for (_i, action) in node.untried_actions.iter().enumerate() { - self.add_node( - node.state.clone().do_action(Some(action.clone())), - Some(action.clone()), - Some(node_index.clone()), - ); - self.tree - .get_mut(node_index) - .expect("No node") - .push(self.size - 1); - } - while node.untried_actions.len() > 0 { - node.untried_actions.pop(); - } + // Pick one random action to expand (not all at once) + let mut rng = rand::thread_rng(); + let action_index = rng.gen_range(0..untried_actions.len()); + let action = untried_actions[action_index].clone(); + + // Remove this action from untried_actions in the original node + self.nodes[node_index].untried_actions.remove(action_index); + + // Create a new node with this action + let new_state = self.nodes[node_index] + .state + .clone() + .do_action(Some(action.clone())); + self.add_node(new_state, Some(action), Some(node_index)); + self.tree[node_index].push(self.size - 1); + + // Return the new node's index + return self.size - 1; } - node_index } // Simulates a game from the given node and returns the result - fn simulate(&mut self, node_index: usize) -> (i8, isize) { + fn simulate(&mut self, node_index: usize) -> (Color, isize) { if let Some(node) = self.nodes.get_mut(node_index) { let mut node_state = node.state.clone(); let mut score = simulate_game(&mut node_state); @@ -195,11 +184,11 @@ impl MCTS { node.update_node((node.state.next_turn, score)); return (node_state.next_turn, score); } - (-1, 0) + panic!("Node not found"); } // Updates the nodes in the MCTS from the given child node to the root based on the result of a simulated game - fn backpropagate(&mut self, child_index: usize, result: (i8, isize)) { + fn backpropagate(&mut self, child_index: usize, result: (Color, isize)) { let mut current_node: &mut Node; let mut parent_index: Option = self.parents.get(child_index).unwrap().clone(); while parent_index.is_some() { diff --git a/src/othello.rs b/src/othello.rs index 955fd95..c442a2d 100644 --- a/src/othello.rs +++ b/src/othello.rs @@ -1,202 +1,479 @@ -use std::{i16, isize}; - use rand::Rng; +use std::{fmt, isize, u16, usize}; const BOARD_SIZE: usize = 8; +const FIELD_SIZE: usize = 2; + +#[derive(Debug, Clone)] +struct EmptyFieldError; +impl fmt::Display for EmptyFieldError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Empty Fields can't be flipped") + } +} +#[derive(Debug, Clone)] +struct OccupiedFieldError; +impl fmt::Display for OccupiedFieldError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Occupied Fields can't be Set") + } +} #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub struct State { - pub board: [[i8; BOARD_SIZE]; BOARD_SIZE], - pub next_turn: i8, - pub remaining_moves: i16, +pub enum Color { + BLACK, + WHITE, } -impl State { - pub fn new() -> Self { - let mut new = Self { - board: [[-1; BOARD_SIZE]; BOARD_SIZE], - next_turn: 1, - remaining_moves: 60, - }; - new.board[3][3] = 0; - new.board[3][4] = 1; - new.board[4][4] = 0; - new.board[4][3] = 1; - new +impl Color { + fn bitmask(&self) -> u16 { + match *self { + Color::BLACK => 0b010, + Color::WHITE => 0b001, + } } - pub fn get_actions(&self) -> Vec { - let mut actions: Vec = Vec::new(); - let mut tmp_action = Action::new(self.next_turn, 0, 0); - for (x, row) in self.board.iter().enumerate() { - for (y, ch) in row.iter().enumerate() { - tmp_action.x = x; - tmp_action.y = y; - if *ch == -1 { - for dir in vec![ - (0, 1), - (1, 0), - (1, 1), - (0, -1), - (-1, 0), - (-1, -1), - (1, -1), - (-1, 1), - ] { - let mut tmp_state = self.clone(); - if tmp_state.flip_pieces(tmp_action.clone(), dir.0, dir.1) { - actions.push(tmp_action.clone()); - break; +} +#[derive(Debug, Clone, Copy)] +pub enum Direction { + Left, + Right, + Up, + Down, + UpLeft, + UpRight, + DownLeft, + DownRight, +} +impl Direction { + const VALUES: [Self; 8] = [ + Self::Left, + Self::Right, + Self::Up, + Self::Down, + Self::UpLeft, + Self::UpRight, + Self::DownLeft, + Self::DownRight, + ]; +} +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub struct Position { + pub x: usize, + pub y: usize, +} +impl Position { + fn new(x_coordinate: usize, y_coordinate: usize) -> Option { + match (x_coordinate, y_coordinate) { + (x, y) if x >= BOARD_SIZE || y >= BOARD_SIZE => None, + (_, _) => Some(Self { + x: x_coordinate, + y: y_coordinate, + }), + } + } + fn shift(self, dir: Direction) -> Option { + let x = self.x; + let y = self.y; + match dir { + Direction::Up => match y { + 0 => None, + _ => Position::new(x, y - 1), + }, + Direction::Down => match y + 1 { + BOARD_SIZE => None, + _ => Position::new(x, y + 1), + }, + Direction::Left => match x { + 0 => None, + _ => Position::new(x - 1, y), + }, + Direction::Right => match x + 1 { + BOARD_SIZE => None, + _ => Position::new(x + 1, y), + }, + Direction::UpLeft => match (x, y) { + (0, _) => None, + (_, 0) => None, + (_, _) => Position::new(x - 1, y - 1), + }, + Direction::UpRight => match (x + 1, y) { + (BOARD_SIZE, _) => None, + (_, 0) => None, + (_, _) => Position::new(x + 1, y - 1), + }, + Direction::DownLeft => match (x, y + 1) { + (0, _) => None, + (_, BOARD_SIZE) => None, + (_, _) => Position::new(x - 1, y + 1), + }, + Direction::DownRight => match (x, y) { + (BOARD_SIZE, _) => None, + (_, BOARD_SIZE) => None, + (_, _) => Position::new(x + 1, y + 1), + }, + } + } +} + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +struct Row { + value: u16, +} +impl Row { + fn new(val: u16) -> Row { + Self { value: val } + } + fn get_pos(&self, pos: usize) -> Option { + let mask = 0b011 << (pos * FIELD_SIZE); + let field = (self.value & mask) >> (pos * FIELD_SIZE); + match field { + w if w == Color::WHITE.bitmask() => Some(Color::WHITE), + b if b == Color::BLACK.bitmask() => Some(Color::BLACK), + _ => None, + } + } + fn set_pos(&self, color: Color, pos: usize) -> Result { + let color_mask = color.bitmask() << (pos * FIELD_SIZE); + let check_mask = 0b011 << (pos * FIELD_SIZE); + match self.value & check_mask { + 0 => Ok(Row { + value: self.value ^ color_mask, + }), + _ => Err(OccupiedFieldError), + } + } + fn flip_pos(&self, pos: usize) -> Result { + let flip_mask = 0b011 << (pos * FIELD_SIZE); + match self.value & flip_mask { + 0 => Err(EmptyFieldError), + _ => Ok(Row { + value: self.value ^ flip_mask, + }), + } + } + fn count_colors(&self) -> (isize, isize) { + let mut w_score = 0; + let mut b_score = 0; + let mut row = self.value.clone(); + for _ in 0..BOARD_SIZE { + if row & Color::WHITE.bitmask() > 0 { + w_score += 1; + } + if row & Color::BLACK.bitmask() > 0 { + b_score += 1; + } + row = row >> FIELD_SIZE; + } + return (w_score, b_score); + } +} +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +struct Board { + rows: [Row; BOARD_SIZE as usize], +} +impl Board { + fn new() -> Board { + let mut new_rows = [Row::new(0); BOARD_SIZE as usize]; + let center = (BOARD_SIZE / 2) - 1; + new_rows[center as usize] = Row::new(0b1001 << (center * FIELD_SIZE)); + new_rows[(center + 1) as usize] = Row::new(0b0110 << (center * FIELD_SIZE)); + Self { rows: new_rows } + } + fn blank() -> Board { + let new_rows = [Row::new(0); BOARD_SIZE as usize]; + Self { rows: new_rows } + } + fn flip_pieces(&self, action: Action, position: Position, dir: Direction) -> Option { + let mut to_flip = Vec::new(); + let mut current_pos = position; + + // Move in the specified direction, collecting opponent pieces + while let Some(next_pos) = current_pos.shift(dir) { + match self.rows[next_pos.y].get_pos(next_pos.x) { + Some(color) if color != action.color => { + // Found an opponent's piece add it to list + to_flip.push(next_pos); + current_pos = next_pos; + } + Some(color) if color == action.color => { + // Found own piece flip all the pieces collected + if !to_flip.is_empty() { + // Create new board with the flipped pieces + let mut new_board = self.clone(); + + // Flip all pieces in between + for pos in to_flip { + new_board.rows[pos.y] = new_board.rows[pos.y] + .flip_pos(pos.x) + .expect("Should be able to flip occupied positions"); } + + return Some(new_board); } + return None; + } + _ => { + // Empty space or board edge, can't flip in this direction + return None; } } } - - return actions; + None } - - pub fn do_action(&mut self, action: Option) -> State { - let next_turn = match self.next_turn { - 0 => 1, - 1 => 0, - _ => -1, - }; - - let mut new_state = State { - next_turn: next_turn.clone(), - board: self.board.clone(), - remaining_moves: (self.remaining_moves.clone() - 1), + fn get_empty_positions(&self) -> Vec { + let mut positions = Vec::new(); + for (y, row) in self.into_iter().enumerate() { + for x in 0..BOARD_SIZE { + match row.get_pos(x) { + None => { + positions.push(Position::new(x, y).expect( + "Iterating through board shouldn't be able to get out of bounds", + )) + } + Some(_) => (), + } + } + } + return positions; + } + fn would_flip_pieces(&self, action: Action, position: Position, dir: Direction) -> bool { + match position.shift(dir) { + Some(pos_1) => match self.rows[pos_1.y].get_pos(pos_1.x) { + Some(color) if color != action.color => { + // Found an opponent's piece in this direction + let mut current_pos = pos_1; + while let Some(next_pos) = current_pos.shift(dir) { + match self.rows[next_pos.y].get_pos(next_pos.x) { + Some(color) if color == action.color => { + // Found our own piece on the other side + return true; + } + Some(_) => { + // Another opponent piece keep checking + current_pos = next_pos; + } + None => { + // Empty space can't flip + return false; + } + } + } + false // Reached edge of board without finding own piece + } + _ => false, // Either empty or same color + }, + None => false, // Can't go in this direction + } + } +} +impl IntoIterator for Board { + type Item = Row; + type IntoIter = BoardIntoIterator; + fn into_iter(self) -> Self::IntoIter { + BoardIntoIterator { + board: self.clone(), + index: 0, + } + } +} +struct BoardIntoIterator { + board: Board, + index: usize, +} +impl Iterator for BoardIntoIterator { + type Item = Row; + fn next(&mut self) -> Option { + let result = match self.index { + x if x < BOARD_SIZE as usize => self.board.rows[x], + _ => return None, }; + self.index += 1; + Some(result) + } +} - if action.is_some() { - let act = action.unwrap(); - new_state.board[act.x][act.y] = act.color.clone(); - for dir in vec![ - (0, 1), - (1, 0), - (1, 1), - (0, -1), - (-1, 0), - (-1, -1), - (1, -1), - (-1, 1), - ] { - new_state.flip_pieces(act.clone(), dir.0, dir.1); +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub struct State { + board: Board, + pub next_turn: Color, + pub remaining_moves: u8, + pub prev_player_skipped: bool, +} +impl State { + pub fn new() -> Self { + Self { + board: Board::new(), + next_turn: Color::BLACK, + remaining_moves: 121, + prev_player_skipped: false, + } + } + pub fn get_actions(&self) -> Vec { + let empty_spots = self.board.get_empty_positions(); + let mut actions = Vec::new(); + if empty_spots.len() == 0 { + return actions; + } + for pos in empty_spots { + let action = Action::new(self.next_turn.clone(), pos); + if self.is_valid_action(action.clone()) { + actions.push(action); } } - return new_state; + return actions; } - - fn flip_pieces(&mut self, action: Action, x1: isize, y1: isize) -> bool { - let mut to_flip = Vec::new(); - let mut x_index = (action.x as isize + x1) as usize; - let mut y_index = (action.y as isize + y1) as usize; - let own_color = action.color.clone(); - let opponent = match action.color { - 0 => 1, - _ => 0, - }; - loop { - //Bounds Check - if x_index > BOARD_SIZE - 1 || y_index > BOARD_SIZE - 1 { - return false; + fn is_valid_action(&self, action: Action) -> bool { + for dir in Direction::VALUES { + if self + .board + .would_flip_pieces(action.clone(), action.position.clone(), dir) + { + return true; } - match self.board[x_index][y_index] { - x if x == own_color => break, - k if k == opponent => { - to_flip.push((x_index.clone(), y_index.clone())); - x_index = (x_index as isize + x1) as usize; - y_index = (y_index as isize + y1) as usize; + } + false + } + + pub fn do_action(&self, action: Option) -> State { + let mut new_state = self.clone(); + match action { + Some(act) => { + if new_state.flip_directions(act) { + new_state.remaining_moves -= 1; + new_state.prev_player_skipped = false; + } else { + new_state.prev_player_skipped = true; } - _ => return false, + } + None => { + new_state.prev_player_skipped = true; } } - if to_flip.len() == 0 { - return false; + // If both players had to skip end the game + if new_state.prev_player_skipped && self.prev_player_skipped { + new_state.remaining_moves = 0; + } + new_state.next_turn = match self.next_turn { + Color::BLACK => Color::WHITE, + Color::WHITE => Color::BLACK, + }; + new_state + } + fn flip_directions(&mut self, action: Action) -> bool { + let mut any_flipped = false; + let mut new_board = self.board.clone(); + + // Set the piece at the action position + if let Ok(row) = new_board.rows[action.position.y].set_pos(action.color, action.position.x) + { + new_board.rows[action.position.y] = row; } else { - for (x, y) in to_flip.iter() { - self.board[x.clone()][y.clone()] = action.color; + return false; + } + // Check each direction for pieces to flip + for dir in Direction::VALUES { + if let Some(updated_board) = new_board.flip_pieces(action.clone(), action.position, dir) + { + new_board = updated_board; + any_flipped = true; } - true } + if any_flipped { + self.board = new_board; + self.remaining_moves -= 1; + } + any_flipped } } #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct Action { - pub color: i8, - pub x: usize, - pub y: usize, + pub color: Color, + pub position: Position, } impl Action { - pub fn new(player: i8, x1: usize, y1: usize) -> Self { + pub fn new(player: Color, pos: Position) -> Self { Self { color: player, - x: x1, - y: y1, + position: pos, } } } #[inline] -pub fn simulate_game(state: &mut State) -> isize { +pub fn simulate_game(state: &State) -> isize { let mut test_state = state.clone(); - let mut test_actions = test_state.get_actions(); - let mut do_act: Option; - while test_state.remaining_moves > 0 { - if test_actions.len() < 1 { - do_act = None; + let mut consecutive_skips = 0; + + // Maximum number of moves to prevent infinite loops + let max_iterations = 100; + let mut iterations = 0; + + while test_state.remaining_moves > 0 && consecutive_skips < 2 && iterations < max_iterations { + iterations += 1; + + let test_actions = test_state.get_actions(); + let current_action; + + if test_actions.is_empty() { + current_action = None; + consecutive_skips += 1; } else { let mut rng = rand::thread_rng(); let index = rng.gen_range(0..test_actions.len()); - do_act = test_actions.get(index).cloned(); + current_action = Some(test_actions[index].clone()); + consecutive_skips = 0; + } + + test_state = test_state.do_action(current_action); + + // If both players had to skip end the game + if consecutive_skips >= 2 { + break; } - test_state = test_state.do_action(do_act); - test_actions = test_state.get_actions(); } - caculate_win(state.next_turn, test_state) + match caculate_win(test_state) { + Some(Color::WHITE) => 1, + Some(Color::BLACK) => -1, + None => 0, + } } -fn caculate_win(player: i8, state: State) -> isize { - let p1 = player; - let p2 = match p1 { - 1 => 0, - _ => 1, - }; - let mut p1_score: isize = 0; - let mut p2_score: isize = 0; - for row in state.board { - for ch in row { - if ch == p1 { - p1_score += 1; - } else if ch == p2 { - p2_score += 1; - } - } +pub fn caculate_win(state: State) -> Option { + let mut w_score: isize = 0; + let mut b_score: isize = 0; + for row in state.board.rows { + let (w, b) = row.count_colors(); + w_score += w; + b_score += b; } - match p1_score - p2_score { - x if x > 0 => 1, - x if x < 0 => -1, - _ => 0, + match w_score - b_score { + x if x > 0 => Some(Color::WHITE), + x if x < 0 => Some(Color::BLACK), + _ => None, } } pub fn parse_state(json: serde_json::Value) -> State { - let mut new_board = [[-1; BOARD_SIZE]; BOARD_SIZE]; - let mut moves_left: i16 = 0; + //todo!("Fix parse_state") + let mut new_board = Board::blank(); + let mut moves_left: u8 = 0; let next = match json["turn"] { - serde_json::Value::Bool(true) => 1, - _ => 0, + serde_json::Value::Bool(true) => Color::BLACK, + _ => Color::WHITE, }; if let Some(board) = json["board"].as_array() { for (x, row) in board.iter().enumerate() { if let Some(row) = row.as_array() { for (y, cell) in row.iter().enumerate() { match cell.as_i64() { - Some(1) => new_board[x][y] = 1, - Some(0) => new_board[x][y] = 0, + Some(1) => { + new_board.rows[y] = new_board.rows[y].set_pos(Color::WHITE, x).unwrap() + } + Some(0) => { + new_board.rows[y] = new_board.rows[y].set_pos(Color::BLACK, x).unwrap() + } Some(-1) => { - new_board[x][y] = -1; moves_left += 1; } _ => {} @@ -209,12 +486,61 @@ pub fn parse_state(json: serde_json::Value) -> State { board: new_board, next_turn: next, remaining_moves: moves_left, + prev_player_skipped: false, } } pub fn print_state(state: State) { - for i in state.board { - println!("{:?}", i); + println!(" 0 1 2 3 4 5 6 7"); + let black_comp = Color::BLACK.bitmask(); // << ((BOARD_SIZE - 1) * FIELD_SIZE); + let white_comp = Color::WHITE.bitmask(); // << ((BOARD_SIZE - 1) * FIELD_SIZE); + for (i, row) in state.board.into_iter().enumerate() { + print!("{i} "); + for f in 0..BOARD_SIZE { + let c = { + if row.value & (black_comp << (f * FIELD_SIZE)) != 0 { + 'B' + } else if row.value & (white_comp << (f * FIELD_SIZE)) != 0 { + 'W' + } else { + '_' + } + }; + print!("|{}", c); + } + print!("|\n"); + } + let next = match state.next_turn { + Color::BLACK => "Black", + Color::WHITE => "White", + }; + println!("Next: {}", next) +} + +#[cfg(test)] +mod othello_tests { + use super::*; + + #[test] + fn test_board_empty_spaces() { + let board = Board::new(); + assert_eq!(board.get_empty_positions().len(), 60); + } + #[test] + fn test_row_get_pos() { + let board = Board::new(); + assert_eq!(board.rows[3].get_pos(3), Some(Color::WHITE)); + assert_eq!(board.rows[3].get_pos(4), Some(Color::BLACK)); + assert_eq!(board.rows[4].get_pos(4), Some(Color::WHITE)); + assert_eq!(board.rows[4].get_pos(3), Some(Color::BLACK)); + assert_eq!(board.rows[1].get_pos(3), None); + assert_eq!(board.rows[2].get_pos(2), None); + assert_eq!(board.rows[2].get_pos(4), None); + } + #[test] + fn test_row_set_pos() { + let board = Board::new(); + assert!(board.rows[3].set_pos(Color::BLACK, 4).is_err()); + assert!(board.rows[3].set_pos(Color::WHITE, 3).is_err()); } - println!("next: {}", state.next_turn) }