diff --git a/Cargo.lock b/Cargo.lock index 3eb78ef6b..8f74654a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1494,7 +1494,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -1729,9 +1729,12 @@ dependencies = [ [[package]] name = "fragile" -version = "2.0.1" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28dd6caf6059519a65843af8fe2a3ae298b14b80179855aeb4adc2c1934ee619" +checksum = "8878864ba14bb86e818a412bfd6f18f9eabd4ec0f008a28e8f7eb61db532fcf9" +dependencies = [ + "futures-core", +] [[package]] name = "futures" @@ -2410,16 +2413,6 @@ version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" -[[package]] -name = "iri-string" -version = "0.7.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25e659a4bb38e810ebc252e53b5814ff908a8c58c2a9ce2fae1bbec24cbf4e20" -dependencies = [ - "memchr", - "serde", -] - [[package]] name = "is_terminal_polyfill" version = "1.70.2" @@ -2888,15 +2881,14 @@ checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" [[package]] name = "openssl" -version = "0.10.76" +version = "0.10.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" +checksum = "a45fa2aa886c42762255da344f0a0d313e254066c46aad76f300c3d3da62d967" dependencies = [ "bitflags", "cfg-if", "foreign-types", "libc", - "once_cell", "openssl-macros", "openssl-sys", ] @@ -2920,9 +2912,9 @@ checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" [[package]] name = "openssl-sys" -version = "0.9.112" +version = "0.9.116" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb" +checksum = "f28a22dc7140cda5f096e5e7724a6962ca81a7f8bfd2979f9b18c11af56318c4" dependencies = [ "cc", "libc", @@ -3399,9 +3391,9 @@ checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" [[package]] name = "rayon" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" dependencies = [ "either", "rayon-core", @@ -3609,7 +3601,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -3666,7 +3658,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -4302,7 +4294,7 @@ dependencies = [ "getrandom 0.4.2", "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -4633,20 +4625,20 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.8" +version = "0.6.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +checksum = "4cfcf7e2740e6fc6d4d688b4ef00650406bb94adf4731e43c096c3a19fe40840" dependencies = [ "bitflags", "bytes", "futures-util", "http 1.4.1", "http-body", - "iri-string", "pin-project-lite", "tower", "tower-layer", "tower-service", + "url", ] [[package]] @@ -5160,7 +5152,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/cot/Cargo.toml b/cot/Cargo.toml index 75e9548d8..bc0067e35 100644 --- a/cot/Cargo.toml +++ b/cot/Cargo.toml @@ -24,7 +24,7 @@ blake3.workspace = true bytes.workspace = true chrono = { workspace = true, features = ["alloc", "serde", "clock"] } chrono-tz.workspace = true -clap.workspace = true +clap = {workspace = true, features = ["string"] } cot_core.workspace = true cot_macros.workspace = true deadpool-redis = { workspace = true, features = ["tokio-comp", "rt_tokio_1"], optional = true } diff --git a/cot/src/cli.rs b/cot/src/cli.rs index 97652a2e1..ef174bf39 100644 --- a/cot/src/cli.rs +++ b/cot/src/cli.rs @@ -7,6 +7,9 @@ use std::str::FromStr; use async_trait::async_trait; pub use clap; use clap::{Arg, ArgMatches, Command, value_parser}; +#[cfg(feature = "db")] +use cot::db::migrations::{MigrationEngine, SyncDynMigration}; +use cot::project::BootstrappedProject; use derive_more::Debug; use crate::{Bootstrapper, Error, Result}; @@ -16,6 +19,8 @@ const COLLECT_STATIC_SUBCOMMAND: &str = "collect-static"; const CHECK_SUBCOMMAND: &str = "check"; const LISTEN_PARAM: &str = "listen"; const COLLECT_STATIC_DIR_PARAM: &str = "dir"; +const MIGRATION_GROUP_SUBCOMMAND: &str = "migration"; +const MIGRATION_ROLLBACK_SUBCOMMAND: &str = "rollback"; /// A central point for configuring the default Command Line Interface (CLI) for /// Cot-powered projects. @@ -91,6 +96,12 @@ impl Cli { cli.add_task(Check); cli.add_task(CollectStatic); + let mut migration_group = + CliTaskGroup::new(MIGRATION_GROUP_SUBCOMMAND).about("Database migration commands"); + migration_group.add_task(MigrationRollback); + + cli.add_task(migration_group); + cli } @@ -389,6 +400,237 @@ impl CliTask for Check { } } +/// A group of related sub-tasks under a single parent subcommand. +/// +/// # Examples +/// +/// ``` +/// use async_trait::async_trait; +/// use clap::{ArgMatches, Command}; +/// use cot::cli::{Cli, CliTask, CliTaskGroup}; +/// use cot::project::WithConfig; +/// use cot::{Bootstrapper, Project}; +/// +/// struct Frobnicate; +/// +/// #[async_trait(?Send)] +/// impl CliTask for Frobnicate { +/// fn subcommand(&self) -> Command { +/// Command::new("frobnicate") +/// } +/// +/// async fn execute( +/// &mut self, +/// _matches: &ArgMatches, +/// _bootstrapper: Bootstrapper, +/// ) -> cot::Result<()> { +/// println!("Frobnicating..."); +/// +/// Ok(()) +/// } +/// } +/// +/// struct MyProject; +/// impl Project for MyProject { +/// fn register_tasks(&self, cli: &mut Cli) { +/// let mut group_command = CliTaskGroup::new("foo").about("Foo related commands"); +/// group_command.add_task(Frobnicate); +/// cli.add_task(group_command); +/// } +/// } +/// ``` +#[derive(Debug)] +pub struct CliTaskGroup { + name: String, + about: String, + #[debug("..")] + tasks: HashMap>, +} + +impl CliTaskGroup { + /// Create a subcommand group. + /// + /// # Examples + /// + /// ``` + /// use cot::cli::CliTaskGroup; + /// + /// let group = CliTaskGroup::new("command"); + /// ``` + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + about: String::new(), + tasks: HashMap::new(), + } + } + + /// Sets the description of the group, which is displayed in the help + /// message for the group's subcommands. + /// + /// # Example + /// ``` + /// use cot::cli::CliTaskGroup; + /// + /// let group = CliTaskGroup::new("command").about("This is a description for the command group"); + /// ``` + #[must_use] + pub fn about(mut self, about: impl Into) -> Self { + self.about = about.into(); + self + } + + /// Adds a new task to the subcommand group. + /// + /// # Panics + /// + /// Panics if a task with the same name has already been registered. + /// + /// # Examples + /// + /// ``` + /// use async_trait::async_trait; + /// use clap::{ArgMatches, Command}; + /// use cot::cli::{Cli, CliTask, CliTaskGroup}; + /// use cot::project::WithConfig; + /// use cot::{Bootstrapper, Project}; + /// + /// struct Frobnicate; + /// + /// #[async_trait(?Send)] + /// impl CliTask for Frobnicate { + /// fn subcommand(&self) -> Command { + /// Command::new("frobnicate") + /// } + /// + /// async fn execute( + /// &mut self, + /// _matches: &ArgMatches, + /// _bootstrapper: Bootstrapper, + /// ) -> cot::Result<()> { + /// println!("Frobnicating..."); + /// + /// Ok(()) + /// } + /// } + /// + /// struct MyProject; + /// impl Project for MyProject { + /// fn register_tasks(&self, cli: &mut Cli) { + /// let mut group_command = CliTaskGroup::new("foo").about("Foo related commands"); + /// group_command.add_task(Frobnicate); + /// cli.add_task(group_command); + /// } + /// } + /// ``` + pub fn add_task(&mut self, task: C) { + let subcommand = task.subcommand(); + let name = subcommand.get_name().to_owned(); + + assert!( + !self.tasks.contains_key(&name), + "Task with name {name} already exists in group '{}'", + self.name + ); + + self.tasks.insert(name, Box::new(task)); + } +} + +#[async_trait(?Send)] +impl CliTask for CliTaskGroup { + fn subcommand(&self) -> Command { + let name = self.name.clone(); + let mut cmd = Command::new(name); + + if !self.about.is_empty() { + cmd = cmd.about(self.about.clone()); + } + + cmd = cmd.subcommand_required(true).arg_required_else_help(true); + for task in self.tasks.values() { + cmd = cmd.subcommand(task.subcommand()); + } + cmd + } + + async fn execute( + &mut self, + matches: &ArgMatches, + bootstrapper: Bootstrapper, + ) -> Result<()> { + let (sub_name, sub_matches) = matches + .subcommand() + .expect("subcommand should be present since subcommand_required is true"); + + self.tasks + .get_mut(sub_name) + .expect("command should be registered") + .execute(sub_matches, bootstrapper) + .await + } +} + +struct MigrationRollback; + +#[async_trait(?Send)] +impl CliTask for MigrationRollback { + fn subcommand(&self) -> Command { + Command::new(MIGRATION_ROLLBACK_SUBCOMMAND) + .about("Rollback migrations up to the specified migration file") + .arg( + Arg::new("file") + .help("The migration filename to roll back to (e.g. 0001_initial or 0001)") + .value_name("FILE") + .required(true), + ) + } + + async fn execute( + &mut self, + matches: &ArgMatches, + bootstrapper: Bootstrapper, + ) -> Result<()> { + let file = matches + .get_one::("file") + .expect("required argument"); + + let bootstrapper = bootstrapper + .with_apps() + .with_database() + .await? + .with_cache() + .await? + .boot() + .await?; + + // migrations are currently tied to crates, so we use the crate name as the app + // name. + // TODO: cli command should take an explicit crate name as arg when workspaces + // are supported. + let crate_name = bootstrapper.project().cli_metadata().name; + + let BootstrappedProject { + context, + handler: _, + error_handler: _, + } = bootstrapper.finish(); + + #[cfg(feature = "db")] + { + let mut migrations: Vec> = Vec::new(); + for app in context.apps() { + migrations.extend(app.migrations()); + } + let migration_engine = MigrationEngine::new(migrations)?; + migration_engine + .rollback(context.database(), file, crate_name) + .await?; + } + Ok(()) + } +} + /// A macro to generate a [`CliMetadata`] struct from the Cargo manifest. #[macro_export] macro_rules! metadata { @@ -476,6 +718,60 @@ mod tests { ); } + #[test] + fn cli_new_includes_migration_rollback_group() { + let cli = Cli::new(); + + let migration_group = cli + .command + .get_subcommands() + .find(|command| command.get_name() == MIGRATION_GROUP_SUBCOMMAND) + .expect("migration group is registered"); + + assert!( + migration_group + .get_subcommands() + .any(|command| command.get_name() == MIGRATION_ROLLBACK_SUBCOMMAND) + ); + } + + #[cot::test] + async fn cli_task_group_dispatches_nested_task() { + use std::sync::atomic::{AtomicBool, Ordering}; + + struct NestedTask; + #[async_trait(?Send)] + impl CliTask for NestedTask { + fn subcommand(&self) -> Command { + Command::new("nested") + } + + async fn execute( + &mut self, + _matches: &ArgMatches, + _bootstrapper: Bootstrapper, + ) -> Result<()> { + TASK_CALLED.store(true, Ordering::SeqCst); + Ok(()) + } + } + + struct TestProject; + impl crate::Project for TestProject {} + + static TASK_CALLED: AtomicBool = AtomicBool::new(false); + TASK_CALLED.store(false, Ordering::SeqCst); + + let mut group = CliTaskGroup::new("group"); + group.add_task(NestedTask); + let matches = group.subcommand().get_matches_from(["group", "nested"]); + let bootstrapper = Bootstrapper::new(TestProject).with_config(ProjectConfig::default()); + + group.execute(&matches, bootstrapper).await.unwrap(); + + assert!(TASK_CALLED.load(Ordering::SeqCst)); + } + #[test] fn run_server_subcommand() { let matches = RunServer diff --git a/cot/src/db/migrations.rs b/cot/src/db/migrations.rs index 51dc19761..de71c6563 100644 --- a/cot/src/db/migrations.rs +++ b/cot/src/db/migrations.rs @@ -2,6 +2,7 @@ mod sorter; +use std::collections::{HashSet, VecDeque}; use std::fmt; use std::fmt::{Debug, Formatter}; use std::future::Future; @@ -215,6 +216,102 @@ impl MigrationEngine { Ok(()) } + /// Roll back necessary migrations up until the specified migration in an + /// app. + /// + /// # Errors + /// + /// Returns an error if there is an error while interacting with the + /// database or if there is an error while generating the migration + /// graph or if there is an error while unapplying a migration. + pub async fn rollback(&self, database: &Database, file: &str, app_name: &str) -> Result<()> { + info!("Rolling back migrations"); + + let rollback_plan = self.rollback_plan(file, app_name)?; + + for migration in rollback_plan { + if !Self::is_migration_applied(database, migration).await? { + continue; + } + + let span = tracing::span!( + Level::TRACE, + "rollback_migration", + app_name = migration.app_name(), + migration_name = migration.name() + ); + let _enter = span.enter(); + + info!( + "Rolling back migration {} for app {}", + migration.name(), + migration.app_name() + ); + + for operation in migration.operations().iter().rev() { + operation.backwards(database).await?; + } + + Self::mark_migration_unapplied(database, migration).await?; + } + + Ok(()) + } + + fn rollback_plan<'a>( + &'a self, + file_name: &str, + app_name: &str, + ) -> Result> { + let target_index = self + .migrations + .iter() + .position(|migration| { + migration.app_name() == app_name + && expand_migration_file_name(migration.name()).contains(&file_name) + }) + .ok_or_else(|| { + MigrationEngineError::Custom(format!( + "Migration with file name {file_name} not found for app {app_name}" + )) + })?; + + let mut rollback_indices = HashSet::new(); + // Seed later migrations in the same app, then include migrations from + // other apps only when they depend on that seed set. + rollback_indices.extend( + self.migrations + .iter() + .enumerate() + .filter(|(index, migration)| { + *index > target_index && migration.app_name() == app_name + }) + .map(|(index, _)| index), + ); + + let graph = MigrationSorter::generate_graph(&self.migrations).map_err(|e| { + MigrationEngineError::Custom(format!("Failed to generate migration graph: {e}")) + })?; + let mut queue = rollback_indices.iter().copied().collect::>(); + while let Some(index) = queue.pop_front() { + for &dependent_index in graph.get_edges(index) { + if rollback_indices.insert(dependent_index) { + // we found a migration that depends on the one we're rolling back, so let's + // add it to the queue which we will later traverse its dependents as well. + queue.push_back(dependent_index); + } + } + } + + Ok(self + .migrations + .iter() + .enumerate() + .rev() + .filter_map(|(index, migration)| rollback_indices.contains(&index).then_some(migration)) + .collect()) + } + async fn is_migration_applied( database: &Database, migration: &MigrationWrapper, @@ -241,6 +338,30 @@ impl MigrationEngine { database.insert(&mut applied_migration).await?; Ok(()) } + + async fn mark_migration_unapplied( + database: &Database, + migration: &MigrationWrapper, + ) -> Result<()> { + query!(AppliedMigration, $app == migration.app_name() && $name == migration.name()) + .delete(database) + .await?; + Ok(()) + } +} + +/// Resolves the possible migration names that can be used to refer to a +/// migration file. For example, for a migration file named `m_0001_initial`, +/// this function will return both `m_0001_initial` and `0001`. This allows +/// users to refer to migrations using either the full file name or just the +/// migration number when rolling back migrations. +fn expand_migration_file_name(file_name: &str) -> Vec<&str> { + let mut names = vec![file_name]; + let migration_number = file_name.split('_').nth(1); + if let Some(migration_number) = migration_number { + names.push(migration_number); + } + names } /// A migration operation that can be run forwards or backwards. @@ -2026,7 +2147,10 @@ mod tests { use sea_query::ColumnSpec; use super::*; + use crate::App; + use crate::auth::db::DatabaseUserApp; use crate::db::{ColumnType, DatabaseField, Identifier}; + use crate::session::db::SessionApp; struct TestMigration; @@ -2054,6 +2178,110 @@ mod tests { const OPERATIONS: &'static [Operation] = &[]; } + struct RollbackApp1Initial; + + impl Migration for RollbackApp1Initial { + const APP_NAME: &'static str = "rollback_app1"; + const MIGRATION_NAME: &'static str = "m_0001_initial"; + const DEPENDENCIES: &'static [MigrationDependency] = &[]; + const OPERATIONS: &'static [Operation] = &[Operation::create_model() + .table_name(Identifier::new("rollback_single__first")) + .fields(&[ + Field::new(Identifier::new("id"), ::TYPE) + .primary_key() + .auto(), + ]) + .build()]; + } + + struct RollbackApp10002; + + impl Migration for RollbackApp10002 { + const APP_NAME: &'static str = "rollback_app1"; + const MIGRATION_NAME: &'static str = "m_0002_second"; + const DEPENDENCIES: &'static [MigrationDependency] = &[MigrationDependency::migration( + "rollback_app1", + "m_0001_initial", + )]; + const OPERATIONS: &'static [Operation] = &[Operation::create_model() + .table_name(Identifier::new("rollback_app1__second")) + .fields(&[ + Field::new(Identifier::new("id"), ::TYPE) + .primary_key() + .auto(), + ]) + .build()]; + } + + struct RollbackApp1003; + + impl Migration for RollbackApp1003 { + const APP_NAME: &'static str = "rollback_app1"; + const MIGRATION_NAME: &'static str = "m_0003_third"; + const DEPENDENCIES: &'static [MigrationDependency] = &[MigrationDependency::migration( + "rollback_app1", + "m_0002_second", + )]; + const OPERATIONS: &'static [Operation] = &[Operation::create_model() + .table_name(Identifier::new("rollback_single__third")) + .fields(&[ + Field::new(Identifier::new("id"), ::TYPE) + .primary_key() + .auto(), + ]) + .build()]; + } + + struct RollbackApp2Initial; + + impl Migration for RollbackApp2Initial { + const APP_NAME: &'static str = "rollback_app2"; + const MIGRATION_NAME: &'static str = "m_0001_initial"; + const DEPENDENCIES: &'static [MigrationDependency] = &[]; + const OPERATIONS: &'static [Operation] = &[Operation::create_model() + .table_name(Identifier::new("rollback_app2__foo")) + .fields(&[ + Field::new(Identifier::new("id"), ::TYPE) + .primary_key() + .auto(), + ]) + .build()]; + } + + struct RollbackDependentInitial; + + impl Migration for RollbackDependentInitial { + const APP_NAME: &'static str = "rollback_dependent"; + const MIGRATION_NAME: &'static str = "m_0001_initial"; + const DEPENDENCIES: &'static [MigrationDependency] = &[MigrationDependency::migration( + "rollback_app1", + "m_0002_second", + )]; + const OPERATIONS: &'static [Operation] = &[Operation::create_model() + .table_name(Identifier::new("rollback_dependent__bar")) + .fields(&[ + Field::new(Identifier::new("id"), ::TYPE) + .primary_key() + .auto(), + ]) + .build()]; + } + + async fn assert_migration_applied(database: Database, app: &str, name: &str, expected: bool) { + let applied = query!(AppliedMigration, $app == app && $name == name) + .exists(&database) + .await + .unwrap(); + + assert_eq!(applied, expected, "{app}::{name}"); + } + + #[cot_macros::dbtest] + async fn test_migration_rollback_no_deps(test_db: &mut TestDatabase) { + let engine = MigrationEngine::new([RollbackApp1Initial]).unwrap(); + engine.run(&test_db.database()).await.unwrap(); + } + #[cot_macros::dbtest] async fn test_migration_engine_run(test_db: &mut TestDatabase) { let engine = MigrationEngine::new([TestMigration]).unwrap(); @@ -2077,6 +2305,117 @@ mod tests { assert!(result.is_ok()); } + #[cot_macros::dbtest] + async fn test_migration_engine_rollback_single_app(test_db: &mut TestDatabase) { + #[expect(trivial_casts)] + let engine = MigrationEngine::new([ + &RollbackApp1Initial as &SyncDynMigration, + &RollbackApp10002 as &SyncDynMigration, + &RollbackApp1003 as &SyncDynMigration, + ]) + .unwrap(); + + engine.run(&test_db.database()).await.unwrap(); + // migrations should be applied + + assert_migration_applied(test_db.database(), "rollback_app1", "m_0001_initial", true).await; + assert_migration_applied(test_db.database(), "rollback_app1", "m_0002_second", true).await; + assert_migration_applied(test_db.database(), "rollback_app1", "m_0003_third", true).await; + + // rollback everything except the initial migration + engine + .rollback(&test_db.database(), "0001", "rollback_app1") + .await + .unwrap(); + + // the initial migration should stay applied + assert_migration_applied(test_db.database(), "rollback_app1", "m_0001_initial", true).await; + // everything else should be unapplied + assert_migration_applied(test_db.database(), "rollback_app1", "m_0002_second", false).await; + assert_migration_applied(test_db.database(), "rollback_app1", "m_0003_third", false).await; + } + + #[cot_macros::dbtest] + async fn test_migration_rollback_unrelated_apps(test_db: &mut TestDatabase) { + let mut migrations = DatabaseUserApp::new().migrations(); + // combine migrations from multiple apps/crates + #[expect(trivial_casts)] + migrations.extend(wrap_migrations(&[ + &RollbackApp1Initial as &SyncDynMigration, + &RollbackApp10002 as &SyncDynMigration, + &RollbackApp2Initial as &SyncDynMigration, + ])); + migrations.extend(SessionApp::new().migrations()); + let engine = MigrationEngine::new(migrations).unwrap(); + + engine.run(&test_db.database()).await.unwrap(); + // migrations should be applied across all apps + assert_migration_applied(test_db.database(), "cot", "m_0001_initial", true).await; + assert_migration_applied(test_db.database(), "cot_session", "m_0001_initial", true).await; + assert_migration_applied(test_db.database(), "rollback_app1", "m_0001_initial", true).await; + assert_migration_applied(test_db.database(), "rollback_app1", "m_0002_second", true).await; + assert_migration_applied(test_db.database(), "rollback_app2", "m_0001_initial", true).await; + + // rollback every migration in the rollback_app1 app except the initial + engine + .rollback(&test_db.database(), "0001", "rollback_app1") + .await + .unwrap(); + + // the initial migration should stay applied + assert_migration_applied(test_db.database(), "rollback_app1", "m_0001_initial", true).await; + // everything else in the rollback_app1 app should be unapplied + assert_migration_applied(test_db.database(), "rollback_app1", "m_0002_second", false).await; + // migrations from other apps should remain unaffected + assert_migration_applied(test_db.database(), "cot", "m_0001_initial", true).await; + assert_migration_applied(test_db.database(), "cot_session", "m_0001_initial", true).await; + assert_migration_applied(test_db.database(), "rollback_app2", "m_0001_initial", true).await; + } + + #[cot_macros::dbtest] + async fn test_migration_engine_rollback_includes_dependent_apps(test_db: &mut TestDatabase) { + #[expect(trivial_casts)] + let engine = MigrationEngine::new([ + &RollbackApp1Initial as &SyncDynMigration, + &RollbackApp10002 as &SyncDynMigration, + &RollbackDependentInitial as &SyncDynMigration, + &RollbackApp2Initial as &SyncDynMigration, + ]) + .unwrap(); + + engine.run(&test_db.database()).await.unwrap(); + assert_migration_applied(test_db.database(), "rollback_app1", "m_0001_initial", true).await; + assert_migration_applied(test_db.database(), "rollback_app1", "m_0002_second", true).await; + assert_migration_applied( + test_db.database(), + "rollback_dependent", + "m_0001_initial", + true, + ) + .await; + assert_migration_applied(test_db.database(), "rollback_app2", "m_0001_initial", true).await; + + // rollback everything except the initial migration in the source/independent + // app + engine + .rollback(&test_db.database(), "0001", "rollback_app1") + .await + .unwrap(); + + assert_migration_applied(test_db.database(), "rollback_app1", "m_0001_initial", true).await; + assert_migration_applied(test_db.database(), "rollback_app1", "m_0002_second", false).await; + // the sink/dependent app should also be unapplied/rolled back + assert_migration_applied( + test_db.database(), + "rollback_dependent", + "m_0001_initial", + false, + ) + .await; + // migrations from non-dependent apps should remain unaffected + assert_migration_applied(test_db.database(), "rollback_app2", "m_0001_initial", true).await; + } + #[test] fn test_operation_create_model() { const OPERATION_CREATE_MODEL_FIELDS: &[Field; 2] = &[ diff --git a/cot/src/db/migrations/sorter.rs b/cot/src/db/migrations/sorter.rs index e2d5c7e66..46faf1f96 100644 --- a/cot/src/db/migrations/sorter.rs +++ b/cot/src/db/migrations/sorter.rs @@ -61,11 +61,11 @@ impl<'a, T: DynMigration> MigrationSorter<'a, T> { Ok(()) } - fn toposort(&mut self) -> Result<()> { - let lookup = Self::create_lookup_table(self.migrations)?; - let mut graph = Graph::new(self.migrations.len()); + pub(super) fn generate_graph(migrations: &[T]) -> Result { + let lookup = Self::create_lookup_table(migrations)?; + let mut graph = Graph::new(migrations.len()); - for (index, migration) in self.migrations.iter().enumerate() { + for (index, migration) in migrations.iter().enumerate() { for dependency in migration.dependencies() { let dependency_index = lookup .get(&MigrationLookup::from(dependency)) @@ -74,6 +74,12 @@ impl<'a, T: DynMigration> MigrationSorter<'a, T> { } } + Ok(graph) + } + + fn toposort(&mut self) -> Result<()> { + let mut graph = Self::generate_graph(self.migrations)?; + let mut sorted_indices = graph.toposort()?; apply_permutation(self.migrations, &mut sorted_indices); diff --git a/cot/src/utils/graph.rs b/cot/src/utils/graph.rs index 741f98400..bf6053efc 100644 --- a/cot/src/utils/graph.rs +++ b/cot/src/utils/graph.rs @@ -40,6 +40,10 @@ impl Graph { self.vertex_edges[from].push(to); } + pub(crate) fn get_edges(&self, from: usize) -> &[usize] { + &self.vertex_edges[from] + } + #[must_use] pub(crate) fn vertex_num(&self) -> usize { self.vertex_edges.len()