Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
ureq = {version = "2.9.6", features = ["json"] }
ureq = { version = "2.9.6", features = ["json"] }
serde_json = "1.0.114"
rand = "0.8.5"

[dev-dependencies]
criterion = "0.3.4"

[[bench]]
name = "othello"
harness = false
34 changes: 34 additions & 0 deletions benches/othello.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rusty_othello_ai::{
mcts::MCTS,
othello::{simulate_game, State},
};
use std::time::Duration;

pub fn bench_simulate_game(c: &mut Criterion) {
let mut group = c.benchmark_group("simulate_game");
group
.sample_size(1000)
.measurement_time(Duration::from_secs(10));
let game_state = rusty_othello_ai::othello::State::new();
group.bench_function("simulate game 1", |b| {
b.iter(|| simulate_game(black_box(&mut game_state.clone())))
});

group.finish()
}
pub fn bench_mcts_search(c: &mut Criterion) {
let mut group = c.benchmark_group("mcts_search");
group
.sample_size(1000)
.measurement_time(Duration::from_secs(10));
let mut mcts = MCTS::new("true", 1.0);
group.bench_function("Monte Carlo Tree Search", |b| {
b.iter(|| mcts.search(State::new(), 10, |_, _, _| {}))
});

group.finish()
}

criterion_group!(game, bench_simulate_game, bench_mcts_search);
criterion_main!(game);
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod mcts;
pub mod othello;
57 changes: 31 additions & 26 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use ureq::Response;
use rusty_othello_ai::mcts::MCTS;
use rusty_othello_ai::othello::{parse_state, Action, State};
use std::thread::current;
use std::usize;
use std::{thread::sleep, borrow::Borrow};
use std::time::Duration;
mod mcts;
mod othello;
use mcts::{MCTS, Node};
use othello::{State, Action, parse_state};


use std::usize;
use std::{borrow::Borrow, thread::sleep};
use ureq::Response;

const SERVER_URL: &str = "http://localhost:8181";

Expand All @@ -18,7 +14,11 @@ fn main() {
// If the argument is not recognized, the program will panic
let args: Vec<String> = std::env::args().collect();
let ai_color;
match args.get(1).expect("Please specify color to the AI").to_lowercase() {
match args
.get(1)
.expect("Please specify color to the AI")
.to_lowercase()
{
x if x == "false" => ai_color = x,
x if x == "0" => ai_color = "false".to_string(),
x if x == "b" => ai_color = "false".to_string(),
Expand All @@ -28,7 +28,6 @@ fn main() {
x if x == "w" => ai_color = "true".to_string(),
x if x == "white" => ai_color = "true".to_string(),
_ => panic!("Please pass a proper argument to the AI"),

}
// Initialize the game state and the Monte Carlo Tree Search (MCTS)
// The MCTS is initialized with a new node that represents the current game state
Expand All @@ -41,12 +40,12 @@ fn main() {
loop {
// The AI checks if it's its turn, if so, it gets the current game state and performs a search using MCTS
match is_my_turn(ai_color.borrow()) {
Ok(true) => {
Ok(true) => {
state = get_game_state();
choice = mcts.search(state, ai_iterations, send_progress);
// Gives the ai 2% more iterations every round to balance the game simulations
// being shorter
ai_iterations += ai_iterations / 50;
ai_iterations += ai_iterations / 50;

// If a valid action is found, it sends the move to the server and updates the game state
if choice.is_ok() {
Expand All @@ -56,16 +55,15 @@ fn main() {
// If no valid action is found, it sends a pass move to the server and updates the game state
else {
let _ = send_move(&ai_color, None);
state.do_action(None);
state.do_action(None);
}

},
}
// If it's not the AI's turn, it performs a search using MCTS and waits
Ok(false) => {
let dev_null = |_a: usize, _b: usize, _c: &i8| -> (){};
let dev_null = |_a: usize, _b: usize, _c: &i8| -> () {};
_ = mcts.search(state, 1000, dev_null);
//sleep(Duration::from_secs(1));
},
}
Err(e) => {
eprintln!("Error checking turn: {}", e);
sleep(Duration::from_secs(1));
Expand All @@ -79,7 +77,7 @@ fn is_my_turn(ai: &String) -> Result<bool, Box<dyn std::error::Error>> {
let mut delay = Duration::from_secs(1);
let opponent = match ai {
x if x == "true" => "false",
_ => "true"
_ => "true",
};
loop {
let url = format!("{}/turn", SERVER_URL);
Expand All @@ -94,10 +92,13 @@ fn is_my_turn(ai: &String) -> Result<bool, Box<dyn std::error::Error>> {
// If the response is anything else, the function returns an error
_ => return Err("Unexpected response from server".into()),
}
},
}
Err(e) => {
// Error occurred, possibly a network issue or server error, wait before trying again
eprintln!("Error checking turn: {}, will retry after {:?} seconds", e, delay);
eprintln!(
"Error checking turn: {}, will retry after {:?} seconds",
e, delay
);
sleep(delay);
delay = std::cmp::min(delay.saturating_mul(2), Duration::from_secs(10));
}
Expand All @@ -112,12 +113,14 @@ fn get_game_state() -> State {
let mut delay = Duration::from_secs(3);
loop {
match get_json() {
Ok(resp) => return parse_state(resp.into_json().expect("Error parsing response to json")),
Ok(resp) => {
return parse_state(resp.into_json().expect("Error parsing response to json"))
}
Err(_e) => {
sleep(delay);
delay *= 2;
delay = std::cmp::min(Duration::from_millis(10000), delay);
},
}
}
}
}
Expand All @@ -138,7 +141,10 @@ fn send_move(player: &String, ai_move: Option<Action>) -> Result<Response, ureq:
// The setChoice endpoint requires the x and y coordinates of the move and the player
if ai_move.is_some() {
let ai_choice = ai_move.unwrap();
url = format!("{}/setChoice/{}/{}/{}",SERVER_URL, ai_choice.x, ai_choice.y, player);
url = format!(
"{}/setChoice/{}/{}/{}",
SERVER_URL, ai_choice.x, ai_choice.y, player
);
}
// If the AI does not have a move, format the URL for the skipTurn endpoint
// The skipTurn endpoint requires the player
Expand All @@ -148,12 +154,11 @@ fn send_move(player: &String, ai_move: Option<Action>) -> Result<Response, ureq:
resp = ureq::get(&url).call()?;
Ok(resp)
}
fn send_progress(current: usize, total: usize, ai_color: &i8) {
fn send_progress(current: usize, total: usize, ai_color: &i8) {
let color = match ai_color {
1 => "false",
_ => "true",
};
let url = format!("{}/AIStatus/{}/{}/{}", SERVER_URL, current, total, color);
_ = ureq::post(&url).call();

}
Loading
Loading