diff --git a/src/bin/generator.rs b/src/bin/generator.rs index 0ae181e..cb6ff39 100644 --- a/src/bin/generator.rs +++ b/src/bin/generator.rs @@ -4,9 +4,10 @@ use std::error::Error; use std::io::Write; use sudoku_solver::solver::{SolveController, SolveStatistics}; use std::str::FromStr; -use std::sync::{mpsc}; +use std::sync::{mpsc, Arc}; use std::process::exit; use std::thread; +use std::sync::atomic::{AtomicBool, Ordering}; /* We have to be very careful here because Grid contains lots of Rcs and RefCells which could enable mutability @@ -125,64 +126,28 @@ fn main() { let solve_controller = difficulty.map_to_solve_controller(); - let (grid, solve_statistics, num_hints) = - if threads < 1 { - eprintln!("--threads must be at least 1"); - exit(1); - } else if threads == 1 { - let mut rng = SmallRng::from_entropy(); - let result = get_puzzle_matching_conditions(&mut rng, &difficulty, &solve_controller, max_attempts, max_hints); - match result { - Some(x) => x, - None => { - eprintln!("Unable to find an appropriate puzzle in the required amount of attempts"); - exit(1); - } - } - } else { - let mut thread_rng = thread_rng(); - let (transmitter, receiver) = mpsc::channel(); + let (result, num_attempts) = + if threads < 1 { + eprintln!("--threads must be at least 1"); + exit(1); + } else if threads == 1 { + let mut rng = SmallRng::from_entropy(); + get_puzzle_matching_conditions(&mut rng, &difficulty, &solve_controller, max_attempts, max_hints, &AtomicBool::new(false)) + } else { + run_multi_threaded(max_attempts, max_hints, threads, debug, solve_controller, difficulty) + }; - for _i in 0..threads { - let cloned_transmitter = mpsc::Sender::clone(&transmitter); - let mut rng = SmallRng::from_rng(&mut thread_rng).unwrap(); - - thread::spawn(move || { - if debug { - println!("Thread spawned"); - } - - let result = get_puzzle_matching_conditions(&mut rng, &difficulty, &solve_controller, max_attempts, max_hints); - match result { - Some((grid, solve_statistics, num_hints)) => { - cloned_transmitter.send((SafeGridWrapper(grid), solve_statistics, num_hints)).unwrap(); - }, - None => {} - }; - - if debug { - println!("Thread terminated"); - } - }); - } - - // TODO - fix bug where recv doesn't return if no Grid is found by any threads! - match receiver.recv() { - Ok((grid, solve_statistics, num_hints)) => (grid.0, solve_statistics, num_hints), - Err(e) => { - eprintln!("Unable to find an appropriate puzzle in the required amount of attempts"); - if debug { - eprintln!("Error returned: {:?}", e); - } - - exit(1); - } + let (grid, solve_statistics, num_hints) = match result { + Some(x) => x, + None => { + println!("Unable to find a desired puzzle in {} tries.", num_attempts); + return; } }; println!("{}", grid); - println!("Puzzle has {} hints", num_hints); + println!("Puzzle has {} hints and was found in {} attempts.", num_hints, num_attempts); if debug { println!("Solving this puzzle involves roughly:"); @@ -209,21 +174,83 @@ fn main() { } -fn get_puzzle_matching_conditions(rng: &mut SmallRng, difficulty: &Difficulty, solve_controller: &SolveController, max_attempts: i32, max_hints: i32) -> Option<(Grid, SolveStatistics, i32)>{ +fn run_multi_threaded(max_attempts: i32, max_hints: i32, threads: i32, debug: bool, solve_controller: SolveController, difficulty: Difficulty) -> (Option<(Grid, SolveStatistics, i32)>, i32){ + let mut thread_rng = thread_rng(); + let (transmitter, receiver) = mpsc::channel(); + let mut remaining_attempts = max_attempts; + + let should_stop = AtomicBool::new(false); + let should_stop = Arc::new(should_stop); + + for i in 0..threads { + let cloned_transmitter = mpsc::Sender::clone(&transmitter); + let mut rng = SmallRng::from_rng(&mut thread_rng).unwrap(); + let thread_attempts = remaining_attempts / (threads - i); + remaining_attempts -= thread_attempts; + let should_stop = Arc::clone(&should_stop); + + thread::spawn(move || { + if debug { + println!("Thread {} spawned with {} max attempts", i, thread_attempts); + } + + let should_stop = &*should_stop; + let (result, num_attempts) = get_puzzle_matching_conditions(&mut rng, &difficulty, &solve_controller, thread_attempts, max_hints, should_stop); + + let mut result_was_some = false; + let result = match result { + None => {None} + Some((grid, solve_statistics, num_hints)) => { + result_was_some = true; + Some((SafeGridWrapper(grid), solve_statistics, num_hints)) + } + }; + + cloned_transmitter.send((result, num_attempts)).unwrap(); + + if debug { + println!("Thread {}, terminated having run {} attempts; did send result: {}", i, num_attempts, result_was_some); + } + }); + } + + let mut threads_running = threads; + let mut attempt_count = 0; + let mut result_to_return = None; + + while threads_running > 0 { + let signal = receiver.recv().unwrap(); // Not sure what errors can result here but they are unexpected and deserve a panic + threads_running-=1; + + let (result, attempts) = signal; + attempt_count += attempts; + + match result { + Some((safe_grid, solve_statistics, num_hints)) => { + result_to_return = Some((safe_grid.0, solve_statistics, num_hints)); + should_stop.store(true, Ordering::Relaxed); + } + None => {} + }; + } + + return (result_to_return, attempt_count); +} + +fn get_puzzle_matching_conditions(rng: &mut SmallRng, difficulty: &Difficulty, solve_controller: &SolveController, max_attempts: i32, max_hints: i32, should_stop: &AtomicBool) -> (Option<(Grid, SolveStatistics, i32)>, i32){ let mut num_attempts = 0; - loop { - if num_attempts >= max_attempts { - return None; - } + while num_attempts < max_attempts && !should_stop.load(Ordering::Relaxed){ let (grid, num_hints, solve_statistics) = sudoku_solver::generator::generate_grid(rng, &solve_controller); num_attempts += 1; if difficulty.meets_minimum_requirements(&solve_statistics) && num_hints <= max_hints { - return Some((grid, solve_statistics, num_hints)); + return (Some((grid, solve_statistics, num_hints)), num_attempts); } } + + return (None, num_attempts); } fn save_grid_csv(grid: &Grid, filename: &str) -> Result<(), Box>{ diff --git a/src/generator.rs b/src/generator.rs index 3d98596..5c01d1c 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -251,8 +251,8 @@ mod tests { use crate::grid::*; use crate::solver::{solve_grid_with_solve_controller, SolveController, Uniqueness, SolveStatus, SolveStatistics}; use crate::generator::generate_grid; - use rand_chacha::SmallRng; - use rand_chacha::rand_core::SeedableRng; + use rand::prelude::SmallRng; + use rand::SeedableRng; #[test] fn test_unique_detection() {