Skip to content

Commit a384cb4

Browse files
committed
updated progress reporter to use refcell and atomics
1 parent 3f078ec commit a384cb4

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

tmc-langs-core/src/tmc_core.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ use reqwest::{blocking::Client, Url};
1313
use serde::{Deserialize, Serialize};
1414
use std::collections::HashMap;
1515
use std::io::Write;
16-
use std::path::Path;
17-
use std::path::PathBuf;
16+
use std::path::{Path, PathBuf};
1817
use std::thread;
1918
use std::time::Duration;
2019
use tempfile::NamedTempFile;

tmc-langs-util/src/progress_reporter.rs

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
/// Utility struct for printing progress reports.
22
use serde::Serialize;
3-
use std::cell::Cell;
3+
use std::cell::RefCell;
44
use std::error::Error;
5+
use std::sync::atomic::{AtomicUsize, Ordering};
56
use std::time::Instant;
67

78
/// The format for all status updates. May contain some data.
@@ -19,21 +20,23 @@ pub struct StatusUpdate<T> {
1920
type DynError = Box<dyn Error + Send + Sync + 'static>;
2021
type UpdateClosure<'a, T> = Box<dyn 'a + Fn(StatusUpdate<T>) -> Result<(), DynError>>;
2122

23+
/// The reporter contains a RefCell for the timer, meaning care should be taken when using in a multithreaded context.
24+
/// The first call to progress and each step completion should be called from one thread. Other progress calls can be done from separate threads.
2225
pub struct ProgressReporter<'a, T> {
2326
progress_report: UpdateClosure<'a, T>,
24-
progress_steps_total: Cell<usize>,
25-
progress_steps_done: Cell<usize>,
26-
start_time: Cell<Option<Instant>>,
27+
progress_steps_total: AtomicUsize,
28+
progress_steps_done: AtomicUsize,
29+
start_time: RefCell<Option<Instant>>,
2730
}
2831

2932
impl<'a, T> ProgressReporter<'a, T> {
3033
/// Takes a closure that will be called with all status updates, for example to print it.
3134
pub fn new(progress_report: impl 'a + Fn(StatusUpdate<T>) -> Result<(), DynError>) -> Self {
3235
Self {
3336
progress_report: Box::new(progress_report),
34-
progress_steps_total: Cell::new(1),
35-
progress_steps_done: Cell::new(0),
36-
start_time: Cell::new(None),
37+
progress_steps_total: AtomicUsize::new(1),
38+
progress_steps_done: AtomicUsize::new(0),
39+
start_time: RefCell::new(None),
3740
}
3841
}
3942

@@ -42,13 +45,14 @@ impl<'a, T> ProgressReporter<'a, T> {
4245
/// Should be incremented to its final value before the process starts.
4346
pub fn increment_progress_steps(&self, amount: usize) {
4447
self.progress_steps_total
45-
.set(self.progress_steps_total.get() + amount);
48+
.fetch_add(amount, Ordering::Relaxed);
4649
}
4750

4851
/// Starts the timer if not started yet.
4952
pub fn start_timer(&self) {
50-
if self.start_time.get().is_none() {
51-
self.start_time.set(Some(Instant::now()))
53+
if self.start_time.borrow().is_none() {
54+
let mut time = self.start_time.borrow_mut();
55+
*time = Some(Instant::now())
5256
}
5357
}
5458

@@ -61,30 +65,35 @@ impl<'a, T> ProgressReporter<'a, T> {
6165
) -> Result<(), DynError> {
6266
self.start_timer();
6367

64-
let from_prev_steps = self.progress_steps_done.get() as f64;
65-
let percent_done =
66-
(from_prev_steps + step_percent_done) / self.progress_steps_total.get() as f64;
68+
let from_prev_steps = self.progress_steps_done.load(Ordering::Relaxed) as f64;
69+
let percent_done = (from_prev_steps + step_percent_done)
70+
/ self.progress_steps_total.load(Ordering::Relaxed) as f64;
6771

6872
self.progress_report.as_ref()(StatusUpdate {
6973
finished: false,
7074
message: message.to_string(),
7175
percent_done,
72-
time: self.start_time.get().map(|t| t.elapsed().as_millis()),
76+
time: self.start_time.borrow().map(|t| t.elapsed().as_millis()),
7377
data,
7478
})
7579
}
7680

7781
/// Finish the current step and the whole process if the current step is the last one.
7882
pub fn finish_step(&self, message: impl ToString, data: Option<T>) -> Result<(), DynError> {
79-
self.progress_steps_done
80-
.set(self.progress_steps_done.get() + 1);
81-
if self.progress_steps_done.get() == self.progress_steps_total.get() {
83+
self.progress_steps_done.fetch_add(1, Ordering::Relaxed);
84+
if self.progress_steps_done.load(Ordering::Relaxed)
85+
== self.progress_steps_total.load(Ordering::Relaxed)
86+
{
8287
// all done
8388
let result = self.progress_report.as_ref()(StatusUpdate {
8489
finished: true,
8590
message: message.to_string(),
8691
percent_done: 1.0,
87-
time: self.start_time.take().map(|t| t.elapsed().as_millis()),
92+
time: self
93+
.start_time
94+
.borrow_mut()
95+
.take()
96+
.map(|t| t.elapsed().as_millis()),
8897
data,
8998
});
9099
result

0 commit comments

Comments
 (0)