From 7d4e47473132853dc828d27f952ebddcbc61a5ca Mon Sep 17 00:00:00 2001
From: Joel Therrien <joel@joeltherrien.ca>
Date: Tue, 29 Sep 2020 13:43:11 -0700
Subject: [PATCH] Finish multi-threaded support for generating puzzles

---
 src/bin/generator.rs | 145 +++++++++++++++++++++++++------------------
 src/generator.rs     |   4 +-
 2 files changed, 88 insertions(+), 61 deletions(-)

diff --git a/src/bin/generator.rs b/src/bin/generator.rs
index 0ae181e..cb6ff39 100644
--- a/src/bin/generator.rs
+++ b/src/bin/generator.rs
@@ -4,9 +4,10 @@ use std::error::Error;
 use std::io::Write;
 use sudoku_solver::solver::{SolveController, SolveStatistics};
 use std::str::FromStr;
-use std::sync::{mpsc};
+use std::sync::{mpsc, Arc};
 use std::process::exit;
 use std::thread;
+use std::sync::atomic::{AtomicBool, Ordering};
 
 /*
 We have to be very careful here because Grid contains lots of Rcs and RefCells which could enable mutability
@@ -125,64 +126,28 @@ fn main() {
 
     let solve_controller = difficulty.map_to_solve_controller();
 
-    let (grid, solve_statistics, num_hints) =
-    if threads < 1 {
-        eprintln!("--threads must be at least 1");
-        exit(1);
-    } else if threads == 1 {
-        let mut rng = SmallRng::from_entropy();
-        let result = get_puzzle_matching_conditions(&mut rng, &difficulty, &solve_controller, max_attempts, max_hints);
-        match result {
-            Some(x) => x,
-            None => {
-                eprintln!("Unable to find an appropriate puzzle in the required amount of attempts");
-                exit(1);
-            }
-        }
-    } else {
-        let mut thread_rng = thread_rng();
-        let (transmitter, receiver) = mpsc::channel();
+    let (result, num_attempts) =
+        if threads < 1 {
+            eprintln!("--threads must be at least 1");
+            exit(1);
+        } else if threads == 1 {
+            let mut rng = SmallRng::from_entropy();
+            get_puzzle_matching_conditions(&mut rng, &difficulty, &solve_controller, max_attempts, max_hints, &AtomicBool::new(false))
+        } else {
+            run_multi_threaded(max_attempts, max_hints, threads, debug, solve_controller, difficulty)
+        };
 
-        for _i in 0..threads {
-            let cloned_transmitter = mpsc::Sender::clone(&transmitter);
-            let mut rng = SmallRng::from_rng(&mut thread_rng).unwrap();
-
-            thread::spawn(move || {
-                if debug {
-                    println!("Thread spawned");
-                }
-
-                let result = get_puzzle_matching_conditions(&mut rng, &difficulty, &solve_controller, max_attempts, max_hints);
-                match result {
-                    Some((grid, solve_statistics, num_hints)) => {
-                        cloned_transmitter.send((SafeGridWrapper(grid), solve_statistics, num_hints)).unwrap();
-                    },
-                    None => {}
-                };
-
-                if debug {
-                    println!("Thread terminated");
-                }
-            });
-        }
-
-        // TODO - fix bug where recv doesn't return if no Grid is found by any threads!
-        match receiver.recv() {
-            Ok((grid, solve_statistics, num_hints)) => (grid.0, solve_statistics, num_hints),
-            Err(e) => {
-                eprintln!("Unable to find an appropriate puzzle in the required amount of attempts");
-                if debug {
-                    eprintln!("Error returned: {:?}", e);
-                }
-
-                exit(1);
-            }
+    let (grid, solve_statistics, num_hints) = match result {
+        Some(x) => x,
+        None => {
+            println!("Unable to find a desired puzzle in {} tries.", num_attempts);
+            return;
         }
     };
 
 
     println!("{}", grid);
-    println!("Puzzle has {} hints", num_hints);
+    println!("Puzzle has {} hints and was found in {} attempts.", num_hints, num_attempts);
 
     if debug {
         println!("Solving this puzzle involves roughly:");
@@ -209,21 +174,83 @@ fn main() {
 
 }
 
-fn get_puzzle_matching_conditions(rng: &mut SmallRng, difficulty: &Difficulty, solve_controller: &SolveController, max_attempts: i32, max_hints: i32) -> Option<(Grid, SolveStatistics, i32)>{
+fn run_multi_threaded(max_attempts: i32, max_hints: i32, threads: i32, debug: bool, solve_controller: SolveController, difficulty: Difficulty) -> (Option<(Grid, SolveStatistics, i32)>, i32){
+    let mut thread_rng = thread_rng();
+    let (transmitter, receiver) = mpsc::channel();
+    let mut remaining_attempts = max_attempts;
+
+    let should_stop = AtomicBool::new(false);
+    let should_stop = Arc::new(should_stop);
+
+    for i in 0..threads {
+        let cloned_transmitter = mpsc::Sender::clone(&transmitter);
+        let mut rng = SmallRng::from_rng(&mut thread_rng).unwrap();
+        let thread_attempts = remaining_attempts / (threads - i);
+        remaining_attempts -= thread_attempts;
+        let should_stop = Arc::clone(&should_stop);
+
+        thread::spawn(move || {
+            if debug {
+                println!("Thread {} spawned with {} max attempts", i, thread_attempts);
+            }
+
+            let should_stop = &*should_stop;
+            let (result, num_attempts) = get_puzzle_matching_conditions(&mut rng, &difficulty, &solve_controller, thread_attempts, max_hints, should_stop);
+
+            let mut result_was_some = false;
+            let result = match result {
+                None => {None}
+                Some((grid, solve_statistics, num_hints)) => {
+                    result_was_some = true;
+                    Some((SafeGridWrapper(grid), solve_statistics, num_hints))
+                }
+            };
+
+            cloned_transmitter.send((result, num_attempts)).unwrap();
+
+            if debug {
+                println!("Thread {}, terminated having run {} attempts; did send result: {}", i, num_attempts, result_was_some);
+            }
+        });
+    }
+
+    let mut threads_running = threads;
+    let mut attempt_count = 0;
+    let mut result_to_return = None;
+
+    while threads_running > 0 {
+        let signal = receiver.recv().unwrap(); // Not sure what errors can result here but they are unexpected and deserve a panic
+        threads_running-=1;
+
+        let (result, attempts) = signal;
+        attempt_count += attempts;
+
+        match result {
+            Some((safe_grid, solve_statistics, num_hints)) => {
+                result_to_return = Some((safe_grid.0, solve_statistics, num_hints));
+                should_stop.store(true, Ordering::Relaxed);
+            }
+            None => {}
+        };
+    }
+
+    return (result_to_return, attempt_count);
+}
+
+fn get_puzzle_matching_conditions(rng: &mut SmallRng, difficulty: &Difficulty, solve_controller: &SolveController, max_attempts: i32, max_hints: i32, should_stop: &AtomicBool) -> (Option<(Grid, SolveStatistics, i32)>, i32){
     let mut num_attempts = 0;
 
-    loop {
-        if num_attempts >= max_attempts {
-            return None;
-        }
+    while num_attempts < max_attempts && !should_stop.load(Ordering::Relaxed){
 
         let (grid, num_hints, solve_statistics) = sudoku_solver::generator::generate_grid(rng, &solve_controller);
         num_attempts += 1;
 
         if difficulty.meets_minimum_requirements(&solve_statistics) && num_hints <= max_hints {
-            return Some((grid, solve_statistics, num_hints));
+            return (Some((grid, solve_statistics, num_hints)), num_attempts);
         }
     }
+
+    return (None, num_attempts);
 }
 
 fn save_grid_csv(grid: &Grid, filename: &str) -> Result<(), Box<dyn Error>>{
diff --git a/src/generator.rs b/src/generator.rs
index 3d98596..5c01d1c 100644
--- a/src/generator.rs
+++ b/src/generator.rs
@@ -251,8 +251,8 @@ mod tests {
     use crate::grid::*;
     use crate::solver::{solve_grid_with_solve_controller, SolveController, Uniqueness, SolveStatus, SolveStatistics};
     use crate::generator::generate_grid;
-    use rand_chacha::SmallRng;
-    use rand_chacha::rand_core::SeedableRng;
+    use rand::prelude::SmallRng;
+    use rand::SeedableRng;
 
     #[test]
     fn test_unique_detection() {