diff --git a/components/spider-storage/src/cache/job.rs b/components/spider-storage/src/cache/job.rs index 21997b06..a6e3bf98 100644 --- a/components/spider-storage/src/cache/job.rs +++ b/components/spider-storage/src/cache/job.rs @@ -22,7 +22,7 @@ use crate::{ job_submission::ValidatedJobSubmission, task::TaskGraph, }, - db::InternalJobOrchestration, + db::{InternalJobOrchestration, RecoverableJob}, ready_queue::ReadyQueueSender, task_instance_pool::{TaskInstanceMetadata, TaskInstancePoolConnector}, }; @@ -93,6 +93,94 @@ impl< }) } + /// Recovers a job control block from persistent database state. + /// + /// This constructor does not mutate the database. It rebuilds enough cache state to resume + /// scheduling: + /// + /// * [`JobState::Running`] jobs enqueue their initially-ready regular tasks. + /// * [`JobState::CommitReady`] jobs enqueue the commit task. + /// * [`JobState::CleanupReady`] jobs enqueue the cleanup task. + /// + /// # Returns + /// + /// The recovered [`SharedJobControlBlock`] on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`InternalError::UnexpectedJobState`] if `state` is not recoverable. + /// * [`InternalError::TaskGraphCorrupted`] if a commit-ready job has no persisted outputs. + /// * Forwards [`TaskGraph::create`]'s return values on failure. + /// * Forwards [`TaskGraph::restore_outputs`]'s return values on failure. + /// * Forwards [`SharedJobControlBlock::resend_ready_tasks`]'s return values on failure. + pub async fn recover( + recoverable_job: RecoverableJob, + ready_queue_sender: ReadyQueueSenderType, + db_connector: DbConnectorType, + task_instance_pool_connector: TaskInstancePoolConnectorType, + ) -> Result { + let RecoverableJob { + id, + resource_group_id, + state, + job_submission, + job_outputs, + } = recoverable_job; + if !matches!( + state, + JobState::Running | JobState::CommitReady | JobState::CleanupReady + ) { + return Err(UnexpectedJobState { + current: state, + expected: JobState::Running, + } + .into()); + } + + let num_tasks = job_submission.task_graph().get_num_tasks(); + let mut task_graph = TaskGraph::create(job_submission).await?; + if matches!(state, JobState::CommitReady) && job_outputs.is_none() { + return Err(InternalError::TaskGraphCorrupted( + "commit-ready job has no persisted outputs".to_owned(), + ) + .into()); + } + if let Some(outputs) = job_outputs { + task_graph.restore_outputs(outputs).await?; + } + let num_incomplete_tasks = if matches!(state, JobState::CommitReady) { + 0 + } else { + num_tasks + }; + + if matches!(state, JobState::CleanupReady) { + task_graph.cancel_non_terminal().await; + } + + let job_execution_state = JobExecutionState { + state, + task_graph, + num_incomplete_tasks: AtomicUsize::new(num_incomplete_tasks), + ready_queue_sender, + db_connector, + task_instance_pool_connector, + }; + let recovered = Self { + inner: Arc::new(JobControlBlock { + id, + owner_id: resource_group_id, + job_execution_state: JobExecutionStateHandle { + inner: tokio::sync::RwLock::new(job_execution_state), + }, + }), + }; + recovered.resend_ready_tasks().await?; + Ok(recovered) + } + /// Returns the job ID. #[must_use] pub fn id(&self) -> JobId { diff --git a/components/spider-storage/src/cache/sync.rs b/components/spider-storage/src/cache/sync.rs index 0fc03448..4d1847a9 100644 --- a/components/spider-storage/src/cache/sync.rs +++ b/components/spider-storage/src/cache/sync.rs @@ -17,6 +17,13 @@ impl Reader { Self { inner } } + /// # Returns + /// + /// A writer for the same shared data. + pub(crate) fn writer(&self) -> Writer { + Writer::new(self.inner.clone()) + } + /// # Returns /// /// A guard that allows read access to the shared data. The guard will be released when it goes diff --git a/components/spider-storage/src/cache/task.rs b/components/spider-storage/src/cache/task.rs index 5ce7ff30..75e1bba3 100644 --- a/components/spider-storage/src/cache/task.rs +++ b/components/spider-storage/src/cache/task.rs @@ -172,6 +172,30 @@ impl TaskGraph { &self.outputs } + /// Restores graph outputs from persisted job outputs. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`InternalError::TaskOutputsLengthMismatch`] if the number of persisted outputs does not + /// match the number of graph outputs. + pub async fn restore_outputs( + &self, + persisted_outputs: Vec, + ) -> Result<(), InternalError> { + if persisted_outputs.len() != self.outputs.len() { + return Err(InternalError::TaskOutputsLengthMismatch( + self.outputs.len(), + persisted_outputs.len(), + )); + } + for (output_reader, output) in self.outputs.iter().zip(persisted_outputs) { + *output_reader.writer().write().await = Some(output); + } + Ok(()) + } + #[must_use] pub const fn has_commit_task(&self) -> bool { self.commit_task.is_some() diff --git a/components/spider-storage/src/db.rs b/components/spider-storage/src/db.rs index c5152e2f..8ea9a5f0 100644 --- a/components/spider-storage/src/db.rs +++ b/components/spider-storage/src/db.rs @@ -9,6 +9,7 @@ pub use protocol::{ ExecutionManagerLivenessManagement, ExternalJobOrchestration, InternalJobOrchestration, + RecoverableJob, ResourceGroupManagement, SessionManagement, }; diff --git a/components/spider-storage/src/db/error.rs b/components/spider-storage/src/db/error.rs index 62b6434b..3bce5386 100644 --- a/components/spider-storage/src/db/error.rs +++ b/components/spider-storage/src/db/error.rs @@ -40,6 +40,9 @@ pub enum DbError { #[error("Task graph serialization failure: {0}")] TaskGraphSerializationFailure(#[source] Box), + #[error("Task graph deserialization failure: {0}")] + TaskGraphDeserializationFailure(#[source] Box), + #[error("Value serialization failure: {0}")] ValueSerializationFailure(#[source] Box), @@ -57,6 +60,12 @@ impl DbError { Self::TaskGraphSerializationFailure(Box::new(e)) } + pub fn task_graph_de( + e: DeserializationError, + ) -> Self { + Self::TaskGraphDeserializationFailure(Box::new(e)) + } + pub fn value_ser( e: SerializationError, ) -> Self { diff --git a/components/spider-storage/src/db/mariadb.rs b/components/spider-storage/src/db/mariadb.rs index 6bd7017c..4cb11320 100644 --- a/components/spider-storage/src/db/mariadb.rs +++ b/components/spider-storage/src/db/mariadb.rs @@ -5,9 +5,10 @@ use const_format::formatcp; use secrecy::ExposeSecret; use spider_core::{ job::JobState, + task::TaskGraph, types::{ id::{ExecutionManagerId, JobId, ResourceGroupId, SessionId}, - io::TaskOutput, + io::{TaskInput, TaskOutput}, }, }; use spider_derive::MySqlEnum; @@ -22,6 +23,7 @@ use crate::{ ExecutionManagerLivenessManagement, ExternalJobOrchestration, InternalJobOrchestration, + RecoverableJob, ResourceGroupManagement, SessionManagement, error::ExpectedStates, @@ -380,6 +382,63 @@ impl InternalJobOrchestration for MariaDbStorageConnector { tx.commit().await?; Ok(deleted_job_ids) } + + async fn get_recoverable_jobs(&self) -> Result, DbError> { + const SELECT_QUERY: &str = formatcp!( + "SELECT `id`, `resource_group_id`, `state`, `serialized_task_graph`, \ + `serialized_job_inputs`, `serialized_job_outputs` FROM `{table}` WHERE `state` IN \ + ('{running_state}','{commit_ready_state}','{cleanup_ready_state}');", + table = JOBS_TABLE_NAME, + running_state = JobState::Running.as_str(), + commit_ready_state = JobState::CommitReady.as_str(), + cleanup_ready_state = JobState::CleanupReady.as_str(), + ); + + let rows = sqlx::query_as::< + _, + ( + JobId, + ResourceGroupId, + JobState, + String, + Vec, + Option>, + ), + >(SELECT_QUERY) + .fetch_all(&self.pool) + .await?; + + rows.into_iter() + .map( + |( + id, + resource_group_id, + state, + serialized_task_graph, + serialized_job_inputs, + serialized_job_outputs, + )| { + let task_graph = TaskGraph::from_json(&serialized_task_graph) + .map_err(DbError::task_graph_de)?; + let job_inputs: Vec = + rmp_serde::from_slice(&serialized_job_inputs).map_err(DbError::value_de)?; + let job_submission = ValidatedJobSubmission::create(task_graph, job_inputs) + .map_err(|e| DbError::CorruptedDbState(e.to_string()))?; + let job_outputs = serialized_job_outputs + .map(|outputs| rmp_serde::from_slice(&outputs).map_err(DbError::value_de)) + .transpose()?; + + Ok(RecoverableJob { + id, + resource_group_id, + state, + job_submission, + job_outputs, + }) + }, + ) + .collect() + } } #[async_trait] diff --git a/components/spider-storage/src/db/protocol.rs b/components/spider-storage/src/db/protocol.rs index 0b9e297f..2f9be3fd 100644 --- a/components/spider-storage/src/db/protocol.rs +++ b/components/spider-storage/src/db/protocol.rs @@ -11,6 +11,23 @@ use spider_core::{ use crate::{cache::job_submission::ValidatedJobSubmission, db::error::DbError}; +/// A job persisted in the database that should be rebuilt in the storage cache on startup. +/// +/// Only jobs that have already started execution are recoverable. [`JobState::Ready`] jobs remain +/// database-only until a client starts them. +pub struct RecoverableJob { + /// The persisted job ID. + pub id: JobId, + /// The owning resource group. + pub resource_group_id: ResourceGroupId, + /// The source-of-truth database state. + pub state: JobState, + /// The original job submission. + pub job_submission: ValidatedJobSubmission, + /// The committed job outputs, if the job has reached the commit phase. + pub job_outputs: Option>, +} + /// The database storage interface. A database storage must implement the following traits: /// /// * [`ExternalJobOrchestration`] @@ -244,6 +261,22 @@ pub trait InternalJobOrchestration: Clone + Send + Sync { &self, expire_after_sec: u64, ) -> Result, DbError>; + + /// Gets all jobs that should be recovered into the cache. + /// + /// # Returns + /// + /// All persisted jobs in [`JobState::Running`], [`JobState::CommitReady`], or + /// [`JobState::CleanupReady`] on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`DbError::TaskGraphDeserializationFailure`] if a persisted task graph is invalid. + /// * [`DbError::ValueDeserializationFailure`] if persisted inputs or outputs are invalid. + /// * Forwards [`sqlx::error::Error`] on DB operation failure. + async fn get_recoverable_jobs(&self) -> Result, DbError>; } /// Defines the storage interface for resource group management in the database. diff --git a/components/spider-storage/src/state/runtime.rs b/components/spider-storage/src/state/runtime.rs index 5bda0d7a..9350bce1 100644 --- a/components/spider-storage/src/state/runtime.rs +++ b/components/spider-storage/src/state/runtime.rs @@ -4,7 +4,10 @@ use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use crate::{ - cache::error::{CacheError, InternalError}, + cache::{ + error::{CacheError, InternalError}, + job::SharedJobControlBlock, + }, config::DatabaseConfig, db::{DbStorage, MariaDbStorageConnector, SessionManagement}, ready_queue::{ReadyQueueConfig, ReadyQueueSender, ReadyQueueSenderHandle, create_ready_queue}, @@ -121,11 +124,16 @@ pub async fn create_runtime( ) .map_err(CacheError::from)?; - // TODO: Recover jobs from the database. + let job_cache = recover_job_cache( + &db, + ready_queue_sender.clone(), + task_instance_pool_connector.clone(), + ) + .await?; let service_state = ServiceState::new( db, session_id, - JobCache::new(), + job_cache, ready_queue_sender, ready_queue_receiver, task_instance_pool_connector, @@ -144,6 +152,52 @@ pub async fn create_runtime( const STOP_BACKGROUND_TASKS_TIMEOUT_SEC: u64 = 30; +/// Recovers jobs from persistent storage into the cache. +/// +/// # Returns +/// +/// A [`JobCache`] containing all recoverable jobs on success. +/// +/// # Errors +/// +/// Returns an error if: +/// +/// * Forwards [`DbStorage::get_recoverable_jobs`]'s return values on failure. +/// * Forwards [`SharedJobControlBlock::recover`]'s return values on failure. +/// * Forwards [`JobCache::insert`]'s return values on failure. +async fn recover_job_cache< + ReadyQueueSenderType: ReadyQueueSender, + DbConnectorType: DbStorage, + TaskInstancePoolConnectorType: TaskInstancePoolConnector, +>( + db: &DbConnectorType, + ready_queue_sender: ReadyQueueSenderType, + task_instance_pool_connector: TaskInstancePoolConnectorType, +) -> Result< + JobCache, + StorageServerError, +> { + let job_cache = JobCache::new(); + for recoverable_job in db.get_recoverable_jobs().await? { + let id = recoverable_job.id; + let state = recoverable_job.state; + let jcb = SharedJobControlBlock::recover( + recoverable_job, + ready_queue_sender.clone(), + db.clone(), + task_instance_pool_connector.clone(), + ) + .await?; + job_cache.insert(jcb).await?; + tracing::info!( + job_id = ? id, + job_state = ? state, + "Job recovered into cache.", + ); + } + Ok(job_cache) +} + #[cfg(test)] mod tests { use std::time::Duration; diff --git a/components/spider-storage/src/state/service.rs b/components/spider-storage/src/state/service.rs index ac257e77..a198fb3c 100644 --- a/components/spider-storage/src/state/service.rs +++ b/components/spider-storage/src/state/service.rs @@ -83,6 +83,14 @@ impl< } } + /// # Returns + /// + /// The storage session ID owned by this service state. + #[must_use] + pub fn session_id(&self) -> SessionId { + self.inner.session_id + } + /// Registers a job in the database and inserts its control block into the cache. /// /// # Returns diff --git a/components/spider-storage/src/state/test_utils.rs b/components/spider-storage/src/state/test_utils.rs index a2536d6c..52dcf383 100644 --- a/components/spider-storage/src/state/test_utils.rs +++ b/components/spider-storage/src/state/test_utils.rs @@ -27,6 +27,7 @@ use crate::{ ExecutionManagerLivenessManagement, ExternalJobOrchestration, InternalJobOrchestration, + RecoverableJob, ResourceGroupManagement, SessionManagement, }, @@ -166,6 +167,10 @@ impl InternalJobOrchestration for MockDbConnector { ) -> Result, DbError> { Ok(Vec::new()) } + + async fn get_recoverable_jobs(&self) -> Result, DbError> { + Ok(Vec::new()) + } } #[async_trait::async_trait] diff --git a/components/spider-storage/tests/mariadb_infra.rs b/components/spider-storage/tests/mariadb_infra.rs index 0772ec04..299ec1fb 100644 --- a/components/spider-storage/tests/mariadb_infra.rs +++ b/components/spider-storage/tests/mariadb_infra.rs @@ -16,6 +16,23 @@ use spider_storage::{ /// Panics if any required environment variable (`MARIADB_PORT`, `MARIADB_DATABASE`, /// `MARIADB_USERNAME`, `MARIADB_PASSWORD`) is missing or if the connection fails. pub async fn create_mariadb_connector() -> MariaDbStorageConnector { + MariaDbStorageConnector::connect(&create_mariadb_config()) + .await + .expect("connect failed") +} + +/// Creates a [`DatabaseConfig`] from environment variables. +/// +/// # Returns +/// +/// A [`DatabaseConfig`] configured from environment variables. +/// +/// # Panics +/// +/// Panics if any required environment variable (`MARIADB_PORT`, `MARIADB_DATABASE`, +/// `MARIADB_USERNAME`, `MARIADB_PASSWORD`) is missing or if `MARIADB_PORT` is invalid. +#[must_use] +pub fn create_mariadb_config() -> DatabaseConfig { let port: u16 = std::env::var("MARIADB_PORT") .expect("MARIADB_PORT") .parse() @@ -24,17 +41,14 @@ pub async fn create_mariadb_connector() -> MariaDbStorageConnector { let username = std::env::var("MARIADB_USERNAME").expect("MARIADB_USERNAME"); let password = std::env::var("MARIADB_PASSWORD").expect("MARIADB_PASSWORD"); - let config = DatabaseConfig { + DatabaseConfig { host: "localhost".to_string(), port, name: database, username, password: SecretString::from(password), max_connections: 5, - }; - MariaDbStorageConnector::connect(&config) - .await - .expect("connect failed") + } } /// Registers a new resource group with a random external ID and a fixed test password. diff --git a/components/spider-storage/tests/mariadb_test.rs b/components/spider-storage/tests/mariadb_test.rs index 88343c82..f58a020f 100644 --- a/components/spider-storage/tests/mariadb_test.rs +++ b/components/spider-storage/tests/mariadb_test.rs @@ -269,7 +269,7 @@ async fn test_get_error_wrong_state() { async fn test_cancel_job_with_cleanup_transitions_to_cleanup_ready() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; - let (graph, inputs) = single_task_graph(); + let (graph, inputs) = build_flat_task_graph(1, TEST_INPUT_PAYLOAD_SIZE, false, true); let job_submission = ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); @@ -403,7 +403,7 @@ async fn test_commit_outputs_without_commit_task() { async fn test_commit_outputs_with_commit_task() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; - let (graph, inputs) = single_task_graph(); + let (graph, inputs) = build_flat_task_graph(1, TEST_INPUT_PAYLOAD_SIZE, true, false); let job_submission = ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); diff --git a/components/spider-storage/tests/recovery_test.rs b/components/spider-storage/tests/recovery_test.rs new file mode 100644 index 00000000..84860644 --- /dev/null +++ b/components/spider-storage/tests/recovery_test.rs @@ -0,0 +1,481 @@ +use std::{net::IpAddr, time::Duration}; + +use spider_core::{ + job::JobState, + task::TaskIndex, + types::{ + id::{JobId, TaskInstanceId}, + io::TaskInput, + }, +}; +use spider_storage::{ + db::ExternalJobOrchestration, + ready_queue::{ReadyQueueConfig, ReadyQueueEntry}, + state::{Runtime, ServiceState, StorageServerError, create_runtime}, + task_instance_pool::TaskInstancePoolConfig, +}; +use spider_tdl::wire::{TaskInputsSerializer, TaskOutputsSerializer}; + +use crate::{ + mariadb_infra::{create_mariadb_config, create_mariadb_connector}, + task_graph_builder::build_flat_task_graph, +}; + +#[tokio::test] +async fn restarted_storage_cache_does_not_recover_ready_job() -> anyhow::Result<()> { + let db_config = create_mariadb_config(); + let (runtime, _) = create_runtime( + &db_config, + &ReadyQueueConfig::default(), + &TaskInstancePoolConfig::default(), + ) + .await?; + let service = runtime.get_service_state(); + let job_id = create_registered_job(&service, false, false).await?; + assert_eq!(service.get_job_state(job_id).await?, JobState::Ready); + runtime.stop().await?; + + let (recovered_runtime, _) = create_runtime( + &db_config, + &ReadyQueueConfig::default(), + &TaskInstancePoolConfig::default(), + ) + .await?; + let recovered_service = recovered_runtime.get_service_state(); + let start_result = recovered_service.start_job(job_id).await; + assert!( + matches!(start_result, Err(StorageServerError::JobNotFound(id)) if id == job_id), + "ready job should not be recovered into cache" + ); + assert_eq!( + recovered_service.get_job_state(job_id).await?, + JobState::Ready + ); + recovered_runtime.stop().await?; + Ok(()) +} + +#[tokio::test] +async fn restarted_storage_cache_recovers_running_job_from_start() -> anyhow::Result<()> { + let db_config = create_mariadb_config(); + let (job_id, recovered_service, recovered_runtime) = + restart_after_starting_job(&db_config, false, false).await?; + + let ready_entries = recovered_service + .poll_ready_tasks(32, Duration::from_secs(1)) + .await?; + let ready_entry = find_entry_for_job(ready_entries, job_id); + + let task_instance_id = + run_recovered_regular_task(&recovered_service, job_id, ready_entry.task_kind).await?; + let state = recovered_service + .succeed_task_instance( + recovered_service.session_id(), + job_id, + task_instance_id, + ready_entry.task_kind, + serialized_single_output()?, + ) + .await?; + assert_eq!(state, JobState::Succeeded); + + assert_eq!( + create_mariadb_connector().await.get_state(job_id).await?, + JobState::Succeeded + ); + recovered_runtime.stop().await?; + Ok(()) +} + +#[tokio::test] +async fn restarted_storage_cache_recovers_commit_ready_job() -> anyhow::Result<()> { + let db_config = create_mariadb_config(); + let (job_id, recovered_service, recovered_runtime) = + restart_after_commit_ready(&db_config).await?; + + let ready_entries = recovered_service + .poll_commit_ready_tasks(32, Duration::from_secs(1)) + .await?; + let _ready_entry = find_entry_for_job(ready_entries, job_id); + + let execution_manager_id = recovered_service + .register_execution_manager(IpAddr::from([127, 0, 0, 1])) + .await?; + let execution_context = recovered_service + .create_task_instance( + recovered_service.session_id(), + job_id, + spider_core::types::id::TaskId::Commit, + execution_manager_id, + ) + .await?; + let state = recovered_service + .succeed_commit_task_instance( + recovered_service.session_id(), + job_id, + execution_context.task_instance_id, + ) + .await?; + assert_eq!(state, JobState::Succeeded); + let expected_outputs = TaskOutputsSerializer::deserialize(&serialized_single_output()?)?; + assert_eq!( + recovered_service.get_job_outputs(job_id).await?, + expected_outputs + ); + + assert_eq!( + create_mariadb_connector().await.get_state(job_id).await?, + JobState::Succeeded + ); + recovered_runtime.stop().await?; + Ok(()) +} + +#[tokio::test] +async fn restarted_storage_cache_recovers_cleanup_ready_job() -> anyhow::Result<()> { + let db_config = create_mariadb_config(); + let (job_id, recovered_service, recovered_runtime) = + restart_after_cleanup_ready(&db_config).await?; + + let ready_entries = recovered_service + .poll_cleanup_ready_tasks(32, Duration::from_secs(1)) + .await?; + let _ready_entry = find_entry_for_job(ready_entries, job_id); + + let execution_manager_id = recovered_service + .register_execution_manager(IpAddr::from([127, 0, 0, 1])) + .await?; + let execution_context = recovered_service + .create_task_instance( + recovered_service.session_id(), + job_id, + spider_core::types::id::TaskId::Cleanup, + execution_manager_id, + ) + .await?; + let state = recovered_service + .succeed_cleanup_task_instance( + recovered_service.session_id(), + job_id, + execution_context.task_instance_id, + ) + .await?; + assert_eq!(state, JobState::Cancelled); + + assert_eq!( + create_mariadb_connector().await.get_state(job_id).await?, + JobState::Cancelled + ); + recovered_runtime.stop().await?; + Ok(()) +} + +/// Starts a job, stops the runtime, and creates a replacement runtime over the same database. +/// +/// # Returns +/// +/// The job ID, recovered service state, and recovered runtime on success. +/// +/// # Errors +/// +/// Returns an error if: +/// +/// * Forwards [`create_runtime`]'s return values on failure. +/// * Forwards [`create_and_start_job`]'s return values on failure. +/// * Forwards [`Runtime::stop`]'s return values on failure. +async fn restart_after_starting_job( + db_config: &spider_storage::DatabaseConfig, + with_commit: bool, + with_cleanup: bool, +) -> anyhow::Result<( + JobId, + ServiceState< + spider_storage::ready_queue::ReadyQueueSenderHandle, + spider_storage::db::MariaDbStorageConnector, + spider_storage::task_instance_pool::TaskInstancePoolHandle, + >, + Runtime< + spider_storage::ready_queue::ReadyQueueSenderHandle, + spider_storage::db::MariaDbStorageConnector, + spider_storage::task_instance_pool::TaskInstancePoolHandle, + >, +)> { + let (runtime, _) = create_runtime( + db_config, + &ReadyQueueConfig::default(), + &TaskInstancePoolConfig::default(), + ) + .await?; + let service = runtime.get_service_state(); + let job_id = create_and_start_job(&service, with_commit, with_cleanup).await?; + runtime.stop().await?; + + let (recovered_runtime, _) = create_runtime( + db_config, + &ReadyQueueConfig::default(), + &TaskInstancePoolConfig::default(), + ) + .await?; + let recovered_service = recovered_runtime.get_service_state(); + Ok((job_id, recovered_service, recovered_runtime)) +} + +/// Drives a job to [`JobState::CommitReady`], stops the runtime, and creates a replacement runtime. +/// +/// # Returns +/// +/// The job ID, recovered service state, and recovered runtime on success. +/// +/// # Errors +/// +/// Returns an error if: +/// +/// * Forwards [`restart_after_starting_job`]'s return values on failure. +/// * Forwards [`ServiceState::poll_ready_tasks`]'s return values on failure. +/// * Forwards [`run_recovered_regular_task`]'s return values on failure. +/// * Forwards [`serialized_single_output`]'s return values on failure. +/// * Forwards [`ServiceState::succeed_task_instance`]'s return values on failure. +/// * Forwards [`Runtime::stop`]'s return values on failure. +/// * Forwards [`create_runtime`]'s return values on failure. +async fn restart_after_commit_ready( + db_config: &spider_storage::DatabaseConfig, +) -> anyhow::Result<( + JobId, + ServiceState< + spider_storage::ready_queue::ReadyQueueSenderHandle, + spider_storage::db::MariaDbStorageConnector, + spider_storage::task_instance_pool::TaskInstancePoolHandle, + >, + Runtime< + spider_storage::ready_queue::ReadyQueueSenderHandle, + spider_storage::db::MariaDbStorageConnector, + spider_storage::task_instance_pool::TaskInstancePoolHandle, + >, +)> { + let (job_id, service, runtime) = restart_after_starting_job(db_config, true, false).await?; + let ready_entries = service.poll_ready_tasks(32, Duration::from_secs(1)).await?; + let ready_entry = find_entry_for_job(ready_entries, job_id); + let task_instance_id = + run_recovered_regular_task(&service, job_id, ready_entry.task_kind).await?; + let state = service + .succeed_task_instance( + service.session_id(), + job_id, + task_instance_id, + 0, + serialized_single_output()?, + ) + .await?; + assert_eq!(state, JobState::CommitReady); + runtime.stop().await?; + + let (recovered_runtime, _) = create_runtime( + db_config, + &ReadyQueueConfig::default(), + &TaskInstancePoolConfig::default(), + ) + .await?; + let recovered_service = recovered_runtime.get_service_state(); + Ok((job_id, recovered_service, recovered_runtime)) +} + +/// Drives a job to [`JobState::CleanupReady`], stops the runtime, and creates a replacement +/// runtime. +/// +/// # Returns +/// +/// The job ID, recovered service state, and recovered runtime on success. +/// +/// # Errors +/// +/// Returns an error if: +/// +/// * Forwards [`create_runtime`]'s return values on failure. +/// * Forwards [`create_and_start_job`]'s return values on failure. +/// * Forwards [`ServiceState::cancel_job`]'s return values on failure. +/// * Forwards [`Runtime::stop`]'s return values on failure. +async fn restart_after_cleanup_ready( + db_config: &spider_storage::DatabaseConfig, +) -> anyhow::Result<( + JobId, + ServiceState< + spider_storage::ready_queue::ReadyQueueSenderHandle, + spider_storage::db::MariaDbStorageConnector, + spider_storage::task_instance_pool::TaskInstancePoolHandle, + >, + Runtime< + spider_storage::ready_queue::ReadyQueueSenderHandle, + spider_storage::db::MariaDbStorageConnector, + spider_storage::task_instance_pool::TaskInstancePoolHandle, + >, +)> { + let (runtime, _) = create_runtime( + db_config, + &ReadyQueueConfig::default(), + &TaskInstancePoolConfig::default(), + ) + .await?; + let service = runtime.get_service_state(); + let job_id = create_and_start_job(&service, false, true).await?; + let state = service.cancel_job(job_id).await?; + assert_eq!(state, JobState::CleanupReady); + runtime.stop().await?; + + let (recovered_runtime, _) = create_runtime( + db_config, + &ReadyQueueConfig::default(), + &TaskInstancePoolConfig::default(), + ) + .await?; + let recovered_service = recovered_runtime.get_service_state(); + Ok((job_id, recovered_service, recovered_runtime)) +} + +/// Registers and starts a flat recovery-test job. +/// +/// # Returns +/// +/// The registered job ID on success. +/// +/// # Errors +/// +/// Returns an error if: +/// +/// * Forwards [`create_registered_job`]'s return values on failure. +/// * Forwards [`ServiceState::start_job`]'s return values on failure. +async fn create_and_start_job< + ReadyQueueSenderType: spider_storage::ready_queue::ReadyQueueSender, + DbConnectorType: spider_storage::db::DbStorage, + TaskInstancePoolConnectorType: spider_storage::task_instance_pool::TaskInstancePoolConnector, +>( + service: &ServiceState, + with_commit: bool, + with_cleanup: bool, +) -> anyhow::Result { + let job_id = create_registered_job(service, with_commit, with_cleanup).await?; + service.start_job(job_id).await?; + Ok(job_id) +} + +/// Registers a flat recovery-test job without starting it. +/// +/// # Returns +/// +/// The registered job ID on success. +/// +/// # Errors +/// +/// Returns an error if: +/// +/// * Forwards [`ServiceState::add_resource_group`]'s return values on failure. +/// * Forwards [`spider_core::task::TaskGraph::to_json`]'s return values on failure. +/// * Forwards [`serialize_inputs`]'s return values on failure. +/// * Forwards [`ServiceState::register_job`]'s return values on failure. +async fn create_registered_job< + ReadyQueueSenderType: spider_storage::ready_queue::ReadyQueueSender, + DbConnectorType: spider_storage::db::DbStorage, + TaskInstancePoolConnectorType: spider_storage::task_instance_pool::TaskInstancePoolConnector, +>( + service: &ServiceState, + with_commit: bool, + with_cleanup: bool, +) -> anyhow::Result { + let rg_id = service + .add_resource_group( + format!("recovery-test-{}", rand::random::()), + b"test-password".to_vec(), + ) + .await?; + let (task_graph, inputs) = build_flat_task_graph(1, 4, with_commit, with_cleanup); + Ok(service + .register_job(rg_id, task_graph.to_json()?, serialize_inputs(inputs)?) + .await?) +} + +/// Registers an execution manager and creates an instance for a recovered regular task. +/// +/// # Returns +/// +/// The created task instance ID on success. +/// +/// # Errors +/// +/// Returns an error if: +/// +/// * Forwards [`ServiceState::register_execution_manager`]'s return values on failure. +/// * Forwards [`ServiceState::create_task_instance`]'s return values on failure. +async fn run_recovered_regular_task< + ReadyQueueSenderType: spider_storage::ready_queue::ReadyQueueSender, + DbConnectorType: spider_storage::db::DbStorage, + TaskInstancePoolConnectorType: spider_storage::task_instance_pool::TaskInstancePoolConnector, +>( + service: &ServiceState, + job_id: JobId, + task_index: TaskIndex, +) -> anyhow::Result { + let execution_manager_id = service + .register_execution_manager(IpAddr::from([127, 0, 0, 1])) + .await?; + let execution_context = service + .create_task_instance( + service.session_id(), + job_id, + spider_core::types::id::TaskId::Index(task_index), + execution_manager_id, + ) + .await?; + Ok(execution_context.task_instance_id) +} + +/// Finds the ready-queue entry for a job. +/// +/// # Returns +/// +/// The matching ready-queue entry. +/// +/// # Panics +/// +/// Panics if no matching entry exists. +fn find_entry_for_job( + entries: Vec>, + job_id: JobId, +) -> ReadyQueueEntry { + entries + .into_iter() + .find(|entry| entry.job_id == job_id) + .expect("recovered job should be enqueued") +} + +/// Serializes task inputs into the storage service wire format. +/// +/// # Returns +/// +/// The serialized task inputs on success. +/// +/// # Errors +/// +/// Returns an error if: +/// +/// * Forwards [`TaskInputsSerializer::append`]'s return values on failure. +fn serialize_inputs(inputs: Vec) -> anyhow::Result> { + let mut serializer = TaskInputsSerializer::new(); + for input in inputs { + serializer.append(input)?; + } + Ok(serializer.release()) +} + +/// Serializes the single output payload used by recovery tests. +/// +/// # Returns +/// +/// The serialized task output on success. +/// +/// # Errors +/// +/// Returns an error if: +/// +/// * Forwards [`TaskOutputsSerializer::from_tuple`]'s return values on failure. +fn serialized_single_output() -> anyhow::Result> { + Ok(TaskOutputsSerializer::from_tuple(&(vec![1u8; 4],))?) +} diff --git a/components/spider-storage/tests/scheduling_infra.rs b/components/spider-storage/tests/scheduling_infra.rs index a089d66f..d4fa4878 100644 --- a/components/spider-storage/tests/scheduling_infra.rs +++ b/components/spider-storage/tests/scheduling_infra.rs @@ -98,7 +98,13 @@ use spider_storage::{ job_submission::ValidatedJobSubmission, task::{SharedTaskControlBlock, SharedTerminationTaskControlBlock}, }, - db::{DbError, ExternalJobOrchestration, InternalJobOrchestration, MariaDbStorageConnector}, + db::{ + DbError, + ExternalJobOrchestration, + InternalJobOrchestration, + MariaDbStorageConnector, + RecoverableJob, + }, ready_queue::ReadyQueueSender, task_instance_pool::{TaskInstanceMetadata, TaskInstancePoolConnector}, }; @@ -176,6 +182,10 @@ impl InternalJobOrchestration for NoopDbConnector { ) -> Result, DbError> { Ok(Vec::new()) } + + async fn get_recoverable_jobs(&self) -> Result, DbError> { + Ok(Vec::new()) + } } /// The result of running a workload to completion. diff --git a/components/spider-storage/tests/test_spider_storage.rs b/components/spider-storage/tests/test_spider_storage.rs index 78520dd4..6e69cc13 100644 --- a/components/spider-storage/tests/test_spider_storage.rs +++ b/components/spider-storage/tests/test_spider_storage.rs @@ -4,3 +4,4 @@ mod task_graph_builder; mod jcb_test; mod mariadb_test; +mod recovery_test;