From 617398ea84e1511977f33c7ffd55b6932fa8192a Mon Sep 17 00:00:00 2001 From: Daniel Nachun Date: Mon, 27 Apr 2026 15:43:04 -0700 Subject: [PATCH 01/11] port in some functions from xqtl-protocol --- NAMESPACE | 3 + R/LD.R | 28 ++- R/file_utils.R | 30 ++- R/misc.R | 68 +++++- R/twas.R | 32 +++ R/variant_id.R | 118 +++++++++++ man/classify_variant_type.Rd | 21 ++ man/find_overlapping_regions.Rd | 22 ++ man/regions_overlap.Rd | 20 ++ tests/testthat/test_LD.R | 31 +++ tests/testthat/test_file_utils.R | 77 +++++++ tests/testthat/test_misc.R | 121 +++++++++++ tests/testthat/test_twas_method_fallback.R | 128 ++++++++++++ xqtl-protocol-pecotmr-audit.md | 232 +++++++++++++++++++++ 14 files changed, 922 insertions(+), 9 deletions(-) create mode 100644 man/classify_variant_type.Rd create mode 100644 man/find_overlapping_regions.Rd create mode 100644 man/regions_overlap.Rd create mode 100644 tests/testthat/test_twas_method_fallback.R create mode 100644 xqtl-protocol-pecotmr-audit.md diff --git a/NAMESPACE b/NAMESPACE index 9f408151..33776f91 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -27,6 +27,7 @@ export(bayes_n_weights) export(bayes_r_weights) export(build_top_loci) export(check_ld) +export(classify_variant_type) export(clean_context_names) export(coloc_post_processor) export(coloc_wrapper) @@ -63,6 +64,7 @@ export(filter_relatedness) export(filter_variants_by_ld_reference) export(find_data) export(find_duplicate_variants) +export(find_overlapping_regions) export(fine_mr) export(fit_mash_contrast) export(format_finemapping_output) @@ -169,6 +171,7 @@ export(region_data_to_colocboost_input) export(region_data_to_ind_input) export(region_data_to_rss_input) export(region_to_df) +export(regions_overlap) export(robust_mahalanobis) export(rss_analysis_pipeline) export(rss_basic_qc) diff --git a/R/LD.R b/R/LD.R index 569ccdbb..a037a5de 100644 --- a/R/LD.R +++ b/R/LD.R @@ -272,16 +272,30 @@ load_LD_matrix <- function(LD_meta_file_path, region, extract_coordinates = NULL if (is_geno) { geno_path <- resolve_genotype_path_for_region(source$meta_path, region) - return(load_LD_from_genotype(geno_path, region, - return_genotype = return_genotype, - n_sample = n_sample)) + result <- load_LD_from_genotype(geno_path, region, + return_genotype = return_genotype, + n_sample = n_sample) + } else { + # Pre-computed LD blocks (.cor.xz) + if (return_genotype) { + stop("return_genotype=TRUE requires genotype files, not pre-computed LD matrices.") + } + result <- load_LD_from_blocks(source$meta_path, region, extract_coordinates, n_sample = n_sample) } - # Pre-computed LD blocks (.cor.xz) - if (return_genotype) { - stop("return_genotype=TRUE requires genotype files, not pre-computed LD matrices.") + # Remove any duplicate variant IDs (safety net for boundary overlaps) + if (!is.null(result$LD_variants)) { + dup_idx <- which(duplicated(result$LD_variants)) + if (length(dup_idx) > 0) { + result$LD_variants <- result$LD_variants[-dup_idx] + result$LD_matrix <- result$LD_matrix[-dup_idx, -dup_idx, drop = FALSE] + if (!is.null(result$ref_panel)) { + result$ref_panel <- result$ref_panel[-dup_idx, , drop = FALSE] + } + } } - load_LD_from_blocks(source$meta_path, region, extract_coordinates, n_sample = n_sample) + + result } # ---------- Internal: resolve LD source type ---------- diff --git a/R/file_utils.R b/R/file_utils.R index 24807af8..353a0f26 100644 --- a/R/file_utils.R +++ b/R/file_utils.R @@ -1154,7 +1154,23 @@ load_twas_weights <- function(weight_db_files, conditions = NULL, ## Internal function to load and validate data from RDS files load_and_validate_data <- function(weight_db_files, conditions, variable_name_obj) { all_data <- do.call(c, lapply(unname(weight_db_files), function(rds_file) { - db <- readRDS(rds_file) + # Validate file before loading + if (!file.exists(rds_file)) { + warning(paste0("Skipping weight file '", rds_file, "': file does not exist.")) + return(NULL) + } + if (file.size(rds_file) <= 200) { + warning(paste0("Skipping weight file '", rds_file, "': file too small (", file.size(rds_file), " bytes), likely empty or corrupt.")) + return(NULL) + } + db <- tryCatch(readRDS(rds_file), error = function(e) { + warning(paste0("Skipping weight file '", rds_file, "': failed to read RDS — ", conditionMessage(e))) + return(NULL) + }) + if (!is.list(db) || length(db) == 0) { + warning(paste0("Skipping weight file '", rds_file, "': unexpected structure (not a non-empty list).")) + return(NULL) + } gene <- names(db) # Filter by conditions if specified if (!is.null(conditions)) { @@ -1447,6 +1463,18 @@ load_rss_data <- function(sumstat_path, column_file_path = NULL, n_sample = 0, n n <- NULL } } + # Validate determined sample size + if (!is.null(n)) { + if (length(n) != 1) { + stop("Sample size must be a single value, got length ", length(n), ".") + } + if (is.na(n) || !is.finite(n) || n <= 0) { + stop("Invalid sample size determined: ", n, + ". Sample size must be a positive finite number.", + "\n Hint: check n_sample, n_case, n_control parameters or the ", + "n_sample/n_case/n_control columns in your summary statistics.") + } + } return(list(sumstats = sumstats, n = n, var_y = var_y)) } diff --git a/R/misc.R b/R/misc.R index 9320e21d..ae6db864 100644 --- a/R/misc.R +++ b/R/misc.R @@ -560,6 +560,72 @@ find_data <- function(x, depth_obj, show_path = FALSE, rm_null = TRUE, rm_dup = } +#' Convert region specifications to a GRanges object +#' +#' Accepts region strings ("chr1:100-200", "1_100_200"), character vectors of +#' such strings, or data.frames with chrom/start/end columns. Returns a +#' \code{\link[GenomicRanges]{GRanges}} object. +#' +#' @param regions A region string, character vector, or data.frame with +#' chrom/start/end columns. +#' @return A \code{GRanges} object. +#' @noRd +as_granges <- function(regions) { + if (is.character(regions)) { + df <- region_to_df(regions) + } else if (is.data.frame(regions)) { + if (!all(c("chrom", "start", "end") %in% names(regions))) { + stop("data.frame must have columns: chrom, start, end") + } + df <- regions + } else { + stop("regions must be a character vector or data.frame with chrom/start/end columns") + } + # GRanges expects character seqnames; prefix with "chr" if numeric + seqnames <- as.character(df$chrom) + if (!any(grepl("^chr", seqnames))) { + seqnames <- paste0("chr", seqnames) + } + GenomicRanges::GRanges( + seqnames = seqnames, + ranges = IRanges::IRanges(start = as.integer(df$start), end = as.integer(df$end)) + ) +} + +#' Test whether two genomic regions overlap +#' +#' @param region_a A region string ("chr1:100-200" or "1_100_200") or a +#' single-row data.frame with chrom/start/end columns. +#' @param region_b A region string or single-row data.frame. +#' @return Logical scalar: TRUE if the regions share at least one base pair. +#' @importFrom GenomicRanges GRanges +#' @importFrom IRanges IRanges findOverlaps +#' @export +regions_overlap <- function(region_a, region_b) { + gr_a <- as_granges(region_a) + gr_b <- as_granges(region_b) + length(IRanges::findOverlaps(gr_a, gr_b)) > 0 +} + +#' Find which target regions overlap a query region +#' +#' @param query A single region string or single-row data.frame with +#' chrom/start/end columns. +#' @param targets A character vector of region strings, or a multi-row +#' data.frame with chrom/start/end columns. +#' @return Integer vector of 1-based indices into \code{targets} that overlap +#' the query. Empty integer vector if no overlaps. +#' @importFrom GenomicRanges GRanges +#' @importFrom IRanges IRanges findOverlaps +#' @importFrom S4Vectors subjectHits +#' @export +find_overlapping_regions <- function(query, targets) { + gr_query <- as_granges(query) + gr_targets <- as_granges(targets) + hits <- IRanges::findOverlaps(gr_query, gr_targets) + unique(S4Vectors::subjectHits(hits)) +} + thisFile <- function() { cmdArgs <- commandArgs(trailingOnly = FALSE) needle <- "--file=" @@ -1028,4 +1094,4 @@ detect_outliers_mahalanobis <- function(x, prob = 0.99, row.names = NULL, stringsAsFactors = FALSE ) -} \ No newline at end of file +} diff --git a/R/twas.R b/R/twas.R index ed3e55fb..dad545b3 100644 --- a/R/twas.R +++ b/R/twas.R @@ -289,6 +289,37 @@ build_twas_score_row <- function(twas_rs, weight_db, context, study) { ) } +# Internal: for each gene-context-study group, if the selected method produced +# NA/Inf TWAS z-scores, fall back to the next best method by rsq_cv. +apply_method_fallback <- function(df) { + if (nrow(df) == 0 || !all(c("molecular_id", "context", "gwas_study", "is_selected_method", "twas_z", "rsq_cv", "is_imputable") %in% names(df))) { + return(df) + } + groups <- split(seq_len(nrow(df)), list(df$molecular_id, df$context, df$gwas_study), drop = TRUE) + for (idxs in groups) { + sel_idx <- idxs[df$is_selected_method[idxs]] + if (length(sel_idx) != 1) next + z_val <- df$twas_z[sel_idx] + if (!is.na(z_val) && is.finite(z_val)) next + # Selected method has invalid z — try fallback + other_idxs <- setdiff(idxs, sel_idx) + valid_mask <- !is.na(df$twas_z[other_idxs]) & is.finite(df$twas_z[other_idxs]) + if (any(valid_mask)) { + candidates <- other_idxs[valid_mask] + best <- candidates[which.max(df$rsq_cv[candidates])] + df$is_selected_method[sel_idx] <- FALSE + df$is_selected_method[best] <- TRUE + message(paste0("TWAS method fallback for ", df$molecular_id[sel_idx], + " / ", df$context[sel_idx], " / ", df$gwas_study[sel_idx], + ": ", df$method[sel_idx], " -> ", df$method[best])) + } else { + # No method has valid z — mark group as non-imputable + df$is_imputable[idxs] <- FALSE + } + } + df +} + #' @importFrom stringr str_remove #' @importFrom purrr list_flatten #' @export @@ -593,6 +624,7 @@ twas_pipeline <- function(twas_weights_data, return(list(twas_result = NULL, twas_data = NULL, mr_result = NULL)) } twas_table <- merge(twas_table, twas_results_table, by = c("molecular_id", "context", "method")) + twas_table <- apply_method_fallback(twas_table) twas_table <- twas_table[twas_table$is_imputable, , drop = FALSE] if (output_twas_data & nrow(twas_table) > 0) { twas_data_subset <- format_twas_data(twas_data, twas_table) diff --git a/R/variant_id.R b/R/variant_id.R index 85a556e1..8a39f3a5 100644 --- a/R/variant_id.R +++ b/R/variant_id.R @@ -228,3 +228,121 @@ region_to_df <- function(ld_region_id, colnames = c("chrom", "start", "end")) { colnames(region_of_interest) <- colnames return(region_of_interest) } + +#' Ensure two sets of variant IDs use matching chr prefix convention +#' +#' Detects whether \code{ids_a} and \code{ids_b} have mismatched chr prefixes. +#' If mismatched, normalizes both to canonical format (with "chr" prefix) using +#' \code{\link{normalize_variant_id}}. If already matching, returns inputs +#' unchanged. +#' +#' @param ids_a Character vector of variant IDs. +#' @param ids_b Character vector of variant IDs. +#' @return A list with components \code{ids_a} and \code{ids_b}, both normalized +#' to canonical chr-prefix format if they were mismatched. +#' @noRd +ensure_chr_match <- function(ids_a, ids_b) { + has_chr_a <- any(grepl("^chr", ids_a[!is.na(ids_a)][1:min(5, sum(!is.na(ids_a)))])) + has_chr_b <- any(grepl("^chr", ids_b[!is.na(ids_b)][1:min(5, sum(!is.na(ids_b)))])) + if (has_chr_a == has_chr_b) { + return(list(ids_a = ids_a, ids_b = ids_b)) + } + list( + ids_a = normalize_variant_id(ids_a, chr_prefix = TRUE), + ids_b = normalize_variant_id(ids_b, chr_prefix = TRUE) + ) +} + +#' Convert region specifications to a GRanges object +#' +#' Accepts region strings ("chr1:100-200", "1_100_200"), character vectors of +#' such strings, or data.frames with chrom/start/end columns. Returns a +#' \code{\link[GenomicRanges]{GRanges}} object. +#' +#' @param regions A region string, character vector, or data.frame with +#' chrom/start/end columns. +#' @return A \code{GRanges} object. +#' @noRd +as_granges <- function(regions) { + if (is.character(regions)) { + df <- region_to_df(regions) + } else if (is.data.frame(regions)) { + if (!all(c("chrom", "start", "end") %in% names(regions))) { + stop("data.frame must have columns: chrom, start, end") + } + df <- regions + } else { + stop("regions must be a character vector or data.frame with chrom/start/end columns") + } + # GRanges expects character seqnames; prefix with "chr" if numeric + seqnames <- as.character(df$chrom) + if (!any(grepl("^chr", seqnames))) { + seqnames <- paste0("chr", seqnames) + } + GenomicRanges::GRanges( + seqnames = seqnames, + ranges = IRanges::IRanges(start = as.integer(df$start), end = as.integer(df$end)) + ) +} + +#' Test whether two genomic regions overlap +#' +#' @param region_a A region string ("chr1:100-200" or "1_100_200") or a +#' single-row data.frame with chrom/start/end columns. +#' @param region_b A region string or single-row data.frame. +#' @return Logical scalar: TRUE if the regions share at least one base pair. +#' @importFrom GenomicRanges GRanges +#' @importFrom IRanges IRanges findOverlaps +#' @export +regions_overlap <- function(region_a, region_b) { + gr_a <- as_granges(region_a) + gr_b <- as_granges(region_b) + length(IRanges::findOverlaps(gr_a, gr_b)) > 0 +} + +#' Find which target regions overlap a query region +#' +#' @param query A single region string or single-row data.frame with +#' chrom/start/end columns. +#' @param targets A character vector of region strings, or a multi-row +#' data.frame with chrom/start/end columns. +#' @return Integer vector of 1-based indices into \code{targets} that overlap +#' the query. Empty integer vector if no overlaps. +#' @importFrom GenomicRanges GRanges +#' @importFrom IRanges IRanges findOverlaps +#' @importFrom S4Vectors subjectHits +#' @export +find_overlapping_regions <- function(query, targets) { + gr_query <- as_granges(query) + gr_targets <- as_granges(targets) + hits <- IRanges::findOverlaps(gr_query, gr_targets) + unique(S4Vectors::subjectHits(hits)) +} + +#' Classify variant type from allele strings +#' +#' Determines whether each variant is a SNP, insertion, deletion, or +#' multi-nucleotide polymorphism (MNP) based on the allele lengths. +#' +#' @param ids A character vector of variant IDs in "chr:pos:ref:alt" format, +#' or a data.frame with A2 (ref) and A1 (alt) columns (e.g., from +#' \code{\link{parse_variant_id}}). +#' @return A character vector with one of "SNP", "insertion", "deletion", or +#' "MNP" for each variant. +#' @export +classify_variant_type <- function(ids) { + if (is.character(ids)) { + ids <- parse_variant_id(ids) + } + if (!is.data.frame(ids) || !all(c("A2", "A1") %in% names(ids))) { + stop("Input must be a character vector of variant IDs or a data.frame with A2 and A1 columns.") + } + len_ref <- nchar(ids$A2) + len_alt <- nchar(ids$A1) + type <- character(nrow(ids)) + type[len_ref == 1L & len_alt == 1L & grepl("^[ATCG]$", ids$A2) & grepl("^[ATCG]$", ids$A1)] <- "SNP" + type[len_ref == len_alt & (len_ref > 1L | !grepl("^[ATCG]$", ids$A2) | !grepl("^[ATCG]$", ids$A1)) & type == ""] <- "MNP" + type[len_ref > len_alt] <- "deletion" + type[len_alt > len_ref] <- "insertion" + type +} diff --git a/man/classify_variant_type.Rd b/man/classify_variant_type.Rd new file mode 100644 index 00000000..f1359ecd --- /dev/null +++ b/man/classify_variant_type.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/misc.R +\name{classify_variant_type} +\alias{classify_variant_type} +\title{Classify variant type from allele strings} +\usage{ +classify_variant_type(ids) +} +\arguments{ +\item{ids}{A character vector of variant IDs in "chr:pos:ref:alt" format, +or a data.frame with A2 (ref) and A1 (alt) columns (e.g., from +\code{\link{parse_variant_id}}).} +} +\value{ +A character vector with one of "SNP", "insertion", "deletion", or + "MNP" for each variant. +} +\description{ +Determines whether each variant is a SNP, insertion, deletion, or +multi-nucleotide polymorphism (MNP) based on the allele lengths. +} diff --git a/man/find_overlapping_regions.Rd b/man/find_overlapping_regions.Rd new file mode 100644 index 00000000..3fe5f2ef --- /dev/null +++ b/man/find_overlapping_regions.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/misc.R +\name{find_overlapping_regions} +\alias{find_overlapping_regions} +\title{Find which target regions overlap a query region} +\usage{ +find_overlapping_regions(query, targets) +} +\arguments{ +\item{query}{A single region string or single-row data.frame with +chrom/start/end columns.} + +\item{targets}{A character vector of region strings, or a multi-row +data.frame with chrom/start/end columns.} +} +\value{ +Integer vector of 1-based indices into \code{targets} that overlap + the query. Empty integer vector if no overlaps. +} +\description{ +Find which target regions overlap a query region +} diff --git a/man/regions_overlap.Rd b/man/regions_overlap.Rd new file mode 100644 index 00000000..8bebac80 --- /dev/null +++ b/man/regions_overlap.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/misc.R +\name{regions_overlap} +\alias{regions_overlap} +\title{Test whether two genomic regions overlap} +\usage{ +regions_overlap(region_a, region_b) +} +\arguments{ +\item{region_a}{A region string ("chr1:100-200" or "1_100_200") or a +single-row data.frame with chrom/start/end columns.} + +\item{region_b}{A region string or single-row data.frame.} +} +\value{ +Logical scalar: TRUE if the regions share at least one base pair. +} +\description{ +Test whether two genomic regions overlap +} diff --git a/tests/testthat/test_LD.R b/tests/testthat/test_LD.R index b25634b7..281fcb5b 100644 --- a/tests/testthat/test_LD.R +++ b/tests/testthat/test_LD.R @@ -1544,3 +1544,34 @@ test_that("ld_prune_by_correlation verbose reports no pruning", { "no columns pruned" ) }) + +# ============================================================================= +# load_LD_matrix duplicate variant removal +# ============================================================================= + +test_that("load_LD_matrix dedup removes duplicated variants from result", { + # Simulate what load_LD_matrix does after calling the backend: a result with + # duplicated LD_variants should have duplicates removed. + # We test the dedup logic by constructing a mock result and verifying + # the internal dedup code path via the exported function's contract. + # Since we can't easily call the real function without data, test the dedup + # behavior directly on the result structure. + mat <- matrix(1:16, nrow = 4, ncol = 4) + variants <- c("chr1:100:A:G", "chr1:200:C:T", "chr1:100:A:G", "chr1:300:T:A") + ref <- data.frame(chrom = c(1,1,1,1), pos = c(100,200,100,300), + A2 = c("A","C","A","T"), A1 = c("G","T","G","A")) + + # Apply the same dedup logic used in load_LD_matrix + dup_idx <- which(duplicated(variants)) + expect_equal(dup_idx, 3L) + + variants_clean <- variants[-dup_idx] + mat_clean <- mat[-dup_idx, -dup_idx, drop = FALSE] + ref_clean <- ref[-dup_idx, , drop = FALSE] + + expect_equal(length(variants_clean), 3) + expect_equal(nrow(mat_clean), 3) + expect_equal(ncol(mat_clean), 3) + expect_equal(nrow(ref_clean), 3) + expect_false(any(duplicated(variants_clean))) +}) diff --git a/tests/testthat/test_file_utils.R b/tests/testthat/test_file_utils.R index 0dfd8749..23de3a8b 100644 --- a/tests/testthat/test_file_utils.R +++ b/tests/testthat/test_file_utils.R @@ -2761,3 +2761,80 @@ test_that("load_regional_multivariate_data returns correct fields", { expect_true(is.matrix(result$residual_Y)) expect_equal(nrow(result$X), 100L) }) + +# ============================================================================= +# load_twas_weights file pre-validation +# ============================================================================= + +test_that("load_twas_weights skips non-existent files with warning", { + expect_warning( + tryCatch( + load_twas_weights(c("/nonexistent/path/fake.rds")), + error = function(e) NULL + ), + "does not exist" + ) +}) + +test_that("load_twas_weights skips too-small files with warning", { + tmp <- tempfile(fileext = ".rds") + writeLines("x", tmp) # tiny file, not valid RDS + on.exit(unlink(tmp)) + expect_warning( + tryCatch( + load_twas_weights(tmp), + error = function(e) NULL + ), + "too small" + ) +}) + +test_that("load_twas_weights skips corrupt RDS files with warning", { + tmp <- tempfile(fileext = ".rds") + writeBin(as.raw(rep(0L, 500)), tmp) # 500 bytes of garbage + on.exit(unlink(tmp)) + expect_warning( + tryCatch( + load_twas_weights(tmp), + error = function(e) NULL + ), + "failed to read RDS" + ) +}) + +test_that("load_twas_weights skips non-list RDS with warning", { + tmp <- tempfile(fileext = ".rds") + saveRDS(paste0("x", seq_len(10000)), tmp) # valid RDS but not a list; large enough to pass size check + on.exit(unlink(tmp)) + expect_warning( + tryCatch( + load_twas_weights(tmp), + error = function(e) NULL + ), + "unexpected structure" + ) +}) + +# ============================================================================= +# load_rss_data sample size validation +# ============================================================================= + +test_that("load_rss_data rejects negative sample size", { + skip_if_not_installed("MungeSumstats") + sumstat_file <- file.path(test_path("test_data"), "test_sumstats.tsv.gz") + skip_if_not(file.exists(sumstat_file), "test sumstat file not found") + expect_error( + suppressMessages(load_rss_data(sumstat_file, n_sample = -100)), + "Invalid sample size" + ) +}) + +test_that("load_rss_data rejects Inf sample size", { + skip_if_not_installed("MungeSumstats") + sumstat_file <- file.path(test_path("test_data"), "test_sumstats.tsv.gz") + skip_if_not(file.exists(sumstat_file), "test sumstat file not found") + expect_error( + suppressMessages(load_rss_data(sumstat_file, n_sample = Inf)), + "Invalid sample size" + ) +}) diff --git a/tests/testthat/test_misc.R b/tests/testthat/test_misc.R index 093c392a..43e09021 100644 --- a/tests/testthat/test_misc.R +++ b/tests/testthat/test_misc.R @@ -1926,3 +1926,124 @@ test_that("find_data with numeric indices in list_name path", { result <- find_data(x, c(1, "results", "2", "val")) expect_equal(result, c(10, 20)) }) + +# ============================================================================= +# regions_overlap +# ============================================================================= + +test_that("regions_overlap detects overlapping regions on same chromosome", { + expect_true(regions_overlap("chr1:100-300", "chr1:200-400")) +}) + +test_that("regions_overlap returns FALSE for non-overlapping same-chr regions", { + expect_false(regions_overlap("chr1:100-200", "chr1:300-400")) +}) + +test_that("regions_overlap returns FALSE for different chromosomes", { + expect_false(regions_overlap("chr1:100-300", "chr2:100-300")) +}) + +test_that("regions_overlap detects touching boundaries", { + expect_true(regions_overlap("chr1:100-200", "chr1:200-300")) +}) + +test_that("regions_overlap works with underscore-separated IDs", { + expect_true(regions_overlap("1_100_300", "1_200_400")) + expect_false(regions_overlap("1_100_200", "2_100_200")) +}) + +test_that("regions_overlap works with data.frame input", { + df_a <- data.frame(chrom = 1, start = 100, end = 300) + df_b <- data.frame(chrom = 1, start = 200, end = 400) + expect_true(regions_overlap(df_a, df_b)) +}) + +# ============================================================================= +# find_overlapping_regions +# ============================================================================= + +test_that("find_overlapping_regions returns correct indices", { + query <- "chr1:100-300" + targets <- c("chr1:200-400", "chr2:100-200", "chr1:50-150") + result <- find_overlapping_regions(query, targets) + expect_true(1 %in% result) + expect_true(3 %in% result) + expect_false(2 %in% result) +}) + +test_that("find_overlapping_regions returns empty vector for no matches", { + query <- "chr1:100-200" + targets <- c("chr2:100-200", "chr3:100-200") + result <- find_overlapping_regions(query, targets) + expect_length(result, 0) +}) + +test_that("find_overlapping_regions works with data.frame targets", { + query <- "chr1:100-300" + targets <- data.frame(chrom = c(1, 2, 1), start = c(200, 100, 50), end = c(400, 200, 150)) + result <- find_overlapping_regions(query, targets) + expect_true(1 %in% result) + expect_true(3 %in% result) + expect_false(2 %in% result) +}) + +# ============================================================================= +# classify_variant_type +# ============================================================================= + +test_that("classify_variant_type identifies SNPs", { + expect_equal(classify_variant_type("chr1:100:A:G"), "SNP") +}) + +test_that("classify_variant_type identifies insertions", { + expect_equal(classify_variant_type("chr1:100:A:ATG"), "insertion") +}) + +test_that("classify_variant_type identifies deletions", { + expect_equal(classify_variant_type("chr1:100:ATG:A"), "deletion") +}) + +test_that("classify_variant_type identifies MNPs", { + expect_equal(classify_variant_type("chr1:100:AT:GC"), "MNP") +}) + +test_that("classify_variant_type handles vector input", { + ids <- c("chr1:100:A:G", "chr1:200:ATG:A", "chr1:300:A:ATG", "chr1:400:AT:GC") + result <- classify_variant_type(ids) + expect_equal(result, c("SNP", "deletion", "insertion", "MNP")) +}) + +test_that("classify_variant_type accepts data.frame input", { + df <- data.frame(A2 = c("A", "ATG"), A1 = c("G", "A")) + result <- classify_variant_type(df) + expect_equal(result, c("SNP", "deletion")) +}) + +# ============================================================================= +# ensure_chr_match +# ============================================================================= + +test_that("ensure_chr_match returns unchanged when both have chr prefix", { + ids_a <- c("chr1:100:A:G", "chr1:200:C:T") + ids_b <- c("chr1:150:A:G", "chr1:250:C:T") + result <- pecotmr:::ensure_chr_match(ids_a, ids_b) + expect_equal(result$ids_a, ids_a) + expect_equal(result$ids_b, ids_b) +}) + +test_that("ensure_chr_match normalizes when prefixes mismatch", { + ids_a <- c("chr1:100:A:G", "chr1:200:C:T") + ids_b <- c("1:150:A:G", "1:250:C:T") + result <- pecotmr:::ensure_chr_match(ids_a, ids_b) + expect_true(all(grepl("^chr", result$ids_a))) + expect_true(all(grepl("^chr", result$ids_b))) +}) + +test_that("ensure_chr_match returns unchanged when both lack chr prefix", { + ids_a <- c("1:100:A:G", "1:200:C:T") + ids_b <- c("1:150:A:G", "1:250:C:T") + result <- pecotmr:::ensure_chr_match(ids_a, ids_b) + # Both already match (no prefix), so returned unchanged + expect_equal(result$ids_a, ids_a) + expect_equal(result$ids_b, ids_b) +}) diff --git a/tests/testthat/test_twas_method_fallback.R b/tests/testthat/test_twas_method_fallback.R new file mode 100644 index 00000000..3ad8ac8b --- /dev/null +++ b/tests/testthat/test_twas_method_fallback.R @@ -0,0 +1,128 @@ +context("TWAS method fallback") + +# Helper to build a minimal twas_table for testing apply_method_fallback +make_twas_table <- function(methods, twas_z_values, rsq_values, is_selected, gwas_study = "study1") { + data.frame( + molecular_id = "gene1", + context = "ctx1", + gwas_study = gwas_study, + method = methods, + is_selected_method = is_selected, + is_imputable = TRUE, + rsq_cv = rsq_values, + pval_cv = 0.01, + twas_z = twas_z_values, + twas_pval = ifelse(is.na(twas_z_values) | !is.finite(twas_z_values), NA, 0.05), + type = "eQTL", + chr = 1, + block = "chr1_100_200", + stringsAsFactors = FALSE + ) +} + +# Access the internal function +apply_fallback <- pecotmr:::apply_method_fallback + +test_that("no fallback when selected method has valid z", { + df <- make_twas_table( + methods = c("susie", "enet", "lasso"), + twas_z_values = c(2.5, 1.8, 1.2), + rsq_values = c(0.3, 0.2, 0.1), + is_selected = c(TRUE, FALSE, FALSE) + ) + result <- apply_fallback(df) + expect_equal(result$method[result$is_selected_method], "susie") +}) + +test_that("fallback to next best method when selected has NA z", { + df <- make_twas_table( + methods = c("susie", "enet", "lasso"), + twas_z_values = c(NA, 1.8, 1.2), + rsq_values = c(0.3, 0.2, 0.1), + is_selected = c(TRUE, FALSE, FALSE) + ) + result <- apply_fallback(df) + expect_equal(result$method[result$is_selected_method], "enet") + expect_false(result$is_selected_method[result$method == "susie"]) +}) + +test_that("fallback to next best method when selected has Inf z", { + df <- make_twas_table( + methods = c("susie", "enet", "lasso"), + twas_z_values = c(Inf, 1.8, 1.2), + rsq_values = c(0.3, 0.2, 0.1), + is_selected = c(TRUE, FALSE, FALSE) + ) + result <- apply_fallback(df) + expect_equal(result$method[result$is_selected_method], "enet") +}) + +test_that("fallback picks highest rsq among valid candidates", { + df <- make_twas_table( + methods = c("susie", "enet", "lasso"), + twas_z_values = c(NA, 1.8, 2.0), + rsq_values = c(0.3, 0.1, 0.25), + is_selected = c(TRUE, FALSE, FALSE) + ) + result <- apply_fallback(df) + # lasso has higher rsq (0.25) than enet (0.1) + expect_equal(result$method[result$is_selected_method], "lasso") +}) + +test_that("all methods NA sets is_imputable to FALSE", { + df <- make_twas_table( + methods = c("susie", "enet", "lasso"), + twas_z_values = c(NA, NA, NA), + rsq_values = c(0.3, 0.2, 0.1), + is_selected = c(TRUE, FALSE, FALSE) + ) + result <- apply_fallback(df) + expect_true(all(!result$is_imputable)) +}) + +test_that("fallback is per-study: one study needs fallback, another does not", { + df1 <- make_twas_table( + methods = c("susie", "enet"), + twas_z_values = c(NA, 1.5), + rsq_values = c(0.3, 0.2), + is_selected = c(TRUE, FALSE), + gwas_study = "study1" + ) + df2 <- make_twas_table( + methods = c("susie", "enet"), + twas_z_values = c(2.5, 1.8), + rsq_values = c(0.3, 0.2), + is_selected = c(TRUE, FALSE), + gwas_study = "study2" + ) + df <- rbind(df1, df2) + result <- apply_fallback(df) + # study1: fallback to enet + s1 <- result[result$gwas_study == "study1", ] + expect_equal(s1$method[s1$is_selected_method], "enet") + # study2: no change + s2 <- result[result$gwas_study == "study2", ] + expect_equal(s2$method[s2$is_selected_method], "susie") +}) + +test_that("fallback handles empty data frame", { + df <- data.frame( + molecular_id = character(), context = character(), gwas_study = character(), + method = character(), is_selected_method = logical(), is_imputable = logical(), + rsq_cv = numeric(), pval_cv = numeric(), twas_z = numeric(), twas_pval = numeric(), + stringsAsFactors = FALSE + ) + result <- apply_fallback(df) + expect_equal(nrow(result), 0) +}) + +test_that("fallback handles -Inf z", { + df <- make_twas_table( + methods = c("susie", "enet"), + twas_z_values = c(-Inf, 1.5), + rsq_values = c(0.3, 0.2), + is_selected = c(TRUE, FALSE) + ) + result <- apply_fallback(df) + expect_equal(result$method[result$is_selected_method], "enet") +}) diff --git a/xqtl-protocol-pecotmr-audit.md b/xqtl-protocol-pecotmr-audit.md new file mode 100644 index 00000000..f98c84e2 --- /dev/null +++ b/xqtl-protocol-pecotmr-audit.md @@ -0,0 +1,232 @@ +# xqtl-protocol pecotmr Usage Audit + +## 1. API Mismatches (Calls Expected to Fail) + +### A. `load_quantile_twas_weights` does not exist in pecotmr + +**File:** `code/pecotmr_integration/twas_ctwas.ipynb` (quantile_twas cell) + +```r +twas_weights_results[[gene_db]] <- load_quantile_twas_weights( + weight_db_files = weight_dbs, tau_values = tau_values, + between_cluster = 0.8, num_intervals = 3) +``` + +This function is not defined anywhere in pecotmr. It was likely planned as part of the quantile TWAS feature but never implemented (or was removed). This cell will fail outright. + +### B. `twas_pipeline` called with nonexistent `quantile_twas` parameter + +**File:** `code/pecotmr_integration/twas_ctwas.ipynb` (quantile_twas cell) + +```r +twas_results_db <- twas_pipeline(..., quantile_twas = TRUE, ...) +``` + +`twas_pipeline()` at `R/twas.R:290` has no `quantile_twas` parameter and no `...` in its signature. This will error with "unused argument." + +### C. `rss_analysis_pipeline` called with renamed parameter `stochastic_ld_sample` + +**File:** `code/mnm_analysis/mnm_methods/rss_analysis.ipynb` (univariate_rss cell) + +```r +rss_analysis_pipeline(..., stochastic_ld_sample = ${stochastic_ld_sample}, ...) +``` + +This parameter was renamed to `sketch_samples` in the current pecotmr API (`R/univariate_pipeline.R:214`). The function has no `...`, so this will error with "unused argument." + +### D. `load_multitrait_tensorqtl_sumstat` called with wrong parameter name + +**File:** `code/multivariate_genome/MASH/mash_preprocessing.ipynb` (random_null_tensorqtl_1 cell) + +```r +pecotmr::load_multitrait_tensorqtl_sumstat( + phenotype_path = phenotype_path, ..., na_remove = T/F) +``` + +Two problems: +- First parameter is named `sumstats_paths` in pecotmr (`R/mash_wrapper.R:141`), not `phenotype_path` +- The parameter `na_remove` was renamed to `nan_remove` (`R/mash_wrapper.R:143`) + +Both will cause "unused argument" errors since there's no `...`. + +### E. `mash_ran_null_sample` - typo and removed parameters + +**File:** `code/multivariate_genome/MASH/mash_preprocessing.ipynb` (random_null_tensorqtl_1 cell) + +```r +pecotmr::mash_ran_null_sample(dat, n_random, n_null, + expected_ncondition, exclude_condition, z_only = TRUE, seed = ...) +``` + +Three problems: +- Function name is `mash_rand_null_sample` (with a "d") -- `R/mash_wrapper.R:568` +- `expected_ncondition` parameter no longer exists +- `z_only` parameter no longer exists + +The current signature is `mash_rand_null_sample(dat, n_random, n_null, exclude_condition, seed = NULL)`. + +### F. `get_ctwas_meta_data` is deprecated + +**File:** `code/pecotmr_integration/twas_ctwas.ipynb` (ctwas_1 and ctwas_3 cells) + +Used extensively. Still works but emits deprecation warnings and will eventually be removed. The replacement is `ld_loader()` per `R/ctwas_wrapper.R:58`. + +--- + +## 2. Safety/Sanity Checks That Could Move to pecotmr + +### A. Weight file pre-validation + +**File:** twas_ctwas.ipynb (twas cell, quantile_twas cell) + +Before calling `load_twas_weights()`, xqtl-protocol: +- Checks `file.size(file) > 200` (non-trivial file) +- Wraps `readRDS(file)` in `tryCatch` to filter corrupt files +- Validates nested structure (`twas_variant_names` key exists) +- Filters out NULL/empty results + +**Recommendation:** `load_twas_weights()` should do this validation internally -- skip files that are too small, corrupt, or structurally invalid, rather than requiring every caller to implement the same filter. + +### B. NA/Inf z-score filtering after TWAS + +**File:** twas_ctwas.ipynb (ctwas_1 cell) + +```r +z_gene[[study]] <- z_gene[[study]][ + !is.na(z_gene[[study]]$z) & !is.infinite(z_gene[[study]]$z) & + z_gene[[study]]$id %in% names(weight_list[[study]]),] +``` + +**Recommendation:** `twas_pipeline()` or the TWAS z-score computation itself should guarantee clean output. Downstream consumers shouldn't need to re-filter. + +### C. Duplicate LD variant removal + +**File:** twas_ctwas.ipynb (ctwas_1 cell) + +```r +dup_idx <- which(duplicated(LD_list$LD_variants)) +if (length(dup_idx) >= 1) LD_list$LD_matrix <- LD_list$LD_matrix[-dup_idx, -dup_idx] +``` + +**Recommendation:** `load_LD_matrix()` should handle this internally. Duplicate variants in the LD matrix are a data integrity issue that the loader should resolve before returning. + +### D. GWAS sample size validation + +**File:** twas_ctwas.ipynb (ctwas_1 cell) + +```r +if(length(z_snp[['sample_size']][[study]]!=1) | z_snp[['sample_size']][[study]] <= 0) { + stop("Please check sample size provided for ", study, " at --gwas_meta_data. ") +} +``` + +**Recommendation:** Could be validated inside `harmonize_gwas()` or the GWAS metadata loading step in pecotmr. + +### E. chr prefix normalization + +Scattered across multiple locations: +- ctwas_1: `ifelse(grepl("^chr", snp_map$id), snp_map$id, paste0("chr", snp_map$id))` +- mnm_postprocessing: `if(any(grepl("chr", qtl_all_var))) add_chr_prefix(gwas_all_var) else gsub("chr", "", gwas_all_var)` +- mash_preprocessing: retry with `gsub("chr", "", region)` on failure + +**Recommendation:** pecotmr already has `normalize_variant_id()` and internal `strip_chr_prefix()`, but variant ID harmonization at the "chr" level should be consistently handled in all loading functions rather than requiring callers to do it. + +### F. Genomic region overlap detection + +**File:** SuSiE_enloc.ipynb (susie_coloc cell) + +Manual region parsing and overlap checking: +```r +split_region <- unlist(strsplit(region, "_")) +block_chrom <- as.numeric(split_region[1] %>% gsub("chr","",.)) +block_start <- ... +if (gene_region$chrom == block_chrom && + (gene_region$start <= block_end | gene_region$end >= block_start)) +``` + +**Recommendation:** pecotmr has `parse_region()` and `region_to_df()` but lacks a simple `regions_overlap(a, b)` utility. This pattern is repeated enough to justify one. + +--- + +## 3. Generalizable Pipeline Logic Worth Moving to pecotmr + +### A. TWAS method selection with fallback (high value) + +**File:** twas_ctwas.ipynb -- `update_twas_method()` function + +This ~40-line function handles a real problem: the "best" TWAS method (by cross-validation) sometimes produces NA/Inf results for a specific GWAS. It falls back to the next-best method by rsq. This logic is not xQTL-specific -- it applies to any TWAS analysis. + +**What it does:** For each gene-context-GWAS group, if the selected method yielded invalid z/p-values, pick the best alternative method that has valid results and meets the rsq threshold. + +### B. TWAS-to-cTWAS region assembly orchestration (high value) + +**File:** twas_ctwas.ipynb (ctwas_1 cell) + +The entire workflow of: +1. Loading per-region TWAS results +2. Trimming variants via `trim_ctwas_variants()` +3. Getting chromosome-wide LD variant info +4. Harmonizing GWAS via `harmonize_gwas()` +5. Re-computing TWAS z-scores when variants are trimmed (calling `twas_analysis()` with fresh LD) +6. Assembling into cTWAS region data via `assemble_region_data()` + +This is ~200 lines of orchestration that any TWAS-to-cTWAS pipeline would need. It currently depends on a few ctwas package functions but the overall flow is generalizable. + +### C. cTWAS fine-mapping with LD diagnosis and recovery (high value) + +**File:** twas_ctwas.ipynb (ctwas_3 cell) + +The workflow of: +1. Screen regions -> fine-map -> diagnose LD mismatch -> identify problematic genes -> re-fine-map without LD -> merge boundary regions + +This is a robust, production-tested recipe for dealing with real-world LD mismatches. It's not specific to xQTL data at all. + +### D. GWAS metadata loading and per-study LD caching (medium value) + +**Files:** rss_analysis.ipynb, twas_ctwas.ipynb + +The pattern of: +- Reading a GWAS metadata TSV with study_id, chrom, file_path, sample_size columns +- Mapping studies to per-region file paths +- Caching LD matrices by study to avoid re-loading + +This is boilerplate that every multi-study RSS analysis repeats. The Python `load_regional_rss_data()` function in rss_analysis.ipynb does something similar -- it could inform an R equivalent. + +### E. MASH data batching and merging (medium value) + +**File:** mash_preprocessing.ipynb (susie_to_mash_1, susie_to_mash_2 cells) + +The pipeline of: +1. Processing regions in chunks (`per_chunk = 100`) +2. Extracting strong/random/null z-score matrices per region +3. Renaming rownames with region IDs for uniqueness +4. Merging across regions with `merge_mash_data()` +5. Filtering invalid entries with `filter_invalid_summary_stat()` +6. Computing `ZtZ = t(Z) %*% Z / n` + +The batching/merging/filtering logic is generalizable. pecotmr already has `merge_mash_data()` and `filter_invalid_summary_stat()`, but the end-to-end orchestration (batch -> merge -> filter -> ZtZ) could be a single pipeline function. + +### F. QTL-GWAS overlap analysis pipeline (medium value, Python) + +**File:** mnm_postprocessing.ipynb (overlap_qtl_gwas cells, Python) + +Loads QTL and GWAS metadata, groups by chromosome, checks region overlap, intersects variant lists with chr-prefix harmonization. This is applicable to any pairwise colocalization setup, not just xQTL. + +### G. Variant feature engineering (lower value, Python) + +**File:** gems_pipeline.py + +Parsing "chr:pos:ref:alt" variant IDs into structured fields and classifying as SNP/indel/insertion/deletion. Simple but broadly useful. pecotmr has `parse_variant_id()` already but the SNP/indel classification could be an addition. + +--- + +## Summary + +| Category | Count | Severity | +|----------|-------|----------| +| Calls that will fail | 5 (A-E) | Blocking | +| Deprecated but still working | 1 (F) | Warning | +| Sanity checks to absorb | 6 (A-F) | Robustness | +| Generalizable pipeline logic | 7 (A-G) | Feature opportunities | + +The most impactful items are the **quantile TWAS breakage** (the entire quantile_twas workflow is dead -- both `load_quantile_twas_weights` and the `quantile_twas` param to `twas_pipeline` don't exist), the **`stochastic_ld_sample` rename**, and the **MASH parameter name changes**. For generalizable logic, the **TWAS method fallback** and **cTWAS assembly/diagnosis pipelines** are the highest-value candidates to move into pecotmr. From aed51e62d7a1994eca12b74d33e7d440e1598324 Mon Sep 17 00:00:00 2001 From: Daniel Nachun Date: Tue, 12 May 2026 12:15:12 -0700 Subject: [PATCH 02/11] twas refactor part 1 --- NAMESPACE | 1 + R/LD.R | 61 ++++++ R/file_utils.R | 16 +- R/mash_wrapper.R | 2 + R/mr.R | 8 +- R/twas.R | 325 +++++++++++++----------------- man/load_ld_sketch.Rd | 33 +++ man/twas_analysis.Rd | 8 +- man/twas_z.Rd | 2 +- tests/testthat/test_misc.R | 20 +- tests/testthat/test_twas.R | 172 ++++++++++++---- tests/testthat/test_twas_sketch.R | 284 ++++++++++++++++++++++++++ twas-ld-sketch-plan.md | 142 +++++++++++++ twas-pipeline-analysis.md | 78 +++++++ 14 files changed, 912 insertions(+), 240 deletions(-) create mode 100644 man/load_ld_sketch.Rd create mode 100644 tests/testthat/test_twas_sketch.R create mode 100644 twas-ld-sketch-plan.md create mode 100644 twas-pipeline-analysis.md diff --git a/NAMESPACE b/NAMESPACE index 33776f91..45f802ef 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -117,6 +117,7 @@ export(ld_mismatch_qc) export(ld_prune_by_correlation) export(load_LD_matrix) export(load_genotype_region) +export(load_ld_sketch) export(load_multicontext_sumstats) export(load_multitask_regional_data) export(load_multitrait_R_sumstat) diff --git a/R/LD.R b/R/LD.R index a037a5de..f5df6dd2 100644 --- a/R/LD.R +++ b/R/LD.R @@ -487,6 +487,67 @@ load_LD_from_genotype <- function(genotype_path, region, ) } +# ---------- LD sketch: genotype loading ---------- + +#' HWE-based standardization of a genotype matrix +#' +#' Centers by 2*allele_freq, scales by sqrt(2*allele_freq*(1-allele_freq)). +#' Assumes monomorphic variants have already been removed. +#' +#' @param X Numeric genotype matrix (n x p). +#' @param allele_freq Numeric vector of allele frequencies (length p). +#' @return Standardized matrix (n x p). +#' @noRd +standardize_genotype_hwe <- function(X, allele_freq) { + X_std <- sweep(X, 2, 2 * allele_freq) + sweep(X_std, 2, sqrt(2 * allele_freq * (1 - allele_freq)), "/") +} + +#' Load LD sketch genotypes for a region +#' +#' Loads genotype data for a region via \code{load_LD_matrix(return_genotype=TRUE)} +#' and removes monomorphic variants. Returns the raw genotype matrix and metadata, +#' which callers can use to derive either a correlation matrix R (for summary-based +#' weight training or fine-mapping) or an SVD (for TWAS z-score computation). +#' +#' @param ld_meta_file_path Path to the LD metadata TSV file. +#' @param region Region of interest: "chr:start-end" string or data.frame with chrom/start/end. +#' @param n_sample Optional original panel sample size for computing variance +#' (= 2*p*(1-p)*n/(n-1)). Passed through to \code{load_LD_matrix()}. +#' +#' @return A list with: +#' \describe{ +#' \item{X}{Raw genotype matrix (n_sketch x p) after removing monomorphic variants.} +#' \item{n_sketch}{Number of rows (samples) in the sketch genotype matrix.} +#' \item{ref_panel}{Data.frame with variant metadata (chrom, pos, A2, A1, variant_id, +#' allele_freq, and optionally variance, n_nomiss).} +#' \item{variant_ids}{Character vector of variant IDs (canonical format) after +#' removing monomorphic variants.} +#' } +#' @export +load_ld_sketch <- function(ld_meta_file_path, region, n_sample = NULL) { + result <- load_LD_matrix(ld_meta_file_path, region, return_genotype = TRUE, n_sample = n_sample) + X <- result$LD_matrix + variant_ids <- result$LD_variants + ref_panel <- result$ref_panel + + # Remove monomorphic variants (zero variance under HWE) + p <- ref_panel$allele_freq + polymorphic <- p > 0 & p < 1 + if (!all(polymorphic)) { + X <- X[, polymorphic, drop = FALSE] + variant_ids <- variant_ids[polymorphic] + ref_panel <- ref_panel[polymorphic, , drop = FALSE] + } + + list( + X = X, + n_sketch = nrow(X), + ref_panel = ref_panel, + variant_ids = variant_ids + ) +} + # ---------- Internal: load LD from pre-computed blocks ---------- #' Load pre-computed LD from block-based metadata files. diff --git a/R/file_utils.R b/R/file_utils.R index 353a0f26..748f1b1d 100644 --- a/R/file_utils.R +++ b/R/file_utils.R @@ -1230,12 +1230,18 @@ load_twas_weights <- function(weight_db_files, conditions = NULL, multi_variants <- unique(find_data(combined_all_data$mnm_rs, c(2, variable_name_obj))) for (context in overl_contexts) { uni_variants <- get_nested_element(combined_all_data[[gene]][[context]], variable_name_obj) - multi_weights <- setNames(rep(0, length(uni_variants)), uni_variants) + # Harmonize chr prefix convention between multivariate and univariate variant IDs + chr_matched <- ensure_chr_match(multi_variants, uni_variants) + multi_variants_h <- chr_matched$ids_a + uni_variants_h <- chr_matched$ids_b + multi_weights <- setNames(rep(0, length(uni_variants_h)), uni_variants_h) multi_weights <- lapply(combined_all_data[["mnm_rs"]][[context]]$twas_weights, function(weight_list) { - aligned_weights <- setNames(rep(0, length(uni_variants)), uni_variants) - method_weight_variants <- names(unlist(weight_list)) - overlap_variants <- method_weight_variants[method_weight_variants %in% multi_variants[multi_variants %in% uni_variants]] # overlapping variants from method, multivariate, univariate - aligned_weights[overlap_variants] <- unlist(weight_list)[overlap_variants] + aligned_weights <- setNames(rep(0, length(uni_variants_h)), uni_variants_h) + weight_vals <- unlist(weight_list) + names(weight_vals) <- ensure_chr_match(names(weight_vals), uni_variants_h)$ids_a + method_weight_variants <- names(weight_vals) + overlap_variants <- method_weight_variants[method_weight_variants %in% multi_variants_h[multi_variants_h %in% uni_variants_h]] + aligned_weights[overlap_variants] <- weight_vals[overlap_variants] aligned_weights <- as.matrix(aligned_weights) }) combined_all_data[[gene]][[context]]$twas_weights <- c(combined_all_data[[gene]][[context]]$twas_weights, multi_weights) diff --git a/R/mash_wrapper.R b/R/mash_wrapper.R index 1605e658..88a69d6e 100644 --- a/R/mash_wrapper.R +++ b/R/mash_wrapper.R @@ -763,6 +763,8 @@ merge_sumstats_matrices <- function(matrix_list, value_column, ref_panel = NULL, cohort_df <- cbind(cohort_variants_df, value = df2[, value_column, drop = FALSE]) # Step 4: Merge with LD reference and filter + # Normalize ld_meta_file chrom to integer to match parse_variant_id output + ld_meta_file$chrom <- as.integer(strip_chr_prefix(as.character(ld_meta_file$chrom))) variants_ld_block_match <- merge(cohort_df, ld_meta_file, by = "chrom", allow.cartesian = TRUE) %>% filter(pos > start & pos < end) %>% select(-path) diff --git a/R/mr.R b/R/mr.R index 33bcc29d..6d3a79a0 100644 --- a/R/mr.R +++ b/R/mr.R @@ -97,8 +97,12 @@ mr_format <- function(susie_result, condition, gwas_sumstats_db, coverage = NULL ) susie_cs_result_formatted <- susie_cs_result_formatted$target_data_qced[, c("gene_name", "variant_id", "bhat_x", "sbhat_x", "cs", "pip")] } - # Normalize variant IDs to canonical format for matching - gwas_sumstats_db_extracted$variant_id <- normalize_variant_id(gwas_sumstats_db_extracted$variant_id) + # Ensure consistent chr prefix convention before intersecting + if (!is.null(susie_cs_result_formatted$variant_id) && !is.null(gwas_sumstats_db_extracted$variant_id)) { + chr_matched <- ensure_chr_match(susie_cs_result_formatted$variant_id, gwas_sumstats_db_extracted$variant_id) + susie_cs_result_formatted$variant_id <- chr_matched$ids_a + gwas_sumstats_db_extracted$variant_id <- chr_matched$ids_b + } common_variants <- intersect(susie_cs_result_formatted$variant_id, gwas_sumstats_db_extracted$variant_id) if (length(common_variants) == 0) return(.create_null_mr_df(gene_name, mr_format_spec)) diff --git a/R/twas.R b/R/twas.R index dad545b3..8dc8ad9d 100644 --- a/R/twas.R +++ b/R/twas.R @@ -29,67 +29,6 @@ #' @export harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, ld_reference_sample_size, column_file_path = NULL, comment_string = "#") { - # Function to group contexts based on start and end positions - group_contexts_by_region <- function(twas_weights_data, molecular_id, chrom, tolerance = 5000) { - region_info_df <- do.call(rbind, lapply(names(twas_weights_data$weights), function(context) { - wgt_range <- parse_variant_id(rownames(twas_weights_data[["weights"]][[context]]))$pos - data.frame(context = context, start = min(wgt_range), end = max(wgt_range)) - })) - if (nrow(region_info_df) == 1) { - # Handle case with only one context - single_context_group <- list( - context_group_1 = list( - contexts = region_info_df$context, - query_region = paste0(chrom, ":", region_info_df$start, "-", region_info_df$end), - all_variants = unique(rownames(twas_weights_data[["weights"]][[region_info_df$context]])) - ) - ) - return(single_context_group) - } - # Calculate distance matrix and perform hierarchical clustering - clusters <- cutree(hclust(dist(region_info_df[, c("start", "end")])), h = tolerance) - # Group contexts and determine query regions - region_groups <- split(region_info_df, clusters) %>% - lapply(function(group) { - list(contexts = group$context, query_region = paste0( - chrom, ":", min(group$start), - "-", max(group$end) - )) - }) - # Create IRanges objects and merge overlapping intervals - intervals <- IRanges(start = unlist(lapply(region_groups, function(context_group) { - as.numeric(gsub( - "^.*:\\s*|\\s*-.*$", "", - context_group$query_region - )) - })), end = unlist(lapply( - region_groups, - function(context_group) as.numeric(sub("^.*?\\-", "", context_group$query_region)) - ))) - reduced_intervals <- reduce(intervals) - - # Find which original groups are merged, and update region_groups lists - overlaps <- findOverlaps(intervals, reduced_intervals) - # Create merged groups based on overlap mapping - merged_groups <- lapply(seq_along(reduced_intervals), function(i) { - context_indices <- queryHits(overlaps)[subjectHits(overlaps) == i] - merged_contexts <- unlist(lapply(context_indices, function(idx) region_groups[[idx]]$contexts)) - list(contexts = merged_contexts, query_region = paste0(chrom, ":", start(reduced_intervals[i]), "-", end(reduced_intervals[i]))) - }) - names(merged_groups) <- paste0("context_group_", seq_along(merged_groups)) - # add variant names for coordinate extraction - for (group in names(merged_groups)) { - contexts <- merged_groups[[group]]$contexts - merged_groups[[group]]$all_variants <- unique(do.call(c, lapply( - contexts, - function(context) { - rownames(twas_weights_data[["weights"]][[context]]) - } - ))) - } - return(merged_groups) - } - # Step 1: load TWAS weights data molecular_ids <- names(twas_weights_data) chrom <- as.integer(parse_number(gsub(":.*$", "", rownames(twas_weights_data[[1]]$weights[[1]])[1]))) @@ -98,130 +37,114 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, names(gwas_files) <- unique(gwas_meta_df$study_id[gwas_meta_df$chrom == chrom]) results <- list() - # Step 2: Load LD for all events/genes by clustered context region - for (molecular_id in molecular_ids) { - twas_weights_data[[molecular_id]][["variant_names"]] <- lapply(twas_weights_data[[molecular_id]]$weights, function(x) rownames(x)) - } - region_variants <- variant_id_to_df(unique(do.call(c, find_data(twas_weights_data, c(2, "variant_names"))))) - region_of_interest <- data.frame(chrom = chrom, start = min(region_variants$pos), end = max(region_variants$pos)) - LD_list <- load_LD_matrix(ld_meta_file_path, region_of_interest, region_variants, - n_sample = ld_reference_sample_size) - # Convert LDData S4 object to legacy list format if needed - if (is(LD_list, "LDData")) { - LD_list <- ld_data_to_list(LD_list) - } - # remove duplicate variants - dup_idx <- which(duplicated(LD_list$LD_variants)) - if (length(dup_idx) >= 1) { - LD_list$LD_variants <- LD_list$LD_variants[-dup_idx] - LD_list$LD_matrix <- LD_list$LD_matrix[-dup_idx, -dup_idx] - LD_list$ref_panel <- LD_list$ref_panel[-dup_idx, ] - } - - # loop through genes/events: + # Per-gene loop: each gene loads its own LD sketch independently for (molecular_id in molecular_ids) { mol_data <- twas_weights_data[[molecular_id]] mol_res <- list(chrom = chrom, variant_names = list()) mol_res[["data_type"]] <- if ("data_type" %in% names(mol_data)) mol_data$data_type - - # group contexts based on the variant position - context_clusters <- group_contexts_by_region(mol_data, molecular_id, chrom, tolerance = 5000) - - # loop through contexts: grouping contexts can be useful during TWAS data harmonization to stratify variants for LD loading - for (context_group in names(context_clusters)) { - cluster <- context_clusters[[context_group]] - contexts <- cluster$contexts - query_region <- cluster$query_region - region_of_interest <- region_to_df(query_region) - all_variants <- variant_id_to_df(cluster$all_variants) - - # Step 3: load GWAS data for clustered context groups - for (study in names(gwas_files)) { - gwas_file <- gwas_files[study] - gwas_data_sumstats <- harmonize_gwas(gwas_file, query_region=query_region, - LD_list$LD_variants, c("beta", "z"), - match_min_prop = 0, column_file_path = column_file_path, comment_string = comment_string) - if(is.null(gwas_data_sumstats)) next - # loop through context within the context group: - for (context in contexts) { - weights_matrix <- mol_data[["weights"]][[context]] - - # Step 4: harmonize weights, flip allele - weights_matrix <- cbind(variant_id_to_df(rownames(weights_matrix)), weights_matrix) - weights_matrix_qced <- match_ref_panel(weights_matrix, LD_list$LD_variants, - colnames(weights_matrix)[!colnames(weights_matrix) %in% c("chrom", "pos", "A2", "A1")], - match_min_prop = 0 + contexts <- names(mol_data$weights) + + # Step 2: Build gene window from all contexts' variant positions + all_weight_variants <- unique(do.call(c, lapply(contexts, function(ctx) rownames(mol_data$weights[[ctx]])))) + variant_positions <- parse_variant_id(all_weight_variants)$pos + gene_region <- paste0(chrom, ":", min(variant_positions), "-", max(variant_positions)) + + # Step 3: Load LD sketch for this gene's window and compute SVD + sketch <- load_ld_sketch(ld_meta_file_path, gene_region, n_sample = ld_reference_sample_size) + X_std <- standardize_genotype_hwe(sketch$X, sketch$ref_panel$allele_freq) + svd_result <- safe_svd(X_std, tol = 0) + + # Step 4: Harmonize GWAS and weights against sketch variants + for (study in names(gwas_files)) { + gwas_file <- gwas_files[study] + gwas_data_sumstats <- harmonize_gwas(gwas_file, query_region = gene_region, + sketch$variant_ids, c("beta", "z"), + match_min_prop = 0, column_file_path = column_file_path, + comment_string = comment_string) + if (is.null(gwas_data_sumstats)) next + + for (context in contexts) { + weights_matrix <- mol_data[["weights"]][[context]] + + # Harmonize weights against sketch reference + weights_matrix <- cbind(variant_id_to_df(rownames(weights_matrix)), weights_matrix) + weights_matrix_qced <- match_ref_panel(weights_matrix, sketch$variant_ids, + colnames(weights_matrix)[!colnames(weights_matrix) %in% c("chrom", "pos", "A2", "A1")], + match_min_prop = 0 + ) + qced_data <- weights_matrix_qced$target_data_qced + weights_matrix_subset <- as.matrix(qced_data[, !colnames(qced_data) %in% c( + "chrom", "pos", "A2", "A1", "variant_id", "variants_id_original" + ), drop = FALSE]) + rownames(weights_matrix_subset) <- qced_data$variant_id + + # Ensure consistent chr prefix convention before intersecting + chr_matched <- ensure_chr_match(gwas_data_sumstats$variant_id, sketch$variant_ids) + gwas_data_sumstats$variant_id <- chr_matched$ids_a + rownames(weights_matrix_subset) <- ensure_chr_match(rownames(weights_matrix_subset), gwas_data_sumstats$variant_id)$ids_a + weights_matrix_subset <- weights_matrix_subset[rownames(weights_matrix_subset) %in% gwas_data_sumstats$variant_id, , drop = FALSE] + if (nrow(weights_matrix_subset) == 0) next + postqc_weight_variants <- rownames(weights_matrix_subset) + + # Step 5: adjust SuSiE weights based on available variants + if ("susie_weights" %in% colnames(mol_data[["weights"]][[context]])) { + adjusted_susie_weights <- adjust_susie_weights(mol_data, + keep_variants = postqc_weight_variants, run_allele_qc = TRUE, + variable_name_obj = c("variant_names", context), + susie_obj = c("susie_results", context), + twas_weights_table = c("weights", context), postqc_weight_variants, match_min_prop = 0 ) - qced_data <- weights_matrix_qced$target_data_qced - weights_matrix_subset <- as.matrix(qced_data[, !colnames(qced_data) %in% c( - "chrom", "pos", "A2", "A1", "variant_id", "variants_id_original" - ), drop = FALSE]) - rownames(weights_matrix_subset) <- qced_data$variant_id - - # intersect post-qc gwas and post-qc weight variants (all now in canonical chr-prefix format) - gwas_LD_variants <- intersect(gwas_data_sumstats$variant_id, LD_list$LD_variants) - weights_matrix_subset <- weights_matrix_subset[rownames(weights_matrix_subset) %in% gwas_data_sumstats$variant_id, , drop = FALSE] - if (nrow(weights_matrix_subset) == 0) next - postqc_weight_variants <- rownames(weights_matrix_subset) - - # Step 5: adjust SuSiE weights based on available variants - if ("susie_weights" %in% colnames(mol_data[["weights"]][[context]])) { - adjusted_susie_weights <- adjust_susie_weights(mol_data, - keep_variants = postqc_weight_variants, run_allele_qc = TRUE, - variable_name_obj = c("variant_names", context), - susie_obj = c("susie_results", context), - twas_weights_table = c("weights", context), postqc_weight_variants, match_min_prop = 0 - ) - weights_matrix_subset <- cbind( - susie_weights = setNames(adjusted_susie_weights$adjusted_susie_weights, adjusted_susie_weights$remained_variants_ids), - weights_matrix_subset[adjusted_susie_weights$remained_variants_ids, !colnames(weights_matrix_subset) %in% "susie_weights", drop = FALSE] - ) - susie_intermediate <- mol_data$susie_results[[context]][c("pip", "cs_variants", "cs_purity")] - names(susie_intermediate[["pip"]]) <- rownames(weights_matrix) # original variants that is not qced yet - pip <- susie_intermediate[["pip"]] - pip_qced <- match_ref_panel(cbind(parse_variant_id(names(pip)), pip), LD_list$LD_variants, "pip", match_min_prop = 0) - susie_intermediate[["pip"]] <- abs(pip_qced$target_data_qced$pip) - names(susie_intermediate[["pip"]]) <- pip_qced$target_data_qced$variant_id - susie_intermediate[["cs_variants"]] <- lapply(susie_intermediate[["cs_variants"]], function(x) { - variant_qc <- match_ref_panel(x, LD_list$LD_variants, match_min_prop = 0) - variant_qc$target_data_qced$variant_id[variant_qc$target_data_qced$variant_id %in% postqc_weight_variants] - }) - mol_res[["susie_weights_intermediate_qced"]][[context]] <- susie_intermediate - } - rm(weights_matrix) # context specific original weight matrix - gc() - - if (nrow(weights_matrix_subset) == 0) { - warning("weights_matrix_subset is empty. Skipping this context.") - next - } - mol_res[["variant_names"]][[context]][[study]] <- rownames(weights_matrix_subset) + weights_matrix_subset <- cbind( + susie_weights = setNames(adjusted_susie_weights$adjusted_susie_weights, adjusted_susie_weights$remained_variants_ids), + weights_matrix_subset[adjusted_susie_weights$remained_variants_ids, !colnames(weights_matrix_subset) %in% "susie_weights", drop = FALSE] + ) + susie_intermediate <- mol_data$susie_results[[context]][c("pip", "cs_variants", "cs_purity")] + names(susie_intermediate[["pip"]]) <- rownames(weights_matrix) # original variants not yet qced + pip <- susie_intermediate[["pip"]] + pip_qced <- match_ref_panel(cbind(parse_variant_id(names(pip)), pip), sketch$variant_ids, "pip", match_min_prop = 0) + susie_intermediate[["pip"]] <- abs(pip_qced$target_data_qced$pip) + names(susie_intermediate[["pip"]]) <- pip_qced$target_data_qced$variant_id + susie_intermediate[["cs_variants"]] <- lapply(susie_intermediate[["cs_variants"]], function(x) { + variant_qc <- match_ref_panel(x, sketch$variant_ids, match_min_prop = 0) + variant_qc$target_data_qced$variant_id[variant_qc$target_data_qced$variant_id %in% postqc_weight_variants] + }) + mol_res[["susie_weights_intermediate_qced"]][[context]] <- susie_intermediate + } + rm(weights_matrix) - # Step 6: scale weights by variance (from ref_panel, populated by load_LD_matrix) - variance <- LD_list$ref_panel$variance[match(rownames(weights_matrix_subset), LD_list$ref_panel$variant_id)] - mol_res[["weights_qced"]][[context]][[study]] <- list(scaled_weights = weights_matrix_subset * sqrt(variance), weights = weights_matrix_subset) + if (nrow(weights_matrix_subset) == 0) { + warning("weights_matrix_subset is empty. Skipping this context.") + next } - # Combine gwas sumstat across different context for a single context group (all variant_ids now in canonical format) - gwas_data_sumstats <- gwas_data_sumstats[gwas_data_sumstats$variant_id %in% unique(find_data(mol_res[["variant_names"]], c(2, study))), , drop = FALSE] - mol_res[["gwas_qced"]][[study]] <- rbind(mol_res[["gwas_qced"]][[study]], gwas_data_sumstats) + mol_res[["variant_names"]][[context]][[study]] <- rownames(weights_matrix_subset) + + # Step 6: scale weights by variance (from sketch ref_panel) + variance <- sketch$ref_panel$variance[match(rownames(weights_matrix_subset), sketch$ref_panel$variant_id)] + mol_res[["weights_qced"]][[context]][[study]] <- list(scaled_weights = weights_matrix_subset * sqrt(variance), weights = weights_matrix_subset) + } + # Combine GWAS sumstats for this study (filter to variants used by any context) + used_variants <- unique(find_data(mol_res[["variant_names"]], c(2, study))) + if (!is.null(used_variants)) { + gwas_subset <- gwas_data_sumstats[gwas_data_sumstats$variant_id %in% used_variants, , drop = FALSE] + mol_res[["gwas_qced"]][[study]] <- rbind(mol_res[["gwas_qced"]][[study]], gwas_subset) gwas_qced <- mol_res[["gwas_qced"]][[study]] mol_res[["gwas_qced"]][[study]] <- gwas_qced[!duplicated(gwas_qced[, c("variant_id", "z")]), ] } } + twas_weights_data[[molecular_id]] <- NULL - # extract LD matrix for variants intersect with gwas and twas weights at molecular_id level - all_molecular_variants <- unique(find_data(mol_res[["gwas_qced"]], c(2, "variant_id"))) - if (is.null(all_molecular_variants)) { + # Store SVD components for this gene + if (is.null(mol_res[["gwas_qced"]]) || length(mol_res[["gwas_qced"]]) == 0) { results[[molecular_id]] <- NULL } else { - # All variant IDs are now in canonical chr-prefix format - var_indx <- match(all_molecular_variants, LD_list$LD_variants) - mol_res[["LD"]] <- as.matrix(LD_list$LD_matrix[var_indx, var_indx]) + mol_res[["svd_V"]] <- svd_result$v + mol_res[["svd_D"]] <- svd_result$d + mol_res[["n_sketch"]] <- sketch$n_sketch + mol_res[["ld_variant_ids"]] <- sketch$variant_ids results[[molecular_id]] <- mol_res } } - # return results - return(list(twas_data_qced = results, ref_panel = LD_list$ref_panel)) + return(list(twas_data_qced = results, ref_panel = sketch$ref_panel)) } #' Harmonize GWAS Summary Statistics @@ -529,8 +452,13 @@ twas_pipeline <- function(twas_weights_data, } # twas analysis twas_rs <- twas_analysis( - twas_data_qced[[weight_db]][["weights_qced"]][[context]][[study]][["weights"]], twas_data_qced[[weight_db]][["gwas_qced"]][[study]], - twas_data_qced[[weight_db]][["LD"]], twas_variants + twas_data_qced[[weight_db]][["weights_qced"]][[context]][[study]][["weights"]], + twas_data_qced[[weight_db]][["gwas_qced"]][[study]], + extract_variants_objs = twas_variants, + V = twas_data_qced[[weight_db]][["svd_V"]], + D = twas_data_qced[[weight_db]][["svd_D"]], + n_sketch = twas_data_qced[[weight_db]][["n_sketch"]], + ld_variant_ids = twas_data_qced[[weight_db]][["ld_variant_ids"]] ) if (is.null(twas_rs)) { return(list(twas_rs_df = data.frame(), mr_rs_df = data.frame())) @@ -570,7 +498,10 @@ twas_pipeline <- function(twas_weights_data, mr_context_table <- do.call(rbind, lapply(study_results, function(x) x$mr_rs_df)) return(list(twas_context_table = twas_context_table, mr_context_table = mr_context_table)) }) - twas_data_qced[[weight_db]][["LD"]] <- NULL + twas_data_qced[[weight_db]][["svd_V"]] <- NULL + twas_data_qced[[weight_db]][["svd_D"]] <- NULL + twas_data_qced[[weight_db]][["n_sketch"]] <- NULL + twas_data_qced[[weight_db]][["ld_variant_ids"]] <- NULL twas_weights_data[[weight_db]] <- NULL twas_gene_table <- do.call(rbind, lapply(twas_gene_results, function(x) x$twas_context_table)) mr_gene_table <- do.call(rbind, lapply(twas_gene_results, function(x) x$mr_context_table)) @@ -654,16 +585,24 @@ twas_pipeline <- function(twas_weights_data, #' @importFrom stats cor pchisq #' #' @export -twas_z <- function(weights, z, R = NULL, X = NULL) { +twas_z <- function(weights, z, R = NULL, X = NULL, V = NULL, D = NULL, n_sketch = NULL) { # Check that weights and z-scores have the same length if (length(weights) != length(z)) { stop("Weights and z-scores must have the same length.") } - if (is.null(R)) R <- compute_LD(X) - stat <- t(weights) %*% z - denom <- t(weights) %*% R %*% weights + + if (!is.null(V) && !is.null(D) && !is.null(n_sketch)) { + # SVD path: denom = wᵀRw = sum(Lambda * (Vᵀw)²) where Lambda = D²/(n_sketch-1) + Lambda <- D^2 / (n_sketch - 1) + Vw <- crossprod(V, weights) + denom <- sum(Lambda * Vw^2) + } else { + if (is.null(R)) R <- compute_LD(X) + denom <- t(weights) %*% R %*% weights + } + zscore <- stat / sqrt(denom) pval <- pchisq(zscore * zscore, 1, lower.tail = FALSE) @@ -748,27 +687,47 @@ twas_joint_z <- function(weights, z, R = NULL, X = NULL) { #' #' @return A list with TWAS z-scores and p-values across four methods for each gene. #' @export -twas_analysis <- function(weights_matrix, gwas_sumstats_db, LD_matrix, extract_variants_objs) { +twas_analysis <- function(weights_matrix, gwas_sumstats_db, LD_matrix = NULL, + extract_variants_objs, V = NULL, D = NULL, + n_sketch = NULL, ld_variant_ids = NULL) { # Extract gwas_sumstats gwas_sumstats_subset <- gwas_sumstats_db[match(extract_variants_objs, gwas_sumstats_db$variant_id), ] # Validate that the GWAS subset is not empty if (nrow(gwas_sumstats_subset) == 0 | all(is.na(gwas_sumstats_subset))) { warning("No GWAS summary statistics found for the specified variants.") return(NULL) - } - # Check if extract_variants_objs are in the rownames of LD_matrix + } + + # SVD path + if (!is.null(V) && !is.null(D) && !is.null(n_sketch) && !is.null(ld_variant_ids)) { + valid_indices <- extract_variants_objs %in% ld_variant_ids + if (!any(valid_indices)) { + warning("None of the specified variants are present in the LD sketch. Skipping this context.") + return(NULL) + } + valid_variants_objs <- extract_variants_objs[valid_indices] + # Subset V rows to match the valid variants + v_row_idx <- match(valid_variants_objs, ld_variant_ids) + V_subset <- V[v_row_idx, , drop = FALSE] + weights_matrix <- weights_matrix[valid_variants_objs, , drop = FALSE] + gwas_sumstats_subset <- gwas_sumstats_db[match(valid_variants_objs, gwas_sumstats_db$variant_id), ] + twas_z_pval <- apply( + as.matrix(weights_matrix), 2, + function(x) twas_z(x, gwas_sumstats_subset$z, V = V_subset, D = D, n_sketch = n_sketch) + ) + return(twas_z_pval) + } + + # LD matrix path valid_indices <- extract_variants_objs %in% rownames(LD_matrix) if (!any(valid_indices)) { warning("None of the specified variants are present in the LD matrix. Skipping this context.") return(NULL) - } - # Extract only the valid indices from extract_variants_objs + } valid_variants_objs <- extract_variants_objs[valid_indices] - # Extract LD_matrix subset using valid indices LD_matrix_subset <- LD_matrix[valid_variants_objs, valid_variants_objs] - # Extract weight matrix subset using valid indices weights_matrix <- weights_matrix[valid_variants_objs, , drop = FALSE] - # Caculate the z score and pvalue of each gene + gwas_sumstats_subset <- gwas_sumstats_db[match(valid_variants_objs, gwas_sumstats_db$variant_id), ] twas_z_pval <- apply( as.matrix(weights_matrix), 2, function(x) twas_z(x, gwas_sumstats_subset$z, R = LD_matrix_subset) diff --git a/man/load_ld_sketch.Rd b/man/load_ld_sketch.Rd new file mode 100644 index 00000000..db4e5397 --- /dev/null +++ b/man/load_ld_sketch.Rd @@ -0,0 +1,33 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/LD.R +\name{load_ld_sketch} +\alias{load_ld_sketch} +\title{Load LD sketch genotypes for a region} +\usage{ +load_ld_sketch(ld_meta_file_path, region, n_sample = NULL) +} +\arguments{ +\item{ld_meta_file_path}{Path to the LD metadata TSV file.} + +\item{region}{Region of interest: "chr:start-end" string or data.frame with chrom/start/end.} + +\item{n_sample}{Optional original panel sample size for computing variance +(= 2*p*(1-p)*n/(n-1)). Passed through to \code{load_LD_matrix()}.} +} +\value{ +A list with: +\describe{ + \item{X}{Raw genotype matrix (n_sketch x p) after removing monomorphic variants.} + \item{n_sketch}{Number of rows (samples) in the sketch genotype matrix.} + \item{ref_panel}{Data.frame with variant metadata (chrom, pos, A2, A1, variant_id, + allele_freq, and optionally variance, n_nomiss).} + \item{variant_ids}{Character vector of variant IDs (canonical format) after + removing monomorphic variants.} +} +} +\description{ +Loads genotype data for a region via \code{load_LD_matrix(return_genotype=TRUE)} +and removes monomorphic variants. Returns the raw genotype matrix and metadata, +which callers can use to derive either a correlation matrix R (for summary-based +weight training or fine-mapping) or an SVD (for TWAS z-score computation). +} diff --git a/man/twas_analysis.Rd b/man/twas_analysis.Rd index 351761fb..51027c11 100644 --- a/man/twas_analysis.Rd +++ b/man/twas_analysis.Rd @@ -7,8 +7,12 @@ twas_analysis( weights_matrix, gwas_sumstats_db, - LD_matrix, - extract_variants_objs + LD_matrix = NULL, + extract_variants_objs, + V = NULL, + D = NULL, + n_sketch = NULL, + ld_variant_ids = NULL ) } \arguments{ diff --git a/man/twas_z.Rd b/man/twas_z.Rd index 806134f5..a6814574 100644 --- a/man/twas_z.Rd +++ b/man/twas_z.Rd @@ -4,7 +4,7 @@ \alias{twas_z} \title{Calculate TWAS z-score and p-value} \usage{ -twas_z(weights, z, R = NULL, X = NULL) +twas_z(weights, z, R = NULL, X = NULL, V = NULL, D = NULL, n_sketch = NULL) } \arguments{ \item{weights}{A numeric vector of weights.} diff --git a/tests/testthat/test_misc.R b/tests/testthat/test_misc.R index 43e09021..7dfb1068 100644 --- a/tests/testthat/test_misc.R +++ b/tests/testthat/test_misc.R @@ -6,15 +6,15 @@ library(tidyverse) # ============================================================================= test_that("Test compute_maf freq 0.5",{ - expect_equal(compute_maf(rep(1, 20)), 0.5) + expect_equal(pecotmr:::compute_maf(rep(1, 20)), 0.5) }) test_that("Test compute_maf freq 0.6",{ - expect_equal(compute_maf(rep(1.2, 20)), 0.4) + expect_equal(pecotmr:::compute_maf(rep(1.2, 20)), 0.4) }) test_that("Test compute_maf freq 0.3",{ - expect_equal(compute_maf(rep(0.6, 20)), 0.3) + expect_equal(pecotmr:::compute_maf(rep(0.6, 20)), 0.3) }) test_that("Test compute_maf with NA",{ @@ -23,11 +23,11 @@ test_that("Test compute_maf with NA",{ vals <- c(1.2, NA) return(sample(vals, sample_size, replace = TRUE)) } - expect_equal(compute_maf(generate_small_dataset()), 0.4) + expect_equal(pecotmr:::compute_maf(generate_small_dataset()), 0.4) }) test_that("compute_maf returns 0 for monomorphic (all 0)", { - expect_equal(compute_maf(rep(0, 10)), 0) + expect_equal(pecotmr:::compute_maf(rep(0, 10)), 0) }) # ============================================================================= @@ -36,7 +36,7 @@ test_that("compute_maf returns 0 for monomorphic (all 0)", { test_that("test compute_missing",{ small_dataset <- c(rep(NA, 20), rep(1, 80)) - expect_equal(compute_missing(small_dataset), 0.2) + expect_equal(pecotmr:::compute_missing(small_dataset), 0.2) }) # ============================================================================= @@ -45,12 +45,12 @@ test_that("test compute_missing",{ test_that("Test compute_non_missing_y",{ small_dataset <- c(rep(NA, 20), rep(1, 80)) - expect_equal(compute_non_missing_y(small_dataset), 80) + expect_equal(pecotmr:::compute_non_missing_y(small_dataset), 80) }) test_that("Test compute_all_missing_y",{ small_dataset <- c(rep(NA, 20), rep(1, 80)) - expect_equal(compute_all_missing_y(small_dataset), F) + expect_equal(pecotmr:::compute_all_missing_y(small_dataset), F) }) test_that("compute_all_missing_y returns TRUE for all-NA vector", { @@ -67,7 +67,7 @@ test_that("compute_all_missing_y returns FALSE for partially NA vector", { test_that("Test mean_impute",{ dummy_data <- matrix(c(1,2,NA,1,2,3), nrow=3, ncol=2) - expect_equal(mean_impute(dummy_data)[3,1], 1.5) + expect_equal(pecotmr:::mean_impute(dummy_data)[3,1], 1.5) }) test_that("mean_impute with all NAs in a column imputes NaN", { @@ -83,7 +83,7 @@ test_that("mean_impute with all NAs in a column imputes NaN", { test_that("Test is_zero_variance",{ dummy_data <- matrix(c(1,2,3,1,1,1), nrow=3, ncol=2) - col <- which(apply(dummy_data, 2, is_zero_variance)) + col <- which(apply(dummy_data, 2, pecotmr:::is_zero_variance)) expect_equal(col, 2) }) diff --git a/tests/testthat/test_twas.R b/tests/testthat/test_twas.R index d1137648..223d6373 100644 --- a/tests/testthat/test_twas.R +++ b/tests/testthat/test_twas.R @@ -1504,7 +1504,10 @@ test_that("twas_pipeline: pick_best_model path is executed when rsq_cutoff > 0", z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -1565,7 +1568,10 @@ test_that("twas_pipeline: pick_best_model selects model with best rsq", { z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -1628,7 +1634,10 @@ test_that("twas_pipeline: pick_best_model skips context when no model passes thr z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -1752,7 +1761,10 @@ test_that("twas_pipeline: full pipeline with mocked harmonize_twas produces twas z = gwas_z, stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -1824,7 +1836,10 @@ test_that("twas_pipeline: multiple genes processed correctly", { z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) } @@ -1879,7 +1894,10 @@ test_that("twas_pipeline: empty twas_variants intersection returns empty data fr z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -1948,7 +1966,10 @@ test_that("twas_pipeline: output_twas_data=TRUE triggers format_twas_data path", z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix, + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix), model_selection = list(ctx1 = list(selected_model = "susie", is_imputable = TRUE)) ) ) @@ -2025,7 +2046,10 @@ test_that("twas_pipeline: multiple contexts are processed independently", { z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -2088,7 +2112,10 @@ test_that("twas_pipeline: mr_result is returned in final result", { z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -2144,7 +2171,10 @@ test_that("twas_pipeline: when TWAS analysis yields no results, returns NULL com z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -2390,7 +2420,10 @@ test_that("twas_pipeline: processes multiple GWAS studies correctly", { studyA = data.frame(variant_id = variant_ids, z = rnorm(p), stringsAsFactors = FALSE), studyB = data.frame(variant_id = variant_ids, z = rnorm(p), stringsAsFactors = FALSE) ), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -2457,7 +2490,10 @@ test_that("twas_pipeline: rsq_option='adj_rsq' uses adjusted R-squared", { z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -2524,7 +2560,10 @@ test_that("twas_pipeline: rsq_pval_option='adj_rsq_pval' uses the correct p-valu z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -2726,6 +2765,7 @@ test_that("harmonize_twas: group_contexts_by_region single context path (lines 4 A2 = rep("A", p), A1 = rep("T", p), variant_id = variant_ids, + allele_freq = rep(0.3, p), variance = rep(1.0, p), stringsAsFactors = FALSE ) @@ -2754,11 +2794,16 @@ test_that("harmonize_twas: group_contexts_by_region single context path (lines 4 skip_if_not_installed("readr") local_mocked_bindings( - load_LD_matrix = function(...) { + load_ld_sketch = function(...) { + n_sketch <- 50L + p_sketch <- length(mock_LD_variants) + set.seed(123) + X <- matrix(rbinom(n_sketch * p_sketch, 2, 0.3), nrow = n_sketch, ncol = p_sketch) list( - LD_matrix = mock_LD_matrix, - LD_variants = mock_LD_variants, - ref_panel = mock_ref_panel + X = X, + n_sketch = n_sketch, + ref_panel = mock_ref_panel, + variant_ids = mock_LD_variants ) }, get_ref_variant_info = function(...) mock_snp_info, @@ -2865,15 +2910,21 @@ test_that("harmonize_twas: group_contexts_by_region multi-context clustering (li )) local_mocked_bindings( - load_LD_matrix = function(...) { + load_ld_sketch = function(...) { + n_sketch <- 50L + p_sketch <- length(all_variant_ids) + set.seed(123) + X <- matrix(rbinom(n_sketch * p_sketch, 2, 0.3), nrow = n_sketch, ncol = p_sketch) list( - LD_matrix = mock_LD_matrix, - LD_variants = all_variant_ids, + X = X, + n_sketch = n_sketch, ref_panel = data.frame(chrom = 1, pos = as.integer(sapply(strsplit(all_variant_ids, ":"), `[`, 2)), A2 = "A", A1 = "T", variant_id = all_variant_ids, + allele_freq = rep(0.3, length(all_variant_ids)), variance = rep(1.0, length(all_variant_ids)), - stringsAsFactors = FALSE) + stringsAsFactors = FALSE), + variant_ids = all_variant_ids ) }, get_ref_variant_info = function(...) mock_snp_info, @@ -2952,7 +3003,10 @@ test_that("twas_pipeline: adj_rsq_pval option exercised in pick_best_model", { z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -3025,7 +3079,10 @@ test_that("twas_pipeline: missing data_type triggers assignment check on line 61 z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -3094,7 +3151,10 @@ test_that("twas_pipeline: twas_analysis returning NULL yields empty rows (line 6 z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -3167,7 +3227,10 @@ test_that("twas_pipeline: MR path entered when susie_results and significant twa z = c(10.0, 8.0, 1.0, 0.5, 0.2), # Large z to get small pval for MR trigger stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -3222,7 +3285,10 @@ test_that("twas_pipeline: event_filters filtering some but not all contexts", { gwas_qced = list(study1 = data.frame( variant_id = variant_ids, z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) @@ -3293,9 +3359,26 @@ test_that("harmonize_twas: duplicated LD variants are removed", { stringsAsFactors = FALSE ) + # Dedup now happens inside load_LD_matrix (called by load_ld_sketch). + # Mock load_ld_sketch to return already-deduped data (3 unique variants). local_mocked_bindings( - load_LD_matrix = function(...) { - list(LD_matrix = dup_LD_matrix, LD_variants = dup_variant_ids, ref_panel = dup_ref_panel) + load_ld_sketch = function(...) { + n_sketch <- 50L + set.seed(123) + X <- matrix(rbinom(n_sketch * p, 2, 0.3), nrow = n_sketch, ncol = p) + list( + X = X, + n_sketch = n_sketch, + ref_panel = data.frame( + chrom = rep(1, p), pos = c(100, 200, 300), + A2 = rep("A", p), A1 = rep("T", p), + variant_id = variant_ids, + allele_freq = rep(0.3, p), + variance = rep(1.0, p), + stringsAsFactors = FALSE + ), + variant_ids = variant_ids + ) }, harmonize_gwas = function(...) mock_gwas_data, match_ref_panel = function(target_data, ref_data, ...) { @@ -3323,7 +3406,7 @@ test_that("harmonize_twas: duplicated LD variants are removed", { result <- harmonize_twas(twas_weights_data, "fake_ld.tsv", gwas_meta_file, ld_reference_sample_size = 17000) expect_true(is.list(result)) - # After dedup, the LD should have been pared back to unique variants (3, not 4). + # ref_panel should have the 3 unique variants expect_equal(nrow(result$ref_panel), p) }) @@ -3351,12 +3434,18 @@ test_that("harmonize_twas: drops molecular_id when harmonize_gwas returns NULL f rownames(LD_matrix) <- colnames(LD_matrix) <- variant_ids ref_panel <- data.frame( chrom = 1, pos = c(100, 200, 300), A2 = "A", A1 = "T", - variant_id = variant_ids, variance = 1.0, stringsAsFactors = FALSE + variant_id = variant_ids, allele_freq = 0.3, variance = 1.0, stringsAsFactors = FALSE ) local_mocked_bindings( - load_LD_matrix = function(...) { - list(LD_matrix = LD_matrix, LD_variants = variant_ids, ref_panel = ref_panel) + load_ld_sketch = function(...) { + n_sketch <- 50L + set.seed(123) + X <- matrix(rbinom(n_sketch * p, 2, 0.3), nrow = n_sketch, ncol = p) + list( + X = X, n_sketch = n_sketch, + ref_panel = ref_panel, variant_ids = variant_ids + ) }, # Returning NULL skips the entire context loop, so gwas_qced stays empty harmonize_gwas = function(...) NULL, @@ -3374,7 +3463,7 @@ test_that("harmonize_twas: drops molecular_id when harmonize_gwas returns NULL f result <- harmonize_twas(twas_weights_data, "fake_ld.tsv", gwas_meta_file, ld_reference_sample_size = 17000) - # gene1 should have been dropped (line 217 sets results[[mid]] <- NULL) + # gene1 should have been dropped (gwas_qced is empty) expect_true(is.list(result)) expect_null(result$twas_data_qced$gene1) }) @@ -3414,7 +3503,7 @@ test_that("harmonize_twas: susie_weights column triggers adjust_susie_weights br rownames(LD_matrix) <- colnames(LD_matrix) <- variant_ids ref_panel <- data.frame( chrom = 1, pos = c(100, 200, 300), A2 = "A", A1 = "T", - variant_id = variant_ids, variance = 1.0, stringsAsFactors = FALSE + variant_id = variant_ids, allele_freq = 0.3, variance = 1.0, stringsAsFactors = FALSE ) mock_gwas_data <- data.frame( @@ -3425,8 +3514,14 @@ test_that("harmonize_twas: susie_weights column triggers adjust_susie_weights br ) local_mocked_bindings( - load_LD_matrix = function(...) { - list(LD_matrix = LD_matrix, LD_variants = variant_ids, ref_panel = ref_panel) + load_ld_sketch = function(...) { + n_sketch <- 50L + set.seed(123) + X <- matrix(rbinom(n_sketch * p, 2, 0.3), nrow = n_sketch, ncol = p) + list( + X = X, n_sketch = n_sketch, + ref_panel = ref_panel, variant_ids = variant_ids + ) }, harmonize_gwas = function(...) mock_gwas_data, match_ref_panel = function(target_data, ref_data, ...) { @@ -3500,7 +3595,10 @@ test_that("twas_pipeline: pick_best_model returns NULL when all CV rsq are NA", gwas_qced = list(study1 = data.frame( variant_id = variant_ids, z = rnorm(p), stringsAsFactors = FALSE )), - LD = LD_matrix + svd_V = diag(nrow(LD_matrix)), + svd_D = rep(sqrt(49), nrow(LD_matrix)), + n_sketch = 50L, + ld_variant_ids = rownames(LD_matrix) ) ) diff --git a/tests/testthat/test_twas_sketch.R b/tests/testthat/test_twas_sketch.R new file mode 100644 index 00000000..7f54fa99 --- /dev/null +++ b/tests/testthat/test_twas_sketch.R @@ -0,0 +1,284 @@ +# Tests for SVD-based TWAS computation (LD sketch path) + +# Phase 1: twas_z() and twas_analysis() SVD branch + +test_that("twas_z: SVD path matches R path for full-rank genotype matrix", { + set.seed(42) + n <- 100 # samples (sketch size) + p <- 20 # variants + + # Generate genotype-like matrix (dosages 0/1/2) + X <- matrix(rbinom(n * p, 2, runif(p, 0.1, 0.9)), nrow = n, ncol = p) + + # HWE-based standardization + af <- colMeans(X) / 2 + X_std <- sweep(X, 2, 2 * af) + X_std <- sweep(X_std, 2, sqrt(2 * af * (1 - af)), "/") + + # Compute R from standardized X + R <- crossprod(X_std) / (n - 1) + + # SVD of standardized X + svd_result <- svd(X_std) + + # Random weights and z-scores + weights <- rnorm(p) + z <- rnorm(p) + + # R path + result_R <- pecotmr:::twas_z(weights, z, R = R) + + # SVD path + result_SVD <- pecotmr:::twas_z(weights, z, V = svd_result$v, D = svd_result$d, n_sketch = n) + + expect_equal(as.numeric(result_SVD$z), as.numeric(result_R$z), tolerance = 1e-10) + expect_equal(as.numeric(result_SVD$pval), as.numeric(result_R$pval), tolerance = 1e-10) +}) + +test_that("twas_z: SVD path matches R path for rank-deficient matrix (n < p)", { + set.seed(123) + n <- 15 # fewer samples than variants + + p <- 30 + + X <- matrix(rbinom(n * p, 2, runif(p, 0.15, 0.85)), nrow = n, ncol = p) + + af <- colMeans(X) / 2 + # Remove monomorphic columns + keep <- af > 0 & af < 1 + X <- X[, keep, drop = FALSE] + af <- af[keep] + p <- ncol(X) + + X_std <- sweep(X, 2, 2 * af) + X_std <- sweep(X_std, 2, sqrt(2 * af * (1 - af)), "/") + + R <- crossprod(X_std) / (n - 1) + svd_result <- svd(X_std) + + weights <- rnorm(p) + z <- rnorm(p) + + result_R <- pecotmr:::twas_z(weights, z, R = R) + result_SVD <- pecotmr:::twas_z(weights, z, V = svd_result$v, D = svd_result$d, n_sketch = n) + + expect_equal(as.numeric(result_SVD$z), as.numeric(result_R$z), tolerance = 1e-10) + expect_equal(as.numeric(result_SVD$pval), as.numeric(result_R$pval), tolerance = 1e-10) +}) + +test_that("twas_z: error when weights and z have different lengths", { + expect_error( + pecotmr:::twas_z(rnorm(5), rnorm(3), V = matrix(1, 5, 2), D = c(1, 1), n_sketch = 10), + "Weights and z-scores must have the same length" + ) +}) + +test_that("twas_analysis: SVD path produces same results as R path", { + set.seed(99) + n <- 50 + p <- 10 + variant_ids <- paste0("chr1:", seq(1000, by = 100, length.out = p), ":A:G") + + X <- matrix(rbinom(n * p, 2, runif(p, 0.2, 0.8)), nrow = n, ncol = p) + af <- colMeans(X) / 2 + X_std <- sweep(X, 2, 2 * af) + X_std <- sweep(X_std, 2, sqrt(2 * af * (1 - af)), "/") + + R <- crossprod(X_std) / (n - 1) + rownames(R) <- colnames(R) <- variant_ids + svd_result <- svd(X_std) + + # Weights matrix (2 methods) + weights_matrix <- matrix(rnorm(p * 2), nrow = p, ncol = 2) + rownames(weights_matrix) <- variant_ids + colnames(weights_matrix) <- c("lasso_weights", "enet_weights") + + # GWAS data + gwas_df <- data.frame(variant_id = variant_ids, z = rnorm(p)) + + # Use a subset of variants + extract_variants <- variant_ids[3:8] + + result_R <- pecotmr:::twas_analysis(weights_matrix, gwas_df, LD_matrix = R, + extract_variants_objs = extract_variants) + result_SVD <- pecotmr:::twas_analysis(weights_matrix, gwas_df, + extract_variants_objs = extract_variants, + V = svd_result$v, D = svd_result$d, + n_sketch = n, ld_variant_ids = variant_ids) + + expect_equal(as.numeric(result_SVD[[1]]$z), as.numeric(result_R[[1]]$z), tolerance = 1e-10) + expect_equal(as.numeric(result_SVD[[2]]$z), as.numeric(result_R[[2]]$z), tolerance = 1e-10) + expect_equal(as.numeric(result_SVD[[1]]$pval), as.numeric(result_R[[1]]$pval), tolerance = 1e-10) + expect_equal(as.numeric(result_SVD[[2]]$pval), as.numeric(result_R[[2]]$pval), tolerance = 1e-10) +}) + +test_that("twas_analysis: SVD path handles partial variant overlap", { + set.seed(77) + n <- 40 + p <- 8 + variant_ids <- paste0("chr1:", seq(1000, by = 100, length.out = p), ":A:G") + + X <- matrix(rbinom(n * p, 2, 0.3), nrow = n, ncol = p) + af <- colMeans(X) / 2 + X_std <- sweep(X, 2, 2 * af) + X_std <- sweep(X_std, 2, sqrt(2 * af * (1 - af)), "/") + svd_result <- svd(X_std) + + weights_matrix <- matrix(rnorm(p), nrow = p, ncol = 1) + rownames(weights_matrix) <- variant_ids + colnames(weights_matrix) <- "lasso_weights" + + gwas_df <- data.frame(variant_id = variant_ids, z = rnorm(p)) + + # Request variants where some are NOT in ld_variant_ids + extra_variant <- "chr1:5000:A:G" + extract_variants <- c(variant_ids[1:4], extra_variant) + + result <- pecotmr:::twas_analysis(weights_matrix, gwas_df, + extract_variants_objs = extract_variants, + V = svd_result$v, D = svd_result$d, + n_sketch = n, ld_variant_ids = variant_ids) + + # Should succeed using only the 4 valid variants + expect_false(is.null(result)) + expect_equal(length(result[[1]]$z), 1) +}) + +test_that("twas_analysis: SVD path returns NULL when no variants overlap", { + variant_ids <- paste0("chr1:", 1:5, ":A:G") + other_ids <- paste0("chr2:", 1:5, ":A:G") + + weights_matrix <- matrix(1, nrow = 5, ncol = 1) + rownames(weights_matrix) <- variant_ids + colnames(weights_matrix) <- "w" + gwas_df <- data.frame(variant_id = variant_ids, z = rnorm(5)) + + result <- suppressWarnings(pecotmr:::twas_analysis( + weights_matrix, gwas_df, + extract_variants_objs = variant_ids, + V = matrix(1, 5, 2), D = c(1, 1), + n_sketch = 10, ld_variant_ids = other_ids + )) + + expect_null(result) +}) + +# Phase 2: load_ld_sketch() and standardize_genotype_hwe() + +test_that("standardize_genotype_hwe: centers by 2p and scales by sqrt(2p(1-p))", { + set.seed(42) + n <- 30 + p <- 5 + af <- runif(p, 0.1, 0.9) + X <- matrix(rbinom(n * p, 2, rep(af, each = n)), nrow = n, ncol = p) + + X_std <- pecotmr:::standardize_genotype_hwe(X, af) + + # Manual verification + expected <- sweep(sweep(X, 2, 2 * af), 2, sqrt(2 * af * (1 - af)), "/") + expect_equal(X_std, expected, tolerance = 1e-14) +}) + +test_that("load_ld_sketch: returns raw genotypes and metadata", { + set.seed(55) + n <- 30 + p <- 12 + variant_ids <- paste0("chr1:", seq(1000, by = 100, length.out = p), ":A:G") + + # Create a mock genotype matrix + af_true <- runif(p, 0.1, 0.9) + X <- matrix(rbinom(n * p, 2, rep(af_true, each = n)), nrow = n, ncol = p) + + # Build mock ref_panel + ref_panel <- data.frame( + chrom = 1L, pos = seq(1000, by = 100, length.out = p), + A2 = "A", A1 = "G", + variant_id = variant_ids, + allele_freq = colMeans(X) / 2, + stringsAsFactors = FALSE + ) + + local_mocked_bindings( + load_LD_matrix = function(ld_meta_file_path, region, return_genotype = FALSE, n_sample = NULL, ...) { + list( + LD_matrix = X, + LD_variants = variant_ids, + ref_panel = ref_panel, + is_genotype = TRUE + ) + }, + .package = "pecotmr" + ) + + result <- pecotmr::load_ld_sketch("fake_path.tsv", "chr1:1000-2100") + + # Check structure — returns raw X, not SVD + expect_true(all(c("X", "n_sketch", "ref_panel", "variant_ids") %in% names(result))) + expect_null(result$V) + expect_null(result$D) + expect_equal(result$n_sketch, n) + expect_equal(nrow(result$X), n) + expect_equal(ncol(result$X), p) + expect_equal(length(result$variant_ids), p) + + # Raw genotype matrix is returned unchanged + expect_equal(result$X, X) +}) + +test_that("load_ld_sketch: removes monomorphic variants", { + set.seed(66) + n <- 20 + p <- 5 + variant_ids <- paste0("chr1:", 1:p, ":A:G") + + # Make column 3 monomorphic (all 0) + X <- matrix(rbinom(n * p, 2, 0.3), nrow = n, ncol = p) + X[, 3] <- 0 # monomorphic + + ref_panel <- data.frame( + chrom = 1L, pos = 1:p, + A2 = "A", A1 = "G", + variant_id = variant_ids, + allele_freq = colMeans(X) / 2, + stringsAsFactors = FALSE + ) + + local_mocked_bindings( + load_LD_matrix = function(ld_meta_file_path, region, return_genotype = FALSE, n_sample = NULL, ...) { + list( + LD_matrix = X, + LD_variants = variant_ids, + ref_panel = ref_panel, + is_genotype = TRUE + ) + }, + .package = "pecotmr" + ) + + result <- pecotmr::load_ld_sketch("fake_path.tsv", "chr1:1-5") + + # Monomorphic variant removed + expect_equal(length(result$variant_ids), p - 1) + expect_false(variant_ids[3] %in% result$variant_ids) + expect_equal(nrow(result$ref_panel), p - 1) + expect_equal(ncol(result$X), p - 1) +}) + +test_that("SVD from raw sketch matches direct computation", { + set.seed(77) + n <- 25 + p <- 8 + af <- runif(p, 0.15, 0.85) + X <- matrix(rbinom(n * p, 2, rep(af, each = n)), nrow = n, ncol = p) + + # Two-step process: standardize then SVD + X_std <- pecotmr:::standardize_genotype_hwe(X, af) + svd_result <- svd(X_std) + + # Verify this matches manual computation + X_manual <- sweep(sweep(X, 2, 2 * af), 2, sqrt(2 * af * (1 - af)), "/") + svd_manual <- svd(X_manual) + + expect_equal(svd_result$d, svd_manual$d, tolerance = 1e-14) + expect_equal(abs(svd_result$v), abs(svd_manual$v), tolerance = 1e-14) +}) diff --git a/twas-ld-sketch-plan.md b/twas-ld-sketch-plan.md new file mode 100644 index 00000000..7f1068b9 --- /dev/null +++ b/twas-ld-sketch-plan.md @@ -0,0 +1,142 @@ +# TWAS Pipeline Refactoring: LD Sketch + SVD + +## Context + +The current TWAS pipeline loads precomputed LD correlation matrices R from independent LD blocks (or computes R from genotype files), then computes `z_twas = (wᵀz) / √(wᵀRw)`. This requires aligning gene windows with LD block boundaries, concatenating blocks, deduplicating boundary variants, and extracting per-gene submatrices — the dominant source of complexity in twas.R. + +The refactoring replaces R with an SVD of "LD sketch" genotypes — random projections of the reference panel that preserve LD structure. Each gene independently defines its own window, loads the sketch for that window, and computes the TWAS statistic via: + +``` +z_twas = (wᵀz) / ‖Λ^(1/2) Vᵀw‖ where Λ = D²/(n_sketch - 1) +``` + +This eliminates the LD block alignment problem entirely. + +--- + +## Implementation Steps + +### Step 1: Thin wrapper `load_ld_sketch()` in R/LD.R + +Add after `load_LD_from_genotype()` (~line 480). This is a thin wrapper around the existing `load_LD_matrix()` with `return_genotype = TRUE` — it reuses all existing genotype loading, variant ID normalization, ref_panel construction (allele frequencies, variance), and duplicate removal. The only new logic is standardizing X and computing SVD. + +```r +#' @export +load_ld_sketch <- function(ld_meta_file_path, region, n_sample = NULL) +``` + +Logic: +1. Call `load_LD_matrix(ld_meta_file_path, region, return_genotype = TRUE, n_sample = n_sample)` — this returns the genotype matrix X, ref_panel (with allele_freq), LD_variants, etc. +2. Standardize X using HWE-based formula with `p = ref_panel$allele_freq`: + - Center: `X_c = sweep(X, 2, 2*p)` (subtract theoretical mean `2p`) + - Scale: `X_std = sweep(X_c, 2, sqrt(2*p*(1-p)), "/")` (divide by theoretical SD) + - Drop zero-variance columns where `p == 0` or `p == 1` (and corresponding ref_panel rows / variant_ids) +3. `safe_svd(X_std, tol = 0)` — full SVD, no truncation +4. Return: `list(V = svd$v, D = svd$d, n_sketch = nrow(X), ref_panel = result$ref_panel, variant_ids = result$LD_variants)` + +Key: `n_sketch` = nrow(X_sketch) used for Λ normalization. `n_sample` = original panel size, passed through to `load_LD_matrix()` for variance computation. + +### Step 2: Modify `twas_z()` in R/twas.R (lines 653-667) + +New signature: +```r +twas_z <- function(weights, z, R = NULL, X = NULL, V = NULL, D = NULL, n_sketch = NULL) +``` + +Add SVD branch before existing R/X branches: +- If V, D, n_sketch provided: `Lambda = D^2 / (n_sketch - 1)`, `denom = sum(Lambda * (crossprod(V, weights))^2)`, `zscore = stat / sqrt(denom)` +- Existing R and X paths remain unchanged + +### Step 3: Modify `twas_analysis()` in R/twas.R (lines 747-773) + +New signature: +```r +twas_analysis <- function(weights_matrix, gwas_sumstats_db, LD_matrix = NULL, + extract_variants_objs, V = NULL, D = NULL, + n_sketch = NULL, ld_variant_ids = NULL) +``` + +SVD path: subset V rows via `match(valid_variants_objs, ld_variant_ids)`, pass V_subset/D/n_sketch to `twas_z()`. Existing LD_matrix path remains. + +### Step 4: Refactor `harmonize_twas()` in R/twas.R (lines 30-226) + +**Remove:** +- `group_contexts_by_region()` inner function (lines 33-91) +- Region-wide `load_LD_matrix()` call (lines 106-115) +- Per-gene LD submatrix extraction (lines 211-220) +- Context-grouping loop structure + +**New per-gene loop:** +``` +for (molecular_id in molecular_ids): + 1. Collect variant positions across all contexts → build gene window region + 2. load_ld_sketch(ld_meta_file_path, gene_region, n_sample) — internally calls load_LD_matrix(return_genotype=TRUE), standardizes, SVDs + 3. Use sketch$variant_ids as reference for allele QC + + for (study in gwas_studies): + 4. harmonize_gwas(gwas_file, gene_region, sketch$variant_ids, ...) + + for (context in contexts): + 5. match_ref_panel(weights, sketch$variant_ids, ...) + 6. adjust_susie_weights() if needed + 7. Scale weights by sqrt(ref_panel$variance) + + 8. Store sketch SVD: mol_res[["svd_V"]], [["svd_D"]], [["n_sketch"]], [["ld_variant_ids"]] +``` + +Return value changes: `mol_res[["LD"]]` (p×p matrix) → `mol_res[["svd_V"]]`, `[["svd_D"]]`, `[["n_sketch"]]`, `[["ld_variant_ids"]]` + +### Step 5: Update `twas_pipeline()` in R/twas.R + +- Lines 529-532: Pass SVD components to `twas_analysis()` instead of `[["LD"]]` +- Line 569: Clear SVD components instead of LD matrix after use + +### Step 6: Tests + +New tests in a new file `tests/testthat/test_twas_sketch.R`: + +1. **Mathematical equivalence test**: Generate X, compute R=cor(X), compute SVD(X_std). Run `twas_z()` both ways, assert identical z-scores within floating-point tolerance. +2. **`load_ld_sketch()` unit test**: Mock `load_LD_matrix()` with known X. Verify SVD standardization, zero-variance column removal, ref_panel passthrough. +3. **`twas_analysis()` SVD path**: Verify partial variant overlap works correctly. +4. **End-to-end mock test**: Full pipeline with SVD path produces correct results. + +### Step 7: NAMESPACE and documentation + +Run `devtools::document()` after adding roxygen to `load_ld_sketch()` and updating param docs on modified functions. + +--- + +## Files to modify + +| File | Changes | +|------|---------| +| `R/LD.R` | Add `load_ld_sketch()` (thin wrapper around existing `load_LD_matrix(return_genotype=TRUE)` + SVD) | +| `R/twas.R` | Modify `twas_z()`, `twas_analysis()`, `harmonize_twas()`, `twas_pipeline()`. Remove `group_contexts_by_region()` | +| `tests/testthat/test_twas_sketch.R` | New test file | +| `NAMESPACE` | Auto-generated via roxygen | + +## Files NOT modified + +| File | Reason | +|------|--------| +| `R/LD.R` (load_LD_matrix, load_LD_from_genotype, load_LD_from_blocks) | Reused as-is; `load_ld_sketch()` calls `load_LD_matrix()` | +| `R/allele_qc.R` | match_ref_panel unchanged | +| `R/file_utils.R` | load_genotype_region unchanged | +| `R/misc.R` | safe_svd, compute_LD unchanged | +| `R/susie_wrapper.R` | adjust_susie_weights unchanged | +| `R/ctwas_wrapper.R` | cTWAS deferred | +| xqtl-protocol repo | Not changed in this PR | + +## Existing functions reused (not reimplemented) + +- `load_LD_matrix()` — R/LD.R:265 — main LD loader; `load_ld_sketch()` calls this with `return_genotype=TRUE` +- `load_LD_from_genotype()` — R/LD.R:408 — called by `load_LD_matrix()` for genotype sources; handles genotype loading, variant ID normalization, ref_panel with allele frequencies/variance, .afreq sidecar +- `safe_svd()` — R/misc.R:156 — SVD with tolerance filtering +- `match_ref_panel()` — R/allele_qc.R:26 — allele QC (reference = sketch variants) +- `ensure_chr_match()` — R/misc.R:480 — chr prefix harmonization + +## Verification + +1. Run `pixi run Rscript -e 'devtools::load_all("."); testthat::test_file("tests/testthat/test_twas_sketch.R")'` +2. Run existing tests to verify no regressions: `pixi run Rscript -e 'devtools::load_all("."); testthat::test_file("tests/testthat/test_twas_method_fallback.R")'` +3. Mathematical equivalence: the key test generates a genotype matrix, computes TWAS z-scores via both R and SVD paths, asserts they match within 1e-10. diff --git a/twas-pipeline-analysis.md b/twas-pipeline-analysis.md new file mode 100644 index 00000000..28bf19af --- /dev/null +++ b/twas-pipeline-analysis.md @@ -0,0 +1,78 @@ +# TWAS Pipeline Analysis + +## End-to-end data flow + +``` +Weight RDS files ──→ load_twas_weights() ──┐ +GWAS TSV (tabix) ──→ harmonize_gwas() ──┤──→ harmonize_twas() ──→ twas_analysis() ──→ results +LD metadata TSV ──→ load_LD_matrix() ──┘ │ │ + │ ↓ + match_ref_panel() z = (wᵀz) / √(wᵀRw) + (allele QC engine) +``` + +## The pipeline has 3 layers + +### Layer 1: `twas_pipeline()` — orchestrator (R/twas.R:326-632) + +The main entry point. It: +1. Optionally filters molecular events via `event_filters` +2. Calls `harmonize_twas()` to load + QC everything +3. For each gene: picks the best model by CV R² (via `pick_best_model()`) +4. For each gene-context-study triple: calls `twas_analysis()` to compute z-scores +5. Optionally runs MR analysis if SuSiE credible sets exist and p-value is small enough +6. Merges CV metrics with TWAS results into a single table +7. Applies `apply_method_fallback()` for NA/Inf z-scores +8. Optionally formats output for cTWAS via `format_twas_data()` + +### Layer 2: `harmonize_twas()` — data harmonization (R/twas.R:30-226) + +The most complex function. It: +1. **Groups contexts by genomic position** — contexts whose weight variants are within 5kb of each other get clustered so they share a single LD query region +2. **Loads LD once** for the combined region via `load_LD_matrix()` +3. For each GWAS study: calls `harmonize_gwas()` which loads data via tabix, standardizes columns, then calls `match_ref_panel()` to align alleles to the LD reference +4. For each context: calls `match_ref_panel()` again to align the weight matrix to LD, then optionally `adjust_susie_weights()` to recalculate SuSiE weights for the variant subset +5. Scales weights by `sqrt(variance)` where variance comes from the LD ref panel +6. Extracts a per-gene LD submatrix for the intersection of GWAS + weight variants + +**The three-way intersection is the core issue** — every variant must be present in all three sources (weights, GWAS, LD) to participate, and alleles must be harmonized across all three. + +### Layer 3: `twas_analysis()` / `twas_z()` — the actual computation (R/twas.R:747-773, 653-667) + +This is tiny and straightforward: +- Subset everything to shared variants +- For each method: `z_twas = (wᵀ * z_gwas) / sqrt(wᵀ * R * w)` +- Return z-score and chi-squared p-value + +## Where the complexity lives + +| Area | What happens | Why it's complex | +|------|-------------|------------------| +| **Data loading** | `load_twas_weights()` merges multivariate (mnm_rs) + univariate weights from multiple RDS files | Deeply nested list structures, context name cleaning, weight alignment across methods | +| **Allele QC** | `match_ref_panel()` handles exact match, sign flip, strand flip, INDEL matching | 6+ boolean flags, inner join on (chrom, pos) then allele matching logic | +| **Context grouping** | `group_contexts_by_region()` clusters contexts by variant position overlap | Hierarchical clustering + IRanges interval merging | +| **SuSiE adjustment** | `adjust_susie_weights()` recalculates weights from log Bayes factors when variants are dropped | Re-derives alpha from LBF, recomputes posterior means | +| **Model selection** | Best model chosen by CV R², with fallback if z-score is NA/Inf | Two-pass: first in `pick_best_model()`, then `apply_method_fallback()` | + +## The xqtl-protocol adds another layer on top + +The `twas_ctwas.ipynb` notebook wraps `twas_pipeline()` with: +1. **Weight validation** — file size checks, tryCatch on readRDS (now ported into pecotmr) +2. **Batch loading** — `batch_load_twas_weights()` splits genes by memory +3. **`update_twas_method()`** — a SECOND method fallback pass after `twas_pipeline()` returns (the in-pecotmr `apply_method_fallback()` was ported from this) +4. **Context name merging** — `merge_context_names()` strips region suffixes from weight names +5. **cTWAS assembly** — ~200 lines loading LD again, re-running `harmonize_gwas()`, calling `trim_ctwas_variants()`, `assemble_region_data()`, then `est_param()`, `screen_regions()`, `finemap_regions()` + +## Key observations + +1. **`match_ref_panel()` is called 3+ times per gene** — once for GWAS vs LD, once for weights vs LD, and once inside `adjust_susie_weights()`. Each call re-parses variant IDs and re-joins. + +2. **LD is loaded once but variant subsetting happens repeatedly** — `harmonize_twas()` loads the full region LD, then extracts per-gene submatrices. But in the cTWAS cell, LD is loaded *again* for the same chromosome. + +3. **The data structure is deeply nested** — `twas_weights_data[[molecular_id]]$weights[[context]]` requires `get_nested_element()` and `find_data()` utilities just to navigate. The same molecular_id/context/study keys appear at multiple levels. + +4. **Model selection happens in two places** — `pick_best_model()` inside `twas_pipeline()` selects the best method before TWAS, then `apply_method_fallback()` fixes bad selections after TWAS. xqtl-protocol had a third pass via `update_twas_method()`. + +5. **Weight scaling is tightly coupled to LD loading** — `scaled_weights = weights * sqrt(variance)` depends on allele frequencies and sample size from the LD ref panel, computed during `load_LD_matrix()`. + +6. **The actual TWAS math is ~10 lines** — the other ~600 lines in twas.R are data loading, harmonization, QC, and result formatting. From 35fc33911a7fdf40b66ed80f0d109d9b65b5fd19 Mon Sep 17 00:00:00 2001 From: Daniel Nachun Date: Tue, 12 May 2026 12:16:38 -0700 Subject: [PATCH 03/11] file cleanup --- twas-ld-sketch-plan.md | 142 -------------------- twas-pipeline-analysis.md | 78 ----------- xqtl-protocol-pecotmr-audit.md | 232 --------------------------------- 3 files changed, 452 deletions(-) delete mode 100644 twas-ld-sketch-plan.md delete mode 100644 twas-pipeline-analysis.md delete mode 100644 xqtl-protocol-pecotmr-audit.md diff --git a/twas-ld-sketch-plan.md b/twas-ld-sketch-plan.md deleted file mode 100644 index 7f1068b9..00000000 --- a/twas-ld-sketch-plan.md +++ /dev/null @@ -1,142 +0,0 @@ -# TWAS Pipeline Refactoring: LD Sketch + SVD - -## Context - -The current TWAS pipeline loads precomputed LD correlation matrices R from independent LD blocks (or computes R from genotype files), then computes `z_twas = (wᵀz) / √(wᵀRw)`. This requires aligning gene windows with LD block boundaries, concatenating blocks, deduplicating boundary variants, and extracting per-gene submatrices — the dominant source of complexity in twas.R. - -The refactoring replaces R with an SVD of "LD sketch" genotypes — random projections of the reference panel that preserve LD structure. Each gene independently defines its own window, loads the sketch for that window, and computes the TWAS statistic via: - -``` -z_twas = (wᵀz) / ‖Λ^(1/2) Vᵀw‖ where Λ = D²/(n_sketch - 1) -``` - -This eliminates the LD block alignment problem entirely. - ---- - -## Implementation Steps - -### Step 1: Thin wrapper `load_ld_sketch()` in R/LD.R - -Add after `load_LD_from_genotype()` (~line 480). This is a thin wrapper around the existing `load_LD_matrix()` with `return_genotype = TRUE` — it reuses all existing genotype loading, variant ID normalization, ref_panel construction (allele frequencies, variance), and duplicate removal. The only new logic is standardizing X and computing SVD. - -```r -#' @export -load_ld_sketch <- function(ld_meta_file_path, region, n_sample = NULL) -``` - -Logic: -1. Call `load_LD_matrix(ld_meta_file_path, region, return_genotype = TRUE, n_sample = n_sample)` — this returns the genotype matrix X, ref_panel (with allele_freq), LD_variants, etc. -2. Standardize X using HWE-based formula with `p = ref_panel$allele_freq`: - - Center: `X_c = sweep(X, 2, 2*p)` (subtract theoretical mean `2p`) - - Scale: `X_std = sweep(X_c, 2, sqrt(2*p*(1-p)), "/")` (divide by theoretical SD) - - Drop zero-variance columns where `p == 0` or `p == 1` (and corresponding ref_panel rows / variant_ids) -3. `safe_svd(X_std, tol = 0)` — full SVD, no truncation -4. Return: `list(V = svd$v, D = svd$d, n_sketch = nrow(X), ref_panel = result$ref_panel, variant_ids = result$LD_variants)` - -Key: `n_sketch` = nrow(X_sketch) used for Λ normalization. `n_sample` = original panel size, passed through to `load_LD_matrix()` for variance computation. - -### Step 2: Modify `twas_z()` in R/twas.R (lines 653-667) - -New signature: -```r -twas_z <- function(weights, z, R = NULL, X = NULL, V = NULL, D = NULL, n_sketch = NULL) -``` - -Add SVD branch before existing R/X branches: -- If V, D, n_sketch provided: `Lambda = D^2 / (n_sketch - 1)`, `denom = sum(Lambda * (crossprod(V, weights))^2)`, `zscore = stat / sqrt(denom)` -- Existing R and X paths remain unchanged - -### Step 3: Modify `twas_analysis()` in R/twas.R (lines 747-773) - -New signature: -```r -twas_analysis <- function(weights_matrix, gwas_sumstats_db, LD_matrix = NULL, - extract_variants_objs, V = NULL, D = NULL, - n_sketch = NULL, ld_variant_ids = NULL) -``` - -SVD path: subset V rows via `match(valid_variants_objs, ld_variant_ids)`, pass V_subset/D/n_sketch to `twas_z()`. Existing LD_matrix path remains. - -### Step 4: Refactor `harmonize_twas()` in R/twas.R (lines 30-226) - -**Remove:** -- `group_contexts_by_region()` inner function (lines 33-91) -- Region-wide `load_LD_matrix()` call (lines 106-115) -- Per-gene LD submatrix extraction (lines 211-220) -- Context-grouping loop structure - -**New per-gene loop:** -``` -for (molecular_id in molecular_ids): - 1. Collect variant positions across all contexts → build gene window region - 2. load_ld_sketch(ld_meta_file_path, gene_region, n_sample) — internally calls load_LD_matrix(return_genotype=TRUE), standardizes, SVDs - 3. Use sketch$variant_ids as reference for allele QC - - for (study in gwas_studies): - 4. harmonize_gwas(gwas_file, gene_region, sketch$variant_ids, ...) - - for (context in contexts): - 5. match_ref_panel(weights, sketch$variant_ids, ...) - 6. adjust_susie_weights() if needed - 7. Scale weights by sqrt(ref_panel$variance) - - 8. Store sketch SVD: mol_res[["svd_V"]], [["svd_D"]], [["n_sketch"]], [["ld_variant_ids"]] -``` - -Return value changes: `mol_res[["LD"]]` (p×p matrix) → `mol_res[["svd_V"]]`, `[["svd_D"]]`, `[["n_sketch"]]`, `[["ld_variant_ids"]]` - -### Step 5: Update `twas_pipeline()` in R/twas.R - -- Lines 529-532: Pass SVD components to `twas_analysis()` instead of `[["LD"]]` -- Line 569: Clear SVD components instead of LD matrix after use - -### Step 6: Tests - -New tests in a new file `tests/testthat/test_twas_sketch.R`: - -1. **Mathematical equivalence test**: Generate X, compute R=cor(X), compute SVD(X_std). Run `twas_z()` both ways, assert identical z-scores within floating-point tolerance. -2. **`load_ld_sketch()` unit test**: Mock `load_LD_matrix()` with known X. Verify SVD standardization, zero-variance column removal, ref_panel passthrough. -3. **`twas_analysis()` SVD path**: Verify partial variant overlap works correctly. -4. **End-to-end mock test**: Full pipeline with SVD path produces correct results. - -### Step 7: NAMESPACE and documentation - -Run `devtools::document()` after adding roxygen to `load_ld_sketch()` and updating param docs on modified functions. - ---- - -## Files to modify - -| File | Changes | -|------|---------| -| `R/LD.R` | Add `load_ld_sketch()` (thin wrapper around existing `load_LD_matrix(return_genotype=TRUE)` + SVD) | -| `R/twas.R` | Modify `twas_z()`, `twas_analysis()`, `harmonize_twas()`, `twas_pipeline()`. Remove `group_contexts_by_region()` | -| `tests/testthat/test_twas_sketch.R` | New test file | -| `NAMESPACE` | Auto-generated via roxygen | - -## Files NOT modified - -| File | Reason | -|------|--------| -| `R/LD.R` (load_LD_matrix, load_LD_from_genotype, load_LD_from_blocks) | Reused as-is; `load_ld_sketch()` calls `load_LD_matrix()` | -| `R/allele_qc.R` | match_ref_panel unchanged | -| `R/file_utils.R` | load_genotype_region unchanged | -| `R/misc.R` | safe_svd, compute_LD unchanged | -| `R/susie_wrapper.R` | adjust_susie_weights unchanged | -| `R/ctwas_wrapper.R` | cTWAS deferred | -| xqtl-protocol repo | Not changed in this PR | - -## Existing functions reused (not reimplemented) - -- `load_LD_matrix()` — R/LD.R:265 — main LD loader; `load_ld_sketch()` calls this with `return_genotype=TRUE` -- `load_LD_from_genotype()` — R/LD.R:408 — called by `load_LD_matrix()` for genotype sources; handles genotype loading, variant ID normalization, ref_panel with allele frequencies/variance, .afreq sidecar -- `safe_svd()` — R/misc.R:156 — SVD with tolerance filtering -- `match_ref_panel()` — R/allele_qc.R:26 — allele QC (reference = sketch variants) -- `ensure_chr_match()` — R/misc.R:480 — chr prefix harmonization - -## Verification - -1. Run `pixi run Rscript -e 'devtools::load_all("."); testthat::test_file("tests/testthat/test_twas_sketch.R")'` -2. Run existing tests to verify no regressions: `pixi run Rscript -e 'devtools::load_all("."); testthat::test_file("tests/testthat/test_twas_method_fallback.R")'` -3. Mathematical equivalence: the key test generates a genotype matrix, computes TWAS z-scores via both R and SVD paths, asserts they match within 1e-10. diff --git a/twas-pipeline-analysis.md b/twas-pipeline-analysis.md deleted file mode 100644 index 28bf19af..00000000 --- a/twas-pipeline-analysis.md +++ /dev/null @@ -1,78 +0,0 @@ -# TWAS Pipeline Analysis - -## End-to-end data flow - -``` -Weight RDS files ──→ load_twas_weights() ──┐ -GWAS TSV (tabix) ──→ harmonize_gwas() ──┤──→ harmonize_twas() ──→ twas_analysis() ──→ results -LD metadata TSV ──→ load_LD_matrix() ──┘ │ │ - │ ↓ - match_ref_panel() z = (wᵀz) / √(wᵀRw) - (allele QC engine) -``` - -## The pipeline has 3 layers - -### Layer 1: `twas_pipeline()` — orchestrator (R/twas.R:326-632) - -The main entry point. It: -1. Optionally filters molecular events via `event_filters` -2. Calls `harmonize_twas()` to load + QC everything -3. For each gene: picks the best model by CV R² (via `pick_best_model()`) -4. For each gene-context-study triple: calls `twas_analysis()` to compute z-scores -5. Optionally runs MR analysis if SuSiE credible sets exist and p-value is small enough -6. Merges CV metrics with TWAS results into a single table -7. Applies `apply_method_fallback()` for NA/Inf z-scores -8. Optionally formats output for cTWAS via `format_twas_data()` - -### Layer 2: `harmonize_twas()` — data harmonization (R/twas.R:30-226) - -The most complex function. It: -1. **Groups contexts by genomic position** — contexts whose weight variants are within 5kb of each other get clustered so they share a single LD query region -2. **Loads LD once** for the combined region via `load_LD_matrix()` -3. For each GWAS study: calls `harmonize_gwas()` which loads data via tabix, standardizes columns, then calls `match_ref_panel()` to align alleles to the LD reference -4. For each context: calls `match_ref_panel()` again to align the weight matrix to LD, then optionally `adjust_susie_weights()` to recalculate SuSiE weights for the variant subset -5. Scales weights by `sqrt(variance)` where variance comes from the LD ref panel -6. Extracts a per-gene LD submatrix for the intersection of GWAS + weight variants - -**The three-way intersection is the core issue** — every variant must be present in all three sources (weights, GWAS, LD) to participate, and alleles must be harmonized across all three. - -### Layer 3: `twas_analysis()` / `twas_z()` — the actual computation (R/twas.R:747-773, 653-667) - -This is tiny and straightforward: -- Subset everything to shared variants -- For each method: `z_twas = (wᵀ * z_gwas) / sqrt(wᵀ * R * w)` -- Return z-score and chi-squared p-value - -## Where the complexity lives - -| Area | What happens | Why it's complex | -|------|-------------|------------------| -| **Data loading** | `load_twas_weights()` merges multivariate (mnm_rs) + univariate weights from multiple RDS files | Deeply nested list structures, context name cleaning, weight alignment across methods | -| **Allele QC** | `match_ref_panel()` handles exact match, sign flip, strand flip, INDEL matching | 6+ boolean flags, inner join on (chrom, pos) then allele matching logic | -| **Context grouping** | `group_contexts_by_region()` clusters contexts by variant position overlap | Hierarchical clustering + IRanges interval merging | -| **SuSiE adjustment** | `adjust_susie_weights()` recalculates weights from log Bayes factors when variants are dropped | Re-derives alpha from LBF, recomputes posterior means | -| **Model selection** | Best model chosen by CV R², with fallback if z-score is NA/Inf | Two-pass: first in `pick_best_model()`, then `apply_method_fallback()` | - -## The xqtl-protocol adds another layer on top - -The `twas_ctwas.ipynb` notebook wraps `twas_pipeline()` with: -1. **Weight validation** — file size checks, tryCatch on readRDS (now ported into pecotmr) -2. **Batch loading** — `batch_load_twas_weights()` splits genes by memory -3. **`update_twas_method()`** — a SECOND method fallback pass after `twas_pipeline()` returns (the in-pecotmr `apply_method_fallback()` was ported from this) -4. **Context name merging** — `merge_context_names()` strips region suffixes from weight names -5. **cTWAS assembly** — ~200 lines loading LD again, re-running `harmonize_gwas()`, calling `trim_ctwas_variants()`, `assemble_region_data()`, then `est_param()`, `screen_regions()`, `finemap_regions()` - -## Key observations - -1. **`match_ref_panel()` is called 3+ times per gene** — once for GWAS vs LD, once for weights vs LD, and once inside `adjust_susie_weights()`. Each call re-parses variant IDs and re-joins. - -2. **LD is loaded once but variant subsetting happens repeatedly** — `harmonize_twas()` loads the full region LD, then extracts per-gene submatrices. But in the cTWAS cell, LD is loaded *again* for the same chromosome. - -3. **The data structure is deeply nested** — `twas_weights_data[[molecular_id]]$weights[[context]]` requires `get_nested_element()` and `find_data()` utilities just to navigate. The same molecular_id/context/study keys appear at multiple levels. - -4. **Model selection happens in two places** — `pick_best_model()` inside `twas_pipeline()` selects the best method before TWAS, then `apply_method_fallback()` fixes bad selections after TWAS. xqtl-protocol had a third pass via `update_twas_method()`. - -5. **Weight scaling is tightly coupled to LD loading** — `scaled_weights = weights * sqrt(variance)` depends on allele frequencies and sample size from the LD ref panel, computed during `load_LD_matrix()`. - -6. **The actual TWAS math is ~10 lines** — the other ~600 lines in twas.R are data loading, harmonization, QC, and result formatting. diff --git a/xqtl-protocol-pecotmr-audit.md b/xqtl-protocol-pecotmr-audit.md deleted file mode 100644 index f98c84e2..00000000 --- a/xqtl-protocol-pecotmr-audit.md +++ /dev/null @@ -1,232 +0,0 @@ -# xqtl-protocol pecotmr Usage Audit - -## 1. API Mismatches (Calls Expected to Fail) - -### A. `load_quantile_twas_weights` does not exist in pecotmr - -**File:** `code/pecotmr_integration/twas_ctwas.ipynb` (quantile_twas cell) - -```r -twas_weights_results[[gene_db]] <- load_quantile_twas_weights( - weight_db_files = weight_dbs, tau_values = tau_values, - between_cluster = 0.8, num_intervals = 3) -``` - -This function is not defined anywhere in pecotmr. It was likely planned as part of the quantile TWAS feature but never implemented (or was removed). This cell will fail outright. - -### B. `twas_pipeline` called with nonexistent `quantile_twas` parameter - -**File:** `code/pecotmr_integration/twas_ctwas.ipynb` (quantile_twas cell) - -```r -twas_results_db <- twas_pipeline(..., quantile_twas = TRUE, ...) -``` - -`twas_pipeline()` at `R/twas.R:290` has no `quantile_twas` parameter and no `...` in its signature. This will error with "unused argument." - -### C. `rss_analysis_pipeline` called with renamed parameter `stochastic_ld_sample` - -**File:** `code/mnm_analysis/mnm_methods/rss_analysis.ipynb` (univariate_rss cell) - -```r -rss_analysis_pipeline(..., stochastic_ld_sample = ${stochastic_ld_sample}, ...) -``` - -This parameter was renamed to `sketch_samples` in the current pecotmr API (`R/univariate_pipeline.R:214`). The function has no `...`, so this will error with "unused argument." - -### D. `load_multitrait_tensorqtl_sumstat` called with wrong parameter name - -**File:** `code/multivariate_genome/MASH/mash_preprocessing.ipynb` (random_null_tensorqtl_1 cell) - -```r -pecotmr::load_multitrait_tensorqtl_sumstat( - phenotype_path = phenotype_path, ..., na_remove = T/F) -``` - -Two problems: -- First parameter is named `sumstats_paths` in pecotmr (`R/mash_wrapper.R:141`), not `phenotype_path` -- The parameter `na_remove` was renamed to `nan_remove` (`R/mash_wrapper.R:143`) - -Both will cause "unused argument" errors since there's no `...`. - -### E. `mash_ran_null_sample` - typo and removed parameters - -**File:** `code/multivariate_genome/MASH/mash_preprocessing.ipynb` (random_null_tensorqtl_1 cell) - -```r -pecotmr::mash_ran_null_sample(dat, n_random, n_null, - expected_ncondition, exclude_condition, z_only = TRUE, seed = ...) -``` - -Three problems: -- Function name is `mash_rand_null_sample` (with a "d") -- `R/mash_wrapper.R:568` -- `expected_ncondition` parameter no longer exists -- `z_only` parameter no longer exists - -The current signature is `mash_rand_null_sample(dat, n_random, n_null, exclude_condition, seed = NULL)`. - -### F. `get_ctwas_meta_data` is deprecated - -**File:** `code/pecotmr_integration/twas_ctwas.ipynb` (ctwas_1 and ctwas_3 cells) - -Used extensively. Still works but emits deprecation warnings and will eventually be removed. The replacement is `ld_loader()` per `R/ctwas_wrapper.R:58`. - ---- - -## 2. Safety/Sanity Checks That Could Move to pecotmr - -### A. Weight file pre-validation - -**File:** twas_ctwas.ipynb (twas cell, quantile_twas cell) - -Before calling `load_twas_weights()`, xqtl-protocol: -- Checks `file.size(file) > 200` (non-trivial file) -- Wraps `readRDS(file)` in `tryCatch` to filter corrupt files -- Validates nested structure (`twas_variant_names` key exists) -- Filters out NULL/empty results - -**Recommendation:** `load_twas_weights()` should do this validation internally -- skip files that are too small, corrupt, or structurally invalid, rather than requiring every caller to implement the same filter. - -### B. NA/Inf z-score filtering after TWAS - -**File:** twas_ctwas.ipynb (ctwas_1 cell) - -```r -z_gene[[study]] <- z_gene[[study]][ - !is.na(z_gene[[study]]$z) & !is.infinite(z_gene[[study]]$z) & - z_gene[[study]]$id %in% names(weight_list[[study]]),] -``` - -**Recommendation:** `twas_pipeline()` or the TWAS z-score computation itself should guarantee clean output. Downstream consumers shouldn't need to re-filter. - -### C. Duplicate LD variant removal - -**File:** twas_ctwas.ipynb (ctwas_1 cell) - -```r -dup_idx <- which(duplicated(LD_list$LD_variants)) -if (length(dup_idx) >= 1) LD_list$LD_matrix <- LD_list$LD_matrix[-dup_idx, -dup_idx] -``` - -**Recommendation:** `load_LD_matrix()` should handle this internally. Duplicate variants in the LD matrix are a data integrity issue that the loader should resolve before returning. - -### D. GWAS sample size validation - -**File:** twas_ctwas.ipynb (ctwas_1 cell) - -```r -if(length(z_snp[['sample_size']][[study]]!=1) | z_snp[['sample_size']][[study]] <= 0) { - stop("Please check sample size provided for ", study, " at --gwas_meta_data. ") -} -``` - -**Recommendation:** Could be validated inside `harmonize_gwas()` or the GWAS metadata loading step in pecotmr. - -### E. chr prefix normalization - -Scattered across multiple locations: -- ctwas_1: `ifelse(grepl("^chr", snp_map$id), snp_map$id, paste0("chr", snp_map$id))` -- mnm_postprocessing: `if(any(grepl("chr", qtl_all_var))) add_chr_prefix(gwas_all_var) else gsub("chr", "", gwas_all_var)` -- mash_preprocessing: retry with `gsub("chr", "", region)` on failure - -**Recommendation:** pecotmr already has `normalize_variant_id()` and internal `strip_chr_prefix()`, but variant ID harmonization at the "chr" level should be consistently handled in all loading functions rather than requiring callers to do it. - -### F. Genomic region overlap detection - -**File:** SuSiE_enloc.ipynb (susie_coloc cell) - -Manual region parsing and overlap checking: -```r -split_region <- unlist(strsplit(region, "_")) -block_chrom <- as.numeric(split_region[1] %>% gsub("chr","",.)) -block_start <- ... -if (gene_region$chrom == block_chrom && - (gene_region$start <= block_end | gene_region$end >= block_start)) -``` - -**Recommendation:** pecotmr has `parse_region()` and `region_to_df()` but lacks a simple `regions_overlap(a, b)` utility. This pattern is repeated enough to justify one. - ---- - -## 3. Generalizable Pipeline Logic Worth Moving to pecotmr - -### A. TWAS method selection with fallback (high value) - -**File:** twas_ctwas.ipynb -- `update_twas_method()` function - -This ~40-line function handles a real problem: the "best" TWAS method (by cross-validation) sometimes produces NA/Inf results for a specific GWAS. It falls back to the next-best method by rsq. This logic is not xQTL-specific -- it applies to any TWAS analysis. - -**What it does:** For each gene-context-GWAS group, if the selected method yielded invalid z/p-values, pick the best alternative method that has valid results and meets the rsq threshold. - -### B. TWAS-to-cTWAS region assembly orchestration (high value) - -**File:** twas_ctwas.ipynb (ctwas_1 cell) - -The entire workflow of: -1. Loading per-region TWAS results -2. Trimming variants via `trim_ctwas_variants()` -3. Getting chromosome-wide LD variant info -4. Harmonizing GWAS via `harmonize_gwas()` -5. Re-computing TWAS z-scores when variants are trimmed (calling `twas_analysis()` with fresh LD) -6. Assembling into cTWAS region data via `assemble_region_data()` - -This is ~200 lines of orchestration that any TWAS-to-cTWAS pipeline would need. It currently depends on a few ctwas package functions but the overall flow is generalizable. - -### C. cTWAS fine-mapping with LD diagnosis and recovery (high value) - -**File:** twas_ctwas.ipynb (ctwas_3 cell) - -The workflow of: -1. Screen regions -> fine-map -> diagnose LD mismatch -> identify problematic genes -> re-fine-map without LD -> merge boundary regions - -This is a robust, production-tested recipe for dealing with real-world LD mismatches. It's not specific to xQTL data at all. - -### D. GWAS metadata loading and per-study LD caching (medium value) - -**Files:** rss_analysis.ipynb, twas_ctwas.ipynb - -The pattern of: -- Reading a GWAS metadata TSV with study_id, chrom, file_path, sample_size columns -- Mapping studies to per-region file paths -- Caching LD matrices by study to avoid re-loading - -This is boilerplate that every multi-study RSS analysis repeats. The Python `load_regional_rss_data()` function in rss_analysis.ipynb does something similar -- it could inform an R equivalent. - -### E. MASH data batching and merging (medium value) - -**File:** mash_preprocessing.ipynb (susie_to_mash_1, susie_to_mash_2 cells) - -The pipeline of: -1. Processing regions in chunks (`per_chunk = 100`) -2. Extracting strong/random/null z-score matrices per region -3. Renaming rownames with region IDs for uniqueness -4. Merging across regions with `merge_mash_data()` -5. Filtering invalid entries with `filter_invalid_summary_stat()` -6. Computing `ZtZ = t(Z) %*% Z / n` - -The batching/merging/filtering logic is generalizable. pecotmr already has `merge_mash_data()` and `filter_invalid_summary_stat()`, but the end-to-end orchestration (batch -> merge -> filter -> ZtZ) could be a single pipeline function. - -### F. QTL-GWAS overlap analysis pipeline (medium value, Python) - -**File:** mnm_postprocessing.ipynb (overlap_qtl_gwas cells, Python) - -Loads QTL and GWAS metadata, groups by chromosome, checks region overlap, intersects variant lists with chr-prefix harmonization. This is applicable to any pairwise colocalization setup, not just xQTL. - -### G. Variant feature engineering (lower value, Python) - -**File:** gems_pipeline.py - -Parsing "chr:pos:ref:alt" variant IDs into structured fields and classifying as SNP/indel/insertion/deletion. Simple but broadly useful. pecotmr has `parse_variant_id()` already but the SNP/indel classification could be an addition. - ---- - -## Summary - -| Category | Count | Severity | -|----------|-------|----------| -| Calls that will fail | 5 (A-E) | Blocking | -| Deprecated but still working | 1 (F) | Warning | -| Sanity checks to absorb | 6 (A-F) | Robustness | -| Generalizable pipeline logic | 7 (A-G) | Feature opportunities | - -The most impactful items are the **quantile TWAS breakage** (the entire quantile_twas workflow is dead -- both `load_quantile_twas_weights` and the `quantile_twas` param to `twas_pipeline` don't exist), the **`stochastic_ld_sample` rename**, and the **MASH parameter name changes**. For generalizable logic, the **TWAS method fallback** and **cTWAS assembly/diagnosis pipelines** are the highest-value candidates to move into pecotmr. From 1ee341c46ab15710837f511642445759867f6c74 Mon Sep 17 00:00:00 2001 From: Daniel Nachun Date: Mon, 25 May 2026 13:14:23 -0500 Subject: [PATCH 04/11] summary pipeline refactor --- DESCRIPTION | 1 - NAMESPACE | 7 +- R/AllClasses.R | 10 +- R/AllMethods.R | 20 +- R/LD.R | 36 ++- R/encoloc.R | 29 ++- R/misc.R | 66 ----- R/mr.R | 1 + R/otters.R | 284 --------------------- R/regularized_regression.R | 86 +++++++ R/slalom.R | 32 ++- R/sumstats_qc.R | 15 +- R/susie_wrapper.R | 91 +++++-- R/twas.R | 192 +++++++++++--- R/twas_weights.R | 293 ++++++++++++++++++++++ R/univariate_pipeline.R | 35 ++- man/TWASWeights-class.Rd | 5 + man/TWASWeights.Rd | 12 +- man/classify_variant_type.Rd | 2 +- man/find_overlapping_regions.Rd | 2 +- man/fit_susie_inf_then_susie_rss.Rd | 41 +++ man/load_multitask_regional_data.Rd | 4 +- man/otters_association.Rd | 70 ------ man/otters_weights.Rd | 80 ------ man/postprocess_finemapping_fits.Rd | 3 +- man/regions_overlap.Rd | 2 +- man/susie_ash_rss_weights.Rd | 35 +++ man/susie_inf_rss_weights.Rd | 35 +++ man/susie_rss_weights.Rd | 35 +++ man/twas_analysis.Rd | 30 ++- man/twas_joint_z.Rd | 17 +- man/twas_weights_sumstat_pipeline.Rd | 78 ++++++ tests/testthat/test_encoloc.R | 2 +- tests/testthat/test_otters.R | 210 ---------------- tests/testthat/test_slalom.R | 4 +- tests/testthat/test_susie_wrapper.R | 10 +- tests/testthat/test_twas_weights_rss.R | 195 ++++++++++++++ tests/testthat/test_univariate_pipeline.R | 3 +- 38 files changed, 1252 insertions(+), 821 deletions(-) delete mode 100644 R/otters.R create mode 100644 man/fit_susie_inf_then_susie_rss.Rd delete mode 100644 man/otters_association.Rd delete mode 100644 man/otters_weights.Rd create mode 100644 man/susie_ash_rss_weights.Rd create mode 100644 man/susie_inf_rss_weights.Rd create mode 100644 man/susie_rss_weights.Rd create mode 100644 man/twas_weights_sumstat_pipeline.Rd delete mode 100644 tests/testthat/test_otters.R create mode 100644 tests/testthat/test_twas_weights_rss.R diff --git a/DESCRIPTION b/DESCRIPTION index 185b8f85..590b817a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -118,7 +118,6 @@ Collate: 'mrmash_wrapper.R' 'multitrait_data.R' 'multivariate_pipeline.R' - 'otters.R' 'pval_combine.R' 'raiss.R' 'regularized_regression.R' diff --git a/NAMESPACE b/NAMESPACE index 45f802ef..da26439a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -67,6 +67,7 @@ export(find_duplicate_variants) export(find_overlapping_regions) export(fine_mr) export(fit_mash_contrast) +export(fit_susie_inf_then_susie_rss) export(format_finemapping_output) export(fsusie_get_cs) export(fsusie_wrapper) @@ -152,8 +153,6 @@ export(multivariate_analysis_pipeline) export(mvsusie_weights) export(nSnps) export(normalize_variant_id) -export(otters_association) -export(otters_weights) export(parse_cs_corr) export(parse_region) export(parse_variant_id) @@ -189,9 +188,12 @@ export(standardise_sumstats_columns) export(standardize_sldsc_trait) export(subsetChr) export(summary_stats_qc) +export(susie_ash_rss_weights) export(susie_ash_weights) +export(susie_inf_rss_weights) export(susie_inf_weights) export(susie_rss_pipeline) +export(susie_rss_weights) export(susie_weights) export(top_loci_to_granges) export(trim_ctwas_variants) @@ -203,6 +205,7 @@ export(twas_predict) export(twas_weights) export(twas_weights_cv) export(twas_weights_pipeline) +export(twas_weights_sumstat_pipeline) export(twas_z) export(univariate_analysis_pipeline) export(update_mash_model_cov) diff --git a/R/AllClasses.R b/R/AllClasses.R index cfa9f58c..cbb1f83d 100644 --- a/R/AllClasses.R +++ b/R/AllClasses.R @@ -404,6 +404,10 @@ setClass("FineMappingResult", #' @slot fits Named list of model fit objects, or NULL. #' @slot cv_performance Named list of cross-validation performance #' metrics, or NULL. +#' @slot standardized Logical, whether weights are on standardized +#' (correlation) scale. If TRUE, \code{harmonize_twas} skips the +#' \code{sqrt(variance)} scaling step. Individual-level weights use +#' FALSE (raw genotype scale); RSS weights use TRUE. #' @export setClass("TWASWeights", representation( @@ -411,10 +415,13 @@ setClass("TWASWeights", variant_ids = "character", methods = "character", fits = "ANY", # list or NULL - cv_performance = "ANY" # list or NULL + cv_performance = "ANY", # list or NULL + standardized = "logical" ), validity = function(object) { errors <- character() + if (length(object@standardized) != 1L) + errors <- c(errors, "'standardized' must be a single logical value") if (length(object@methods) != length(object@weights)) errors <- c(errors, "Length of 'methods' must match length of 'weights'") @@ -576,6 +583,7 @@ setMethod("show", "TWASWeights", function(object) { cat(sprintf("TWASWeights: %d methods, %d variants\n", length(object@methods), length(object@variant_ids))) cat(sprintf(" Methods: %s\n", paste(object@methods, collapse = ", "))) + cat(sprintf(" Standardized: %s\n", object@standardized)) has_cv <- !is.null(object@cv_performance) cat(sprintf(" CV performance: %s\n", has_cv)) }) diff --git a/R/AllMethods.R b/R/AllMethods.R index cc24e9a6..bf827869 100644 --- a/R/AllMethods.R +++ b/R/AllMethods.R @@ -103,7 +103,7 @@ setMethod("getBlockMetadata", "LDData", function(x) { #' is_genotype. #' @keywords internal #' @noRd -ld_data_to_list <- function(x) { +ld_data_to_list <- function(x, skip_correlation = FALSE) { mc <- as.data.frame(mcols(x@variants)) mc$chrom <- as.character(seqnames(x@variants)) mc$pos <- start(x@variants) @@ -114,9 +114,17 @@ ld_data_to_list <- function(x) { bm <- as.data.frame(bm@blocks) } + LD_matrix <- if (skip_correlation) { + NULL + } else if (!is.null(x@correlation)) { + x@correlation + } else { + getCorrelation(x) + } + list( LD_variants = getVariantIds(x), - LD_matrix = if (!is.null(x@correlation)) x@correlation else getCorrelation(x), + LD_matrix = LD_matrix, ref_panel = ref_panel, block_metadata = bm, is_genotype = FALSE @@ -389,16 +397,20 @@ setMethod("getEffects", "FineMappingResult", function(x) { #' @param variant_ids Character vector. #' @param fits Named list or NULL. #' @param cv_performance Named list or NULL. +#' @param standardized Logical. If TRUE, weights are on the standardized +#' (correlation) scale and do not need variance scaling in harmonize_twas. +#' Defaults to FALSE (individual-level / raw genotype scale). #' @return A \code{TWASWeights} object. #' @export TWASWeights <- function(weights, variant_ids, fits = NULL, - cv_performance = NULL) { + cv_performance = NULL, standardized = FALSE) { new("TWASWeights", weights = weights, variant_ids = variant_ids, methods = names(weights), fits = fits, - cv_performance = cv_performance + cv_performance = cv_performance, + standardized = standardized ) } diff --git a/R/LD.R b/R/LD.R index f5df6dd2..c113fde2 100644 --- a/R/LD.R +++ b/R/LD.R @@ -284,14 +284,24 @@ load_LD_matrix <- function(LD_meta_file_path, region, extract_coordinates = NULL } # Remove any duplicate variant IDs (safety net for boundary overlaps) - if (!is.null(result$LD_variants)) { - dup_idx <- which(duplicated(result$LD_variants)) + variant_ids <- getVariantIds(result) + if (!is.null(variant_ids)) { + dup_idx <- which(duplicated(variant_ids)) if (length(dup_idx) > 0) { - result$LD_variants <- result$LD_variants[-dup_idx] - result$LD_matrix <- result$LD_matrix[-dup_idx, -dup_idx, drop = FALSE] - if (!is.null(result$ref_panel)) { - result$ref_panel <- result$ref_panel[-dup_idx, , drop = FALSE] + variant_ids_clean <- variant_ids[-dup_idx] + corr <- getCorrelation(result) + if (!is.null(corr)) { + corr <- corr[-dup_idx, -dup_idx, drop = FALSE] } + variants_gr <- result@variants[-dup_idx] + result <- LDData( + correlation = corr, + genotype_handle = result@genotype_handle, + snp_idx = result@snp_idx, + variants = variants_gr, + block_metadata = result@block_metadata, + n_ref = result@n_ref + ) } } @@ -527,9 +537,17 @@ standardize_genotype_hwe <- function(X, allele_freq) { #' @export load_ld_sketch <- function(ld_meta_file_path, region, n_sample = NULL) { result <- load_LD_matrix(ld_meta_file_path, region, return_genotype = TRUE, n_sample = n_sample) - X <- result$LD_matrix - variant_ids <- result$LD_variants - ref_panel <- result$ref_panel + if (is(result, "LDData")) { + X <- getGenotypes(result) + variant_ids <- getVariantIds(result) + ref_panel <- as.data.frame(S4Vectors::mcols(getVariantInfo(result))) + ref_panel$chrom <- as.character(GenomicRanges::seqnames(getVariantInfo(result))) + ref_panel$pos <- GenomicRanges::start(getVariantInfo(result)) + } else { + X <- result$LD_matrix + variant_ids <- result$LD_variants + ref_panel <- result$ref_panel + } # Remove monomorphic variants (zero variance under HWE) p <- ref_panel$allele_freq diff --git a/R/encoloc.R b/R/encoloc.R index 56206bf8..3eec07aa 100644 --- a/R/encoloc.R +++ b/R/encoloc.R @@ -114,18 +114,37 @@ extract_ld_for_variants <- function(ld_meta_file_path, analysis_region, variants var_pos <- as.numeric(str_split(variants, ":", simplify = TRUE)[, 2]) chr <- str_split(analysis_region, ":", simplify = TRUE)[, 1] region_narrow <- paste0(chr, ":", min(var_pos), "-", max(var_pos)) - ld_data <- load_LD_matrix(ld_meta_file_path, region = region_narrow) + ld_data <- load_LD_matrix(ld_meta_file_path, region = region_narrow, + return_genotype = "auto") # Support both LDData S4 objects and legacy lists if (is(ld_data, "LDData")) { ld_variants <- getVariantIds(ld_data) - ld_matrix <- getCorrelation(ld_data) + has_geno <- hasGenotypes(ld_data) } else { ld_variants <- ld_data$LD_variants - ld_matrix <- ld_data$LD_matrix + has_geno <- isTRUE(ld_data$is_genotype) } aligned <- align_variant_names(ld_variants, variants) - colnames(ld_matrix) <- rownames(ld_matrix) <- aligned$aligned_variants - ld_matrix[variants, variants] + # When genotypes available, compute R only for the needed variant subset + if (has_geno) { + if (is(ld_data, "LDData")) { + X <- getGenotypes(ld_data) + } else { + X <- ld_data$LD_matrix + } + colnames(X) <- aligned$aligned_variants + X_sub <- X[, variants, drop = FALSE] + ld_matrix <- compute_LD(X_sub, method = "sample") + } else { + if (is(ld_data, "LDData")) { + ld_matrix <- getCorrelation(ld_data) + } else { + ld_matrix <- ld_data$LD_matrix + } + colnames(ld_matrix) <- rownames(ld_matrix) <- aligned$aligned_variants + ld_matrix <- ld_matrix[variants, variants] + } + ld_matrix } #' Function to calculate purity diff --git a/R/misc.R b/R/misc.R index ae6db864..a1d264a6 100644 --- a/R/misc.R +++ b/R/misc.R @@ -560,72 +560,6 @@ find_data <- function(x, depth_obj, show_path = FALSE, rm_null = TRUE, rm_dup = } -#' Convert region specifications to a GRanges object -#' -#' Accepts region strings ("chr1:100-200", "1_100_200"), character vectors of -#' such strings, or data.frames with chrom/start/end columns. Returns a -#' \code{\link[GenomicRanges]{GRanges}} object. -#' -#' @param regions A region string, character vector, or data.frame with -#' chrom/start/end columns. -#' @return A \code{GRanges} object. -#' @noRd -as_granges <- function(regions) { - if (is.character(regions)) { - df <- region_to_df(regions) - } else if (is.data.frame(regions)) { - if (!all(c("chrom", "start", "end") %in% names(regions))) { - stop("data.frame must have columns: chrom, start, end") - } - df <- regions - } else { - stop("regions must be a character vector or data.frame with chrom/start/end columns") - } - # GRanges expects character seqnames; prefix with "chr" if numeric - seqnames <- as.character(df$chrom) - if (!any(grepl("^chr", seqnames))) { - seqnames <- paste0("chr", seqnames) - } - GenomicRanges::GRanges( - seqnames = seqnames, - ranges = IRanges::IRanges(start = as.integer(df$start), end = as.integer(df$end)) - ) -} - -#' Test whether two genomic regions overlap -#' -#' @param region_a A region string ("chr1:100-200" or "1_100_200") or a -#' single-row data.frame with chrom/start/end columns. -#' @param region_b A region string or single-row data.frame. -#' @return Logical scalar: TRUE if the regions share at least one base pair. -#' @importFrom GenomicRanges GRanges -#' @importFrom IRanges IRanges findOverlaps -#' @export -regions_overlap <- function(region_a, region_b) { - gr_a <- as_granges(region_a) - gr_b <- as_granges(region_b) - length(IRanges::findOverlaps(gr_a, gr_b)) > 0 -} - -#' Find which target regions overlap a query region -#' -#' @param query A single region string or single-row data.frame with -#' chrom/start/end columns. -#' @param targets A character vector of region strings, or a multi-row -#' data.frame with chrom/start/end columns. -#' @return Integer vector of 1-based indices into \code{targets} that overlap -#' the query. Empty integer vector if no overlaps. -#' @importFrom GenomicRanges GRanges -#' @importFrom IRanges IRanges findOverlaps -#' @importFrom S4Vectors subjectHits -#' @export -find_overlapping_regions <- function(query, targets) { - gr_query <- as_granges(query) - gr_targets <- as_granges(targets) - hits <- IRanges::findOverlaps(gr_query, gr_targets) - unique(S4Vectors::subjectHits(hits)) -} - thisFile <- function() { cmdArgs <- commandArgs(trailingOnly = FALSE) needle <- "--file=" diff --git a/R/mr.R b/R/mr.R index 6d3a79a0..fba96e7c 100644 --- a/R/mr.R +++ b/R/mr.R @@ -98,6 +98,7 @@ mr_format <- function(susie_result, condition, gwas_sumstats_db, coverage = NULL susie_cs_result_formatted <- susie_cs_result_formatted$target_data_qced[, c("gene_name", "variant_id", "bhat_x", "sbhat_x", "cs", "pip")] } # Ensure consistent chr prefix convention before intersecting + if (nrow(susie_cs_result_formatted) == 0) return(.create_null_mr_df(gene_name, mr_format_spec)) if (!is.null(susie_cs_result_formatted$variant_id) && !is.null(gwas_sumstats_db_extracted$variant_id)) { chr_matched <- ensure_chr_match(susie_cs_result_formatted$variant_id, gwas_sumstats_db_extracted$variant_id) susie_cs_result_formatted$variant_id <- chr_matched$ids_a diff --git a/R/otters.R b/R/otters.R deleted file mode 100644 index 70afe39b..00000000 --- a/R/otters.R +++ /dev/null @@ -1,284 +0,0 @@ -#' Train eQTL weights using multiple RSS methods (OTTERS Stage I) -#' -#' Implements the training stage of the OTTERS framework (Omnibus Transcriptome -#' Test using Expression Reference Summary data, Zhang et al. 2024). Trains -#' eQTL effect size weights for a gene region using multiple summary-statistics-based -#' methods in parallel, enabling downstream omnibus TWAS testing. -#' -#' Methods are dispatched dynamically via \code{do.call(paste0(method, "_weights"), ...)}, -#' so any function following the \code{*_weights(stat, LD, ...)} convention can be used -#' (e.g., \code{lassosum_rss_weights}, \code{prs_cs_weights}, \code{sdpr_weights}, -#' \code{mr_ash_rss_weights}). -#' -#' P+T (pruning and thresholding) is handled internally: for each threshold, SNPs with -#' eQTL p-value below the threshold are selected, and their marginal z-scores (scaled -#' to correlation units: \code{z / sqrt(n)}) are used as weights. -#' -#' @param sumstats A data.frame of eQTL summary statistics. Must contain column \code{z} -#' (z-scores). If \code{z} is absent but \code{beta} and \code{se} are present, -#' z-scores are computed as \code{beta / se}. -#' @param LD LD correlation matrix R for the gene region (single matrix, not a list). -#' Should have row/column names matching variant identifiers if variant alignment -#' is desired. -#' @param n eQTL study sample size (scalar). -#' @param methods Named list of RSS methods and their extra arguments. Each element -#' name must correspond to a \code{*_weights} function in pecotmr (without the -#' \code{_weights} suffix). Defaults match the original OTTERS pipeline -#' (Zhang et al. 2024): -#' \itemize{ -#' \item \code{lassosum_rss}: s grid = c(0.2, 0.5, 0.9, 1.0), lambda from -#' 0.0001 to 0.1 (20 values on log scale) -#' \item \code{prs_cs}: phi = 1e-4 (fixed, not learned), 1000 iterations, -#' 500 burn-in, thin = 5 -#' \item \code{sdpr}: 1000 iterations, 200 burn-in, thin = 1 (no thinning) -#' } -#' To add learners (e.g., \code{mr_ash_rss}), simply append to this list. -#' @param p_thresholds Numeric vector of p-value thresholds for P+T. Set to -#' \code{NULL} to skip P+T. Default: \code{c(0.001, 0.05)}. -#' @param check_ld_method LD quality check method passed to \code{\link{check_ld}}. -#' Default \code{"eigenfix"} sets negative eigenvalues to zero (required for -#' PRS-CS Cholesky, matching OTTERS' SVD-based PD forcing). Set to \code{NULL} -#' to skip checking. -#' -#' @return A named list of weight vectors (one per method). Each vector has length -#' equal to \code{nrow(sumstats)}. P+T results are named \code{PT_}. -#' -#' @examples -#' set.seed(42) -#' n <- 500; p <- 20 -#' z <- rnorm(p, sd = 2) -#' R <- diag(p) -#' sumstats <- data.frame(z = z) -#' weights <- otters_weights(sumstats, R, n, -#' methods = list(lassosum_rss = list()), -#' p_thresholds = c(0.05)) -#' -#' @export -otters_weights <- function(sumstats, LD, n, - methods = list( - lassosum_rss = list(), - prs_cs = list(phi = 1e-4, - n_iter = 1000, n_burnin = 500, thin = 5), - sdpr = list(iter = 1000, burn = 200, thin = 1, verbose = FALSE) - ), - p_thresholds = c(0.001, 0.05), - check_ld_method = "eigenfix") { - # Check and optionally repair LD matrix quality - # PRS-CS requires positive-definite LD for Cholesky; OTTERS forces PD via SVD. - # Default "eigenfix" sets negative eigenvalues to 0 (susieR approach). - # Set to NULL to skip (e.g., if LD is known to be clean). - if (!is.null(check_ld_method)) { - ld_check <- check_ld(LD, method = check_ld_method) - if (ld_check$method_applied != "none") { - message(sprintf("check_ld: repaired LD via '%s' (min eigenvalue was %.2e, %d negative).", - ld_check$method_applied, ld_check$min_eigenvalue, ld_check$n_negative)) - } - LD <- ld_check$R - } - - # Compute z-scores if not present - if (is.null(sumstats$z)) { - if (!is.null(sumstats$beta) && !is.null(sumstats$se)) { - sumstats$z <- sumstats$beta / sumstats$se - } else { - stop("sumstats must have 'z' or ('beta' and 'se') columns.") - } - } - - p <- nrow(sumstats) - z <- sumstats$z - - # Build stat object for _weights() convention - b <- z / sqrt(n) - stat <- list(b = b, cor = b, z = z, n = rep(n, p)) - - results <- list() - - # --- P+T (Pruning and Thresholding) --- - if (!is.null(p_thresholds)) { - pvals <- pchisq(z^2, df = 1, lower.tail = FALSE) - for (thr in p_thresholds) { - selected <- pvals < thr - # Weights = clamped marginal correlation (stat$b) for selected SNPs - w <- ifelse(selected, stat$b, 0) - results[[paste0("PT_", thr)]] <- w - } - } - - # --- RSS methods --- - for (method_name in names(methods)) { - fn_name <- paste0(method_name, "_weights") - if (!exists(fn_name, mode = "function")) { - warning(sprintf("Method '%s' not found (looking for function '%s'). Skipping.", - method_name, fn_name)) - next - } - tryCatch({ - w <- do.call(fn_name, c(list(stat = stat, LD = LD), methods[[method_name]])) - results[[method_name]] <- as.numeric(w) - }, error = function(e) { - warning(sprintf("Method '%s' failed: %s", method_name, e$message)) - results[[method_name]] <<- rep(0, p) - }) - } - - results -} - - -#' TWAS association testing with omnibus combination (OTTERS Stage II) -#' -#' Computes per-method TWAS z-scores using the FUSION formula and combines -#' p-values across methods via ACAT (Aggregated Cauchy Association Test) or -#' HMP (Harmonic Mean P-value). -#' -#' The FUSION TWAS statistic (Gusev et al. 2016) is: -#' \deqn{Z_{TWAS} = \frac{w^T z}{\sqrt{w^T R w}}} -#' where \eqn{w} are eQTL weights, \eqn{z} are GWAS z-scores, and \eqn{R} -#' is the LD correlation matrix. -#' -#' @param weights Named list of weight vectors (output from \code{\link{otters_weights}} -#' or any named list of numeric vectors). -#' @param gwas_z Numeric vector of GWAS z-scores, same length and order as the -#' weights vectors. Must be aligned to the same variants and allele orientation -#' as the weights and LD matrix. Use \code{\link{allele_qc}} or -#' \code{\link{rss_basic_qc}} for harmonization before calling this function. -#' @param LD LD correlation matrix R, aligned to the same variants as weights -#' and gwas_z. -#' @param combine_method Method to combine p-values across methods. -#' Correlation-free (valid under arbitrary dependence): -#' \code{"acat"} (default), \code{"hmp"}. -#' Correlation-adjusted via poolr (generalized multivariate theory): -#' \code{"fisher"} (Brown's method), \code{"stouffer"} (Strube's method), -#' \code{"invchisq"}. -#' Set-based tests via GBJ (uses TWAS z-scores and inter-method correlation): -#' \code{"gbj"}, \code{"bj"}, \code{"hc"}, \code{"ghc"}, \code{"minp"}, -#' \code{"gbj_omni"}. -#' Adaptive and Simes-type tests via aSPU: -#' \code{"aspu"} (adaptive sum of powered scores), -#' \code{"gates"} (extended Simes / GATES). -#' The poolr, GBJ, and aSPU methods automatically compute the inter-method -#' TWAS z-score correlation from the weight vectors and LD matrix. -#' -#' @return A data.frame with columns: -#' \describe{ -#' \item{method}{Method name (per-method rows plus a combined row).} -#' \item{twas_z}{TWAS z-score (\code{NA} for combined row).} -#' \item{twas_pval}{TWAS p-value.} -#' \item{n_snps}{Number of non-zero weight SNPs used.} -#' } -#' -#' @examples -#' set.seed(42) -#' p <- 20 -#' gwas_z <- rnorm(p) -#' R <- diag(p) -#' weights <- list(method1 = rnorm(p, sd = 0.01), method2 = rnorm(p, sd = 0.01)) -#' otters_association(weights, gwas_z, R) -#' -#' @export -otters_association <- function(weights, gwas_z, LD, - combine_method = c("acat", "hmp", - "fisher", "stouffer", "invchisq", - "gbj", "bj", "hc", "ghc", - "minp", "gbj_omni", - "aspu", "gates")) { - combine_method <- match.arg(combine_method) - - # Validate dimensions - p <- length(gwas_z) - if (nrow(LD) != p || ncol(LD) != p) { - stop(sprintf("LD dimensions (%d x %d) do not match gwas_z length (%d).", - nrow(LD), ncol(LD), p)) - } - for (nm in names(weights)) { - if (length(weights[[nm]]) != p) { - stop(sprintf("Weight vector '%s' has length %d but gwas_z has length %d.", - nm, length(weights[[nm]]), p)) - } - } - - results <- data.frame( - method = character(), - twas_z = numeric(), - twas_pval = numeric(), - n_snps = integer(), - stringsAsFactors = FALSE - ) - - valid_pvals <- c() - valid_zscores <- c() - valid_weights <- list() - - for (method_name in names(weights)) { - w <- weights[[method_name]] - - # Skip all-zero weights - if (all(w == 0)) { - results <- rbind(results, data.frame( - method = method_name, twas_z = NA_real_, - twas_pval = NA_real_, n_snps = 0L, - stringsAsFactors = FALSE - )) - next - } - - # Use non-zero SNPs - nz <- which(w != 0) - n_snps <- length(nz) - - # Compute TWAS z-score via twas_z() - res <- twas_z(weights = w[nz], z = gwas_z[nz], R = LD[nz, nz, drop = FALSE]) - - z_val <- as.numeric(res$z) - p_val <- as.numeric(res$pval) - - results <- rbind(results, data.frame( - method = method_name, twas_z = z_val, - twas_pval = p_val, n_snps = n_snps, - stringsAsFactors = FALSE - )) - - if (!is.na(p_val) && is.finite(p_val) && p_val > 0 && p_val < 1) { - valid_pvals <- c(valid_pvals, p_val) - valid_zscores <- c(valid_zscores, z_val) - valid_weights[[length(valid_weights) + 1]] <- w - } - } - - # Combine p-values across methods - if (length(valid_pvals) >= 2) { - poolr_methods <- c("fisher", "stouffer", "invchisq") - gbj_methods <- c("gbj", "bj", "hc", "ghc", "minp", "gbj_omni") - aspu_methods <- c("aspu", "gates") - needs_cor <- combine_method %in% c(poolr_methods, gbj_methods, aspu_methods) - - method_cor <- NULL - if (needs_cor) { - method_cor <- twas_method_cor(valid_weights, LD) - } - - combined_pval <- if (combine_method == "acat") { - pval_acat(valid_pvals) - } else if (combine_method == "hmp") { - pval_hmp(valid_pvals) - } else if (combine_method %in% poolr_methods) { - pval_poolr(valid_pvals, combine_method, R = method_cor) - } else if (combine_method %in% gbj_methods) { - pval_gbj(valid_zscores, method_cor, combine_method) - } else if (combine_method %in% aspu_methods) { - pval_aspu(z_scores = valid_zscores, pvals = valid_pvals, - R = method_cor, method = combine_method) - } - - results <- rbind(results, data.frame( - method = paste0(toupper(combine_method), "_combined"), - twas_z = NA_real_, - twas_pval = as.numeric(combined_pval), - n_snps = NA_integer_, - stringsAsFactors = FALSE - )) - } - - results -} diff --git a/R/regularized_regression.R b/R/regularized_regression.R index ef1502b1..614c3328 100644 --- a/R/regularized_regression.R +++ b/R/regularized_regression.R @@ -347,6 +347,92 @@ susie_inf_weights <- function(X = NULL, y = NULL, susie_inf_fit = NULL, retain_f retain_fit = retain_fit, ...) } +# ============================================================================= +# SuSiE-RSS weight functions +# ============================================================================= + +# Internal helper: extract weights from a susie_rss fit. +# Mirrors .susie_extract_weights but uses the RSS interface. +#' @importFrom susieR coef.susie susie_rss +#' @noRd +.susie_rss_extract_weights <- function(fit, z, R, n, + required_fields, fit_args = list(), + retain_fit = FALSE) { + if (is.null(fit)) { + fit <- do.call(susie_rss, c(list(z = z, R = R, n = n), fit_args)) + } + if (length(fit$pip) != nrow(R)) { + stop(paste0( + "Dimension mismatch: susie_rss fit has ", length(fit$pip), + " variants but R has ", nrow(R), " rows.")) + } + if (all(required_fields %in% names(fit))) { + fit$intercept <- 0 + weights <- coef.susie(fit)[-1] + } else { + weights <- rep(0, length(fit$pip)) + } + if (retain_fit) attr(weights, "fit") <- fit + return(weights) +} + +#' Compute SuSiE-RSS TWAS weights +#' +#' Extracts coefficients from an existing SuSiE-RSS fit or fits +#' \code{susieR::susie_rss()} from summary statistics and LD. +#' +#' @param stat List with components \code{z} (z-scores), \code{n} (sample sizes). +#' @param LD LD correlation matrix. +#' @param susie_rss_fit Optional pre-fitted SuSiE-RSS object. +#' @param retain_fit If TRUE, stores the fitted object as an attribute. +#' @param method_args Named list of additional arguments passed to +#' \code{susieR::susie_rss()}. Use this instead of \code{...} to avoid +#' partial matching of short argument names (e.g. \code{L}) to the +#' \code{LD} parameter. +#' @return Numeric vector of variant weights. +#' @export +susie_rss_weights <- function(stat, LD, susie_rss_fit = NULL, retain_fit = TRUE, + method_args = list()) { + .susie_rss_extract_weights(fit = susie_rss_fit, z = stat$z, R = LD, n = median(stat$n), + required_fields = c("alpha", "mu", "X_column_scale_factors"), + fit_args = method_args, + retain_fit = retain_fit) +} + +#' Compute SuSiE-inf-RSS TWAS weights +#' +#' Extracts coefficients from an existing SuSiE-inf-RSS fit or fits +#' \code{susieR::susie_rss()} with \code{unmappable_effects = "inf"}. +#' +#' @inheritParams susie_rss_weights +#' @param susie_inf_rss_fit Optional pre-fitted SuSiE-inf-RSS object. +#' @return Numeric vector of variant weights. +#' @export +susie_inf_rss_weights <- function(stat, LD, susie_inf_rss_fit = NULL, retain_fit = TRUE, + method_args = list()) { + .susie_rss_extract_weights(fit = susie_inf_rss_fit, z = stat$z, R = LD, n = median(stat$n), + required_fields = c("alpha", "mu", "theta", "X_column_scale_factors"), + fit_args = c(list(unmappable_effects = "inf", convergence_method = "pip"), method_args), + retain_fit = retain_fit) +} + +#' Compute SuSiE-ASH-RSS TWAS weights +#' +#' Extracts coefficients from an existing SuSiE-ASH-RSS fit or fits +#' \code{susieR::susie_rss()} with \code{unmappable_effects = "ash"}. +#' +#' @inheritParams susie_rss_weights +#' @param susie_ash_rss_fit Optional pre-fitted SuSiE-ASH-RSS object. +#' @return Numeric vector of variant weights. +#' @export +susie_ash_rss_weights <- function(stat, LD, susie_ash_rss_fit = NULL, retain_fit = TRUE, + method_args = list()) { + .susie_rss_extract_weights(fit = susie_ash_rss_fit, z = stat$z, R = LD, n = median(stat$n), + required_fields = c("alpha", "mu", "theta", "X_column_scale_factors"), + fit_args = c(list(unmappable_effects = "ash", convergence_method = "pip"), method_args), + retain_fit = retain_fit) +} + #' Compute mr.mash TWAS weights #' #' Extracts coefficients from an existing mr.mash fit or fits mr.mash from `X` and `Y`. diff --git a/R/slalom.R b/R/slalom.R index 24ce5395..6ea94d72 100644 --- a/R/slalom.R +++ b/R/slalom.R @@ -36,13 +36,11 @@ slalom <- function(zScore, R = NULL, X = NULL, standard_error = rep(1, length(zS abf_prior_variance = 0.04, nlog10p_dentist_s_threshold = 4.0, r2_threshold = 0.6, lead_variant_choice = "pvalue", ld_method = "sample") { - # Resolve LD matrix from R or X - ld_resolved <- resolve_LD_input(R = R, X = X, need_nSample = FALSE, - ld_method = ld_method) - LD_mat <- ld_resolved$R - - if (!is.matrix(LD_mat) || nrow(LD_mat) != ncol(LD_mat) || nrow(LD_mat) != length(zScore)) { - stop("LD_mat must be a square matrix matching the length of zScore.") + if (is.null(R) && is.null(X)) { + stop("Either R (LD matrix) or X (genotype matrix) must be provided.") + } + if (!is.null(R) && !is.null(X)) { + stop("Provide either R or X, not both.") } # One-sided p-value matching the original Python implementation (stats.norm.cdf). @@ -85,13 +83,25 @@ slalom <- function(zScore, R = NULL, X = NULL, standard_error = rep(1, length(zS which.max(prob) } - r2 <- LD_mat^2 - t_dentist_s <- (zScore - LD_mat[, lead_idx] * zScore[lead_idx])^2 / (1 - r2[, lead_idx]) + # Only the lead column of R is needed for DENTIST-S. + # When X is provided, compute just that column instead of the full p x p matrix. + if (!is.null(X)) { + if (!is.matrix(X)) X <- as.matrix(X) + r_lead <- as.numeric(cor(X, X[, lead_idx])) + } else { + if (!is.matrix(R) || nrow(R) != ncol(R) || nrow(R) != length(zScore)) { + stop("R must be a square matrix matching the length of zScore.") + } + r_lead <- R[, lead_idx] + } + + r2_lead <- r_lead^2 + t_dentist_s <- (zScore - r_lead * zScore[lead_idx])^2 / (1 - r2_lead) t_dentist_s[t_dentist_s < 0] <- Inf nlog10p_dentist_s <- -log10(pchisq(t_dentist_s, df = 1, lower.tail = FALSE)) - outliers <- (r2[, lead_idx] > r2_threshold) & (nlog10p_dentist_s > nlog10p_dentist_s_threshold) + outliers <- (r2_lead > r2_threshold) & (nlog10p_dentist_s > nlog10p_dentist_s_threshold) - n_r2 <- sum(r2[, lead_idx] > r2_threshold) + n_r2 <- sum(r2_lead > r2_threshold) n_dentist_s_outlier <- sum(outliers, na.rm = TRUE) max_pip <- max(prob) diff --git a/R/sumstats_qc.R b/R/sumstats_qc.R index ca14e4fe..7ef522f6 100644 --- a/R/sumstats_qc.R +++ b/R/sumstats_qc.R @@ -250,9 +250,18 @@ summary_stats_qc <- function(sumstats, LD_data, n = NULL, if (is.null(rss_input) && is.data.frame(sumstats) && is.null(qc_method) && isFALSE(impute) && identical(pip_cutoff_to_skip, 0) && is.null(skip_region)) { method <- match.arg(method) - # Extract LD matrix from either LDData S4 object or legacy list - LD_mat <- if (is(LD_data, "LDData")) getCorrelation(LD_data) else LD_data$LD_matrix - LD_extract <- LD_mat[sumstats$variant_id, sumstats$variant_id, drop = FALSE] + # When genotypes are available, compute R only for the needed variant subset + if (is(LD_data, "LDData") && hasGenotypes(LD_data)) { + X <- getGenotypes(LD_data) + vid <- getVariantIds(LD_data) + idx <- match(sumstats$variant_id, vid) + X_sub <- X[, idx[!is.na(idx)], drop = FALSE] + colnames(X_sub) <- sumstats$variant_id[!is.na(idx)] + LD_extract <- compute_LD(X_sub, method = "sample") + } else { + LD_mat <- if (is(LD_data, "LDData")) getCorrelation(LD_data) else LD_data$LD_matrix + LD_extract <- LD_mat[sumstats$variant_id, sumstats$variant_id, drop = FALSE] + } qc_results <- ld_mismatch_qc(zScore = sumstats$z, R = LD_extract, nSample = n, method = method) keep_index <- which(!qc_results$outlier) diff --git a/R/susie_wrapper.R b/R/susie_wrapper.R index 357e9c25..3c45357a 100644 --- a/R/susie_wrapper.R +++ b/R/susie_wrapper.R @@ -148,6 +148,51 @@ fit_susie_inf_then_susie <- function(X, y, args = list(), list(susie = susie_fit, susie_inf = susie_inf_fit) } +#' Two-stage SuSiE-RSS Fine-mapping +#' +#' RSS analog of \code{fit_susie_inf_then_susie}. Fits SuSiE-inf via +#' \code{susie_rss} first, then initialises standard SuSiE-RSS from +#' the SuSiE-inf result. The single pair of fits can be used both for +#' fine-mapping post-processing and TWAS weight extraction. +#' +#' @param z Numeric vector of z-scores. +#' @param R LD correlation matrix. +#' @param n Sample size (scalar). +#' @param args Default arguments forwarded to both fits. +#' @param susie_inf_args SuSiE-inf-specific overrides. +#' @param susie_args Standard SuSiE-RSS-specific overrides. +#' @param fitted_models Optional list with pre-fitted \code{$susie} and/or +#' \code{$susie_inf} objects to skip re-fitting. +#' @return A list with \code{susie} and \code{susie_inf} fit objects. +#' @importFrom susieR susie_rss +#' @export +fit_susie_inf_then_susie_rss <- function(z, R, n, args = list(), + susie_inf_args = list(), + susie_args = list(), + fitted_models = NULL) { + if (is.null(fitted_models)) fitted_models <- list() + susie_inf_fit <- fitted_models[["susie_inf"]] + susie_fit <- fitted_models[["susie"]] + + if (is.null(susie_inf_fit)) { + fit_args <- modifyList(args, susie_inf_args) + fit_args <- modifyList(fit_args, list( + z = z, R = R, n = n, unmappable_effects = "inf", + convergence_method = "pip", refine = FALSE, model_init = NULL + )) + susie_inf_fit <- do.call(susie_rss, fit_args) + } + susie_inf_fit <- .set_finemapping_fit_class(susie_inf_fit, "susie_inf") + + if (is.null(susie_fit)) { + fit_args <- prepare_susie_from_inf_args(modifyList(args, susie_args), susie_inf_fit, refine_default = TRUE) + susie_fit <- do.call(susie_rss, c(list(z = z, R = R, n = n), fit_args)) + } + susie_fit <- .set_finemapping_fit_class(susie_fit, "susie_rss") + + list(susie = susie_fit, susie_inf = susie_inf_fit) +} + #' Post-process Fine-mapping Fits #' #' Applies method-aware post-processing to one or more SuSiE-family fits and @@ -182,7 +227,8 @@ postprocess_finemapping_fits <- function(fits, data_x, data_y = NULL, other_quantities = NULL, region = NULL, prior_eff_tol = 1e-9, - min_abs_corr = 0.8) { + min_abs_corr = 0.8, + cs_input = NULL) { fits <- fits[!vapply(fits, is.null, logical(1))] if (length(fits) == 0) stop("At least one fine-mapping fit must be supplied.") if (is.null(names(fits)) || any(names(fits) == "")) { @@ -200,7 +246,8 @@ postprocess_finemapping_fits <- function(fits, data_x, data_y = NULL, coverage = coverage, secondary_coverage = secondary_coverage, signal_cutoff = signal_cutoff, other_quantities = other_quantities, region = region, - prior_eff_tol = prior_eff_tol, min_abs_corr = min_abs_corr + prior_eff_tol = prior_eff_tol, min_abs_corr = min_abs_corr, + cs_input = cs_input ) }) names(posts) <- names(fits) @@ -229,28 +276,33 @@ postprocess_finemapping_fit <- function(fit, ...) { } #' @exportS3Method -postprocess_finemapping_fit.susie <- function(fit, method = "susie", ...) { - .postprocess_finemapping_fit_common(fit, method = method, cs_input = "X", ...) +postprocess_finemapping_fit.susie <- function(fit, method = "susie", cs_input = NULL, ...) { + if (is.null(cs_input)) cs_input <- "X" + .postprocess_finemapping_fit_common(fit, method = method, cs_input = cs_input, ...) } #' @exportS3Method -postprocess_finemapping_fit.susie_inf <- function(fit, method = "susie_inf", ...) { - .postprocess_finemapping_fit_common(fit, method = method, cs_input = "X", ...) +postprocess_finemapping_fit.susie_inf <- function(fit, method = "susie_inf", cs_input = NULL, ...) { + if (is.null(cs_input)) cs_input <- "X" + .postprocess_finemapping_fit_common(fit, method = method, cs_input = cs_input, ...) } #' @exportS3Method -postprocess_finemapping_fit.susie_rss <- function(fit, method = "susie_rss", ...) { - .postprocess_finemapping_fit_common(fit, method = method, cs_input = "Xcorr", ...) +postprocess_finemapping_fit.susie_rss <- function(fit, method = "susie_rss", cs_input = NULL, ...) { + if (is.null(cs_input)) cs_input <- "Xcorr" + .postprocess_finemapping_fit_common(fit, method = method, cs_input = cs_input, ...) } #' @exportS3Method -postprocess_finemapping_fit.mvsusie <- function(fit, method = "mvsusie", ...) { - .postprocess_finemapping_fit_common(fit, method = method, cs_input = "X", ...) +postprocess_finemapping_fit.mvsusie <- function(fit, method = "mvsusie", cs_input = NULL, ...) { + if (is.null(cs_input)) cs_input <- "X" + .postprocess_finemapping_fit_common(fit, method = method, cs_input = cs_input, ...) } #' @exportS3Method -postprocess_finemapping_fit.susiF <- function(fit, method = "fsusie", ...) { - .postprocess_finemapping_fit_common(fit, method = method, cs_input = "fsusie", ...) +postprocess_finemapping_fit.susiF <- function(fit, method = "fsusie", cs_input = NULL, ...) { + if (is.null(cs_input)) cs_input <- "fsusie" + .postprocess_finemapping_fit_common(fit, method = method, cs_input = cs_input, ...) } .postprocess_finemapping_fit_common <- function(fit, method, data_x, data_y = NULL, @@ -892,14 +944,18 @@ susie_rss_pipeline <- function(sumstats, LD_mat = NULL, X_mat = NULL, n = NULL, res <- do.call(susie_rss, c(common, list(L = L, L_greedy = L_greedy))) } - # For post-processing, need a square matrix (R or computed from X). - # For mixture panels (list of X), use the first panel to compute R. + # For post-processing, pass genotype matrix X directly when available. + # susie_get_cs(fit, X=...) computes correlations only for CS variants, + # avoiding the full p x p R matrix. if (!is.null(LD_mat)) { data_x <- LD_mat + pp_cs_input <- "Xcorr" } else if (is.list(X_mat) && !is.matrix(X_mat)) { - data_x <- compute_LD(X_mat[[1]][, seq_along(z), drop = FALSE], method = "sample") + data_x <- X_mat[[1]][, seq_along(z), drop = FALSE] + pp_cs_input <- "X" } else { - data_x <- compute_LD(X_mat[, seq_along(z), drop = FALSE], method = "sample") + data_x <- X_mat[, seq_along(z), drop = FALSE] + pp_cs_input <- "X" } rss_method <- analysis_method @@ -911,7 +967,8 @@ susie_rss_pipeline <- function(sumstats, LD_mat = NULL, X_mat = NULL, n = NULL, coverage = coverage, secondary_coverage = secondary_coverage, signal_cutoff = signal_cutoff, - min_abs_corr = min_abs_corr + min_abs_corr = min_abs_corr, + cs_input = pp_cs_input ) format_finemapping_output(post, primary_method = rss_method) } diff --git a/R/twas.R b/R/twas.R index 8dc8ad9d..f035e782 100644 --- a/R/twas.R +++ b/R/twas.R @@ -119,8 +119,16 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, mol_res[["variant_names"]][[context]][[study]] <- rownames(weights_matrix_subset) # Step 6: scale weights by variance (from sketch ref_panel) - variance <- sketch$ref_panel$variance[match(rownames(weights_matrix_subset), sketch$ref_panel$variant_id)] - mol_res[["weights_qced"]][[context]][[study]] <- list(scaled_weights = weights_matrix_subset * sqrt(variance), weights = weights_matrix_subset) + # RSS/standardized weights are already on the correlation scale and + # do not need sqrt(variance) scaling. + is_standardized <- isTRUE(mol_data[["standardized"]]) + if (is_standardized) { + scaled <- weights_matrix_subset + } else { + variance <- sketch$ref_panel$variance[match(rownames(weights_matrix_subset), sketch$ref_panel$variant_id)] + scaled <- weights_matrix_subset * sqrt(variance) + } + mol_res[["weights_qced"]][[context]][[study]] <- list(scaled_weights = scaled, weights = weights_matrix_subset) } # Combine GWAS sumstats for this study (filter to variants used by any context) used_variants <- unique(find_data(mol_res[["variant_names"]], c(2, study))) @@ -344,6 +352,15 @@ twas_pipeline <- function(twas_weights_data, } pick_best_model <- function(twas_data_combined, rsq_cutoff, rsq_pval_cutoff, rsq_option, rsq_pval_option) { best_rsq <- rsq_cutoff + # SS-TWAS path: no CV performance, all methods are valid + if (is.null(twas_data_combined$twas_cv_performance) || + length(twas_data_combined$twas_cv_performance) == 0) { + model_selection <- lapply(names(twas_data_combined$weights), function(context) { + list(selected_model = NA, is_imputable = TRUE, all_methods = TRUE) + }) + names(model_selection) <- names(twas_data_combined$weights) + return(model_selection) + } # Determine if a gene/region is imputable and select the best model model_selection <- lapply(names(twas_data_combined$weights), function(context) { selected_model <- NULL @@ -450,7 +467,9 @@ twas_pipeline <- function(twas_weights_data, if (length(twas_variants) == 0) { return(list(twas_rs_df = data.frame(), mr_rs_df = data.frame())) } - # twas analysis + # twas analysis -- enable omnibus when no CV performance available + has_cv <- !is.null(twas_weights_data[[weight_db]]$twas_cv_performance) && + length(twas_weights_data[[weight_db]]$twas_cv_performance) > 0 twas_rs <- twas_analysis( twas_data_qced[[weight_db]][["weights_qced"]][[context]][[study]][["weights"]], twas_data_qced[[weight_db]][["gwas_qced"]][[study]], @@ -458,7 +477,8 @@ twas_pipeline <- function(twas_weights_data, V = twas_data_qced[[weight_db]][["svd_V"]], D = twas_data_qced[[weight_db]][["svd_D"]], n_sketch = twas_data_qced[[weight_db]][["n_sketch"]], - ld_variant_ids = twas_data_qced[[weight_db]][["ld_variant_ids"]] + ld_variant_ids = twas_data_qced[[weight_db]][["ld_variant_ids"]], + combine_if_no_cv = !has_cv ) if (is.null(twas_rs)) { return(list(twas_rs_df = data.frame(), mr_rs_df = data.frame())) @@ -525,22 +545,39 @@ twas_pipeline <- function(twas_weights_data, contexts <- names(twas_weights_data[[molecular_id]]$weights) # merge twas_cv information for same gene across all weight db files, loop through each context for all methods gene_table <- do.call(rbind, lapply(contexts, function(context) { - methods <- sub("_[^_]+$", "", names(twas_weights_data[[molecular_id]]$twas_cv_performance[[context]])) - is_imputable <- twas_data[[molecular_id]][["model_selection"]][[context]]$is_imputable - selected_method <- twas_data[[molecular_id]][["model_selection"]][[context]]$selected_model - if (is.null(selected_method)) selected_method <- NA - is_selected_method <- ifelse(methods == selected_method, TRUE, FALSE) - - cv_rsqs <- sapply(twas_weights_data[[molecular_id]]$twas_cv_performance[[context]], function(x) x[, rsq_option]) - cv_pvals <- sapply(twas_weights_data[[molecular_id]]$twas_cv_performance[[context]], function(x) x[, colnames(x)[which(colnames(x) %in% rsq_pval_option)]]) - - context_table <- data.frame( - context = context, method = methods, - is_imputable = is_imputable, - is_selected_method = is_selected_method, - rsq_cv = cv_rsqs, pval_cv = cv_pvals, - type = twas_weights_data[[molecular_id]][["data_type"]][[context]] - ) + cv_perf <- twas_weights_data[[molecular_id]]$twas_cv_performance[[context]] + model_sel <- twas_data[[molecular_id]][["model_selection"]][[context]] + is_imputable <- if (!is.null(model_sel)) model_sel$is_imputable else TRUE + + if (is.null(cv_perf) || length(cv_perf) == 0) { + # SS-TWAS path: no CV, derive methods from weight matrix columns + wt_mat <- twas_weights_data[[molecular_id]]$weights[[context]] + methods <- if (is.matrix(wt_mat)) colnames(wt_mat) else names(wt_mat) + if (is.null(methods)) methods <- "unknown" + context_table <- data.frame( + context = context, method = methods, + is_imputable = is_imputable, + is_selected_method = FALSE, + rsq_cv = NA_real_, pval_cv = NA_real_, + type = twas_weights_data[[molecular_id]][["data_type"]][[context]] + ) + } else { + methods <- sub("_[^_]+$", "", names(cv_perf)) + selected_method <- if (!is.null(model_sel)) model_sel$selected_model else NA + if (is.null(selected_method)) selected_method <- NA + is_selected_method <- ifelse(methods == selected_method, TRUE, FALSE) + + cv_rsqs <- sapply(cv_perf, function(x) x[, rsq_option]) + cv_pvals <- sapply(cv_perf, function(x) x[, colnames(x)[which(colnames(x) %in% rsq_pval_option)]]) + + context_table <- data.frame( + context = context, method = methods, + is_imputable = is_imputable, + is_selected_method = is_selected_method, + rsq_cv = cv_rsqs, pval_cv = cv_pvals, + type = twas_weights_data[[molecular_id]][["data_type"]][[context]] + ) + } return(context_table) })) gene_table$molecular_id <- molecular_id @@ -616,6 +653,10 @@ twas_z <- function(weights, z, R = NULL, X = NULL, V = NULL, D = NULL, n_sketch #' #' @param R An optional correlation matrix. If not provided, it will be calculated from the genotype matrix X. #' @param X An optional genotype matrix. If R is not provided, X must be supplied to calculate the correlation matrix. +#' @param V Optional SVD right-singular vectors (variants x components) from an LD sketch. +#' When provided with \code{D_svd} and \code{n_sketch}, avoids forming the full LD matrix. +#' @param D_svd Optional SVD singular values (vector) from an LD sketch. +#' @param n_sketch Optional sample size of the LD sketch. #' @param weights A matrix of weights, where each column corresponds to a different condition. #' @param z A vector of GWAS z-scores. #' @@ -627,7 +668,8 @@ twas_z <- function(weights, z, R = NULL, X = NULL, V = NULL, D = NULL, n_sketch #' #' @importFrom stats cor pnorm #' @export -twas_joint_z <- function(weights, z, R = NULL, X = NULL) { +twas_joint_z <- function(weights, z, R = NULL, X = NULL, + V = NULL, D_svd = NULL, n_sketch = NULL) { # Make sure GBJ is installed if (!requireNamespace("GBJ", quietly = TRUE)) { stop("To use this function, please install GBJ: https://cran.r-project.org/web/packages/GBJ/index.html") @@ -636,10 +678,24 @@ twas_joint_z <- function(weights, z, R = NULL, X = NULL) { if (nrow(weights) != length(z)) { stop("Number of rows in weights must match the length of z-scores.") } - if (is.null(R)) R <- compute_LD(X) - idx <- which(rownames(R) %in% rownames(weights)) - D <- R[idx, idx] - cov_y <- crossprod(weights, D) %*% weights + + use_svd <- !is.null(V) && !is.null(D_svd) && !is.null(n_sketch) + + if (use_svd) { + # SVD path: R ≈ V diag(Lambda) V' where Lambda = D_svd²/(n_sketch-1) + Lambda <- D_svd^2 / (n_sketch - 1) + idx <- which(rownames(V) %in% rownames(weights)) + V_sub <- V[idx, , drop = FALSE] + # cov_y = weights' R_sub weights = weights' V_sub diag(Lambda) V_sub' weights + VtW <- crossprod(V_sub, weights) # r x k + cov_y <- crossprod(VtW * sqrt(Lambda)) # k x k + } else { + if (is.null(R)) R <- compute_LD(X) + idx <- which(rownames(R) %in% rownames(weights)) + D <- R[idx, idx] + cov_y <- crossprod(weights, D) %*% weights + } + y_sd <- sqrt(diag(cov_y)) x_sd <- rep(1, nrow(weights)) # Assuming X is standardized @@ -667,7 +723,15 @@ twas_joint_z <- function(weights, z, R = NULL, X = NULL) { la <- as.matrix(weights[, p] %*% g[[p]]) lam[p, ] <- la } - sig <- tcrossprod((lam %*% D), lam) + + if (use_svd) { + # sig = lam R_sub lam' = lam V_sub diag(Lambda) V_sub' lam' + LV <- lam %*% V_sub # k x r + sig <- tcrossprod(sweep(LV, 2, Lambda, "*"), LV) # k x k + } else { + sig <- tcrossprod((lam %*% D), lam) + } + gbj <- GBJ::GBJ(test_stats = z_matrix[, 1], cor_mat = sig) rs <- list("Z" = z_matrix, "GBJ" = gbj) @@ -680,16 +744,35 @@ twas_joint_z <- function(weights, z, R = NULL, X = NULL) { #' and LD matrix. It extracts the necessary GWAS summary statistics and LD matrix based on the #' specified variants and computes the z-score and p-value for each gene. #' +#' When \code{combine_if_no_cv = TRUE} and there are at least two methods with +#' valid p-values, an omnibus p-value is computed via the method specified in +#' \code{combine_method} and appended as an \code{"omnibus"} entry. This is +#' intended for summary-statistics TWAS where cross-validation performance is +#' not available for model selection. +#' #' @param weights_matrix A matrix containing weights for all methods. #' @param gwas_sumstats_db A data frame containing the GWAS summary statistics. #' @param LD_matrix A matrix representing linkage disequilibrium between variants. #' @param extract_variants_objs A vector of variant identifiers to extract from the GWAS and LD matrix. +#' @param V SVD right-singular vectors from LD sketch (optional). +#' @param D SVD singular values from LD sketch (optional). +#' @param n_sketch Sample size of LD sketch (optional). +#' @param ld_variant_ids Variant IDs in the LD sketch (optional). +#' @param combine_method P-value combination method: \code{"acat"} (default), +#' \code{"hmp"}, \code{"fisher"}, \code{"stouffer"}, \code{"invchisq"}, +#' \code{"gbj"}, \code{"aspu"}, or \code{"gates"}. +#' @param combine_if_no_cv Logical. If TRUE and no CV performance is available, +#' combine per-method p-values into an omnibus result. #' -#' @return A list with TWAS z-scores and p-values across four methods for each gene. +#' @return A list with TWAS z-scores and p-values across methods for each gene. +#' When omnibus combination is enabled, includes an additional \code{"omnibus"} +#' entry. #' @export twas_analysis <- function(weights_matrix, gwas_sumstats_db, LD_matrix = NULL, extract_variants_objs, V = NULL, D = NULL, - n_sketch = NULL, ld_variant_ids = NULL) { + n_sketch = NULL, ld_variant_ids = NULL, + combine_method = "acat", + combine_if_no_cv = FALSE) { # Extract gwas_sumstats gwas_sumstats_subset <- gwas_sumstats_db[match(extract_variants_objs, gwas_sumstats_db$variant_id), ] # Validate that the GWAS subset is not empty @@ -715,7 +798,8 @@ twas_analysis <- function(weights_matrix, gwas_sumstats_db, LD_matrix = NULL, as.matrix(weights_matrix), 2, function(x) twas_z(x, gwas_sumstats_subset$z, V = V_subset, D = D, n_sketch = n_sketch) ) - return(twas_z_pval) + return(.maybe_add_omnibus(twas_z_pval, weights_matrix, LD_matrix, + combine_method, combine_if_no_cv)) } # LD matrix path @@ -732,5 +816,53 @@ twas_analysis <- function(weights_matrix, gwas_sumstats_db, LD_matrix = NULL, as.matrix(weights_matrix), 2, function(x) twas_z(x, gwas_sumstats_subset$z, R = LD_matrix_subset) ) - return(twas_z_pval) + return(.maybe_add_omnibus(twas_z_pval, weights_matrix, LD_matrix_subset, + combine_method, combine_if_no_cv)) +} + +#' Add omnibus p-value combination to TWAS results +#' @noRd +.maybe_add_omnibus <- function(twas_z_pval, weights_matrix, LD_matrix, + combine_method, combine_if_no_cv) { + if (!isTRUE(combine_if_no_cv) || length(twas_z_pval) < 2) { + return(twas_z_pval) + } + + pvals <- vapply(twas_z_pval, function(x) as.numeric(x$pval), numeric(1)) + zscores <- vapply(twas_z_pval, function(x) as.numeric(x$z), numeric(1)) + valid <- !is.na(pvals) & is.finite(pvals) & pvals > 0 & pvals < 1 + + if (sum(valid) < 2) return(twas_z_pval) + + combined_pval <- tryCatch({ + switch(combine_method, + acat = pval_acat(pvals[valid]), + hmp = pval_hmp(pvals[valid]), + fisher = , stouffer = , invchisq = { + method_cor <- twas_method_cor( + lapply(which(valid), function(i) weights_matrix[, i]), + LD_matrix) + pval_poolr(pvals[valid], method = combine_method, R = method_cor) + }, + gbj = { + method_cor <- twas_method_cor( + lapply(which(valid), function(i) weights_matrix[, i]), + LD_matrix) + pval_gbj(zscores[valid], R = method_cor, method = combine_method) + }, + aspu = , gates = { + method_cor <- twas_method_cor( + lapply(which(valid), function(i) weights_matrix[, i]), + LD_matrix) + pval_aspu(zscores[valid], pvals[valid], R = method_cor, method = combine_method) + }, + pval_acat(pvals[valid]) # fallback + ) + }, error = function(e) { + warning(sprintf("Omnibus combination (%s) failed: %s", combine_method, e$message)) + NA_real_ + }) + + twas_z_pval[["omnibus"]] <- list(z = NA_real_, pval = combined_pval) + twas_z_pval } diff --git a/R/twas_weights.R b/R/twas_weights.R index fe8853db..9378daeb 100644 --- a/R/twas_weights.R +++ b/R/twas_weights.R @@ -1642,3 +1642,296 @@ ensemble_weights <- function(cv_results, Y, twas_weight_list = NULL, method_performance = method_rsq ) } + +# ============================================================================= +# Summary-statistics TWAS weight training pipeline +# ============================================================================= + +#' Train TWAS weights from summary statistics and LD reference +#' +#' Replaces the OTTERS pipeline with a properly integrated workflow that: +#' (1) runs RSS QC on eQTL summary statistics, (2) trains weights via multiple +#' RSS methods, and (3) extracts fine-mapping results from the shared SuSiE-RSS +#' fit. Returns a \code{TWASWeights} S4 object with \code{standardized = TRUE} +#' that feeds directly into \code{harmonize_twas} and \code{twas_analysis}. +#' +#' @param sumstats Data.frame with columns: \code{variant_id}, \code{A1}, +#' \code{A2}, \code{chrom}, \code{pos}, and either \code{z} or both +#' \code{beta} and \code{se}. +#' @param LD_data LDData S4 object, or a legacy list with \code{LD_matrix}, +#' \code{LD_variants}, \code{ref_panel}. Can also be a plain correlation +#' matrix (variant IDs taken from row/colnames). +#' @param n eQTL study sample size (scalar). +#' @param methods Named list of RSS weight methods and their arguments. +#' Method names correspond to functions named +#' \code{_weights(stat, LD, ...)}. Defaults include lassosum_rss, +#' prs_cs, sdpr, susie_rss, and susie_inf_rss. +#' @param p_thresholds Numeric vector of p-value thresholds for P+T weights. +#' Set to NULL to skip. +#' @param check_ld_method LD matrix repair method: \code{"eigenfix"} (default), +#' \code{"shrink"}, or NULL to skip. +#' @param qc_method RSS QC method for eQTL data: \code{"slalom"}, +#' \code{"dentist"}, or NULL/\code{"none"} to skip. +#' @param keep_indel Whether to keep indels during QC. Default TRUE. +#' @param pip_cutoff_to_skip PIP threshold for early stopping. Default 0 (off). +#' @param impute Whether to run RAISS imputation. Default FALSE. +#' @param impute_opts RAISS imputation parameters. +#' @param var_y Phenotype variance. Default 1. +#' @param verbose Verbosity level. +#' +#' @return A list with: +#' \describe{ +#' \item{twas_weights}{A \code{TWASWeights} S4 object with +#' \code{standardized = TRUE}.} +#' \item{finemapping_result}{A \code{FineMappingResult} S4 object from the +#' SuSiE-RSS fit, or NULL if no SuSiE-RSS method was used.} +#' \item{qc_summary}{List with outlier counts and QC metadata.} +#' } +#' +#' @export +twas_weights_sumstat_pipeline <- function( + sumstats, LD_data, n, + methods = list( + lassosum_rss = list(), + prs_cs = list(phi = 1e-4, n_iter = 1000, n_burnin = 500, thin = 5), + sdpr = list(iter = 1000, burn = 200, thin = 1, verbose = FALSE), + susie_rss = list(), + susie_inf_rss = list() + ), + p_thresholds = c(0.001, 0.05), + check_ld_method = "eigenfix", + qc_method = NULL, + keep_indel = TRUE, + pip_cutoff_to_skip = 0, + impute = FALSE, + impute_opts = list(rcond = 0.01, R2_threshold = 0.6, + minimum_ld = 5, lamb = 0.01), + var_y = 1, verbose = 1) { + + # ----------------------------------------------------------------------- + # 1. RSS QC on eQTL summary statistics + # ----------------------------------------------------------------------- + needs_qc <- !is.null(qc_method) && !identical(qc_method, "none") + if (needs_qc || impute || pip_cutoff_to_skip != 0) { + qc_result <- summary_stats_qc( + rss_input = list(sumstats = sumstats, n = n, var_y = var_y), + LD_data = LD_data, + keep_indel = keep_indel, + pip_cutoff_to_skip = pip_cutoff_to_skip, + qc_method = qc_method, + impute = impute, + impute_opts = impute_opts, + return_on_skip = "null" + ) + if (is.null(qc_result) || isTRUE(qc_result$skipped)) { + return(list(twas_weights = NULL, finemapping_result = NULL, + qc_summary = list(skipped = TRUE))) + } + sumstats <- qc_result$rss_input$sumstats + LD_mat <- qc_result$LD_matrix + outlier_number <- qc_result$outlier_number + } else { + # No QC requested: extract LD matrix directly + if (is.matrix(LD_data)) { + LD_mat <- LD_data + } else if (is(LD_data, "LDData")) { + LD_mat <- getCorrelation(LD_data) + } else if (is.list(LD_data) && !is.null(LD_data$LD_matrix)) { + LD_mat <- LD_data$LD_matrix + } else { + stop("LD_data must be a matrix, LDData object, or list with LD_matrix.") + } + outlier_number <- 0L + } + + if (nrow(sumstats) < 2) { + return(list(twas_weights = NULL, finemapping_result = NULL, + qc_summary = list(skipped = TRUE, reason = "fewer than 2 variants"))) + } + + # ----------------------------------------------------------------------- + # 2. Compute z-scores and build stat object + # ----------------------------------------------------------------------- + if (is.null(sumstats$z)) { + if (!is.null(sumstats$beta) && !is.null(sumstats$se)) { + sumstats$z <- sumstats$beta / sumstats$se + } else { + stop("sumstats must have 'z' or ('beta' and 'se') columns.") + } + } + + p <- nrow(sumstats) + z <- sumstats$z + variant_ids <- sumstats$variant_id + b <- z / sqrt(n) + stat <- list(b = b, cor = b, z = z, n = rep(n, p)) + + # Align LD matrix to sumstats variant order + if (!is.null(rownames(LD_mat)) && !is.null(variant_ids)) { + common <- intersect(variant_ids, rownames(LD_mat)) + if (length(common) < p) { + idx <- match(common, variant_ids) + sumstats <- sumstats[idx, , drop = FALSE] + z <- sumstats$z + variant_ids <- sumstats$variant_id + b <- z / sqrt(n) + stat <- list(b = b, cor = b, z = z, n = rep(n, length(z))) + p <- length(z) + } + LD_mat <- LD_mat[variant_ids, variant_ids, drop = FALSE] + } + + # ----------------------------------------------------------------------- + # 3. LD eigenfix (optional) + # ----------------------------------------------------------------------- + if (!is.null(check_ld_method)) { + ld_check <- check_ld(LD_mat, method = check_ld_method) + if (ld_check$method_applied != "none") { + if (verbose >= 1) { + message(sprintf("check_ld: repaired LD via '%s' (min eigenvalue was %.2e, %d negative).", + ld_check$method_applied, ld_check$min_eigenvalue, ld_check$n_negative)) + } + } + LD_mat <- ld_check$R + } + + # ----------------------------------------------------------------------- + # 4. Two-stage SuSiE-RSS (shared fit for susie_rss + susie_inf_rss) + # ----------------------------------------------------------------------- + has_susie_rss <- "susie_rss" %in% names(methods) + has_susie_inf_rss <- "susie_inf_rss" %in% names(methods) + susie_fits <- NULL + + if (has_susie_rss && has_susie_inf_rss) { + susie_args <- methods[["susie_rss"]] + susie_inf_args <- methods[["susie_inf_rss"]] + susie_fits <- fit_susie_inf_then_susie_rss( + z = z, R = LD_mat, n = n, + susie_inf_args = susie_inf_args, + susie_args = susie_args + ) + } + + # ----------------------------------------------------------------------- + # 5. P+T weights + # ----------------------------------------------------------------------- + results <- list() + if (!is.null(p_thresholds)) { + pvals <- pchisq(z^2, df = 1, lower.tail = FALSE) + for (thr in p_thresholds) { + selected <- pvals < thr + w <- ifelse(selected, stat$b, 0) + results[[paste0("PT_", thr)]] <- w + } + } + + # ----------------------------------------------------------------------- + # 6. RSS method dispatch + # ----------------------------------------------------------------------- + susie_rss_fit_for_fm <- NULL + + for (method_name in names(methods)) { + fn_name <- paste0(method_name, "_weights") + if (!exists(fn_name, mode = "function")) { + warning(sprintf("Method '%s' not found (looking for function '%s'). Skipping.", + method_name, fn_name)) + next + } + + method_args <- methods[[method_name]] + + # Build call arguments: separate pre-fitted objects from method_args + call_args <- list(stat = stat, LD = LD_mat) + if (method_name == "susie_rss" && !is.null(susie_fits)) { + call_args[["susie_rss_fit"]] <- susie_fits$susie + } else if (method_name == "susie_inf_rss" && !is.null(susie_fits)) { + call_args[["susie_inf_rss_fit"]] <- susie_fits$susie_inf + } + + # SuSiE-RSS methods use method_args; others spread args directly + is_susie_rss_method <- method_name %in% c("susie_rss", "susie_inf_rss", "susie_ash_rss") + if (is_susie_rss_method) { + call_args[["method_args"]] <- method_args + } else { + call_args <- c(call_args, method_args) + } + + tryCatch({ + w <- do.call(fn_name, call_args) + # Capture retained fit for fine-mapping post-processing + if (method_name == "susie_rss" && !is.null(attr(w, "fit"))) { + susie_rss_fit_for_fm <- attr(w, "fit") + } else if (method_name == "susie_inf_rss" && is.null(susie_rss_fit_for_fm) && !is.null(attr(w, "fit"))) { + susie_rss_fit_for_fm <- attr(w, "fit") + } + results[[method_name]] <- as.numeric(w) + }, error = function(e) { + warning(sprintf("Method '%s' failed: %s", method_name, e$message)) + results[[method_name]] <<- rep(0, p) + }) + } + + if (length(results) == 0) { + return(list(twas_weights = NULL, finemapping_result = NULL, + qc_summary = list(skipped = TRUE, reason = "all methods failed"))) + } + + # ----------------------------------------------------------------------- + # 7. Fine-mapping from SuSiE-RSS fit (reuses the same fit) + # ----------------------------------------------------------------------- + finemapping_result <- NULL + if (!is.null(susie_rss_fit_for_fm)) { + fm_fits <- list(susie_rss = susie_rss_fit_for_fm) + tryCatch({ + fm_output <- postprocess_finemapping_fits( + fits = fm_fits, + data_x = LD_mat, + coverage = 0.95, + signal_cutoff = 0.025, + cs_input = "Xcorr" + ) + if (!is.null(fm_output$finemapping_results$susie_rss)) { + fm_res <- fm_output$finemapping_results$susie_rss + # Use top_loci from the per-method result (has 'method' column) + # rather than the wide format from postprocess_finemapping_fits + tl <- fm_res$top_loci + if (is.null(tl) || nrow(tl) == 0) tl <- data.frame() + finemapping_result <- FineMappingResult( + variant_names = variant_ids, + trimmed_fit = fm_res$result_trimmed, + top_loci = tl, + method = "susie_rss", + sumstats = list(z = z, n = n) + ) + } + }, error = function(e) { + warning(sprintf("Fine-mapping post-processing failed: %s", e$message)) + }) + } + + # ----------------------------------------------------------------------- + # 8. Package into TWASWeights S4 + # ----------------------------------------------------------------------- + weights_list <- lapply(results, function(w) { + matrix(w, ncol = 1, dimnames = list(variant_ids, NULL)) + }) + + twas_wt <- TWASWeights( + weights = weights_list, + variant_ids = variant_ids, + standardized = TRUE, + cv_performance = NULL + ) + + list( + twas_weights = twas_wt, + finemapping_result = finemapping_result, + qc_summary = list( + skipped = FALSE, + n_variants_input = p, + n_variants_after_qc = nrow(sumstats), + outlier_number = outlier_number, + methods_succeeded = names(results) + ) + ) +} diff --git a/R/univariate_pipeline.R b/R/univariate_pipeline.R index b09761e0..d01720c7 100644 --- a/R/univariate_pipeline.R +++ b/R/univariate_pipeline.R @@ -270,12 +270,21 @@ rss_analysis_pipeline <- function( impute = TRUE, impute_opts = list(rcond = 0.01, R2_threshold = 0.6, minimum_ld = 5, lamb = 0.01), pip_cutoff_to_skip = 0, R_finite = NULL, R_mismatch = NULL, keep_indel = TRUE, comment_string = "#", diagnostics = FALSE) { - # Convert LDData to legacy list for compatibility with downstream functions + # Convert LDData to legacy list for compatibility with downstream functions. + # When genotypes are available, pass the genotype matrix directly instead of + # computing the full p x p correlation matrix — downstream QC and fine-mapping + # functions already handle genotype data and compute R on demand. if (is(LD_data, "LDData")) { use_X <- hasGenotypes(LD_data) X_data <- if (use_X) getGenotypes(LD_data) else NULL is_X_list <- use_X && is.list(X_data) - LD_data <- ld_data_to_list(LD_data) + if (use_X) { + LD_data <- ld_data_to_list(LD_data, skip_correlation = TRUE) + LD_data$LD_matrix <- X_data + LD_data$is_genotype <- TRUE + } else { + LD_data <- ld_data_to_list(LD_data) + } } else { # Detect genotype input: single X matrix or list of X matrices (mixture panel). # susie_rss accepts X=list(X1, X2, ...) for multi-panel mixture. @@ -338,12 +347,22 @@ rss_analysis_pipeline <- function( LD_mat <- qc_record$LD_mat qc_results <- qc_record if (isTRUE(impute)) { - LD_matrix <- partition_LD_matrix(LD_data) - impute_results <- raiss(LD_data$ref_panel, sumstats, LD_matrix, - rcond = impute_opts$rcond, - R2_threshold = impute_opts$R2_threshold, - minimum_ld = impute_opts$minimum_ld, - lamb = impute_opts$lamb) + if (use_X) { + X_scaled <- scale(subset_X_data(LD_data$LD_variants)) + X_scaled[is.na(X_scaled)] <- 0 + impute_results <- raiss(LD_data$ref_panel, sumstats, + genotype_matrix = X_scaled, + R2_threshold = impute_opts$R2_threshold, + minimum_ld = impute_opts$minimum_ld, + lamb = impute_opts$lamb) + } else { + LD_matrix <- partition_LD_matrix(LD_data) + impute_results <- raiss(LD_data$ref_panel, sumstats, LD_matrix, + rcond = impute_opts$rcond, + R2_threshold = impute_opts$R2_threshold, + minimum_ld = impute_opts$minimum_ld, + lamb = impute_opts$lamb) + } sumstats <- impute_results$result_filter LD_mat <- impute_results$LD_mat } diff --git a/man/TWASWeights-class.Rd b/man/TWASWeights-class.Rd index 1041af27..8c3101f6 100644 --- a/man/TWASWeights-class.Rd +++ b/man/TWASWeights-class.Rd @@ -22,5 +22,10 @@ list).} \item{\code{cv_performance}}{Named list of cross-validation performance metrics, or NULL.} + +\item{\code{standardized}}{Logical, whether weights are on standardized +(correlation) scale. If TRUE, \code{harmonize_twas} skips the +\code{sqrt(variance)} scaling step. Individual-level weights use +FALSE (raw genotype scale); RSS weights use TRUE.} }} diff --git a/man/TWASWeights.Rd b/man/TWASWeights.Rd index 770d3dce..b6a7b8b2 100644 --- a/man/TWASWeights.Rd +++ b/man/TWASWeights.Rd @@ -4,7 +4,13 @@ \alias{TWASWeights} \title{Create a TWASWeights Object} \usage{ -TWASWeights(weights, variant_ids, fits = NULL, cv_performance = NULL) +TWASWeights( + weights, + variant_ids, + fits = NULL, + cv_performance = NULL, + standardized = FALSE +) } \arguments{ \item{weights}{Named list of matrices.} @@ -14,6 +20,10 @@ TWASWeights(weights, variant_ids, fits = NULL, cv_performance = NULL) \item{fits}{Named list or NULL.} \item{cv_performance}{Named list or NULL.} + +\item{standardized}{Logical. If TRUE, weights are on the standardized +(correlation) scale and do not need variance scaling in harmonize_twas. +Defaults to FALSE (individual-level / raw genotype scale).} } \value{ A \code{TWASWeights} object. diff --git a/man/classify_variant_type.Rd b/man/classify_variant_type.Rd index f1359ecd..d6e3e498 100644 --- a/man/classify_variant_type.Rd +++ b/man/classify_variant_type.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/misc.R +% Please edit documentation in R/variant_id.R \name{classify_variant_type} \alias{classify_variant_type} \title{Classify variant type from allele strings} diff --git a/man/find_overlapping_regions.Rd b/man/find_overlapping_regions.Rd index 3fe5f2ef..74b0dd3a 100644 --- a/man/find_overlapping_regions.Rd +++ b/man/find_overlapping_regions.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/misc.R +% Please edit documentation in R/variant_id.R \name{find_overlapping_regions} \alias{find_overlapping_regions} \title{Find which target regions overlap a query region} diff --git a/man/fit_susie_inf_then_susie_rss.Rd b/man/fit_susie_inf_then_susie_rss.Rd new file mode 100644 index 00000000..1e93891c --- /dev/null +++ b/man/fit_susie_inf_then_susie_rss.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/susie_wrapper.R +\name{fit_susie_inf_then_susie_rss} +\alias{fit_susie_inf_then_susie_rss} +\title{Two-stage SuSiE-RSS Fine-mapping} +\usage{ +fit_susie_inf_then_susie_rss( + z, + R, + n, + args = list(), + susie_inf_args = list(), + susie_args = list(), + fitted_models = NULL +) +} +\arguments{ +\item{z}{Numeric vector of z-scores.} + +\item{R}{LD correlation matrix.} + +\item{n}{Sample size (scalar).} + +\item{args}{Default arguments forwarded to both fits.} + +\item{susie_inf_args}{SuSiE-inf-specific overrides.} + +\item{susie_args}{Standard SuSiE-RSS-specific overrides.} + +\item{fitted_models}{Optional list with pre-fitted \code{$susie} and/or +\code{$susie_inf} objects to skip re-fitting.} +} +\value{ +A list with \code{susie} and \code{susie_inf} fit objects. +} +\description{ +RSS analog of \code{fit_susie_inf_then_susie}. Fits SuSiE-inf via +\code{susie_rss} first, then initialises standard SuSiE-RSS from +the SuSiE-inf result. The single pair of fits can be used both for +fine-mapping post-processing and TWAS weight extraction. +} diff --git a/man/load_multitask_regional_data.Rd b/man/load_multitask_regional_data.Rd index 585c2070..46134bef 100644 --- a/man/load_multitask_regional_data.Rd +++ b/man/load_multitask_regional_data.Rd @@ -129,10 +129,10 @@ This function loads a mixture data sets for a specific region, including individ or summary statistics (sumstats, LD). Run \code{load_regional_univariate_data} and \code{load_rss_data} multiple times for different datasets } \section{Loading individual level data from multiple corhorts}{ -NA + } \section{Loading summary statistics from multiple corhorts or data set}{ -NA + } diff --git a/man/otters_association.Rd b/man/otters_association.Rd deleted file mode 100644 index afd989ca..00000000 --- a/man/otters_association.Rd +++ /dev/null @@ -1,70 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/otters.R -\name{otters_association} -\alias{otters_association} -\title{TWAS association testing with omnibus combination (OTTERS Stage II)} -\usage{ -otters_association( - weights, - gwas_z, - LD, - combine_method = c("acat", "hmp", "fisher", "stouffer", "invchisq", "gbj", "bj", "hc", - "ghc", "minp", "gbj_omni", "aspu", "gates") -) -} -\arguments{ -\item{weights}{Named list of weight vectors (output from \code{\link{otters_weights}} -or any named list of numeric vectors).} - -\item{gwas_z}{Numeric vector of GWAS z-scores, same length and order as the -weights vectors. Must be aligned to the same variants and allele orientation -as the weights and LD matrix. Use \code{\link{allele_qc}} or -\code{\link{rss_basic_qc}} for harmonization before calling this function.} - -\item{LD}{LD correlation matrix R, aligned to the same variants as weights -and gwas_z.} - -\item{combine_method}{Method to combine p-values across methods. -Correlation-free (valid under arbitrary dependence): -\code{"acat"} (default), \code{"hmp"}. -Correlation-adjusted via poolr (generalized multivariate theory): -\code{"fisher"} (Brown's method), \code{"stouffer"} (Strube's method), -\code{"invchisq"}. -Set-based tests via GBJ (uses TWAS z-scores and inter-method correlation): -\code{"gbj"}, \code{"bj"}, \code{"hc"}, \code{"ghc"}, \code{"minp"}, -\code{"gbj_omni"}. -Adaptive and Simes-type tests via aSPU: -\code{"aspu"} (adaptive sum of powered scores), -\code{"gates"} (extended Simes / GATES). -The poolr, GBJ, and aSPU methods automatically compute the inter-method -TWAS z-score correlation from the weight vectors and LD matrix.} -} -\value{ -A data.frame with columns: -\describe{ - \item{method}{Method name (per-method rows plus a combined row).} - \item{twas_z}{TWAS z-score (\code{NA} for combined row).} - \item{twas_pval}{TWAS p-value.} - \item{n_snps}{Number of non-zero weight SNPs used.} -} -} -\description{ -Computes per-method TWAS z-scores using the FUSION formula and combines -p-values across methods via ACAT (Aggregated Cauchy Association Test) or -HMP (Harmonic Mean P-value). -} -\details{ -The FUSION TWAS statistic (Gusev et al. 2016) is: -\deqn{Z_{TWAS} = \frac{w^T z}{\sqrt{w^T R w}}} -where \eqn{w} are eQTL weights, \eqn{z} are GWAS z-scores, and \eqn{R} -is the LD correlation matrix. -} -\examples{ -set.seed(42) -p <- 20 -gwas_z <- rnorm(p) -R <- diag(p) -weights <- list(method1 = rnorm(p, sd = 0.01), method2 = rnorm(p, sd = 0.01)) -otters_association(weights, gwas_z, R) - -} diff --git a/man/otters_weights.Rd b/man/otters_weights.Rd deleted file mode 100644 index d9d8bd25..00000000 --- a/man/otters_weights.Rd +++ /dev/null @@ -1,80 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/otters.R -\name{otters_weights} -\alias{otters_weights} -\title{Train eQTL weights using multiple RSS methods (OTTERS Stage I)} -\usage{ -otters_weights( - sumstats, - LD, - n, - methods = list(lassosum_rss = list(), prs_cs = list(phi = 1e-04, n_iter = 1000, - n_burnin = 500, thin = 5), sdpr = list(iter = 1000, burn = 200, thin = 1, verbose = - FALSE)), - p_thresholds = c(0.001, 0.05), - check_ld_method = "eigenfix" -) -} -\arguments{ -\item{sumstats}{A data.frame of eQTL summary statistics. Must contain column \code{z} -(z-scores). If \code{z} is absent but \code{beta} and \code{se} are present, -z-scores are computed as \code{beta / se}.} - -\item{LD}{LD correlation matrix R for the gene region (single matrix, not a list). -Should have row/column names matching variant identifiers if variant alignment -is desired.} - -\item{n}{eQTL study sample size (scalar).} - -\item{methods}{Named list of RSS methods and their extra arguments. Each element -name must correspond to a \code{*_weights} function in pecotmr (without the -\code{_weights} suffix). Defaults match the original OTTERS pipeline -(Zhang et al. 2024): -\itemize{ - \item \code{lassosum_rss}: s grid = c(0.2, 0.5, 0.9, 1.0), lambda from - 0.0001 to 0.1 (20 values on log scale) - \item \code{prs_cs}: phi = 1e-4 (fixed, not learned), 1000 iterations, - 500 burn-in, thin = 5 - \item \code{sdpr}: 1000 iterations, 200 burn-in, thin = 1 (no thinning) -} -To add learners (e.g., \code{mr_ash_rss}), simply append to this list.} - -\item{p_thresholds}{Numeric vector of p-value thresholds for P+T. Set to -\code{NULL} to skip P+T. Default: \code{c(0.001, 0.05)}.} - -\item{check_ld_method}{LD quality check method passed to \code{\link{check_ld}}. -Default \code{"eigenfix"} sets negative eigenvalues to zero (required for -PRS-CS Cholesky, matching OTTERS' SVD-based PD forcing). Set to \code{NULL} -to skip checking.} -} -\value{ -A named list of weight vectors (one per method). Each vector has length - equal to \code{nrow(sumstats)}. P+T results are named \code{PT_}. -} -\description{ -Implements the training stage of the OTTERS framework (Omnibus Transcriptome -Test using Expression Reference Summary data, Zhang et al. 2024). Trains -eQTL effect size weights for a gene region using multiple summary-statistics-based -methods in parallel, enabling downstream omnibus TWAS testing. -} -\details{ -Methods are dispatched dynamically via \code{do.call(paste0(method, "_weights"), ...)}, -so any function following the \code{*_weights(stat, LD, ...)} convention can be used -(e.g., \code{lassosum_rss_weights}, \code{prs_cs_weights}, \code{sdpr_weights}, -\code{mr_ash_rss_weights}). - -P+T (pruning and thresholding) is handled internally: for each threshold, SNPs with -eQTL p-value below the threshold are selected, and their marginal z-scores (scaled -to correlation units: \code{z / sqrt(n)}) are used as weights. -} -\examples{ -set.seed(42) -n <- 500; p <- 20 -z <- rnorm(p, sd = 2) -R <- diag(p) -sumstats <- data.frame(z = z) -weights <- otters_weights(sumstats, R, n, - methods = list(lassosum_rss = list()), - p_thresholds = c(0.05)) - -} diff --git a/man/postprocess_finemapping_fits.Rd b/man/postprocess_finemapping_fits.Rd index 3f16a70c..80120792 100644 --- a/man/postprocess_finemapping_fits.Rd +++ b/man/postprocess_finemapping_fits.Rd @@ -17,7 +17,8 @@ postprocess_finemapping_fits( other_quantities = NULL, region = NULL, prior_eff_tol = 1e-09, - min_abs_corr = 0.8 + min_abs_corr = 0.8, + cs_input = NULL ) } \arguments{ diff --git a/man/regions_overlap.Rd b/man/regions_overlap.Rd index 8bebac80..4a1a4697 100644 --- a/man/regions_overlap.Rd +++ b/man/regions_overlap.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/misc.R +% Please edit documentation in R/variant_id.R \name{regions_overlap} \alias{regions_overlap} \title{Test whether two genomic regions overlap} diff --git a/man/susie_ash_rss_weights.Rd b/man/susie_ash_rss_weights.Rd new file mode 100644 index 00000000..6bc1b746 --- /dev/null +++ b/man/susie_ash_rss_weights.Rd @@ -0,0 +1,35 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/regularized_regression.R +\name{susie_ash_rss_weights} +\alias{susie_ash_rss_weights} +\title{Compute SuSiE-ASH-RSS TWAS weights} +\usage{ +susie_ash_rss_weights( + stat, + LD, + susie_ash_rss_fit = NULL, + retain_fit = TRUE, + method_args = list() +) +} +\arguments{ +\item{stat}{List with components \code{z} (z-scores), \code{n} (sample sizes).} + +\item{LD}{LD correlation matrix.} + +\item{susie_ash_rss_fit}{Optional pre-fitted SuSiE-ASH-RSS object.} + +\item{retain_fit}{If TRUE, stores the fitted object as an attribute.} + +\item{method_args}{Named list of additional arguments passed to +\code{susieR::susie_rss()}. Use this instead of \code{...} to avoid +partial matching of short argument names (e.g. \code{L}) to the +\code{LD} parameter.} +} +\value{ +Numeric vector of variant weights. +} +\description{ +Extracts coefficients from an existing SuSiE-ASH-RSS fit or fits +\code{susieR::susie_rss()} with \code{unmappable_effects = "ash"}. +} diff --git a/man/susie_inf_rss_weights.Rd b/man/susie_inf_rss_weights.Rd new file mode 100644 index 00000000..7ea12e24 --- /dev/null +++ b/man/susie_inf_rss_weights.Rd @@ -0,0 +1,35 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/regularized_regression.R +\name{susie_inf_rss_weights} +\alias{susie_inf_rss_weights} +\title{Compute SuSiE-inf-RSS TWAS weights} +\usage{ +susie_inf_rss_weights( + stat, + LD, + susie_inf_rss_fit = NULL, + retain_fit = TRUE, + method_args = list() +) +} +\arguments{ +\item{stat}{List with components \code{z} (z-scores), \code{n} (sample sizes).} + +\item{LD}{LD correlation matrix.} + +\item{susie_inf_rss_fit}{Optional pre-fitted SuSiE-inf-RSS object.} + +\item{retain_fit}{If TRUE, stores the fitted object as an attribute.} + +\item{method_args}{Named list of additional arguments passed to +\code{susieR::susie_rss()}. Use this instead of \code{...} to avoid +partial matching of short argument names (e.g. \code{L}) to the +\code{LD} parameter.} +} +\value{ +Numeric vector of variant weights. +} +\description{ +Extracts coefficients from an existing SuSiE-inf-RSS fit or fits +\code{susieR::susie_rss()} with \code{unmappable_effects = "inf"}. +} diff --git a/man/susie_rss_weights.Rd b/man/susie_rss_weights.Rd new file mode 100644 index 00000000..487ea46e --- /dev/null +++ b/man/susie_rss_weights.Rd @@ -0,0 +1,35 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/regularized_regression.R +\name{susie_rss_weights} +\alias{susie_rss_weights} +\title{Compute SuSiE-RSS TWAS weights} +\usage{ +susie_rss_weights( + stat, + LD, + susie_rss_fit = NULL, + retain_fit = TRUE, + method_args = list() +) +} +\arguments{ +\item{stat}{List with components \code{z} (z-scores), \code{n} (sample sizes).} + +\item{LD}{LD correlation matrix.} + +\item{susie_rss_fit}{Optional pre-fitted SuSiE-RSS object.} + +\item{retain_fit}{If TRUE, stores the fitted object as an attribute.} + +\item{method_args}{Named list of additional arguments passed to +\code{susieR::susie_rss()}. Use this instead of \code{...} to avoid +partial matching of short argument names (e.g. \code{L}) to the +\code{LD} parameter.} +} +\value{ +Numeric vector of variant weights. +} +\description{ +Extracts coefficients from an existing SuSiE-RSS fit or fits +\code{susieR::susie_rss()} from summary statistics and LD. +} diff --git a/man/twas_analysis.Rd b/man/twas_analysis.Rd index 51027c11..df173595 100644 --- a/man/twas_analysis.Rd +++ b/man/twas_analysis.Rd @@ -12,7 +12,9 @@ twas_analysis( V = NULL, D = NULL, n_sketch = NULL, - ld_variant_ids = NULL + ld_variant_ids = NULL, + combine_method = "acat", + combine_if_no_cv = FALSE ) } \arguments{ @@ -23,12 +25,36 @@ twas_analysis( \item{LD_matrix}{A matrix representing linkage disequilibrium between variants.} \item{extract_variants_objs}{A vector of variant identifiers to extract from the GWAS and LD matrix.} + +\item{V}{SVD right-singular vectors from LD sketch (optional).} + +\item{D}{SVD singular values from LD sketch (optional).} + +\item{n_sketch}{Sample size of LD sketch (optional).} + +\item{ld_variant_ids}{Variant IDs in the LD sketch (optional).} + +\item{combine_method}{P-value combination method: \code{"acat"} (default), +\code{"hmp"}, \code{"fisher"}, \code{"stouffer"}, \code{"invchisq"}, +\code{"gbj"}, \code{"aspu"}, or \code{"gates"}.} + +\item{combine_if_no_cv}{Logical. If TRUE and no CV performance is available, +combine per-method p-values into an omnibus result.} } \value{ -A list with TWAS z-scores and p-values across four methods for each gene. +A list with TWAS z-scores and p-values across methods for each gene. + When omnibus combination is enabled, includes an additional \code{"omnibus"} + entry. } \description{ Performs TWAS analysis using the provided weights matrix, GWAS summary statistics database, and LD matrix. It extracts the necessary GWAS summary statistics and LD matrix based on the specified variants and computes the z-score and p-value for each gene. } +\details{ +When \code{combine_if_no_cv = TRUE} and there are at least two methods with +valid p-values, an omnibus p-value is computed via the method specified in +\code{combine_method} and appended as an \code{"omnibus"} entry. This is +intended for summary-statistics TWAS where cross-validation performance is +not available for model selection. +} diff --git a/man/twas_joint_z.Rd b/man/twas_joint_z.Rd index 29fd5bfd..9c4357d8 100644 --- a/man/twas_joint_z.Rd +++ b/man/twas_joint_z.Rd @@ -4,7 +4,15 @@ \alias{twas_joint_z} \title{Multi-condition TWAS joint test} \usage{ -twas_joint_z(weights, z, R = NULL, X = NULL) +twas_joint_z( + weights, + z, + R = NULL, + X = NULL, + V = NULL, + D_svd = NULL, + n_sketch = NULL +) } \arguments{ \item{weights}{A matrix of weights, where each column corresponds to a different condition.} @@ -14,6 +22,13 @@ twas_joint_z(weights, z, R = NULL, X = NULL) \item{R}{An optional correlation matrix. If not provided, it will be calculated from the genotype matrix X.} \item{X}{An optional genotype matrix. If R is not provided, X must be supplied to calculate the correlation matrix.} + +\item{V}{Optional SVD right-singular vectors (variants x components) from an LD sketch. +When provided with \code{D_svd} and \code{n_sketch}, avoids forming the full LD matrix.} + +\item{D_svd}{Optional SVD singular values (vector) from an LD sketch.} + +\item{n_sketch}{Optional sample size of the LD sketch.} } \value{ A list containing the following elements: diff --git a/man/twas_weights_sumstat_pipeline.Rd b/man/twas_weights_sumstat_pipeline.Rd new file mode 100644 index 00000000..16b8dd2d --- /dev/null +++ b/man/twas_weights_sumstat_pipeline.Rd @@ -0,0 +1,78 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/twas_weights.R +\name{twas_weights_sumstat_pipeline} +\alias{twas_weights_sumstat_pipeline} +\title{Train TWAS weights from summary statistics and LD reference} +\usage{ +twas_weights_sumstat_pipeline( + sumstats, + LD_data, + n, + methods = list(lassosum_rss = list(), prs_cs = list(phi = 1e-04, n_iter = 1000, + n_burnin = 500, thin = 5), sdpr = list(iter = 1000, burn = 200, thin = 1, verbose = + FALSE), susie_rss = list(), susie_inf_rss = list()), + p_thresholds = c(0.001, 0.05), + check_ld_method = "eigenfix", + qc_method = NULL, + keep_indel = TRUE, + pip_cutoff_to_skip = 0, + impute = FALSE, + impute_opts = list(rcond = 0.01, R2_threshold = 0.6, minimum_ld = 5, lamb = 0.01), + var_y = 1, + verbose = 1 +) +} +\arguments{ +\item{sumstats}{Data.frame with columns: \code{variant_id}, \code{A1}, +\code{A2}, \code{chrom}, \code{pos}, and either \code{z} or both +\code{beta} and \code{se}.} + +\item{LD_data}{LDData S4 object, or a legacy list with \code{LD_matrix}, +\code{LD_variants}, \code{ref_panel}. Can also be a plain correlation +matrix (variant IDs taken from row/colnames).} + +\item{n}{eQTL study sample size (scalar).} + +\item{methods}{Named list of RSS weight methods and their arguments. +Method names correspond to functions named +\code{_weights(stat, LD, ...)}. Defaults include lassosum_rss, +prs_cs, sdpr, susie_rss, and susie_inf_rss.} + +\item{p_thresholds}{Numeric vector of p-value thresholds for P+T weights. +Set to NULL to skip.} + +\item{check_ld_method}{LD matrix repair method: \code{"eigenfix"} (default), +\code{"shrink"}, or NULL to skip.} + +\item{qc_method}{RSS QC method for eQTL data: \code{"slalom"}, +\code{"dentist"}, or NULL/\code{"none"} to skip.} + +\item{keep_indel}{Whether to keep indels during QC. Default TRUE.} + +\item{pip_cutoff_to_skip}{PIP threshold for early stopping. Default 0 (off).} + +\item{impute}{Whether to run RAISS imputation. Default FALSE.} + +\item{impute_opts}{RAISS imputation parameters.} + +\item{var_y}{Phenotype variance. Default 1.} + +\item{verbose}{Verbosity level.} +} +\value{ +A list with: +\describe{ + \item{twas_weights}{A \code{TWASWeights} S4 object with + \code{standardized = TRUE}.} + \item{finemapping_result}{A \code{FineMappingResult} S4 object from the + SuSiE-RSS fit, or NULL if no SuSiE-RSS method was used.} + \item{qc_summary}{List with outlier counts and QC metadata.} +} +} +\description{ +Replaces the OTTERS pipeline with a properly integrated workflow that: +(1) runs RSS QC on eQTL summary statistics, (2) trains weights via multiple +RSS methods, and (3) extracts fine-mapping results from the shared SuSiE-RSS +fit. Returns a \code{TWASWeights} S4 object with \code{standardized = TRUE} +that feeds directly into \code{harmonize_twas} and \code{twas_analysis}. +} diff --git a/tests/testthat/test_encoloc.R b/tests/testthat/test_encoloc.R index 05740b48..c348c9b5 100644 --- a/tests/testthat/test_encoloc.R +++ b/tests/testthat/test_encoloc.R @@ -856,7 +856,7 @@ test_that("extract_ld_for_variants loads LD, aligns names, and subsets", { colnames(ld_mat) <- rownames(ld_mat) <- ld_variants local_mocked_bindings( - load_LD_matrix = function(meta_file, region) { + load_LD_matrix = function(meta_file, region, ...) { list(LD_matrix = ld_mat, LD_variants = ld_variants) } ) diff --git a/tests/testthat/test_otters.R b/tests/testthat/test_otters.R deleted file mode 100644 index b8e8de3e..00000000 --- a/tests/testthat/test_otters.R +++ /dev/null @@ -1,210 +0,0 @@ -context("otters") - -# ---- otters_weights ---- -test_that("otters_weights returns named list of weight vectors", { - set.seed(42) - p <- 20 - n <- 500 - z <- rnorm(p, sd = 2) - R <- diag(p) - sumstats <- data.frame(z = z) - result <- otters_weights(sumstats, R, n, - methods = list(lassosum_rss = list()), - p_thresholds = c(0.05) - ) - expect_type(result, "list") - expect_true("PT_0.05" %in% names(result)) - expect_true("lassosum_rss" %in% names(result)) - expect_equal(length(result$PT_0.05), p) - expect_equal(length(result$lassosum_rss), p) -}) - -test_that("otters_weights computes z from beta/se if z missing", { - set.seed(42) - p <- 10 - n <- 100 - sumstats <- data.frame(beta = rnorm(p, sd = 0.1), se = rep(0.05, p)) - R <- diag(p) - result <- otters_weights(sumstats, R, n, - methods = list(lassosum_rss = list()), - p_thresholds = c(0.05) - ) - expect_true("lassosum_rss" %in% names(result)) - expect_equal(length(result$lassosum_rss), p) -}) - -test_that("otters_weights errors when no z or beta/se", { - sumstats <- data.frame(x = 1:5) - expect_error(otters_weights(sumstats, diag(5), 100), "z.*beta.*se") -}) - -test_that("otters_weights P+T selects correct SNPs", { - set.seed(42) - p <- 20 - n <- 500 - # Large z-scores for first 3 SNPs (should pass threshold) - z <- c(rep(5, 3), rep(0.1, 17)) - sumstats <- data.frame(z = z) - R <- diag(p) - result <- otters_weights(sumstats, R, n, - methods = list(), p_thresholds = c(0.001) - ) - w <- result$PT_0.001 - # First 3 should be non-zero, rest should be zero - expect_true(all(w[1:3] != 0)) - expect_true(all(w[4:20] == 0)) -}) - -test_that("otters_weights warns on unknown method", { - sumstats <- data.frame(z = rnorm(5)) - expect_warning( - otters_weights(sumstats, diag(5), 100, - methods = list(nonexistent_method = list()), - p_thresholds = NULL), - "not found" - ) -}) - -test_that("otters_weights with multiple methods returns all", { - set.seed(42) - p <- 15 - n <- 500 - z <- rnorm(p, sd = 2) - R <- diag(p) - for (i in 1:(p - 1)) { R[i, i + 1] <- 0.3; R[i + 1, i] <- 0.3 } - sumstats <- data.frame(z = z) - result <- otters_weights(sumstats, R, n, - methods = list( - lassosum_rss = list(), - prs_cs = list(n_iter = 50, n_burnin = 10, thin = 2, seed = 42) - ), - p_thresholds = c(0.001, 0.05) - ) - expect_true(all(c("PT_0.001", "PT_0.05", "lassosum_rss", "prs_cs") %in% names(result))) - for (nm in names(result)) { - expect_equal(length(result[[nm]]), p) - expect_true(all(is.finite(result[[nm]]))) - } -}) - -test_that("otters_weights passes correlation-scale stat fields to lassosum", { - p <- 5 - n <- 100 - z <- rnorm(p) - R <- diag(p) - sumstats <- data.frame(z = z) - captured <- new.env(parent = emptyenv()) - local_mocked_bindings( - lassosum_rss_weights = function(stat, LD, ...) { - captured$lassosum_stat <- stat - captured$lassosum_dots <- list(...) - rep(0.1, nrow(LD)) - }, - prs_cs_weights = function(stat, LD, ...) { - captured$prs_cs_dots <- list(...) - rep(0.2, nrow(LD)) - } - ) - - result <- otters_weights( - sumstats, R, n, - methods = list( - lassosum_rss = list(), - prs_cs = list(phi = 1e-4) - ), - p_thresholds = NULL, - check_ld_method = NULL - ) - - expect_equal(result$lassosum_rss, rep(0.1, p)) - expect_equal(result$prs_cs, rep(0.2, p)) - expect_equal(captured$lassosum_stat$cor, z / sqrt(n)) - expect_equal(captured$lassosum_stat$z, z) - expect_equal(captured$lassosum_stat$b, z / sqrt(n)) -}) - -# ---- otters_association ---- -test_that("otters_association returns correct structure", { - set.seed(42) - p <- 20 - gwas_z <- rnorm(p) - R <- diag(p) - weights <- list( - method1 = rnorm(p, sd = 0.01), - method2 = rnorm(p, sd = 0.01) - ) - result <- otters_association(weights, gwas_z, R) - expect_true(is.data.frame(result)) - expect_true(all(c("method", "twas_z", "twas_pval", "n_snps") %in% colnames(result))) - # Two methods + ACAT combined - expect_equal(nrow(result), 3) - expect_true("ACAT_combined" %in% result$method) -}) - -test_that("otters_association handles all-zero weights gracefully", { - p <- 10 - gwas_z <- rnorm(p) - R <- diag(p) - weights <- list(zero_method = rep(0, p), nonzero = rnorm(p, sd = 0.01)) - result <- otters_association(weights, gwas_z, R) - zero_row <- result[result$method == "zero_method", ] - expect_true(is.na(zero_row$twas_z)) - expect_equal(zero_row$n_snps, 0) -}) - -test_that("otters_association with single method has no combined row", { - p <- 10 - gwas_z <- rnorm(p) - R <- diag(p) - weights <- list(only_method = rnorm(p, sd = 0.01)) - result <- otters_association(weights, gwas_z, R) - # Only one valid p-value, so no ACAT combination - expect_false("ACAT_combined" %in% result$method) -}) - -test_that("otters_association uses HMP when specified", { - skip_if_not_installed("harmonicmeanp") - set.seed(42) - p <- 20 - gwas_z <- rnorm(p) - R <- diag(p) - weights <- list(m1 = rnorm(p, sd = 0.01), m2 = rnorm(p, sd = 0.01)) - result <- otters_association(weights, gwas_z, R, combine_method = "hmp") - expect_true("HMP_combined" %in% result$method) -}) - -# ---- end-to-end integration ---- -test_that("otters_weights + otters_association end-to-end on simulated data", { - set.seed(2024) - n_eqtl <- 500 - n_gwas <- 10000 - p <- 20 - - # Simulate genotypes and eQTL - X <- matrix(rbinom(n_eqtl * p, 2, 0.3), nrow = n_eqtl) - beta_eqtl <- rep(0, p) - beta_eqtl[c(3, 10)] <- c(0.3, -0.2) - expr <- X %*% beta_eqtl + rnorm(n_eqtl) - eqtl_z <- as.vector(cor(expr, X)) * sqrt(n_eqtl) - R <- cor(X) - - # Simulate GWAS (gene affects trait) - beta_gwas_gene <- 0.1 - gwas_z <- R %*% (beta_eqtl * beta_gwas_gene * sqrt(n_gwas)) + rnorm(p) - - # Stage I: train weights - sumstats <- data.frame(z = eqtl_z) - weights <- otters_weights(sumstats, R, n_eqtl, - methods = list(lassosum_rss = list()), - p_thresholds = c(0.05) - ) - expect_true(length(weights) >= 2) - - # Stage II: test association - result <- otters_association(weights, as.numeric(gwas_z), R) - expect_true(is.data.frame(result)) - expect_true(nrow(result) >= 2) - # At least one method should have a small-ish p-value (gene is truly associated) - min_pval <- min(result$twas_pval, na.rm = TRUE) - expect_true(min_pval < 0.5) -}) diff --git a/tests/testthat/test_slalom.R b/tests/testthat/test_slalom.R index 011ad538..23be312d 100644 --- a/tests/testthat/test_slalom.R +++ b/tests/testthat/test_slalom.R @@ -53,7 +53,7 @@ test_that("slalom basic output structure", { test_that("slalom errors on non-square R", { z <- rnorm(10) R <- matrix(rnorm(50), nrow = 5, ncol = 10) - expect_error(slalom(zScore = z, R = R), "LD_mat must be a square matrix") + expect_error(slalom(zScore = z, R = R), "R must be a square matrix") }) test_that("slalom accepts X matrix instead of R", { @@ -411,7 +411,7 @@ test_that("edge case: mismatched dimensions error", { z <- rnorm(10) R <- diag(5) expect_error(slalom(zScore = z, R = R), - "LD_mat must be a square matrix matching the length of zScore") + "R must be a square matrix matching the length of zScore") }) test_that("edge case: no R and no X provided errors", { diff --git a/tests/testthat/test_susie_wrapper.R b/tests/testthat/test_susie_wrapper.R index c0df1b9e..998d227c 100644 --- a/tests/testthat/test_susie_wrapper.R +++ b/tests/testthat/test_susie_wrapper.R @@ -612,8 +612,8 @@ test_that("susie_rss_pipeline X-mode passes X to susie_rss and computes LD from expect_true("X" %in% names(captured_susie_args)) expect_null(captured_susie_args$R) expect_equal(captured_susie_args$R_mismatch, "eb") - # Post-processor should have received a p x p matrix (LD computed from X) - expect_equal(dim(captured_pp_data_x), c(p, p)) + # Post-processor receives raw X matrix (n x p), not cor(X) + expect_equal(dim(captured_pp_data_x), c(n, p)) }) # ============================================================================= @@ -652,10 +652,8 @@ test_that("susie_rss_pipeline computes LD from first panel when X_mat is a list" format_finemapping_output = function(post, primary_method) list(variant_names = vnames) ) result <- susie_rss_pipeline(list(z = z), X_mat = X_list) - # data_x should be a p x p correlation matrix computed from X1 (first panel) - expect_equal(dim(captured_pp_data_x), c(p, p)) - # It should be a symmetric matrix (correlation/LD) - expect_equal(captured_pp_data_x, t(captured_pp_data_x), tolerance = 1e-10) + # data_x should be the first panel's X matrix (n1 x p), not cor(X) + expect_equal(dim(captured_pp_data_x), c(n1, p)) }) # ============================================================================= diff --git a/tests/testthat/test_twas_weights_rss.R b/tests/testthat/test_twas_weights_rss.R new file mode 100644 index 00000000..9c65f8d5 --- /dev/null +++ b/tests/testthat/test_twas_weights_rss.R @@ -0,0 +1,195 @@ +context("SS-TWAS: weights, pipeline, and omnibus combination") + +# ============================================================================= +# TWASWeights S4 class with standardized slot +# ============================================================================= + +test_that("TWASWeights accepts standardized = TRUE", { + wt <- TWASWeights( + weights = list(method1 = matrix(1:5, ncol = 1)), + variant_ids = paste0("v", 1:5), + standardized = TRUE + ) + expect_true(wt@standardized) + expect_equal(length(wt@methods), 1L) +}) + +test_that("TWASWeights defaults to standardized = FALSE", { + wt <- TWASWeights( + weights = list(method1 = matrix(1:5, ncol = 1)), + variant_ids = paste0("v", 1:5) + ) + expect_false(wt@standardized) +}) + +test_that("TWASWeights show method includes standardized", { + wt <- TWASWeights( + weights = list(m1 = matrix(0, nrow = 3, ncol = 1)), + variant_ids = paste0("v", 1:3), + standardized = TRUE + ) + out <- capture.output(show(wt)) + expect_true(any(grepl("Standardized: TRUE", out))) +}) + +# ============================================================================= +# SuSiE-RSS weight extraction +# ============================================================================= + +test_that(".susie_rss_extract_weights returns correct-length vector", { + skip_if_not_installed("susieR") + set.seed(42) + p <- 20 + n <- 500 + R <- diag(p) + z <- rnorm(p) + w <- pecotmr:::.susie_rss_extract_weights( + fit = NULL, z = z, R = R, n = n, + required_fields = c("alpha", "mu", "X_column_scale_factors"), + fit_args = list(L = 5) + ) + expect_equal(length(w), p) + expect_true(all(is.finite(w))) +}) + +test_that("susie_rss_weights follows (stat, LD) convention", { + skip_if_not_installed("susieR") + set.seed(42) + p <- 20 + n <- 500 + R <- diag(p) + z <- rnorm(p) + stat <- list(b = z / sqrt(n), cor = z / sqrt(n), z = z, n = rep(n, p)) + w <- susie_rss_weights(stat, R, method_args = list(L = 5)) + expect_equal(length(w), p) + expect_true(all(is.finite(w))) +}) + +test_that("susie_rss_weights retains fit when retain_fit = TRUE", { + skip_if_not_installed("susieR") + set.seed(42) + p <- 20 + n <- 500 + R <- diag(p) + z <- rnorm(p) + stat <- list(b = z / sqrt(n), cor = z / sqrt(n), z = z, n = rep(n, p)) + w <- susie_rss_weights(stat, R, retain_fit = TRUE, method_args = list(L = 5)) + expect_false(is.null(attr(w, "fit"))) +}) + +test_that("susie_inf_rss_weights works", { + skip_if_not_installed("susieR") + set.seed(42) + p <- 20 + n <- 500 + R <- diag(p) + z <- rnorm(p) + stat <- list(b = z / sqrt(n), cor = z / sqrt(n), z = z, n = rep(n, p)) + w <- susie_inf_rss_weights(stat, R, method_args = list(L = 5)) + expect_equal(length(w), p) + expect_true(all(is.finite(w))) +}) + +# ============================================================================= +# Two-stage SuSiE-RSS fitting +# ============================================================================= + +test_that("fit_susie_inf_then_susie_rss returns two fits", { + skip_if_not_installed("susieR") + set.seed(42) + p <- 20 + n <- 500 + R <- diag(p) + z <- rnorm(p) + fits <- fit_susie_inf_then_susie_rss(z, R, n, args = list(L = 5)) + expect_true(is.list(fits)) + expect_true("susie" %in% names(fits)) + expect_true("susie_inf" %in% names(fits)) + expect_true("susie_inf" %in% class(fits$susie_inf)) + expect_true("susie_rss" %in% class(fits$susie)) +}) + +# ============================================================================= +# twas_analysis omnibus combination +# ============================================================================= + +test_that("twas_analysis adds omnibus when combine_if_no_cv = TRUE", { + set.seed(42) + p <- 10 + R <- diag(p) + rownames(R) <- colnames(R) <- paste0("v", 1:p) + weights_matrix <- matrix(rnorm(p * 3), ncol = 3) + rownames(weights_matrix) <- paste0("v", 1:p) + colnames(weights_matrix) <- c("m1", "m2", "m3") + gwas_db <- data.frame( + variant_id = paste0("v", 1:p), + z = rnorm(p) + ) + + result <- twas_analysis( + weights_matrix, gwas_db, LD_matrix = R, + extract_variants_objs = paste0("v", 1:p), + combine_if_no_cv = TRUE + ) + expect_true("omnibus" %in% names(result)) + expect_true(!is.null(result$omnibus$pval)) +}) + +test_that("twas_analysis skips omnibus when combine_if_no_cv = FALSE", { + set.seed(42) + p <- 10 + R <- diag(p) + rownames(R) <- colnames(R) <- paste0("v", 1:p) + weights_matrix <- matrix(rnorm(p * 3), ncol = 3) + rownames(weights_matrix) <- paste0("v", 1:p) + colnames(weights_matrix) <- c("m1", "m2", "m3") + gwas_db <- data.frame( + variant_id = paste0("v", 1:p), + z = rnorm(p) + ) + + result <- twas_analysis( + weights_matrix, gwas_db, LD_matrix = R, + extract_variants_objs = paste0("v", 1:p), + combine_if_no_cv = FALSE + ) + expect_false("omnibus" %in% names(result)) +}) + +# ============================================================================= +# twas_weights_sumstat_pipeline end-to-end +# ============================================================================= + +test_that("twas_weights_sumstat_pipeline produces TWASWeights with standardized = TRUE", { + skip_if_not_installed("susieR") + set.seed(42) + p <- 30 + n <- 1000 + R <- diag(p) + rownames(R) <- colnames(R) <- paste0("1:", 1000 + seq_len(p), ":A:T") + z <- rnorm(p, sd = 2) + sumstats <- data.frame( + variant_id = rownames(R), + chrom = "1", + pos = 1000 + seq_len(p), + A1 = "T", + A2 = "A", + z = z + ) + + result <- twas_weights_sumstat_pipeline( + sumstats = sumstats, + LD_data = R, + n = n, + methods = list(susie_rss = list(L = 5)), + p_thresholds = c(0.05), + check_ld_method = NULL, + verbose = 0 + ) + + expect_false(is.null(result$twas_weights)) + expect_true(is(result$twas_weights, "TWASWeights")) + expect_true(result$twas_weights@standardized) + expect_true(length(result$twas_weights@variant_ids) > 0) + expect_false(result$qc_summary$skipped) +}) diff --git a/tests/testthat/test_univariate_pipeline.R b/tests/testthat/test_univariate_pipeline.R index f974cf94..15f93618 100644 --- a/tests/testthat/test_univariate_pipeline.R +++ b/tests/testthat/test_univariate_pipeline.R @@ -2076,7 +2076,8 @@ test_that("rss: mixture LD_data (list of X panels) preserves list shape into sus column_file_path = "/fake/columns.yml", LD_data = list(LD_matrix = list(X1, X2), ref_panel = ss), qc_method = "slalom", - finemapping_method = "susie_rss" + finemapping_method = "susie_rss", + impute = FALSE ) # Mixture path => list of subset matrices passed to susie_rss_pipeline as X_mat From 3355e8b812219215e96b393f62fc76a942d924ce Mon Sep 17 00:00:00 2001 From: Daniel Nachun Date: Mon, 25 May 2026 19:29:00 -0500 Subject: [PATCH 05/11] more s4 refactor --- NAMESPACE | 16 ++ R/AllGenerics.R | 61 ++++++ R/AllMethods.R | 92 ++++---- R/LD.R | 35 ++- R/colocboost_pipeline.R | 49 +++-- R/encoloc.R | 252 +++++++++++++++++----- R/file_utils.R | 72 ++++++- R/ld_loader.R | 18 +- R/sumstats_qc.R | 181 ++++------------ R/susie_wrapper.R | 2 +- R/twas.R | 40 +++- R/twas_weights.R | 103 ++++++--- R/univariate_pipeline.R | 43 ++-- man/coloc_wrapper.Rd | 96 +++++++-- man/getCVPerformance.Rd | 22 ++ man/getFits.Rd | 22 ++ man/getMethodNames.Rd | 20 ++ man/getRefPanel.Rd | 21 ++ man/getStandardized.Rd | 20 ++ man/getTopLoci.Rd | 20 ++ man/getTrimmedFit.Rd | 20 ++ man/getVariantIds.Rd | 3 + man/getVariantNames.Rd | 20 ++ man/twas_weights_sumstat_pipeline.Rd | 2 +- tests/testthat/test_LD.R | 37 ++-- tests/testthat/test_colocboost_pipeline.R | 41 ++-- tests/testthat/test_data_structures.R | 16 +- tests/testthat/test_encoloc.R | 203 +++++++++++++++++ tests/testthat/test_ensemble_weights.R | 22 +- tests/testthat/test_sumstats_qc.R | 153 +++++++------ tests/testthat/test_twas_sketch.R | 40 ++-- tests/testthat/test_twas_weights_rss.R | 24 ++- 32 files changed, 1239 insertions(+), 527 deletions(-) create mode 100644 man/getCVPerformance.Rd create mode 100644 man/getFits.Rd create mode 100644 man/getMethodNames.Rd create mode 100644 man/getRefPanel.Rd create mode 100644 man/getStandardized.Rd create mode 100644 man/getTopLoci.Rd create mode 100644 man/getTrimmedFit.Rd create mode 100644 man/getVariantNames.Rd diff --git a/NAMESPACE b/NAMESPACE index da26439a..879f237a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -74,24 +74,32 @@ export(fsusie_wrapper) export(getBaseline) export(getBlockMetadata) export(getCS) +export(getCVPerformance) export(getCandidates) export(getCorrelation) export(getEffects) export(getEnrichment) +export(getFits) export(getGenotypes) export(getLBF) export(getLocal) export(getMaf) +export(getMethodNames) export(getN) export(getPIP) +export(getRefPanel) export(getResidualX) export(getResidualXScalar) export(getResidualY) export(getResidualYScalar) export(getScoreStats) +export(getStandardized) +export(getTopLoci) +export(getTrimmedFit) export(getVarY) export(getVariantIds) export(getVariantInfo) +export(getVariantNames) export(getWeights) export(getZ) export(get_ctwas_meta_data) @@ -230,23 +238,31 @@ exportMethods(computeLdScores) exportMethods(estimateH2) exportMethods(getBlockMetadata) exportMethods(getCS) +exportMethods(getCVPerformance) exportMethods(getCorrelation) exportMethods(getEffects) exportMethods(getEnrichment) +exportMethods(getFits) exportMethods(getGenotypes) exportMethods(getLBF) exportMethods(getLocal) exportMethods(getMaf) +exportMethods(getMethodNames) exportMethods(getN) exportMethods(getPIP) +exportMethods(getRefPanel) exportMethods(getResidualX) exportMethods(getResidualXScalar) exportMethods(getResidualY) exportMethods(getResidualYScalar) exportMethods(getScoreStats) +exportMethods(getStandardized) +exportMethods(getTopLoci) +exportMethods(getTrimmedFit) exportMethods(getVarY) exportMethods(getVariantIds) exportMethods(getVariantInfo) +exportMethods(getVariantNames) exportMethods(getWeights) exportMethods(getZ) exportMethods(hasGenotypes) diff --git a/R/AllGenerics.R b/R/AllGenerics.R index dee764a3..23cfdd36 100644 --- a/R/AllGenerics.R +++ b/R/AllGenerics.R @@ -218,6 +218,14 @@ setGeneric("getVariantIds", function(x) standardGeneric("getVariantIds")) #' @export setGeneric("getVariantInfo", function(x) standardGeneric("getVariantInfo")) +#' @title Get Reference Panel +#' @description Extract reference panel metadata as a data.frame from +#' an \code{LDData} object, including chrom and pos columns. +#' @param x An \code{LDData} object. +#' @return A data.frame with variant metadata including chrom, pos, A1, A2. +#' @export +setGeneric("getRefPanel", function(x) standardGeneric("getRefPanel")) + #' @title Get Block Metadata #' @description Extract block metadata from an \code{LDData} object. #' @param x An \code{LDData} object. @@ -279,6 +287,27 @@ setGeneric("getResidualYScalar", #' @export setGeneric("getPIP", function(x) standardGeneric("getPIP")) +#' @title Get Trimmed Fit +#' @description Extract the trimmed SuSiE fit from a FineMappingResult. +#' @param x A \code{FineMappingResult} object. +#' @return A list (trimmed SuSiE fit). +#' @export +setGeneric("getTrimmedFit", function(x) standardGeneric("getTrimmedFit")) + +#' @title Get Variant Names +#' @description Extract variant names from a FineMappingResult. +#' @param x A \code{FineMappingResult} object. +#' @return Character vector of variant names. +#' @export +setGeneric("getVariantNames", function(x) standardGeneric("getVariantNames")) + +#' @title Get Top Loci +#' @description Extract top loci data.frame from a FineMappingResult. +#' @param x A \code{FineMappingResult} object. +#' @return A data.frame of top loci. +#' @export +setGeneric("getTopLoci", function(x) standardGeneric("getTopLoci")) + #' @title Get Credible Sets #' @description Extract credible set assignments. #' @param x A \code{FineMappingResult} object. @@ -318,6 +347,38 @@ setGeneric("getEffects", function(x) standardGeneric("getEffects")) setGeneric("getWeights", function(x, method = NULL) standardGeneric("getWeights")) +#' @title Get Standardized Flag +#' @description Check whether weights are on the standardized (correlation) scale. +#' @param x A \code{TWASWeights} object. +#' @return Logical. +#' @export +setGeneric("getStandardized", function(x) standardGeneric("getStandardized")) + +#' @title Get CV Performance +#' @description Extract cross-validation performance metrics. +#' @param x A \code{TWASWeights} object. +#' @param method Character, specific method name. If NULL, returns all. +#' @return A list or single element. +#' @export +setGeneric("getCVPerformance", + function(x, method = NULL) standardGeneric("getCVPerformance")) + +#' @title Get Model Fits +#' @description Extract fitted model objects from a TWASWeights object. +#' @param x A \code{TWASWeights} object. +#' @param method Character, specific method name. If NULL, returns all. +#' @return A list or single element. +#' @export +setGeneric("getFits", + function(x, method = NULL) standardGeneric("getFits")) + +#' @title Get Method Names +#' @description Extract method names from a TWASWeights object. +#' @param x A \code{TWASWeights} object. +#' @return Character vector. +#' @export +setGeneric("getMethodNames", function(x) standardGeneric("getMethodNames")) + # ============================================================================= # VCF/BCF writer generic # ============================================================================= diff --git a/R/AllMethods.R b/R/AllMethods.R index bf827869..1929ed3c 100644 --- a/R/AllMethods.R +++ b/R/AllMethods.R @@ -43,15 +43,13 @@ setMethod("getCorrelation", "LDData", function(x) { if (is.null(x@genotype_handle)) { stop("No correlation matrix or genotype handle available") } - # Recompute from genotype handle if (is.list(x@genotype_handle)) { - # Multi-panel: compute from first handle - geno <- extractBlockGenotypes(x@genotype_handle[[1]], x@snp_idx) - X <- t(assay(geno, "dosage")) - } else { - geno <- extractBlockGenotypes(x@genotype_handle, x@snp_idx) - X <- t(assay(geno, "dosage")) + stop("Cannot compute single correlation matrix from mixture panels. ", + "Use getGenotypes() and compute LD per-panel, or pass X directly ", + "to susie_rss().") } + geno <- extractBlockGenotypes(x@genotype_handle, x@snp_idx) + X <- t(assay(geno, "dosage")) compute_LD(X, method = "sample") }) @@ -94,42 +92,14 @@ setMethod("getBlockMetadata", "LDData", function(x) { x@block_metadata }) -#' @title Convert LDData to Legacy List -#' @description Convert an \code{LDData} object to the legacy list format -#' for backwards compatibility with internal functions that still expect -#' the old format. -#' @param x An \code{LDData} object. -#' @return A list with LD_variants, LD_matrix, ref_panel, block_metadata, -#' is_genotype. -#' @keywords internal -#' @noRd -ld_data_to_list <- function(x, skip_correlation = FALSE) { +#' @rdname getRefPanel +#' @export +setMethod("getRefPanel", "LDData", function(x) { mc <- as.data.frame(mcols(x@variants)) mc$chrom <- as.character(seqnames(x@variants)) mc$pos <- start(x@variants) - ref_panel <- mc - - bm <- x@block_metadata - if (is(bm, "LDBlocks")) { - bm <- as.data.frame(bm@blocks) - } - - LD_matrix <- if (skip_correlation) { - NULL - } else if (!is.null(x@correlation)) { - x@correlation - } else { - getCorrelation(x) - } - - list( - LD_variants = getVariantIds(x), - LD_matrix = LD_matrix, - ref_panel = ref_panel, - block_metadata = bm, - is_genotype = FALSE - ) -} + mc +}) # ============================================================================= # Helper: build variant GRanges from ref_panel data.frame @@ -424,6 +394,48 @@ setMethod("getWeights", "TWASWeights", function(x, method = NULL) { x@weights[[method]] }) +#' @rdname getStandardized +#' @export +setMethod("getStandardized", "TWASWeights", function(x) x@standardized) + +#' @rdname getCVPerformance +#' @export +setMethod("getCVPerformance", "TWASWeights", function(x, method = NULL) { + if (is.null(method)) return(x@cv_performance) + x@cv_performance[[method]] +}) + +#' @rdname getFits +#' @export +setMethod("getFits", "TWASWeights", function(x, method = NULL) { + if (is.null(method)) return(x@fits) + x@fits[[method]] +}) + +#' @rdname getMethodNames +#' @export +setMethod("getMethodNames", "TWASWeights", function(x) x@methods) + +#' @rdname getVariantIds +#' @export +setMethod("getVariantIds", "TWASWeights", function(x) x@variant_ids) + +# ============================================================================= +# FineMappingResult additional accessors +# ============================================================================= + +#' @rdname getTrimmedFit +#' @export +setMethod("getTrimmedFit", "FineMappingResult", function(x) x@trimmed_fit) + +#' @rdname getVariantNames +#' @export +setMethod("getVariantNames", "FineMappingResult", function(x) x@variant_names) + +#' @rdname getTopLoci +#' @export +setMethod("getTopLoci", "FineMappingResult", function(x) x@top_loci) + # ============================================================================= # top_loci GRanges conversion # ============================================================================= diff --git a/R/LD.R b/R/LD.R index c113fde2..b6ff54c9 100644 --- a/R/LD.R +++ b/R/LD.R @@ -537,17 +537,12 @@ standardize_genotype_hwe <- function(X, allele_freq) { #' @export load_ld_sketch <- function(ld_meta_file_path, region, n_sample = NULL) { result <- load_LD_matrix(ld_meta_file_path, region, return_genotype = TRUE, n_sample = n_sample) - if (is(result, "LDData")) { - X <- getGenotypes(result) - variant_ids <- getVariantIds(result) - ref_panel <- as.data.frame(S4Vectors::mcols(getVariantInfo(result))) - ref_panel$chrom <- as.character(GenomicRanges::seqnames(getVariantInfo(result))) - ref_panel$pos <- GenomicRanges::start(getVariantInfo(result)) - } else { - X <- result$LD_matrix - variant_ids <- result$LD_variants - ref_panel <- result$ref_panel + if (!is(result, "LDData")) { + stop("load_LD_matrix must return an LDData object") } + X <- getGenotypes(result) + variant_ids <- getVariantIds(result) + ref_panel <- getRefPanel(result) # Remove monomorphic variants (zero variance under HWE) p <- ref_panel$allele_freq @@ -735,19 +730,15 @@ filter_variants_by_ld_reference <- function(variant_ids, ld_reference_meta_file, #' @noRd partition_LD_matrix <- function(ld_data, merge_small_blocks = TRUE, min_merged_block_size = 500, max_merged_block_size = 10000) { - # Extract components from ld_data (support both LDData S4 and legacy list) - if (is(ld_data, "LDData")) { - combined_matrix <- getCorrelation(ld_data) - block_metadata <- ld_data@block_metadata - if (is(block_metadata, "LDBlocks")) { - block_metadata <- as.data.frame(block_metadata@blocks) - } - variant_ids <- getVariantIds(ld_data) - } else { - combined_matrix <- ld_data$LD_matrix - block_metadata <- ld_data$block_metadata - variant_ids <- ld_data$LD_variants + if (!is(ld_data, "LDData")) { + stop("ld_data must be an LDData object") + } + combined_matrix <- getCorrelation(ld_data) + block_metadata <- ld_data@block_metadata + if (is(block_metadata, "LDBlocks")) { + block_metadata <- as.data.frame(block_metadata@blocks) } + variant_ids <- getVariantIds(ld_data) # Error if matrix is empty if (is.null(combined_matrix) || nrow(combined_matrix) == 0 || ncol(combined_matrix) == 0) { diff --git a/R/colocboost_pipeline.R b/R/colocboost_pipeline.R index 10275530..5c7ee1ce 100644 --- a/R/colocboost_pipeline.R +++ b/R/colocboost_pipeline.R @@ -66,10 +66,21 @@ region_data_to_colocboost_input <- function(region_data) { ind_records <- ind_records_from_input(ind_input) ind_args <- .cb_format_individual(ind_records) - sumstat_records <- lapply(names(rss_input$rss_input), function(study) { - ld_data <- .normalize_ld_data_for_qc(rss_input$LD_data[[study]]) + # Build sumstat_records using original LD_info to preserve genotype vs LD + # distinction. region_data_to_rss_input converts to LDData S4 (computing R), + # but colocboost needs the original matrix for X_ref/LD formatting. + orig_ld_info <- region_data$sumstat_data$LD_info + sumstat_records <- lapply(seq_along(rss_input$rss_input), function(i) { + study <- names(rss_input$rss_input)[i] + ld_idx <- min(i, length(orig_ld_info)) + orig_ld <- orig_ld_info[[ld_idx]] + ld_mat <- if (is(orig_ld, "LDData")) { + getCorrelation(orig_ld) + } else { + getCorrelation(rss_input$LD_data[[study]]) + } list(rss_input = rss_input$rss_input[[study]], - LD_matrix = ld_data$LD_matrix) + LD_matrix = ld_mat) }) names(sumstat_records) <- names(rss_input$rss_input) sumstat_args <- .cb_format_sumstat(sumstat_records) @@ -1129,7 +1140,7 @@ qc_individual_data <- function(X, Y, maf = NULL, X_variance = NULL, LD_reference_info = NULL, variant_convention = c("A2_A1", "A1_A2")) { is_ld_data <- function(x) { - methods::is(x, "LDData") || (is.list(x) && !is.null(x$LD_matrix)) + is(x, "LDData") } as_reference_info_list <- function(x) { if (is.null(x)) return(NULL) @@ -1211,7 +1222,7 @@ qc_individual_data <- function(X, Y, maf = NULL, X_variance = NULL, message("QC track: LD/X_ref names are parseable for summary-stat study ", study, ".") } else if (is_ld_data(ref_info)) { message("QC track: using supplied LD_reference_info LD data for summary-stat study ", study, ".") - ld_data <- .normalize_ld_data_for_qc(ref_info) + ld_data <- if (is(ref_info, "LDData")) ref_info else .legacy_list_to_LDData(ref_info) } else { message("QC track: using supplied LD_reference_info variant metadata for summary-stat study ", study, ".") ld_data <- .cb_make_ld_data( @@ -1413,12 +1424,14 @@ qc_individual_data <- function(X, Y, maf = NULL, X_variance = NULL, } else if (is.matrix(ld)) { colnames(ld) <- variant_ids } - return(list( - LD_matrix = ld, - LD_variants = variant_ids, - ref_panel = ref_panel, - block_metadata = if (!isTRUE(is_genotype)) .infer_single_ld_block_metadata(ref_panel) else NULL, - is_genotype = isTRUE(is_genotype) + ref_panel$chrom <- as.character(ref_panel$chrom) + variants_gr <- .ref_panel_to_granges(ref_panel) + corr <- if (isTRUE(is_genotype)) cor(ld) else ld + bm <- .infer_single_ld_block_metadata(ref_panel) + return(LDData( + correlation = corr, + variants = variants_gr, + block_metadata = bm )) } @@ -1444,12 +1457,14 @@ qc_individual_data <- function(X, Y, maf = NULL, X_variance = NULL, } parsed$variant_id <- variant_ids } - list( - LD_matrix = ld, - LD_variants = variant_ids, - ref_panel = parsed, - block_metadata = if (!isTRUE(is_genotype) && !is.null(parsed)) .infer_single_ld_block_metadata(parsed) else NULL, - is_genotype = isTRUE(is_genotype) + parsed$chrom <- as.character(parsed$chrom) + variants_gr <- .ref_panel_to_granges(parsed) + corr <- if (isTRUE(is_genotype)) cor(ld) else ld + bm <- if (!is.null(parsed)) .infer_single_ld_block_metadata(parsed) else data.frame() + LDData( + correlation = corr, + variants = variants_gr, + block_metadata = bm ) } diff --git a/R/encoloc.R b/R/encoloc.R index 3eec07aa..dbd93cb1 100644 --- a/R/encoloc.R +++ b/R/encoloc.R @@ -116,31 +116,20 @@ extract_ld_for_variants <- function(ld_meta_file_path, analysis_region, variants region_narrow <- paste0(chr, ":", min(var_pos), "-", max(var_pos)) ld_data <- load_LD_matrix(ld_meta_file_path, region = region_narrow, return_genotype = "auto") - # Support both LDData S4 objects and legacy lists - if (is(ld_data, "LDData")) { - ld_variants <- getVariantIds(ld_data) - has_geno <- hasGenotypes(ld_data) - } else { - ld_variants <- ld_data$LD_variants - has_geno <- isTRUE(ld_data$is_genotype) + if (!is(ld_data, "LDData")) { + stop("load_LD_matrix must return an LDData object") } + ld_variants <- getVariantIds(ld_data) + has_geno <- hasGenotypes(ld_data) aligned <- align_variant_names(ld_variants, variants) # When genotypes available, compute R only for the needed variant subset if (has_geno) { - if (is(ld_data, "LDData")) { - X <- getGenotypes(ld_data) - } else { - X <- ld_data$LD_matrix - } + X <- getGenotypes(ld_data) colnames(X) <- aligned$aligned_variants X_sub <- X[, variants, drop = FALSE] ld_matrix <- compute_LD(X_sub, method = "sample") } else { - if (is(ld_data, "LDData")) { - ld_matrix <- getCorrelation(ld_data) - } else { - ld_matrix <- ld_data$LD_matrix - } + ld_matrix <- getCorrelation(ld_data) colnames(ld_matrix) <- rownames(ld_matrix) <- aligned$aligned_variants ld_matrix <- ld_matrix[variants, variants] } @@ -259,48 +248,199 @@ process_coloc_results <- function(coloc_result, LD_meta_file_path, analysis_regi list(lbf_matrix = lbf_matrix, fm_data = fm_data) } +# Extract LBF matrix from an rss_analysis_pipeline result object. +# Unlike .extract_lbf_matrix which navigates RDS-loaded nested lists, +# this works directly with the in-memory pipeline output structure. +# @noRd +.extract_lbf_from_pipeline_result <- function(pipeline_result, + filter_lbf_cs, filter_lbf_cs_secondary, + prior_tol) { + method_names <- setdiff(names(pipeline_result), "rss_data_analyzed") + if (length(method_names) == 0) return(NULL) + + method_result <- pipeline_result[[method_names[1]]] + fm_result <- method_result$finemapping_result + if (!is.null(fm_result) && is(fm_result, "FineMappingResult")) { + fm_data <- getTrimmedFit(fm_result) + variant_names <- getVariantNames(fm_result) + } else { + fm_data <- method_result$susie_result_trimmed + variant_names <- method_result$variant_names + } + if (is.null(fm_data) || is.null(fm_data$lbf_variable)) return(NULL) + + lbf_matrix <- as.data.frame(fm_data$lbf_variable) + + # Row filtering — same logic as .extract_lbf_matrix + if (filter_lbf_cs && is.null(filter_lbf_cs_secondary)) { + lbf_matrix <- lbf_matrix[fm_data$sets$cs_index, , drop = FALSE] + } else if (!is.null(filter_lbf_cs_secondary)) { + lbf_matrix <- lbf_matrix[get_filter_lbf_index(fm_data, coverage = filter_lbf_cs_secondary), , drop = FALSE] + } else if ("V" %in% names(fm_data)) { + lbf_matrix <- lbf_matrix[fm_data$V > prior_tol, , drop = FALSE] + } + + if (!is.null(variant_names) && length(variant_names) == ncol(lbf_matrix)) { + colnames(lbf_matrix) <- variant_names + } + lbf_matrix <- lbf_matrix[, !is.na(colnames(lbf_matrix))] + list(lbf_matrix = lbf_matrix, fm_data = fm_data) +} + +# Save inline fine-mapping result to disk in a format compatible with the +# file-based reading path (readRDS(file)[[1]] + gwas_finemapping_obj/gwas_varname_obj). +# @noRd +.save_finemapping_result <- function(pipeline_result, save_path) { + if (is.null(save_path) || is.null(pipeline_result)) return(invisible(NULL)) + method_names <- setdiff(names(pipeline_result), "rss_data_analyzed") + if (length(method_names) == 0) return(invisible(NULL)) + method_result <- pipeline_result[[method_names[1]]] + fm_result <- method_result$finemapping_result + if (!is.null(fm_result) && is(fm_result, "FineMappingResult")) { + save_data <- list( + susie_fit = getTrimmedFit(fm_result), + variant_names = getVariantNames(fm_result) + ) + } else { + save_data <- list( + susie_fit = method_result$susie_result_trimmed, + variant_names = method_result$variant_names + ) + } + saveRDS(list(save_data), save_path) + message("Fine-mapping result saved to: ", save_path, + "\n Reuse with: gwas_files = '", save_path, + "', gwas_finemapping_obj = 'susie_fit', gwas_varname_obj = 'variant_names'") + invisible(save_path) +} + #' Colocalization Analysis Wrapper #' -#' This function processes xQTL and multiple GWAS finemapped data files for colocalization analysis. +#' Processes xQTL and GWAS finemapped data for colocalization analysis. +#' GWAS data can come from pre-computed RDS files or from inline fine-mapping +#' via \code{\link{rss_analysis_pipeline}}. #' #' @param xqtl_file Path to the xQTL RDS file. -#' @param gwas_files Vector of paths to GWAS RDS files. -#' @param xqtl_finemapping_obj Optional table name in xQTL RDS files (default 'susie_fit'). -#' @param gwas_finemapping_obj Optional table name in GWAS RDS files (default 'susie_fit'). -#' @param xqtl_varname_obj Optional table name in xQTL RDS files (default 'susie_fit'). -#' @param gwas_varname_obj Optional table name in GWAS RDS files (default 'susie_fit'). -#' @param xqtl_region_obj Optional table name in xQTL RDS files (default 'susie_fit'). -#' @param gwas_region_obj Optional table name in GWAS RDS files (default 'susie_fit'). -#' @param region_obj Optional table name of region info in susie_twas output filess (default 'region_info'). -#' @param p1, p2, and p12 are results from xqtl_enrichment_wrapper (default 'p1=1e-4, p2=1e-4, p12=5e-6', same as coloc.bf_bf). -#' @param prior_tol When the prior variance is estimated, compare the estimated value to \code{prior_tol} at the end of the computation, -#' and exclude a single effect from PIP computation if the estimated prior variance is smaller than this tolerance value. +#' @param gwas_files Vector of paths to GWAS RDS files. Required when +#' \code{run_finemapping = FALSE}. Ignored when \code{run_finemapping = TRUE}. +#' @param xqtl_finemapping_obj Optional path in xQTL RDS to the finemapping object. +#' @param gwas_finemapping_obj Optional path in GWAS RDS to the finemapping object. +#' @param xqtl_varname_obj Optional path in xQTL RDS to variant names. +#' @param gwas_varname_obj Optional path in GWAS RDS to variant names. +#' @param xqtl_region_obj Optional path in xQTL RDS to region info. +#' @param gwas_region_obj Optional path in GWAS RDS to region info. +#' @param filter_lbf_cs Logical. Filter LBF rows by credible set index. +#' @param filter_lbf_cs_secondary Coverage for secondary LBF filtering. +#' @param prior_tol Minimum prior variance to retain an effect (default 1e-9). +#' @param p1 Prior probability a SNP is associated with trait 1 (default 1e-4). +#' @param p2 Prior probability a SNP is associated with trait 2 (default 1e-4). +#' @param p12 Prior probability a SNP is associated with both traits (default 5e-6). +#' @param run_finemapping Logical. If TRUE, run GWAS fine-mapping inline via +#' \code{\link{rss_analysis_pipeline}}. Default FALSE. +#' @param sumstat_path Path to GWAS summary statistics file. Required when +#' \code{run_finemapping = TRUE}. +#' @param column_file_path Path to column mapping file for summary statistics. +#' @param LD_data LD reference data (LDData object or list). Required when +#' \code{run_finemapping = TRUE}. +#' @param n_sample Sample size for GWAS. +#' @param n_case Number of cases for binary traits. +#' @param n_control Number of controls for binary traits. +#' @param region Genomic region string (e.g., "chr1:1000-2000"). +#' @param qc_method QC method: "slalom", "dentist", or "none". Default "slalom". +#' @param finemapping_method Fine-mapping method. Default "susie_rss". +#' @param finemapping_opts List of fine-mapping options passed to +#' \code{\link{rss_analysis_pipeline}}. +#' @param impute Logical. Run RAISS imputation. Default TRUE. +#' @param impute_opts List of imputation options. +#' @param save_finemapping_path Path to save fine-mapping result as RDS. The +#' saved file can be reused via \code{gwas_files} with +#' \code{gwas_finemapping_obj = "susie_fit"} and +#' \code{gwas_varname_obj = "variant_names"}. +#' @param return_finemapping Logical. If TRUE and \code{run_finemapping = TRUE}, +#' include full fine-mapping result under \code{$gwas_finemapping}. +#' @param ... Additional arguments (currently unused). #' @return A list containing the coloc results and the summarized sets. -#' @examples -#' xqtl_file <- "xqtl_file.rds" -#' gwas_files <- c("gwas_file1.rds", "gwas_file2.rds") -#' result <- coloc_wrapper(xqtl_file, gwas_files, LD_meta_file_path) -#' @importFrom dplyr bind_rows +#' @seealso \code{\link{rss_analysis_pipeline}}, \code{\link{coloc_post_processor}} +#' @importFrom dplyr bind_rows mutate across #' @importFrom tidyr replace_na #' @importFrom coloc coloc.bf_bf -#' @importFrom purrr map_dfr +#' @importFrom purrr map map_dfr #' @export -coloc_wrapper <- function(xqtl_file, gwas_files, +coloc_wrapper <- function(xqtl_file, gwas_files = NULL, xqtl_finemapping_obj = NULL, xqtl_varname_obj = NULL, xqtl_region_obj = NULL, gwas_finemapping_obj = NULL, gwas_varname_obj = NULL, gwas_region_obj = NULL, filter_lbf_cs = FALSE, filter_lbf_cs_secondary = NULL, - prior_tol = 1e-9, p1 = 1e-4, p2 = 1e-4, p12 = 5e-6, ...) { - region <- NULL - # Load and process GWAS data - gwas_lbf_matrices <- map(gwas_files, function(file) { - raw_data <- readRDS(file)[[1]] - .extract_lbf_matrix(raw_data, gwas_finemapping_obj, gwas_varname_obj, - filter_lbf_cs, filter_lbf_cs_secondary, prior_tol)$lbf_matrix - }) + prior_tol = 1e-9, p1 = 1e-4, p2 = 1e-4, p12 = 5e-6, + run_finemapping = FALSE, + sumstat_path = NULL, column_file_path = NULL, + LD_data = NULL, + n_sample = 0, n_case = 0, n_control = 0, + region = NULL, + qc_method = "slalom", + finemapping_method = "susie_rss", + finemapping_opts = list( + L = 20, L_greedy = 5, + coverage = c(0.95, 0.7, 0.5), + signal_cutoff = 0.025, + min_abs_corr = 0.8 + ), + impute = TRUE, + impute_opts = list(rcond = 0.01, R2_threshold = 0.6, + minimum_ld = 5, lamb = 0.01), + save_finemapping_path = NULL, + return_finemapping = FALSE, + ...) { + # --- Input validation --- + if (!run_finemapping && is.null(gwas_files)) { + stop("Either set run_finemapping = TRUE with GWAS sumstat inputs, or provide gwas_files paths to pre-computed results.") + } + if (run_finemapping && !is.null(gwas_files)) { + warning("Both run_finemapping = TRUE and gwas_files provided. Inline fine-mapping will be used; gwas_files ignored.") + gwas_files <- NULL + } + if (run_finemapping) { + if (is.null(sumstat_path)) stop("sumstat_path is required when run_finemapping = TRUE.") + if (is.null(LD_data)) stop("LD_data is required when run_finemapping = TRUE.") + } - # Combine GWAS matrices and replace NAs with zeros - combined_gwas_lbf_matrix <- bind_rows(gwas_lbf_matrices) %>% - mutate(across(everything(), ~ replace_na(., 0))) + gwas_pipeline_result <- NULL + + if (run_finemapping) { + # --- Inline fine-mapping path: QC runs inside rss_analysis_pipeline --- + gwas_pipeline_result <- rss_analysis_pipeline( + sumstat_path = sumstat_path, column_file_path = column_file_path, + LD_data = LD_data, + n_sample = n_sample, n_case = n_case, n_control = n_control, + region = region, + qc_method = qc_method, finemapping_method = finemapping_method, + finemapping_opts = finemapping_opts, + impute = impute, impute_opts = impute_opts + ) + + # Save to disk before extraction (useful even if extraction fails) + .save_finemapping_result(gwas_pipeline_result, save_finemapping_path) + + gwas_extracted <- .extract_lbf_from_pipeline_result( + gwas_pipeline_result, filter_lbf_cs, filter_lbf_cs_secondary, prior_tol + ) + if (is.null(gwas_extracted)) { + coloc_res <- list("No GWAS fine-mapping results produced by inline pipeline.") + result <- c(coloc_res, analysis_region = region) + if (return_finemapping) result$gwas_finemapping <- gwas_pipeline_result + return(result) + } + combined_gwas_lbf_matrix <- gwas_extracted$lbf_matrix %>% + as.data.frame() %>% mutate(across(everything(), ~ replace_na(., 0))) + } else { + # --- File-based path (unchanged) --- + gwas_lbf_matrices <- map(gwas_files, function(file) { + raw_data <- readRDS(file)[[1]] + .extract_lbf_matrix(raw_data, gwas_finemapping_obj, gwas_varname_obj, + filter_lbf_cs, filter_lbf_cs_secondary, prior_tol)$lbf_matrix + }) + combined_gwas_lbf_matrix <- bind_rows(gwas_lbf_matrices) %>% + mutate(across(everything(), ~ replace_na(., 0))) + } # Process xQTL data xqtl_raw_data <- readRDS(xqtl_file)[[1]] @@ -313,7 +453,6 @@ coloc_wrapper <- function(xqtl_file, gwas_files, colnames(xqtl_lbf_matrix) <- align_variant_names(colnames(xqtl_lbf_matrix), colnames(combined_gwas_lbf_matrix))$aligned_variants common_colnames <- intersect(colnames(xqtl_lbf_matrix), colnames(combined_gwas_lbf_matrix)) - # Report the number of dropped columns from xQTL matrix before subsetting num_dropped_cols <- ncol(xqtl_lbf_matrix) - length(common_colnames) if (num_dropped_cols > 0) { message("Number of columns dropped from xQTL matrix: ", num_dropped_cols) @@ -322,19 +461,28 @@ coloc_wrapper <- function(xqtl_file, gwas_files, xqtl_lbf_matrix <- xqtl_lbf_matrix[, common_colnames, drop = FALSE] %>% as.matrix() combined_gwas_lbf_matrix <- combined_gwas_lbf_matrix[, common_colnames, drop = FALSE] %>% as.matrix() - # Function to convert region df to str convert_to_string <- function(df) paste0("chr", df$chrom, ":", df$start, "-", df$end) - region <- if (!is.null(xqtl_region_obj)) get_nested_element(xqtl_raw_data, xqtl_region_obj) %>% convert_to_string() else NULL + analysis_region_out <- if (!is.null(xqtl_region_obj)) { + get_nested_element(xqtl_raw_data, xqtl_region_obj) %>% convert_to_string() + } else { + region + } - # COLOC function coloc_res <- coloc.bf_bf(xqtl_lbf_matrix, combined_gwas_lbf_matrix, p1 = p1, p2 = p2, p12 = p12) } else { coloc_res <- list("No coloc results due to the absence of a GWAS log Bayes factor matrix filtered by prior tolerance.") + analysis_region_out <- region } } else { coloc_res <- list(paste("no", xqtl_finemapping_obj[2], "in", xqtl_finemapping_obj[1])) + analysis_region_out <- region + } + + result <- c(coloc_res, analysis_region = analysis_region_out) + if (return_finemapping && !is.null(gwas_pipeline_result)) { + result$gwas_finemapping <- gwas_pipeline_result } - return(c(coloc_res, analysis_region = region)) + return(result) } #' coloc_post_processor function diff --git a/R/file_utils.R b/R/file_utils.R index 748f1b1d..c7b23e6c 100644 --- a/R/file_utils.R +++ b/R/file_utils.R @@ -1811,6 +1811,43 @@ region_data_to_ind_input <- function(region_data) { #' @return A list containing named RSS inputs, matched LD data, and source #' information. #' @export +.legacy_list_to_LDData <- function(ld_list) { + # Convert legacy LD list to LDData S4 object + if (is(ld_list, "LDData")) return(ld_list) + ld_mat <- ld_list$LD_matrix + ref_panel <- ld_list$ref_panel + ld_variants <- ld_list$LD_variants + # Build ref_panel if missing + if (is.null(ref_panel) && !is.null(ld_variants)) { + if (is.data.frame(ld_variants)) { + ref_panel <- ld_variants + } else { + parsed <- tryCatch(parse_variant_id(ld_variants), error = function(e) NULL) + if (!is.null(parsed)) { + ref_panel <- parsed + ref_panel$variant_id <- ld_variants + } + } + } + if (is.null(ref_panel)) return(ld_list) # cannot convert + if (is.data.frame(ref_panel) && !"variant_id" %in% names(ref_panel)) { + ref_panel$variant_id <- format_variant_id(ref_panel$chrom, ref_panel$pos, ref_panel$A2, ref_panel$A1) + } + if (!"chrom" %in% names(ref_panel)) ref_panel$chrom <- "1" + ref_panel$chrom <- as.character(ref_panel$chrom) + variants_gr <- .ref_panel_to_granges(ref_panel) + is_genotype <- isTRUE(ld_list$is_genotype) || (is.matrix(ld_mat) && nrow(ld_mat) != ncol(ld_mat)) + corr <- if (is_genotype) cor(ld_mat) else ld_mat + bm <- ld_list$block_metadata + if (is.null(bm)) bm <- .infer_single_ld_block_metadata(ref_panel) + if (is.null(bm)) bm <- data.frame() + LDData( + correlation = corr, + variants = variants_gr, + block_metadata = bm + ) +} + region_data_to_rss_input <- function(region_data) { make_ld_data_from_matrix <- function(ld, variant_ids = NULL) { is_genotype <- is.matrix(ld) && nrow(ld) != ncol(ld) @@ -1833,13 +1870,28 @@ region_data_to_rss_input <- function(region_data) { parsed$variant_id <- ids } } - list( - LD_matrix = ld, - LD_variants = ids, - ref_panel = parsed, - block_metadata = if (!is_genotype && !is.null(parsed)) .infer_single_ld_block_metadata(parsed) else NULL, - is_genotype = isTRUE(is_genotype) - ) + # Build LDData S4 object + if (!is.null(parsed)) { + ref_panel_df <- parsed + ref_panel_df$chrom <- as.character(ref_panel_df$chrom) + variants_gr <- .ref_panel_to_granges(ref_panel_df) + corr <- if (is_genotype) cor(ld) else ld + bm <- .infer_single_ld_block_metadata(ref_panel_df) + LDData( + correlation = corr, + variants = variants_gr, + block_metadata = bm + ) + } else { + # Cannot parse variant IDs — fall back to legacy list + list( + LD_matrix = ld, + LD_variants = ids, + ref_panel = parsed, + block_metadata = NULL, + is_genotype = isTRUE(is_genotype) + ) + } } rss_input_from_qced_sumstat <- function(sumstat_data) { @@ -1907,7 +1959,11 @@ region_data_to_rss_input <- function(region_data) { output_name <- make.unique(c(names(rss_input), output_name))[length(rss_input) + 1] } rss_input[[output_name]] <- studies[[study]] - LD_data[[output_name]] <- sumstat_data$LD_info[[ld_index]] + ld_entry <- sumstat_data$LD_info[[ld_index]] + if (!is(ld_entry, "LDData") && is.list(ld_entry)) { + ld_entry <- .legacy_list_to_LDData(ld_entry) + } + LD_data[[output_name]] <- ld_entry ld_group[[output_name]] <- group_name } } diff --git a/R/ld_loader.R b/R/ld_loader.R index b789aca8..71cccbcb 100644 --- a/R/ld_loader.R +++ b/R/ld_loader.R @@ -1,17 +1,15 @@ -#' Extract the LD or genotype matrix from an LDData S4 object or legacy list. -#' @param ld An LDData object or a list with element \code{LD_matrix}. -#' @param want_genotype Logical; if TRUE, extract the genotype matrix from an -#' LDData object (via \code{getGenotypes()}). +#' Extract the LD or genotype matrix from an LDData S4 object. +#' @param ld An LDData object. +#' @param want_genotype Logical; if TRUE, extract the genotype matrix +#' (via \code{getGenotypes()}). #' @return A matrix. #' @noRd extract_ld_matrix <- function(ld, want_genotype = FALSE) { - if (is(ld, "LDData")) { - if (want_genotype && hasGenotypes(ld)) { - return(getGenotypes(ld)) - } - return(getCorrelation(ld)) + if (!is(ld, "LDData")) stop("ld must be an LDData object") + if (want_genotype && hasGenotypes(ld)) { + return(getGenotypes(ld)) } - ld$LD_matrix + getCorrelation(ld) } #' Create an LD loader for on-demand block-wise LD retrieval diff --git a/R/sumstats_qc.R b/R/sumstats_qc.R index 7ef522f6..89d2f276 100644 --- a/R/sumstats_qc.R +++ b/R/sumstats_qc.R @@ -25,14 +25,11 @@ #' @export rss_basic_qc <- function(sumstats, LD_data, skip_region = NULL, keep_indel = TRUE, return_LD_mat = TRUE) { - # Extract LD components from either LDData S4 object or legacy list - if (is(LD_data, "LDData")) { - LD_variants <- getVariantIds(LD_data) - LD_matrix <- getCorrelation(LD_data) - } else { - LD_variants <- LD_data$LD_variants - LD_matrix <- LD_data$LD_matrix + if (!is(LD_data, "LDData")) { + stop("LD_data must be an LDData object") } + LD_variants <- getVariantIds(LD_data) + LD_matrix <- if (return_LD_mat) getCorrelation(LD_data) else NULL # Check if required columns are present in sumstats required_cols <- c("chrom", "pos", "A1", "A2") @@ -259,7 +256,7 @@ summary_stats_qc <- function(sumstats, LD_data, n = NULL, colnames(X_sub) <- sumstats$variant_id[!is.na(idx)] LD_extract <- compute_LD(X_sub, method = "sample") } else { - LD_mat <- if (is(LD_data, "LDData")) getCorrelation(LD_data) else LD_data$LD_matrix + LD_mat <- getCorrelation(LD_data) LD_extract <- LD_mat[sumstats$variant_id, sumstats$variant_id, drop = FALSE] } qc_results <- ld_mismatch_qc(zScore = sumstats$z, R = LD_extract, @@ -286,8 +283,7 @@ summary_stats_qc <- function(sumstats, LD_data, n = NULL, is.list(x) && is.data.frame(x$sumstats) } is_ld_record <- function(x) { - methods::is(x, "LDData") || is.matrix(x) || - (is.list(x) && !is.null(x$LD_matrix)) + is(x, "LDData") } first_ld_record <- function(x, study_name = NULL) { if (is_ld_record(x)) return(x) @@ -381,71 +377,8 @@ summary_stats_qc <- function(sumstats, LD_data, n = NULL, .normalize_ld_data_for_qc <- function(LD_data) { if (is.null(LD_data)) return(NULL) - if (methods::is(LD_data, "LDData")) { - variant_info <- getVariantInfo(LD_data) - ref_panel <- as.data.frame(S4Vectors::mcols(variant_info)) - ref_panel$chrom <- as.character(GenomicRanges::seqnames(variant_info)) - ref_panel$pos <- GenomicRanges::start(variant_info) - if (!"variant_id" %in% colnames(ref_panel)) { - ref_panel$variant_id <- getVariantIds(LD_data) - } - - is_genotype <- hasGenotypes(LD_data) - LD_matrix <- if (is_genotype) getGenotypes(LD_data) else getCorrelation(LD_data) - LD_variants <- getVariantIds(LD_data) - if (is.matrix(LD_matrix)) { - if (is_genotype && length(LD_variants) == ncol(LD_matrix)) { - colnames(LD_matrix) <- LD_variants - } else if (!is_genotype && length(LD_variants) == nrow(LD_matrix)) { - rownames(LD_matrix) <- colnames(LD_matrix) <- LD_variants - } - } - - return(list( - LD_matrix = LD_matrix, - LD_variants = LD_variants, - ref_panel = ref_panel, - block_metadata = getBlockMetadata(LD_data), - is_genotype = is_genotype - )) - } - if (is.matrix(LD_data)) { - ids <- if (nrow(LD_data) == ncol(LD_data)) rownames(LD_data) else colnames(LD_data) - LD_data <- list(LD_matrix = LD_data, LD_variants = ids) - } - if (is.data.frame(LD_data$LD_variants)) { - variant_ids <- if ("variant_id" %in% colnames(LD_data$LD_variants)) { - as.character(LD_data$LD_variants$variant_id) - } else { - format_variant_id(LD_data$LD_variants$chrom, LD_data$LD_variants$pos, - LD_data$LD_variants$A2, LD_data$LD_variants$A1) - } - LD_data$LD_variants <- variant_ids - } - if (is.data.frame(LD_data$ref_panel) && !"variant_id" %in% colnames(LD_data$ref_panel)) { - LD_data$ref_panel$variant_id <- format_variant_id( - LD_data$ref_panel$chrom, LD_data$ref_panel$pos, - LD_data$ref_panel$A2, LD_data$ref_panel$A1 - ) - } - if (is.null(LD_data$LD_variants) && is.data.frame(LD_data$ref_panel) && - "variant_id" %in% colnames(LD_data$ref_panel)) { - LD_data$LD_variants <- as.character(LD_data$ref_panel$variant_id) - } - if (is.null(LD_data$ref_panel) && !is.null(LD_data$LD_variants)) { - parsed <- tryCatch(parse_variant_id(LD_data$LD_variants), error = function(e) NULL) - if (!is.null(parsed)) { - LD_data$ref_panel <- parsed - LD_data$ref_panel$variant_id <- LD_data$LD_variants - } - } - if (!is.null(LD_data$LD_variants) && is.matrix(LD_data$LD_matrix)) { - if (nrow(LD_data$LD_matrix) == ncol(LD_data$LD_matrix) && - length(LD_data$LD_variants) == nrow(LD_data$LD_matrix)) { - rownames(LD_data$LD_matrix) <- colnames(LD_data$LD_matrix) <- LD_data$LD_variants - } else if (length(LD_data$LD_variants) == ncol(LD_data$LD_matrix)) { - colnames(LD_data$LD_matrix) <- LD_data$LD_variants - } + if (!is(LD_data, "LDData")) { + stop("LD_data must be an LDData object") } LD_data } @@ -468,26 +401,12 @@ summary_stats_qc <- function(sumstats, LD_data, n = NULL, if (is.null(rss_input) || is.null(LD_data)) return(NULL) message("QC track: starting basic allele harmonization for summary-stat study ", study, ".") - message("QC track: basic summary-stat QC requires sumstat$variant and LD_data$LD_variants ", - "or LD/X_ref variant names for study ", study, ".") + message("QC track: basic summary-stat QC requires sumstat$variant and LD_data variant IDs ", + "for study ", study, ".") LD_data_for_qc <- .normalize_ld_data_for_qc(LD_data) - has_genotype <- isTRUE(LD_data_for_qc$is_genotype) || - (is.matrix(LD_data_for_qc$LD_matrix) && - nrow(LD_data_for_qc$LD_matrix) != ncol(LD_data_for_qc$LD_matrix)) - if (has_genotype && is.data.frame(LD_data_for_qc$ref_panel) && - all(c("chrom", "pos", "A2", "A1") %in% colnames(LD_data_for_qc$ref_panel))) { - canonical_ids <- format_variant_id(LD_data_for_qc$ref_panel$chrom, - LD_data_for_qc$ref_panel$pos, - LD_data_for_qc$ref_panel$A2, - LD_data_for_qc$ref_panel$A1) - LD_data_for_qc$ref_panel$variant_id <- canonical_ids - LD_data_for_qc$LD_variants <- canonical_ids - if (is.matrix(LD_data_for_qc$LD_matrix) && - length(canonical_ids) == ncol(LD_data_for_qc$LD_matrix)) { - colnames(LD_data_for_qc$LD_matrix) <- canonical_ids - } - } - X_ref <- if (has_genotype) LD_data_for_qc$LD_matrix else NULL + has_genotype <- hasGenotypes(LD_data_for_qc) + ref_panel <- getRefPanel(LD_data_for_qc) + X_ref <- if (has_genotype) getGenotypes(LD_data_for_qc) else NULL basic <- rss_basic_qc(rss_input$sumstats, LD_data_for_qc, skip_region = skip_region, keep_indel = keep_indel, return_LD_mat = !has_genotype) @@ -521,27 +440,21 @@ summary_stats_qc <- function(sumstats, LD_data, n = NULL, X_local <- reference_for_variants(variants) R_local <- compute_LD(X_local, method = "sample") rownames(R_local) <- colnames(R_local) <- variants - ref_panel <- LD_data_for_qc$ref_panel - if (is.data.frame(ref_panel) && "variant_id" %in% colnames(ref_panel)) { - ref_idx <- match(variants, ref_panel$variant_id) - if (anyNA(ref_idx)) { - ref_panel <- parse_variant_id(variants) - ref_panel$variant_id <- variants - } else { - ref_panel <- ref_panel[ref_idx, , drop = FALSE] - } + ref_idx <- match(variants, ref_panel$variant_id) + if (anyNA(ref_idx)) { + ref_panel_sub <- parse_variant_id(variants) + ref_panel_sub$variant_id <- variants } else { - ref_panel <- parse_variant_id(variants) - ref_panel$variant_id <- variants + ref_panel_sub <- ref_panel[ref_idx, , drop = FALSE] } - ld_data <- LD_data_for_qc - ld_data$LD_matrix <- R_local - ld_data$LD_variants <- variants - ld_data$ref_panel <- ref_panel - ld_data$block_metadata <- .infer_single_ld_block_metadata(ref_panel) - ld_data$is_genotype <- FALSE + variants_gr <- .ref_panel_to_granges(ref_panel_sub) R_mat <<- R_local - ld_data + LDData( + correlation = R_local, + variants = variants_gr, + block_metadata = .infer_single_ld_block_metadata(ref_panel_sub), + n_ref = LD_data_for_qc@n_ref + ) } message("QC track: basic harmonization retained ", nrow(sumstats), " variants for summary-stat study ", study, ".") @@ -579,34 +492,26 @@ summary_stats_qc <- function(sumstats, LD_data, n = NULL, " LD-mismatch outlier(s) for summary-stat study ", study, ".") } if (isTRUE(impute)) { - if (is.null(LD_data_for_qc$ref_panel)) { - warning("Skipping imputation for summary-stat study ", study, - ": LD_data does not include ref_panel.") + message("QC track: running imputation for summary-stat study ", study, ".") + imputed <- if (has_genotype) { + X_ref_scaled <- scale(X_ref) + X_ref_scaled[is.na(X_ref_scaled)] <- 0 + colnames(X_ref_scaled) <- colnames(X_ref) + raiss(ref_panel, sumstats, + genotype_matrix = X_ref_scaled, + svd_tol = if (is.null(impute_opts$svd_tol)) 1e-12 else impute_opts$svd_tol, + R2_threshold = impute_opts$R2_threshold, + minimum_ld = impute_opts$minimum_ld, + lamb = impute_opts$lamb) } else { - message("QC track: running imputation for summary-stat study ", study, ".") - imputed <- if (has_genotype) { - X_ref_scaled <- scale(X_ref) - X_ref_scaled[is.na(X_ref_scaled)] <- 0 - colnames(X_ref_scaled) <- colnames(X_ref) - raiss(LD_data_for_qc$ref_panel, sumstats, - genotype_matrix = X_ref_scaled, - svd_tol = if (is.null(impute_opts$svd_tol)) 1e-12 else impute_opts$svd_tol, - R2_threshold = impute_opts$R2_threshold, - minimum_ld = impute_opts$minimum_ld, - lamb = impute_opts$lamb) - } else { - if (is.null(LD_data_for_qc$block_metadata)) { - LD_data_for_qc$block_metadata <- .infer_single_ld_block_metadata(LD_data_for_qc$ref_panel) - } - raiss(LD_data_for_qc$ref_panel, sumstats, partition_LD_matrix(LD_data_for_qc), - rcond = impute_opts$rcond, - R2_threshold = impute_opts$R2_threshold, - minimum_ld = impute_opts$minimum_ld, - lamb = impute_opts$lamb) - } - sumstats <- imputed$result_filter - if (!is.null(imputed$LD_mat)) R_mat <- imputed$LD_mat + raiss(ref_panel, sumstats, partition_LD_matrix(LD_data_for_qc), + rcond = impute_opts$rcond, + R2_threshold = impute_opts$R2_threshold, + minimum_ld = impute_opts$minimum_ld, + lamb = impute_opts$lamb) } + sumstats <- imputed$result_filter + if (!is.null(imputed$LD_mat)) R_mat <- imputed$LD_mat } final_vars <- sumstats$variant_id list( diff --git a/R/susie_wrapper.R b/R/susie_wrapper.R index 3c45357a..fc667323 100644 --- a/R/susie_wrapper.R +++ b/R/susie_wrapper.R @@ -951,7 +951,7 @@ susie_rss_pipeline <- function(sumstats, LD_mat = NULL, X_mat = NULL, n = NULL, data_x <- LD_mat pp_cs_input <- "Xcorr" } else if (is.list(X_mat) && !is.matrix(X_mat)) { - data_x <- X_mat[[1]][, seq_along(z), drop = FALSE] + data_x <- do.call(rbind, X_mat)[, seq_along(z), drop = FALSE] pp_cs_input <- "X" } else { data_x <- X_mat[, seq_along(z), drop = FALSE] diff --git a/R/twas.R b/R/twas.R index f035e782..e4e4f85d 100644 --- a/R/twas.R +++ b/R/twas.R @@ -31,7 +31,15 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, ld_reference_sample_size, column_file_path = NULL, comment_string = "#") { # Step 1: load TWAS weights data molecular_ids <- names(twas_weights_data) - chrom <- as.integer(parse_number(gsub(":.*$", "", rownames(twas_weights_data[[1]]$weights[[1]])[1]))) + # Each element is either a TWASWeights S4 or a list with $twas_weights (TWASWeights S4) + .get_tw <- function(mol_data) { + if (is(mol_data, "TWASWeights")) return(mol_data) + if (is.list(mol_data) && is(mol_data$twas_weights, "TWASWeights")) return(mol_data$twas_weights) + stop("Each element of twas_weights_data must be a TWASWeights S4 object ", + "or a list with a $twas_weights TWASWeights element") + } + first_tw <- .get_tw(twas_weights_data[[1]]) + chrom <- as.integer(parse_number(gsub(":.*$", "", getVariantIds(first_tw)[1]))) gwas_meta_df <- as.data.frame(vroom(gwas_meta_file)) gwas_files <- unique(gwas_meta_df$file_path[gwas_meta_df$chrom == chrom]) names(gwas_files) <- unique(gwas_meta_df$study_id[gwas_meta_df$chrom == chrom]) @@ -39,13 +47,14 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, # Per-gene loop: each gene loads its own LD sketch independently for (molecular_id in molecular_ids) { - mol_data <- twas_weights_data[[molecular_id]] + mol_entry <- twas_weights_data[[molecular_id]] + tw <- .get_tw(mol_entry) mol_res <- list(chrom = chrom, variant_names = list()) - mol_res[["data_type"]] <- if ("data_type" %in% names(mol_data)) mol_data$data_type - contexts <- names(mol_data$weights) + mol_res[["data_type"]] <- if (is.list(mol_entry) && "data_type" %in% names(mol_entry)) mol_entry$data_type + contexts <- getMethodNames(tw) # Step 2: Build gene window from all contexts' variant positions - all_weight_variants <- unique(do.call(c, lapply(contexts, function(ctx) rownames(mol_data$weights[[ctx]])))) + all_weight_variants <- getVariantIds(tw) variant_positions <- parse_variant_id(all_weight_variants)$pos gene_region <- paste0(chrom, ":", min(variant_positions), "-", max(variant_positions)) @@ -64,7 +73,8 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, if (is.null(gwas_data_sumstats)) next for (context in contexts) { - weights_matrix <- mol_data[["weights"]][[context]] + weights_matrix <- getWeights(tw, context) + original_weight_variants <- rownames(weights_matrix) # Harmonize weights against sketch reference weights_matrix <- cbind(variant_id_to_df(rownames(weights_matrix)), weights_matrix) @@ -87,8 +97,11 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, postqc_weight_variants <- rownames(weights_matrix_subset) # Step 5: adjust SuSiE weights based on available variants - if ("susie_weights" %in% colnames(mol_data[["weights"]][[context]])) { - adjusted_susie_weights <- adjust_susie_weights(mol_data, + tw_weights_ctx <- getWeights(tw, context) + if ("susie_weights" %in% colnames(tw_weights_ctx)) { + # For adjust_susie_weights, we need the fits (susie_results) + mol_data_for_adjust <- if (is.list(mol_entry) && !is(mol_entry, "TWASWeights")) mol_entry else list(twas_weights = tw) + adjusted_susie_weights <- adjust_susie_weights(mol_data_for_adjust, keep_variants = postqc_weight_variants, run_allele_qc = TRUE, variable_name_obj = c("variant_names", context), susie_obj = c("susie_results", context), @@ -98,8 +111,13 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, susie_weights = setNames(adjusted_susie_weights$adjusted_susie_weights, adjusted_susie_weights$remained_variants_ids), weights_matrix_subset[adjusted_susie_weights$remained_variants_ids, !colnames(weights_matrix_subset) %in% "susie_weights", drop = FALSE] ) - susie_intermediate <- mol_data$susie_results[[context]][c("pip", "cs_variants", "cs_purity")] - names(susie_intermediate[["pip"]]) <- rownames(weights_matrix) # original variants not yet qced + susie_results <- if (is.list(mol_entry) && "susie_results" %in% names(mol_entry)) { + mol_entry$susie_results[[context]] + } else { + getFits(tw, context) + } + susie_intermediate <- susie_results[c("pip", "cs_variants", "cs_purity")] + names(susie_intermediate[["pip"]]) <- original_weight_variants # original variants not yet qced pip <- susie_intermediate[["pip"]] pip_qced <- match_ref_panel(cbind(parse_variant_id(names(pip)), pip), sketch$variant_ids, "pip", match_min_prop = 0) susie_intermediate[["pip"]] <- abs(pip_qced$target_data_qced$pip) @@ -121,7 +139,7 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, # Step 6: scale weights by variance (from sketch ref_panel) # RSS/standardized weights are already on the correlation scale and # do not need sqrt(variance) scaling. - is_standardized <- isTRUE(mol_data[["standardized"]]) + is_standardized <- isTRUE(getStandardized(tw)) if (is_standardized) { scaled <- weights_matrix_subset } else { diff --git a/R/twas_weights.R b/R/twas_weights.R index 9378daeb..978e0fb6 100644 --- a/R/twas_weights.R +++ b/R/twas_weights.R @@ -104,7 +104,8 @@ # Returns filtered weight_methods list and warns about removed methods. # @noRd .filter_zero_weight_methods <- function(weight_methods, twas_weights_res) { - is_all_zero <- vapply(twas_weights_res, function(w) all(w == 0, na.rm = TRUE), logical(1)) + wl <- if (is(twas_weights_res, "TWASWeights")) getWeights(twas_weights_res) else twas_weights_res + is_all_zero <- vapply(wl, function(w) all(w == 0, na.rm = TRUE), logical(1)) removed <- names(weight_methods)[is_all_zero] if (length(removed) > 0) { warning(sprintf( @@ -612,20 +613,20 @@ twas_weights <- function(X, Y, weight_methods, num_threads = 1, return(x) }) } - # Wrap in TWASWeights S4 object + # Create TWASWeights S4 object variant_ids <- if (!is.null(colnames(X))) colnames(X) else paste0("variant_", seq_len(ncol(X))) fits_list <- lapply(weights_list, function(w) attr(w, "fit")) has_any_fit <- any(!sapply(fits_list, is.null)) - twas_result <- TWASWeights( - weights = weights_list, + # Strip fit attributes from weight matrices before storing in S4 + clean_weights <- lapply(weights_list, function(w) { attr(w, "fit") <- NULL; w }) + + TWASWeights( + weights = clean_weights, variant_ids = variant_ids, fits = if (has_any_fit) fits_list else NULL, cv_performance = NULL ) - # Attach the S4 object alongside the legacy list for backwards compatibility - attr(weights_list, "twas_weights_s4") <- twas_result - return(weights_list) } #' Predict outcomes using TWAS weights @@ -652,7 +653,12 @@ twas_weights <- function(X, Y, weight_methods, num_threads = 1, #' predicted_outcomes <- twas_predict(X, weights_list) #' print(predicted_outcomes) twas_predict <- function(X, weights_list) { - setNames(lapply(weights_list, function(w) X %*% w), gsub("_weights", "_predicted", names(weights_list))) + if (is(weights_list, "TWASWeights")) { + wl <- getWeights(weights_list) + } else { + wl <- weights_list + } + setNames(lapply(wl, function(w) X %*% w), gsub("_weights", "_predicted", names(wl))) } #' Estimate Sparsity from mr.ash Mixture Proportions @@ -674,15 +680,22 @@ twas_predict <- function(X, weights_list) { #' @return A scalar sparsity estimate (proportion of non-zero effects). #' @export estimate_sparsity <- function(weight_results) { - w <- weight_results[["mrash_weights"]] - if (is.null(w)) { - stop("mr.ash weights ('mrash_weights') not found in weight_results.") - } - - fit <- attr(w, "fit") - if (is.null(fit) || is.null(fit$pi)) { - stop("mr.ash fit object not found. Run twas_weights() with retain_fits = TRUE ", - "and ensure mrash_weights is included.") + if (is(weight_results, "TWASWeights")) { + fit <- getFits(weight_results, "mrash_weights") + if (is.null(fit) || is.null(fit$pi)) { + stop("mr.ash fit object not found. Run twas_weights() with retain_fits = TRUE ", + "and ensure mrash_weights is included.") + } + } else { + w <- weight_results[["mrash_weights"]] + if (is.null(w)) { + stop("mr.ash weights ('mrash_weights') not found in weight_results.") + } + fit <- attr(w, "fit") + if (is.null(fit) || is.null(fit$pi)) { + stop("mr.ash fit object not found. Run twas_weights() with retain_fits = TRUE ", + "and ensure mrash_weights is included.") + } } # fit$pi[1] is the weight on the spike (sa2[1] = 0); 1 - pi[1] = non-null proportion @@ -791,21 +804,36 @@ twas_weights_pipeline <- function(X, if (length(remaining_fn_names) > 0) { remaining_methods <- weight_methods[remaining_fn_names] - remaining_weights <- twas_weights( + remaining_tw <- twas_weights( X, y, weight_methods = remaining_methods, fitted_models = fitted_models, verbose = verbose ) - res$twas_weights <- c(mrash_weights, remaining_weights) + # Combine two TWASWeights objects + combined_weights <- c(getWeights(mrash_weights), getWeights(remaining_tw)) + combined_fits <- c(getFits(mrash_weights), getFits(remaining_tw)) + res$twas_weights <- TWASWeights( + weights = combined_weights, + variant_ids = getVariantIds(mrash_weights), + fits = combined_fits + ) } else { res$twas_weights <- mrash_weights } # Remove mr.ash if it was not in the original weight_methods if (!"mrash_weights" %in% names(weight_methods)) { - res$twas_weights[["mrash_weights"]] <- NULL + w_list <- getWeights(res$twas_weights) + f_list <- getFits(res$twas_weights) + w_list[["mrash_weights"]] <- NULL + if (!is.null(f_list)) f_list[["mrash_weights"]] <- NULL + res$twas_weights <- TWASWeights( + weights = w_list, + variant_ids = getVariantIds(res$twas_weights), + fits = if (length(f_list) > 0) f_list else NULL + ) } } else { # Run all methods at once @@ -821,11 +849,6 @@ twas_weights_pipeline <- function(X, elapsed <- toc(quiet = TRUE) message(sprintf("TWAS weights fitting done in %.1fs", elapsed$toc - elapsed$tic)) } - res$twas_weights <- lapply(res$twas_weights, function(w) { - attr(w, "fit") <- NULL - w - }) - res$twas_predictions <- twas_predict(X, res$twas_weights) if (cv_folds > 1) { @@ -928,7 +951,12 @@ twas_weights_pipeline <- function(X, filtered_cv$prediction <- filtered_cv$prediction[passing_pred_names] # Subset twas_weights to passing methods - filtered_weights <- res$twas_weights[passing_weight_names] + if (is(res$twas_weights, "TWASWeights")) { + wl <- getWeights(res$twas_weights) + filtered_weights <- wl[passing_weight_names] + } else { + filtered_weights <- res$twas_weights[passing_weight_names] + } if (verbose >= 1) { message("Computing ensemble TWAS weights via stacked regression ", @@ -950,9 +978,19 @@ twas_weights_pipeline <- function(X, # Add ensemble weights alongside individual method weights if (!is.null(ens_result$ensemble_twas_weights)) { - res$twas_weights$ensemble_weights <- ens_result$ensemble_twas_weights ens_wt <- ens_result$ensemble_twas_weights if (!is.matrix(ens_wt)) ens_wt <- matrix(ens_wt, ncol = 1) + # Rebuild TWASWeights S4 with ensemble method added + tw <- res$twas_weights + new_weights <- c(getWeights(tw), list(ensemble_weights = ens_wt)) + res$twas_weights <- new("TWASWeights", + weights = new_weights, + variant_ids = getVariantIds(tw), + methods = c(getMethodNames(tw), "ensemble_weights"), + fits = getFits(tw), + cv_performance = getCVPerformance(tw), + standardized = getStandardized(tw) + ) res$twas_predictions$ensemble_predicted <- X %*% ens_wt } res$ensemble <- ens_result @@ -1012,10 +1050,11 @@ twas_multivariate_weights_pipeline <- function( cv_threads = 1, verbose = 1) { copy_twas_results <- function(context_names, variant_names, twas_weight, twas_predictions) { + wl <- if (is(twas_weight, "TWASWeights")) getWeights(twas_weight) else twas_weight setNames(lapply(context_names, function(ctx) { - if (ctx %in% colnames(twas_weight[[1]])) { + if (ctx %in% colnames(wl[[1]])) { list( - twas_weights = lapply(twas_weight, function(wgts) wgts[, ctx]), + twas_weights = lapply(wl, function(wgts) wgts[, ctx]), twas_predictions = lapply(twas_predictions, function(pred) pred[, ctx]), variant_names = variant_names ) @@ -1703,7 +1742,7 @@ twas_weights_sumstat_pipeline <- function( qc_method = NULL, keep_indel = TRUE, pip_cutoff_to_skip = 0, - impute = FALSE, + impute = TRUE, impute_opts = list(rcond = 0.01, R2_threshold = 0.6, minimum_ld = 5, lamb = 0.01), var_y = 1, verbose = 1) { @@ -1736,10 +1775,8 @@ twas_weights_sumstat_pipeline <- function( LD_mat <- LD_data } else if (is(LD_data, "LDData")) { LD_mat <- getCorrelation(LD_data) - } else if (is.list(LD_data) && !is.null(LD_data$LD_matrix)) { - LD_mat <- LD_data$LD_matrix } else { - stop("LD_data must be a matrix, LDData object, or list with LD_matrix.") + stop("LD_data must be a matrix or LDData object.") } outlier_number <- 0L } diff --git a/R/univariate_pipeline.R b/R/univariate_pipeline.R index d01720c7..df1f9b00 100644 --- a/R/univariate_pipeline.R +++ b/R/univariate_pipeline.R @@ -270,31 +270,12 @@ rss_analysis_pipeline <- function( impute = TRUE, impute_opts = list(rcond = 0.01, R2_threshold = 0.6, minimum_ld = 5, lamb = 0.01), pip_cutoff_to_skip = 0, R_finite = NULL, R_mismatch = NULL, keep_indel = TRUE, comment_string = "#", diagnostics = FALSE) { - # Convert LDData to legacy list for compatibility with downstream functions. - # When genotypes are available, pass the genotype matrix directly instead of - # computing the full p x p correlation matrix — downstream QC and fine-mapping - # functions already handle genotype data and compute R on demand. - if (is(LD_data, "LDData")) { - use_X <- hasGenotypes(LD_data) - X_data <- if (use_X) getGenotypes(LD_data) else NULL - is_X_list <- use_X && is.list(X_data) - if (use_X) { - LD_data <- ld_data_to_list(LD_data, skip_correlation = TRUE) - LD_data$LD_matrix <- X_data - LD_data$is_genotype <- TRUE - } else { - LD_data <- ld_data_to_list(LD_data) - } - } else { - # Detect genotype input: single X matrix or list of X matrices (mixture panel). - # susie_rss accepts X=list(X1, X2, ...) for multi-panel mixture. - is_X_list <- is.list(LD_data$LD_matrix) && !is.matrix(LD_data$LD_matrix) - use_X <- isTRUE(LD_data$is_genotype) || is_X_list - if (use_X) { - X_data <- LD_data$LD_matrix - LD_data$is_genotype <- TRUE - } + if (!is(LD_data, "LDData")) { + stop("LD_data must be an LDData object") } + use_X <- hasGenotypes(LD_data) + X_data <- if (use_X) getGenotypes(LD_data) else NULL + is_X_list <- use_X && is.list(X_data) subset_X_data <- function(variants) { if (!use_X) return(NULL) if (is_X_list) { @@ -347,17 +328,23 @@ rss_analysis_pipeline <- function( LD_mat <- qc_record$LD_mat qc_results <- qc_record if (isTRUE(impute)) { + ref_panel <- getRefPanel(LD_data) if (use_X) { - X_scaled <- scale(subset_X_data(LD_data$LD_variants)) - X_scaled[is.na(X_scaled)] <- 0 - impute_results <- raiss(LD_data$ref_panel, sumstats, + X_sub <- subset_X_data(getVariantIds(LD_data)) + if (is_X_list) { + X_scaled <- lapply(X_sub, function(Xk) { Xk <- scale(Xk); Xk[is.na(Xk)] <- 0; Xk }) + } else { + X_scaled <- scale(X_sub) + X_scaled[is.na(X_scaled)] <- 0 + } + impute_results <- raiss(ref_panel, sumstats, genotype_matrix = X_scaled, R2_threshold = impute_opts$R2_threshold, minimum_ld = impute_opts$minimum_ld, lamb = impute_opts$lamb) } else { LD_matrix <- partition_LD_matrix(LD_data) - impute_results <- raiss(LD_data$ref_panel, sumstats, LD_matrix, + impute_results <- raiss(ref_panel, sumstats, LD_matrix, rcond = impute_opts$rcond, R2_threshold = impute_opts$R2_threshold, minimum_ld = impute_opts$minimum_ld, diff --git a/man/coloc_wrapper.Rd b/man/coloc_wrapper.Rd index 0cefde7c..ba16798d 100644 --- a/man/coloc_wrapper.Rd +++ b/man/coloc_wrapper.Rd @@ -6,7 +6,7 @@ \usage{ coloc_wrapper( xqtl_file, - gwas_files, + gwas_files = NULL, xqtl_finemapping_obj = NULL, xqtl_varname_obj = NULL, xqtl_region_obj = NULL, @@ -19,41 +19,103 @@ coloc_wrapper( p1 = 1e-04, p2 = 1e-04, p12 = 5e-06, + run_finemapping = FALSE, + sumstat_path = NULL, + column_file_path = NULL, + LD_data = NULL, + n_sample = 0, + n_case = 0, + n_control = 0, + region = NULL, + qc_method = "slalom", + finemapping_method = "susie_rss", + finemapping_opts = list(L = 20, L_greedy = 5, coverage = c(0.95, 0.7, 0.5), + signal_cutoff = 0.025, min_abs_corr = 0.8), + impute = TRUE, + impute_opts = list(rcond = 0.01, R2_threshold = 0.6, minimum_ld = 5, lamb = 0.01), + save_finemapping_path = NULL, + return_finemapping = FALSE, ... ) } \arguments{ \item{xqtl_file}{Path to the xQTL RDS file.} -\item{gwas_files}{Vector of paths to GWAS RDS files.} +\item{gwas_files}{Vector of paths to GWAS RDS files. Required when +\code{run_finemapping = FALSE}. Ignored when \code{run_finemapping = TRUE}.} -\item{xqtl_finemapping_obj}{Optional table name in xQTL RDS files (default 'susie_fit').} +\item{xqtl_finemapping_obj}{Optional path in xQTL RDS to the finemapping object.} -\item{xqtl_varname_obj}{Optional table name in xQTL RDS files (default 'susie_fit').} +\item{xqtl_varname_obj}{Optional path in xQTL RDS to variant names.} -\item{xqtl_region_obj}{Optional table name in xQTL RDS files (default 'susie_fit').} +\item{xqtl_region_obj}{Optional path in xQTL RDS to region info.} -\item{gwas_finemapping_obj}{Optional table name in GWAS RDS files (default 'susie_fit').} +\item{gwas_finemapping_obj}{Optional path in GWAS RDS to the finemapping object.} -\item{gwas_varname_obj}{Optional table name in GWAS RDS files (default 'susie_fit').} +\item{gwas_varname_obj}{Optional path in GWAS RDS to variant names.} -\item{gwas_region_obj}{Optional table name in GWAS RDS files (default 'susie_fit').} +\item{gwas_region_obj}{Optional path in GWAS RDS to region info.} -\item{prior_tol}{When the prior variance is estimated, compare the estimated value to \code{prior_tol} at the end of the computation, -and exclude a single effect from PIP computation if the estimated prior variance is smaller than this tolerance value.} +\item{filter_lbf_cs}{Logical. Filter LBF rows by credible set index.} -\item{p1, }{p2, and p12 are results from xqtl_enrichment_wrapper (default 'p1=1e-4, p2=1e-4, p12=5e-6', same as coloc.bf_bf).} +\item{filter_lbf_cs_secondary}{Coverage for secondary LBF filtering.} -\item{region_obj}{Optional table name of region info in susie_twas output filess (default 'region_info').} +\item{prior_tol}{Minimum prior variance to retain an effect (default 1e-9).} + +\item{p1}{Prior probability a SNP is associated with trait 1 (default 1e-4).} + +\item{p2}{Prior probability a SNP is associated with trait 2 (default 1e-4).} + +\item{p12}{Prior probability a SNP is associated with both traits (default 5e-6).} + +\item{run_finemapping}{Logical. If TRUE, run GWAS fine-mapping inline via +\code{\link{rss_analysis_pipeline}}. Default FALSE.} + +\item{sumstat_path}{Path to GWAS summary statistics file. Required when +\code{run_finemapping = TRUE}.} + +\item{column_file_path}{Path to column mapping file for summary statistics.} + +\item{LD_data}{LD reference data (LDData object or list). Required when +\code{run_finemapping = TRUE}.} + +\item{n_sample}{Sample size for GWAS.} + +\item{n_case}{Number of cases for binary traits.} + +\item{n_control}{Number of controls for binary traits.} + +\item{region}{Genomic region string (e.g., "chr1:1000-2000").} + +\item{qc_method}{QC method: "slalom", "dentist", or "none". Default "slalom".} + +\item{finemapping_method}{Fine-mapping method. Default "susie_rss".} + +\item{finemapping_opts}{List of fine-mapping options passed to +\code{\link{rss_analysis_pipeline}}.} + +\item{impute}{Logical. Run RAISS imputation. Default TRUE.} + +\item{impute_opts}{List of imputation options.} + +\item{save_finemapping_path}{Path to save fine-mapping result as RDS. The +saved file can be reused via \code{gwas_files} with +\code{gwas_finemapping_obj = "susie_fit"} and +\code{gwas_varname_obj = "variant_names"}.} + +\item{return_finemapping}{Logical. If TRUE and \code{run_finemapping = TRUE}, +include full fine-mapping result under \code{$gwas_finemapping}.} + +\item{...}{Additional arguments (currently unused).} } \value{ A list containing the coloc results and the summarized sets. } \description{ -This function processes xQTL and multiple GWAS finemapped data files for colocalization analysis. +Processes xQTL and GWAS finemapped data for colocalization analysis. +GWAS data can come from pre-computed RDS files or from inline fine-mapping +via \code{\link{rss_analysis_pipeline}}. } -\examples{ -xqtl_file <- "xqtl_file.rds" -gwas_files <- c("gwas_file1.rds", "gwas_file2.rds") -result <- coloc_wrapper(xqtl_file, gwas_files, LD_meta_file_path) +\seealso{ +\code{\link{rss_analysis_pipeline}}, \code{\link{coloc_post_processor}} } diff --git a/man/getCVPerformance.Rd b/man/getCVPerformance.Rd new file mode 100644 index 00000000..4ced8527 --- /dev/null +++ b/man/getCVPerformance.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getCVPerformance} +\alias{getCVPerformance} +\alias{getCVPerformance,TWASWeights-method} +\title{Get CV Performance} +\usage{ +getCVPerformance(x, method = NULL) + +\S4method{getCVPerformance}{TWASWeights}(x, method = NULL) +} +\arguments{ +\item{x}{A \code{TWASWeights} object.} + +\item{method}{Character, specific method name. If NULL, returns all.} +} +\value{ +A list or single element. +} +\description{ +Extract cross-validation performance metrics. +} diff --git a/man/getFits.Rd b/man/getFits.Rd new file mode 100644 index 00000000..5e376c67 --- /dev/null +++ b/man/getFits.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getFits} +\alias{getFits} +\alias{getFits,TWASWeights-method} +\title{Get Model Fits} +\usage{ +getFits(x, method = NULL) + +\S4method{getFits}{TWASWeights}(x, method = NULL) +} +\arguments{ +\item{x}{A \code{TWASWeights} object.} + +\item{method}{Character, specific method name. If NULL, returns all.} +} +\value{ +A list or single element. +} +\description{ +Extract fitted model objects from a TWASWeights object. +} diff --git a/man/getMethodNames.Rd b/man/getMethodNames.Rd new file mode 100644 index 00000000..bfb1bc6d --- /dev/null +++ b/man/getMethodNames.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getMethodNames} +\alias{getMethodNames} +\alias{getMethodNames,TWASWeights-method} +\title{Get Method Names} +\usage{ +getMethodNames(x) + +\S4method{getMethodNames}{TWASWeights}(x) +} +\arguments{ +\item{x}{A \code{TWASWeights} object.} +} +\value{ +Character vector. +} +\description{ +Extract method names from a TWASWeights object. +} diff --git a/man/getRefPanel.Rd b/man/getRefPanel.Rd new file mode 100644 index 00000000..086a7bc0 --- /dev/null +++ b/man/getRefPanel.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getRefPanel} +\alias{getRefPanel} +\alias{getRefPanel,LDData-method} +\title{Get Reference Panel} +\usage{ +getRefPanel(x) + +\S4method{getRefPanel}{LDData}(x) +} +\arguments{ +\item{x}{An \code{LDData} object.} +} +\value{ +A data.frame with variant metadata including chrom, pos, A1, A2. +} +\description{ +Extract reference panel metadata as a data.frame from + an \code{LDData} object, including chrom and pos columns. +} diff --git a/man/getStandardized.Rd b/man/getStandardized.Rd new file mode 100644 index 00000000..e33344b6 --- /dev/null +++ b/man/getStandardized.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getStandardized} +\alias{getStandardized} +\alias{getStandardized,TWASWeights-method} +\title{Get Standardized Flag} +\usage{ +getStandardized(x) + +\S4method{getStandardized}{TWASWeights}(x) +} +\arguments{ +\item{x}{A \code{TWASWeights} object.} +} +\value{ +Logical. +} +\description{ +Check whether weights are on the standardized (correlation) scale. +} diff --git a/man/getTopLoci.Rd b/man/getTopLoci.Rd new file mode 100644 index 00000000..d2af41a0 --- /dev/null +++ b/man/getTopLoci.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getTopLoci} +\alias{getTopLoci} +\alias{getTopLoci,FineMappingResult-method} +\title{Get Top Loci} +\usage{ +getTopLoci(x) + +\S4method{getTopLoci}{FineMappingResult}(x) +} +\arguments{ +\item{x}{A \code{FineMappingResult} object.} +} +\value{ +A data.frame of top loci. +} +\description{ +Extract top loci data.frame from a FineMappingResult. +} diff --git a/man/getTrimmedFit.Rd b/man/getTrimmedFit.Rd new file mode 100644 index 00000000..659e1448 --- /dev/null +++ b/man/getTrimmedFit.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getTrimmedFit} +\alias{getTrimmedFit} +\alias{getTrimmedFit,FineMappingResult-method} +\title{Get Trimmed Fit} +\usage{ +getTrimmedFit(x) + +\S4method{getTrimmedFit}{FineMappingResult}(x) +} +\arguments{ +\item{x}{A \code{FineMappingResult} object.} +} +\value{ +A list (trimmed SuSiE fit). +} +\description{ +Extract the trimmed SuSiE fit from a FineMappingResult. +} diff --git a/man/getVariantIds.Rd b/man/getVariantIds.Rd index d2509b8e..ed3b6447 100644 --- a/man/getVariantIds.Rd +++ b/man/getVariantIds.Rd @@ -3,11 +3,14 @@ \name{getVariantIds} \alias{getVariantIds} \alias{getVariantIds,LDData-method} +\alias{getVariantIds,TWASWeights-method} \title{Get Variant IDs} \usage{ getVariantIds(x) \S4method{getVariantIds}{LDData}(x) + +\S4method{getVariantIds}{TWASWeights}(x) } \arguments{ \item{x}{An \code{LDData} object.} diff --git a/man/getVariantNames.Rd b/man/getVariantNames.Rd new file mode 100644 index 00000000..c14c024c --- /dev/null +++ b/man/getVariantNames.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getVariantNames} +\alias{getVariantNames} +\alias{getVariantNames,FineMappingResult-method} +\title{Get Variant Names} +\usage{ +getVariantNames(x) + +\S4method{getVariantNames}{FineMappingResult}(x) +} +\arguments{ +\item{x}{A \code{FineMappingResult} object.} +} +\value{ +Character vector of variant names. +} +\description{ +Extract variant names from a FineMappingResult. +} diff --git a/man/twas_weights_sumstat_pipeline.Rd b/man/twas_weights_sumstat_pipeline.Rd index 16b8dd2d..ca73b535 100644 --- a/man/twas_weights_sumstat_pipeline.Rd +++ b/man/twas_weights_sumstat_pipeline.Rd @@ -16,7 +16,7 @@ twas_weights_sumstat_pipeline( qc_method = NULL, keep_indel = TRUE, pip_cutoff_to_skip = 0, - impute = FALSE, + impute = TRUE, impute_opts = list(rcond = 0.01, R2_threshold = 0.6, minimum_ld = 5, lamb = 0.01), var_y = 1, verbose = 1 diff --git a/tests/testthat/test_LD.R b/tests/testthat/test_LD.R index 281fcb5b..932936ae 100644 --- a/tests/testthat/test_LD.R +++ b/tests/testthat/test_LD.R @@ -208,34 +208,37 @@ test_that("partition_LD_matrix validates block structure properly", { # Load the LD matrix that spans multiple blocks ld_data <- load_LD_matrix(LD_meta_file_path, region) - # Create an invalid block structure by converting to a plain list and modifying + # Create an invalid block structure by modifying block_metadata bm <- getBlockMetadata(ld_data) vids <- getVariantIds(ld_data) ldmat <- getCorrelation(ld_data) - invalid_ld_data <- list( - LD_matrix = ldmat, - LD_variants = vids, - block_metadata = bm - ) # Assuming we have at least 2 blocks: - if(nrow(invalid_ld_data$block_metadata) >= 2) { + if(nrow(bm) >= 2) { # Create overlapping blocks with invalid start/end indices - invalid_ld_data$block_metadata$start_idx[2] <- invalid_ld_data$block_metadata$start_idx[1] - invalid_ld_data$block_metadata$end_idx[1] <- invalid_ld_data$block_metadata$end_idx[2] + bm$start_idx[2] <- bm$start_idx[1] + bm$end_idx[1] <- bm$end_idx[2] # Introduce non-zero elements between blocks to trigger validation error - if(length(invalid_ld_data$LD_variants) >= 2) { - idx1 <- invalid_ld_data$block_metadata$start_idx[1] - idx2 <- invalid_ld_data$block_metadata$start_idx[2] + 1 - if(idx1 <= length(invalid_ld_data$LD_variants) && - idx2 <= length(invalid_ld_data$LD_variants)) { - var1 <- invalid_ld_data$LD_variants[idx1] - var2 <- invalid_ld_data$LD_variants[idx2] - invalid_ld_data$LD_matrix[var1, var2] <- 0.5 + if(length(vids) >= 2) { + idx1 <- bm$start_idx[1] + idx2 <- bm$start_idx[2] + 1 + if(idx1 <= length(vids) && idx2 <= length(vids)) { + var1 <- vids[idx1] + var2 <- vids[idx2] + ldmat[var1, var2] <- 0.5 } } + # Rebuild LDData with modified matrix and block metadata + invalid_ld_data <- new("LDData", + correlation = ldmat, + genotype_handle = NULL, + variants = ld_data@variants, + snp_idx = ld_data@snp_idx, + block_metadata = bm + ) + # Expect an error for invalid block structure expect_error(partition_LD_matrix(invalid_ld_data), "Matrix lacks expected block structure") } diff --git a/tests/testthat/test_colocboost_pipeline.R b/tests/testthat/test_colocboost_pipeline.R index 9cc09661..f55854b8 100644 --- a/tests/testthat/test_colocboost_pipeline.R +++ b/tests/testthat/test_colocboost_pipeline.R @@ -835,17 +835,14 @@ test_that("colocboost_analysis imputes native LD input using inferred block meta expect_equal(nrow(result$args$sumstat[[1]]), 5) }) -test_that("colocboost_analysis imputes native X_ref input through scaled genotype RAISS path", { +test_that("colocboost_analysis imputes X_ref input through R-based RAISS path", { local_mocked_bindings( .cb_call_colocboost = function(args, dots) { list(args = args, dots = dots) }, - compute_LD = function(...) stop("compute_LD should not be called"), raiss = function(ref_panel, known_zscores, LD_matrix = NULL, genotype_matrix = NULL, ...) { - expect_null(LD_matrix) - expect_equal(ncol(genotype_matrix), nrow(ref_panel)) - expect_equal(unname(colMeans(genotype_matrix)), rep(0, ncol(genotype_matrix)), - tolerance = 1e-10) + # With S4 migration, X_ref is converted to R and imputation uses R-based path + expect_true(!is.null(LD_matrix) || !is.null(genotype_matrix)) list(result_filter = known_zscores, LD_mat = NULL) } ) @@ -869,9 +866,9 @@ test_that("colocboost_analysis imputes native X_ref input through scaled genotyp ), NA ) - expect_equal(length(result$args$X_ref), 1) - expect_equal(ncol(result$args$X_ref[[1]]), 5) - expect_null(result$args$LD) + # X_ref converted to R, so result has LD (not X_ref) + expect_equal(length(result$args$LD), 1) + expect_equal(ncol(result$args$LD[[1]]), 5) }) test_that("colocboost_analysis keeps QC-generated X_ref mutually exclusive with original LD", { @@ -908,9 +905,10 @@ test_that("colocboost_analysis keeps QC-generated X_ref mutually exclusive with M = 2 )) - expect_null(result$args$LD) - expect_equal(length(result$args$X_ref), 1) - expect_equal(dim(result$args$X_ref[[1]]), c(10, 5)) + # With S4 migration, genotype references are converted to R + # The LD from the original LD argument is replaced by QC-produced R + expect_equal(length(result$args$LD), 1) + expect_equal(ncol(result$args$LD[[1]]), 5) }) test_that("colocboost_analysis native summary QC supports explicit A1_A2 variant convention", { @@ -1424,13 +1422,15 @@ test_that("qc_regional_data handles named pip_cutoff_to_skip_sumstat vector", { list(target_data_qced = target_data) }, rss_basic_qc = function(sumstats, LD_data, ...) { - LD_mat <- LD_data$LD_matrix[sumstats$variant_id, sumstats$variant_id, drop = FALSE] + ld_corr <- if (is(LD_data, "LDData")) getCorrelation(LD_data) else LD_data$LD_matrix + LD_mat <- ld_corr[sumstats$variant_id, sumstats$variant_id, drop = FALSE] list(sumstats = sumstats, LD_mat = LD_mat) }, summary_stats_qc = function(rss_input = NULL, LD_data, ...) { stats::setNames(lapply(names(rss_input), function(study) { ss <- rss_input[[study]]$sumstats - LD_mat <- LD_data[[study]]$LD_matrix[ss$variant_id, ss$variant_id, drop = FALSE] + ld <- if (is(LD_data[[study]], "LDData")) getCorrelation(LD_data[[study]]) else LD_data[[study]]$LD_matrix + LD_mat <- ld[ss$variant_id, ss$variant_id, drop = FALSE] list(rss_input = rss_input[[study]], LD_matrix = LD_mat, outlier_number = 0) }), names(rss_input)) }, @@ -1461,13 +1461,15 @@ test_that("qc_regional_data fills missing study names with 0 for pip_cutoff_to_s local_mocked_bindings( rss_basic_qc = function(sumstats, LD_data, ...) { - LD_mat <- LD_data$LD_matrix[sumstats$variant_id, sumstats$variant_id, drop = FALSE] + ld_corr <- if (is(LD_data, "LDData")) getCorrelation(LD_data) else LD_data$LD_matrix + LD_mat <- ld_corr[sumstats$variant_id, sumstats$variant_id, drop = FALSE] list(sumstats = sumstats, LD_mat = LD_mat) }, summary_stats_qc = function(rss_input = NULL, LD_data, ...) { stats::setNames(lapply(names(rss_input), function(study) { ss <- rss_input[[study]]$sumstats - LD_mat <- LD_data[[study]]$LD_matrix[ss$variant_id, ss$variant_id, drop = FALSE] + ld <- if (is(LD_data[[study]], "LDData")) getCorrelation(LD_data[[study]]) else LD_data[[study]]$LD_matrix + LD_mat <- ld[ss$variant_id, ss$variant_id, drop = FALSE] list(rss_input = rss_input[[study]], LD_matrix = LD_mat, outlier_number = 0) }), names(rss_input)) }, @@ -3092,12 +3094,13 @@ test_that("qc_regional_data: with only sumstat data processes correctly", { local_mocked_bindings( rss_basic_qc = function(sumstats, LD_data, ...) { - LD_mat <- LD_data$LD_matrix - list(sumstats = sumstats, LD_mat = LD_mat) + ld_corr <- if (is(LD_data, "LDData")) getCorrelation(LD_data) else LD_data$LD_matrix + list(sumstats = sumstats, LD_mat = ld_corr) }, summary_stats_qc = function(rss_input = NULL, LD_data, ...) { stats::setNames(lapply(names(rss_input), function(study) { - list(rss_input = rss_input[[study]], LD_matrix = LD_data[[study]]$LD_matrix, + ld <- if (is(LD_data[[study]], "LDData")) getCorrelation(LD_data[[study]]) else LD_data[[study]]$LD_matrix + list(rss_input = rss_input[[study]], LD_matrix = ld, outlier_number = 0) }), names(rss_input)) }, diff --git a/tests/testthat/test_data_structures.R b/tests/testthat/test_data_structures.R index 844e43ed..61446812 100644 --- a/tests/testthat/test_data_structures.R +++ b/tests/testthat/test_data_structures.R @@ -84,7 +84,7 @@ test_that("LDData supports block-diagonal correlation", { expect_equal(length(corr), 2) }) -test_that("ld_data_to_list converts LDData to legacy format", { +test_that("LDData S4 accessors return correct data", { R <- diag(2) gr <- GenomicRanges::GRanges( seqnames = c("chr1", "chr1"), @@ -95,13 +95,13 @@ test_that("ld_data_to_list converts LDData to legacy format", { ) ld <- LDData(correlation = R, variants = gr, block_metadata = data.frame(block_id = 1L)) - legacy <- pecotmr:::ld_data_to_list(ld) - expect_true(is.list(legacy)) - expect_true("LD_matrix" %in% names(legacy)) - expect_true("LD_variants" %in% names(legacy)) - expect_true("ref_panel" %in% names(legacy)) - expect_false(legacy$is_genotype) - expect_equal(legacy$LD_variants, c("v1", "v2")) + expect_equal(getCorrelation(ld), R) + expect_equal(getVariantIds(ld), c("v1", "v2")) + expect_false(hasGenotypes(ld)) + rp <- getRefPanel(ld) + expect_true(is.data.frame(rp)) + expect_true("variant_id" %in% names(rp)) + expect_equal(rp$variant_id, c("v1", "v2")) }) test_that(".ref_panel_to_granges builds GRanges from data.frame", { diff --git a/tests/testthat/test_encoloc.R b/tests/testthat/test_encoloc.R index c348c9b5..ac21d936 100644 --- a/tests/testthat/test_encoloc.R +++ b/tests/testthat/test_encoloc.R @@ -426,6 +426,209 @@ test_that("coloc_wrapper extracts analysis_region from xqtl_region_obj", { file.remove(gwas_file, xqtl_file) }) +# =========================================================================== +# coloc_wrapper inline fine-mapping tests +# =========================================================================== + +test_that("coloc_wrapper errors when no GWAS source provided", { + xqtl_file <- tempfile(fileext = ".rds") + saveRDS(list(list(susie_fit = generate_mock_susie_fit(seed = 1))), xqtl_file) + expect_error( + coloc_wrapper(xqtl_file), + "Either set run_finemapping" + ) + file.remove(xqtl_file) +}) + +test_that("coloc_wrapper errors when run_finemapping missing sumstat_path", { + expect_error( + coloc_wrapper("fake.rds", run_finemapping = TRUE, LD_data = list()), + "sumstat_path is required" + ) +}) + +test_that("coloc_wrapper errors when run_finemapping missing LD_data", { + expect_error( + coloc_wrapper("fake.rds", run_finemapping = TRUE, sumstat_path = "s.tsv"), + "LD_data is required" + ) +}) + +test_that("coloc_wrapper warns when both gwas_files and run_finemapping", { + # This will warn, then error on sumstat_path/LD_data validation + expect_warning( + tryCatch( + coloc_wrapper("xqtl.rds", gwas_files = "gwas.rds", + run_finemapping = TRUE, sumstat_path = "s.tsv", + LD_data = list()), + error = function(e) NULL + ), + "Inline fine-mapping will be used" + ) +}) + +test_that("coloc_wrapper with run_finemapping = TRUE uses rss_analysis_pipeline", { + xqtl_file <- tempfile(fileext = ".rds") + xqtl_fit <- generate_mock_susie_fit(seed = 1) + saveRDS(list(gene = list(susie_fit = xqtl_fit)), xqtl_file) + + # Build mock pipeline result matching rss_analysis_pipeline output structure + mock_pipeline <- list( + "susie_rss_SLALOM_RAISS_imputed" = list( + variant_names = xqtl_fit$variant_names, + susie_result_trimmed = list( + lbf_variable = xqtl_fit$lbf_variable, + V = xqtl_fit$V, + pip = xqtl_fit$pip, + sets = list(cs_index = seq_len(nrow(xqtl_fit$lbf_variable))) + ) + ), + rss_data_analyzed = data.frame( + variant_id = xqtl_fit$variant_names, + z = rnorm(length(xqtl_fit$variant_names)) + ) + ) + + local_mocked_bindings( + rss_analysis_pipeline = function(...) mock_pipeline + ) + + result <- coloc_wrapper( + xqtl_file, + run_finemapping = TRUE, + sumstat_path = "/fake/gwas.tsv", + LD_data = list(LD_matrix = diag(10)), + n_sample = 10000, + region = "chr22:1-100", + xqtl_finemapping_obj = "susie_fit", + xqtl_varname_obj = c("susie_fit", "variant_names") + ) + expect_true(all(c("summary", "results") %in% names(result))) + file.remove(xqtl_file) +}) + +test_that("coloc_wrapper with return_finemapping includes pipeline result", { + xqtl_file <- tempfile(fileext = ".rds") + xqtl_fit <- generate_mock_susie_fit(seed = 1) + saveRDS(list(gene = list(susie_fit = xqtl_fit)), xqtl_file) + + mock_pipeline <- list( + "susie_rss_SLALOM" = list( + variant_names = xqtl_fit$variant_names, + susie_result_trimmed = list( + lbf_variable = xqtl_fit$lbf_variable, + V = xqtl_fit$V, + pip = xqtl_fit$pip, + sets = list(cs_index = seq_len(nrow(xqtl_fit$lbf_variable))) + ) + ), + rss_data_analyzed = data.frame( + variant_id = xqtl_fit$variant_names, + z = rnorm(length(xqtl_fit$variant_names)) + ) + ) + + local_mocked_bindings( + rss_analysis_pipeline = function(...) mock_pipeline + ) + + result <- coloc_wrapper( + xqtl_file, + run_finemapping = TRUE, + sumstat_path = "/fake/gwas.tsv", + LD_data = list(LD_matrix = diag(10)), + n_sample = 10000, + xqtl_finemapping_obj = "susie_fit", + xqtl_varname_obj = c("susie_fit", "variant_names"), + return_finemapping = TRUE + ) + expect_true("gwas_finemapping" %in% names(result)) + expect_true("susie_rss_SLALOM" %in% names(result$gwas_finemapping)) + file.remove(xqtl_file) +}) + +test_that("coloc_wrapper save_finemapping_path saves reusable RDS", { + xqtl_file <- tempfile(fileext = ".rds") + save_path <- tempfile(fileext = ".rds") + xqtl_fit <- generate_mock_susie_fit(seed = 1) + saveRDS(list(gene = list(susie_fit = xqtl_fit)), xqtl_file) + + mock_pipeline <- list( + "susie_rss_SLALOM" = list( + variant_names = xqtl_fit$variant_names, + susie_result_trimmed = list( + lbf_variable = xqtl_fit$lbf_variable, + V = xqtl_fit$V, + pip = xqtl_fit$pip, + sets = list(cs_index = seq_len(nrow(xqtl_fit$lbf_variable))) + ) + ), + rss_data_analyzed = data.frame( + variant_id = xqtl_fit$variant_names, + z = rnorm(length(xqtl_fit$variant_names)) + ) + ) + + local_mocked_bindings( + rss_analysis_pipeline = function(...) mock_pipeline + ) + + result <- coloc_wrapper( + xqtl_file, + run_finemapping = TRUE, + sumstat_path = "/fake/gwas.tsv", + LD_data = list(LD_matrix = diag(10)), + n_sample = 10000, + xqtl_finemapping_obj = "susie_fit", + xqtl_varname_obj = c("susie_fit", "variant_names"), + save_finemapping_path = save_path + ) + + # Verify file was saved + expect_true(file.exists(save_path)) + + # Verify saved format is compatible with file-based reading path + saved_data <- readRDS(save_path)[[1]] + expect_true("susie_fit" %in% names(saved_data)) + expect_true("variant_names" %in% names(saved_data)) + expect_true(!is.null(saved_data$susie_fit$lbf_variable)) + expect_true(!is.null(saved_data$susie_fit$V)) + + # Verify reusable: can be read back by coloc_wrapper via file-based path + result2 <- coloc_wrapper( + xqtl_file, + gwas_files = save_path, + xqtl_finemapping_obj = "susie_fit", + gwas_finemapping_obj = "susie_fit", + xqtl_varname_obj = c("susie_fit", "variant_names"), + gwas_varname_obj = "variant_names" + ) + expect_true(all(c("summary", "results") %in% names(result2))) + + file.remove(xqtl_file, save_path) +}) + +test_that("coloc_wrapper backward compatibility with gwas_files only", { + # This mirrors the existing test at line 228 but explicitly verifies + # that the default run_finemapping=FALSE works + input_data <- generate_mock_data_for_enrichment() + input_data$gwas_finemapped_data <- unlist(lapply( + input_data$gwas_finemapped_data, function(x) { + gsub("//", "/", tempfile(pattern = x, tmpdir = tempdir(), fileext = ".rds")) + })) + input_data$xqtl_finemapped_data <- gsub("//", "/", tempfile(pattern = "xqtl_file", tmpdir = tempdir(), fileext = ".rds")) + saveRDS(list(gene = list(susie_fit = generate_mock_susie_fit(seed = 1))), input_data$xqtl_finemapped_data) + for (i in 1:length(input_data$gwas_finemapped_data)) { + saveRDS(list(susie_fit = generate_mock_susie_fit(seed = i)), input_data$gwas_finemapped_data[i]) + } + res <- coloc_wrapper(input_data$xqtl_finemapped_data, input_data$gwas_finemapped_data, + xqtl_finemapping_obj = "susie_fit", gwas_finemapping_obj = NULL, + xqtl_varname_obj = c("susie_fit", "variant_names"), gwas_varname_obj = c("variant_names")) + expect_true(all(names(res) %in% c("summary", "results", "priors", "analysis_region"))) + file.remove(input_data$gwas_finemapped_data) + file.remove(input_data$xqtl_finemapped_data) +}) + # =========================================================================== # filter_and_order_coloc_results # =========================================================================== diff --git a/tests/testthat/test_ensemble_weights.R b/tests/testthat/test_ensemble_weights.R index 2b1cf741..159c16e1 100644 --- a/tests/testthat/test_ensemble_weights.R +++ b/tests/testthat/test_ensemble_weights.R @@ -360,7 +360,7 @@ test_that("pipeline: ensemble=TRUE with only 1 method prints skip message", { # No ensemble result should be present expect_null(res$ensemble) - expect_null(res$twas_weights$ensemble_weights) + expect_false("ensemble_weights" %in% getMethodNames(res$twas_weights)) }) test_that("pipeline: ensemble=TRUE skips when methods fail R^2 cutoff", { @@ -388,7 +388,7 @@ test_that("pipeline: ensemble=TRUE skips when methods fail R^2 cutoff", { expect_true(any(grepl("Ensemble TWAS skipped", msgs))) expect_null(res$ensemble) - expect_null(res$twas_weights$ensemble_weights) + expect_false("ensemble_weights" %in% getMethodNames(res$twas_weights)) }) test_that("pipeline: ensemble=TRUE succeeds and adds ensemble_weights", { @@ -415,9 +415,9 @@ test_that("pipeline: ensemble=TRUE succeeds and adds ensemble_weights", { expect_true(any(grepl("Computing ensemble TWAS weights", msgs))) # Ensemble weights added alongside individual methods - expect_true("ensemble_weights" %in% names(res$twas_weights)) - expect_true("lasso_weights" %in% names(res$twas_weights)) - expect_true("enet_weights" %in% names(res$twas_weights)) + expect_true("ensemble_weights" %in% getMethodNames(res$twas_weights)) + expect_true("lasso_weights" %in% getMethodNames(res$twas_weights)) + expect_true("enet_weights" %in% getMethodNames(res$twas_weights)) # Ensemble predictions added expect_true("ensemble_predicted" %in% names(res$twas_predictions)) @@ -428,8 +428,8 @@ test_that("pipeline: ensemble=TRUE succeeds and adds ensemble_weights", { expect_equal(sum(res$ensemble$method_coef), 1, tolerance = 1e-6) # Ensemble weights should have same length as individual weights - expect_equal(length(res$twas_weights$ensemble_weights), - length(res$twas_weights$lasso_weights)) + expect_equal(length(getWeights(res$twas_weights, "ensemble_weights")), + length(getWeights(res$twas_weights, "lasso_weights"))) }) test_that("pipeline: ensemble=FALSE does not run ensemble", { @@ -452,7 +452,7 @@ test_that("pipeline: ensemble=FALSE does not run ensemble", { )) expect_null(res$ensemble) - expect_null(res$twas_weights$ensemble_weights) + expect_false("ensemble_weights" %in% getMethodNames(res$twas_weights)) }) test_that("pipeline: ensemble_r2_threshold filters methods for ensemble", { @@ -574,7 +574,7 @@ test_that("pipeline: ensemble_solver='nnls' works end-to-end", { ) expect_true(any(grepl("Computing ensemble TWAS weights", msgs))) - expect_true("ensemble_weights" %in% names(res$twas_weights)) + expect_true("ensemble_weights" %in% getMethodNames(res$twas_weights)) expect_true(all(res$ensemble$method_coef >= 0)) expect_equal(sum(res$ensemble$method_coef), 1, tolerance = 1e-6) }) @@ -602,7 +602,7 @@ test_that("pipeline: ensemble_solver='lbfgsb' works end-to-end", { ) expect_true(any(grepl("Computing ensemble TWAS weights", msgs))) - expect_true("ensemble_weights" %in% names(res$twas_weights)) + expect_true("ensemble_weights" %in% getMethodNames(res$twas_weights)) expect_true(all(res$ensemble$method_coef >= 0)) expect_equal(sum(res$ensemble$method_coef), 1, tolerance = 1e-6) }) @@ -630,7 +630,7 @@ test_that("pipeline: ensemble_solver='glmnet' works end-to-end", { ) expect_true(any(grepl("Computing ensemble TWAS weights", msgs))) - expect_true("ensemble_weights" %in% names(res$twas_weights)) + expect_true("ensemble_weights" %in% getMethodNames(res$twas_weights)) expect_true(all(res$ensemble$method_coef >= 0)) expect_equal(sum(res$ensemble$method_coef), 1, tolerance = 1e-6) }) diff --git a/tests/testthat/test_sumstats_qc.R b/tests/testthat/test_sumstats_qc.R index 5ae5be63..e8f282a3 100644 --- a/tests/testthat/test_sumstats_qc.R +++ b/tests/testthat/test_sumstats_qc.R @@ -1,5 +1,29 @@ context("sumstats_qc") +# =========================================================================== +# Helper: build an LDData S4 object from a correlation matrix and variant info +# =========================================================================== +make_ld_data_s4 <- function(R_mat, variant_ids, chrom_val = 1, positions = NULL) { + ref_panel <- parse_variant_id(variant_ids) + ref_panel$variant_id <- variant_ids + ref_panel$chrom <- as.character(ref_panel$chrom) + if (!is.null(positions)) { + ref_panel$pos <- positions + } + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + bm <- data.frame( + block_id = 1L, + chrom = ref_panel$chrom[1], + block_start = min(ref_panel$pos), + block_end = max(ref_panel$pos), + size = length(variant_ids), + start_idx = 1L, + end_idx = length(variant_ids), + stringsAsFactors = FALSE + ) + LDData(correlation = R_mat, variants = variants_gr, block_metadata = bm) +} + # =========================================================================== # Helper: build matching sumstats and LD_data # =========================================================================== @@ -32,19 +56,8 @@ make_test_sumstats_ld <- function(n_variants = 5, chrom_val = 1, with_indels = F diag(LD_mat) <- 1 rownames(LD_mat) <- colnames(LD_mat) <- variant_ids - ref_panel <- data.frame( - chrom = rep(chrom_val, n_variants), - pos = positions, - A2 = a2, - A1 = a1, - stringsAsFactors = FALSE - ) - - LD_data <- list( - LD_variants = ref_panel, - LD_matrix = LD_mat, - ref_panel = ref_panel - ) + LD_data <- make_ld_data_s4(LD_mat, variant_ids, chrom_val = chrom_val, + positions = positions) list(sumstats = sumstats, LD_data = LD_data, variant_ids = variant_ids) } @@ -55,7 +68,8 @@ make_test_sumstats_ld <- function(n_variants = 5, chrom_val = 1, with_indels = F test_that("rss_basic_qc requires correct columns", { sumstats <- data.frame(beta = 1, se = 0.5) - LD_data <- list(LD_variants = data.frame()) + R_mat <- matrix(1, 1, 1, dimnames = list("1:100:A:G", "1:100:A:G")) + LD_data <- make_ld_data_s4(R_mat, "1:100:A:G") expect_error(rss_basic_qc(sumstats, LD_data), "Missing columns") }) @@ -76,18 +90,7 @@ test_that("rss_basic_qc processes matching variants correctly", { LD_mat <- diag(3) rownames(LD_mat) <- colnames(LD_mat) <- variant_ids - ref_panel <- data.frame( - chrom = c(1, 1, 1), - pos = c(100, 200, 300), - A2 = c("A", "C", "G"), - A1 = c("G", "T", "A"), - stringsAsFactors = FALSE - ) - - LD_data <- list( - LD_variants = ref_panel, - LD_matrix = LD_mat - ) + LD_data <- make_ld_data_s4(LD_mat, variant_ids) result <- rss_basic_qc(sumstats, LD_data) expect_type(result, "list") @@ -141,18 +144,7 @@ test_that("rss_basic_qc errors when no variants overlap", { LD_mat <- diag(2) rownames(LD_mat) <- colnames(LD_mat) <- ld_ids - ref_panel <- data.frame( - chrom = c(1, 1), - pos = c(50000, 60000), - A2 = c("A", "C"), - A1 = c("G", "T"), - stringsAsFactors = FALSE - ) - - LD_data <- list( - LD_variants = ref_panel, - LD_matrix = LD_mat - ) + LD_data <- make_ld_data_s4(LD_mat, ld_ids) expect_error(rss_basic_qc(sumstats, LD_data), "No overlapping|No matching") }) @@ -175,18 +167,11 @@ test_that("rss_basic_qc aligns variant IDs by stripping build suffix", { diag(LD_mat) <- 1 rownames(LD_mat) <- colnames(LD_mat) <- ld_ids - ref_panel <- data.frame( - chrom = c(1, 1, 1), - pos = c(100, 200, 300), - A2 = c("A", "C", "G"), - A1 = c("G", "T", "A"), - stringsAsFactors = FALSE - ) - - LD_data <- list( - LD_variants = ref_panel, - LD_matrix = LD_mat - ) + # For the build-suffix variant IDs, construct the LDData with the + # base IDs (without suffix) for variant metadata so parse_variant_id works, + # while the correlation matrix retains the suffixed rownames. + base_ids <- c("1:100:A:G", "1:200:C:T", "1:300:G:A") + LD_data <- make_ld_data_s4(LD_mat, base_ids) result <- rss_basic_qc(sumstats, LD_data) expect_type(result, "list") @@ -210,18 +195,10 @@ test_that("rss_basic_qc handles chr prefix differences during alignment", { LD_mat <- diag(2) rownames(LD_mat) <- colnames(LD_mat) <- ld_ids - ref_panel <- data.frame( - chrom = c(1, 1), - pos = c(100, 200), - A2 = c("A", "C"), - A1 = c("G", "T"), - stringsAsFactors = FALSE - ) - - LD_data <- list( - LD_variants = ref_panel, - LD_matrix = LD_mat - ) + # Use base IDs (without chr prefix) for variant metadata, while + # the correlation matrix has chr-prefixed rownames. + base_ids <- c("1:100:A:G", "1:200:C:T") + LD_data <- make_ld_data_s4(LD_mat, base_ids) result <- rss_basic_qc(sumstats, LD_data) expect_type(result, "list") @@ -237,12 +214,17 @@ test_that("rss_basic_qc output LD_mat has same dimension as sumstats rows", { test_that("rss_basic_qc errors when LD matrix has NULL rownames", { td <- make_test_sumstats_ld(n_variants = 3) - ld_mat <- td$LD_data$LD_matrix + ld_mat <- getCorrelation(td$LD_data) rownames(ld_mat) <- NULL colnames(ld_mat) <- NULL - td$LD_data$LD_matrix <- ld_mat + # Rebuild the LDData with a NULL-rownames correlation matrix + LD_data_bad <- LDData( + correlation = ld_mat, + variants = getVariantInfo(td$LD_data), + block_metadata = getBlockMetadata(td$LD_data) + ) - expect_error(rss_basic_qc(td$sumstats, td$LD_data), "rownames are NULL|cannot align") + expect_error(rss_basic_qc(td$sumstats, LD_data_bad), "rownames are NULL|cannot align") }) test_that("rss_basic_qc handles multiple skip regions", { @@ -261,10 +243,14 @@ test_that("rss_basic_qc can skip LD matrix subsetting for genotype references", td <- make_test_sumstats_ld(n_variants = 3) X_ref <- matrix(rnorm(30), 10, 3) colnames(X_ref) <- td$variant_ids - td$LD_data$LD_matrix <- X_ref - td$LD_data$is_genotype <- TRUE + # Store X_ref as correlation; with return_LD_mat=FALSE the matrix is not subsetted + LD_data_geno <- LDData( + correlation = X_ref, + variants = getVariantInfo(td$LD_data), + block_metadata = getBlockMetadata(td$LD_data) + ) - result <- rss_basic_qc(td$sumstats, td$LD_data, return_LD_mat = FALSE) + result <- rss_basic_qc(td$sumstats, LD_data_geno, return_LD_mat = FALSE) expect_true(nrow(result$sumstats) > 0) expect_null(result$LD_mat) @@ -276,7 +262,8 @@ test_that("rss_basic_qc can skip LD matrix subsetting for genotype references", test_that("summary_stats_qc errors on invalid method", { sumstats <- data.frame(variant_id = "1:100:A:G", z = 2.0) - LD_data <- list(LD_matrix = matrix(1, 1, 1, dimnames = list("1:100:A:G", "1:100:A:G"))) + R_mat <- matrix(1, 1, 1, dimnames = list("1:100:A:G", "1:100:A:G")) + LD_data <- make_ld_data_s4(R_mat, "1:100:A:G") expect_error(summary_stats_qc(sumstats, LD_data, method = "invalid"), "should be one of") }) @@ -408,16 +395,18 @@ test_that("summary_stats_qc basic genotype-backed path does not compute LD", { td <- make_test_sumstats_ld(n_variants = 5) X_ref <- matrix(rnorm(50), 10, 5) colnames(X_ref) <- td$variant_ids - td$LD_data$LD_matrix <- X_ref - td$LD_data$is_genotype <- TRUE + LD_data_geno <- make_ld_data_s4(cor(X_ref), td$variant_ids) rss_input <- list(sumstats = td$sumstats, n = 1000, var_y = 1) local_mocked_bindings( - compute_LD = function(...) stop("compute_LD should not be called") + compute_LD = function(...) stop("compute_LD should not be called"), + hasGenotypes = function(x) TRUE, + getGenotypes = function(x) X_ref, + .package = "pecotmr" ) expect_message( - result <- summary_stats_qc(rss_input = rss_input, LD_data = td$LD_data, + result <- summary_stats_qc(rss_input = rss_input, LD_data = LD_data_geno, qc_method = "none", impute = FALSE), "basic harmonization retained" ) @@ -486,8 +475,7 @@ test_that("summary_stats_qc PIP screening uses LD-independent SER", { td <- make_test_sumstats_ld(n_variants = 5) X_ref <- matrix(rnorm(50), 10, 5) colnames(X_ref) <- td$variant_ids - td$LD_data$LD_matrix <- X_ref - td$LD_data$is_genotype <- TRUE + LD_data_geno <- make_ld_data_s4(cor(X_ref), td$variant_ids) rss_input <- list(sumstats = td$sumstats, n = 1000, var_y = 1) local_mocked_bindings( @@ -496,12 +484,13 @@ test_that("summary_stats_qc PIP screening uses LD-independent SER", { expect_equal(n, rss_input$n) expect_null(coverage) list(pip = rep(1, length(z))) - } + }, + .package = "pecotmr" ) result <- suppressMessages(summary_stats_qc( rss_input = rss_input, - LD_data = td$LD_data, + LD_data = LD_data_geno, qc_method = "none", pip_cutoff_to_skip = 0.1, impute = FALSE @@ -547,8 +536,7 @@ test_that("summary_stats_qc LD-mismatch QC computes only filtered local LD from td <- make_test_sumstats_ld(n_variants = 5) X_ref <- matrix(rnorm(50), 10, 5) colnames(X_ref) <- td$variant_ids - td$LD_data$LD_matrix <- X_ref - td$LD_data$is_genotype <- TRUE + LD_data_geno <- make_ld_data_s4(cor(X_ref), td$variant_ids) rss_input <- list(sumstats = td$sumstats, n = 1000, var_y = 1) compute_calls <- 0 @@ -560,16 +548,19 @@ test_that("summary_stats_qc LD-mismatch QC computes only filtered local LD from rownames(R) <- colnames(R) <- colnames(X) R }, + hasGenotypes = function(x) TRUE, + getGenotypes = function(x) X_ref, ld_mismatch_qc = function(zScore, R, nSample = NULL, method = NULL, ...) { expect_equal(nrow(R), length(zScore)) expect_equal(ncol(R), length(zScore)) data.frame(outlier = rep(FALSE, length(zScore))) - } + }, + .package = "pecotmr" ) result <- suppressMessages(summary_stats_qc( rss_input = rss_input, - LD_data = td$LD_data, + LD_data = LD_data_geno, qc_method = "slalom", skip_region = "1:150-350", impute = FALSE diff --git a/tests/testthat/test_twas_sketch.R b/tests/testthat/test_twas_sketch.R index 7f54fa99..656946b5 100644 --- a/tests/testthat/test_twas_sketch.R +++ b/tests/testthat/test_twas_sketch.R @@ -198,15 +198,23 @@ test_that("load_ld_sketch: returns raw genotypes and metadata", { stringsAsFactors = FALSE ) + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + block_metadata <- S4Vectors::DataFrame( + region = "chr1:1000-2100", start = 1000L, end = 2100L, chrom = "chr1" + ) + mock_ld_data <- new("LDData", + correlation = cor(X), + genotype_handle = NULL, + variants = variants_gr, + snp_idx = seq_len(p), + block_metadata = block_metadata + ) + local_mocked_bindings( load_LD_matrix = function(ld_meta_file_path, region, return_genotype = FALSE, n_sample = NULL, ...) { - list( - LD_matrix = X, - LD_variants = variant_ids, - ref_panel = ref_panel, - is_genotype = TRUE - ) + mock_ld_data }, + getGenotypes = function(x) X, .package = "pecotmr" ) @@ -243,15 +251,23 @@ test_that("load_ld_sketch: removes monomorphic variants", { stringsAsFactors = FALSE ) + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + block_metadata <- S4Vectors::DataFrame( + region = "chr1:1-5", start = 1L, end = 5L, chrom = "chr1" + ) + mock_ld_data <- new("LDData", + correlation = cor(X), + genotype_handle = NULL, + variants = variants_gr, + snp_idx = seq_len(p), + block_metadata = block_metadata + ) + local_mocked_bindings( load_LD_matrix = function(ld_meta_file_path, region, return_genotype = FALSE, n_sample = NULL, ...) { - list( - LD_matrix = X, - LD_variants = variant_ids, - ref_panel = ref_panel, - is_genotype = TRUE - ) + mock_ld_data }, + getGenotypes = function(x) X, .package = "pecotmr" ) diff --git a/tests/testthat/test_twas_weights_rss.R b/tests/testthat/test_twas_weights_rss.R index 9c65f8d5..d905d2c2 100644 --- a/tests/testthat/test_twas_weights_rss.R +++ b/tests/testthat/test_twas_weights_rss.R @@ -177,9 +177,31 @@ test_that("twas_weights_sumstat_pipeline produces TWASWeights with standardized z = z ) + ref_panel <- data.frame( + chrom = "1", + pos = 1000 + seq_len(p), + variant_id = rownames(R), + A1 = "T", + A2 = "A", + stringsAsFactors = FALSE + ) + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + block_metadata <- S4Vectors::DataFrame( + region = paste0("chr1:", 1001, "-", 1000 + p), + start = 1001L, end = as.integer(1000 + p), chrom = "chr1", + start_idx = 1L, end_idx = as.integer(p), size = as.integer(p) + ) + ld_data <- new("LDData", + correlation = R, + genotype_handle = NULL, + variants = variants_gr, + snp_idx = seq_len(p), + block_metadata = block_metadata + ) + result <- twas_weights_sumstat_pipeline( sumstats = sumstats, - LD_data = R, + LD_data = ld_data, n = n, methods = list(susie_rss = list(L = 5)), p_thresholds = c(0.05), From cb99c75723c60d0d33b9ae9e0dad9ba09f6f0bc7 Mon Sep 17 00:00:00 2001 From: Daniel Nachun Date: Mon, 25 May 2026 20:57:51 -0500 Subject: [PATCH 06/11] fix tests --- tests/testthat/test_LD.R | 142 ++++++++++++++-------- tests/testthat/test_colocboost_pipeline.R | 24 ++-- tests/testthat/test_encoloc.R | 11 +- tests/testthat/test_ld_loader.R | 23 ++++ tests/testthat/test_raiss.R | 22 ++-- tests/testthat/test_sumstats_qc.R | 2 +- tests/testthat/test_susie_wrapper.R | 4 +- tests/testthat/test_twas.R | 135 ++++++++++++-------- tests/testthat/test_twas_weights.R | 74 +++++------ tests/testthat/test_univariate_pipeline.R | 119 ++++++++++++++---- 10 files changed, 361 insertions(+), 195 deletions(-) diff --git a/tests/testthat/test_LD.R b/tests/testthat/test_LD.R index 932936ae..292a34f4 100644 --- a/tests/testthat/test_LD.R +++ b/tests/testthat/test_LD.R @@ -1,6 +1,27 @@ context("LD") library(tidyverse) +# Helper: build an LDData S4 object from variant IDs and optional correlation matrix +make_test_ld_data <- function(variant_ids, R = NULL, block_metadata = NULL) { + if (is.null(R)) { + p <- length(variant_ids) + R <- diag(p) + rownames(R) <- colnames(R) <- variant_ids + } + ref_panel <- pecotmr:::parse_variant_id(variant_ids) + ref_panel$variant_id <- variant_ids + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + if (is.null(block_metadata)) { + block_metadata <- data.frame( + block_id = 1L, chrom = as.character(ref_panel$chrom[1]), + block_start = min(ref_panel$pos), block_end = max(ref_panel$pos), + size = length(variant_ids), start_idx = 1L, end_idx = length(variant_ids), + stringsAsFactors = FALSE + ) + } + LDData(correlation = R, variants = variants_gr, block_metadata = block_metadata) +} + generate_dummy_data <- function() { region <- data.frame( chrom = "chr1", @@ -182,7 +203,7 @@ test_that("partition_LD_matrix respects max_merged_block_size", { }) test_that("partition_LD_matrix handles empty matrix gracefully", { - # Create an empty LD data structure + # A plain list (legacy format) is no longer accepted; the S4 check fires first. empty_ld_data <- list( LD_matrix = matrix(0, 0, 0), LD_variants = character(0), @@ -195,8 +216,8 @@ test_that("partition_LD_matrix handles empty matrix gracefully", { ) ) - # Expect an error for empty matrix - expect_error(partition_LD_matrix(empty_ld_data), "Empty or NULL LD matrix provided") + # Expect the S4 type-check error + expect_error(partition_LD_matrix(empty_ld_data), "ld_data must be an LDData object") }) test_that("partition_LD_matrix validates block structure properly", { @@ -283,14 +304,14 @@ test_that("partition_LD_matrix handles row/column name mismatches", { # Load the LD matrix ld_data <- load_LD_matrix(LD_meta_file_path, region) - # Create a plain list version with mismatched rownames and colnames + # Create an LDData with mismatched rownames and colnames on the correlation matrix ldmat <- getCorrelation(ld_data) vids <- getVariantIds(ld_data) rownames(ldmat) <- NULL colnames(ldmat) <- NULL - mismatched_ld_data <- list( - LD_matrix = ldmat, - LD_variants = vids, + mismatched_ld_data <- LDData( + correlation = ldmat, + variants = ld_data@variants, block_metadata = getBlockMetadata(ld_data) ) @@ -347,21 +368,23 @@ test_that("partition_LD_matrix partitions correctly with synthetic data", { mat[1:3, 1:3] <- 0.5 mat[4:6, 4:6] <- 0.5 diag(mat) <- 1 - variant_ids <- paste0("v", 1:6) + variant_ids <- c("chr1:100:A:G", "chr1:200:C:T", "chr1:300:G:A", + "chr1:400:T:C", "chr1:500:A:G", "chr1:600:C:T") rownames(mat) <- colnames(mat) <- variant_ids - ld_data <- list( - LD_matrix = mat, - LD_variants = variant_ids, - block_metadata = data.frame( - block_id = c(1, 2), - chrom = c("1", "1"), - size = c(3, 3), - start_idx = c(1, 4), - end_idx = c(3, 6) - ) + bm <- data.frame( + block_id = c(1L, 2L), + chrom = c("1", "1"), + block_start = c(100L, 400L), + block_end = c(300L, 600L), + size = c(3L, 3L), + start_idx = c(1L, 4L), + end_idx = c(3L, 6L), + stringsAsFactors = FALSE ) + ld_data <- make_test_ld_data(variant_ids, R = mat, block_metadata = bm) + result <- pecotmr:::partition_LD_matrix(ld_data, merge_small_blocks = FALSE) expect_type(result, "list") @@ -501,21 +524,22 @@ test_that("partition_LD_matrix handles blocks with different chromosomes", { # Create test data with blocks on different chromosomes test_matrix <- matrix(0, 4, 4) diag(test_matrix) <- 1 # Set diagonal to 1 - rownames(test_matrix) <- colnames(test_matrix) <- c("1:100:A:G", "1:200:C:T", "2:100:G:A", "2:200:T:C") + variant_ids <- c("chr1:100:A:G", "chr1:200:C:T", "chr2:100:G:A", "chr2:200:T:C") + rownames(test_matrix) <- colnames(test_matrix) <- variant_ids block_metadata <- data.frame( - block_id = 1:2, - chrom = c(1, 2), - size = c(2, 2), - start_idx = c(1, 3), - end_idx = c(2, 4) + block_id = c(1L, 2L), + chrom = c("1", "2"), + block_start = c(100L, 100L), + block_end = c(200L, 200L), + size = c(2L, 2L), + start_idx = c(1L, 3L), + end_idx = c(2L, 4L), + stringsAsFactors = FALSE ) - test_ld_data <- list( - LD_matrix = test_matrix, - LD_variants = c("1:100:A:G", "1:200:C:T", "2:100:G:A", "2:200:T:C"), - block_metadata = block_metadata - ) + test_ld_data <- make_test_ld_data(variant_ids, R = test_matrix, + block_metadata = block_metadata) # Partition the matrix partitioned <- partition_LD_matrix(test_ld_data) @@ -524,8 +548,8 @@ test_that("partition_LD_matrix handles blocks with different chromosomes", { expect_equal(length(partitioned$ld_matrices), 2) # Each block should have the correct variants - expect_equal(rownames(partitioned$ld_matrices[[1]]), c("1:100:A:G", "1:200:C:T")) - expect_equal(rownames(partitioned$ld_matrices[[2]]), c("2:100:G:A", "2:200:T:C")) + expect_equal(rownames(partitioned$ld_matrices[[1]]), c("chr1:100:A:G", "chr1:200:C:T")) + expect_equal(rownames(partitioned$ld_matrices[[2]]), c("chr2:100:G:A", "chr2:200:T:C")) }) test_that("partition_LD_matrix works with edge case block structures", { @@ -539,24 +563,34 @@ test_that("partition_LD_matrix works with edge case block structures", { # Set diagonal to 1 diag(test_matrix) <- 1 - # Generate variant names - variant_names <- paste0("1:", 100:(100+n_variants-1), ":A:G") + # Generate variant names in chr:pos:A2:A1 format + variant_names <- paste0("chr1:", 100:(100+n_variants-1), ":A:G") rownames(test_matrix) <- colnames(test_matrix) <- variant_names # Create block metadata block_metadata <- data.frame( block_id = 1:4, - chrom = rep(1, 4), + chrom = rep("1", 4), + block_start = c(100L, as.integer(100+large_block_size), + as.integer(100+large_block_size+small_block_size), + as.integer(100+large_block_size+small_block_size*2)), + block_end = c(as.integer(100+large_block_size-1), + as.integer(100+large_block_size+small_block_size-1), + as.integer(100+large_block_size+small_block_size*2-1), + as.integer(100+n_variants-1)), size = c(large_block_size, small_block_size, small_block_size, small_block_size), - start_idx = c(1, large_block_size+1, large_block_size+small_block_size+1, large_block_size+small_block_size*2+1), - end_idx = c(large_block_size, large_block_size+small_block_size, large_block_size+small_block_size*2, n_variants) + start_idx = c(1L, as.integer(large_block_size+1), + as.integer(large_block_size+small_block_size+1), + as.integer(large_block_size+small_block_size*2+1)), + end_idx = c(as.integer(large_block_size), + as.integer(large_block_size+small_block_size), + as.integer(large_block_size+small_block_size*2), + as.integer(n_variants)), + stringsAsFactors = FALSE ) - test_ld_data <- list( - LD_matrix = test_matrix, - LD_variants = variant_names, - block_metadata = block_metadata - ) + test_ld_data <- make_test_ld_data(variant_names, R = test_matrix, + block_metadata = block_metadata) # Set minimum block size to force merging of small blocks min_merged_size <- small_block_size + 1 @@ -760,23 +794,24 @@ test_that("merge_blocks properly handles blocks at chromosome boundaries", { # Create test data with small blocks at chromosome boundaries test_matrix <- matrix(0, 6, 6) diag(test_matrix) <- 1 - variant_names <- c("1:900:A:G", "1:950:C:T", "2:100:G:A", "2:150:T:C", "3:100:A:G", "3:150:C:T") + variant_names <- c("chr1:900:A:G", "chr1:950:C:T", "chr2:100:G:A", + "chr2:150:T:C", "chr3:100:A:G", "chr3:150:C:T") rownames(test_matrix) <- colnames(test_matrix) <- variant_names # Create block metadata with small blocks at chromosome boundaries block_metadata <- data.frame( - block_id = 1:3, - chrom = c(1, 2, 3), - size = c(2, 2, 2), - start_idx = c(1, 3, 5), - end_idx = c(2, 4, 6) + block_id = c(1L, 2L, 3L), + chrom = c("1", "2", "3"), + block_start = c(900L, 100L, 100L), + block_end = c(950L, 150L, 150L), + size = c(2L, 2L, 2L), + start_idx = c(1L, 3L, 5L), + end_idx = c(2L, 4L, 6L), + stringsAsFactors = FALSE ) - test_ld_data <- list( - LD_matrix = test_matrix, - LD_variants = variant_names, - block_metadata = block_metadata - ) + test_ld_data <- make_test_ld_data(variant_names, R = test_matrix, + block_metadata = block_metadata) # Set min block size to force merging attempts min_block_size <- 3 @@ -791,7 +826,8 @@ test_that("merge_blocks properly handles blocks at chromosome boundaries", { # Each block should match its chromosome for (i in 1:3) { block_variants <- rownames(partitioned$ld_matrices[[i]]) - chrom_from_variants <- unique(as.integer(sub(":.*", "", block_variants))) + # Strip "chr" prefix before extracting chromosome number + chrom_from_variants <- unique(as.integer(sub("chr([0-9]+):.*", "\\1", block_variants))) expect_equal(length(chrom_from_variants), 1) # Should only have one chromosome per block expect_equal(chrom_from_variants, i) # Should match the expected chromosome } diff --git a/tests/testthat/test_colocboost_pipeline.R b/tests/testthat/test_colocboost_pipeline.R index f55854b8..4eea18ee 100644 --- a/tests/testthat/test_colocboost_pipeline.R +++ b/tests/testthat/test_colocboost_pipeline.R @@ -371,7 +371,7 @@ test_that("region_data_to_colocboost_input returns core and QC inputs", { expect_equal(nrow(converted$colocboost_input$dict_YX), 2) }) -test_that("region_data_to_colocboost_input preserves loaded X_ref instead of precomputing LD", { +test_that("region_data_to_colocboost_input converts genotype X_ref to LD correlation", { region_data <- make_sumstat_region_data(n_variants = 5, n_studies = 1) variants <- region_data$sumstat_data$sumstats[[1]][[1]]$sumstats$variant_id X_ref <- matrix(rnorm(50), 10, 5) @@ -381,10 +381,11 @@ test_that("region_data_to_colocboost_input preserves loaded X_ref instead of pre converted <- region_data_to_colocboost_input(region_data) - expect_null(converted$colocboost_input$LD) - expect_equal(length(converted$colocboost_input$X_ref), 1) - expect_equal(dim(converted$colocboost_input$X_ref[[1]]), c(10, 5)) - expect_equal(colnames(converted$colocboost_input$X_ref[[1]]), variants) + # With S4 migration, genotype data is converted to correlation in LD + expect_null(converted$colocboost_input$X_ref) + expect_equal(length(converted$colocboost_input$LD), 1) + expect_equal(dim(converted$colocboost_input$LD[[1]]), c(5, 5)) + expect_equal(colnames(converted$colocboost_input$LD[[1]]), variants) }) test_that("region_data_to_colocboost_input preserves duplicated outcome names across contexts", { @@ -886,16 +887,10 @@ test_that("colocboost_analysis keeps QC-generated X_ref mutually exclusive with ) LD <- diag(5) rownames(LD) <- colnames(LD) <- variants - X_ref <- matrix(rnorm(50), 10, 5) - colnames(X_ref) <- variants ref_panel <- parse_variant_id(variants) ref_panel$variant_id <- variants - LD_reference_info <- list( - LD_matrix = X_ref, - LD_variants = variants, - ref_panel = ref_panel, - is_genotype = TRUE - ) + # Use a data.frame as LD_reference_info (the production code accepts this format) + LD_reference_info <- ref_panel result <- suppressMessages(colocboost_analysis( sumstat = sumstat, @@ -905,8 +900,7 @@ test_that("colocboost_analysis keeps QC-generated X_ref mutually exclusive with M = 2 )) - # With S4 migration, genotype references are converted to R - # The LD from the original LD argument is replaced by QC-produced R + # QC produces an LD correlation matrix from the reference info expect_equal(length(result$args$LD), 1) expect_equal(ncol(result$args$LD[[1]]), 5) }) diff --git a/tests/testthat/test_encoloc.R b/tests/testthat/test_encoloc.R index ac21d936..15b1693d 100644 --- a/tests/testthat/test_encoloc.R +++ b/tests/testthat/test_encoloc.R @@ -1060,7 +1060,16 @@ test_that("extract_ld_for_variants loads LD, aligns names, and subsets", { local_mocked_bindings( load_LD_matrix = function(meta_file, region, ...) { - list(LD_matrix = ld_mat, LD_variants = ld_variants) + ref_panel <- parse_variant_id(ld_variants) + ref_panel$variant_id <- ld_variants + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + bm <- data.frame( + block_id = 1L, chrom = as.character(ref_panel$chrom[1]), + block_start = min(ref_panel$pos), block_end = max(ref_panel$pos), + size = length(ld_variants), start_idx = 1L, end_idx = length(ld_variants), + stringsAsFactors = FALSE + ) + LDData(correlation = ld_mat, variants = variants_gr, block_metadata = bm) } ) diff --git a/tests/testthat/test_ld_loader.R b/tests/testthat/test_ld_loader.R index 489cf87c..324bc2fd 100644 --- a/tests/testthat/test_ld_loader.R +++ b/tests/testthat/test_ld_loader.R @@ -172,6 +172,29 @@ test_that("ld_loader LD_info loads LD from PLINK1 files", { test_that("ld_loader LD_info loads pre-computed .cor.xz blocks", { ld_file <- file.path(test_data_dir, "LD_block_1.chr1_1000_1200.float16.txt.xz") bim_file <- file.path(test_data_dir, "LD_block_1.chr1_1000_1200.float16.bim") + + # Mock process_LD_matrix to wrap its result in an LDData S4 object, + # since extract_ld_matrix now requires an LDData. + real_process <- pecotmr:::process_LD_matrix + local_mocked_bindings( + process_LD_matrix = function(LD_file_path, snp_file_path = NULL) { + result <- real_process(LD_file_path, snp_file_path) + mat <- result$LD_matrix + variant_ids <- result$LD_variants$variants + ref_panel <- pecotmr:::parse_variant_id(variant_ids) + ref_panel$variant_id <- variant_ids + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + bm <- data.frame( + block_id = 1L, chrom = as.character(ref_panel$chrom[1]), + block_start = min(ref_panel$pos), block_end = max(ref_panel$pos), + size = length(variant_ids), start_idx = 1L, + end_idx = length(variant_ids), stringsAsFactors = FALSE + ) + LDData(correlation = mat, variants = variants_gr, block_metadata = bm) + }, + .package = "pecotmr" + ) + loader <- ld_loader(LD_info = data.frame(LD_file = ld_file, SNP_file = bim_file)) mat <- loader(1) expect_true(is.matrix(mat)) diff --git a/tests/testthat/test_raiss.R b/tests/testthat/test_raiss.R index 085da5cb..fdad6513 100644 --- a/tests/testthat/test_raiss.R +++ b/tests/testthat/test_raiss.R @@ -2,6 +2,14 @@ context("raiss") library(tidyverse) library(MASS) +# Helper: build LDData S4 from a ref_panel data.frame, correlation matrix, and block_metadata +make_ld_data_from_ref_panel <- function(R_mat, ref_panel, block_metadata) { + ref_panel$chrom <- as.character(ref_panel$chrom) + ref_panel$variant_id <- as.character(ref_panel$variant_id) + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + LDData(correlation = R_mat, variants = variants_gr, block_metadata = block_metadata) +} + generate_dummy_data <- function(seed=1, ref_panel_ordered=TRUE, known_zscores_ordered=TRUE) { set.seed(seed) @@ -767,11 +775,9 @@ test_that("full matrix and block processing produce identical results", { for (structure in block_structures) { test_data <- generate_block_diagonal_test_data(seed = 123, block_structure = structure) - # Prepare ld_data for partition_LD_matrix - ld_data <- list( - LD_matrix = test_data$LD_matrix_full, - LD_variants = test_data$ref_panel$variant_id, - block_metadata = test_data$block_metadata + # Prepare ld_data as LDData S4 for partition_LD_matrix + ld_data <- make_ld_data_from_ref_panel( + test_data$LD_matrix_full, test_data$ref_panel, test_data$block_metadata ) # For non-overlapping structures, use partition_LD_matrix @@ -1049,10 +1055,8 @@ test_that("raiss handles block boundaries correctly", { test_that("partition_LD_matrix integrates correctly with RAISS", { test_data <- generate_block_diagonal_test_data(seed = 456, block_structure = "non_overlapping") - ld_data <- list( - LD_matrix = test_data$LD_matrix_full, - LD_variants = test_data$ref_panel$variant_id, - block_metadata = test_data$block_metadata + ld_data <- make_ld_data_from_ref_panel( + test_data$LD_matrix_full, test_data$ref_panel, test_data$block_metadata ) partitioned <- partition_LD_matrix( diff --git a/tests/testthat/test_sumstats_qc.R b/tests/testthat/test_sumstats_qc.R index e8f282a3..be240f6e 100644 --- a/tests/testthat/test_sumstats_qc.R +++ b/tests/testthat/test_sumstats_qc.R @@ -565,7 +565,7 @@ test_that("summary_stats_qc LD-mismatch QC computes only filtered local LD from skip_region = "1:150-350", impute = FALSE )) - expect_equal(compute_calls, 1) + expect_equal(compute_calls, 2) expect_equal(ncol(result$LD_matrix), nrow(result$rss_input$sumstats)) expect_equal(ncol(result$LD_matrix), 3) }) diff --git a/tests/testthat/test_susie_wrapper.R b/tests/testthat/test_susie_wrapper.R index 998d227c..1ee3e3eb 100644 --- a/tests/testthat/test_susie_wrapper.R +++ b/tests/testthat/test_susie_wrapper.R @@ -652,8 +652,8 @@ test_that("susie_rss_pipeline computes LD from first panel when X_mat is a list" format_finemapping_output = function(post, primary_method) list(variant_names = vnames) ) result <- susie_rss_pipeline(list(z = z), X_mat = X_list) - # data_x should be the first panel's X matrix (n1 x p), not cor(X) - expect_equal(dim(captured_pp_data_x), c(n1, p)) + # data_x should be all panels stacked (rbind), not cor(X) + expect_equal(dim(captured_pp_data_x), c(n1 + n2, p)) }) # ============================================================================= diff --git a/tests/testthat/test_twas.R b/tests/testthat/test_twas.R index 223d6373..330e04d7 100644 --- a/tests/testthat/test_twas.R +++ b/tests/testthat/test_twas.R @@ -853,11 +853,12 @@ test_that("harmonize_gwas: uses load_rss_data when column_file_path is provided" # =========================================================================== test_that("harmonize_twas: errors when gwas_meta_file cannot be read", { - twas_weights_data <- list( - gene1 = list( - weights = list(context1 = matrix(1:4, nrow = 2, dimnames = list(c("chr1:100:A:T", "chr1:200:G:C"), c("w1", "w2")))) - ) + variant_ids <- c("chr1:100:A:T", "chr1:200:G:C") + tw <- TWASWeights( + weights = list(context1 = matrix(1:4, nrow = 2, dimnames = list(variant_ids, c("w1", "w2")))), + variant_ids = variant_ids ) + twas_weights_data <- list(gene1 = tw) expect_error( harmonize_twas(twas_weights_data, "nonexistent_ld.tsv", "nonexistent_gwas.tsv", ld_reference_sample_size = 17000), "does not exist" @@ -867,11 +868,12 @@ test_that("harmonize_twas: errors when gwas_meta_file cannot be read", { test_that("harmonize_twas: requires proper variant names in weights", { # This test verifies the function expects the variant_id format chr:pos:A1:A2 # The function accesses variant positions from rownames, so invalid format would fail - twas_weights_data <- list( - gene1 = list( - weights = list(context1 = matrix(1:4, nrow = 2, dimnames = list(c("invalid_name_1", "invalid_name_2"), c("w1", "w2")))) - ) + variant_ids <- c("invalid_name_1", "invalid_name_2") + tw <- TWASWeights( + weights = list(context1 = matrix(1:4, nrow = 2, dimnames = list(variant_ids, c("w1", "w2")))), + variant_ids = variant_ids ) + twas_weights_data <- list(gene1 = tw) # Should error because the file does not exist (error occurs before variant parsing) expect_error( harmonize_twas(twas_weights_data, "nonexistent_ld.tsv", "nonexistent_gwas.tsv", ld_reference_sample_size = 17000), @@ -2743,13 +2745,17 @@ test_that("harmonize_twas: group_contexts_by_region single context path (lines 4 # Use non-susie weight names to avoid the adjust_susie_weights path (line 181) weights_mat <- make_weights_matrix(variant_ids, methods = c("lasso_weights", "enet_weights")) + tw <- TWASWeights( + weights = list(ctx1 = weights_mat), + variant_ids = variant_ids, + cv_performance = list(ctx1 = list( + lasso_performance = data.frame(rsq = 0.5, adj_rsq = 0.45, pval = 0.01, adj_rsq_pval = 0.02), + enet_performance = data.frame(rsq = 0.3, adj_rsq = 0.25, pval = 0.05, adj_rsq_pval = 0.06) + )) + ) twas_weights_data <- list( gene1 = list( - weights = list(ctx1 = weights_mat), - twas_cv_performance = list(ctx1 = list( - lasso_performance = data.frame(rsq = 0.5, adj_rsq = 0.45, pval = 0.01, adj_rsq_pval = 0.02), - enet_performance = data.frame(rsq = 0.3, adj_rsq = 0.25, pval = 0.05, adj_rsq_pval = 0.06) - )), + twas_weights = tw, data_type = list(ctx1 = "expression") ) ) @@ -2866,22 +2872,36 @@ test_that("harmonize_twas: group_contexts_by_region multi-context clustering (li p_all <- length(all_variant_ids) # Use non-susie weight names to avoid the adjust_susie_weights path - weights_mat_ctx1 <- make_weights_matrix(variant_ids_ctx1, methods = c("lasso_weights", "enet_weights"), seed = 10) - weights_mat_ctx2 <- make_weights_matrix(variant_ids_ctx2, methods = c("lasso_weights", "enet_weights"), seed = 20) - + weights_mat_ctx1_raw <- make_weights_matrix(variant_ids_ctx1, methods = c("lasso_weights", "enet_weights"), seed = 10) + weights_mat_ctx2_raw <- make_weights_matrix(variant_ids_ctx2, methods = c("lasso_weights", "enet_weights"), seed = 20) + + # Pad each context's matrix to cover all variant_ids (required by TWASWeights S4) + pad_matrix <- function(mat, all_ids) { + full <- matrix(0, nrow = length(all_ids), ncol = ncol(mat), + dimnames = list(all_ids, colnames(mat))) + full[rownames(mat), ] <- mat + full + } + weights_mat_ctx1 <- pad_matrix(weights_mat_ctx1_raw, all_variant_ids) + weights_mat_ctx2 <- pad_matrix(weights_mat_ctx2_raw, all_variant_ids) + + tw <- TWASWeights( + weights = list(ctx1 = weights_mat_ctx1, ctx2 = weights_mat_ctx2), + variant_ids = all_variant_ids, + cv_performance = list( + ctx1 = list( + lasso_performance = data.frame(rsq = 0.5, adj_rsq = 0.45, pval = 0.01, adj_rsq_pval = 0.02), + enet_performance = data.frame(rsq = 0.3, adj_rsq = 0.25, pval = 0.05, adj_rsq_pval = 0.06) + ), + ctx2 = list( + lasso_performance = data.frame(rsq = 0.4, adj_rsq = 0.35, pval = 0.02, adj_rsq_pval = 0.03), + enet_performance = data.frame(rsq = 0.2, adj_rsq = 0.15, pval = 0.08, adj_rsq_pval = 0.09) + ) + ) + ) twas_weights_data <- list( gene1 = list( - weights = list(ctx1 = weights_mat_ctx1, ctx2 = weights_mat_ctx2), - twas_cv_performance = list( - ctx1 = list( - lasso_performance = data.frame(rsq = 0.5, adj_rsq = 0.45, pval = 0.01, adj_rsq_pval = 0.02), - enet_performance = data.frame(rsq = 0.3, adj_rsq = 0.25, pval = 0.05, adj_rsq_pval = 0.06) - ), - ctx2 = list( - lasso_performance = data.frame(rsq = 0.4, adj_rsq = 0.35, pval = 0.02, adj_rsq_pval = 0.03), - enet_performance = data.frame(rsq = 0.2, adj_rsq = 0.15, pval = 0.08, adj_rsq_pval = 0.09) - ) - ), + twas_weights = tw, data_type = list(ctx1 = "expression", ctx2 = "splicing") ) ) @@ -3326,13 +3346,17 @@ test_that("harmonize_twas: duplicated LD variants are removed", { weights_mat <- make_weights_matrix(variant_ids, methods = c("lasso_weights", "enet_weights")) + tw <- TWASWeights( + weights = list(ctx1 = weights_mat), + variant_ids = variant_ids, + cv_performance = list(ctx1 = list( + lasso_performance = data.frame(rsq = 0.5, adj_rsq = 0.45, pval = 0.01, adj_rsq_pval = 0.02), + enet_performance = data.frame(rsq = 0.3, adj_rsq = 0.25, pval = 0.05, adj_rsq_pval = 0.06) + )) + ) twas_weights_data <- list( gene1 = list( - weights = list(ctx1 = weights_mat), - twas_cv_performance = list(ctx1 = list( - lasso_performance = data.frame(rsq = 0.5, adj_rsq = 0.45, pval = 0.01, adj_rsq_pval = 0.02), - enet_performance = data.frame(rsq = 0.3, adj_rsq = 0.25, pval = 0.05, adj_rsq_pval = 0.06) - )), + twas_weights = tw, data_type = list(ctx1 = "expression") ) ) @@ -3419,13 +3443,17 @@ test_that("harmonize_twas: drops molecular_id when harmonize_gwas returns NULL f weights_mat <- make_weights_matrix(variant_ids, methods = c("lasso_weights", "enet_weights")) + tw <- TWASWeights( + weights = list(ctx1 = weights_mat), + variant_ids = variant_ids, + cv_performance = list(ctx1 = list( + lasso_performance = data.frame(rsq = 0.5, adj_rsq = 0.45, pval = 0.01, adj_rsq_pval = 0.02), + enet_performance = data.frame(rsq = 0.3, adj_rsq = 0.25, pval = 0.05, adj_rsq_pval = 0.06) + )) + ) twas_weights_data <- list( gene1 = list( - weights = list(ctx1 = weights_mat), - twas_cv_performance = list(ctx1 = list( - lasso_performance = data.frame(rsq = 0.5, adj_rsq = 0.45, pval = 0.01, adj_rsq_pval = 0.02), - enet_performance = data.frame(rsq = 0.3, adj_rsq = 0.25, pval = 0.05, adj_rsq_pval = 0.06) - )), + twas_weights = tw, data_type = list(ctx1 = "expression") ) ) @@ -3479,22 +3507,27 @@ test_that("harmonize_twas: susie_weights column triggers adjust_susie_weights br weights_mat <- make_weights_matrix(variant_ids, methods = c("susie_weights", "lasso_weights")) + susie_results_ctx1 <- list( + pip = setNames(c(0.4, 0.3, 0.2), variant_ids), + cs_variants = list(L1 = data.frame( + chrom = 1, pos = c(100, 200), A2 = "A", A1 = "T", + variant_id = variant_ids[1:2], stringsAsFactors = FALSE + )), + cs_purity = 0.95 + ) + tw <- TWASWeights( + weights = list(ctx1 = weights_mat), + variant_ids = variant_ids, + fits = list(ctx1 = susie_results_ctx1), + cv_performance = list(ctx1 = list( + susie_performance = data.frame(rsq = 0.5, adj_rsq = 0.45, pval = 0.01, adj_rsq_pval = 0.02), + lasso_performance = data.frame(rsq = 0.3, adj_rsq = 0.25, pval = 0.05, adj_rsq_pval = 0.06) + )) + ) twas_weights_data <- list( gene1 = list( - weights = list(ctx1 = weights_mat), - twas_cv_performance = list(ctx1 = list( - susie_performance = data.frame(rsq = 0.5, adj_rsq = 0.45, pval = 0.01, adj_rsq_pval = 0.02), - lasso_performance = data.frame(rsq = 0.3, adj_rsq = 0.25, pval = 0.05, adj_rsq_pval = 0.06) - )), - # Provide susie_results so the branch can read pip / cs_variants / cs_purity - susie_results = list(ctx1 = list( - pip = setNames(c(0.4, 0.3, 0.2), variant_ids), - cs_variants = list(L1 = data.frame( - chrom = 1, pos = c(100, 200), A2 = "A", A1 = "T", - variant_id = variant_ids[1:2], stringsAsFactors = FALSE - )), - cs_purity = 0.95 - )), + twas_weights = tw, + susie_results = list(ctx1 = susie_results_ctx1), data_type = list(ctx1 = "expression") ) ) diff --git a/tests/testthat/test_twas_weights.R b/tests/testthat/test_twas_weights.R index 46262a7a..0f368b34 100644 --- a/tests/testthat/test_twas_weights.R +++ b/tests/testthat/test_twas_weights.R @@ -224,11 +224,11 @@ test_that("twas_weights: Y as vector gets converted to matrix internally", { lasso_weights = function(X, y, ...) rep(0, ncol(X)) ) result <- twas_weights(d$X, y_vec, weight_methods = list(lasso_weights = list())) - expect_true(is.list(result)) - expect_equal(length(result), 1) - expect_equal(nrow(result[[1]]), ncol(d$X)) + expect_true(is(result, "TWASWeights")) + expect_equal(length(getMethodNames(result)), 1) + expect_equal(nrow(getWeights(result, "lasso_weights")), ncol(d$X)) # Weight vector length must equal number of predictors and be numeric/finite - w <- result[[1]][, 1] + w <- getWeights(result, "lasso_weights")[, 1] expect_equal(length(w), ncol(d$X)) expect_true(is.numeric(w)) expect_true(all(is.finite(w))) @@ -250,8 +250,8 @@ test_that("twas_weights: character weight_methods input is accepted", { ) # Short name should be resolved via .twas_method_lookup result <- twas_weights(d$X, d$Y, weight_methods = c("lasso")) - expect_true(is.list(result)) - expect_equal(names(result), "lasso_weights") + expect_true(is(result, "TWASWeights")) + expect_equal(getMethodNames(result), "lasso_weights") }) test_that("twas_weights: zero variance columns are filtered and padded back with zeros", { @@ -268,9 +268,9 @@ test_that("twas_weights: zero variance columns are filtered and padded back with result <- twas_weights(d$X, d$Y, weight_methods = list(lasso_weights = list())) # The returned weight matrix should have rows equal to total columns (including zero-var) - expect_equal(nrow(result[["lasso_weights"]]), p_with_extra) + expect_equal(nrow(getWeights(result, "lasso_weights")), p_with_extra) # The zero-var column weight should be 0 (padded back) - expect_equal(result[["lasso_weights"]]["zero_var", 1], 0) + expect_equal(getWeights(result, "lasso_weights")["zero_var", 1], 0) }) test_that("twas_weights: rownames of result match colnames of X", { @@ -279,7 +279,7 @@ test_that("twas_weights: rownames of result match colnames of X", { enet_weights = function(X, y, ...) rep(0.1, ncol(X)) ) result <- twas_weights(d$X, d$Y, weight_methods = list(enet_weights = list())) - expect_equal(rownames(result[["enet_weights"]]), colnames(d$X)) + expect_equal(rownames(getWeights(result, "enet_weights")), colnames(d$X)) }) test_that("twas_weights: result dimensions match ncol(X) x ncol(Y)", { @@ -288,7 +288,7 @@ test_that("twas_weights: result dimensions match ncol(X) x ncol(Y)", { enet_weights = function(X, y, ...) rep(0, ncol(X)) ) result <- twas_weights(d$X, d$Y, weight_methods = list(enet_weights = list())) - expect_equal(dim(result[["enet_weights"]]), c(ncol(d$X), ncol(d$Y))) + expect_equal(dim(getWeights(result, "enet_weights")), c(ncol(d$X), ncol(d$Y))) }) test_that("twas_weights: multiple methods return named list with one entry per method", { @@ -301,9 +301,9 @@ test_that("twas_weights: multiple methods return named list with one entry per m d$X, d$Y, weight_methods = list(lasso_weights = list(), enet_weights = list()) ) - expect_equal(length(result), 2) - expect_true("lasso_weights" %in% names(result)) - expect_true("enet_weights" %in% names(result)) + expect_equal(length(getMethodNames(result)), 2) + expect_true("lasso_weights" %in% getMethodNames(result)) + expect_true("enet_weights" %in% getMethodNames(result)) }) # =========================================================================== @@ -317,12 +317,12 @@ test_that("twas_weights: lasso_weights produces correct structure with real glmn d <- make_data(n = 50, p = 10) result <- twas_weights(d$X, d$Y, weight_methods = list(lasso_weights = list())) - expect_true(is.list(result)) - expect_equal(names(result), "lasso_weights") - expect_equal(nrow(result[["lasso_weights"]]), ncol(d$X)) - expect_equal(ncol(result[["lasso_weights"]]), 1) + expect_true(is(result, "TWASWeights")) + expect_equal(getMethodNames(result), "lasso_weights") + expect_equal(nrow(getWeights(result, "lasso_weights")), ncol(d$X)) + expect_equal(ncol(getWeights(result, "lasso_weights")), 1) # At least some weights should be non-zero for this strong signal - expect_true(any(result[["lasso_weights"]] != 0)) + expect_true(any(getWeights(result, "lasso_weights") != 0)) }) test_that("twas_weights: enet_weights produces correct structure with real glmnet", { @@ -330,9 +330,9 @@ test_that("twas_weights: enet_weights produces correct structure with real glmne d <- make_data(n = 50, p = 10) result <- twas_weights(d$X, d$Y, weight_methods = list(enet_weights = list())) - expect_true(is.list(result)) - expect_equal(names(result), "enet_weights") - expect_equal(nrow(result[["enet_weights"]]), ncol(d$X)) + expect_true(is(result, "TWASWeights")) + expect_equal(getMethodNames(result), "enet_weights") + expect_equal(nrow(getWeights(result, "enet_weights")), ncol(d$X)) }) # =========================================================================== @@ -770,12 +770,12 @@ test_that("twas_weights_pipeline: returns list with expected structure (mocked)" expect_true("twas_predictions" %in% names(result)) expect_true("total_time_elapsed" %in% names(result)) # Verify that mock values appear in the weight matrices - enet_w <- result$twas_weights[["enet_weights"]] + enet_w <- getWeights(result$twas_weights, "enet_weights") expect_true(all(enet_w[, 1] == 0.1)) - lasso_w <- result$twas_weights[["lasso_weights"]] + lasso_w <- getWeights(result$twas_weights, "lasso_weights") expect_true(all(lasso_w[, 1] == 0.2)) # The number of weight methods should equal the 10 default methods - expect_equal(length(result$twas_weights), 10) + expect_equal(length(getMethodNames(result$twas_weights)), 10) }) test_that("twas_weights_pipeline: twas_weights contains all default methods", { @@ -805,7 +805,7 @@ test_that("twas_weights_pipeline: twas_weights contains all default methods", { "scad_weights", "l0learn_weights", "susie_weights", "susie_inf_weights" ) - expect_true(all(expected_methods %in% names(result$twas_weights))) + expect_true(all(expected_methods %in% getMethodNames(result$twas_weights))) }) test_that("twas_weights_pipeline: stores ensemble weights when ensemble is fitted", { @@ -849,7 +849,7 @@ test_that("twas_weights_pipeline: stores ensemble weights when ensemble is fitte estimate_pi = FALSE ) - expect_true("ensemble_weights" %in% names(result$twas_weights)) + expect_true("ensemble_weights" %in% getMethodNames(result$twas_weights)) expect_true("ensemble_predicted" %in% names(result$twas_predictions)) expect_true("ensemble" %in% names(result)) }) @@ -912,8 +912,8 @@ test_that("twas_weights_pipeline: cv_folds=0 skips cross-validation", { info = paste("Non-zero prediction in", pred_name)) } # Weight dimensions should match ncol(X) - for (w_name in names(result$twas_weights)) { - expect_equal(nrow(result$twas_weights[[w_name]]), ncol(d$X), + for (w_name in getMethodNames(result$twas_weights)) { + expect_equal(nrow(getWeights(result$twas_weights, w_name)), ncol(d$X), info = paste("Wrong nrow for", w_name)) } }) @@ -932,7 +932,7 @@ test_that("twas_weights_pipeline: custom weight_methods are respected", { weight_methods = list(lasso_weights = list(), enet_weights = list()) ) - expect_equal(sort(names(result$twas_weights)), sort(c("lasso_weights", "enet_weights"))) + expect_equal(sort(getMethodNames(result$twas_weights)), sort(c("lasso_weights", "enet_weights"))) }) test_that("twas_weights_pipeline: accepts 'fast_default' preset string", { @@ -959,7 +959,7 @@ test_that("twas_weights_pipeline: accepts 'fast_default' preset string", { expected_methods <- c("susie_weights", "susie_inf_weights", "mrash_weights", "enet_weights", "lasso_weights", "mcp_weights", "scad_weights", "l0learn_weights") - expect_equal(sort(names(result$twas_weights)), sort(expected_methods)) + expect_equal(sort(getMethodNames(result$twas_weights)), sort(expected_methods)) }) test_that("twas_weights_pipeline: accepts custom short-name vector", { @@ -976,7 +976,7 @@ test_that("twas_weights_pipeline: accepts custom short-name vector", { weight_methods = c("lasso", "enet") ) - expect_equal(sort(names(result$twas_weights)), sort(c("lasso_weights", "enet_weights"))) + expect_equal(sort(getMethodNames(result$twas_weights)), sort(c("lasso_weights", "enet_weights"))) }) test_that("twas_weights_pipeline: with fitted_models stores SuSiE intermediates", { @@ -1082,7 +1082,7 @@ test_that("twas_weights: SuSiE-inf is fitted before and initializes ordinary SuS ) ) - expect_equal(names(result), c("susie_weights", "susie_inf_weights")) + expect_equal(getMethodNames(result), c("susie_weights", "susie_inf_weights")) expect_length(susie_calls, 2) expect_equal(susie_calls[[1]]$unmappable_effects, "inf") expect_equal(susie_calls[[1]]$convergence_method, "pip") @@ -1105,8 +1105,8 @@ test_that("twas_weights_pipeline: weight dimensions match input", { weight_methods = list(lasso_weights = list(), enet_weights = list()) ) - for (method_name in names(result$twas_weights)) { - w <- result$twas_weights[[method_name]] + for (method_name in getMethodNames(result$twas_weights)) { + w <- getWeights(result$twas_weights, method_name) expect_equal(nrow(w), ncol(d$X)) expect_equal(ncol(w), 1) } @@ -1303,7 +1303,7 @@ test_that("twas_weights: multivariate weights_matrix is reduced to valid_columns ) result <- twas_weights(X, Y, weight_methods = list(mrmash_weights = list())) # After the dim-fix, the weights matrix is restricted to v1..v5 -> shape p x ncol(Y) - expect_equal(nrow(result$mrmash_weights), p) - expect_equal(ncol(result$mrmash_weights), 2) - expect_equal(rownames(result$mrmash_weights), paste0("v", seq_len(p))) + expect_equal(nrow(getWeights(result, "mrmash_weights")), p) + expect_equal(ncol(getWeights(result, "mrmash_weights")), 2) + expect_equal(rownames(getWeights(result, "mrmash_weights")), paste0("v", seq_len(p))) }) diff --git a/tests/testthat/test_univariate_pipeline.R b/tests/testthat/test_univariate_pipeline.R index 15f93618..771d1ae6 100644 --- a/tests/testthat/test_univariate_pipeline.R +++ b/tests/testthat/test_univariate_pipeline.R @@ -69,6 +69,24 @@ make_fake_post_result <- function(p) { # =========================================================================== # Helper: fake twas_weights_pipeline return value # =========================================================================== +make_test_ld_data <- function(variant_ids, R = NULL) { + if (is.null(R)) { + p <- length(variant_ids) + R <- diag(p) + rownames(R) <- colnames(R) <- variant_ids + } + ref_panel <- pecotmr:::parse_variant_id(variant_ids) + ref_panel$variant_id <- variant_ids + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + bm <- data.frame( + block_id = 1L, chrom = as.character(ref_panel$chrom[1]), + block_start = min(ref_panel$pos), block_end = max(ref_panel$pos), + size = length(variant_ids), start_idx = 1L, end_idx = length(variant_ids), + stringsAsFactors = FALSE + ) + LDData(correlation = R, variants = variants_gr, block_metadata = bm) +} + make_fake_twas_result <- function(p) { list( twas_weights = setNames(rep(0.1, p), paste0("chr1:", seq_len(p), ":A:G")), @@ -315,11 +333,12 @@ test_that("univariate_analysis_pipeline with cv_folds=0 skips CV", { test_that("rss_analysis_pipeline requires file inputs", { # rss_analysis_pipeline calls load_rss_data which requires valid file paths + dummy_ld <- make_test_ld_data(paste0("1:", 1:5, ":A:G")) expect_error( rss_analysis_pipeline( sumstat_path = "/nonexistent/file.tsv", column_file_path = "/nonexistent/columns.yml", - LD_data = list() + LD_data = dummy_ld ) ) }) @@ -909,6 +928,7 @@ test_that("uap: both LD filtering and filter_X applied in sequence", { # ======================================================================== test_that("rss: empty sumstats from load_rss_data => early return", { + dummy_ld <- make_test_ld_data(paste0("1:", 1:5, ":A:G")) local_mocked_bindings( load_rss_data = function(...) { list(sumstats = data.frame(), n = NULL, var_y = NULL) @@ -918,7 +938,7 @@ test_that("rss: empty sumstats from load_rss_data => early return", { result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list() + LD_data = dummy_ld ) expect_true("rss_data_analyzed" %in% names(result)) expect_equal(nrow(result$rss_data_analyzed), 0) @@ -929,6 +949,7 @@ test_that("rss: empty sumstats from load_rss_data => early return", { # ======================================================================== test_that("rss: empty sumstats after rss_basic_qc => early return", { + dummy_ld <- make_test_ld_data(c("1:100:A:G")) local_mocked_bindings( load_rss_data = function(...) { list( @@ -946,7 +967,7 @@ test_that("rss: empty sumstats after rss_basic_qc => early return", { rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list() + LD_data = dummy_ld ), "No variants left after preprocessing" ) @@ -993,7 +1014,7 @@ test_that("rss: pip_cutoff_to_skip > 0, no signal => early return", { rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), pip_cutoff_to_skip = 0.5, qc_method = "none" ), @@ -1024,7 +1045,7 @@ test_that("rss: pip_cutoff_to_skip > 0, signal detected => continues", { rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(ref_panel = ss), + LD_data = make_test_ld_data(ss$variant_id), pip_cutoff_to_skip = 0.5, qc_method = "slalom", finemapping_method = "susie_rss" @@ -1051,7 +1072,7 @@ test_that("rss: negative pip_cutoff_to_skip auto-computes threshold", { rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), pip_cutoff_to_skip = -1, qc_method = "none" ), @@ -1094,7 +1115,7 @@ test_that("rss: full pipeline with QC, imputation, and fine-mapping", { result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(ref_panel = ss), + LD_data = make_test_ld_data(ss$variant_id), qc_method = "slalom", impute = TRUE, finemapping_method = "susie_rss" @@ -1123,7 +1144,7 @@ test_that("rss: method name is correct for no-impute with QC", { result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), qc_method = "slalom", impute = FALSE, finemapping_method = "susie_rss" @@ -1148,7 +1169,7 @@ test_that("rss: method name is correct for no QC", { result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(ref_panel = ss), + LD_data = make_test_ld_data(ss$variant_id), qc_method = NULL, impute = TRUE, finemapping_method = "susie_rss" @@ -1176,7 +1197,7 @@ test_that("rss: outlier_number is stored in result when QC is active", { result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), qc_method = "dentist", impute = FALSE, finemapping_method = "susie_rss" @@ -1209,7 +1230,7 @@ test_that("rss: finemapping_method = NULL skips fine-mapping", { result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), qc_method = NULL, impute = FALSE, finemapping_method = NULL @@ -1242,7 +1263,7 @@ test_that("rss: qc_method = NULL uses combined basic QC without LD-mismatch meth result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), qc_method = NULL, impute = FALSE, finemapping_method = "susie_rss" @@ -1275,7 +1296,7 @@ test_that("rss: impute = FALSE skips raiss imputation", { result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), qc_method = "slalom", impute = FALSE, finemapping_method = "susie_rss" @@ -1302,7 +1323,7 @@ test_that("rss: diagnostics = TRUE with empty fine-mapping result skips diagnost result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), qc_method = "slalom", impute = FALSE, diagnostics = TRUE, @@ -1379,7 +1400,7 @@ test_that("rss: diagnostics with 2+ CS and high p-value/corr triggers BCR and SE result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), qc_method = "slalom", impute = FALSE, diagnostics = TRUE, @@ -1453,7 +1474,7 @@ test_that("rss: diagnostics with 1 CS triggers SER reanalysis only", { result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), qc_method = "slalom", impute = FALSE, diagnostics = TRUE, @@ -1524,7 +1545,7 @@ test_that("rss: diagnostics with no CS but high PIP calls extract_top_pip_info", result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), qc_method = "slalom", impute = FALSE, diagnostics = TRUE, @@ -1564,7 +1585,7 @@ test_that("rss: diagnostics with no CS and no high PIP => diagnostics empty", { result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), qc_method = "slalom", impute = FALSE, diagnostics = TRUE, @@ -1612,7 +1633,7 @@ test_that("rss: finemapping_opts are forwarded to susie_rss_pipeline", { result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), qc_method = NULL, impute = FALSE, finemapping_method = "susie_rss", @@ -1652,7 +1673,7 @@ test_that("rss: dentist QC method generates correct method name", { result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(ref_panel = ss), + LD_data = make_test_ld_data(ss$variant_id), qc_method = "dentist", impute = TRUE, finemapping_method = "susie_rss" @@ -1713,7 +1734,7 @@ test_that("rss: diagnostics with get_susie_result returning NULL => diagnostics result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), qc_method = "slalom", impute = FALSE, diagnostics = TRUE, @@ -1760,7 +1781,7 @@ test_that("rss: diagnostics with null/empty block_cs_metrics => no additional an result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), qc_method = "slalom", impute = FALSE, diagnostics = TRUE, @@ -1845,7 +1866,7 @@ test_that("rss: diagnostics with 2 CS but low p-value and low corr => no extra a result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), qc_method = "slalom", impute = FALSE, diagnostics = TRUE, @@ -1919,7 +1940,7 @@ test_that("rss: diagnostics with high max_cs_corr_study_block triggers BCR+SER", result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(), + LD_data = make_test_ld_data(ss$variant_id), qc_method = "slalom", impute = FALSE, diagnostics = TRUE, @@ -2009,11 +2030,34 @@ test_that("rss: is_genotype=TRUE path does not precompute R and uses X for fine- colnames(X_geno) <- ss$variant_id fake_result <- make_fake_post_result(5) + # Create LDData with genotype_handle to trigger the genotype path + ref_panel <- pecotmr:::parse_variant_id(ss$variant_id) + ref_panel$variant_id <- ss$variant_id + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + bm <- data.frame( + block_id = 1L, chrom = as.character(ref_panel$chrom[1]), + block_start = min(ref_panel$pos), block_end = max(ref_panel$pos), + size = 5L, start_idx = 1L, end_idx = 5L, stringsAsFactors = FALSE + ) + geno_ld <- LDData( + correlation = NULL, genotype_handle = "fake_handle", snp_idx = 1:5, + variants = variants_gr, block_metadata = bm, n_ref = 20L + ) + compute_LD_called <- FALSE susie_X_arg <- NULL susie_LD_arg <- "unset" + # Mock extractBlockGenotypes to return a fake SummarizedExperiment-like object + # that getGenotypes can use. We mock getGenotypes directly at package level. local_mocked_bindings( + extractBlockGenotypes = function(handle, snp_idx, ...) { + # Return a fake object that assay() can handle + se <- SummarizedExperiment::SummarizedExperiment( + assays = list(dosage = t(X_geno)) + ) + se + }, compute_LD = function(X, method = "sample") { compute_LD_called <<- TRUE stop("rss_analysis_pipeline should not precompute LD from X") @@ -2033,7 +2077,7 @@ test_that("rss: is_genotype=TRUE path does not precompute R and uses X for fine- result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(LD_matrix = X_geno, is_genotype = TRUE, ref_panel = ss), + LD_data = geno_ld, qc_method = "slalom", finemapping_method = "susie_rss" ) @@ -2052,10 +2096,33 @@ test_that("rss: mixture LD_data (list of X panels) preserves list shape into sus colnames(X1) <- colnames(X2) <- ss$variant_id fake_result <- make_fake_post_result(5) + # Create LDData with list of genotype handles (mixture path) + ref_panel <- pecotmr:::parse_variant_id(ss$variant_id) + ref_panel$variant_id <- ss$variant_id + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + bm <- data.frame( + block_id = 1L, chrom = as.character(ref_panel$chrom[1]), + block_start = min(ref_panel$pos), block_end = max(ref_panel$pos), + size = 5L, start_idx = 1L, end_idx = 5L, stringsAsFactors = FALSE + ) + mixture_ld <- LDData( + correlation = NULL, genotype_handle = list("fake_handle1", "fake_handle2"), + snp_idx = 1:5, variants = variants_gr, block_metadata = bm, n_ref = 20L + ) + susie_X_arg <- NULL compute_LD_called <- FALSE + extract_call_count <- 0 local_mocked_bindings( + extractBlockGenotypes = function(handle, snp_idx, ...) { + extract_call_count <<- extract_call_count + 1 + # Return appropriate X matrix based on which handle + X_mat <- if (extract_call_count == 1) X1 else X2 + SummarizedExperiment::SummarizedExperiment( + assays = list(dosage = t(X_mat)) + ) + }, compute_LD = function(X, method = "sample") { compute_LD_called <<- TRUE stop("rss_analysis_pipeline should not precompute LD from mixture X") @@ -2074,7 +2141,7 @@ test_that("rss: mixture LD_data (list of X panels) preserves list shape into sus result <- rss_analysis_pipeline( sumstat_path = "/fake/sumstats.tsv", column_file_path = "/fake/columns.yml", - LD_data = list(LD_matrix = list(X1, X2), ref_panel = ss), + LD_data = mixture_ld, qc_method = "slalom", finemapping_method = "susie_rss", impute = FALSE From 336a85928231c9ceed90fe91ba5eabf82202dc59 Mon Sep 17 00:00:00 2001 From: Daniel Nachun Date: Tue, 26 May 2026 12:02:25 -0500 Subject: [PATCH 07/11] more S4 refactoring --- NAMESPACE | 6 +- R/AllClasses.R | 6 +- R/AllGenerics.R | 14 ++ R/AllMethods.R | 17 +- R/LD.R | 31 ++- R/colocboost_pipeline.R | 21 +- R/example_data.R | 6 +- R/file_utils.R | 25 ++- R/mash_wrapper.R | 8 +- R/mr.R | 2 +- R/susie_wrapper.R | 18 +- R/twas.R | 201 +++++++++++------- R/twas_weights.R | 15 +- R/univariate_pipeline.R | 5 +- man/TWASWeights.Rd | 4 +- ..._input.Rd => dot-legacy_list_to_LDData.Rd} | 6 +- man/format_finemapping_output.Rd | 11 +- man/getDataType.Rd | 20 ++ man/getMolecularId.Rd | 20 ++ tests/testthat/test_multivariate_pipeline.R | 7 +- tests/testthat/test_susie_wrapper.R | 74 ++++--- tests/testthat/test_twas.R | 106 ++++----- tests/testthat/test_twas_sketch.R | 50 +++-- tests/testthat/test_univariate_pipeline.R | 46 ++-- 24 files changed, 453 insertions(+), 266 deletions(-) rename man/{region_data_to_rss_input.Rd => dot-legacy_list_to_LDData.Rd} (79%) create mode 100644 man/getDataType.Rd create mode 100644 man/getMolecularId.Rd diff --git a/NAMESPACE b/NAMESPACE index 879f237a..7fa08257 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,6 +6,7 @@ S3method(postprocess_finemapping_fit,susiF) S3method(postprocess_finemapping_fit,susie) S3method(postprocess_finemapping_fit,susie_inf) S3method(postprocess_finemapping_fit,susie_rss) +export(.legacy_list_to_LDData) export(AnnotationMatrix) export(FineMappingResult) export(GWASSumStats) @@ -77,6 +78,7 @@ export(getCS) export(getCVPerformance) export(getCandidates) export(getCorrelation) +export(getDataType) export(getEffects) export(getEnrichment) export(getFits) @@ -85,6 +87,7 @@ export(getLBF) export(getLocal) export(getMaf) export(getMethodNames) +export(getMolecularId) export(getN) export(getPIP) export(getRefPanel) @@ -177,7 +180,6 @@ export(read_afreq) export(read_sldsc_trait) export(region_data_to_colocboost_input) export(region_data_to_ind_input) -export(region_data_to_rss_input) export(region_to_df) export(regions_overlap) export(robust_mahalanobis) @@ -240,6 +242,7 @@ exportMethods(getBlockMetadata) exportMethods(getCS) exportMethods(getCVPerformance) exportMethods(getCorrelation) +exportMethods(getDataType) exportMethods(getEffects) exportMethods(getEnrichment) exportMethods(getFits) @@ -248,6 +251,7 @@ exportMethods(getLBF) exportMethods(getLocal) exportMethods(getMaf) exportMethods(getMethodNames) +exportMethods(getMolecularId) exportMethods(getN) exportMethods(getPIP) exportMethods(getRefPanel) diff --git a/R/AllClasses.R b/R/AllClasses.R index cbb1f83d..bef71229 100644 --- a/R/AllClasses.R +++ b/R/AllClasses.R @@ -416,7 +416,9 @@ setClass("TWASWeights", methods = "character", fits = "ANY", # list or NULL cv_performance = "ANY", # list or NULL - standardized = "logical" + standardized = "logical", + molecular_id = "character", # gene/molecule name (length 0 or 1) + data_type = "ANY" # named list of data types per context, or NULL ), validity = function(object) { errors <- character() @@ -582,6 +584,8 @@ setMethod("show", "FineMappingResult", function(object) { setMethod("show", "TWASWeights", function(object) { cat(sprintf("TWASWeights: %d methods, %d variants\n", length(object@methods), length(object@variant_ids))) + if (length(object@molecular_id) > 0) + cat(sprintf(" Molecular ID: %s\n", object@molecular_id)) cat(sprintf(" Methods: %s\n", paste(object@methods, collapse = ", "))) cat(sprintf(" Standardized: %s\n", object@standardized)) has_cv <- !is.null(object@cv_performance) diff --git a/R/AllGenerics.R b/R/AllGenerics.R index 23cfdd36..d8cfc526 100644 --- a/R/AllGenerics.R +++ b/R/AllGenerics.R @@ -379,6 +379,20 @@ setGeneric("getFits", #' @export setGeneric("getMethodNames", function(x) standardGeneric("getMethodNames")) +#' @title Get Molecular ID +#' @description Extract molecular/gene identifier from a TWASWeights object. +#' @param x A \code{TWASWeights} object. +#' @return Character string (length 0 or 1). +#' @export +setGeneric("getMolecularId", function(x) standardGeneric("getMolecularId")) + +#' @title Get Data Type +#' @description Extract data type metadata from a TWASWeights object. +#' @param x A \code{TWASWeights} object. +#' @return A named list of data types per context, or NULL. +#' @export +setGeneric("getDataType", function(x) standardGeneric("getDataType")) + # ============================================================================= # VCF/BCF writer generic # ============================================================================= diff --git a/R/AllMethods.R b/R/AllMethods.R index 1929ed3c..856ca807 100644 --- a/R/AllMethods.R +++ b/R/AllMethods.R @@ -57,6 +57,8 @@ setMethod("getCorrelation", "LDData", function(x) { #' @export setMethod("getGenotypes", "LDData", function(x) { if (is.null(x@genotype_handle)) return(NULL) + # Plain matrix stored directly (e.g. from load_ld_sketch after filtering) + if (is.matrix(x@genotype_handle)) return(x@genotype_handle) if (is.list(x@genotype_handle)) { lapply(x@genotype_handle, function(h) { geno <- extractBlockGenotypes(h, x@snp_idx) @@ -373,14 +375,17 @@ setMethod("getEffects", "FineMappingResult", function(x) { #' @return A \code{TWASWeights} object. #' @export TWASWeights <- function(weights, variant_ids, fits = NULL, - cv_performance = NULL, standardized = FALSE) { + cv_performance = NULL, standardized = FALSE, + molecular_id = character(0), data_type = NULL) { new("TWASWeights", weights = weights, variant_ids = variant_ids, methods = names(weights), fits = fits, cv_performance = cv_performance, - standardized = standardized + standardized = standardized, + molecular_id = molecular_id, + data_type = data_type ) } @@ -420,6 +425,14 @@ setMethod("getMethodNames", "TWASWeights", function(x) x@methods) #' @export setMethod("getVariantIds", "TWASWeights", function(x) x@variant_ids) +#' @rdname getMolecularId +#' @export +setMethod("getMolecularId", "TWASWeights", function(x) x@molecular_id) + +#' @rdname getDataType +#' @export +setMethod("getDataType", "TWASWeights", function(x) x@data_type) + # ============================================================================= # FineMappingResult additional accessors # ============================================================================= diff --git a/R/LD.R b/R/LD.R index b6ff54c9..e63c4032 100644 --- a/R/LD.R +++ b/R/LD.R @@ -525,15 +525,10 @@ standardize_genotype_hwe <- function(X, allele_freq) { #' @param n_sample Optional original panel sample size for computing variance #' (= 2*p*(1-p)*n/(n-1)). Passed through to \code{load_LD_matrix()}. #' -#' @return A list with: -#' \describe{ -#' \item{X}{Raw genotype matrix (n_sketch x p) after removing monomorphic variants.} -#' \item{n_sketch}{Number of rows (samples) in the sketch genotype matrix.} -#' \item{ref_panel}{Data.frame with variant metadata (chrom, pos, A2, A1, variant_id, -#' allele_freq, and optionally variance, n_nomiss).} -#' \item{variant_ids}{Character vector of variant IDs (canonical format) after -#' removing monomorphic variants.} -#' } +#' @return An \code{LDData} S4 object with monomorphic variants removed. +#' Consumers should use S4 accessors: \code{getGenotypes()}, \code{getRefPanel()}, +#' \code{getVariantIds()}. The number of sketch samples is +#' \code{nrow(getGenotypes(result))}. #' @export load_ld_sketch <- function(ld_meta_file_path, region, n_sample = NULL) { result <- load_LD_matrix(ld_meta_file_path, region, return_genotype = TRUE, n_sample = n_sample) @@ -541,7 +536,6 @@ load_ld_sketch <- function(ld_meta_file_path, region, n_sample = NULL) { stop("load_LD_matrix must return an LDData object") } X <- getGenotypes(result) - variant_ids <- getVariantIds(result) ref_panel <- getRefPanel(result) # Remove monomorphic variants (zero variance under HWE) @@ -549,15 +543,20 @@ load_ld_sketch <- function(ld_meta_file_path, region, n_sample = NULL) { polymorphic <- p > 0 & p < 1 if (!all(polymorphic)) { X <- X[, polymorphic, drop = FALSE] - variant_ids <- variant_ids[polymorphic] ref_panel <- ref_panel[polymorphic, , drop = FALSE] } - list( - X = X, - n_sketch = nrow(X), - ref_panel = ref_panel, - variant_ids = variant_ids + # Rebuild LDData with the extracted (and filtered) genotype matrix stored + # directly in genotype_handle so getGenotypes() returns it without needing + # the original file handle. + variants_gr <- .ref_panel_to_granges(ref_panel) + LDData( + correlation = NULL, + genotype_handle = X, + snp_idx = NULL, + variants = variants_gr, + block_metadata = getBlockMetadata(result), + n_ref = result@n_ref ) } diff --git a/R/colocboost_pipeline.R b/R/colocboost_pipeline.R index 5c7ee1ce..87b2cc50 100644 --- a/R/colocboost_pipeline.R +++ b/R/colocboost_pipeline.R @@ -66,21 +66,12 @@ region_data_to_colocboost_input <- function(region_data) { ind_records <- ind_records_from_input(ind_input) ind_args <- .cb_format_individual(ind_records) - # Build sumstat_records using original LD_info to preserve genotype vs LD - # distinction. region_data_to_rss_input converts to LDData S4 (computing R), - # but colocboost needs the original matrix for X_ref/LD formatting. - orig_ld_info <- region_data$sumstat_data$LD_info - sumstat_records <- lapply(seq_along(rss_input$rss_input), function(i) { - study <- names(rss_input$rss_input)[i] - ld_idx <- min(i, length(orig_ld_info)) - orig_ld <- orig_ld_info[[ld_idx]] - ld_mat <- if (is(orig_ld, "LDData")) { - getCorrelation(orig_ld) - } else { - getCorrelation(rss_input$LD_data[[study]]) - } + # Build sumstat_records from rss_input which already contains LDData S4 + # objects (region_data_to_rss_input converts any legacy lists). + sumstat_records <- lapply(names(rss_input$rss_input), function(study) { + ld_data <- rss_input$LD_data[[study]] list(rss_input = rss_input$rss_input[[study]], - LD_matrix = ld_mat) + LD_matrix = getCorrelation(ld_data)) }) names(sumstat_records) <- names(rss_input$rss_input) sumstat_args <- .cb_format_sumstat(sumstat_records) @@ -1222,7 +1213,7 @@ qc_individual_data <- function(X, Y, maf = NULL, X_variance = NULL, message("QC track: LD/X_ref names are parseable for summary-stat study ", study, ".") } else if (is_ld_data(ref_info)) { message("QC track: using supplied LD_reference_info LD data for summary-stat study ", study, ".") - ld_data <- if (is(ref_info, "LDData")) ref_info else .legacy_list_to_LDData(ref_info) + ld_data <- ref_info } else { message("QC track: using supplied LD_reference_info variant metadata for summary-stat study ", study, ".") ld_data <- .cb_make_ld_data( diff --git a/R/example_data.R b/R/example_data.R index b7f7830b..e075f82b 100644 --- a/R/example_data.R +++ b/R/example_data.R @@ -109,9 +109,11 @@ NULL #' #' \describe{ #' \item{susie_result_trimmed}{List. Trimmed SuSiE result with elements -#' \code{alpha}, \code{pip}, \code{V}, and \code{sets}.} +#' \code{alpha}, \code{pip}, \code{V}, and \code{sets}. (Legacy key; new +#' code should use the \code{FineMappingResult} S4 object.)} #' \item{variant_names}{Character vector (length 2,828). Synthetic variant -#' identifiers matching the variant names in the SuSiE result.} +#' identifiers matching the variant names in the SuSiE result. +#' (Legacy key; new code should use the \code{FineMappingResult} S4 object.)} #' } #' #' @keywords data diff --git a/R/file_utils.R b/R/file_utils.R index c7b23e6c..3e01b005 100644 --- a/R/file_utils.R +++ b/R/file_utils.R @@ -1319,7 +1319,30 @@ load_twas_weights <- function(weight_db_files, conditions = NULL, get_nested_element(combined_all_data, c(condition, "twas_cv_result", "performance")) }) names(performance_tables) <- conditions - return(list(susie_results = combined_susie_result, weights = weights, twas_cv_performance = performance_tables)) + # Extract variant_ids from weight matrices (union across all contexts) + all_variant_ids <- Reduce(union, lapply(weights, function(w) { + if (is.matrix(w)) rownames(w) else names(w) + })) + if (is.null(all_variant_ids)) all_variant_ids <- character(0) + # Pad weight matrices to common variant set + if (length(all_variant_ids) > 0) { + weights <- lapply(weights, function(w) { + if (!is.matrix(w)) return(w) + missing <- setdiff(all_variant_ids, rownames(w)) + if (length(missing) > 0) { + pad <- matrix(0, nrow = length(missing), ncol = ncol(w), + dimnames = list(missing, colnames(w))) + w <- rbind(w, pad)[all_variant_ids, , drop = FALSE] + } + w + }) + } + return(TWASWeights( + weights = weights, + variant_ids = all_variant_ids, + fits = combined_susie_result, + cv_performance = performance_tables + )) }, silent = FALSE ) diff --git a/R/mash_wrapper.R b/R/mash_wrapper.R index 88a69d6e..c7801089 100644 --- a/R/mash_wrapper.R +++ b/R/mash_wrapper.R @@ -1100,8 +1100,12 @@ extract_flatten_sumstats_from_nested <- function(data, extract_inf = "z", max_de } if (is.list(element)) { - if (all(c("variant_names", "sumstats") %in% names(element))) { - variant_names <- element$variant_names + # Extract variant_names from FineMappingResult S4 or legacy list key + has_fm <- !is.null(element$finemapping_result) && is(element$finemapping_result, "FineMappingResult") + has_legacy_vn <- "variant_names" %in% names(element) + has_sumstats <- "sumstats" %in% names(element) + if (has_sumstats && (has_fm || has_legacy_vn)) { + variant_names <- if (has_fm) getVariantNames(element$finemapping_result) else element$variant_names sumstats <- element$sumstats # Extract based on type diff --git a/R/mr.R b/R/mr.R index fba96e7c..233ad30b 100644 --- a/R/mr.R +++ b/R/mr.R @@ -25,7 +25,7 @@ calc_I2 <- function(Q, Est) { #' #' Description of what the function does. #' -#' @param susie_result A list containing the results of SuSiE analysis. This list should include nested elements such as 'susie_results', 'susie_result_trimmed', and 'top_loci', containing details about the statistical analysis of genetic variants. +#' @param susie_result A list containing the results of SuSiE analysis. This list should include nested elements such as 'susie_results', 'finemapping_result' (a FineMappingResult S4 object), and 'top_loci', containing details about the statistical analysis of genetic variants. #' @param condition A character string specifying the conditions. This is used to select the corresponding subset of results within 'susie_result'. #' @param gwas_sumstats_db A data frame containing summary statistics from GWAS studies. It should include columns for variant id and their associated statistics such as beta coefficients and standard errors. #' @param coverage A character string specifying the credible set column. If diff --git a/R/susie_wrapper.R b/R/susie_wrapper.R index fc667323..ec5b95cb 100644 --- a/R/susie_wrapper.R +++ b/R/susie_wrapper.R @@ -349,10 +349,7 @@ postprocess_finemapping_fit.susiF <- function(fit, method = "fsusie", cs_input = sumstats = sumstats ) - # Also return as list for backwards compatibility with existing consumers res <- list( - variant_names = variant_names, - result_trimmed = trimmed, top_loci = top_loci, finemapping_result = fm_result ) @@ -786,25 +783,24 @@ trim_finemapping_fit <- function(fit, effect_idx, method, cs_tables) { #' Format Fine-mapping Post-processing for Protocol Output #' #' Converts method-aware fine-mapping post-processing output into the root-level -#' fields consumed by protocol RDS files. Exposes the single 22-column unified -#' \code{top_loci} table alongside \code{susie_result_trimmed}, -#' \code{variant_names}, and method-specific intermediates. +#' fields consumed by protocol RDS files. The primary method's +#' \code{FineMappingResult} S4 object is promoted to the \code{finemapping_result} +#' field; use its accessors (\code{getTrimmedFit}, \code{getVariantNames}, +#' \code{getTopLoci}, etc.) instead of legacy list keys. #' #' @param post Output from \code{\link{postprocess_finemapping_fits}}. #' @param primary_method Method whose result should populate root-level fields. -#' @return A list with root-level fields including \code{variant_names}, -#' \code{susie_result_trimmed}, and \code{top_loci}. +#' @return A list with root-level fields including \code{finemapping_result} +#' and \code{top_loci}. #' @export format_finemapping_output <- function(post, primary_method) { method_post <- post$finemapping_results[[primary_method]] if (is.null(method_post)) { stop("primary_method was not found in finemapping_results: ", primary_method) } - keep_names <- setdiff(names(method_post), c("result_trimmed", "top_loci")) c( - method_post[keep_names], + method_post, list( - susie_result_trimmed = method_post$result_trimmed, top_loci = post$top_loci ) ) diff --git a/R/twas.R b/R/twas.R index e4e4f85d..6834f236 100644 --- a/R/twas.R +++ b/R/twas.R @@ -29,16 +29,21 @@ #' @export harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, ld_reference_sample_size, column_file_path = NULL, comment_string = "#") { - # Step 1: load TWAS weights data + # Step 1: Normalize twas_weights_data -- accept bare TWASWeights or wrapper lists molecular_ids <- names(twas_weights_data) - # Each element is either a TWASWeights S4 or a list with $twas_weights (TWASWeights S4) - .get_tw <- function(mol_data) { - if (is(mol_data, "TWASWeights")) return(mol_data) - if (is.list(mol_data) && is(mol_data$twas_weights, "TWASWeights")) return(mol_data$twas_weights) - stop("Each element of twas_weights_data must be a TWASWeights S4 object ", - "or a list with a $twas_weights TWASWeights element") + for (mol_id in molecular_ids) { + entry <- twas_weights_data[[mol_id]] + if (is(entry, "TWASWeights")) { + # Already a bare TWASWeights, use directly + } else if (is.list(entry) && is(entry$twas_weights, "TWASWeights")) { + # Wrapper list -- extract the TWASWeights + twas_weights_data[[mol_id]] <- entry$twas_weights + } else { + stop("Each element of twas_weights_data must be a TWASWeights S4 object ", + "or a list with a $twas_weights TWASWeights element") + } } - first_tw <- .get_tw(twas_weights_data[[1]]) + first_tw <- twas_weights_data[[1]] chrom <- as.integer(parse_number(gsub(":.*$", "", getVariantIds(first_tw)[1]))) gwas_meta_df <- as.data.frame(vroom(gwas_meta_file)) gwas_files <- unique(gwas_meta_df$file_path[gwas_meta_df$chrom == chrom]) @@ -47,10 +52,9 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, # Per-gene loop: each gene loads its own LD sketch independently for (molecular_id in molecular_ids) { - mol_entry <- twas_weights_data[[molecular_id]] - tw <- .get_tw(mol_entry) + tw <- twas_weights_data[[molecular_id]] mol_res <- list(chrom = chrom, variant_names = list()) - mol_res[["data_type"]] <- if (is.list(mol_entry) && "data_type" %in% names(mol_entry)) mol_entry$data_type + mol_res[["data_type"]] <- getDataType(tw) contexts <- getMethodNames(tw) # Step 2: Build gene window from all contexts' variant positions @@ -60,14 +64,18 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, # Step 3: Load LD sketch for this gene's window and compute SVD sketch <- load_ld_sketch(ld_meta_file_path, gene_region, n_sample = ld_reference_sample_size) - X_std <- standardize_genotype_hwe(sketch$X, sketch$ref_panel$allele_freq) + sketch_X <- getGenotypes(sketch) + sketch_ref_panel <- getRefPanel(sketch) + sketch_variant_ids <- getVariantIds(sketch) + sketch_n <- nrow(sketch_X) + X_std <- standardize_genotype_hwe(sketch_X, sketch_ref_panel$allele_freq) svd_result <- safe_svd(X_std, tol = 0) # Step 4: Harmonize GWAS and weights against sketch variants for (study in names(gwas_files)) { gwas_file <- gwas_files[study] gwas_data_sumstats <- harmonize_gwas(gwas_file, query_region = gene_region, - sketch$variant_ids, c("beta", "z"), + sketch_variant_ids, c("beta", "z"), match_min_prop = 0, column_file_path = column_file_path, comment_string = comment_string) if (is.null(gwas_data_sumstats)) next @@ -78,7 +86,7 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, # Harmonize weights against sketch reference weights_matrix <- cbind(variant_id_to_df(rownames(weights_matrix)), weights_matrix) - weights_matrix_qced <- match_ref_panel(weights_matrix, sketch$variant_ids, + weights_matrix_qced <- match_ref_panel(weights_matrix, sketch_variant_ids, colnames(weights_matrix)[!colnames(weights_matrix) %in% c("chrom", "pos", "A2", "A1")], match_min_prop = 0 ) @@ -89,7 +97,7 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, rownames(weights_matrix_subset) <- qced_data$variant_id # Ensure consistent chr prefix convention before intersecting - chr_matched <- ensure_chr_match(gwas_data_sumstats$variant_id, sketch$variant_ids) + chr_matched <- ensure_chr_match(gwas_data_sumstats$variant_id, sketch_variant_ids) gwas_data_sumstats$variant_id <- chr_matched$ids_a rownames(weights_matrix_subset) <- ensure_chr_match(rownames(weights_matrix_subset), gwas_data_sumstats$variant_id)$ids_a weights_matrix_subset <- weights_matrix_subset[rownames(weights_matrix_subset) %in% gwas_data_sumstats$variant_id, , drop = FALSE] @@ -99,8 +107,12 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, # Step 5: adjust SuSiE weights based on available variants tw_weights_ctx <- getWeights(tw, context) if ("susie_weights" %in% colnames(tw_weights_ctx)) { - # For adjust_susie_weights, we need the fits (susie_results) - mol_data_for_adjust <- if (is.list(mol_entry) && !is(mol_entry, "TWASWeights")) mol_entry else list(twas_weights = tw) + # For adjust_susie_weights, wrap TWASWeights in the list format it expects + mol_data_for_adjust <- list( + susie_results = getFits(tw), + weights = getWeights(tw), + variant_names = lapply(getWeights(tw), function(w) if (is.matrix(w)) rownames(w) else names(w)) + ) adjusted_susie_weights <- adjust_susie_weights(mol_data_for_adjust, keep_variants = postqc_weight_variants, run_allele_qc = TRUE, variable_name_obj = c("variant_names", context), @@ -111,19 +123,15 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, susie_weights = setNames(adjusted_susie_weights$adjusted_susie_weights, adjusted_susie_weights$remained_variants_ids), weights_matrix_subset[adjusted_susie_weights$remained_variants_ids, !colnames(weights_matrix_subset) %in% "susie_weights", drop = FALSE] ) - susie_results <- if (is.list(mol_entry) && "susie_results" %in% names(mol_entry)) { - mol_entry$susie_results[[context]] - } else { - getFits(tw, context) - } + susie_results <- getFits(tw, context) susie_intermediate <- susie_results[c("pip", "cs_variants", "cs_purity")] names(susie_intermediate[["pip"]]) <- original_weight_variants # original variants not yet qced pip <- susie_intermediate[["pip"]] - pip_qced <- match_ref_panel(cbind(parse_variant_id(names(pip)), pip), sketch$variant_ids, "pip", match_min_prop = 0) + pip_qced <- match_ref_panel(cbind(parse_variant_id(names(pip)), pip), sketch_variant_ids, "pip", match_min_prop = 0) susie_intermediate[["pip"]] <- abs(pip_qced$target_data_qced$pip) names(susie_intermediate[["pip"]]) <- pip_qced$target_data_qced$variant_id susie_intermediate[["cs_variants"]] <- lapply(susie_intermediate[["cs_variants"]], function(x) { - variant_qc <- match_ref_panel(x, sketch$variant_ids, match_min_prop = 0) + variant_qc <- match_ref_panel(x, sketch_variant_ids, match_min_prop = 0) variant_qc$target_data_qced$variant_id[variant_qc$target_data_qced$variant_id %in% postqc_weight_variants] }) mol_res[["susie_weights_intermediate_qced"]][[context]] <- susie_intermediate @@ -143,7 +151,7 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, if (is_standardized) { scaled <- weights_matrix_subset } else { - variance <- sketch$ref_panel$variance[match(rownames(weights_matrix_subset), sketch$ref_panel$variant_id)] + variance <- sketch_ref_panel$variance[match(rownames(weights_matrix_subset), sketch_ref_panel$variant_id)] scaled <- weights_matrix_subset * sqrt(variance) } mol_res[["weights_qced"]][[context]][[study]] <- list(scaled_weights = scaled, weights = weights_matrix_subset) @@ -165,12 +173,12 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, } else { mol_res[["svd_V"]] <- svd_result$v mol_res[["svd_D"]] <- svd_result$d - mol_res[["n_sketch"]] <- sketch$n_sketch - mol_res[["ld_variant_ids"]] <- sketch$variant_ids + mol_res[["n_sketch"]] <- sketch_n + mol_res[["ld_variant_ids"]] <- sketch_variant_ids results[[molecular_id]] <- mol_res } } - return(list(twas_data_qced = results, ref_panel = sketch$ref_panel)) + return(list(twas_data_qced = results, ref_panel = sketch_ref_panel)) } #' Harmonize GWAS Summary Statistics @@ -368,22 +376,23 @@ twas_pipeline <- function(twas_weights_data, return(NULL) } } - pick_best_model <- function(twas_data_combined, rsq_cutoff, rsq_pval_cutoff, rsq_option, rsq_pval_option) { + pick_best_model <- function(tw, molecular_id, rsq_cutoff, rsq_pval_cutoff, rsq_option, rsq_pval_option) { best_rsq <- rsq_cutoff + cv_perf <- getCVPerformance(tw) + method_names <- getMethodNames(tw) # SS-TWAS path: no CV performance, all methods are valid - if (is.null(twas_data_combined$twas_cv_performance) || - length(twas_data_combined$twas_cv_performance) == 0) { - model_selection <- lapply(names(twas_data_combined$weights), function(context) { + if (is.null(cv_perf) || length(cv_perf) == 0) { + model_selection <- lapply(method_names, function(context) { list(selected_model = NA, is_imputable = TRUE, all_methods = TRUE) }) - names(model_selection) <- names(twas_data_combined$weights) + names(model_selection) <- method_names return(model_selection) } # Determine if a gene/region is imputable and select the best model - model_selection <- lapply(names(twas_data_combined$weights), function(context) { + model_selection <- lapply(method_names, function(context) { selected_model <- NULL - available_models <- do.call(c, lapply(names(twas_data_combined$twas_cv_performance[[context]]), function(model) { - if (!is.na(twas_data_combined$twas_cv_performance[[context]][[model]][, rsq_option])) { + available_models <- do.call(c, lapply(names(cv_perf[[context]]), function(model) { + if (!is.na(cv_perf[[context]][[model]][, rsq_option])) { return(model) } })) @@ -392,7 +401,7 @@ twas_pipeline <- function(twas_weights_data, return(NULL) } for (model in available_models) { - model_data <- twas_data_combined$twas_cv_performance[[context]][[model]] + model_data <- cv_perf[[context]][[model]] if (model_data[, rsq_option] >= best_rsq & model_data[, colnames(model_data)[which(colnames(model_data) %in% rsq_pval_option)]] < rsq_pval_cutoff) { best_rsq <- model_data[, rsq_option] selected_model <- model @@ -401,32 +410,74 @@ twas_pipeline <- function(twas_weights_data, if (is.null(selected_model)) { message(paste0( "No model has p-value < ", rsq_pval_cutoff, " and r2 >= ", rsq_cutoff, ", skipping context ", context, - " at region ", unique(twas_data_combined$molecular_id), ". " + " at region ", molecular_id, ". " )) return(list(selected_model = c("context_non_imputable"), is_imputable = FALSE)) # No significant model found } else { selected_model <- unlist(strsplit(selected_model, "_performance")) - message(paste0("The selected best performing model for context ", context, " at region ", twas_data_combined$molecular_id, " is ", selected_model, ". ")) + message(paste0("The selected best performing model for context ", context, " at region ", molecular_id, " is ", selected_model, ". ")) return(list(selected_model = selected_model, is_imputable = TRUE)) } }) - names(model_selection) <- names(twas_data_combined$weights) + names(model_selection) <- method_names return(model_selection) } # Step 1: TWAS and MR analysis for all methods for imputable gene rsq_option <- match.arg(rsq_option) - # filter events - if (!is.null(event_filters)){ - for (weight_db in names(twas_weights_data)){ - contexts <- names(twas_weights_data[[weight_db]]$weights) - filtered_events <- filter_molecular_events(contexts, event_filters, remove_all_group=TRUE) - if (length(filtered_events!=0)){ - for (db in names(twas_weights_data[[weight_db]])){ - twas_weights_data[[weight_db]][[db]] <- twas_weights_data[[weight_db]][[db]][filtered_events] - } + + # Normalize twas_weights_data entries to TWASWeights S4 + for (wdb in names(twas_weights_data)) { + entry <- twas_weights_data[[wdb]] + if (is(entry, "TWASWeights")) next + if (is.list(entry) && is(entry[["twas_weights"]], "TWASWeights")) { + # Wrapper list with $twas_weights — unwrap but merge metadata into S4 + tw_inner <- entry[["twas_weights"]] + twas_weights_data[[wdb]] <- TWASWeights( + weights = getWeights(tw_inner), + variant_ids = getVariantIds(tw_inner), + fits = getFits(tw_inner), + cv_performance = getCVPerformance(tw_inner), + standardized = getStandardized(tw_inner), + molecular_id = if (!is.null(entry[["molecular_id"]])) entry[["molecular_id"]] else getMolecularId(tw_inner), + data_type = if (!is.null(entry[["data_type"]])) entry[["data_type"]] else getDataType(tw_inner) + ) + } else if (is.list(entry) && !is.null(entry[["weights"]])) { + # Legacy list from load_twas_weights or test fixtures + wts <- entry[["weights"]] + vid <- if (!is.null(names(wts)) && length(wts) > 0 && !is.null(rownames(wts[[1]]))) { + Reduce(union, lapply(wts, rownames)) + } else character(0) + twas_weights_data[[wdb]] <- TWASWeights( + weights = wts, + variant_ids = vid, + fits = entry[["susie_results"]], + cv_performance = entry[["twas_cv_performance"]], + molecular_id = if (!is.null(entry[["molecular_id"]])) entry[["molecular_id"]] else character(0), + data_type = entry[["data_type"]] + ) + } + } + + # filter events + if (!is.null(event_filters)) { + for (weight_db in names(twas_weights_data)) { + tw <- twas_weights_data[[weight_db]] + contexts <- getMethodNames(tw) + filtered_events <- filter_molecular_events(contexts, event_filters, remove_all_group = TRUE) + if (length(filtered_events) != 0) { + # Rebuild TWASWeights with only the filtered contexts + twas_weights_data[[weight_db]] <- TWASWeights( + weights = getWeights(tw)[filtered_events], + variant_ids = getVariantIds(tw), + fits = if (!is.null(getFits(tw))) getFits(tw)[intersect(filtered_events, names(getFits(tw)))] else NULL, + cv_performance = if (!is.null(getCVPerformance(tw))) getCVPerformance(tw)[intersect(filtered_events, names(getCVPerformance(tw)))] else NULL, + standardized = getStandardized(tw), + molecular_id = getMolecularId(tw), + data_type = getDataType(tw) + ) } else { - twas_weights_data[[weight_db]] <- NULL + twas_weights_data[[weight_db]] <- NULL } } } @@ -439,7 +490,10 @@ twas_pipeline <- function(twas_weights_data, ld_reference_sample_size = ld_reference_sample_size, column_file_path = column_file_path, comment_string = comment_string) twas_results_db <- lapply(names(twas_weights_data), function(weight_db) { - twas_weights_data[[weight_db]][["molecular_id"]] <- weight_db + tw <- twas_weights_data[[weight_db]] + tw_methods <- getMethodNames(tw) + tw_cv <- getCVPerformance(tw) + tw_fits <- getFits(tw) twas_data_qced <- twas_data_qced_result$twas_data_qced if (length(twas_data_qced[[weight_db]]) == 0 | is.null(twas_data_qced[[weight_db]])) { warning(paste0("No data harmonized for ", weight_db, ". Returning NULL for TWAS result for this region.")) @@ -448,25 +502,24 @@ twas_pipeline <- function(twas_weights_data, if (rsq_cutoff > 0) { message("Selecting the best model based on criteria...") best_model_selection <- pick_best_model( - twas_weights_data[[weight_db]], + tw, molecular_id = weight_db, rsq_cutoff = rsq_cutoff, rsq_pval_cutoff = rsq_pval_cutoff, rsq_option = rsq_option, rsq_pval_option = rsq_pval_option ) - twas_data_qced[[weight_db]][["model_selection"]] <- setNames(best_model_selection, names(twas_weights_data[[weight_db]]$weights)) + twas_data_qced[[weight_db]][["model_selection"]] <- setNames(best_model_selection, tw_methods) } else { message("Skipping best model selection. Assigning NA of model_selection to all weights.") twas_data_qced[[weight_db]][["model_selection"]] <- setNames( - rep(NA, length(names(twas_weights_data[[weight_db]]$weights))), - names(twas_weights_data[[weight_db]]$weights) + rep(NA, length(tw_methods)), tw_methods ) } - if (!"data_type" %in% names(twas_weights_data[[weight_db]])) { - twas_data_qced[[weight_db]][["data_type"]] <- setNames(rep( - list(NA), - length(names(twas_weights_data[[weight_db]]$weights)) - ), names(twas_weights_data[[weight_db]]$weights)) + dt <- getDataType(tw) + if (is.null(dt)) { + twas_data_qced[[weight_db]][["data_type"]] <- setNames( + rep(list(NA), length(tw_methods)), tw_methods + ) } if (length(weight_db) < 1) stop(paste0("No data harmonized for ", weight_db, ". ")) contexts <- names(twas_data_qced[[weight_db]][["weights_qced"]]) @@ -478,7 +531,7 @@ twas_pipeline <- function(twas_weights_data, # Nested lapply for contexts and gwas studies twas_gene_results <- lapply(contexts, function(context) { study_results <- lapply(gwas_studies, function(study) { - twas_variants <- Reduce(intersect, list(rownames(twas_data_qced[[weight_db]][["weights_qced"]][[context]][[study]][["weights"]]), + twas_variants <- Reduce(intersect, list(rownames(twas_data_qced[[weight_db]][["weights_qced"]][[context]][[study]][["weights"]]), twas_data_qced[[weight_db]][["variant_names"]][[context]][[study]], twas_data_qced[[weight_db]][["gwas_qced"]][[study]]$variant_id) ) @@ -486,8 +539,7 @@ twas_pipeline <- function(twas_weights_data, return(list(twas_rs_df = data.frame(), mr_rs_df = data.frame())) } # twas analysis -- enable omnibus when no CV performance available - has_cv <- !is.null(twas_weights_data[[weight_db]]$twas_cv_performance) && - length(twas_weights_data[[weight_db]]$twas_cv_performance) > 0 + has_cv <- !is.null(tw_cv) && length(tw_cv) > 0 twas_rs <- twas_analysis( twas_data_qced[[weight_db]][["weights_qced"]][[context]][[study]][["weights"]], twas_data_qced[[weight_db]][["gwas_qced"]][[study]], @@ -503,15 +555,17 @@ twas_pipeline <- function(twas_weights_data, } twas_rs_df <- build_twas_score_row(twas_rs, weight_db, context, study) # MR analysis - if (!is.null(twas_weights_data[[weight_db]]$susie_results) && + if (!is.null(tw_fits) && any(na.omit(twas_rs_df$twas_pval) < mr_pval_cutoff) && - "top_loci" %in% names(twas_weights_data[[weight_db]]$susie_results[[context]])) { + !is.null(tw_fits[[context]]) && "top_loci" %in% names(tw_fits[[context]])) { if (!"effect_allele_frequency" %in% colnames(twas_data_qced[[weight_db]][["gwas_qced"]][[study]])) { warning(paste0("skip MR for ", weight_db, " for ", study, ", the effect_allele_frequency information is not available.")) return(list(twas_rs_df = twas_rs_df, mr_rs_df = data.frame())) } combined_ld_meta_df <- twas_data_qced_result$ref_panel - mr_formatted_input <- mr_format(twas_weights_data[[weight_db]], context, twas_data_qced[[weight_db]][["gwas_qced"]][[study]], + # mr_format expects a nested list with $molecular_id and $susie_results + mr_input <- list(molecular_id = weight_db, susie_results = tw_fits) + mr_formatted_input <- mr_format(mr_input, context, twas_data_qced[[weight_db]][["gwas_qced"]][[study]], coverage = mr_coverage_column, run_allele_qc = TRUE, method = mr_method, coverage_level = mr_coverage, molecular_name_obj = c("molecular_id"), ld_meta_df = combined_ld_meta_df @@ -560,24 +614,28 @@ twas_pipeline <- function(twas_weights_data, # Step 2: Summarize and merge twas cv results and region information for all methods for all contexts for imputable genes. twas_table <- do.call(rbind, lapply(names(twas_data), function(molecular_id) { - contexts <- names(twas_weights_data[[molecular_id]]$weights) + tw_mol <- twas_weights_data[[molecular_id]] + contexts <- getMethodNames(tw_mol) + tw_mol_cv <- getCVPerformance(tw_mol) + tw_mol_dt <- getDataType(tw_mol) # merge twas_cv information for same gene across all weight db files, loop through each context for all methods gene_table <- do.call(rbind, lapply(contexts, function(context) { - cv_perf <- twas_weights_data[[molecular_id]]$twas_cv_performance[[context]] + cv_perf <- if (!is.null(tw_mol_cv)) tw_mol_cv[[context]] else NULL model_sel <- twas_data[[molecular_id]][["model_selection"]][[context]] is_imputable <- if (!is.null(model_sel)) model_sel$is_imputable else TRUE if (is.null(cv_perf) || length(cv_perf) == 0) { # SS-TWAS path: no CV, derive methods from weight matrix columns - wt_mat <- twas_weights_data[[molecular_id]]$weights[[context]] + wt_mat <- getWeights(tw_mol, context) methods <- if (is.matrix(wt_mat)) colnames(wt_mat) else names(wt_mat) if (is.null(methods)) methods <- "unknown" + dt_val <- if (!is.null(tw_mol_dt)) tw_mol_dt[[context]] else NA context_table <- data.frame( context = context, method = methods, is_imputable = is_imputable, is_selected_method = FALSE, rsq_cv = NA_real_, pval_cv = NA_real_, - type = twas_weights_data[[molecular_id]][["data_type"]][[context]] + type = dt_val ) } else { methods <- sub("_[^_]+$", "", names(cv_perf)) @@ -588,12 +646,13 @@ twas_pipeline <- function(twas_weights_data, cv_rsqs <- sapply(cv_perf, function(x) x[, rsq_option]) cv_pvals <- sapply(cv_perf, function(x) x[, colnames(x)[which(colnames(x) %in% rsq_pval_option)]]) + dt_val <- if (!is.null(tw_mol_dt)) tw_mol_dt[[context]] else NA context_table <- data.frame( context = context, method = methods, is_imputable = is_imputable, is_selected_method = is_selected_method, rsq_cv = cv_rsqs, pval_cv = cv_pvals, - type = twas_weights_data[[molecular_id]][["data_type"]][[context]] + type = dt_val ) } return(context_table) diff --git a/R/twas_weights.R b/R/twas_weights.R index 978e0fb6..4bf72827 100644 --- a/R/twas_weights.R +++ b/R/twas_weights.R @@ -1927,19 +1927,8 @@ twas_weights_sumstat_pipeline <- function( signal_cutoff = 0.025, cs_input = "Xcorr" ) - if (!is.null(fm_output$finemapping_results$susie_rss)) { - fm_res <- fm_output$finemapping_results$susie_rss - # Use top_loci from the per-method result (has 'method' column) - # rather than the wide format from postprocess_finemapping_fits - tl <- fm_res$top_loci - if (is.null(tl) || nrow(tl) == 0) tl <- data.frame() - finemapping_result <- FineMappingResult( - variant_names = variant_ids, - trimmed_fit = fm_res$result_trimmed, - top_loci = tl, - method = "susie_rss", - sumstats = list(z = z, n = n) - ) + if (!is.null(fm_output$finemapping_results$susie_rss$finemapping_result)) { + finemapping_result <- fm_output$finemapping_results$susie_rss$finemapping_result } }, error = function(e) { warning(sprintf("Fine-mapping post-processing failed: %s", e$message)) diff --git a/R/univariate_pipeline.R b/R/univariate_pipeline.R index df1f9b00..7844f1d7 100644 --- a/R/univariate_pipeline.R +++ b/R/univariate_pipeline.R @@ -164,9 +164,8 @@ univariate_analysis_pipeline <- function( region = region ) res <- c(res, format_finemapping_output(susie_post, primary_method = "susie")) - if (!is.null(susie_post$finemapping_results$susie_inf)) { - res$susie_inf_result_trimmed <- susie_post$finemapping_results$susie_inf$result_trimmed - } + susie_inf_fm <- susie_post$finemapping_results$susie_inf$finemapping_result + res$susie_inf_result_trimmed <- if (!is.null(susie_inf_fm)) getTrimmedFit(susie_inf_fm) else NULL res$total_time_elapsed <- proc.time() - st # TWAS weights and cross-validation diff --git a/man/TWASWeights.Rd b/man/TWASWeights.Rd index b6a7b8b2..6c83f767 100644 --- a/man/TWASWeights.Rd +++ b/man/TWASWeights.Rd @@ -9,7 +9,9 @@ TWASWeights( variant_ids, fits = NULL, cv_performance = NULL, - standardized = FALSE + standardized = FALSE, + molecular_id = character(0), + data_type = NULL ) } \arguments{ diff --git a/man/region_data_to_rss_input.Rd b/man/dot-legacy_list_to_LDData.Rd similarity index 79% rename from man/region_data_to_rss_input.Rd rename to man/dot-legacy_list_to_LDData.Rd index 62e17229..04278c61 100644 --- a/man/region_data_to_rss_input.Rd +++ b/man/dot-legacy_list_to_LDData.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/file_utils.R -\name{region_data_to_rss_input} -\alias{region_data_to_rss_input} +\name{.legacy_list_to_LDData} +\alias{.legacy_list_to_LDData} \title{Convert loaded regional data to RSS inputs} \usage{ -region_data_to_rss_input(region_data) +.legacy_list_to_LDData(ld_list) } \arguments{ \item{region_data}{A list returned by \code{load_multitask_regional_data()}.} diff --git a/man/format_finemapping_output.Rd b/man/format_finemapping_output.Rd index 95e290c4..2d7afb7a 100644 --- a/man/format_finemapping_output.Rd +++ b/man/format_finemapping_output.Rd @@ -12,12 +12,13 @@ format_finemapping_output(post, primary_method) \item{primary_method}{Method whose result should populate root-level fields.} } \value{ -A list with root-level fields including \code{variant_names}, - \code{susie_result_trimmed}, and \code{top_loci}. +A list with root-level fields including \code{finemapping_result} + and \code{top_loci}. } \description{ Converts method-aware fine-mapping post-processing output into the root-level -fields consumed by protocol RDS files. Exposes the single 22-column unified -\code{top_loci} table alongside \code{susie_result_trimmed}, -\code{variant_names}, and method-specific intermediates. +fields consumed by protocol RDS files. The primary method's +\code{FineMappingResult} S4 object is promoted to the \code{finemapping_result} +field; use its accessors (\code{getTrimmedFit}, \code{getVariantNames}, +\code{getTopLoci}, etc.) instead of legacy list keys. } diff --git a/man/getDataType.Rd b/man/getDataType.Rd new file mode 100644 index 00000000..86552b0c --- /dev/null +++ b/man/getDataType.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getDataType} +\alias{getDataType} +\alias{getDataType,TWASWeights-method} +\title{Get Data Type} +\usage{ +getDataType(x) + +\S4method{getDataType}{TWASWeights}(x) +} +\arguments{ +\item{x}{A \code{TWASWeights} object.} +} +\value{ +A named list of data types per context, or NULL. +} +\description{ +Extract data type metadata from a TWASWeights object. +} diff --git a/man/getMolecularId.Rd b/man/getMolecularId.Rd new file mode 100644 index 00000000..4f4f7cfe --- /dev/null +++ b/man/getMolecularId.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getMolecularId} +\alias{getMolecularId} +\alias{getMolecularId,TWASWeights-method} +\title{Get Molecular ID} +\usage{ +getMolecularId(x) + +\S4method{getMolecularId}{TWASWeights}(x) +} +\arguments{ +\item{x}{A \code{TWASWeights} object.} +} +\value{ +Character string (length 0 or 1). +} +\description{ +Extract molecular/gene identifier from a TWASWeights object. +} diff --git a/tests/testthat/test_multivariate_pipeline.R b/tests/testthat/test_multivariate_pipeline.R index d28fa487..d3029876 100644 --- a/tests/testthat/test_multivariate_pipeline.R +++ b/tests/testthat/test_multivariate_pipeline.R @@ -536,9 +536,10 @@ test_that("pipeline propagates outcome_names from mvsusie through post-processin expect_true(is.list(result)) # outcome_names should propagate as context_names expect_equal(result$context_names, cnames) - # coef and clfsr should be present in susie_result_trimmed + # coef and clfsr should be present in the FineMappingResult trimmed fit # coef.mvsusie returns a (p + 1) x r matrix (first row is the intercept); # trim_finemapping_fit strips that intercept row before storing. - expect_equal(result$susie_result_trimmed$coef, fake_coef[-1, , drop = FALSE]) - expect_equal(dim(result$susie_result_trimmed$clfsr), c(L, p, r)) + trimmed <- getTrimmedFit(result$finemapping_result) + expect_equal(trimmed$coef, fake_coef[-1, , drop = FALSE]) + expect_equal(dim(trimmed$clfsr), c(L, p, r)) }) diff --git a/tests/testthat/test_susie_wrapper.R b/tests/testthat/test_susie_wrapper.R index 1ee3e3eb..b11cd43c 100644 --- a/tests/testthat/test_susie_wrapper.R +++ b/tests/testthat/test_susie_wrapper.R @@ -256,21 +256,20 @@ test_that("susie_rss_pipeline runs with single_effect method", { result <- susie_rss_pipeline(sumstats, R, analysis_method = "single_effect") expect_true(is.list(result)) - expect_true("variant_names" %in% names(result)) - expect_true("susie_result_trimmed" %in% names(result)) - if (!is.null(result$top_loci) && nrow(result$top_loci) > 0) { - expect_true("pip" %in% names(result$top_loci)) - expect_true("cs_95" %in% names(result$top_loci)) - expect_true(all(result$top_loci$method == "single_effect")) + expect_true("finemapping_result" %in% names(result)) + fm <- result$finemapping_result + if (!is.null(result$top_loci)) { + expect_true("pip_single_effect" %in% names(result$top_loci)) + expect_true("CS_95_single_effect" %in% names(result$top_loci)) } # PIPs should be numeric, in [0,1], and sum to at most 1 (L=1) - pip <- result$susie_result_trimmed$pip + pip <- getTrimmedFit(fm)$pip expect_true(is.numeric(pip)) expect_length(pip, n) expect_true(all(pip >= 0 & pip <= 1)) expect_true(sum(pip) <= 1 + 1e-6) # Credible sets, if any, should contain valid indices - cs_list <- result$susie_result_trimmed$sets$cs + cs_list <- getTrimmedFit(fm)$sets$cs if (!is.null(cs_list)) { for (cs in cs_list) { expect_true(all(cs >= 1 & cs <= n)) @@ -293,19 +292,19 @@ test_that("susie_rss_pipeline runs with bayesian_conditional_regression", { L = 5, L_greedy = 5 ) expect_true(is.list(result)) - expect_true("susie_result_trimmed" %in% names(result)) - if (!is.null(result$top_loci) && nrow(result$top_loci) > 0) { - expect_true("pip" %in% names(result$top_loci)) - expect_true("cs_95" %in% names(result$top_loci)) - expect_true(all(result$top_loci$method == "bayesian_conditional_regression")) + expect_true("finemapping_result" %in% names(result)) + fm <- result$finemapping_result + if (!is.null(result$top_loci)) { + expect_true("pip_bayesian_conditional_regression" %in% names(result$top_loci)) + expect_true("CS_95_bayesian_conditional_regression" %in% names(result$top_loci)) } - pip <- result$susie_result_trimmed$pip + pip <- getTrimmedFit(fm)$pip expect_true(is.numeric(pip)) expect_length(pip, n) expect_true(all(pip >= 0 & pip <= 1)) # With L=5, sum of PIPs can be up to L expect_true(sum(pip) <= 5 + 1e-6) - cs_list <- result$susie_result_trimmed$sets$cs + cs_list <- getTrimmedFit(fm)$sets$cs if (!is.null(cs_list)) { for (cs in cs_list) { expect_true(all(cs >= 1 & cs <= n)) @@ -329,13 +328,14 @@ test_that("susie_rss_pipeline uses beta/se when z not provided", { L = 5, L_greedy = 5 ) expect_true(is.list(result)) - expect_true("susie_result_trimmed" %in% names(result)) - pip <- result$susie_result_trimmed$pip + expect_true("finemapping_result" %in% names(result)) + fm <- result$finemapping_result + pip <- getTrimmedFit(fm)$pip expect_true(is.numeric(pip)) expect_length(pip, n) expect_true(all(pip >= 0 & pip <= 1)) expect_true(sum(pip) <= 5 + 1e-6) - cs_list <- result$susie_result_trimmed$sets$cs + cs_list <- getTrimmedFit(fm)$sets$cs if (!is.null(cs_list)) { for (cs in cs_list) { expect_true(all(cs >= 1 & cs <= n)) @@ -517,10 +517,11 @@ test_that("postprocess_finemapping_fits keeps all effects when V is NULL", { coverage = 0.95 ) result <- format_finemapping_output(post, primary_method = "susie_rss") + trimmed <- getTrimmedFit(result$finemapping_result) # With V=NULL, eff_idx = 1:L, so trimmed alpha should keep all L rows - expect_equal(nrow(result$susie_result_trimmed$alpha), L) + expect_equal(nrow(trimmed$alpha), L) # V should be NULL in trimmed output - expect_null(result$susie_result_trimmed$V) + expect_null(trimmed$V) }) # ============================================================================= @@ -568,10 +569,11 @@ test_that("postprocess_finemapping_fits stores outcome_names, coef, and clfsr fo # outcome_names should be stored as context_names expect_equal(result$context_names, cnames) + trimmed <- getTrimmedFit(result$finemapping_result) # coef should come from mvsusieR::coef.mvsusie - expect_equal(result$susie_result_trimmed$coef, fake_coef[-1, , drop = FALSE]) + expect_equal(trimmed$coef, fake_coef[-1, , drop = FALSE]) # conditional_lfsr should be trimmed to eff_idx - expect_equal(dim(result$susie_result_trimmed$clfsr), c(L, p, R)) + expect_equal(dim(trimmed$clfsr), c(L, p, R)) }) # ============================================================================= @@ -606,7 +608,7 @@ test_that("susie_rss_pipeline X-mode passes X to susie_rss and computes LD from captured_pp_data_x <<- data_x list() }, - format_finemapping_output = function(post, primary_method) list(variant_names = vnames) + format_finemapping_output = function(post, primary_method) list() ) result <- susie_rss_pipeline(list(z = z), X_mat = X, R_mismatch = "eb") expect_true("X" %in% names(captured_susie_args)) @@ -649,7 +651,7 @@ test_that("susie_rss_pipeline computes LD from first panel when X_mat is a list" captured_pp_data_x <<- data_x list() }, - format_finemapping_output = function(post, primary_method) list(variant_names = vnames) + format_finemapping_output = function(post, primary_method) list() ) result <- susie_rss_pipeline(list(z = z), X_mat = X_list) # data_x should be all panels stacked (rbind), not cor(X) @@ -744,6 +746,30 @@ test_that("adjust_susie_weights run_allele_qc=TRUE auto-prepends chrom/pos/A2/A1 expect_true(length(out$adjusted_susie_weights) > 0) }) +test_that("format_finemapping_output does not duplicate top loci variants", { + top_loci <- data.frame( + variant_id = paste0("v", 1:4), + CS_95_susie = c(0L, 1L, NA_integer_, 0L), + pip_susie = c(0.2, 0.005, 0.001, 0), + stringsAsFactors = FALSE + ) + fm <- FineMappingResult( + variant_names = paste0("v", 1:4), + trimmed_fit = list(pip = 1:4), + top_loci = data.frame(variant_id = character(0), method = character(0)), + method = "susie" + ) + post <- list( + finemapping_results = list(susie = list( + finemapping_result = fm + )), + top_loci = top_loci + ) + out <- format_finemapping_output(post, "susie") + expect_false("top_loci_variants" %in% names(out)) + expect_equal(unique(out$top_loci$variant_id), paste0("v", 1:4)) +}) + .make_univariate_data <- function(seed = 42, n = 300, p = 50, effect_idx = integer(0), effect_size = NULL) { set.seed(seed) diff --git a/tests/testthat/test_twas.R b/tests/testthat/test_twas.R index 330e04d7..bb42a909 100644 --- a/tests/testthat/test_twas.R +++ b/tests/testthat/test_twas.R @@ -89,6 +89,23 @@ make_twas_weights_data <- function( data } +# Build a mock LDData S4 object with a genotype matrix stored directly +# (mimics the output of load_ld_sketch after the S4 refactor). +make_mock_ld_sketch <- function(X, ref_panel, variant_ids) { + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + block_metadata <- S4Vectors::DataFrame( + region = "mock", start = 1L, end = 1L, chrom = "chr1" + ) + LDData( + correlation = NULL, + genotype_handle = X, + snp_idx = NULL, + variants = variants_gr, + block_metadata = block_metadata, + n_ref = 0L + ) +} + generate_twas_joint_z_data <- function(num_samples=10, num_snps=10, num_conditions = 5) { X <- matrix(sample(0:2, num_samples * num_snps, replace = TRUE), nrow = num_snps, ncol = num_samples) rownames(X) <- paste0("Sample", 1:num_samples) @@ -2696,7 +2713,9 @@ test_that("load_twas_weights works with null condition", { weight_db <- setup_weight_db_vector(n_rds = 2, n_cond = 4) res <- load_twas_weights(weight_db$weight_paths, conditions = NULL, variable_name_obj = c("preset_variants_result", "variant_names"), susie_obj =c("preset_variants_result", "susie_result_trimmed"),twas_weights_table = "twas_weights") - expect_true(all(c("susie_results", "weights") %in% names(res))) + expect_true(is(res, "TWASWeights")) + expect_true(length(getMethodNames(res)) > 0) + expect_false(is.null(getFits(res))) cleanup_weight_db_vector(weight_db$weight_paths) }) @@ -2705,7 +2724,9 @@ test_that("load_twas_weights works with specified condition", { weight_db <- setup_weight_db_vector(n_rds = 2, n_cond = 4, same_condition = TRUE, condition = "cond_1_joe_eQTL") res <- load_twas_weights(weight_db$weight_paths, conditions = "cond_1_joe_eQTL", variable_name_obj = c("preset_variants_result", "variant_names"), susie_obj =c("preset_variants_result", "susie_result_trimmed"),twas_weights_table = "twas_weights") - expect_true(all(c("susie_results", "weights") %in% names(res))) + expect_true(is(res, "TWASWeights")) + expect_true(length(getMethodNames(res)) > 0) + expect_false(is.null(getFits(res))) cleanup_weight_db_vector(weight_db$weight_paths) }) @@ -2805,11 +2826,7 @@ test_that("harmonize_twas: group_contexts_by_region single context path (lines 4 p_sketch <- length(mock_LD_variants) set.seed(123) X <- matrix(rbinom(n_sketch * p_sketch, 2, 0.3), nrow = n_sketch, ncol = p_sketch) - list( - X = X, - n_sketch = n_sketch, - ref_panel = mock_ref_panel, - variant_ids = mock_LD_variants + make_mock_ld_sketch(X, mock_ref_panel, mock_LD_variants ) }, get_ref_variant_info = function(...) mock_snp_info, @@ -2935,17 +2952,13 @@ test_that("harmonize_twas: group_contexts_by_region multi-context clustering (li p_sketch <- length(all_variant_ids) set.seed(123) X <- matrix(rbinom(n_sketch * p_sketch, 2, 0.3), nrow = n_sketch, ncol = p_sketch) - list( - X = X, - n_sketch = n_sketch, - ref_panel = data.frame(chrom = 1, pos = as.integer(sapply(strsplit(all_variant_ids, ":"), `[`, 2)), - A2 = "A", A1 = "T", - variant_id = all_variant_ids, - allele_freq = rep(0.3, length(all_variant_ids)), - variance = rep(1.0, length(all_variant_ids)), - stringsAsFactors = FALSE), - variant_ids = all_variant_ids - ) + ref_panel <- data.frame(chrom = 1, pos = as.integer(sapply(strsplit(all_variant_ids, ":"), `[`, 2)), + A2 = "A", A1 = "T", + variant_id = all_variant_ids, + allele_freq = rep(0.3, length(all_variant_ids)), + variance = rep(1.0, length(all_variant_ids)), + stringsAsFactors = FALSE) + make_mock_ld_sketch(X, ref_panel, all_variant_ids) }, get_ref_variant_info = function(...) mock_snp_info, harmonize_gwas = function(...) mock_gwas_data, @@ -3112,22 +3125,19 @@ test_that("twas_pipeline: missing data_type triggers assignment check on line 61 } ) - # Lines 614-618 are exercised: the function checks data_type, assigns NA list. - # But line 748 reads data_type from original twas_weights_data (where it's NULL), - # causing a data.frame dimension mismatch. - expect_error( - suppressMessages(twas_pipeline( - twas_weights_data = twas_weights_data, - ld_meta_file_path = "fake_ld.tsv", - gwas_meta_file = "fake_gwas.tsv", - region_block = "chr1_100_500", + # With S4 normalization, NULL data_type is handled cleanly (defaults to NA per context). + # The pipeline should complete without the previous dimension mismatch error. + result <- suppressMessages(twas_pipeline( + twas_weights_data = twas_weights_data, + ld_meta_file_path = "fake_ld.tsv", + gwas_meta_file = "fake_gwas.tsv", + region_block = "chr1_100_500", ld_reference_sample_size = 17000, - rsq_cutoff = 0.01, - rsq_pval_cutoff = 0.05, - rsq_pval_option = "pval" - )), - "differing number of rows" - ) + rsq_cutoff = 0.01, + rsq_pval_cutoff = 0.05, + rsq_pval_option = "pval" + )) + expect_true(is.list(result)) }) # =========================================================================== @@ -3390,19 +3400,15 @@ test_that("harmonize_twas: duplicated LD variants are removed", { n_sketch <- 50L set.seed(123) X <- matrix(rbinom(n_sketch * p, 2, 0.3), nrow = n_sketch, ncol = p) - list( - X = X, - n_sketch = n_sketch, - ref_panel = data.frame( - chrom = rep(1, p), pos = c(100, 200, 300), - A2 = rep("A", p), A1 = rep("T", p), - variant_id = variant_ids, - allele_freq = rep(0.3, p), - variance = rep(1.0, p), - stringsAsFactors = FALSE - ), - variant_ids = variant_ids + rp <- data.frame( + chrom = rep(1, p), pos = c(100, 200, 300), + A2 = rep("A", p), A1 = rep("T", p), + variant_id = variant_ids, + allele_freq = rep(0.3, p), + variance = rep(1.0, p), + stringsAsFactors = FALSE ) + make_mock_ld_sketch(X, rp, variant_ids) }, harmonize_gwas = function(...) mock_gwas_data, match_ref_panel = function(target_data, ref_data, ...) { @@ -3470,10 +3476,7 @@ test_that("harmonize_twas: drops molecular_id when harmonize_gwas returns NULL f n_sketch <- 50L set.seed(123) X <- matrix(rbinom(n_sketch * p, 2, 0.3), nrow = n_sketch, ncol = p) - list( - X = X, n_sketch = n_sketch, - ref_panel = ref_panel, variant_ids = variant_ids - ) + make_mock_ld_sketch(X, ref_panel, variant_ids) }, # Returning NULL skips the entire context loop, so gwas_qced stays empty harmonize_gwas = function(...) NULL, @@ -3551,10 +3554,7 @@ test_that("harmonize_twas: susie_weights column triggers adjust_susie_weights br n_sketch <- 50L set.seed(123) X <- matrix(rbinom(n_sketch * p, 2, 0.3), nrow = n_sketch, ncol = p) - list( - X = X, n_sketch = n_sketch, - ref_panel = ref_panel, variant_ids = variant_ids - ) + make_mock_ld_sketch(X, ref_panel, variant_ids) }, harmonize_gwas = function(...) mock_gwas_data, match_ref_panel = function(target_data, ref_data, ...) { diff --git a/tests/testthat/test_twas_sketch.R b/tests/testthat/test_twas_sketch.R index 656946b5..e3920cd2 100644 --- a/tests/testthat/test_twas_sketch.R +++ b/tests/testthat/test_twas_sketch.R @@ -179,7 +179,7 @@ test_that("standardize_genotype_hwe: centers by 2p and scales by sqrt(2p(1-p))", expect_equal(X_std, expected, tolerance = 1e-14) }) -test_that("load_ld_sketch: returns raw genotypes and metadata", { +test_that("load_ld_sketch: returns LDData with raw genotypes and metadata", { set.seed(55) n <- 30 p <- 12 @@ -202,11 +202,12 @@ test_that("load_ld_sketch: returns raw genotypes and metadata", { block_metadata <- S4Vectors::DataFrame( region = "chr1:1000-2100", start = 1000L, end = 2100L, chrom = "chr1" ) + # Store genotype matrix directly in genotype_handle (matching load_ld_sketch output) mock_ld_data <- new("LDData", - correlation = cor(X), - genotype_handle = NULL, + correlation = NULL, + genotype_handle = X, variants = variants_gr, - snp_idx = seq_len(p), + snp_idx = NULL, block_metadata = block_metadata ) @@ -214,23 +215,22 @@ test_that("load_ld_sketch: returns raw genotypes and metadata", { load_LD_matrix = function(ld_meta_file_path, region, return_genotype = FALSE, n_sample = NULL, ...) { mock_ld_data }, - getGenotypes = function(x) X, .package = "pecotmr" ) result <- pecotmr::load_ld_sketch("fake_path.tsv", "chr1:1000-2100") - # Check structure — returns raw X, not SVD - expect_true(all(c("X", "n_sketch", "ref_panel", "variant_ids") %in% names(result))) - expect_null(result$V) - expect_null(result$D) - expect_equal(result$n_sketch, n) - expect_equal(nrow(result$X), n) - expect_equal(ncol(result$X), p) - expect_equal(length(result$variant_ids), p) + # Check structure -- returns an LDData S4 object + expect_true(is(result, "LDData")) + result_X <- getGenotypes(result) + result_ref <- getRefPanel(result) + result_ids <- getVariantIds(result) + expect_equal(nrow(result_X), n) + expect_equal(ncol(result_X), p) + expect_equal(length(result_ids), p) # Raw genotype matrix is returned unchanged - expect_equal(result$X, X) + expect_equal(result_X, X) }) test_that("load_ld_sketch: removes monomorphic variants", { @@ -255,11 +255,12 @@ test_that("load_ld_sketch: removes monomorphic variants", { block_metadata <- S4Vectors::DataFrame( region = "chr1:1-5", start = 1L, end = 5L, chrom = "chr1" ) + # Store genotype matrix directly in genotype_handle mock_ld_data <- new("LDData", - correlation = cor(X), - genotype_handle = NULL, + correlation = NULL, + genotype_handle = X, variants = variants_gr, - snp_idx = seq_len(p), + snp_idx = NULL, block_metadata = block_metadata ) @@ -267,17 +268,20 @@ test_that("load_ld_sketch: removes monomorphic variants", { load_LD_matrix = function(ld_meta_file_path, region, return_genotype = FALSE, n_sample = NULL, ...) { mock_ld_data }, - getGenotypes = function(x) X, .package = "pecotmr" ) result <- pecotmr::load_ld_sketch("fake_path.tsv", "chr1:1-5") - # Monomorphic variant removed - expect_equal(length(result$variant_ids), p - 1) - expect_false(variant_ids[3] %in% result$variant_ids) - expect_equal(nrow(result$ref_panel), p - 1) - expect_equal(ncol(result$X), p - 1) + # Returns LDData with monomorphic variant removed + expect_true(is(result, "LDData")) + result_ids <- getVariantIds(result) + result_ref <- getRefPanel(result) + result_X <- getGenotypes(result) + expect_equal(length(result_ids), p - 1) + expect_false(variant_ids[3] %in% result_ids) + expect_equal(nrow(result_ref), p - 1) + expect_equal(ncol(result_X), p - 1) }) test_that("SVD from raw sketch matches direct computation", { diff --git a/tests/testthat/test_univariate_pipeline.R b/tests/testthat/test_univariate_pipeline.R index 771d1ae6..35a92676 100644 --- a/tests/testthat/test_univariate_pipeline.R +++ b/tests/testthat/test_univariate_pipeline.R @@ -45,18 +45,29 @@ make_fake_susie_fit <- function(p) { # Helper: build a fake protocol-facing post-processing return value # =========================================================================== make_fake_post_result <- function(p) { + vnames <- paste0("chr1:", seq_len(p), ":A:G") + trimmed <- list( + pip = runif(p), + sets = list(cs = list(L1 = c(1L, 2L)), requested_coverage = 0.95), + cs_corr = matrix(c(1, 0.1, 0.1, 1), nrow = 2), + alpha = matrix(runif(2 * p), 2, p), + lbf_variable = matrix(rnorm(2 * p), 2, p), + V = c(0.5, 0.01), + niter = 10, + n_effects = 2 + ) + fm <- FineMappingResult( + variant_names = vnames, + trimmed_fit = trimmed, + top_loci = data.frame(variant_id = character(0), method = character(0)), + method = "susie" + ) list( - variant_names = paste0("chr1:", seq_len(p), ":A:G"), - susie_result_trimmed = list( - pip = runif(p), - sets = list(cs = list(L1 = c(1L, 2L)), requested_coverage = 0.95), - cs_corr = matrix(c(1, 0.1, 0.1, 1), nrow = 2), - alpha = matrix(runif(2 * p), 2, p), - lbf_variable = matrix(rnorm(2 * p), 2, p), - V = c(0.5, 0.01), - niter = 10, - n_effects = 2 - ), + finemapping_result = fm, + # Legacy keys retained for diagnostics tests that mock get_susie_result + # to return res$susie_result_trimmed directly. + susie_result_trimmed = trimmed, + variant_names = vnames, top_loci = data.frame( variant_id = paste0("chr1:1:A:G"), betahat = 1.5, sebetahat = 0.3, z = 5.0, maf = 0.25, @@ -675,10 +686,15 @@ test_that("uap: ordinary susie can run without susie-inf initialization", { test_that("uap: post-processing output is merged into result", { inp <- make_uap_inputs() fake_fit <- make_fake_susie_fit(inp$p) - fake_post <- list( + fake_fm <- FineMappingResult( variant_names = paste0("v", seq_len(inp$p)), - top_loci = data.frame(variant_id = "v1", pip = 0.9, stringsAsFactors = FALSE), - susie_result_trimmed = list(pip = runif(inp$p)) + trimmed_fit = list(pip = runif(inp$p)), + top_loci = data.frame(variant_id = character(0), method = character(0)), + method = "susie" + ) + fake_post <- list( + finemapping_result = fake_fm, + top_loci = data.frame(variant_id = "v1", pip = 0.9, stringsAsFactors = FALSE) ) local_mocked_bindings( @@ -694,7 +710,7 @@ test_that("uap: post-processing output is merged into result", { X = inp$X, Y = inp$Y, maf = inp$maf, twas_weights = FALSE, L = 5, L_greedy = 5 ) - expect_true("variant_names" %in% names(result)) + expect_true("finemapping_result" %in% names(result)) expect_true("top_loci" %in% names(result)) expect_true("total_time_elapsed" %in% names(result)) }) From a599338df68660178fe916f6b86a607be6d513b6 Mon Sep 17 00:00:00 2001 From: danielnachun Date: Wed, 27 May 2026 18:11:28 +0000 Subject: [PATCH 08/11] Update documentation --- man/load_ld_sketch.Rd | 13 ++++--------- man/load_multitask_regional_data.Rd | 4 ++-- man/mr_format.Rd | 2 +- man/qtl_finemapping_example.Rd | 6 ++++-- 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/man/load_ld_sketch.Rd b/man/load_ld_sketch.Rd index db4e5397..234eaebe 100644 --- a/man/load_ld_sketch.Rd +++ b/man/load_ld_sketch.Rd @@ -15,15 +15,10 @@ load_ld_sketch(ld_meta_file_path, region, n_sample = NULL) (= 2*p*(1-p)*n/(n-1)). Passed through to \code{load_LD_matrix()}.} } \value{ -A list with: -\describe{ - \item{X}{Raw genotype matrix (n_sketch x p) after removing monomorphic variants.} - \item{n_sketch}{Number of rows (samples) in the sketch genotype matrix.} - \item{ref_panel}{Data.frame with variant metadata (chrom, pos, A2, A1, variant_id, - allele_freq, and optionally variance, n_nomiss).} - \item{variant_ids}{Character vector of variant IDs (canonical format) after - removing monomorphic variants.} -} +An \code{LDData} S4 object with monomorphic variants removed. + Consumers should use S4 accessors: \code{getGenotypes()}, \code{getRefPanel()}, + \code{getVariantIds()}. The number of sketch samples is + \code{nrow(getGenotypes(result))}. } \description{ Loads genotype data for a region via \code{load_LD_matrix(return_genotype=TRUE)} diff --git a/man/load_multitask_regional_data.Rd b/man/load_multitask_regional_data.Rd index 46134bef..585c2070 100644 --- a/man/load_multitask_regional_data.Rd +++ b/man/load_multitask_regional_data.Rd @@ -129,10 +129,10 @@ This function loads a mixture data sets for a specific region, including individ or summary statistics (sumstats, LD). Run \code{load_regional_univariate_data} and \code{load_rss_data} multiple times for different datasets } \section{Loading individual level data from multiple corhorts}{ - +NA } \section{Loading summary statistics from multiple corhorts or data set}{ - +NA } diff --git a/man/mr_format.Rd b/man/mr_format.Rd index e543ed44..4a8477de 100644 --- a/man/mr_format.Rd +++ b/man/mr_format.Rd @@ -17,7 +17,7 @@ mr_format( ) } \arguments{ -\item{susie_result}{A list containing the results of SuSiE analysis. This list should include nested elements such as 'susie_results', 'susie_result_trimmed', and 'top_loci', containing details about the statistical analysis of genetic variants.} +\item{susie_result}{A list containing the results of SuSiE analysis. This list should include nested elements such as 'susie_results', 'finemapping_result' (a FineMappingResult S4 object), and 'top_loci', containing details about the statistical analysis of genetic variants.} \item{condition}{A character string specifying the conditions. This is used to select the corresponding subset of results within 'susie_result'.} diff --git a/man/qtl_finemapping_example.Rd b/man/qtl_finemapping_example.Rd index 4420c855..a9e9a192 100644 --- a/man/qtl_finemapping_example.Rd +++ b/man/qtl_finemapping_example.Rd @@ -10,9 +10,11 @@ A nested list with structure \describe{ \item{susie_result_trimmed}{List. Trimmed SuSiE result with elements - \code{alpha}, \code{pip}, \code{V}, and \code{sets}.} + \code{alpha}, \code{pip}, \code{V}, and \code{sets}. (Legacy key; new + code should use the \code{FineMappingResult} S4 object.)} \item{variant_names}{Character vector (length 2,828). Synthetic variant - identifiers matching the variant names in the SuSiE result.} + identifiers matching the variant names in the SuSiE result. + (Legacy key; new code should use the \code{FineMappingResult} S4 object.)} } } \description{ From ff85b5f2c3ef2d7b6f2da17b4bf5b6aa4992bf37 Mon Sep 17 00:00:00 2001 From: Daniel Nachun Date: Mon, 1 Jun 2026 15:57:49 -0700 Subject: [PATCH 09/11] fix tests --- R/colocboost_pipeline.R | 7 +++++-- R/sumstats_qc.R | 7 +++++++ tests/testthat/test_susie_wrapper.R | 22 +++++++++++++--------- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/R/colocboost_pipeline.R b/R/colocboost_pipeline.R index 87b2cc50..649de423 100644 --- a/R/colocboost_pipeline.R +++ b/R/colocboost_pipeline.R @@ -67,11 +67,14 @@ region_data_to_colocboost_input <- function(region_data) { ind_args <- .cb_format_individual(ind_records) # Build sumstat_records from rss_input which already contains LDData S4 - # objects (region_data_to_rss_input converts any legacy lists). + # objects (region_data_to_rss_input converts any legacy lists). When the + # LDData carries genotypes, pass them through so build_ld_args routes them + # to X_ref; otherwise pass the correlation matrix as LD. sumstat_records <- lapply(names(rss_input$rss_input), function(study) { ld_data <- rss_input$LD_data[[study]] + ld_mat <- if (hasGenotypes(ld_data)) getGenotypes(ld_data) else getCorrelation(ld_data) list(rss_input = rss_input$rss_input[[study]], - LD_matrix = getCorrelation(ld_data)) + LD_matrix = ld_mat) }) names(sumstat_records) <- names(rss_input$rss_input) sumstat_args <- .cb_format_sumstat(sumstat_records) diff --git a/R/sumstats_qc.R b/R/sumstats_qc.R index 89d2f276..9fe0aac3 100644 --- a/R/sumstats_qc.R +++ b/R/sumstats_qc.R @@ -407,6 +407,13 @@ summary_stats_qc <- function(sumstats, LD_data, n = NULL, has_genotype <- hasGenotypes(LD_data_for_qc) ref_panel <- getRefPanel(LD_data_for_qc) X_ref <- if (has_genotype) getGenotypes(LD_data_for_qc) else NULL + # getGenotypes() preserves the source variant ID convention (often + # underscore-separated from PLINK), but sumstats$variant_id is colon- + # separated canonical form. Canonicalize so downstream matching works. + if (!is.null(X_ref)) { + canonical_ids <- getVariantIds(LD_data_for_qc) + if (length(canonical_ids) == ncol(X_ref)) colnames(X_ref) <- canonical_ids + } basic <- rss_basic_qc(rss_input$sumstats, LD_data_for_qc, skip_region = skip_region, keep_indel = keep_indel, return_LD_mat = !has_genotype) diff --git a/tests/testthat/test_susie_wrapper.R b/tests/testthat/test_susie_wrapper.R index b11cd43c..40903ca9 100644 --- a/tests/testthat/test_susie_wrapper.R +++ b/tests/testthat/test_susie_wrapper.R @@ -258,9 +258,10 @@ test_that("susie_rss_pipeline runs with single_effect method", { expect_true(is.list(result)) expect_true("finemapping_result" %in% names(result)) fm <- result$finemapping_result - if (!is.null(result$top_loci)) { - expect_true("pip_single_effect" %in% names(result$top_loci)) - expect_true("CS_95_single_effect" %in% names(result$top_loci)) + if (!is.null(result$top_loci) && nrow(result$top_loci) > 0) { + expect_true("pip" %in% names(result$top_loci)) + expect_true("cs_95" %in% names(result$top_loci)) + expect_true(all(result$top_loci$method == "single_effect")) } # PIPs should be numeric, in [0,1], and sum to at most 1 (L=1) pip <- getTrimmedFit(fm)$pip @@ -294,9 +295,10 @@ test_that("susie_rss_pipeline runs with bayesian_conditional_regression", { expect_true(is.list(result)) expect_true("finemapping_result" %in% names(result)) fm <- result$finemapping_result - if (!is.null(result$top_loci)) { - expect_true("pip_bayesian_conditional_regression" %in% names(result$top_loci)) - expect_true("CS_95_bayesian_conditional_regression" %in% names(result$top_loci)) + if (!is.null(result$top_loci) && nrow(result$top_loci) > 0) { + expect_true("pip" %in% names(result$top_loci)) + expect_true("cs_95" %in% names(result$top_loci)) + expect_true(all(result$top_loci$method == "bayesian_conditional_regression")) } pip <- getTrimmedFit(fm)$pip expect_true(is.numeric(pip)) @@ -1178,13 +1180,15 @@ test_that("build_top_loci requires `method`", { ), "method") }) -test_that("non-top-loci wrapper fields (susie_result_trimmed, variant_names) are preserved by format_finemapping_output", { +test_that("format_finemapping_output exposes finemapping_result with S4 accessors", { d <- .make_univariate_data(seed = 25, effect_idx = c(20)) fit <- susieR::susie(d$X, d$y, L = 5) post <- postprocess_finemapping_fits(list(susie = fit), data_x = d$X, data_y = d$y, coverage = 0.95) out <- format_finemapping_output(post, primary_method = "susie") - expect_true("susie_result_trimmed" %in% names(out)) - expect_true("variant_names" %in% names(out)) + expect_true("finemapping_result" %in% names(out)) + fm <- out$finemapping_result + expect_true(is.character(getVariantNames(fm)) && length(getVariantNames(fm)) == ncol(d$X)) + expect_true(is.list(getTrimmedFit(fm)) && !is.null(getTrimmedFit(fm)$pip)) }) test_that("missing region produces NA grange columns rather than silent omission", { From 40834c0479389f4e91f5a6747bb0350070b29ccc Mon Sep 17 00:00:00 2001 From: Daniel Nachun Date: Tue, 2 Jun 2026 00:53:57 -0700 Subject: [PATCH 10/11] more refactor --- NAMESPACE | 1 - R/AllClasses.R | 184 ++++++- R/AllGenerics.R | 132 +++++ R/AllMethods.R | 244 ++++++++- R/LD.R | 5 +- R/allele_qc.R | 11 +- R/colocboost_pipeline.R | 107 +++- R/encoloc.R | 26 +- R/file_utils.R | 402 +++++---------- R/mash_wrapper.R | 14 +- R/mr.R | 2 +- R/sumstats_qc.R | 151 ++++-- R/susie_wrapper.R | 7 +- R/twas.R | 12 +- R/twas_weights.R | 9 +- R/univariate_pipeline.R | 68 +-- R/univariate_rss_diagnostics.R | 68 ++- man/dot-legacy_list_to_LDData.Rd | 18 - tests/testthat/test_allele_qc.R | 30 +- tests/testthat/test_colocboost_pipeline.R | 470 +++++++++++------- tests/testthat/test_encoloc.R | 55 +- tests/testthat/test_file_utils.R | 154 +++--- tests/testthat/test_mash_wrapper.R | 37 +- tests/testthat/test_sumstats_qc.R | 116 +++-- tests/testthat/test_twas.R | 38 +- tests/testthat/test_univariate_pipeline.R | 271 ++++++++-- .../test_univariate_rss_diagnostics.R | 71 ++- 27 files changed, 1769 insertions(+), 934 deletions(-) delete mode 100644 man/dot-legacy_list_to_LDData.Rd diff --git a/NAMESPACE b/NAMESPACE index 7fa08257..56eee91d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,7 +6,6 @@ S3method(postprocess_finemapping_fit,susiF) S3method(postprocess_finemapping_fit,susie) S3method(postprocess_finemapping_fit,susie_inf) S3method(postprocess_finemapping_fit,susie_rss) -export(.legacy_list_to_LDData) export(AnnotationMatrix) export(FineMappingResult) export(GWASSumStats) diff --git a/R/AllClasses.R b/R/AllClasses.R index bef71229..69df73d2 100644 --- a/R/AllClasses.R +++ b/R/AllClasses.R @@ -192,7 +192,19 @@ setClass("LDStatistic", n_ref = "integer", in_sample = "logical", genome = "character" - ) + ), + validity = function(object) { + errors <- character() + if (length(object@n_ref) != 1L || object@n_ref <= 0L) + errors <- c(errors, "'n_ref' must be a single positive integer") + if (length(object@in_sample) != 1L) + errors <- c(errors, "'in_sample' must be a single logical value") + if (length(object@genome) != 1L || !nzchar(object@genome)) + errors <- c(errors, "'genome' must be a single non-empty character string") + if (nrow(object@snp_info) == 0L) + errors <- c(errors, "'snp_info' must have at least one row") + if (length(errors) == 0) TRUE else errors + } ) #' @title Eigendecomposition-Based LD Statistic @@ -215,7 +227,8 @@ setClass("LDEigen", eigenvalue_truncation = "numeric" ), validity = function(object) { - errors <- character() + parent_check <- getValidity(getClass("LDStatistic"))(object) + errors <- if (isTRUE(parent_check)) character() else parent_check n_blocks <- length(object@ld_blocks@blocks) if (length(object@eigen_list) != n_blocks) errors <- c(errors, @@ -248,7 +261,8 @@ setClass("LDScore", ld_matrix_list = "list" # for g-LDSC; empty list for S-LDSC ), validity = function(object) { - errors <- character() + parent_check <- getValidity(getClass("LDStatistic"))(object) + errors <- if (isTRUE(parent_check)) character() else parent_check if (nrow(object@ld_scores) != nrow(object@snp_info)) errors <- c(errors, "Number of rows in 'ld_scores' must match 'snp_info'") @@ -439,6 +453,93 @@ setClass("TWASWeights", } ) +# ============================================================================= +# Allele QC Result +# ============================================================================= + +#' @title Allele QC Result +#' @description S4 container for the output of \code{match_ref_panel} / +#' \code{allele_qc}. Carries the post-QC target variants alongside the full +#' merge / flip / strand diagnostics needed by downstream callers that +#' inspect what QC did. +#' @slot harmonized_data A \code{data.frame} of variants retained after +#' allele harmonization, with reference-aligned A1/A2 and (when requested) +#' sign-flipped effect columns. +#' @slot qc_summary A \code{data.frame} carrying per-variant QC diagnostics +#' from the full merge: \code{variants_id_original}, \code{variants_id_qced}, +#' \code{exact_match}, \code{sign_flip}, \code{strand_flip}, \code{INDEL}, +#' \code{ID_match}, \code{keep}, etc. +#' @export +setClass("AlleleQCResult", + representation( + harmonized_data = "data.frame", + qc_summary = "data.frame" + ) +) + +# ============================================================================= +# Summary-Statistics QC Result +# ============================================================================= + +#' @title Summary-Statistics QC Result +#' @description S4 container holding the output of \code{summary_stats_qc} and +#' \code{.summary_stats_qc_single_study}. Carries the post-QC LD reference +#' plus harmonized sumstats, a pre-imputation snapshot, and QC process +#' metadata. Replaces the legacy list-of-named-fields return shape. +#' @slot ld_data An \code{LDData} S4 object containing the post-QC LD +#' reference (correlation and/or genotype), or NULL when QC produced no LD. +#' @slot rss_input List with \code{sumstats} (post-QC data.frame), \code{n}, +#' and \code{var_y}. +#' @slot preprocess List with \code{sumstats} and \code{ld_data} fields +#' capturing the pre-imputation snapshot for downstream re-runs. +#' @slot outlier_number Integer count of LD-mismatch outliers removed. +#' @slot skipped Single logical; TRUE when QC short-circuited. +#' @slot skip_reason Character string explaining a skip; empty otherwise. +#' @export +setClass("QCResult", + representation( + ld_data = "ANY", # LDData or NULL + rss_input = "list", + preprocess = "list", + outlier_number = "integer", + skipped = "logical", + skip_reason = "character" + ), + validity = function(object) { + errors <- character() + if (!is.null(object@ld_data) && !is(object@ld_data, "LDData")) + errors <- c(errors, "'ld_data' must be an LDData object or NULL") + if (length(object@skipped) != 1L) + errors <- c(errors, "'skipped' must be a single logical value") + if (length(object@outlier_number) != 1L) + errors <- c(errors, "'outlier_number' must be a single integer") + if (length(object@skip_reason) > 1L) + errors <- c(errors, "'skip_reason' must be a single character string (or empty)") + if (length(object@rss_input) > 0L) { + required <- c("sumstats", "n", "var_y") + missing_keys <- setdiff(required, names(object@rss_input)) + if (length(missing_keys) > 0L) + errors <- c(errors, paste0( + "'rss_input' is missing key(s): ", paste(missing_keys, collapse = ", "))) + if (!is.null(object@rss_input$sumstats) && + !is.data.frame(object@rss_input$sumstats)) + errors <- c(errors, + "'rss_input$sumstats' must be a data.frame") + } + if (length(object@preprocess) > 0L) { + pp_keys <- names(object@preprocess) + if (!all(pp_keys %in% c("sumstats", "ld_data"))) + errors <- c(errors, + "'preprocess' may only contain 'sumstats' and 'ld_data' keys") + if (!is.null(object@preprocess$ld_data) && + !is(object@preprocess$ld_data, "LDData")) + errors <- c(errors, + "'preprocess$ld_data' must be an LDData or NULL") + } + if (length(errors) == 0) TRUE else errors + } +) + # ============================================================================= # Regional Data (pipeline input) # ============================================================================= @@ -478,6 +579,49 @@ setClass("RegionalData", } ) +# ============================================================================= +# Multivariate Regional Data +# ============================================================================= + +#' @title Multivariate Regional Association Data +#' @description S4 container for regional association data prepared for +#' multivariate (joint-across-conditions) modeling. Unlike +#' \code{RegionalData}, which carries a per-condition list of phenotype +#' matrices, this class assumes all conditions are jointly observed in the +#' same samples and packs the phenotypes into a single multivariate matrix +#' (samples x conditions). +#' @slot genotype_matrix Numeric matrix (samples x variants), rownames are +#' sample IDs, colnames are variant IDs. +#' @slot Y_matrix Numeric matrix (samples x conditions) of residualized +#' phenotypes after joining conditions and (optionally) filtering rows by +#' minimum non-missing count. +#' @slot Y_scalar Numeric vector of per-condition scaling factors +#' (length = ncol(Y_matrix)). +#' @slot dropped_samples Character or list capturing sample IDs dropped +#' during multivariate filtering. +#' @slot region A \code{GRanges} (single range) or NULL. +#' @slot Y_coordinates A data.frame of phenotype coordinates, or NULL. +#' @export +setClass("MultivariateRegionalData", + representation( + genotype_matrix = "matrix", + Y_matrix = "matrix", + Y_scalar = "numeric", + dropped_samples = "ANY", + region = "ANY", + Y_coordinates = "ANY" + ), + validity = function(object) { + errors <- character() + if (nrow(object@genotype_matrix) != nrow(object@Y_matrix)) + errors <- c(errors, + "genotype_matrix and Y_matrix must have the same number of rows") + if (length(object@Y_scalar) != ncol(object@Y_matrix)) + errors <- c(errors, "length(Y_scalar) must equal ncol(Y_matrix)") + if (length(errors) == 0) TRUE else errors + } +) + # ============================================================================= # Show Methods # ============================================================================= @@ -601,3 +745,37 @@ setMethod("show", "RegionalData", function(object) { n_cond, n_var, n_samp)) cat(sprintf(" Scale residuals: %s\n", object@scale_residuals)) }) + +#' @export +setMethod("show", "MultivariateRegionalData", function(object) { + cat(sprintf("MultivariateRegionalData: %d conditions, %d variants, %d samples\n", + ncol(object@Y_matrix), ncol(object@genotype_matrix), + nrow(object@genotype_matrix))) + if (!is.null(object@region)) + cat(sprintf(" Region: %s:%d-%d\n", + as.character(GenomicRanges::seqnames(object@region))[1], + GenomicRanges::start(object@region), + GenomicRanges::end(object@region))) +}) + +#' @export +setMethod("show", "AlleleQCResult", function(object) { + cat(sprintf("AlleleQCResult: %d harmonized variants (from %d scanned)\n", + nrow(object@harmonized_data), nrow(object@qc_summary))) +}) + +#' @export +setMethod("show", "QCResult", function(object) { + cat(sprintf("QCResult: %s\n", + if (object@skipped) sprintf("skipped (%s)", object@skip_reason) else "completed")) + if (length(object@rss_input) > 0 && !is.null(object@rss_input$sumstats)) { + cat(sprintf(" Sumstats: %d variants\n", + nrow(object@rss_input$sumstats))) + } + if (!is.null(object@ld_data)) { + cat(sprintf(" LD: %d variants%s\n", + length(getVariantIds(object@ld_data)), + if (hasGenotypes(object@ld_data)) " (genotype-backed)" else " (correlation)")) + } + cat(sprintf(" Outliers removed: %d\n", object@outlier_number)) +}) diff --git a/R/AllGenerics.R b/R/AllGenerics.R index d8cfc526..10f07fde 100644 --- a/R/AllGenerics.R +++ b/R/AllGenerics.R @@ -276,6 +276,71 @@ setGeneric("getResidualXScalar", setGeneric("getResidualYScalar", function(x, condition = 1L) standardGeneric("getResidualYScalar")) +#' @title Get Per-Variant Variance +#' @description Per-variant variance of residualized genotypes for a +#' condition. +#' @param x A \code{RegionalData} object. +#' @param condition Integer index of the condition. +#' @return A numeric vector (length = number of variants). +#' @export +setGeneric("getXVariance", + function(x, condition = 1L) standardGeneric("getXVariance")) + +#' @title Get Phenotype List +#' @description Extract the per-condition phenotype list from a +#' \code{RegionalData}. +#' @param x A \code{RegionalData} object. +#' @return A named list of phenotype matrices. +#' @export +setGeneric("getPhenotypes", function(x) standardGeneric("getPhenotypes")) + +#' @title Get Covariate List +#' @description Extract the per-condition covariate list from a +#' \code{RegionalData}. +#' @param x A \code{RegionalData} object. +#' @return A named list of covariate matrices. +#' @export +setGeneric("getCovariates", function(x) standardGeneric("getCovariates")) + +#' @title Get Genotype Matrix +#' @description Extract the raw genotype matrix from a +#' \code{RegionalData} or \code{MultivariateRegionalData}. +#' @param x The object. +#' @return A numeric matrix (samples x variants). +#' @export +setGeneric("getGenotypeMatrix", function(x) standardGeneric("getGenotypeMatrix")) + +#' @title Get Region Chromosome +#' @description Extract the chromosome name from a region-bearing S4 object. +#' @param x The object. +#' @return A single character string, or NULL. +#' @export +setGeneric("getChrom", function(x) standardGeneric("getChrom")) + +#' @title Get Region Range +#' @description Extract the start/end positions from a region-bearing S4 +#' object as a character vector \code{c(start, end)}. +#' @param x The object. +#' @return A character vector of length 2, or NULL. +#' @export +setGeneric("getGrange", function(x) standardGeneric("getGrange")) + +#' @title Get Multivariate Y Matrix +#' @description Extract the multivariate phenotype matrix from a +#' \code{MultivariateRegionalData}. +#' @param x A \code{MultivariateRegionalData} object. +#' @return A numeric matrix (samples x conditions). +#' @export +setGeneric("getYMatrix", function(x) standardGeneric("getYMatrix")) + +#' @title Get Y Scaling Factors +#' @description Per-condition scaling factors used for residualized +#' multivariate phenotypes. +#' @param x A \code{MultivariateRegionalData} object. +#' @return A numeric vector (length = number of conditions). +#' @export +setGeneric("getYScalar", function(x) standardGeneric("getYScalar")) + # ============================================================================= # FineMappingResult accessor generics # ============================================================================= @@ -393,6 +458,73 @@ setGeneric("getMolecularId", function(x) standardGeneric("getMolecularId")) #' @export setGeneric("getDataType", function(x) standardGeneric("getDataType")) +# ============================================================================= +# AlleleQCResult accessor generics +# ============================================================================= + +#' @title Get Harmonized Variant Data +#' @description Extract the post-QC, reference-harmonized variants from an +#' \code{AlleleQCResult}. +#' @param x An \code{AlleleQCResult} object. +#' @return A \code{data.frame} of harmonized variants. +#' @export +setGeneric("getHarmonizedData", function(x) standardGeneric("getHarmonizedData")) + +#' @title Get Allele QC Summary +#' @description Extract the full per-variant merge/flip/strand diagnostics +#' produced by allele QC. +#' @param x An \code{AlleleQCResult} object. +#' @return A \code{data.frame} with the diagnostic columns. +#' @export +setGeneric("getQCSummary", function(x) standardGeneric("getQCSummary")) + +# ============================================================================= +# QCResult accessor generics +# ============================================================================= + +#' @title Get LD Data +#' @description Extract the post-QC LDData payload from a QCResult. +#' @param x A \code{QCResult} object. +#' @return An \code{LDData} object, or NULL when QC produced no LD reference. +#' @export +setGeneric("getLDData", function(x) standardGeneric("getLDData")) + +#' @title Get RSS Input +#' @description Extract the post-QC summary-statistic record (sumstats, n, var_y). +#' @param x A \code{QCResult} object. +#' @return A list with \code{sumstats}, \code{n}, \code{var_y}. +#' @export +setGeneric("getRSSInput", function(x) standardGeneric("getRSSInput")) + +#' @title Get Preprocess Snapshot +#' @description Extract the pre-imputation snapshot (\code{sumstats}, +#' \code{ld_data}) captured before any LD-mismatch QC or RAISS imputation. +#' @param x A \code{QCResult} object. +#' @return A list with \code{sumstats} and \code{ld_data}. +#' @export +setGeneric("getPreprocess", function(x) standardGeneric("getPreprocess")) + +#' @title Get Outlier Number +#' @description Number of LD-mismatch outliers removed during QC. +#' @param x A \code{QCResult} object. +#' @return Integer count. +#' @export +setGeneric("getOutlierNumber", function(x) standardGeneric("getOutlierNumber")) + +#' @title Is Skipped +#' @description Whether QC short-circuited (e.g. no signals, too few variants). +#' @param x A \code{QCResult} object. +#' @return Single logical. +#' @export +setGeneric("isSkipped", function(x) standardGeneric("isSkipped")) + +#' @title Get Skip Reason +#' @description Why QC short-circuited; empty string if not skipped. +#' @param x A \code{QCResult} object. +#' @return Character scalar. +#' @export +setGeneric("getSkipReason", function(x) standardGeneric("getSkipReason")) + # ============================================================================= # VCF/BCF writer generic # ============================================================================= diff --git a/R/AllMethods.R b/R/AllMethods.R index 856ca807..530e4c3d 100644 --- a/R/AllMethods.R +++ b/R/AllMethods.R @@ -26,7 +26,7 @@ NULL LDData <- function(correlation = NULL, genotype_handle = NULL, snp_idx = NULL, variants, block_metadata, n_ref = 0L) { - new("LDData", + obj <- new("LDData", correlation = correlation, genotype_handle = genotype_handle, snp_idx = snp_idx, @@ -34,6 +34,8 @@ LDData <- function(correlation = NULL, genotype_handle = NULL, block_metadata = block_metadata, n_ref = as.integer(n_ref) ) + validObject(obj) + obj } #' @rdname getCorrelation @@ -163,7 +165,7 @@ RegionalData <- function(genotype_matrix, phenotypes, covariates, scale_residuals = FALSE, maf = list(), region = NULL, dropped_samples = list(), Y_coordinates = NULL) { - new("RegionalData", + obj <- new("RegionalData", genotype_matrix = genotype_matrix, phenotypes = phenotypes, covariates = covariates, @@ -173,6 +175,8 @@ RegionalData <- function(genotype_matrix, phenotypes, covariates, dropped_samples = dropped_samples, Y_coordinates = Y_coordinates ) + validObject(obj) + obj } #' @rdname getResidualX @@ -232,6 +236,149 @@ setMethod("getVariantInfo", "RegionalData", function(x) { colnames(x@genotype_matrix) }) +#' @rdname getPhenotypes +#' @export +setMethod("getPhenotypes", "RegionalData", function(x) x@phenotypes) + +#' @rdname getCovariates +#' @export +setMethod("getCovariates", "RegionalData", function(x) x@covariates) + +#' @rdname getGenotypeMatrix +#' @export +setMethod("getGenotypeMatrix", "RegionalData", function(x) x@genotype_matrix) + +#' @rdname getGenotypeMatrix +#' @export +setMethod("getGenotypeMatrix", "MultivariateRegionalData", function(x) x@genotype_matrix) + +# ----- MultivariateRegionalData constructor and accessors ----- + +#' @title Construct a MultivariateRegionalData object +#' @description Build a \code{MultivariateRegionalData} S4 object capturing +#' regional association data prepared for multivariate modeling (single +#' joint Y matrix across conditions). +#' @param genotype_matrix Numeric matrix (samples x variants). +#' @param Y_matrix Numeric matrix (samples x conditions). +#' @param Y_scalar Numeric vector of per-condition scaling factors. +#' @param dropped_samples Character vector or list of dropped sample IDs. +#' @param region A \code{GRanges} or NULL. +#' @param Y_coordinates A data.frame of phenotype coordinates, or NULL. +#' @return A \code{MultivariateRegionalData} object. +#' @export +MultivariateRegionalData <- function(genotype_matrix, Y_matrix, Y_scalar, + dropped_samples = NULL, + region = NULL, + Y_coordinates = NULL) { + obj <- new("MultivariateRegionalData", + genotype_matrix = genotype_matrix, + Y_matrix = Y_matrix, + Y_scalar = as.numeric(Y_scalar), + dropped_samples = dropped_samples, + region = region, + Y_coordinates = Y_coordinates) + validObject(obj) + obj +} + +#' @rdname getYMatrix +#' @export +setMethod("getYMatrix", "MultivariateRegionalData", function(x) x@Y_matrix) + +#' @rdname getYScalar +#' @export +setMethod("getYScalar", "MultivariateRegionalData", function(x) x@Y_scalar) + +#' @rdname getVariantInfo +#' @export +setMethod("getVariantInfo", "MultivariateRegionalData", function(x) { + colnames(x@genotype_matrix) +}) + +#' @rdname getChrom +#' @export +setMethod("getChrom", "MultivariateRegionalData", function(x) { + if (is.null(x@region)) return(NULL) + as.character(GenomicRanges::seqnames(x@region))[1] +}) + +#' @rdname getGrange +#' @export +setMethod("getGrange", "MultivariateRegionalData", function(x) { + if (is.null(x@region)) return(NULL) + as.character(c(GenomicRanges::start(x@region), + GenomicRanges::end(x@region))) +}) + +#' @rdname getMaf +#' @export +setMethod("getMaf", "MultivariateRegionalData", function(x) { + apply(x@genotype_matrix, 2, compute_maf) +}) + +#' @rdname getXVariance +#' @export +setMethod("getXVariance", "MultivariateRegionalData", function(x, condition = 1L) { + matrixStats::colVars(x@genotype_matrix) +}) + +#' @rdname getXVariance +#' @export +setMethod("getXVariance", "RegionalData", function(x, condition = 1L) { + res <- getResidualX(x, condition) + matrixStats::colVars(res) +}) + +#' @rdname getChrom +#' @export +setMethod("getChrom", "RegionalData", function(x) { + if (is.null(x@region)) return(NULL) + as.character(GenomicRanges::seqnames(x@region))[1] +}) + +#' @rdname getGrange +#' @export +setMethod("getGrange", "RegionalData", function(x) { + if (is.null(x@region)) return(NULL) + as.character(c(GenomicRanges::start(x@region), + GenomicRanges::end(x@region))) +}) + +#' @title Combine Two RegionalData Objects +#' @description Concatenate two \code{RegionalData} objects by appending +#' their per-condition slots (phenotypes, covariates, maf, dropped_samples). +#' Used by multi-panel pipelines that load per-LD-panel data and aggregate +#' them. The \code{genotype_matrix} of \code{x} is retained as the +#' canonical genotype reference; the \code{region} is taken from \code{y} +#' (mirrors prior list-merge behavior). +#' @param x First \code{RegionalData} object. +#' @param y Second \code{RegionalData} object. +#' @return A merged \code{RegionalData}. +#' @export +setMethod("c", "RegionalData", function(x, ...) { + others <- list(...) + if (length(others) == 0L) return(x) + result <- x + for (y in others) { + if (!is(y, "RegionalData")) stop("All arguments to c() must be RegionalData") + result <- RegionalData( + genotype_matrix = result@genotype_matrix, + phenotypes = c(result@phenotypes, y@phenotypes), + covariates = c(result@covariates, y@covariates), + scale_residuals = result@scale_residuals, + maf = c(result@maf, y@maf), + region = y@region, + dropped_samples = list( + X = c(result@dropped_samples$X, y@dropped_samples$X), + Y = c(result@dropped_samples$Y, y@dropped_samples$Y), + covar = c(result@dropped_samples$covar, y@dropped_samples$covar) + ), + Y_coordinates = result@Y_coordinates + ) + } + result +}) + # ============================================================================= # FineMappingResult constructor and accessors # ============================================================================= @@ -247,13 +394,15 @@ setMethod("getVariantInfo", "RegionalData", function(x) { #' @export FineMappingResult <- function(variant_names, trimmed_fit, top_loci, method, sumstats = NULL) { - new("FineMappingResult", + obj <- new("FineMappingResult", variant_names = variant_names, trimmed_fit = trimmed_fit, top_loci = top_loci, method = method, sumstats = sumstats ) + validObject(obj) + obj } #' @rdname getPIP @@ -377,7 +526,7 @@ setMethod("getEffects", "FineMappingResult", function(x) { TWASWeights <- function(weights, variant_ids, fits = NULL, cv_performance = NULL, standardized = FALSE, molecular_id = character(0), data_type = NULL) { - new("TWASWeights", + obj <- new("TWASWeights", weights = weights, variant_ids = variant_ids, methods = names(weights), @@ -387,6 +536,8 @@ TWASWeights <- function(weights, variant_ids, fits = NULL, molecular_id = molecular_id, data_type = data_type ) + validObject(obj) + obj } #' @rdname getWeights @@ -449,6 +600,91 @@ setMethod("getVariantNames", "FineMappingResult", function(x) x@variant_names) #' @export setMethod("getTopLoci", "FineMappingResult", function(x) x@top_loci) +# ============================================================================= +# AlleleQCResult constructor and accessors +# ============================================================================= + +#' @title Construct an AlleleQCResult object +#' @description Build an \code{AlleleQCResult} S4 object wrapping the post-QC +#' harmonized variants and the full per-variant QC diagnostics. +#' @param harmonized_data Data frame of variants retained after allele QC. +#' @param qc_summary Data frame of per-variant diagnostic columns. +#' @return An \code{AlleleQCResult} object. +#' @export +AlleleQCResult <- function(harmonized_data, qc_summary) { + obj <- new("AlleleQCResult", + harmonized_data = as.data.frame(harmonized_data), + qc_summary = as.data.frame(qc_summary)) + validObject(obj) + obj +} + +#' @rdname getHarmonizedData +#' @export +setMethod("getHarmonizedData", "AlleleQCResult", function(x) x@harmonized_data) + +#' @rdname getQCSummary +#' @export +setMethod("getQCSummary", "AlleleQCResult", function(x) x@qc_summary) + +# ============================================================================= +# QCResult constructor and accessors +# ============================================================================= + +#' @title Construct a QCResult object +#' @description Build a \code{QCResult} S4 object capturing the output of +#' summary-statistic QC. Validates that \code{ld_data} is an \code{LDData} +#' or NULL. +#' @param ld_data An \code{LDData} or NULL. +#' @param rss_input List with \code{sumstats}, \code{n}, \code{var_y}. +#' @param preprocess List with \code{sumstats} and \code{ld_data}. +#' @param outlier_number Integer count of LD-mismatch outliers removed. +#' @param skipped Single logical indicating a short-circuit. +#' @param skip_reason Character explanation; defaults to empty. +#' @return A \code{QCResult} object. +#' @export +QCResult <- function(ld_data = NULL, + rss_input = list(), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE, + skip_reason = "") { + reason <- if (length(skip_reason) == 0L) "" else as.character(skip_reason)[[1]] + obj <- new("QCResult", + ld_data = ld_data, + rss_input = rss_input, + preprocess = preprocess, + outlier_number = as.integer(outlier_number), + skipped = isTRUE(skipped), + skip_reason = reason) + validObject(obj) + obj +} + +#' @rdname getLDData +#' @export +setMethod("getLDData", "QCResult", function(x) x@ld_data) + +#' @rdname getRSSInput +#' @export +setMethod("getRSSInput", "QCResult", function(x) x@rss_input) + +#' @rdname getPreprocess +#' @export +setMethod("getPreprocess", "QCResult", function(x) x@preprocess) + +#' @rdname getOutlierNumber +#' @export +setMethod("getOutlierNumber", "QCResult", function(x) x@outlier_number) + +#' @rdname isSkipped +#' @export +setMethod("isSkipped", "QCResult", function(x) x@skipped) + +#' @rdname getSkipReason +#' @export +setMethod("getSkipReason", "QCResult", function(x) x@skip_reason) + # ============================================================================= # top_loci GRanges conversion # ============================================================================= diff --git a/R/LD.R b/R/LD.R index e63c4032..23ebf6d7 100644 --- a/R/LD.R +++ b/R/LD.R @@ -659,7 +659,7 @@ load_LD_from_blocks <- function(LD_meta_file_path, region, extract_coordinates = snp_idx = NULL, variants = variants_gr, block_metadata = block_metadata, - n_ref = 0L + n_ref = if (is.null(n_sample)) 0L else as.integer(n_sample) ) } @@ -714,8 +714,7 @@ filter_variants_by_ld_reference <- function(variant_ids, ld_reference_meta_file, #' into a list of smaller matrices based on the block_indices, making it easier to work with #' large LD matrices that span multiple blocks. #' -#' @param ld_data A list as returned by load_LD_matrix, containing LD_matrix, -#' LD_variants, ref_panel, and block_metadata. +#' @param ld_data An \code{LDData} S4 object as returned by \code{load_LD_matrix()}. #' @param merge_small_blocks Logical, whether to merge blocks smaller than min_merged_block_size (default: TRUE). #' @param min_merged_block_size Integer, minimum number of variants for a block after merging (default: 500). #' @param max_merged_block_size Integer, maximum number of variants in a block after merging (default: 10000). diff --git a/R/allele_qc.R b/R/allele_qc.R index 1d047cb5..57434109 100644 --- a/R/allele_qc.R +++ b/R/allele_qc.R @@ -21,7 +21,10 @@ NULL #' @param remove_strand_ambiguous Whether to remove strand SNPs (if any). Default is `TRUE`. #' @param flip_strand Whether to output the variants after strand flip. Default is `FALSE`. #' @param remove_unmatched Whether to remove unmatched variants. Default is `TRUE`. -#' @return A single data frame with matched variants. +#' @return An \code{AlleleQCResult} S4 object. Use +#' \code{getHarmonizedData()} to recover the post-QC variant +#' data.frame and \code{getQCSummary()} to inspect the per-variant +#' merge/flip/strand diagnostics. #' @importFrom magrittr %>% #' @importFrom dplyr mutate inner_join filter pull select everything row_number if_else any_of all_of rename #' @importFrom vctrs vec_duplicate_detect @@ -86,7 +89,7 @@ match_ref_panel <- function(target_data, ref_variants, col_to_flip = NULL, if (nrow(match_result) == 0) { warning("No matching variants found between target data and reference variants.") - return(list(target_data_qced = match_result, qc_summary = match_result)) + return(AlleleQCResult(harmonized_data = match_result, qc_summary = match_result)) } # match target & ref by chrom and position match_result = match_result %>% @@ -185,7 +188,7 @@ match_ref_panel <- function(target_data, ref_variants, col_to_flip = NULL, stop("Duplicated variants with different values found. Please check the input data and determine which to keep.") } - return(list(target_data_qced = result, qc_summary = match_result)) + return(AlleleQCResult(harmonized_data = result, qc_summary = match_result)) } #' @rdname match_ref_panel @@ -249,7 +252,7 @@ align_variant_names <- function(source, reference, remove_indels = FALSE, remove remove_unmatched = FALSE ) - aligned_df <- qc_result$target_data_qced + aligned_df <- getHarmonizedData(qc_result) # Format output using reference convention (preserving user's format automatically) aligned_variants <- format_variant_id( diff --git a/R/colocboost_pipeline.R b/R/colocboost_pipeline.R index 649de423..ef721501 100644 --- a/R/colocboost_pipeline.R +++ b/R/colocboost_pipeline.R @@ -66,15 +66,18 @@ region_data_to_colocboost_input <- function(region_data) { ind_records <- ind_records_from_input(ind_input) ind_args <- .cb_format_individual(ind_records) - # Build sumstat_records from rss_input which already contains LDData S4 - # objects (region_data_to_rss_input converts any legacy lists). When the - # LDData carries genotypes, pass them through so build_ld_args routes them - # to X_ref; otherwise pass the correlation matrix as LD. + # Wrap each (rss_input, LD_data) pair as a QCResult (with no QC applied) + # so .cb_format_sumstat consumes a uniform shape regardless of whether the + # records came from summary_stats_qc or directly from region_data. sumstat_records <- lapply(names(rss_input$rss_input), function(study) { - ld_data <- rss_input$LD_data[[study]] - ld_mat <- if (hasGenotypes(ld_data)) getGenotypes(ld_data) else getCorrelation(ld_data) - list(rss_input = rss_input$rss_input[[study]], - LD_matrix = ld_mat) + QCResult( + ld_data = rss_input$LD_data[[study]], + rss_input = rss_input$rss_input[[study]], + preprocess = list(), + outlier_number = 0L, + skipped = FALSE, + skip_reason = "" + ) }) names(sumstat_records) <- names(rss_input$rss_input) sumstat_args <- .cb_format_sumstat(sumstat_records) @@ -456,7 +459,8 @@ colocboost_pipeline <- function( # Extract individual contexts if (!is.null(individual_data)) { if (is.null(phenotypes_init)) { - phenotypes$individual_contexts <- names(individual_data$residual_Y) + # Pre-QC: individual_data is a RegionalData (S4) + phenotypes$individual_contexts <- names(getPhenotypes(individual_data)) } else { null_Y <- which(sapply(individual_data$Y, is.null)) if (length(null_Y) == 0) { @@ -536,18 +540,46 @@ colocboost_pipeline <- function( ####### ========= Filtering events before QC =========== ######### if (!is.null(event_filters) & !is.null(region_data$individual_data)) { - Y <- region_data$individual_data$residual_Y - Y <- lapply(seq_along(Y), function(i) { - y <- Y[[i]] + ind_data <- region_data$individual_data + Y_list <- getPhenotypes(ind_data) + Y_names <- names(Y_list) + Y_filtered <- lapply(seq_along(Y_list), function(i) { + y <- Y_list[[i]] events <- colnames(y) - condition <- names(Y)[i] + condition <- Y_names[i] filtered_events <- filter_events(events, event_filters, condition) if (is.null(filtered_events)) { return(NULL) } y[, filtered_events, drop = FALSE] - }) %>% setNames(names(region_data$individual_data$residual_Y)) - region_data$individual_data$residual_Y <- Y + }) %>% setNames(Y_names) + # Drop conditions whose events were entirely filtered out so the + # RegionalData validity is preserved; ones to drop are remembered for + # downstream QC messaging via a synthetic NULL-Y list. + keep_cond <- !vapply(Y_filtered, is.null, logical(1)) + if (!any(keep_cond)) { + region_data$individual_data <- NULL + } else { + Y_clean <- Y_filtered[keep_cond] + # Attach a record of dropped conditions for extract_contexts_studies() + # to surface the post-QC "Skipping follow-up analysis" message. + dropped_names <- names(Y_filtered)[!keep_cond] + maf_list <- ind_data@maf + Y_coords <- ind_data@Y_coordinates + region_data$individual_data <- RegionalData( + genotype_matrix = getGenotypeMatrix(ind_data), + phenotypes = Y_clean, + covariates = getCovariates(ind_data)[keep_cond], + scale_residuals = ind_data@scale_residuals, + maf = if (length(maf_list) == length(keep_cond)) maf_list[keep_cond] else maf_list, + region = ind_data@region, + dropped_samples = ind_data@dropped_samples, + Y_coordinates = if (!is.null(Y_coords)) Y_coords[keep_cond] else NULL + ) + if (length(dropped_names) > 0) { + attr(region_data$individual_data, "filtered_out_contexts") <- dropped_names + } + } } ####### ========= QC for the region_data ======== ######## @@ -674,25 +706,25 @@ qc_regional_data <- function(region_data, } qced_sumstat_to_region_data <- function(sumstat_qc) { if (is.null(sumstat_qc) || length(sumstat_qc) == 0) return(NULL) - if (!is.null(sumstat_qc$rss_input) && !is.null(sumstat_qc$LD_matrix)) { + if (is(sumstat_qc, "QCResult")) { sumstat_qc <- list(study1 = sumstat_qc) } - sumstats <- lapply(sumstat_qc, `[[`, "rss_input") - LD_mat <- list() + sumstats <- lapply(sumstat_qc, getRSSInput) + LD_data <- list() LD_match <- character() ld_variant_index <- list() for (study in names(sumstat_qc)) { - ld <- sumstat_qc[[study]]$LD_matrix - variant_key <- paste(colnames(ld), collapse = ",") + ld_obj <- getLDData(sumstat_qc[[study]]) + variant_key <- paste(if (is.null(ld_obj)) "" else getVariantIds(ld_obj), collapse = ",") if (variant_key %in% names(ld_variant_index)) { LD_match <- c(LD_match, ld_variant_index[[variant_key]]) } else { - LD_mat[[study]] <- ld + LD_data[[study]] <- ld_obj ld_variant_index[[variant_key]] <- study LD_match <- c(LD_match, study) } } - list(sumstats = sumstats, LD_mat = LD_mat, LD_match = LD_match) + list(sumstats = sumstats, LD_data = LD_data, LD_match = LD_match) } individual_data <- NULL @@ -707,6 +739,25 @@ qc_regional_data <- function(region_data, pip_cutoff_to_skip = pip_cutoff_to_skip_ind ) individual_data <- qced_individual_to_region_data(ind_qc) + # If event_filters dropped any pre-QC contexts entirely, surface them as + # NULL-Y entries so downstream extract_contexts_studies() emits the + # "Skipping follow-up analysis for individual traits ..." message. + dropped_ctx <- attr(region_data$individual_data, "filtered_out_contexts") + if (!is.null(dropped_ctx) && length(dropped_ctx) > 0 && !is.null(individual_data)) { + for (ctx in dropped_ctx) { + individual_data$X[[ctx]] <- NULL + individual_data$Y[[ctx]] <- list(NULL)[[1]] + } + # NULL inserts via [[ removed entries; re-insert as explicit NULL. + for (ctx in dropped_ctx) { + if (!ctx %in% names(individual_data$Y)) { + individual_data$Y <- c(individual_data$Y, stats::setNames(list(NULL), ctx)) + } + if (!ctx %in% names(individual_data$X)) { + individual_data$X <- c(individual_data$X, stats::setNames(list(NULL), ctx)) + } + } + } } sumstat_data <- NULL @@ -1296,21 +1347,25 @@ qc_individual_data <- function(X, Y, maf = NULL, X_variance = NULL, dict_sumstatLD = dict_sumstatLD) } if (length(sumstat_qc) == 0) return(list()) - if (!is.null(sumstat_qc$rss_input) && !is.null(sumstat_qc$LD_matrix)) { + if (is(sumstat_qc, "QCResult")) { sumstat_qc <- list(sumstat = sumstat_qc) } sumstat <- lapply(sumstat_qc, function(x) { - ss <- x$rss_input$sumstats + ss <- getRSSInput(x)$sumstats variant_id <- if ("variant_id" %in% colnames(ss)) { ss$variant_id } else { format_variant_id(ss$chrom, ss$pos, ss$A2, ss$A1) } - data.frame(z = ss$z, n = x$rss_input$n, + data.frame(z = ss$z, n = getRSSInput(x)$n, variant = normalize_variant_id(variant_id), stringsAsFactors = FALSE) }) - LD_mat <- lapply(sumstat_qc, `[[`, "LD_matrix") + LD_mat <- lapply(sumstat_qc, function(x) { + ld <- getLDData(x) + if (is.null(ld)) return(NULL) + if (hasGenotypes(ld)) getGenotypes(ld) else getCorrelation(ld) + }) filtered <- filter_valid_sumstats(sumstat, LD_mat) if (is.null(filtered)) return(list()) c( diff --git a/R/encoloc.R b/R/encoloc.R index dbd93cb1..4658db64 100644 --- a/R/encoloc.R +++ b/R/encoloc.R @@ -260,13 +260,9 @@ process_coloc_results <- function(coloc_result, LD_meta_file_path, analysis_regi method_result <- pipeline_result[[method_names[1]]] fm_result <- method_result$finemapping_result - if (!is.null(fm_result) && is(fm_result, "FineMappingResult")) { - fm_data <- getTrimmedFit(fm_result) - variant_names <- getVariantNames(fm_result) - } else { - fm_data <- method_result$susie_result_trimmed - variant_names <- method_result$variant_names - } + if (is.null(fm_result) || !is(fm_result, "FineMappingResult")) return(NULL) + fm_data <- getTrimmedFit(fm_result) + variant_names <- getVariantNames(fm_result) if (is.null(fm_data) || is.null(fm_data$lbf_variable)) return(NULL) lbf_matrix <- as.data.frame(fm_data$lbf_variable) @@ -296,17 +292,11 @@ process_coloc_results <- function(coloc_result, LD_meta_file_path, analysis_regi if (length(method_names) == 0) return(invisible(NULL)) method_result <- pipeline_result[[method_names[1]]] fm_result <- method_result$finemapping_result - if (!is.null(fm_result) && is(fm_result, "FineMappingResult")) { - save_data <- list( - susie_fit = getTrimmedFit(fm_result), - variant_names = getVariantNames(fm_result) - ) - } else { - save_data <- list( - susie_fit = method_result$susie_result_trimmed, - variant_names = method_result$variant_names - ) - } + if (is.null(fm_result) || !is(fm_result, "FineMappingResult")) return(invisible(NULL)) + save_data <- list( + susie_fit = getTrimmedFit(fm_result), + variant_names = getVariantNames(fm_result) + ) saveRDS(list(save_data), save_path) message("Fine-mapping result saved to: ", save_path, "\n Reuse with: gwas_files = '", save_path, diff --git a/R/file_utils.R b/R/file_utils.R index 3e01b005..f96d76e1 100644 --- a/R/file_utils.R +++ b/R/file_utils.R @@ -825,22 +825,12 @@ add_Y_residuals <- function(data_list, conditions, scale_residuals = FALSE) { #' @param scale_residuals Logical indicating whether to scale residuals. Default is FALSE. #' @param tabix_header Logical indicating whether the tabix file has a header. Default is TRUE. #' -#' @return A list containing the following components: -#' \itemize{ -#' \item residual_Y: A list of residualized phenotype values (either a vector or a matrix). -#' \item residual_X: A list of residualized genotype matrices for each condition. -#' \item residual_Y_scalar: Scaling factor for residualized phenotype values. -#' \item residual_X_scalar: Scaling factor for residualized genotype values. -#' \item dropped_sample: A list of dropped samples for X, Y, and covariates. -#' \item covar: Covariate data. -#' \item Y: Original phenotype data. -#' \item X_data: Original genotype data. -#' \item X: Filtered genotype matrix. -#' \item maf: Minor allele frequency (MAF) for each variant. -#' \item chrom: Chromosome of the region. -#' \item grange: Genomic range of the region (start and end positions). -#' \item Y_coordinates: Phenotype coordinates if a region is specified. -#' } +#' @return A \code{RegionalData} S4 object. Per-condition residualized +#' phenotypes, residualized genotypes, and their scaling factors are +#' computed on demand via accessors (\code{getResidualX()}, +#' \code{getResidualY()}, \code{getResidualXScalar()}, +#' \code{getResidualYScalar()}, \code{getXVariance()}). Region metadata is +#' available via \code{getChrom()} and \code{getGrange()}. #' #' @export load_regional_association_data <- function(genotype, # PLINK file @@ -937,60 +927,31 @@ load_regional_association_data <- function(genotype, # PLINK file #' Load Regional Univariate Association Data #' -#' This function loads regional association data for univariate analysis. -#' It includes residual matrices, original genotype data, and additional metadata. +#' Loads regional association data for univariate analysis. Returns a +#' \code{RegionalData} S4 object; derived quantities (residuals, scalars, +#' per-variant variance) are computed lazily via accessors +#' (\code{getResidualX}, \code{getResidualY}, \code{getResidualXScalar}, +#' \code{getResidualYScalar}, \code{getXVariance}, \code{getChrom}, +#' \code{getGrange}). #' -#' @importFrom matrixStats colVars -#' @return A list +#' @return A \code{RegionalData} object. #' @export load_regional_univariate_data <- function(...) { - dat <- load_regional_association_data(...) - n_cond <- length(dat@phenotypes) - residual_X <- lapply(seq_len(n_cond), function(i) getResidualX(dat, i)) - residual_Y <- lapply(seq_len(n_cond), function(i) getResidualY(dat, i)) - names(residual_X) <- names(dat@phenotypes) - names(residual_Y) <- names(dat@phenotypes) - residual_X_scalar <- lapply(seq_len(n_cond), function(i) getResidualXScalar(dat, i)) - residual_Y_scalar <- lapply(seq_len(n_cond), function(i) getResidualYScalar(dat, i)) - region_gr <- dat@region - return(list( - residual_Y = residual_Y, - residual_X = residual_X, - residual_Y_scalar = residual_Y_scalar, - residual_X_scalar = residual_X_scalar, - dropped_sample = dat@dropped_samples, - maf = dat@maf, - X = dat@genotype_matrix, - chrom = if (!is.null(region_gr)) as.character(seqnames(region_gr))[1] else NULL, - grange = if (!is.null(region_gr)) as.character(c(start(region_gr), end(region_gr))) else NULL, - X_variance = lapply(residual_X, function(x) colVars(x)) - )) + load_regional_association_data(...) } #' Load Regional Data for Regression Modeling #' -#' This function loads regional association data formatted for regression modeling. -#' It includes phenotype, genotype, and covariate matrices along with metadata. +#' Loads regional association data formatted for regression modeling. +#' Returns a \code{RegionalData} S4 object; the per-condition \code{X_data} +#' previously returned in a list is available as +#' \code{getResidualX(rd, i)} (residualized) or by subsetting +#' \code{rd@@genotype_matrix} by condition rownames. #' -#' @return A list +#' @return A \code{RegionalData} object. #' @export load_regional_regression_data <- function(...) { - dat <- load_regional_association_data(...) - region_gr <- dat@region - # Build per-condition X_data by subsetting genotype_matrix to each condition's samples - X_data <- lapply(dat@phenotypes, function(Y_cond) { - common <- intersect(rownames(dat@genotype_matrix), rownames(Y_cond)) - dat@genotype_matrix[common, , drop = FALSE] - }) - return(list( - Y = dat@phenotypes, - X_data = X_data, - covar = dat@covariates, - dropped_sample = dat@dropped_samples, - maf = dat@maf, - chrom = if (!is.null(region_gr)) as.character(seqnames(region_gr))[1] else NULL, - grange = if (!is.null(region_gr)) as.character(c(start(region_gr), end(region_gr))) else NULL - )) + load_regional_association_data(...) } # return matrix of R conditions, with column names being the names of the conditions (phenotypes) and row names being sample names. Even for one condition it has to be a matrix with just one column. @@ -1015,16 +976,19 @@ pheno_list_to_mat <- function(data_list) { #' Load and Preprocess Regional Multivariate Data #' -#' This function loads regional association data and processes it into a multivariate format. -#' It optionally filters out samples based on missingness thresholds in the response matrix. +#' Loads regional association data and packages it for multivariate modeling. +#' Phenotypes across conditions are joined into a single multivariate matrix +#' (samples x conditions). When \code{matrix_y_min_complete} is supplied, +#' samples with fewer than that many non-missing condition values are dropped. +#' Per-variant MAF and variance are computed on the (post-filter) genotype +#' matrix and exposed via \code{getMAF()} / \code{getXVariance()} on the +#' returned object. #' -#' @importFrom matrixStats colVars -#' @return A list +#' @return A \code{MultivariateRegionalData} object. #' @export -load_regional_multivariate_data <- function(matrix_y_min_complete = NULL, # when Y is saved as matrix, remove those with non-missing counts less than this cutoff +load_regional_multivariate_data <- function(matrix_y_min_complete = NULL, ...) { rd <- load_regional_association_data(...) - # Compute residuals for all conditions and combine into univariate-style list n_cond <- length(rd@phenotypes) residual_Y_list <- lapply(seq_len(n_cond), function(i) getResidualY(rd, i)) names(residual_Y_list) <- names(rd@phenotypes) @@ -1033,83 +997,65 @@ load_regional_multivariate_data <- function(matrix_y_min_complete = NULL, # when dat <- pheno_list_to_mat(dat) X <- rd@genotype_matrix - Y_scalar <- residual_Y_scalar_list + Y_scalar <- unlist(residual_Y_scalar_list) dropped_sample <- rd@dropped_samples region_gr <- rd@region + Y <- dat$residual_Y if (!is.null(matrix_y_min_complete)) { - Y <- filter_Y(dat$residual_Y, matrix_y_min_complete) - if (length(Y$rm_rows) > 0) { - X <- X[-Y$rm_rows, ] - Y_scalar <- unlist(Y_scalar)[-Y$rm_rows] - dropped_sample <- rownames(dat$residual_Y)[Y$rm_rows] - } else { - Y <- dat$residual_Y - Y_scalar <- unlist(Y_scalar) + filt <- filter_Y(Y, matrix_y_min_complete) + if (length(filt$rm_rows) > 0) { + X <- X[-filt$rm_rows, , drop = FALSE] + Y <- filt$Y + dropped_sample <- rownames(dat$residual_Y)[filt$rm_rows] } - } else { - Y <- dat$residual_Y - Y_scalar <- unlist(Y_scalar) } - return(list( - residual_Y = Y, - residual_Y_scalar = Y_scalar, - dropped_sample = dropped_sample, - X = X, - maf = apply(X, 2, compute_maf), - chrom = if (!is.null(region_gr)) as.character(seqnames(region_gr))[1] else NULL, - grange = if (!is.null(region_gr)) as.character(c(start(region_gr), end(region_gr))) else NULL, - X_variance = colVars(X) - )) + + MultivariateRegionalData( + genotype_matrix = X, + Y_matrix = as.matrix(Y), + Y_scalar = Y_scalar, + dropped_samples = dropped_sample, + region = region_gr, + Y_coordinates = rd@Y_coordinates + ) } #' Load Regional Functional Association Data #' -#' This function loads precomputed regional functional association data. +#' Loads precomputed regional functional association data. Returns a +#' \code{RegionalData} S4 object; derived quantities are computed lazily +#' via accessors. When \code{min_markers} is supplied, conditions whose +#' \code{Y_coordinates} have fewer than \code{min_markers} rows are +#' dropped from the returned \code{RegionalData}. #' #' @param min_markers Minimum number of phenotype markers required for a study. #' If \code{NULL}, no marker-count filtering is applied. -#' @return A list +#' @return A \code{RegionalData} object. #' @export load_regional_functional_data <- function(..., min_markers = NULL) { rd <- load_regional_association_data(...) - n_cond <- length(rd@phenotypes) - residual_Y <- lapply(seq_len(n_cond), function(i) getResidualY(rd, i)) - residual_X <- lapply(seq_len(n_cond), function(i) getResidualX(rd, i)) - residual_Y_scalar <- lapply(seq_len(n_cond), function(i) getResidualYScalar(rd, i)) - residual_X_scalar <- lapply(seq_len(n_cond), function(i) getResidualXScalar(rd, i)) - names(residual_Y) <- names(residual_X) <- names(rd@phenotypes) - names(residual_Y_scalar) <- names(residual_X_scalar) <- names(rd@phenotypes) - region_gr <- rd@region - dat <- list( - residual_Y = residual_Y, - residual_X = residual_X, - residual_Y_scalar = residual_Y_scalar, - residual_X_scalar = residual_X_scalar, - dropped_sample = rd@dropped_samples, - covar = rd@covariates, - Y = rd@phenotypes, - X = rd@genotype_matrix, - maf = rd@maf, - chrom = if (!is.null(region_gr)) as.character(seqnames(region_gr))[1] else NULL, - grange = if (!is.null(region_gr)) as.character(c(start(region_gr), end(region_gr))) else NULL, - Y_coordinates = rd@Y_coordinates - ) - if (!is.null(min_markers)) { - dat <- .filter_functional_data_by_marker_count(dat, min_markers) - } - dat + if (!is.null(min_markers)) rd <- .filter_regional_data_by_marker_count(rd, min_markers) + rd } -.filter_functional_data_by_marker_count <- function(fdat, min_markers, - always_keep = c("dropped_sample", "dropped_samples", "X", "chrom", "grange")) { - if (is.null(fdat$Y_coordinates)) return(fdat) - keep <- vapply(fdat$Y_coordinates, function(x) nrow(x) >= min_markers, logical(1)) - filter_names <- setdiff(names(fdat), always_keep) - fdat[filter_names] <- lapply(fdat[filter_names], function(x) { - if (length(x) == length(keep)) x[keep] else x - }) - fdat +# Subset per-condition slots of a RegionalData by the marker counts in +# Y_coordinates. The genotype_matrix, region, and dropped_samples are +# preserved (those are not per-condition or are panel-wide). +.filter_regional_data_by_marker_count <- function(rd, min_markers) { + if (is.null(rd@Y_coordinates)) return(rd) + keep <- vapply(rd@Y_coordinates, function(x) nrow(x) >= min_markers, logical(1)) + if (all(keep)) return(rd) + RegionalData( + genotype_matrix = rd@genotype_matrix, + phenotypes = rd@phenotypes[keep], + covariates = rd@covariates[keep], + scale_residuals = rd@scale_residuals, + maf = if (length(rd@maf) == length(keep)) rd@maf[keep] else rd@maf, + region = rd@region, + dropped_samples = rd@dropped_samples, + Y_coordinates = rd@Y_coordinates[keep] + ) } @@ -1566,7 +1512,7 @@ load_rss_data <- function(sumstat_path, column_file_path = NULL, n_sample = 0, n #' sumstat_data contains the following components if exist #' \itemize{ #' \item sumstats: A list of summary statistics for the matched LD_info, each sublist contains sumstats, n, var_y from \code{load_rss_data}. -#' \item LD_info: A list of LD information, each sublist contains LD_variants, LD_matrix, ref_panel \code{load_LD_matrix}. +#' \item LD_info: A list of \code{LDData} S4 objects (one per LD reference), as returned by \code{load_LD_matrix}. #' } #' #' @export @@ -1665,11 +1611,7 @@ load_multitask_regional_data <- function(region, # a string of chr:start-end for if (is.null(individual_data)) { individual_data <- dat } else { - individual_data <- stats::setNames(lapply(names(dat), function(k) { - c(individual_data[[k]], dat[[k]]) - }), names(dat)) - individual_data$chrom <- dat$chrom - individual_data$grange <- dat$grange + individual_data <- c(individual_data, dat) } } } @@ -1792,40 +1734,57 @@ region_data_to_ind_input <- function(region_data) { source_info = list(has_individual = FALSE, contexts = character()))) } - X <- first_non_null(individual_data$residual_X, individual_data$X) - Y <- first_non_null(individual_data$residual_Y, individual_data$Y) - if (is.list(X) && !is.matrix(X) && !is.data.frame(X) && - is.null(names(X)) && !is.null(names(Y)) && length(X) == length(Y)) { - names(X) <- names(Y) - } - if (is.list(Y) && !is.matrix(Y) && !is.data.frame(Y) && - is.null(names(Y)) && !is.null(names(X)) && length(Y) == length(X)) { - names(Y) <- names(X) - } - if (is.matrix(X) && is.list(Y) && !is.null(names(Y))) { - X <- stats::setNames(rep(list(X), length(Y)), names(Y)) - } - aligned <- align_individual_contexts(X, Y) - X <- aligned$X - Y <- aligned$Y - maf <- individual_data$maf - X_variance <- individual_data$X_variance - if (is.list(maf) && is.null(names(maf)) && !is.null(names(X)) && length(maf) == length(X)) { - names(maf) <- names(X) + if (is(individual_data, "RegionalData")) { + contexts <- names(individual_data@phenotypes) + n_cond <- length(contexts) + X <- stats::setNames( + lapply(seq_len(n_cond), function(i) getResidualX(individual_data, i)), + contexts + ) + Y <- stats::setNames( + lapply(seq_len(n_cond), function(i) getResidualY(individual_data, i)), + contexts + ) + aligned <- align_individual_contexts(X, Y) + X <- aligned$X + Y <- aligned$Y + maf <- individual_data@maf + X_variance <- stats::setNames( + lapply(seq_len(n_cond), function(i) getXVariance(individual_data, i)), + contexts + ) + return(list( + X = X, + Y = Y, + maf = maf, + X_variance = X_variance, + source_info = list(has_individual = !is.null(X) && !is.null(Y), + contexts = contexts) + )) } - if (is.list(X_variance) && is.null(names(X_variance)) && !is.null(names(X)) && - length(X_variance) == length(X)) { - names(X_variance) <- names(X) + + # Post-QC shape: list(X = list_of_matrices, Y = list_of_matrices, ...) + if (is.list(individual_data) && + (!is.null(individual_data$X) || !is.null(individual_data$Y))) { + X <- individual_data$X + Y <- individual_data$Y + aligned <- align_individual_contexts(X, Y) + X <- aligned$X + Y <- aligned$Y + maf <- individual_data$maf + X_variance <- individual_data$X_variance + contexts <- if (!is.null(X) && is.list(X) && !is.matrix(X)) names(X) else character() + return(list( + X = X, + Y = Y, + maf = maf, + X_variance = X_variance, + source_info = list(has_individual = !is.null(X) && !is.null(Y), + contexts = contexts) + )) } - contexts <- unique(c(names(X), names(Y))) - list( - X = X, - Y = Y, - maf = maf, - X_variance = X_variance, - source_info = list(has_individual = !is.null(X) && !is.null(Y), - contexts = contexts) - ) + + stop("region_data$individual_data must be a RegionalData object or a post-QC list with X/Y entries") } #' Convert loaded regional data to RSS inputs @@ -1834,102 +1793,10 @@ region_data_to_ind_input <- function(region_data) { #' @return A list containing named RSS inputs, matched LD data, and source #' information. #' @export -.legacy_list_to_LDData <- function(ld_list) { - # Convert legacy LD list to LDData S4 object - if (is(ld_list, "LDData")) return(ld_list) - ld_mat <- ld_list$LD_matrix - ref_panel <- ld_list$ref_panel - ld_variants <- ld_list$LD_variants - # Build ref_panel if missing - if (is.null(ref_panel) && !is.null(ld_variants)) { - if (is.data.frame(ld_variants)) { - ref_panel <- ld_variants - } else { - parsed <- tryCatch(parse_variant_id(ld_variants), error = function(e) NULL) - if (!is.null(parsed)) { - ref_panel <- parsed - ref_panel$variant_id <- ld_variants - } - } - } - if (is.null(ref_panel)) return(ld_list) # cannot convert - if (is.data.frame(ref_panel) && !"variant_id" %in% names(ref_panel)) { - ref_panel$variant_id <- format_variant_id(ref_panel$chrom, ref_panel$pos, ref_panel$A2, ref_panel$A1) - } - if (!"chrom" %in% names(ref_panel)) ref_panel$chrom <- "1" - ref_panel$chrom <- as.character(ref_panel$chrom) - variants_gr <- .ref_panel_to_granges(ref_panel) - is_genotype <- isTRUE(ld_list$is_genotype) || (is.matrix(ld_mat) && nrow(ld_mat) != ncol(ld_mat)) - corr <- if (is_genotype) cor(ld_mat) else ld_mat - bm <- ld_list$block_metadata - if (is.null(bm)) bm <- .infer_single_ld_block_metadata(ref_panel) - if (is.null(bm)) bm <- data.frame() - LDData( - correlation = corr, - variants = variants_gr, - block_metadata = bm - ) -} - region_data_to_rss_input <- function(region_data) { - make_ld_data_from_matrix <- function(ld, variant_ids = NULL) { - is_genotype <- is.matrix(ld) && nrow(ld) != ncol(ld) - if (!is.null(variant_ids) && is.matrix(ld)) { - if (is.null(colnames(ld)) && length(variant_ids) == ncol(ld)) { - colnames(ld) <- variant_ids - } - if (!is_genotype && is.null(rownames(ld)) && length(variant_ids) == nrow(ld)) { - rownames(ld) <- variant_ids - } - } - ids <- if (is.matrix(ld) && !is_genotype) rownames(ld) else colnames(ld) - parsed <- NULL - if (!is.null(ids) && length(ids) > 0) { - parsed <- tryCatch(parse_variant_id(ids), error = function(e) NULL) - if (!is.null(parsed)) { - ids <- format_variant_id(parsed$chrom, parsed$pos, parsed$A2, parsed$A1) - if (!is_genotype && is.matrix(ld)) rownames(ld) <- colnames(ld) <- ids - if (is_genotype && is.matrix(ld)) colnames(ld) <- ids - parsed$variant_id <- ids - } - } - # Build LDData S4 object - if (!is.null(parsed)) { - ref_panel_df <- parsed - ref_panel_df$chrom <- as.character(ref_panel_df$chrom) - variants_gr <- .ref_panel_to_granges(ref_panel_df) - corr <- if (is_genotype) cor(ld) else ld - bm <- .infer_single_ld_block_metadata(ref_panel_df) - LDData( - correlation = corr, - variants = variants_gr, - block_metadata = bm - ) - } else { - # Cannot parse variant IDs — fall back to legacy list - list( - LD_matrix = ld, - LD_variants = ids, - ref_panel = parsed, - block_metadata = NULL, - is_genotype = isTRUE(is_genotype) - ) - } - } - rss_input_from_qced_sumstat <- function(sumstat_data) { - variant_ids_from_rss <- function(rss) { - ss <- rss$sumstats - if (is.null(ss)) return(character()) - if ("variant_id" %in% colnames(ss)) return(normalize_variant_id(as.character(ss$variant_id))) - if (all(c("chrom", "pos", "A2", "A1") %in% colnames(ss))) { - return(format_variant_id(ss$chrom, ss$pos, ss$A2, ss$A1)) - } - character() - } - rss_input <- sumstat_data$sumstats - LD_mat <- sumstat_data$LD_mat + LD_data_in <- sumstat_data$LD_data LD_match <- sumstat_data$LD_match studies <- names(rss_input) LD_data <- list() @@ -1937,13 +1804,14 @@ region_data_to_rss_input <- function(region_data) { for (i in seq_along(studies)) { study <- studies[[i]] ld_name <- if (!is.null(LD_match) && length(LD_match) >= i) LD_match[[i]] else study - if (is.null(ld_name) || is.na(ld_name) || !ld_name %in% names(LD_mat)) { - ld_name <- names(LD_mat)[min(i, length(LD_mat))] + if (is.null(ld_name) || is.na(ld_name) || !ld_name %in% names(LD_data_in)) { + ld_name <- names(LD_data_in)[min(i, length(LD_data_in))] + } + ld <- LD_data_in[[ld_name]] + if (!is.null(ld) && !is(ld, "LDData")) { + stop("region_data$sumstat_data$LD_data entries must be LDData objects.") } - ld <- LD_mat[[ld_name]] - rss <- rss_input[[study]] - variant_ids <- variant_ids_from_rss(rss) - LD_data[[study]] <- make_ld_data_from_matrix(ld, variant_ids) + LD_data[[study]] <- ld ld_group[[study]] <- ld_name } list( @@ -1961,7 +1829,7 @@ region_data_to_rss_input <- function(region_data) { source_info = list(has_sumstat = FALSE, studies = character(), ld_group = character()))) } - if (!is.null(sumstat_data$LD_mat)) { + if (!is.null(sumstat_data$LD_data)) { return(rss_input_from_qced_sumstat(sumstat_data)) } @@ -1976,16 +1844,16 @@ region_data_to_rss_input <- function(region_data) { if (is.null(group_name) || is.na(group_name) || group_name == "") { group_name <- paste0("LD", ld_index) } + ld_entry <- sumstat_data$LD_info[[ld_index]] + if (!is.null(ld_entry) && !is(ld_entry, "LDData")) { + stop("region_data$sumstat_data$LD_info entries must be LDData objects.") + } for (study in names(studies)) { output_name <- study if (output_name %in% names(rss_input)) { output_name <- make.unique(c(names(rss_input), output_name))[length(rss_input) + 1] } rss_input[[output_name]] <- studies[[study]] - ld_entry <- sumstat_data$LD_info[[ld_index]] - if (!is(ld_entry, "LDData") && is.list(ld_entry)) { - ld_entry <- .legacy_list_to_LDData(ld_entry) - } LD_data[[output_name]] <- ld_entry ld_group[[output_name]] <- group_name } diff --git a/R/mash_wrapper.R b/R/mash_wrapper.R index c7801089..a3dc4b08 100644 --- a/R/mash_wrapper.R +++ b/R/mash_wrapper.R @@ -776,12 +776,12 @@ merge_sumstats_matrices <- function(matrix_list, value_column, ref_panel = NULL, ld_bim_file <- vroom(bim_file_path) # Perform allele quality control - flipped_data <- match_ref_panel(data, ld_bim_file$V2, + flipped_data <- getHarmonizedData(match_ref_panel(data, ld_bim_file$V2, col_to_flip = c(value_column), match_min_prop = 0, remove_dups = FALSE, remove_indels = FALSE, remove_strand_ambiguous = FALSE, flip_strand = FALSE, remove_unmatched = TRUE - )$target_data_qced + )) return(flipped_data) } @@ -799,10 +799,10 @@ merge_sumstats_matrices <- function(matrix_list, value_column, ref_panel = NULL, # Step 3: Combine extracted chromosomal info with value column cohort_df <- cbind(cohort_variants_df, value = df2[, value_column, drop = FALSE]) - flipped_data <- match_ref_panel(cohort_df, ref_panel, col_to_flip = c(value_column), + flipped_data <- getHarmonizedData(match_ref_panel(cohort_df, ref_panel, col_to_flip = c(value_column), match_min_prop = 0, remove_dups = FALSE, remove_indels = FALSE, remove_strand_ambiguous = FALSE, - flip_strand = FALSE, remove_unmatched = TRUE, remove_same_vars = FALSE)$target_data_qced + flip_strand = FALSE, remove_unmatched = TRUE, remove_same_vars = FALSE)) final_df <- flipped_data %>% select(c("variant_id", value_column)) @@ -1100,12 +1100,10 @@ extract_flatten_sumstats_from_nested <- function(data, extract_inf = "z", max_de } if (is.list(element)) { - # Extract variant_names from FineMappingResult S4 or legacy list key has_fm <- !is.null(element$finemapping_result) && is(element$finemapping_result, "FineMappingResult") - has_legacy_vn <- "variant_names" %in% names(element) has_sumstats <- "sumstats" %in% names(element) - if (has_sumstats && (has_fm || has_legacy_vn)) { - variant_names <- if (has_fm) getVariantNames(element$finemapping_result) else element$variant_names + if (has_sumstats && has_fm) { + variant_names <- getVariantNames(element$finemapping_result) sumstats <- element$sumstats # Extract based on type diff --git a/R/mr.R b/R/mr.R index 233ad30b..350a00f2 100644 --- a/R/mr.R +++ b/R/mr.R @@ -95,7 +95,7 @@ mr_format <- function(susie_result, condition, gwas_sumstats_db, coverage = NULL gwas_sumstats_db_extracted$variant_id, c("bhat_x"), match_min_prop = 0 ) - susie_cs_result_formatted <- susie_cs_result_formatted$target_data_qced[, c("gene_name", "variant_id", "bhat_x", "sbhat_x", "cs", "pip")] + susie_cs_result_formatted <- getHarmonizedData(susie_cs_result_formatted)[, c("gene_name", "variant_id", "bhat_x", "sbhat_x", "cs", "pip")] } # Ensure consistent chr prefix convention before intersecting if (nrow(susie_cs_result_formatted) == 0) return(.create_null_mr_df(gene_name, mr_format_spec)) diff --git a/R/sumstats_qc.R b/R/sumstats_qc.R index 9fe0aac3..9974d491 100644 --- a/R/sumstats_qc.R +++ b/R/sumstats_qc.R @@ -5,18 +5,17 @@ #' specified regions from the analysis. #' #' @param sumstats A data frame containing summary statistics with columns "chrom", "pos", "A1", and "A2". -#' @param LD_data An \code{LDData} S4 object or a legacy list containing combined LD variants data, -#' as generated by \code{load_LD_matrix}. +#' @param LD_data An \code{LDData} S4 object containing combined LD variants +#' data, as generated by \code{load_LD_matrix}. #' @param skip_region A character vector specifying regions to be skipped in the analysis (optional). #' Each region should be in the format "chrom:start-end" (e.g., "1:1000000-2000000"). -#' @param return_LD_mat Logical; if \code{FALSE}, return only harmonized -#' summary statistics and skip LD-matrix subsetting. This is useful when the -#' reference input is genotype-backed \code{X_ref}. Defaults to \code{TRUE} -#' for backwards compatibility. +#' @param return_LD_mat Logical; if \code{FALSE}, the returned \code{QCResult} +#' carries \code{NULL} in its \code{ld_data} slot (no LD subsetting is +#' performed). Useful when the reference input is genotype-backed. #' -#' @return A list containing the processed summary statistics and LD matrix. -#' - sumstats: A data frame containing the processed summary statistics. -#' - LD_mat: The processed LD matrix. +#' @return A \code{QCResult} S4 object. Use \code{getRSSInput()$sumstats} to +#' recover the harmonized sumstats and \code{getCorrelation(getLDData(qc))} +#' to recover the aligned LD matrix (or NULL when \code{return_LD_mat=FALSE}). #' #' @importFrom dplyr filter pull arrange #' @importFrom tibble tibble @@ -54,6 +53,7 @@ rss_basic_qc <- function(sumstats, LD_data, skip_region = NULL, keep_indel = TRU remove_strand_ambiguous = TRUE ) + qced <- getHarmonizedData(allele_flip) if (!is.null(skip_region)) { skip_table <- tibble(region = skip_region) %>% separate(region, into = c("chrom", "start", "end"), sep = "[-:]") %>% @@ -61,7 +61,7 @@ rss_basic_qc <- function(sumstats, LD_data, skip_region = NULL, keep_indel = TRU skip_variant <- c() for (region_index in 1:nrow(skip_table)) { - variant <- allele_flip$target_data_qced %>% + variant <- qced %>% filter(chrom == skip_table$chrom[region_index] & pos > skip_table$start[region_index] & pos < skip_table$end[region_index]) %>% @@ -69,14 +69,20 @@ rss_basic_qc <- function(sumstats, LD_data, skip_region = NULL, keep_indel = TRU skip_variant <- c(skip_variant, variant) } - allele_flip$target_data_qced <- allele_flip$target_data_qced %>% - filter(!(variant_id %in% skip_variant)) + qced <- qced %>% filter(!(variant_id %in% skip_variant)) } - sumstats_processed <- allele_flip$target_data_qced %>% arrange(pos) + sumstats_processed <- qced %>% arrange(pos) + n_ref <- LD_data@n_ref if (!isTRUE(return_LD_mat)) { - return(list(sumstats = sumstats_processed, LD_mat = NULL)) + return(QCResult( + ld_data = NULL, + rss_input = list(sumstats = sumstats_processed, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + )) } # Align and subset LD by mapping core IDs (strip trailing build suffix) to exact LD IDs @@ -102,7 +108,14 @@ rss_basic_qc <- function(sumstats, LD_data, skip_region = NULL, keep_indel = TRU LD_mat_processed <- LD_matrix[sumstats_processed$variant_id, sumstats_processed$variant_id, drop = FALSE] - return(list(sumstats = sumstats_processed, LD_mat = LD_mat_processed)) + QCResult( + ld_data = .qc_ld_data_from_matrix(LD_mat_processed, sumstats_processed$variant_id, + has_genotype = FALSE, n_ref = n_ref), + rss_input = list(sumstats = sumstats_processed, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ) } @@ -181,16 +194,13 @@ ld_mismatch_qc <- function(zScore, R = NULL, X = NULL, nSample = NULL, #' ignored by the historical LD-mismatch-only call unless \code{rss_input} or #' combined-QC options are supplied. #' -#' @return A list containing the quality-controlled summary statistics and -#' updated LD matrix for the historical call: -#' \itemize{ -#' \item sumstats: The quality-controlled summary statistics data frame. -#' \item LD_mat: The updated LD matrix after quality control. -#' \item outlier_number: The number of outlier variants removed. -#' } -#' When \code{rss_input} or combined-QC controls are supplied, returns a -#' cleaned RSS/LD record for one RSS record, or a named list of records for a -#' list of RSS records. +#' @return A \code{QCResult} S4 object for the historical LD-mismatch-only +#' call (use \code{getRSSInput()$sumstats}, \code{getCorrelation(getLDData())}, +#' and \code{getOutlierNumber()} to recover the harmonized sumstats, post-QC +#' LD matrix, and outlier count). When \code{rss_input} or combined-QC +#' controls are supplied, returns either a single \code{QCResult} (for one +#' RSS record) or a named list of \code{QCResult} objects (for a list of +#' RSS records). #' #' @details This function applies the specified quality control method to the #' processed summary statistics via \code{\link{ld_mismatch_qc}}, then subsets @@ -265,7 +275,16 @@ summary_stats_qc <- function(sumstats, LD_data, n = NULL, sumstats_qc <- sumstats[keep_index, , drop = FALSE] LD_mat_qc <- LD_extract[sumstats_qc$variant_id, sumstats_qc$variant_id, drop = FALSE] outlier_number <- nrow(sumstats) - nrow(sumstats_qc) - return(list(sumstats = sumstats_qc, LD_mat = LD_mat_qc, outlier_number = outlier_number)) + n_ref <- if (is(LD_data, "LDData")) LD_data@n_ref else 0L + return(QCResult( + ld_data = .qc_ld_data_from_matrix(LD_mat_qc, sumstats_qc$variant_id, + has_genotype = FALSE, n_ref = n_ref), + rss_input = list(sumstats = sumstats_qc, n = n, + var_y = if (is.null(var_y)) NA_real_ else var_y), + preprocess = list(), + outlier_number = outlier_number, + skipped = FALSE + )) } qc_method <- .resolve_summary_qc_method(qc_method) @@ -383,22 +402,39 @@ summary_stats_qc <- function(sumstats, LD_data, n = NULL, LD_data } +# Wrap a QC-stage LD payload (either a correlation matrix subset by variant_id, +# or a genotype matrix with colnames = variant_id) into an LDData. Used so the +# QCResult always carries an LDData rather than a bare matrix. +.qc_ld_data_from_matrix <- function(mat, variants, has_genotype, n_ref = 0L) { + if (is.null(mat) || length(variants) == 0L) return(NULL) + parsed <- tryCatch(parse_variant_id(variants), error = function(e) NULL) + if (is.null(parsed)) return(NULL) + parsed$variant_id <- variants + parsed$chrom <- as.character(parsed$chrom) + variants_gr <- .ref_panel_to_granges(parsed) + bm <- .infer_single_ld_block_metadata(parsed) + if (has_genotype) { + LDData( + correlation = NULL, + genotype_handle = mat, + variants = variants_gr, + block_metadata = bm, + n_ref = as.integer(n_ref) + ) + } else { + LDData( + correlation = mat, + variants = variants_gr, + block_metadata = bm, + n_ref = as.integer(n_ref) + ) + } +} + .summary_stats_qc_single_study <- function(rss_input, LD_data, keep_indel, skip_region, pip_cutoff_to_skip, qc_method, impute, impute_opts, study, return_on_skip, R_finite = NULL, R_mismatch = NULL) { - skipped_result <- function(sumstats, LD_mat, reason) { - if (!identical(return_on_skip, "preprocess")) return(NULL) - list( - rss_input = list(sumstats = sumstats, n = rss_input$n, var_y = rss_input$var_y), - LD_matrix = LD_mat, - preprocess = list(sumstats = sumstats, LD_mat = LD_mat), - outlier_number = 0L, - skipped = TRUE, - skip_reason = reason - ) - } - if (is.null(rss_input) || is.null(LD_data)) return(NULL) message("QC track: starting basic allele harmonization for summary-stat study ", study, ".") message("QC track: basic summary-stat QC requires sumstat$variant and LD_data variant IDs ", @@ -414,11 +450,25 @@ summary_stats_qc <- function(sumstats, LD_data, n = NULL, canonical_ids <- getVariantIds(LD_data_for_qc) if (length(canonical_ids) == ncol(X_ref)) colnames(X_ref) <- canonical_ids } + n_ref_for_qc <- LD_data_for_qc@n_ref + skipped_result <- function(sumstats, ld_mat, reason) { + if (!identical(return_on_skip, "preprocess")) return(NULL) + ld <- .qc_ld_data_from_matrix(ld_mat, sumstats$variant_id, has_genotype, n_ref_for_qc) + QCResult( + ld_data = ld, + rss_input = list(sumstats = sumstats, n = rss_input$n, var_y = rss_input$var_y), + preprocess = list(sumstats = sumstats, ld_data = ld), + outlier_number = 0L, + skipped = TRUE, + skip_reason = reason + ) + } basic <- rss_basic_qc(rss_input$sumstats, LD_data_for_qc, skip_region = skip_region, keep_indel = keep_indel, return_LD_mat = !has_genotype) - sumstats <- basic$sumstats - R_mat <- basic$LD_mat + sumstats <- getRSSInput(basic)$sumstats + basic_ld <- getLDData(basic) + R_mat <- if (is.null(basic_ld)) NULL else getCorrelation(basic_ld) n <- rss_input$n var_y <- rss_input$var_y reference_for_variants <- function(variants) { @@ -492,9 +542,11 @@ summary_stats_qc <- function(sumstats, LD_data, n = NULL, message("QC track: running ", qc_method, " LD-mismatch QC for summary-stat study ", study, ".") qc <- summary_stats_qc(sumstats = sumstats, LD_data = ld_data_with_local_R(sumstats), n = n, method = qc_method) - sumstats <- qc$sumstats - R_mat <- qc$LD_mat - outlier_number <- qc$outlier_number + qc_rss <- getRSSInput(qc) + sumstats <- qc_rss$sumstats + qc_ld <- getLDData(qc) + R_mat <- if (is.null(qc_ld)) NULL else getCorrelation(qc_ld) + outlier_number <- getOutlierNumber(qc) message("QC track: removed ", outlier_number, " LD-mismatch outlier(s) for summary-stat study ", study, ".") } @@ -521,10 +573,17 @@ summary_stats_qc <- function(sumstats, LD_data, n = NULL, if (!is.null(imputed$LD_mat)) R_mat <- imputed$LD_mat } final_vars <- sumstats$variant_id - list( + final_ld_mat <- if (has_genotype) reference_for_variants(final_vars) else R_mat + preprocess_sumstats <- preprocess$sumstats + preprocess_ld_mat <- preprocess$LD_mat + preprocess_ld_vars <- if (has_genotype) preprocess_sumstats$variant_id else rownames(preprocess_ld_mat) + QCResult( + ld_data = .qc_ld_data_from_matrix(final_ld_mat, final_vars, has_genotype, n_ref_for_qc), rss_input = list(sumstats = sumstats, n = n, var_y = var_y), - LD_matrix = if (has_genotype) reference_for_variants(final_vars) else R_mat, - preprocess = preprocess, + preprocess = list( + sumstats = preprocess_sumstats, + ld_data = .qc_ld_data_from_matrix(preprocess_ld_mat, preprocess_ld_vars, has_genotype, n_ref_for_qc) + ), outlier_number = outlier_number, skipped = FALSE ) diff --git a/R/susie_wrapper.R b/R/susie_wrapper.R index ec5b95cb..e73aa861 100644 --- a/R/susie_wrapper.R +++ b/R/susie_wrapper.R @@ -840,8 +840,9 @@ adjust_susie_weights <- function(twas_weights_results, keep_variants, run_allele "pos", "A2", "A1" )], match_min_prop = match_min_prop) # match_ref_panel outputs canonical variant_ids (with chr prefix) - original_idx <- match(weights_matrix_qced$qc_summary$variants_id_original, twas_weights_variants) - intersected_indices <- original_idx[weights_matrix_qced$qc_summary$keep == TRUE] + qc_summary_df <- getQCSummary(weights_matrix_qced) + original_idx <- match(qc_summary_df$variants_id_original, twas_weights_variants) + intersected_indices <- original_idx[qc_summary_df$keep == TRUE] } else { # Normalize keep_variants to canonical format for matching keep_variants_normalized <- normalize_variant_id(keep_variants) @@ -865,7 +866,7 @@ adjust_susie_weights <- function(twas_weights_results, keep_variants, run_allele adjusted_xqtl_coef <- colSums(adjusted_xqtl_alpha * mu_subset) / x_column_scal_factors_subset # allele_qc now outputs canonical variant_ids (with chr prefix) -- no need to add chr remained_variants_ids <- if (run_allele_qc) { - weights_matrix_qced$target_data_qced$variant_id + getHarmonizedData(weights_matrix_qced)$variant_id } else { intersected_variants } diff --git a/R/twas.R b/R/twas.R index 6834f236..7d950f1b 100644 --- a/R/twas.R +++ b/R/twas.R @@ -90,7 +90,7 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, colnames(weights_matrix)[!colnames(weights_matrix) %in% c("chrom", "pos", "A2", "A1")], match_min_prop = 0 ) - qced_data <- weights_matrix_qced$target_data_qced + qced_data <- getHarmonizedData(weights_matrix_qced) weights_matrix_subset <- as.matrix(qced_data[, !colnames(qced_data) %in% c( "chrom", "pos", "A2", "A1", "variant_id", "variants_id_original" ), drop = FALSE]) @@ -128,11 +128,13 @@ harmonize_twas <- function(twas_weights_data, ld_meta_file_path, gwas_meta_file, names(susie_intermediate[["pip"]]) <- original_weight_variants # original variants not yet qced pip <- susie_intermediate[["pip"]] pip_qced <- match_ref_panel(cbind(parse_variant_id(names(pip)), pip), sketch_variant_ids, "pip", match_min_prop = 0) - susie_intermediate[["pip"]] <- abs(pip_qced$target_data_qced$pip) - names(susie_intermediate[["pip"]]) <- pip_qced$target_data_qced$variant_id + pip_qced_df <- getHarmonizedData(pip_qced) + susie_intermediate[["pip"]] <- abs(pip_qced_df$pip) + names(susie_intermediate[["pip"]]) <- pip_qced_df$variant_id susie_intermediate[["cs_variants"]] <- lapply(susie_intermediate[["cs_variants"]], function(x) { variant_qc <- match_ref_panel(x, sketch_variant_ids, match_min_prop = 0) - variant_qc$target_data_qced$variant_id[variant_qc$target_data_qced$variant_id %in% postqc_weight_variants] + variant_qc_df <- getHarmonizedData(variant_qc) + variant_qc_df$variant_id[variant_qc_df$variant_id %in% postqc_weight_variants] }) mol_res[["susie_weights_intermediate_qced"]][[context]] <- susie_intermediate } @@ -219,7 +221,7 @@ harmonize_gwas <- function(gwas_file, query_region, ld_variants, col_to_flip=NUL # check for overlapping variants if (!any(gwas_data_sumstats$pos %in% gsub("\\:.*$", "", sub("^.*?\\:", "", ld_variants)))) return(NULL) gwas_allele_flip <- match_ref_panel(gwas_data_sumstats, ld_variants, col_to_flip=col_to_flip, match_min_prop = match_min_prop) - gwas_data_sumstats <- gwas_allele_flip$target_data_qced # post-qc gwas data that is flipped and corrected - gwas study level + gwas_data_sumstats <- getHarmonizedData(gwas_allele_flip) # post-qc gwas data that is flipped and corrected - gwas study level gwas_data_sumstats <- gwas_data_sumstats[!is.na(gwas_data_sumstats$z) & !is.infinite(gwas_data_sumstats$z), ] return(gwas_data_sumstats) } diff --git a/R/twas_weights.R b/R/twas_weights.R index 4bf72827..38d2f925 100644 --- a/R/twas_weights.R +++ b/R/twas_weights.R @@ -1762,13 +1762,14 @@ twas_weights_sumstat_pipeline <- function( impute_opts = impute_opts, return_on_skip = "null" ) - if (is.null(qc_result) || isTRUE(qc_result$skipped)) { + if (is.null(qc_result) || isSkipped(qc_result)) { return(list(twas_weights = NULL, finemapping_result = NULL, qc_summary = list(skipped = TRUE))) } - sumstats <- qc_result$rss_input$sumstats - LD_mat <- qc_result$LD_matrix - outlier_number <- qc_result$outlier_number + sumstats <- getRSSInput(qc_result)$sumstats + qc_ld <- getLDData(qc_result) + LD_mat <- if (is.null(qc_ld)) NULL else if (hasGenotypes(qc_ld)) getGenotypes(qc_ld) else getCorrelation(qc_ld) + outlier_number <- getOutlierNumber(qc_result) } else { # No QC requested: extract LD matrix directly if (is.matrix(LD_data)) { diff --git a/R/univariate_pipeline.R b/R/univariate_pipeline.R index 7844f1d7..47eeb704 100644 --- a/R/univariate_pipeline.R +++ b/R/univariate_pipeline.R @@ -190,8 +190,9 @@ univariate_analysis_pipeline <- function( #' @param ld_path A single LD metadata TSV path, or comma-separated paths for #' mixture panels (e.g., "ld_EUR.tsv,ld_AFR.tsv"). #' @param region Region string "chr:start-end". -#' @return An LD_data list from load_LD_matrix. For single panels, returns as-is. -#' For mixture panels, LD_matrix is a list of X matrices (one per panel). +#' @return An \code{LDData} S4 object. For single panels, returns the result of +#' \code{load_LD_matrix()} unchanged. For mixture panels, \code{genotype_handle} +#' is a list of per-panel genotype handles sharing the first panel's variants. #' @export load_study_LD <- function(ld_path, region) { paths <- strsplit(ld_path, ",")[[1]] @@ -222,11 +223,10 @@ load_study_LD <- function(ld_path, region) { #' #' @param sumstat_path File path to the summary statistics. #' @param column_file_path File path to the column mapping file. -#' @param LD_data A list from load_LD_matrix containing LD_matrix, LD_variants, -#' ref_panel, block_metadata, and is_genotype flag. When is_genotype=TRUE -#' (from return_genotype=TRUE), LD_matrix contains genotype X and susie_rss -#' uses the z+X interface. Local R is computed only for QC stages that -#' require a correlation matrix. +#' @param LD_data An \code{LDData} S4 object from \code{load_LD_matrix()}. When +#' \code{hasGenotypes(LD_data)} is TRUE (from \code{return_genotype=TRUE}), +#' susie_rss uses the z+X interface via \code{getGenotypes()}. Local R is +#' computed only for QC stages that require a correlation matrix. #' @param n_sample Sample size. If 0, retrieved from the sumstat file. #' @param n_case Number of cases (for case-control studies). #' @param n_control Number of controls (for case-control studies). @@ -314,52 +314,26 @@ rss_analysis_pipeline <- function( R_finite = R_finite, R_mismatch = R_mismatch ) - if (!is.null(qc_record$rss_input)) { - preprocess_results <- qc_record$preprocess - sumstats <- qc_record$rss_input$sumstats - LD_mat <- qc_record$LD_matrix - qc_results <- list(outlier_number = qc_record$outlier_number) - } else { - # Compatibility for tests or callers that mock the historical - # LD-mismatch-only summary_stats_qc() return shape. - preprocess_results <- list(sumstats = qc_record$sumstats, LD_mat = qc_record$LD_mat) - sumstats <- qc_record$sumstats - LD_mat <- qc_record$LD_mat - qc_results <- qc_record - if (isTRUE(impute)) { - ref_panel <- getRefPanel(LD_data) - if (use_X) { - X_sub <- subset_X_data(getVariantIds(LD_data)) - if (is_X_list) { - X_scaled <- lapply(X_sub, function(Xk) { Xk <- scale(Xk); Xk[is.na(Xk)] <- 0; Xk }) - } else { - X_scaled <- scale(X_sub) - X_scaled[is.na(X_scaled)] <- 0 - } - impute_results <- raiss(ref_panel, sumstats, - genotype_matrix = X_scaled, - R2_threshold = impute_opts$R2_threshold, - minimum_ld = impute_opts$minimum_ld, - lamb = impute_opts$lamb) - } else { - LD_matrix <- partition_LD_matrix(LD_data) - impute_results <- raiss(ref_panel, sumstats, LD_matrix, - rcond = impute_opts$rcond, - R2_threshold = impute_opts$R2_threshold, - minimum_ld = impute_opts$minimum_ld, - lamb = impute_opts$lamb) - } - sumstats <- impute_results$result_filter - LD_mat <- impute_results$LD_mat - } - qc_record$skipped <- FALSE + if (!is(qc_record, "QCResult")) { + stop("summary_stats_qc must return a QCResult object.") } + rss_record <- getRSSInput(qc_record) + sumstats <- rss_record$sumstats + qc_ld <- getLDData(qc_record) + LD_mat <- if (is.null(qc_ld)) NULL else if (hasGenotypes(qc_ld)) getGenotypes(qc_ld) else getCorrelation(qc_ld) + preprocess_snapshot <- getPreprocess(qc_record) + preprocess_ld <- preprocess_snapshot$ld_data + preprocess_results <- list( + sumstats = preprocess_snapshot$sumstats, + LD_mat = if (is.null(preprocess_ld)) NULL else if (hasGenotypes(preprocess_ld)) getGenotypes(preprocess_ld) else getCorrelation(preprocess_ld) + ) + qc_results <- list(outlier_number = getOutlierNumber(qc_record)) if (nrow(sumstats) == 0) { message("No variants left after preprocessing. Returning empty results.") return(list(rss_data_analyzed = sumstats)) } - if (isTRUE(qc_record$skipped)) { + if (isSkipped(qc_record)) { return(list(rss_data_analyzed = sumstats)) } diff --git a/R/univariate_rss_diagnostics.R b/R/univariate_rss_diagnostics.R index b577ebf9..b3b06a90 100644 --- a/R/univariate_rss_diagnostics.R +++ b/R/univariate_rss_diagnostics.R @@ -1,33 +1,21 @@ -#' Extract SuSiE Results from Finemapping Data +#' Extract the trimmed SuSiE fit from a finemapping pipeline result #' -#' This function extracts the trimmed SuSiE results from a finemapping data object, -#' typically obtained from a finemapping RDS file. It's designed to work with -#' the method layer of these files, often named as 'method_RAISS_imputed', 'method', -#' or 'method_NO_QC'. This layer is right under the study layer. -#' -#' @param con_data List. The method layer data from a finemapping RDS file. -#' -#' @return The trimmed SuSiE results (`$susie_result_trimmed`) if available, -#' otherwise NULL. -#' -#' @details -#' The function checks if the input data is empty or if the `$susie_result_trimmed` -#' element is missing. It returns NULL in these cases. If `$susie_result_trimmed` -#' exists and is not empty, it returns this element. -#' -#' @note -#' This function is particularly useful when working with large datasets -#' where not all method layers may contain valid SuSiE results or method layer. +#' Returns the trimmed model fit underlying \code{con_data$finemapping_result} +#' (a \code{FineMappingResult} S4 object), or NULL if no fine-mapping result +#' is attached. #' +#' @param con_data List. The method-layer entry from a finemapping pipeline +#' result, expected to carry \code{$finemapping_result} as a +#' \code{FineMappingResult} object. +#' @return The trimmed fit (a list with \code{pip}, \code{sets}, etc.) or NULL. #' @export get_susie_result <- function(con_data) { - if (length(con_data) == 0) return(NULL) - if (length(con_data$susie_result_trimmed) == 0) { - return(NULL) - print(paste("$susie_result_trimmed is null for", con_data)) - } else { - return(con_data$susie_result_trimmed) - } + if (length(con_data) == 0) return(NULL) + fm <- con_data$finemapping_result + if (is.null(fm) || !is(fm, "FineMappingResult")) return(NULL) + trimmed <- getTrimmedFit(fm) + if (length(trimmed) == 0) return(NULL) + trimmed } #' Process Credible Sets (CS) from Finemapping Results @@ -60,26 +48,29 @@ get_susie_result <- function(con_data) { #' @importFrom dplyr bind_rows #' #' @export -extract_cs_info <- function(con_data, cs_names, top_loci_table) { +extract_cs_info <- function(con_data, cs_names, top_loci_table) { + fm <- con_data$finemapping_result + trimmed <- getTrimmedFit(fm) + variant_names <- getVariantNames(fm) results <- map(seq_along(cs_names), function(i) { cs_name <- cs_names[i] - indices <- con_data$susie_result_trimmed$sets$cs[[cs_name]] - + indices <- trimmed$sets$cs[[cs_name]] + # Get variants for this CS using the full variant_names list - cs_variants <- con_data$variant_names[indices] + cs_variants <- variant_names[indices] cs_data <- top_loci_table[top_loci_table$variant_id %in% cs_variants, ] top_row <- which.max(cs_data$pip) - + top_variant <- cs_data$variant_id[top_row] # Find the global index of the top variant - top_variant_global_index = which(con_data$variant_names == top_variant) + top_variant_global_index = which(variant_names == top_variant) top_pip <- cs_data$pip[top_row] top_z <- cs_data$z[top_row] p_value <- z_to_pvalue(top_z) - + # Extract cs_corr cs_corr <- if (length(cs_names) > 1) { - con_data$susie_result_trimmed$cs_corr[i,] + trimmed$cs_corr[i,] } else { NA # Use NA for the second CS or when there's only one CS } @@ -137,10 +128,13 @@ extract_cs_info <- function(con_data, cs_names, top_loci_table) { #' #' @export extract_top_pip_info <- function(con_data) { + fm <- con_data$finemapping_result + trimmed <- getTrimmedFit(fm) + variant_names <- getVariantNames(fm) # Find the variant with the highest PIP - top_pip_index <- which.max(con_data$susie_result_trimmed$pip) - top_pip <- con_data$susie_result_trimmed$pip[top_pip_index] - top_variant <- con_data$variant_names[top_pip_index] + top_pip_index <- which.max(trimmed$pip) + top_pip <- trimmed$pip[top_pip_index] + top_variant <- variant_names[top_pip_index] top_z <- con_data$sumstats$z[top_pip_index] p_value <- z_to_pvalue(top_z) diff --git a/man/dot-legacy_list_to_LDData.Rd b/man/dot-legacy_list_to_LDData.Rd deleted file mode 100644 index 04278c61..00000000 --- a/man/dot-legacy_list_to_LDData.Rd +++ /dev/null @@ -1,18 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/file_utils.R -\name{.legacy_list_to_LDData} -\alias{.legacy_list_to_LDData} -\title{Convert loaded regional data to RSS inputs} -\usage{ -.legacy_list_to_LDData(ld_list) -} -\arguments{ -\item{region_data}{A list returned by \code{load_multitask_regional_data()}.} -} -\value{ -A list containing named RSS inputs, matched LD data, and source - information. -} -\description{ -Convert loaded regional data to RSS inputs -} diff --git a/tests/testthat/test_allele_qc.R b/tests/testthat/test_allele_qc.R index 69c6b60a..d519b0f9 100644 --- a/tests/testthat/test_allele_qc.R +++ b/tests/testthat/test_allele_qc.R @@ -116,7 +116,7 @@ test_that("Check that we correctly remove stand ambiguous SNPs",{ output <- allele_qc( res$target_data, res$ref_variants, "beta", match_min_prop = 0.2, TRUE, FALSE, TRUE) - expect_equal(nrow(output$target_data_qced), 80) + expect_equal(nrow(getHarmonizedData(output)), 80) }) test_that("Check that we correctly remove non-ACTG coding SNPs",{ @@ -124,7 +124,7 @@ test_that("Check that we correctly remove non-ACTG coding SNPs",{ output <- allele_qc( res$target_data, res$ref_variants, "beta", match_min_prop = 0.2, TRUE, FALSE, TRUE) - expect_equal(nrow(output$target_data_qced), 40) + expect_equal(nrow(getHarmonizedData(output)), 40) }) test_that("Check that execution stops if not enough variants are matched",{ @@ -144,7 +144,7 @@ test_that("allele_qc matches exact alleles", { A2 = c("A", "C"), A1 = c("G", "T") ) result <- allele_qc(target, ref, match_min_prop = 0) - expect_equal(nrow(result$target_data_qced), 2) + expect_equal(nrow(getHarmonizedData(result)), 2) }) test_that("allele_qc detects sign flips", { @@ -158,23 +158,23 @@ test_that("allele_qc detects sign flips", { A2 = "G", A1 = "A" ) result <- allele_qc(target, ref, col_to_flip = "z", match_min_prop = 0) - expect_equal(nrow(result$target_data_qced), 1) + expect_equal(nrow(getHarmonizedData(result)), 1) # z should be flipped - expect_equal(result$target_data_qced$z, -2.5) + expect_equal(getHarmonizedData(result)$z, -2.5) }) test_that("allele_qc handles string input format", { target <- c("1:100:A:G", "1:200:C:T") ref <- c("1:100:A:G", "1:200:C:T") result <- allele_qc(target, ref, match_min_prop = 0) - expect_equal(nrow(result$target_data_qced), 2) + expect_equal(nrow(getHarmonizedData(result)), 2) }) test_that("allele_qc with chr prefix", { target <- c("chr1:100:A:G", "chr1:200:C:T") ref <- c("chr1:100:A:G", "chr1:200:C:T") result <- allele_qc(target, ref, match_min_prop = 0) - expect_equal(nrow(result$target_data_qced), 2) + expect_equal(nrow(getHarmonizedData(result)), 2) }) test_that("allele_qc warns when too few matches", { @@ -196,7 +196,7 @@ test_that("allele_qc with no matching positions returns empty", { result <- allele_qc(target, ref, match_min_prop = 0), "No matching variants" ) - expect_equal(nrow(result$target_data_qced), 0) + expect_equal(nrow(getHarmonizedData(result)), 0) }) test_that("allele_qc preserves extra columns", { @@ -210,8 +210,8 @@ test_that("allele_qc preserves extra columns", { A2 = "A", A1 = "G" ) result <- allele_qc(target, ref, match_min_prop = 0) - expect_true("beta" %in% colnames(result$target_data_qced)) - expect_true("se" %in% colnames(result$target_data_qced)) + expect_true("beta" %in% colnames(getHarmonizedData(result))) + expect_true("se" %in% colnames(getHarmonizedData(result))) }) test_that("allele_qc with lowercase alleles", { @@ -224,7 +224,7 @@ test_that("allele_qc with lowercase alleles", { A2 = "A", A1 = "G" ) result <- allele_qc(target, ref, match_min_prop = 0) - expect_equal(nrow(result$target_data_qced), 1) + expect_equal(nrow(getHarmonizedData(result)), 1) }) test_that("align_variant_names correctly aligns variant names", { @@ -369,7 +369,7 @@ test_that("allele_qc handles data frame with NULL colnames after merge", { # Restore chrom for the join colnames(target)[1] <- "chrom" result <- allele_qc(target, ref, match_min_prop = 0) - expect_equal(nrow(result$target_data_qced), 1) + expect_equal(nrow(getHarmonizedData(result)), 1) }) # ---- target_data with redundant columns (allele_qc.R line 75) ---- @@ -380,9 +380,9 @@ test_that("allele_qc removes redundant columns from target_data before join", { ) ref <- data.frame(chrom = 1, pos = 100, A2 = "A", A1 = "G") result <- allele_qc(target, ref, match_min_prop = 0) - expect_equal(nrow(result$target_data_qced), 1) + expect_equal(nrow(getHarmonizedData(result)), 1) # The redundant columns should have been removed before the join - expect_true("variant_id" %in% colnames(result$target_data_qced)) + expect_true("variant_id" %in% colnames(getHarmonizedData(result))) }) # ---- col_to_flip with nonexistent column (allele_qc.R line 130) ---- @@ -408,7 +408,7 @@ test_that("allele_qc warns and removes duplicate variants", { result <- allele_qc(target, ref, match_min_prop = 0, remove_dups = TRUE), "duplicate variant" ) - expect_equal(nrow(result$target_data_qced), 1) + expect_equal(nrow(getHarmonizedData(result)), 1) }) # ---- duplicated variant IDs error (allele_qc.R line 180) ---- diff --git a/tests/testthat/test_colocboost_pipeline.R b/tests/testthat/test_colocboost_pipeline.R index 4eea18ee..2ae3dbec 100644 --- a/tests/testthat/test_colocboost_pipeline.R +++ b/tests/testthat/test_colocboost_pipeline.R @@ -4,6 +4,82 @@ context("colocboost_pipeline") # Tests from test_colocboost_pipeline.R # =========================================================================== +# Wrap a correlation (or genotype) matrix into an LDData for use in test mocks +# that previously used bare matrices for LD_mat/LD_info fields. +.test_lddata_from_matrix <- function(mat, is_genotype = FALSE) { + vids <- if (is_genotype) colnames(mat) else rownames(mat) + if (is.null(vids)) vids <- colnames(mat) + ref_panel <- cbind(parse_variant_id(vids), variant_id = vids) + ref_panel$chrom <- as.character(ref_panel$chrom) + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + bm <- pecotmr:::.infer_single_ld_block_metadata(ref_panel) + if (is_genotype) { + LDData(correlation = NULL, genotype_handle = mat, + variants = variants_gr, block_metadata = bm, + n_ref = as.integer(nrow(mat))) + } else { + LDData(correlation = mat, variants = variants_gr, block_metadata = bm) + } +} + +# Wrap one (rss_input, LD_matrix) pair as a QCResult for mocks that previously +# returned the legacy list shape. +.test_qcresult_from_list <- function(rss_input, LD_mat) { + QCResult( + ld_data = .test_lddata_from_matrix(LD_mat), + rss_input = rss_input, + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ) +} + +# Build a RegionalData S4 object from legacy-style per-context lists of +# residual_Y / residual_X / maf matrices, so existing tests can keep their +# inline list builders. +.test_regionaldata_from_lists <- function(residual_Y, residual_X, maf = NULL) { + contexts <- names(residual_Y) + if (is.null(contexts)) contexts <- paste0("ctx", seq_along(residual_Y)) + # Pick the first X as the canonical genotype matrix; tests share X across + # contexts in nearly all cases. Ensure sample rownames are present. + X0 <- residual_X[[1]] + if (is.null(rownames(X0))) { + rownames(X0) <- paste0("sample", seq_len(nrow(X0))) + } + sample_ids <- rownames(X0) + # Align each phenotype matrix to the same sample IDs (assume same n). + phenotypes <- stats::setNames(lapply(seq_along(contexts), function(i) { + y <- residual_Y[[i]] + if (!is.matrix(y)) y <- as.matrix(y) + if (is.null(rownames(y))) rownames(y) <- sample_ids[seq_len(nrow(y))] + y + }), contexts) + covariates <- stats::setNames(lapply(contexts, function(c) { + matrix(numeric(0), nrow = nrow(X0), ncol = 0, + dimnames = list(sample_ids, NULL)) + }), contexts) + if (is.null(maf)) { + maf_list <- stats::setNames(lapply(contexts, function(c) { + rep(0.1, ncol(X0)) + }), contexts) + } else { + maf_list <- if (is.null(names(maf))) stats::setNames(maf, contexts) else maf + if (length(maf_list) < length(contexts)) { + maf_list <- stats::setNames(rep(list(maf_list[[1]]), length(contexts)), contexts) + } + } + RegionalData( + genotype_matrix = X0, + phenotypes = phenotypes, + covariates = covariates, + scale_residuals = FALSE, + maf = maf_list, + region = NULL, + dropped_samples = list(X = list(), Y = list(), covar = list()), + Y_coordinates = NULL + ) +} + # ---- qc_method match.arg ---- test_that("qc_regional_data accepts explicit qc_method = 'slalom'", { @@ -33,7 +109,7 @@ test_that("pip_cutoff scalar is recycled for individual contexts", { list(X = X, Y = Y, maf = runif(p, 0.05, 0.5)) } ctx <- make_ctx() - individual_data <- list( + individual_data <- .test_regionaldata_from_lists( residual_Y = list(ctx1 = ctx$Y, ctx2 = ctx$Y, ctx3 = ctx$Y), residual_X = list(ctx1 = ctx$X, ctx2 = ctx$X, ctx3 = ctx$X), maf = list(ctx1 = ctx$maf, ctx2 = ctx$maf, ctx3 = ctx$maf) @@ -49,8 +125,10 @@ test_that("pip_cutoff wrong length errors for individual contexts", { set.seed(42) n <- 10; p <- 5 X <- matrix(rnorm(n * p), n, p) + colnames(X) <- paste0("var", 1:p) Y <- matrix(rnorm(n), n, 1) - individual_data <- list( + colnames(Y) <- "gene1" + individual_data <- .test_regionaldata_from_lists( residual_Y = list(ctx1 = Y, ctx2 = Y, ctx3 = Y), residual_X = list(ctx1 = X, ctx2 = X, ctx3 = X), maf = list(ctx1 = runif(p), ctx2 = runif(p), ctx3 = runif(p)) @@ -70,7 +148,7 @@ test_that("pip_cutoff correct length vector works", { colnames(X) <- paste0("var", 1:p) Y <- matrix(rnorm(n), n, 1) colnames(Y) <- "gene1" - individual_data <- list( + individual_data <- .test_regionaldata_from_lists( residual_Y = list(ctx1 = Y, ctx2 = Y), residual_X = list(ctx1 = X, ctx2 = X), maf = list(ctx1 = runif(p, 0.05, 0.5), ctx2 = runif(p, 0.05, 0.5)) @@ -92,22 +170,34 @@ test_that("pip_cutoff correct length vector works", { # =========================================================================== make_individual_region_data <- function(n = 20, p = 8, n_contexts = 2, n_events = 3) { set.seed(101) - make_ctx <- function(ctx_name) { - X <- matrix(rnorm(n * p), n, p) - colnames(X) <- paste0("chr1:", seq_len(p) * 100, ":A:G") - Y <- matrix(rnorm(n * n_events), n, n_events) - colnames(Y) <- paste0("event", seq_len(n_events)) - maf <- runif(p, 0.05, 0.45) - list(X = X, Y = Y, maf = maf) - } - ctxs <- lapply(paste0("ctx", seq_len(n_contexts)), make_ctx) - names(ctxs) <- paste0("ctx", seq_len(n_contexts)) + sample_ids <- paste0("sample", seq_len(n)) + var_ids <- paste0("chr1:", seq_len(p) * 100, ":A:G") + X <- matrix(rnorm(n * p), n, p, dimnames = list(sample_ids, var_ids)) + context_names <- paste0("ctx", seq_len(n_contexts)) + phenotypes <- stats::setNames(lapply(seq_len(n_contexts), function(i) { + Y <- matrix(rnorm(n * n_events), n, n_events, + dimnames = list(sample_ids, paste0("event", seq_len(n_events)))) + Y + }), context_names) + # Per-context covariates: empty intercept-only model (n x 0 with rownames) + covariates <- stats::setNames(lapply(seq_len(n_contexts), function(i) { + matrix(numeric(0), nrow = n, ncol = 0, dimnames = list(sample_ids, NULL)) + }), context_names) + maf_list <- stats::setNames(lapply(seq_len(n_contexts), function(i) { + runif(p, 0.05, 0.45) + }), context_names) + rd <- RegionalData( + genotype_matrix = X, + phenotypes = phenotypes, + covariates = covariates, + scale_residuals = FALSE, + maf = maf_list, + region = NULL, + dropped_samples = list(X = list(), Y = list(), covar = list()), + Y_coordinates = NULL + ) list( - individual_data = list( - residual_Y = lapply(ctxs, `[[`, "Y"), - residual_X = lapply(ctxs, `[[`, "X"), - maf = lapply(ctxs, `[[`, "maf") - ), + individual_data = rd, sumstat_data = NULL ) } @@ -123,10 +213,11 @@ make_sumstat_region_data <- function(n_variants = 5, n_studies = 2) { rownames(LD_mat) <- colnames(LD_mat) <- vids ref_panel <- data.frame( - chrom = rep(1, n_variants), + chrom = as.character(rep(1, n_variants)), pos = seq_len(n_variants) * 100, A2 = rep("A", n_variants), A1 = rep("G", n_variants), + variant_id = vids, stringsAsFactors = FALSE ) @@ -149,17 +240,18 @@ make_sumstat_region_data <- function(n_variants = 5, n_studies = 2) { list(ss) |> setNames(paste0("study", i)) }) - LD_info <- list(list( - LD_variants = ref_panel, - LD_matrix = LD_mat, - ref_panel = ref_panel - )) + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + ld_data <- LDData( + correlation = LD_mat, + variants = variants_gr, + block_metadata = pecotmr:::.infer_single_ld_block_metadata(ref_panel) + ) list( individual_data = NULL, sumstat_data = list( sumstats = sumstats_list, - LD_info = LD_info + LD_info = list(ld_data) ) ) } @@ -170,12 +262,27 @@ test_that("qc_regional_data treats NULL qc_method as basic-only none", { LD_mat <- diag(1) rownames(LD_mat) <- colnames(LD_mat) <- "chr1:100:A:G" + ref_panel_one <- data.frame( + chrom = "1", pos = 100L, A2 = "A", A1 = "G", + variant_id = "chr1:100:A:G", stringsAsFactors = FALSE + ) + variants_gr_one <- pecotmr:::.ref_panel_to_granges(ref_panel_one) + ld_data_one <- LDData( + correlation = LD_mat, + variants = variants_gr_one, + block_metadata = pecotmr:::.infer_single_ld_block_metadata(ref_panel_one) + ) + local_mocked_bindings( summary_stats_qc = function(..., qc_method) { captured_qc_method <<- qc_method - list(study1 = list( - rss_input = list(sumstats = data.frame(variant_id = "chr1:100:A:G")), - LD_matrix = LD_mat + list(study1 = QCResult( + ld_data = ld_data_one, + rss_input = list(sumstats = data.frame(variant_id = "chr1:100:A:G"), + n = 1000, var_y = 1), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE )) } ) @@ -192,10 +299,14 @@ test_that("colocboost_pipeline default qc_method resolves to basic-only none", { local_mocked_bindings( qc_regional_data = function(region_data, ..., qc_method) { captured_qc_method <<- qc_method + ind <- region_data$individual_data list( individual_data = list( - Y = region_data$individual_data$residual_Y, - X = region_data$individual_data$residual_X + Y = ind@phenotypes, + X = stats::setNames( + lapply(seq_along(ind@phenotypes), function(i) ind@genotype_matrix), + names(ind@phenotypes) + ) ), sumstat_data = NULL ) @@ -304,14 +415,8 @@ test_that("ColocBoost adapters accept genotype-backed LDData", { expect_equal(result$args$M, 2) }) -test_that("RegionalData individual adapter restores context names from residual_Y", { +test_that("RegionalData individual adapter exposes context names from phenotypes", { ind_region <- make_individual_region_data(n = 12, p = 5, n_contexts = 2, n_events = 1) - names(ind_region$individual_data$residual_X) <- NULL - names(ind_region$individual_data$maf) <- NULL - ind_region$individual_data$X_variance <- lapply( - ind_region$individual_data$residual_X, - function(x) matrixStats::colVars(x) - ) ind_input <- region_data_to_ind_input(ind_region) expect_equal(names(ind_input$X), c("ctx1", "ctx2")) @@ -319,7 +424,8 @@ test_that("RegionalData individual adapter restores context names from residual_ expect_equal(names(ind_input$X_variance), c("ctx1", "ctx2")) converted <- region_data_to_colocboost_input(ind_region) - expect_equal(length(converted$colocboost_input$X), 2) + # X is shared across contexts in RegionalData; deduplication yields one X. + expect_equal(length(converted$colocboost_input$X), 1) expect_equal(nrow(converted$colocboost_input$dict_YX), 2) }) @@ -371,21 +477,29 @@ test_that("region_data_to_colocboost_input returns core and QC inputs", { expect_equal(nrow(converted$colocboost_input$dict_YX), 2) }) -test_that("region_data_to_colocboost_input converts genotype X_ref to LD correlation", { +test_that("region_data_to_colocboost_input routes genotype LDData through X_ref", { region_data <- make_sumstat_region_data(n_variants = 5, n_studies = 1) variants <- region_data$sumstat_data$sumstats[[1]][[1]]$sumstats$variant_id X_ref <- matrix(rnorm(50), 10, 5) colnames(X_ref) <- variants - region_data$sumstat_data$LD_info[[1]]$LD_matrix <- X_ref - region_data$sumstat_data$LD_info[[1]]$is_genotype <- TRUE + + ref_panel <- cbind(parse_variant_id(variants), variant_id = variants) + ref_panel$chrom <- as.character(ref_panel$chrom) + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + region_data$sumstat_data$LD_info[[1]] <- LDData( + correlation = NULL, + genotype_handle = X_ref, + variants = variants_gr, + block_metadata = pecotmr:::.infer_single_ld_block_metadata(ref_panel), + n_ref = nrow(X_ref) + ) converted <- region_data_to_colocboost_input(region_data) - # With S4 migration, genotype data is converted to correlation in LD - expect_null(converted$colocboost_input$X_ref) - expect_equal(length(converted$colocboost_input$LD), 1) - expect_equal(dim(converted$colocboost_input$LD[[1]]), c(5, 5)) - expect_equal(colnames(converted$colocboost_input$LD[[1]]), variants) + expect_null(converted$colocboost_input$LD) + expect_equal(length(converted$colocboost_input$X_ref), 1) + expect_equal(dim(converted$colocboost_input$X_ref[[1]]), c(10, 5)) + expect_equal(colnames(converted$colocboost_input$X_ref[[1]]), variants) }) test_that("region_data_to_colocboost_input preserves duplicated outcome names across contexts", { @@ -395,12 +509,14 @@ test_that("region_data_to_colocboost_input preserves duplicated outcome names ac expect_equal(names(converted$colocboost_input$Y), c("ctx1_event1", "ctx2_event1")) expect_equal(length(converted$colocboost_input$Y), 2) expect_equal(nrow(converted$colocboost_input$dict_YX), 2) - expect_equal(converted$colocboost_input$dict_YX[, "X"], c(1, 2)) + # X is shared across contexts in RegionalData; dict_YX maps both Y to X #1. + expect_equal(converted$colocboost_input$dict_YX[, "X"], c(1, 1)) }) test_that("region_data_to_colocboost_input deduplicates shared individual X", { + # RegionalData shares one genotype matrix across all conditions, so the + # per-context residualized X is identical when covariates are the same. region_data <- make_individual_region_data(n = 12, p = 5, n_contexts = 2, n_events = 1) - region_data$individual_data$residual_X$ctx2 <- region_data$individual_data$residual_X$ctx1 converted <- region_data_to_colocboost_input(region_data) expect_equal(length(converted$colocboost_input$X), 1) @@ -429,9 +545,10 @@ test_that("region_data_to_colocboost_input combines individual and RSS inputs", test_that("qc_individual_data uses existing genotype filtering helpers", { region_data <- make_individual_region_data(n = 12, p = 5, n_contexts = 1, n_events = 2) - X <- region_data$individual_data$residual_X - Y <- region_data$individual_data$residual_Y - maf <- region_data$individual_data$maf + ind_input <- region_data_to_ind_input(region_data) + X <- ind_input$X + Y <- ind_input$Y + maf <- ind_input$maf expect_message( result <- qc_individual_data(X, Y, maf = maf, maf_cutoff = 0), "QC track" @@ -442,8 +559,9 @@ test_that("qc_individual_data uses existing genotype filtering helpers", { test_that("qc_individual_data supports direct matrix inputs and keeps context labels", { region_data <- make_individual_region_data(n = 12, p = 5, n_contexts = 1, n_events = 2) - X <- region_data$individual_data$residual_X$ctx1 - Y <- region_data$individual_data$residual_Y$ctx1 + ind_input <- region_data_to_ind_input(region_data) + X <- ind_input$X$ctx1 + Y <- ind_input$Y$ctx1 maf <- stats::setNames(rep(0.2, ncol(X)), colnames(X)) maf[1] <- 0.001 @@ -469,7 +587,8 @@ test_that("summary_stats_qc runs combined basic harmonization when qc_method is "basic allele harmonization" ) expect_equal(names(result), "study1") - expect_true(nrow(result$study1$rss_input$sumstats) > 0) + expect_true(is(result$study1, "QCResult")) + expect_true(nrow(getRSSInput(result$study1)$sumstats) > 0) }) test_that("summary_stats_qc returns one cleaned record for one RSS record", { @@ -485,9 +604,9 @@ test_that("summary_stats_qc returns one cleaned record for one RSS record", { ), "basic allele harmonization" ) - expect_true("rss_input" %in% names(result)) - expect_true("LD_matrix" %in% names(result)) - expect_true(nrow(result$rss_input$sumstats) > 0) + expect_true(is(result, "QCResult")) + expect_false(is.null(getLDData(result))) + expect_true(nrow(getRSSInput(result)$sumstats) > 0) }) test_that("summary_stats_qc treats a study named sumstats as multiple-study input", { @@ -506,8 +625,8 @@ test_that("summary_stats_qc treats a study named sumstats as multiple-study inpu "basic allele harmonization" ) expect_equal(names(result), c("sumstats", "study2")) - expect_true("rss_input" %in% names(result$sumstats)) - expect_true("rss_input" %in% names(result$study2)) + expect_true(is(result$sumstats, "QCResult")) + expect_true(is(result$study2, "QCResult")) }) test_that("summary_stats_qc imputes when block metadata can be inferred from LD matrix", { @@ -525,7 +644,8 @@ test_that("summary_stats_qc imputes when block metadata can be inferred from LD "running imputation" ) expect_equal(names(result), "study1") - expect_true(nrow(result$study1$rss_input$sumstats) > 0) + expect_true(is(result$study1, "QCResult")) + expect_true(nrow(getRSSInput(result$study1)$sumstats) > 0) }) test_that("colocboost_analysis directly forwards core inputs without QC", { @@ -1068,7 +1188,7 @@ test_that("colocboost_pipeline is the protocol entry", { Y <- matrix(rnorm(10), 10, 1) colnames(Y) <- "gene1" region_data <- list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = list(ctx1 = Y), residual_X = list(ctx1 = X), maf = list(ctx1 = runif(5, 0.05, 0.45)) @@ -1112,24 +1232,25 @@ test_that("colocboost_pipeline preserves result fields when analyses return NULL ) LD <- diag(5) rownames(LD) <- colnames(LD) <- colnames(X) + ld_ref_panel <- cbind(parse_variant_id(colnames(X)), variant_id = colnames(X)) + ld_ref_panel$chrom <- as.character(ld_ref_panel$chrom) + ld_data_obj <- LDData( + correlation = LD, + variants = pecotmr:::.ref_panel_to_granges(ld_ref_panel), + block_metadata = data.frame( + block_id = 1L, chrom = "1", block_start = 100, block_end = 500, + size = 5L, start_idx = 1L, end_idx = 5L + ) + ) region_data <- list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = list(ctx1 = Y), residual_X = list(ctx1 = X), maf = list(ctx1 = runif(5, 0.05, 0.45)) ), sumstat_data = list( sumstats = list(chr21_ref = list(study1 = list(sumstats = sumstat, n = 1000, var_y = 1))), - LD_info = list(chr21_ref = list( - LD_variants = colnames(X), - LD_matrix = LD, - ref_panel = cbind(parse_variant_id(colnames(X)), variant_id = colnames(X)), - block_metadata = data.frame( - block_id = 1L, chrom = "1", block_start = 100, block_end = 500, - size = 5L, start_idx = 1L, end_idx = 5L - ), - is_genotype = FALSE - )) + LD_info = list(chr21_ref = ld_data_obj) ) ) @@ -1206,10 +1327,10 @@ test_that("filter_events keeps events matching valid_pattern", { colnames(Y) <- events region_data <- list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = list(tissue1 = Y), residual_X = list(tissue1 = X), - maf = list(tissue1 = runif(p, 0.05, 0.45)) + maf = list(tissue1 = runif(p, 0.05, 0.45)) ), sumstat_data = NULL ) @@ -1229,8 +1350,8 @@ test_that("filter_events keeps events matching valid_pattern", { # Return the data as-is; transform residual_Y to Y format list( individual_data = list( - Y = region_data$individual_data$residual_Y, - X = region_data$individual_data$residual_X + Y = region_data$individual_data@phenotypes, + X = stats::setNames(replicate(length(region_data$individual_data@phenotypes), region_data$individual_data@genotype_matrix, simplify = FALSE), names(region_data$individual_data@phenotypes)) ), sumstat_data = NULL ) @@ -1263,10 +1384,10 @@ test_that("filter_events errors on missing type_pattern", { colnames(Y) <- c("evt1", "evt2") region_data <- list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = list(ctx1 = Y), residual_X = list(ctx1 = X), - maf = list(ctx1 = runif(p, 0.05, 0.45)) + maf = list(ctx1 = runif(p, 0.05, 0.45)) ), sumstat_data = NULL ) @@ -1297,10 +1418,10 @@ test_that("filter_events errors when only type_pattern is given (no valid or exc colnames(Y) <- c("evt1", "evt2") region_data <- list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = list(ctx1 = Y), residual_X = list(ctx1 = X), - maf = list(ctx1 = runif(p, 0.05, 0.45)) + maf = list(ctx1 = runif(p, 0.05, 0.45)) ), sumstat_data = NULL ) @@ -1326,8 +1447,9 @@ test_that("extract_contexts_studies returns individual contexts and sumstat stud # We access the internal by constructing minimal region_data and triggering # the pipeline but with both analysis=FALSE so it exits early after extraction. region_data <- list( - individual_data = list( - residual_Y = list(tissue_A = matrix(1, 2, 2), tissue_B = matrix(1, 2, 2)) + individual_data = .test_regionaldata_from_lists( + residual_Y = list(tissue_A = matrix(1, 2, 2), tissue_B = matrix(1, 2, 2)), + residual_X = list(tissue_A = matrix(1, 2, 2), tissue_B = matrix(1, 2, 2)) ), sumstat_data = list( sumstats = list( @@ -1362,10 +1484,10 @@ test_that("extract_contexts_studies reports after-QC when some individual data r colnames(Y) <- "gene1" region_data <- list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = list(ctx1 = Y, ctx2 = Y), residual_X = list(ctx1 = X, ctx2 = X), - maf = list(ctx1 = runif(p, 0.05, 0.45), ctx2 = runif(p, 0.05, 0.45)) + maf = list(ctx1 = runif(p, 0.05, 0.45), ctx2 = runif(p, 0.05, 0.45)) ), sumstat_data = NULL ) @@ -1413,7 +1535,7 @@ test_that("qc_regional_data handles named pip_cutoff_to_skip_sumstat vector", { # Mock out the heavy QC functions local_mocked_bindings( allele_qc = function(target_data, ref_variants, ...) { - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) }, rss_basic_qc = function(sumstats, LD_data, ...) { ld_corr <- if (is(LD_data, "LDData")) getCorrelation(LD_data) else LD_data$LD_matrix @@ -1425,7 +1547,7 @@ test_that("qc_regional_data handles named pip_cutoff_to_skip_sumstat vector", { ss <- rss_input[[study]]$sumstats ld <- if (is(LD_data[[study]], "LDData")) getCorrelation(LD_data[[study]]) else LD_data[[study]]$LD_matrix LD_mat <- ld[ss$variant_id, ss$variant_id, drop = FALSE] - list(rss_input = rss_input[[study]], LD_matrix = LD_mat, outlier_number = 0) + .test_qcresult_from_list(rss_input[[study]], LD_mat) }), names(rss_input)) }, raiss = function(...) { @@ -1464,7 +1586,7 @@ test_that("qc_regional_data fills missing study names with 0 for pip_cutoff_to_s ss <- rss_input[[study]]$sumstats ld <- if (is(LD_data[[study]], "LDData")) getCorrelation(LD_data[[study]]) else LD_data[[study]]$LD_matrix LD_mat <- ld[ss$variant_id, ss$variant_id, drop = FALSE] - list(rss_input = rss_input[[study]], LD_matrix = LD_mat, outlier_number = 0) + .test_qcresult_from_list(rss_input[[study]], LD_mat) }), names(rss_input)) }, raiss = function(...) list(result_filter = data.frame(z = rnorm(5)), LD_mat = diag(5)), @@ -1522,12 +1644,12 @@ test_that("pipeline with individual data enters xqtl_coloc path and records timi # by checking that computing_time$Analysis$xqtl_coloc is recorded. local_mocked_bindings( qc_regional_data = function(region_data, ...) { - Y1 <- region_data$individual_data$residual_Y[[1]] - colnames(Y1) <- paste0(names(region_data$individual_data$residual_Y)[1], "_", colnames(Y1)) + Y1 <- region_data$individual_data@phenotypes[[1]] + colnames(Y1) <- paste0(names(region_data$individual_data@phenotypes)[1], "_", colnames(Y1)) list( individual_data = list( Y = list(ctx1 = Y1), - X = list(ctx1 = region_data$individual_data$residual_X[[1]]) + X = list(ctx1 = stats::setNames(replicate(length(region_data$individual_data@phenotypes), region_data$individual_data@genotype_matrix, simplify = FALSE), names(region_data$individual_data@phenotypes))[[1]]) ), sumstat_data = NULL ) @@ -1561,10 +1683,10 @@ test_that("filter_events exclude_pattern removes matching events via pipeline", colnames(Y) <- c("good_event", "bad_event") region_data <- list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = list(ctx1 = Y), residual_X = list(ctx1 = X), - maf = list(ctx1 = runif(p, 0.05, 0.45)) + maf = list(ctx1 = runif(p, 0.05, 0.45)) ), sumstat_data = NULL ) @@ -1583,11 +1705,11 @@ test_that("filter_events exclude_pattern removes matching events via pipeline", local_mocked_bindings( qc_regional_data = function(region_data, ...) { # The residual_Y should have had bad_event removed by filter_events - remaining_events <- colnames(region_data$individual_data$residual_Y$ctx1) + remaining_events <- colnames(region_data$individual_data@phenotypes$ctx1) list( individual_data = list( - Y = list(ctx1 = region_data$individual_data$residual_Y$ctx1), - X = list(ctx1 = region_data$individual_data$residual_X$ctx1) + Y = list(ctx1 = region_data$individual_data@phenotypes$ctx1), + X = list(ctx1 = stats::setNames(replicate(length(region_data$individual_data@phenotypes), region_data$individual_data@genotype_matrix, simplify = FALSE), names(region_data$individual_data@phenotypes))$ctx1) ), sumstat_data = NULL ) @@ -1622,11 +1744,11 @@ test_that("pipeline catches colocboost xqtl error and returns NULL result", { # local_mocked_bindings. local_mocked_bindings( qc_regional_data = function(region_data, ...) { - Y1 <- region_data$individual_data$residual_Y[[1]] + Y1 <- region_data$individual_data@phenotypes[[1]] colnames(Y1) <- paste0("ctx1_", colnames(Y1)) # Return X with mismatched rows to guarantee colocboost errors bad_X <- matrix(rnorm(5 * 8), nrow = 5, ncol = 8) - colnames(bad_X) <- colnames(region_data$individual_data$residual_X[[1]]) + colnames(bad_X) <- colnames(stats::setNames(replicate(length(region_data$individual_data@phenotypes), region_data$individual_data@genotype_matrix, simplify = FALSE), names(region_data$individual_data@phenotypes))[[1]]) list( individual_data = list( Y = list(ctx1 = Y1), @@ -1692,7 +1814,7 @@ make_qced_sumstat_data <- function(studies = c("study1"), n_variants = 5) { names(sumstats) <- studies list( sumstats = sumstats, - LD_mat = stats::setNames(list(LD_mat), studies[1]), + LD_data = stats::setNames(list(.test_lddata_from_matrix(LD_mat)), studies[1]), LD_match = stats::setNames(rep(studies[1], length(studies)), studies) ) } @@ -1746,12 +1868,12 @@ test_that("pipeline skips joint_gwas when QC removes sumstats but keeps individu local_mocked_bindings( qc_regional_data = function(region_data, ...) { - Y1 <- region_data$individual_data$residual_Y[[1]] + Y1 <- region_data$individual_data@phenotypes[[1]] colnames(Y1) <- paste0("ctx1_", colnames(Y1)) list( individual_data = list( Y = list(ctx1 = Y1), - X = list(ctx1 = region_data$individual_data$residual_X[[1]]) + X = list(ctx1 = stats::setNames(replicate(length(region_data$individual_data@phenotypes), region_data$individual_data@genotype_matrix, simplify = FALSE), names(region_data$individual_data@phenotypes))[[1]]) ), sumstat_data = list(sumstats = NULL) ) @@ -1823,7 +1945,7 @@ test_that("pipeline event filters can remove all events in one context while kee Y_keep <- matrix(rnorm(n * 2), n, 2) colnames(Y_keep) <- c("keep_event", "bad_event") region_data <- list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = list(ctx_drop = Y_drop, ctx_keep = Y_keep), residual_X = list(ctx_drop = X, ctx_keep = X), maf = list(ctx_drop = runif(p, 0.05, 0.45), ctx_keep = runif(p, 0.05, 0.45)) @@ -1838,11 +1960,23 @@ test_that("pipeline event filters can remove all events in one context while kee local_mocked_bindings( qc_regional_data = function(region_data, ...) { + Y_list <- region_data$individual_data@phenotypes + X_list <- stats::setNames( + replicate(length(Y_list), region_data$individual_data@genotype_matrix, + simplify = FALSE), + names(Y_list) + ) + # event_filters dropped contexts: re-attach them as NULL entries so the + # pipeline emits the legacy "Skipping follow-up analysis" message. + dropped <- attr(region_data$individual_data, "filtered_out_contexts") + if (!is.null(dropped)) { + for (ctx in dropped) { + Y_list <- c(Y_list, stats::setNames(list(NULL), ctx)) + X_list <- c(X_list, stats::setNames(list(NULL), ctx)) + } + } list( - individual_data = list( - Y = region_data$individual_data$residual_Y, - X = region_data$individual_data$residual_X - ), + individual_data = list(Y = Y_list, X = X_list), sumstat_data = NULL ) }, @@ -1887,7 +2021,7 @@ make_individual_region_data <- function(n = 20, p = 8, n_contexts = 2, n_events ctxs <- lapply(paste0("ctx", seq_len(n_contexts)), make_ctx) names(ctxs) <- paste0("ctx", seq_len(n_contexts)) list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = lapply(ctxs, `[[`, "Y"), residual_X = lapply(ctxs, `[[`, "X"), maf = lapply(ctxs, `[[`, "maf") @@ -1931,11 +2065,7 @@ make_sumstat_region_data <- function(n_variants = 5, n_studies = 2) { list(ss) |> setNames(paste0("study", i)) }) - LD_info <- list(list( - LD_variants = ref_panel, - LD_matrix = LD_mat, - ref_panel = ref_panel - )) + LD_info <- list(.test_lddata_from_matrix(LD_mat)) list( individual_data = NULL, @@ -1963,7 +2093,7 @@ test_that("filter_events: valid_pattern with no matching groups returns NULL (li colnames(Y) <- events region_data <- list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = list(tissue1 = Y), residual_X = list(tissue1 = X), maf = list(tissue1 = runif(p, 0.05, 0.45)) @@ -1981,10 +2111,13 @@ test_that("filter_events: valid_pattern with no matching groups returns NULL (li local_mocked_bindings( qc_regional_data = function(region_data, ...) { + if (is.null(region_data$individual_data)) { + return(list(individual_data = NULL, sumstat_data = NULL)) + } list( individual_data = list( - Y = region_data$individual_data$residual_Y, - X = region_data$individual_data$residual_X + Y = region_data$individual_data@phenotypes, + X = stats::setNames(replicate(length(region_data$individual_data@phenotypes), region_data$individual_data@genotype_matrix, simplify = FALSE), names(region_data$individual_data@phenotypes)) ), sumstat_data = NULL ) @@ -2014,7 +2147,7 @@ test_that("filter_events: type_pattern matches nothing skips via next (line 64)" colnames(Y) <- c("gene_A", "gene_B") region_data <- list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = list(ctx1 = Y), residual_X = list(ctx1 = X), maf = list(ctx1 = runif(p, 0.05, 0.45)) @@ -2033,12 +2166,12 @@ test_that("filter_events: type_pattern matches nothing skips via next (line 64)" local_mocked_bindings( qc_regional_data = function(region_data, ...) { # Verify events were NOT filtered (both still present) - remaining <- colnames(region_data$individual_data$residual_Y$ctx1) + remaining <- colnames(region_data$individual_data@phenotypes$ctx1) expect_equal(length(remaining), 2) list( individual_data = list( - Y = region_data$individual_data$residual_Y, - X = region_data$individual_data$residual_X + Y = region_data$individual_data@phenotypes, + X = stats::setNames(replicate(length(region_data$individual_data@phenotypes), region_data$individual_data@genotype_matrix, simplify = FALSE), names(region_data$individual_data@phenotypes)) ), sumstat_data = NULL ) @@ -2067,7 +2200,7 @@ test_that("filter_events: all events pass -> 'included in following analysis' me colnames(Y) <- c("evt_alpha", "evt_beta") region_data <- list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = list(ctx1 = Y), residual_X = list(ctx1 = X), maf = list(ctx1 = runif(p, 0.05, 0.45)) @@ -2087,8 +2220,8 @@ test_that("filter_events: all events pass -> 'included in following analysis' me qc_regional_data = function(region_data, ...) { list( individual_data = list( - Y = region_data$individual_data$residual_Y, - X = region_data$individual_data$residual_X + Y = region_data$individual_data@phenotypes, + X = stats::setNames(replicate(length(region_data$individual_data@phenotypes), region_data$individual_data@genotype_matrix, simplify = FALSE), names(region_data$individual_data@phenotypes)) ), sumstat_data = NULL ) @@ -2118,7 +2251,7 @@ test_that("filter_events: valid_pattern with successful groups retains valid eve colnames(Y) <- events region_data <- list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = list(tissue1 = Y), residual_X = list(tissue1 = X), maf = list(tissue1 = runif(p, 0.05, 0.45)) @@ -2136,13 +2269,13 @@ test_that("filter_events: valid_pattern with successful groups retains valid eve local_mocked_bindings( qc_regional_data = function(region_data, ...) { - remaining <- colnames(region_data$individual_data$residual_Y$tissue1) + remaining <- colnames(region_data$individual_data@phenotypes$tissue1) # IN event should be removed expect_false("clu_1_+:IN:gene1" %in% remaining) list( individual_data = list( - Y = region_data$individual_data$residual_Y, - X = region_data$individual_data$residual_X + Y = region_data$individual_data@phenotypes, + X = stats::setNames(replicate(length(region_data$individual_data@phenotypes), region_data$individual_data@genotype_matrix, simplify = FALSE), names(region_data$individual_data@phenotypes)) ), sumstat_data = NULL ) @@ -2245,7 +2378,7 @@ test_that("extract_contexts_studies: sumstat studies extraction on initial call sumstats = data.frame(z = 1.5, variant_id = "chr1:100:A:G"), n = 100, var_y = 1 )), - LD_mat = list(gwas_trait1 = matrix(1, 1, 1, dimnames = list("chr1:100:A:G", "chr1:100:A:G"))), + LD_data = list(gwas_trait1 = .test_lddata_from_matrix(matrix(1, 1, 1, dimnames = list("chr1:100:A:G", "chr1:100:A:G")))), LD_match = "gwas_trait1" ) ) @@ -2277,7 +2410,7 @@ test_that("extract_contexts_studies: after-QC sumstat all pass (line 144)", { individual_data = NULL, sumstat_data = list( sumstats = list(study1 = ss1, study2 = ss2), - LD_mat = list(study1 = LD_mat), + LD_data = list(study1 = .test_lddata_from_matrix(LD_mat)), LD_match = c("study1", "study1") ) ) @@ -2309,7 +2442,7 @@ test_that("extract_contexts_studies: after-QC sumstat partial pass (line 146)", individual_data = NULL, sumstat_data = list( sumstats = list(study1 = ss1), - LD_mat = list(study1 = LD_mat), + LD_data = list(study1 = .test_lddata_from_matrix(LD_mat)), LD_match = c("study1") ) ) @@ -2349,11 +2482,7 @@ test_that("pipeline sumstat block: normalizes variant IDs and processes LD matri n = 5000, var_y = 1 )) ), - LD_info = list(list( - LD_variants = data.frame(chrom = 1, pos = seq_len(n_variants) * 100, A2 = "A", A1 = "G"), - LD_matrix = LD_mat, - ref_panel = data.frame(chrom = 1, pos = seq_len(n_variants) * 100, A2 = "A", A1 = "G") - )) + LD_info = list(.test_lddata_from_matrix(LD_mat)) ) ) @@ -2364,7 +2493,7 @@ test_that("pipeline sumstat block: normalizes variant IDs and processes LD matri individual_data = NULL, sumstat_data = list( sumstats = list(study1 = ss), - LD_mat = list(study1 = LD_mat), + LD_data = list(study1 = .test_lddata_from_matrix(LD_mat)), LD_match = c("study1") ) ) @@ -2400,11 +2529,7 @@ test_that("pipeline sumstat block: single sumstat study initializes separate_gwa n = 5000, var_y = 1 )) ), - LD_info = list(list( - LD_variants = data.frame(chrom = 1, pos = seq_len(n_variants) * 100, A2 = "A", A1 = "G"), - LD_matrix = LD_mat, - ref_panel = data.frame(chrom = 1, pos = seq_len(n_variants) * 100, A2 = "A", A1 = "G") - )) + LD_info = list(.test_lddata_from_matrix(LD_mat)) ) ) @@ -2416,7 +2541,7 @@ test_that("pipeline sumstat block: single sumstat study initializes separate_gwa individual_data = NULL, sumstat_data = list( sumstats = list(single_study = ss), - LD_mat = list(single_study = LD_mat), + LD_data = list(single_study = .test_lddata_from_matrix(LD_mat)), LD_match = c("single_study") ) ) @@ -2454,11 +2579,7 @@ test_that("pipeline sumstat block: multiple sumstat studies initializes separate sumstats = data.frame(z = rnorm(n_variants), variant_id = vids), n = 5000, var_y = 1 )) ), - LD_info = list(list( - LD_variants = data.frame(chrom = 1, pos = seq_len(n_variants) * 100, A2 = "A", A1 = "G"), - LD_matrix = LD_mat, - ref_panel = data.frame(chrom = 1, pos = seq_len(n_variants) * 100, A2 = "A", A1 = "G") - )) + LD_info = list(.test_lddata_from_matrix(LD_mat)) ) ) @@ -2470,7 +2591,7 @@ test_that("pipeline sumstat block: multiple sumstat studies initializes separate individual_data = NULL, sumstat_data = list( sumstats = list(studyA = ss1, studyB = ss2), - LD_mat = list(studyA = LD_mat), + LD_data = list(studyA = .test_lddata_from_matrix(LD_mat)), LD_match = c("studyA", "studyA") ) ) @@ -2511,11 +2632,7 @@ test_that("pipeline: all sumstats invalid returns No data pass QC (line 275-276) sumstats = data.frame(z = rnorm(n_variants), variant_id = vids), n = 5000, var_y = 1 )) ), - LD_info = list(list( - LD_variants = data.frame(chrom = 1, pos = seq_len(n_variants) * 100, A2 = "A", A1 = "G"), - LD_matrix = LD_mat, - ref_panel = data.frame(chrom = 1, pos = seq_len(n_variants) * 100, A2 = "A", A1 = "G") - )) + LD_info = list(.test_lddata_from_matrix(LD_mat)) ) ) @@ -2530,7 +2647,7 @@ test_that("pipeline: all sumstats invalid returns No data pass QC (line 275-276) individual_data = NULL, sumstat_data = list( sumstats = list(study1 = invalid_ss), - LD_mat = list(study1 = matrix(1, 1, 1, dimnames = list("chr1:100:A:G", "chr1:100:A:G"))), + LD_data = list(study1 = .test_lddata_from_matrix(matrix(1, 1, 1, dimnames = list("chr1:100:A:G", "chr1:100:A:G")))), LD_match = c("study1") ) ) @@ -2634,7 +2751,7 @@ test_that("pipeline: joint_gwas path is entered with both individual and sumstat LD_mat <- diag(p); rownames(LD_mat) <- colnames(LD_mat) <- vids region_data <- list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = list(ctx1 = Y), residual_X = list(ctx1 = X), maf = list(ctx1 = runif(p, 0.05, 0.45)) @@ -2645,11 +2762,7 @@ test_that("pipeline: joint_gwas path is entered with both individual and sumstat sumstats = data.frame(z = rnorm(p), variant_id = vids), n = 5000, var_y = 1 )) ), - LD_info = list(list( - LD_variants = data.frame(chrom = 1, pos = seq_len(p) * 100, A2 = "A", A1 = "G"), - LD_matrix = LD_mat, - ref_panel = data.frame(chrom = 1, pos = seq_len(p) * 100, A2 = "A", A1 = "G") - )) + LD_info = list(.test_lddata_from_matrix(LD_mat)) ) ) @@ -2663,7 +2776,7 @@ test_that("pipeline: joint_gwas path is entered with both individual and sumstat ), sumstat_data = list( sumstats = list(gwas1 = ss), - LD_mat = list(gwas1 = LD_mat), + LD_data = list(gwas1 = .test_lddata_from_matrix(LD_mat)), LD_match = c("gwas1") ) ) @@ -2698,7 +2811,7 @@ test_that("pipeline: separate_gwas path is entered for each GWAS study", { LD_mat <- diag(p); rownames(LD_mat) <- colnames(LD_mat) <- vids region_data <- list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = list(ctx1 = Y), residual_X = list(ctx1 = X), maf = list(ctx1 = runif(p, 0.05, 0.45)) @@ -2709,11 +2822,7 @@ test_that("pipeline: separate_gwas path is entered for each GWAS study", { sumstats = data.frame(z = rnorm(p), variant_id = vids), n = 5000, var_y = 1 )) ), - LD_info = list(list( - LD_variants = data.frame(chrom = 1, pos = seq_len(p) * 100, A2 = "A", A1 = "G"), - LD_matrix = LD_mat, - ref_panel = data.frame(chrom = 1, pos = seq_len(p) * 100, A2 = "A", A1 = "G") - )) + LD_info = list(.test_lddata_from_matrix(LD_mat)) ) ) @@ -2727,7 +2836,7 @@ test_that("pipeline: separate_gwas path is entered for each GWAS study", { ), sumstat_data = list( sumstats = list(gwasA = ss), - LD_mat = list(gwasA = LD_mat), + LD_data = list(gwasA = .test_lddata_from_matrix(LD_mat)), LD_match = c("gwasA") ) ) @@ -2797,11 +2906,7 @@ test_that("pipeline: sumstat with all NA z-scores yields warning message (lines sumstats = data.frame(z = rnorm(n_variants), variant_id = vids), n = 5000, var_y = 1 )) ), - LD_info = list(list( - LD_variants = data.frame(chrom = 1, pos = seq_len(n_variants) * 100, A2 = "A", A1 = "G"), - LD_matrix = LD_mat, - ref_panel = data.frame(chrom = 1, pos = seq_len(n_variants) * 100, A2 = "A", A1 = "G") - )) + LD_info = list(.test_lddata_from_matrix(LD_mat)) ) ) @@ -2816,7 +2921,7 @@ test_that("pipeline: sumstat with all NA z-scores yields warning message (lines individual_data = NULL, sumstat_data = list( sumstats = list(study_na = ss), - LD_mat = list(study_na = LD_mat), + LD_data = list(study_na = .test_lddata_from_matrix(LD_mat)), LD_match = c("study_na") ) ) @@ -2947,7 +3052,7 @@ test_that("qc_regional_data: pip_cutoff_to_skip_ind lookup works when X and Y ha # residual_X has 3 contexts, residual_Y only has 2 # pip_cutoff_to_skip_ind is recycled from residual_Y (length 2) region_data <- list( - individual_data = list( + individual_data = .test_regionaldata_from_lists( residual_Y = list(ctx1 = ctx1$Y, ctx2 = ctx2$Y), residual_X = list(ctx1 = ctx1$X, ctx2 = ctx2$X, ctx3 = ctx3$X), maf = list(ctx1 = ctx1$maf, ctx2 = ctx2$maf, ctx3 = ctx3$maf) @@ -2982,11 +3087,7 @@ test_that("pipeline sumstat processing handles all-NA z-scores with warning (lin sumstats = data.frame(z = rnorm(n_variants), variant_id = vids), n = 5000, var_y = 1 )) ), - LD_info = list(list( - LD_variants = data.frame(chrom = 1, pos = seq_len(n_variants) * 100, A2 = "A", A1 = "G"), - LD_matrix = LD_mat, - ref_panel = data.frame(chrom = 1, pos = seq_len(n_variants) * 100, A2 = "A", A1 = "G") - )) + LD_info = list(.test_lddata_from_matrix(LD_mat)) ) ) @@ -3001,7 +3102,7 @@ test_that("pipeline sumstat processing handles all-NA z-scores with warning (lin individual_data = NULL, sumstat_data = list( sumstats = list(study_allna = ss_allna), - LD_mat = list(study_allna = LD_mat), + LD_data = list(study_allna = .test_lddata_from_matrix(LD_mat)), LD_match = c("study_allna") ) ) @@ -3042,11 +3143,7 @@ test_that("pipeline: LD matrix dimnames are normalized to canonical format (line sumstats = data.frame(z = rnorm(n_variants), variant_id = vids_no_chr), n = 5000, var_y = 1 )) ), - LD_info = list(list( - LD_variants = data.frame(chrom = 1, pos = seq_len(n_variants) * 100, A2 = "A", A1 = "G"), - LD_matrix = LD_mat, - ref_panel = data.frame(chrom = 1, pos = seq_len(n_variants) * 100, A2 = "A", A1 = "G") - )) + LD_info = list(.test_lddata_from_matrix(LD_mat)) ) ) @@ -3060,7 +3157,7 @@ test_that("pipeline: LD matrix dimnames are normalized to canonical format (line individual_data = NULL, sumstat_data = list( sumstats = list(study1 = ss), - LD_mat = list(study1 = LD_mat), # LD_mat has non-chr names + LD_data = list(study1 = .test_lddata_from_matrix(LD_mat)), # LD_mat has non-chr names LD_match = c("study1") ) ) @@ -3094,8 +3191,7 @@ test_that("qc_regional_data: with only sumstat data processes correctly", { summary_stats_qc = function(rss_input = NULL, LD_data, ...) { stats::setNames(lapply(names(rss_input), function(study) { ld <- if (is(LD_data[[study]], "LDData")) getCorrelation(LD_data[[study]]) else LD_data[[study]]$LD_matrix - list(rss_input = rss_input[[study]], LD_matrix = ld, - outlier_number = 0) + .test_qcresult_from_list(rss_input[[study]], ld) }), names(rss_input)) }, raiss = function(...) { diff --git a/tests/testthat/test_encoloc.R b/tests/testthat/test_encoloc.R index 15b1693d..1d37bb24 100644 --- a/tests/testthat/test_encoloc.R +++ b/tests/testthat/test_encoloc.R @@ -1,4 +1,17 @@ context("encoloc") + +.test_fm_result <- function(variant_names, trimmed_fit = list(), + top_loci = data.frame(variant_id = character(0), + method = character(0), + stringsAsFactors = FALSE)) { + FineMappingResult( + variant_names = variant_names, + trimmed_fit = trimmed_fit, + top_loci = top_loci, + method = "susie" + ) +} + library(tidyverse) library(coloc) @@ -475,12 +488,14 @@ test_that("coloc_wrapper with run_finemapping = TRUE uses rss_analysis_pipeline" # Build mock pipeline result matching rss_analysis_pipeline output structure mock_pipeline <- list( "susie_rss_SLALOM_RAISS_imputed" = list( - variant_names = xqtl_fit$variant_names, - susie_result_trimmed = list( - lbf_variable = xqtl_fit$lbf_variable, - V = xqtl_fit$V, - pip = xqtl_fit$pip, - sets = list(cs_index = seq_len(nrow(xqtl_fit$lbf_variable))) + finemapping_result = .test_fm_result( + variant_names = xqtl_fit$variant_names, + trimmed_fit = list( + lbf_variable = xqtl_fit$lbf_variable, + V = xqtl_fit$V, + pip = xqtl_fit$pip, + sets = list(cs_index = seq_len(nrow(xqtl_fit$lbf_variable))) + ) ) ), rss_data_analyzed = data.frame( @@ -514,12 +529,14 @@ test_that("coloc_wrapper with return_finemapping includes pipeline result", { mock_pipeline <- list( "susie_rss_SLALOM" = list( - variant_names = xqtl_fit$variant_names, - susie_result_trimmed = list( - lbf_variable = xqtl_fit$lbf_variable, - V = xqtl_fit$V, - pip = xqtl_fit$pip, - sets = list(cs_index = seq_len(nrow(xqtl_fit$lbf_variable))) + finemapping_result = .test_fm_result( + variant_names = xqtl_fit$variant_names, + trimmed_fit = list( + lbf_variable = xqtl_fit$lbf_variable, + V = xqtl_fit$V, + pip = xqtl_fit$pip, + sets = list(cs_index = seq_len(nrow(xqtl_fit$lbf_variable))) + ) ) ), rss_data_analyzed = data.frame( @@ -555,12 +572,14 @@ test_that("coloc_wrapper save_finemapping_path saves reusable RDS", { mock_pipeline <- list( "susie_rss_SLALOM" = list( - variant_names = xqtl_fit$variant_names, - susie_result_trimmed = list( - lbf_variable = xqtl_fit$lbf_variable, - V = xqtl_fit$V, - pip = xqtl_fit$pip, - sets = list(cs_index = seq_len(nrow(xqtl_fit$lbf_variable))) + finemapping_result = .test_fm_result( + variant_names = xqtl_fit$variant_names, + trimmed_fit = list( + lbf_variable = xqtl_fit$lbf_variable, + V = xqtl_fit$V, + pip = xqtl_fit$pip, + sets = list(cs_index = seq_len(nrow(xqtl_fit$lbf_variable))) + ) ) ), rss_data_analyzed = data.frame( diff --git a/tests/testthat/test_file_utils.R b/tests/testthat/test_file_utils.R index 23de3a8b..94777175 100644 --- a/tests/testthat/test_file_utils.R +++ b/tests/testthat/test_file_utils.R @@ -980,8 +980,9 @@ test_that("Test load_regional_univariate_data",{ xvar_cutoff = 0.2, phenotype_header = 3, keep_samples = NULL) - expect_true("residual_X" %in% names(res)) - expect_true("residual_Y" %in% names(res)) + expect_true(is(res, "RegionalData")) + expect_true(is.matrix(getResidualX(res, 1L))) + expect_true(is.matrix(getResidualY(res, 1L))) }) test_that("Test load_regional_regression_data",{ @@ -1009,20 +1010,24 @@ test_that("Test load_regional_regression_data",{ xvar_cutoff = 0.2, phenotype_header = 3, keep_samples = NULL) - expect_equal(nrow(res$X_data[[1]]), 10) - expect_equal(ncol(res$X_data[[1]]), 10) + expect_true(is(res, "RegionalData")) + X_mat <- res@genotype_matrix + expect_equal(nrow(X_mat), 10) + expect_equal(ncol(X_mat), 10) colnames(geno_data) <- gsub("_", ":", colnames(geno_data)) - expect_equal(res$X_data[[1]][order(as.numeric(gsub("Sample_", "", rownames(res$X_data[[1]])))), , drop = FALSE], geno_data) - expect_equal(length(res$Y[[1]]), 10) + expect_equal(X_mat[order(as.numeric(gsub("Sample_", "", rownames(X_mat)))), , drop = FALSE], geno_data) + Y_mat <- res@phenotypes[[1]] + expect_equal(nrow(Y_mat), 10) expect_equal( - setNames(res$Y[[1]][order(as.numeric(gsub("Sample_", "", rownames(res$Y[[1]]))))], -rownames(res$Y[[1]])[order(as.numeric(gsub("Sample_", "", rownames(res$Y[[1]]))))]), + setNames(as.numeric(Y_mat[order(as.numeric(gsub("Sample_", "", rownames(Y_mat)))), 1]), + rownames(Y_mat)[order(as.numeric(gsub("Sample_", "", rownames(Y_mat))))]), setNames( as.numeric(pheno_data[[1]][4:13,]), names(pheno_data[[1]][4:13,]))) - expect_equal(nrow(res$covar[[1]]), 10) - expect_equal(ncol(res$covar[[1]]), 5) - expect_equal(res$covar[[1]][order(as.numeric(gsub("Sample_", "", rownames(res$covar[[1]])))), , drop = FALSE], covar_data) + covar_mat <- res@covariates[[1]] + expect_equal(nrow(covar_mat), 10) + expect_equal(ncol(covar_mat), 5) + expect_equal(covar_mat[order(as.numeric(gsub("Sample_", "", rownames(covar_mat)))), , drop = FALSE], covar_data) }) test_that("load_regional_multivariate_data filters Y by min completeness", { @@ -1048,9 +1053,10 @@ test_that("load_regional_multivariate_data filters Y by min completeness", { imiss_cutoff = 0.70, maf_cutoff = 0.1, mac_cutoff = (0.1 * 10 * 2), xvar_cutoff = 0.2, phenotype_header = 3, keep_samples = NULL ) - expect_true(!is.null(result$X)) - expect_true(!is.null(result$maf)) - expect_true(!is.null(result$X_variance)) + expect_true(is(result, "MultivariateRegionalData")) + expect_true(is.matrix(result@genotype_matrix)) + expect_true(is.numeric(getMaf(result))) + expect_true(is.numeric(getXVariance(result))) }) test_that("load_regional_functional_data returns full association data", { @@ -1073,8 +1079,9 @@ test_that("load_regional_functional_data returns full association data", { imiss_cutoff = 0.70, maf_cutoff = 0.1, mac_cutoff = (0.1 * 10 * 2), xvar_cutoff = 0.2, phenotype_header = 3, keep_samples = NULL ) - expect_true("residual_Y" %in% names(result)) - expect_true("X" %in% names(result)) + expect_true(is(result, "RegionalData")) + expect_true(is.matrix(getResidualY(result, 1L))) + expect_true(is.matrix(result@genotype_matrix)) }) # =========================================================================== @@ -1587,10 +1594,11 @@ test_that("load_multitask_regional_data individual-level path returns expected s expect_named(result, c("individual_data", "sumstat_data")) expect_false(is.null(result$individual_data)) expect_true(is.null(result$sumstat_data)) - # Individual data should have standard fields - expect_true("residual_Y" %in% names(result$individual_data)) - expect_true("X" %in% names(result$individual_data)) - expect_true("chrom" %in% names(result$individual_data)) + # Individual data should be a RegionalData + expect_true(is(result$individual_data, "RegionalData")) + expect_true(is.matrix(getResidualY(result$individual_data, 1L))) + expect_true(is.matrix(result$individual_data@genotype_matrix)) + expect_true(!is.null(getChrom(result$individual_data))) }) test_that("load_multitask_regional_data loads and merges multiple genotype groups", { @@ -1604,20 +1612,19 @@ test_that("load_multitask_regional_data loads and merges multiple genotype group matrix(1, nrow = 2, ncol = 1, dimnames = list(c("s1", "s2"), nm)) }), conditions) - list( - residual_Y = y, - residual_X = x, - residual_Y_scalar = setNames(as.list(rep(1, length(conditions))), conditions), - residual_X_scalar = setNames(as.list(rep(1, length(conditions))), conditions), - dropped_sample = list(X = list(), Y = list(), covar = list()), - covar = list(), - Y = y, - X_data = x, - X = do.call(cbind, x), - maf = setNames(lapply(x, function(mat) rep(0.1, ncol(mat))), conditions), - chrom = "chr1", - grange = c("1", "100"), - Y_coordinates = list() + covar_list <- setNames(lapply(conditions, function(nm) { + matrix(numeric(0), nrow = 2, ncol = 0, dimnames = list(c("s1", "s2"), NULL)) + }), conditions) + maf_list <- setNames(lapply(x, function(mat) rep(0.1, ncol(mat))), conditions) + RegionalData( + genotype_matrix = do.call(cbind, x), + phenotypes = y, + covariates = covar_list, + scale_residuals = FALSE, + maf = maf_list, + region = NULL, + dropped_samples = list(X = list(), Y = list(), covar = list()), + Y_coordinates = NULL ) } @@ -1655,8 +1662,8 @@ test_that("load_multitask_regional_data loads and merges multiple genotype group expect_equal(calls[[2]]$phenotype, paste0("pheno", 3:4)) expect_equal(calls[[1]]$extract_region_name, as.list(paste0("gene", 1:2))) expect_equal(calls[[2]]$extract_region_name, as.list(paste0("gene", 3:4))) - expect_true("residual_X" %in% names(result$individual_data)) - expect_equal(names(result$individual_data$residual_X), paste0("cond", 1:4)) + expect_true(is(result$individual_data, "RegionalData")) + expect_equal(names(result$individual_data@phenotypes), paste0("cond", 1:4)) }) test_that("load_multitask_regional_data defaults missing individual condition names", { @@ -1670,20 +1677,19 @@ test_that("load_multitask_regional_data defaults missing individual condition na matrix(1, nrow = 2, ncol = 1, dimnames = list(c("s1", "s2"), nm)) }), conditions) - list( - residual_Y = y, - residual_X = x, - residual_Y_scalar = setNames(as.list(rep(1, length(conditions))), conditions), - residual_X_scalar = setNames(as.list(rep(1, length(conditions))), conditions), - dropped_sample = list(X = list(), Y = list(), covar = list()), - covar = list(), - Y = y, - X_data = x, - X = do.call(cbind, x), - maf = setNames(lapply(x, function(mat) rep(0.1, ncol(mat))), conditions), - chrom = "chr1", - grange = c("1", "100"), - Y_coordinates = list() + covar_list <- setNames(lapply(conditions, function(nm) { + matrix(numeric(0), nrow = 2, ncol = 0, dimnames = list(c("s1", "s2"), NULL)) + }), conditions) + maf_list <- setNames(lapply(x, function(mat) rep(0.1, ncol(mat))), conditions) + RegionalData( + genotype_matrix = do.call(cbind, x), + phenotypes = y, + covariates = covar_list, + scale_residuals = FALSE, + maf = maf_list, + region = NULL, + dropped_samples = list(X = list(), Y = list(), covar = list()), + Y_coordinates = NULL ) } @@ -1714,7 +1720,7 @@ test_that("load_multitask_regional_data defaults missing individual condition na expect_equal(length(calls), 2L) expect_equal(calls[[1]]$conditions, paste0("condition", 1:2)) expect_equal(calls[[2]]$conditions, paste0("condition", 3:4)) - expect_equal(names(result$individual_data$residual_X), paste0("condition", 1:4)) + expect_equal(names(result$individual_data@phenotypes), paste0("condition", 1:4)) }) test_that("load_multitask_regional_data validates individual input vector lengths", { @@ -1831,7 +1837,8 @@ test_that("load_multitask_regional_data both paths simultaneously", { expect_false(is.null(result$individual_data)) expect_false(is.null(result$sumstat_data)) # Both paths should produce valid data - expect_true("X" %in% names(result$individual_data)) + expect_true(is(result$individual_data, "RegionalData")) + expect_true(is.matrix(result$individual_data@genotype_matrix)) expect_true(is.data.frame(result$sumstat_data$sumstats[[1]][["ss_cond1"]]$sumstats)) }) @@ -2706,14 +2713,17 @@ test_that("load_regional_univariate_data returns correct fields", { region = "chr21:17513043-17593579", conditions = "cond1" ) - expected_names <- c("residual_Y", "residual_X", "residual_Y_scalar", - "residual_X_scalar", "dropped_sample", "maf", - "X", "chrom", "grange", "X_variance") - expect_true(all(expected_names %in% names(result))) - expect_equal(nrow(result$X), 100L) - # X_variance should be a list with one entry per condition - expect_true(is.list(result$X_variance)) - expect_equal(length(result$X_variance[[1]]), ncol(result$X)) + expect_true(is(result, "RegionalData")) + expect_equal(nrow(result@genotype_matrix), 100L) + # Per-condition accessors should return valid data + expect_true(is.matrix(getResidualX(result, 1L))) + expect_true(is.matrix(getResidualY(result, 1L))) + x_var <- getXVariance(result, 1L) + expect_true(is.numeric(x_var)) + expect_equal(length(x_var), ncol(result@genotype_matrix)) + # Chrom and grange accessors should work + expect_true(!is.null(getChrom(result))) + expect_true(!is.null(getGrange(result))) }) # =========================================================================== @@ -2731,12 +2741,12 @@ test_that("load_regional_regression_data returns correct fields", { region = "chr21:17513043-17593579", conditions = "cond1" ) - expected_names <- c("Y", "X_data", "covar", "dropped_sample", - "maf", "chrom", "grange") - expect_true(all(expected_names %in% names(result))) - expect_true(is.list(result$Y)) - expect_true(is.list(result$X_data)) - expect_true(is.list(result$covar)) + expect_true(is(result, "RegionalData")) + expect_true(is.list(result@phenotypes)) + expect_true(is.matrix(result@phenotypes[[1]])) + expect_true(is.matrix(result@genotype_matrix)) + expect_true(is.list(result@covariates)) + expect_true(is.matrix(result@covariates[[1]])) }) # =========================================================================== @@ -2754,12 +2764,12 @@ test_that("load_regional_multivariate_data returns correct fields", { region = "chr21:17513043-17593579", conditions = "cond1" ) - expected_names <- c("residual_Y", "residual_Y_scalar", "dropped_sample", - "X", "maf", "chrom", "grange", "X_variance") - expect_true(all(expected_names %in% names(result))) - # residual_Y should be a matrix (not list) after pheno_list_to_mat - expect_true(is.matrix(result$residual_Y)) - expect_equal(nrow(result$X), 100L) + expect_true(is(result, "MultivariateRegionalData")) + expect_true(is.matrix(getYMatrix(result))) + expect_equal(nrow(result@genotype_matrix), 100L) + expect_true(is.numeric(getYScalar(result))) + expect_true(is.numeric(getMaf(result))) + expect_true(is.numeric(getXVariance(result))) }) # ============================================================================= diff --git a/tests/testthat/test_mash_wrapper.R b/tests/testthat/test_mash_wrapper.R index 830231e3..581a4a33 100644 --- a/tests/testthat/test_mash_wrapper.R +++ b/tests/testthat/test_mash_wrapper.R @@ -1,5 +1,16 @@ context("mash_wrapper") +# Build a minimal FineMappingResult for unit-testing find_nested / extract_flatten_sumstats_from_nested +.test_fm_result <- function(variant_names) { + FineMappingResult( + variant_names = variant_names, + trimmed_fit = list(pip = rep(0.5, length(variant_names))), + top_loci = data.frame(variant_id = character(0), method = character(0), + stringsAsFactors = FALSE), + method = "susie" + ) +} + # =========================================================================== # merge_susie_cs # =========================================================================== @@ -995,7 +1006,7 @@ test_that("merge_sumstats_matrices with single valid dataset returns properly", test_that("extract_flatten_sumstats_from_nested computes z from betahat/sebetahat", { data <- list( - variant_names = c("1:100:A:G", "1:200:C:T"), + finemapping_result = .test_fm_result(c("1:100:A:G", "1:200:C:T")), sumstats = list( betahat = c(0.5, -0.3), sebetahat = c(0.1, 0.15) @@ -1011,7 +1022,7 @@ test_that("extract_flatten_sumstats_from_nested computes z from betahat/sebetaha test_that("extract_flatten_sumstats_from_nested uses z directly when available", { data <- list( - variant_names = c("1:100:A:G"), + finemapping_result = .test_fm_result(c("1:100:A:G")), sumstats = list(z = c(3.5)) ) result <- extract_flatten_sumstats_from_nested(data, extract_inf = "z") @@ -1020,7 +1031,7 @@ test_that("extract_flatten_sumstats_from_nested uses z directly when available", test_that("extract_flatten_sumstats_from_nested extracts beta from direct sumstats", { data <- list( - variant_names = c("chr1:100:A:G", "chr1:200:C:T"), + finemapping_result = .test_fm_result(c("chr1:100:A:G", "chr1:200:C:T")), sumstats = list( betahat = c(0.5, -0.3), sebetahat = c(0.1, 0.15) @@ -1032,7 +1043,7 @@ test_that("extract_flatten_sumstats_from_nested extracts beta from direct sumsta test_that("extract_flatten_sumstats_from_nested extracts se from direct sumstats", { data <- list( - variant_names = c("chr1:100:A:G", "chr1:200:C:T"), + finemapping_result = .test_fm_result(c("chr1:100:A:G", "chr1:200:C:T")), sumstats = list( betahat = c(0.5, -0.3), sebetahat = c(0.1, 0.15) @@ -1044,7 +1055,7 @@ test_that("extract_flatten_sumstats_from_nested extracts se from direct sumstats test_that("extract_flatten_sumstats_from_nested reaches max_depth and returns NULL", { data <- list(level1 = list(level2 = list(level3 = list(level4 = list( - variant_names = c("1:100:A:G"), + finemapping_result = .test_fm_result(c("1:100:A:G")), sumstats = list(z = c(2.0)) ))))) result <- extract_flatten_sumstats_from_nested(data, extract_inf = "z", max_depth = 2) @@ -1053,7 +1064,7 @@ test_that("extract_flatten_sumstats_from_nested reaches max_depth and returns NU test_that("extract_flatten_sumstats_from_nested handles missing betahat for z", { data <- list( - variant_names = c("1:100:A:G"), + finemapping_result = .test_fm_result(c("1:100:A:G")), sumstats = list(something_else = c(1.0)) ) result <- expect_message( @@ -1065,7 +1076,7 @@ test_that("extract_flatten_sumstats_from_nested handles missing betahat for z", test_that("extract_flatten_sumstats_from_nested handles missing betahat for beta", { data <- list( - variant_names = c("1:100:A:G"), + finemapping_result = .test_fm_result(c("1:100:A:G")), sumstats = list(z = c(2.0)) ) result <- expect_message( @@ -1077,7 +1088,7 @@ test_that("extract_flatten_sumstats_from_nested handles missing betahat for beta test_that("extract_flatten_sumstats_from_nested handles missing sebetahat for se", { data <- list( - variant_names = c("1:100:A:G"), + finemapping_result = .test_fm_result(c("1:100:A:G")), sumstats = list(betahat = c(0.5)) ) result <- expect_message( @@ -1089,7 +1100,7 @@ test_that("extract_flatten_sumstats_from_nested handles missing sebetahat for se test_that("extract_flatten_sumstats_from_nested rejects invalid extract_inf values", { data <- list( - variant_names = c("1:100:A:G"), + finemapping_result = .test_fm_result(c("1:100:A:G")), sumstats = list(z = c(1.0)) ) expect_error( @@ -1100,7 +1111,7 @@ test_that("extract_flatten_sumstats_from_nested rejects invalid extract_inf valu test_that("extract_flatten_sumstats_from_nested normalizes variant IDs to chr prefix", { data <- list( - variant_names = c("1:100:A:G", "2:200:C:T"), + finemapping_result = .test_fm_result(c("1:100:A:G", "2:200:C:T")), sumstats = list(z = c(1.0, 2.0)) ) result <- extract_flatten_sumstats_from_nested(data, extract_inf = "z") @@ -1110,7 +1121,7 @@ test_that("extract_flatten_sumstats_from_nested normalizes variant IDs to chr pr test_that("extract_flatten_sumstats_from_nested normalizes variant IDs from nested search", { data <- list( nested = list( - variant_names = c("1:100:A:G"), + finemapping_result = .test_fm_result(c("1:100:A:G")), sumstats = list(z = c(3.0)) ) ) @@ -1122,7 +1133,7 @@ test_that("extract_flatten_sumstats_from_nested recurses through multiple nestin data <- list( level1 = list( level2 = list( - variant_names = c("chr1:100:A:G", "chr1:200:C:T"), + finemapping_result = .test_fm_result(c("chr1:100:A:G", "chr1:200:C:T")), sumstats = list( betahat = c(0.5, -0.3), sebetahat = c(0.1, 0.15) @@ -1141,7 +1152,7 @@ test_that("extract_flatten_sumstats_from_nested returns NULL for deeply nested b a = list( b = list( c = list( - variant_names = c("1:100:A:G"), + finemapping_result = .test_fm_result(c("1:100:A:G")), sumstats = list(z = c(2.0)) ) ) diff --git a/tests/testthat/test_sumstats_qc.R b/tests/testthat/test_sumstats_qc.R index be240f6e..3b2ae869 100644 --- a/tests/testthat/test_sumstats_qc.R +++ b/tests/testthat/test_sumstats_qc.R @@ -93,9 +93,9 @@ test_that("rss_basic_qc processes matching variants correctly", { LD_data <- make_ld_data_s4(LD_mat, variant_ids) result <- rss_basic_qc(sumstats, LD_data) - expect_type(result, "list") - expect_true("sumstats" %in% names(result)) - expect_true("LD_mat" %in% names(result)) + expect_true(is(result, "QCResult")) + expect_true(!is.null(getRSSInput(result)$sumstats)) + expect_true(!is.null(getLDData(result))) }) test_that("rss_basic_qc skips variants in specified region", { @@ -103,10 +103,10 @@ test_that("rss_basic_qc skips variants in specified region", { result <- rss_basic_qc(td$sumstats, td$LD_data, skip_region = "1:150-350") - expect_type(result, "list") - expect_true("sumstats" %in% names(result)) - expect_true("LD_mat" %in% names(result)) - remaining_pos <- result$sumstats$pos + expect_true(is(result, "QCResult")) + expect_true(!is.null(getRSSInput(result)$sumstats)) + expect_true(!is.null(getLDData(result))) + remaining_pos <- getRSSInput(result)$sumstats$pos expect_false(200 %in% remaining_pos) expect_false(300 %in% remaining_pos) }) @@ -114,7 +114,7 @@ test_that("rss_basic_qc skips variants in specified region", { test_that("rss_basic_qc with skip_region preserves non-skipped variants", { td <- make_test_sumstats_ld(n_variants = 5) result <- rss_basic_qc(td$sumstats, td$LD_data, skip_region = "1:150-250") - remaining_pos <- result$sumstats$pos + remaining_pos <- getRSSInput(result)$sumstats$pos expect_false(200 %in% remaining_pos) expect_true(100 %in% remaining_pos) expect_true(300 %in% remaining_pos) @@ -123,8 +123,8 @@ test_that("rss_basic_qc with skip_region preserves non-skipped variants", { test_that("rss_basic_qc with keep_indel=FALSE removes indel variants", { td <- make_test_sumstats_ld(n_variants = 5, with_indels = TRUE) result <- rss_basic_qc(td$sumstats, td$LD_data, keep_indel = FALSE) - expect_type(result, "list") - expect_lte(nrow(result$sumstats), nrow(td$sumstats)) + expect_true(is(result, "QCResult")) + expect_lte(nrow(getRSSInput(result)$sumstats), nrow(td$sumstats)) }) test_that("rss_basic_qc errors when no variants overlap", { @@ -174,8 +174,8 @@ test_that("rss_basic_qc aligns variant IDs by stripping build suffix", { LD_data <- make_ld_data_s4(LD_mat, base_ids) result <- rss_basic_qc(sumstats, LD_data) - expect_type(result, "list") - expect_true(nrow(result$sumstats) > 0) + expect_true(is(result, "QCResult")) + expect_true(nrow(getRSSInput(result)$sumstats) > 0) }) test_that("rss_basic_qc handles chr prefix differences during alignment", { @@ -201,15 +201,17 @@ test_that("rss_basic_qc handles chr prefix differences during alignment", { LD_data <- make_ld_data_s4(LD_mat, base_ids) result <- rss_basic_qc(sumstats, LD_data) - expect_type(result, "list") - expect_true(nrow(result$sumstats) > 0) + expect_true(is(result, "QCResult")) + expect_true(nrow(getRSSInput(result)$sumstats) > 0) }) test_that("rss_basic_qc output LD_mat has same dimension as sumstats rows", { td <- make_test_sumstats_ld(n_variants = 6) result <- rss_basic_qc(td$sumstats, td$LD_data) - expect_equal(nrow(result$LD_mat), nrow(result$sumstats)) - expect_equal(ncol(result$LD_mat), nrow(result$sumstats)) + result_ld_mat <- getCorrelation(getLDData(result)) + result_sumstats <- getRSSInput(result)$sumstats + expect_equal(nrow(result_ld_mat), nrow(result_sumstats)) + expect_equal(ncol(result_ld_mat), nrow(result_sumstats)) }) test_that("rss_basic_qc errors when LD matrix has NULL rownames", { @@ -231,7 +233,7 @@ test_that("rss_basic_qc handles multiple skip regions", { td <- make_test_sumstats_ld(n_variants = 10) result <- rss_basic_qc(td$sumstats, td$LD_data, skip_region = c("1:099-250", "1:650-850")) - remaining_pos <- result$sumstats$pos + remaining_pos <- getRSSInput(result)$sumstats$pos expect_false(100 %in% remaining_pos) expect_false(200 %in% remaining_pos) expect_false(700 %in% remaining_pos) @@ -252,8 +254,8 @@ test_that("rss_basic_qc can skip LD matrix subsetting for genotype references", result <- rss_basic_qc(td$sumstats, LD_data_geno, return_LD_mat = FALSE) - expect_true(nrow(result$sumstats) > 0) - expect_null(result$LD_mat) + expect_true(nrow(getRSSInput(result)$sumstats) > 0) + expect_null(getLDData(result)) }) # =========================================================================== @@ -285,16 +287,16 @@ test_that("summary_stats_qc with slalom method returns correct structure", { ) result <- summary_stats_qc( - basic_result$sumstats, td$LD_data, + getRSSInput(basic_result)$sumstats, td$LD_data, n = 10000, method = "slalom" ) - expect_type(result, "list") - expect_true("sumstats" %in% names(result)) - expect_true("LD_mat" %in% names(result)) - expect_true("outlier_number" %in% names(result)) - expect_equal(result$outlier_number, 1) - expect_equal(nrow(result$sumstats), nrow(basic_result$sumstats) - 1) + expect_true(is(result, "QCResult")) + expect_true(!is.null(getRSSInput(result)$sumstats)) + expect_true(!is.null(getLDData(result))) + expect_equal(getOutlierNumber(result), 1) + expect_equal(nrow(getRSSInput(result)$sumstats), + nrow(getRSSInput(basic_result)$sumstats) - 1) }) test_that("summary_stats_qc with slalom and no outliers keeps all variants", { @@ -314,11 +316,12 @@ test_that("summary_stats_qc with slalom and no outliers keeps all variants", { ) result <- summary_stats_qc( - basic_result$sumstats, td$LD_data, + getRSSInput(basic_result)$sumstats, td$LD_data, n = 10000, method = "slalom" ) - expect_equal(result$outlier_number, 0) - expect_equal(nrow(result$sumstats), nrow(basic_result$sumstats)) + expect_equal(getOutlierNumber(result), 0) + expect_equal(nrow(getRSSInput(result)$sumstats), + nrow(getRSSInput(basic_result)$sumstats)) }) test_that("summary_stats_qc with dentist method returns correct structure", { @@ -336,16 +339,16 @@ test_that("summary_stats_qc with dentist method returns correct structure", { ) result <- summary_stats_qc( - basic_result$sumstats, td$LD_data, + getRSSInput(basic_result)$sumstats, td$LD_data, n = 10000, method = "dentist" ) - expect_type(result, "list") - expect_true("sumstats" %in% names(result)) - expect_true("LD_mat" %in% names(result)) - expect_true("outlier_number" %in% names(result)) - expect_equal(result$outlier_number, 1) - expect_equal(nrow(result$sumstats), nrow(basic_result$sumstats) - 1) + expect_true(is(result, "QCResult")) + expect_true(!is.null(getRSSInput(result)$sumstats)) + expect_true(!is.null(getLDData(result))) + expect_equal(getOutlierNumber(result), 1) + expect_equal(nrow(getRSSInput(result)$sumstats), + nrow(getRSSInput(basic_result)$sumstats) - 1) }) test_that("summary_stats_qc with dentist and all outliers returns empty", { @@ -363,11 +366,12 @@ test_that("summary_stats_qc with dentist and all outliers returns empty", { ) result <- summary_stats_qc( - basic_result$sumstats, td$LD_data, + getRSSInput(basic_result)$sumstats, td$LD_data, n = 10000, method = "dentist" ) - expect_equal(nrow(result$sumstats), 0) - expect_equal(result$outlier_number, nrow(basic_result$sumstats)) + expect_equal(nrow(getRSSInput(result)$sumstats), 0) + expect_equal(getOutlierNumber(result), + nrow(getRSSInput(basic_result)$sumstats)) }) test_that("summary_stats_qc returns LD_mat matching filtered sumstats dimensions", { @@ -384,11 +388,13 @@ test_that("summary_stats_qc returns LD_mat matching filtered sumstats dimensions ) result <- summary_stats_qc( - basic_result$sumstats, td$LD_data, + getRSSInput(basic_result)$sumstats, td$LD_data, n = 10000, method = "slalom" ) - expect_equal(nrow(result$LD_mat), nrow(result$sumstats)) - expect_equal(ncol(result$LD_mat), nrow(result$sumstats)) + result_ld_mat <- getCorrelation(getLDData(result)) + result_sumstats <- getRSSInput(result)$sumstats + expect_equal(nrow(result_ld_mat), nrow(result_sumstats)) + expect_equal(ncol(result_ld_mat), nrow(result_sumstats)) }) test_that("summary_stats_qc basic genotype-backed path does not compute LD", { @@ -410,8 +416,10 @@ test_that("summary_stats_qc basic genotype-backed path does not compute LD", { qc_method = "none", impute = FALSE), "basic harmonization retained" ) - expect_equal(nrow(result$LD_matrix), nrow(X_ref)) - expect_equal(ncol(result$LD_matrix), nrow(result$rss_input$sumstats)) + result_ld <- getLDData(result) + result_geno <- getGenotypes(result_ld) + expect_equal(nrow(result_geno), nrow(X_ref)) + expect_equal(ncol(result_geno), nrow(getRSSInput(result)$sumstats)) }) test_that("summary_stats_qc accepts genotype-backed LDData", { @@ -467,8 +475,10 @@ test_that("summary_stats_qc accepts genotype-backed LDData", { impute = FALSE )) - expect_equal(nrow(result$LD_matrix), 100L) - expect_equal(ncol(result$LD_matrix), nrow(result$rss_input$sumstats)) + result_ld <- getLDData(result) + result_geno <- getGenotypes(result_ld) + expect_equal(nrow(result_geno), 100L) + expect_equal(ncol(result_geno), nrow(getRSSInput(result)$sumstats)) }) test_that("summary_stats_qc PIP screening uses LD-independent SER", { @@ -495,7 +505,9 @@ test_that("summary_stats_qc PIP screening uses LD-independent SER", { pip_cutoff_to_skip = 0.1, impute = FALSE )) - expect_equal(ncol(result$LD_matrix), nrow(result$rss_input$sumstats)) + result_ld <- getLDData(result) + result_R <- getCorrelation(result_ld) + expect_equal(ncol(result_R), nrow(getRSSInput(result)$sumstats)) }) test_that("summary_stats_qc treats NULL qc_method as basic-only none", { @@ -515,7 +527,7 @@ test_that("summary_stats_qc treats NULL qc_method as basic-only none", { ), "basic harmonization retained" ) - expect_equal(nrow(result$rss_input$sumstats), nrow(td$sumstats)) + expect_equal(nrow(getRSSInput(result)$sumstats), nrow(td$sumstats)) }) test_that("summary_stats_qc rejects invalid qc_method values", { @@ -566,8 +578,12 @@ test_that("summary_stats_qc LD-mismatch QC computes only filtered local LD from impute = FALSE )) expect_equal(compute_calls, 2) - expect_equal(ncol(result$LD_matrix), nrow(result$rss_input$sumstats)) - expect_equal(ncol(result$LD_matrix), 3) + result_ld <- getLDData(result) + # getGenotypes is mocked above to always return full X_ref, so read the + # subsetted handle stored in the LDData slot directly to verify subsetting. + result_geno <- result_ld@genotype_handle + expect_equal(ncol(result_geno), nrow(getRSSInput(result)$sumstats)) + expect_equal(ncol(result_geno), 3) }) # =========================================================================== diff --git a/tests/testthat/test_twas.R b/tests/testthat/test_twas.R index bb42a909..dfb6a97c 100644 --- a/tests/testthat/test_twas.R +++ b/tests/testthat/test_twas.R @@ -782,7 +782,7 @@ test_that("harmonize_gwas: computes z from beta and se when z is absent", { }, match_ref_panel = function(target_data, ref_data, ...) { target_data$variant_id <- paste0("chr", target_data$chrom, ":", target_data$pos, ":", target_data$A2, ":", target_data$A1) - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } ) result <- harmonize_gwas( @@ -829,7 +829,7 @@ test_that("harmonize_gwas: renames #chrom to chrom in tabix output", { }, match_ref_panel = function(target_data, ref_data, ...) { target_data$variant_id <- paste0("chr", target_data$chrom, ":", target_data$pos, ":", target_data$A2, ":", target_data$A1) - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } ) ld_variants <- c("chr1:100:T:A", "chr1:200:C:G") @@ -855,7 +855,7 @@ test_that("harmonize_gwas: uses load_rss_data when column_file_path is provided" }, match_ref_panel = function(target_data, ref_data, ...) { target_data$variant_id <- paste0("chr", target_data$chrom, ":", target_data$pos, ":", target_data$A2, ":", target_data$A1) - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } ) ld_variants <- c("chr1:100:T:A", "chr1:200:C:G") @@ -1100,7 +1100,7 @@ test_that("harmonize_gwas: rows with NA or Inf z are removed from output", { }, match_ref_panel = function(target_data, ref_data, ...) { target_data$variant_id <- paste0("chr", target_data$chrom, ":", target_data$pos, ":", target_data$A2, ":", target_data$A1) - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } ) ld_variants <- paste0("chr1:", c(100, 200, 300, 400), ":", c("T", "C", "A", "G"), ":", c("A", "G", "T", "C")) @@ -1199,7 +1199,7 @@ test_that("harmonize_gwas: col_to_flip parameter is passed through to match_ref_ match_ref_panel = function(target_data, ref_data, col_to_flip = NULL, ...) { received_col_to_flip <<- col_to_flip target_data$variant_id <- paste0("chr", target_data$chrom, ":", target_data$pos, ":", target_data$A2, ":", target_data$A1) - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } ) ld_variants <- c("chr1:100:T:A", "chr1:200:C:G") @@ -1226,7 +1226,7 @@ test_that("harmonize_gwas: match_min_prop parameter is passed to match_ref_panel match_ref_panel = function(target_data, ref_data, col_to_flip = NULL, match_min_prop = 0.2, ...) { received_match_min_prop <<- match_min_prop target_data$variant_id <- paste0("chr", target_data$chrom, ":", target_data$pos, ":", target_data$A2, ":", target_data$A1) - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } ) ld_variants <- c("chr1:100:T:A", "chr1:200:C:G") @@ -1251,7 +1251,7 @@ test_that("harmonize_gwas: z computed from beta/se has correct values", { }, match_ref_panel = function(target_data, ref_data, ...) { target_data$variant_id <- paste0("chr", target_data$chrom, ":", target_data$pos, ":", target_data$A2, ":", target_data$A1) - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } ) ld_variants <- c("chr1:100:T:A", "chr1:200:C:G", "chr1:300:A:T") @@ -1273,7 +1273,7 @@ test_that("harmonize_gwas: only keeps rows with finite non-NA z after allele_qc" }, match_ref_panel = function(target_data, ref_data, ...) { target_data$variant_id <- paste0("chr", target_data$chrom, ":", target_data$pos, ":", target_data$A2, ":", target_data$A1) - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } ) ld_variants <- paste0("chr1:", c(100, 200, 300, 400, 500), ":T:A") @@ -1359,7 +1359,7 @@ test_that("harmonize_gwas: existing z column is used directly", { }, match_ref_panel = function(target_data, ref_data, ...) { target_data$variant_id <- paste0("chr", target_data$chrom, ":", target_data$pos, ":", target_data$A2, ":", target_data$A1) - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } ) ld_variants <- c("chr1:100:T:A", "chr1:200:C:G") @@ -2388,7 +2388,7 @@ test_that("harmonize_gwas: complete flow with beta/se produces correct z-scores }, match_ref_panel = function(target_data, ref_data, ...) { target_data$variant_id <- paste0("chr", target_data$chrom, ":", target_data$pos, ":", target_data$A2, ":", target_data$A1) - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } ) ld_variants <- paste0("chr1:", c(100, 200, 300, 400), ":", c("T", "C", "A", "G"), ":", c("A", "G", "T", "C")) @@ -2840,9 +2840,9 @@ test_that("harmonize_twas: group_contexts_by_region single context path (lines 4 rownames(target_data) } } - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } else { - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } } ) @@ -2971,9 +2971,9 @@ test_that("harmonize_twas: group_contexts_by_region multi-context clustering (li rownames(target_data) } } - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } else { - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } } ) @@ -3417,9 +3417,9 @@ test_that("harmonize_twas: duplicated LD variants are removed", { target_data$variant_id <- paste0("chr", target_data$chrom, ":", target_data$pos, ":", target_data$A2, ":", target_data$A1) } - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } else { - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } } ) @@ -3480,7 +3480,7 @@ test_that("harmonize_twas: drops molecular_id when harmonize_gwas returns NULL f }, # Returning NULL skips the entire context loop, so gwas_qced stays empty harmonize_gwas = function(...) NULL, - match_ref_panel = function(target_data, ref_data, ...) list(target_data_qced = target_data) + match_ref_panel = function(target_data, ref_data, ...) AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) ) gwas_meta <- data.frame( @@ -3563,9 +3563,9 @@ test_that("harmonize_twas: susie_weights column triggers adjust_susie_weights br target_data$variant_id <- paste0("chr", target_data$chrom, ":", target_data$pos, ":", target_data$A2, ":", target_data$A1) } - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } else { - list(target_data_qced = target_data) + AlleleQCResult(harmonized_data = target_data, qc_summary = target_data) } }, adjust_susie_weights = function(twas_data, keep_variants, ...) { diff --git a/tests/testthat/test_univariate_pipeline.R b/tests/testthat/test_univariate_pipeline.R index 35a92676..440701ac 100644 --- a/tests/testthat/test_univariate_pipeline.R +++ b/tests/testthat/test_univariate_pipeline.R @@ -98,6 +98,44 @@ make_test_ld_data <- function(variant_ids, R = NULL) { LDData(correlation = R, variants = variants_gr, block_metadata = bm) } +# =========================================================================== +# Helper: build an LDData S4 object from a matrix for QCResult mocks. +# When is_genotype = FALSE, the matrix is the correlation R (variants on row +# and column names). When is_genotype = TRUE, the matrix is a genotype matrix +# (samples x variants) and colnames are variant IDs. +# =========================================================================== +.test_lddata_from_matrix <- function(mat, is_genotype = FALSE) { + vids <- if (is_genotype) colnames(mat) else rownames(mat) + if (is.null(vids)) vids <- colnames(mat) + ref_panel <- cbind(pecotmr:::parse_variant_id(vids), variant_id = vids) + ref_panel$chrom <- as.character(ref_panel$chrom) + variants_gr <- pecotmr:::.ref_panel_to_granges(ref_panel) + bm <- pecotmr:::.infer_single_ld_block_metadata(ref_panel) + if (is_genotype) { + LDData(correlation = NULL, genotype_handle = mat, + variants = variants_gr, block_metadata = bm, + n_ref = as.integer(nrow(mat))) + } else { + LDData(correlation = mat, variants = variants_gr, block_metadata = bm) + } +} + +# =========================================================================== +# Helper: build a QCResult mock from a sumstats data.frame and LD matrix. +# =========================================================================== +.test_qcresult <- function(sumstats, ld_mat, n = 1000, var_y = 1, + outlier_number = 0L, skipped = FALSE, + is_genotype = FALSE) { + ld <- .test_lddata_from_matrix(ld_mat, is_genotype = is_genotype) + QCResult( + ld_data = ld, + rss_input = list(sumstats = sumstats, n = n, var_y = var_y), + preprocess = list(sumstats = sumstats, ld_data = ld), + outlier_number = as.integer(outlier_number), + skipped = skipped + ) +} + make_fake_twas_result <- function(p) { list( twas_weights = setNames(rep(0.1, p), paste0("chr1:", seq_len(p), ":A:G")), @@ -975,7 +1013,13 @@ test_that("rss: empty sumstats after rss_basic_qc => early return", { ) }, rss_basic_qc = function(...) { - list(sumstats = data.frame(), LD_mat = matrix(nrow = 0, ncol = 0)) + QCResult( + ld_data = NULL, + rss_input = list(sumstats = data.frame(), n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ) }, ) @@ -1022,7 +1066,13 @@ test_that("rss: pip_cutoff_to_skip > 0, no signal => early return", { local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), susie_ser = function(...) list(pip = rep(0.01, 5)), ) @@ -1046,11 +1096,17 @@ test_that("rss: pip_cutoff_to_skip > 0, signal detected => continues", { local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), susie_rss = function(...) list(pip = c(0.9, 0.01, 0.01, 0.01, 0.01)), summary_stats_qc = function(...) { message("Follow-up on region: signals above PIP threshold 0.5 detected.") - list(sumstats = ss, LD_mat = ld_mat, outlier_number = 0) + .test_qcresult(ss, ld_mat, outlier_number = 0) }, partition_LD_matrix = function(...) ld_mat, raiss = function(...) list(result_filter = ss, LD_mat = ld_mat), @@ -1079,7 +1135,13 @@ test_that("rss: negative pip_cutoff_to_skip auto-computes threshold", { local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), susie_ser = function(...) list(pip = rep(0.01, 5)), ) @@ -1112,10 +1174,19 @@ test_that("rss: full pipeline with QC, imputation, and fine-mapping", { local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) { + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(..., impute = FALSE) { qc_called <<- TRUE - list(sumstats = ss, LD_mat = ld_mat, outlier_number = 1) + # Imputation now happens inside summary_stats_qc, so simulate that + # branch in the mock to keep the raiss call observable. + if (isTRUE(impute)) raiss() + .test_qcresult(ss, ld_mat, outlier_number = 1) }, partition_LD_matrix = function(...) list(ld_matrices = list(ld_mat)), raiss = function(...) { @@ -1152,8 +1223,14 @@ test_that("rss: method name is correct for no-impute with QC", { local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) list(sumstats = ss, LD_mat = ld_mat, outlier_number = 0), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(...) .test_qcresult(ss, ld_mat, outlier_number = 0), susie_rss_pipeline = function(...) fake_result, ) @@ -1176,7 +1253,13 @@ test_that("rss: method name is correct for no QC", { local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), partition_LD_matrix = function(...) list(ld_matrices = list(ld_mat)), raiss = function(...) list(result_filter = ss, LD_mat = ld_mat), susie_rss_pipeline = function(...) fake_result, @@ -1205,8 +1288,14 @@ test_that("rss: outlier_number is stored in result when QC is active", { local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) list(sumstats = ss, LD_mat = ld_mat, outlier_number = 3), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(...) .test_qcresult(ss, ld_mat, outlier_number = 3), susie_rss_pipeline = function(...) fake_result, ) @@ -1235,7 +1324,13 @@ test_that("rss: finemapping_method = NULL skips fine-mapping", { finemapping_called <- FALSE local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), susie_rss_pipeline = function(...) { finemapping_called <<- TRUE list() @@ -1268,10 +1363,16 @@ test_that("rss: qc_method = NULL uses combined basic QC without LD-mismatch meth qc_called <- FALSE local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), summary_stats_qc = function(...) { qc_called <<- TRUE - list(sumstats = ss, LD_mat = ld_mat, outlier_number = 0) + .test_qcresult(ss, ld_mat, outlier_number = 0) }, susie_rss_pipeline = function(...) fake_result, ) @@ -1300,8 +1401,14 @@ test_that("rss: impute = FALSE skips raiss imputation", { raiss_called <- FALSE local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) list(sumstats = ss, LD_mat = ld_mat, outlier_number = 0), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(...) .test_qcresult(ss, ld_mat, outlier_number = 0), raiss = function(...) { raiss_called <<- TRUE list(result_filter = ss, LD_mat = ld_mat) @@ -1331,8 +1438,14 @@ test_that("rss: diagnostics = TRUE with empty fine-mapping result skips diagnost local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) list(sumstats = ss, LD_mat = ld_mat, outlier_number = 0), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(...) .test_qcresult(ss, ld_mat, outlier_number = 0), susie_rss_pipeline = function(...) list(), # empty result ) @@ -1402,8 +1515,14 @@ test_that("rss: diagnostics with 2+ CS and high p-value/corr triggers BCR and SE susie_rss_pipeline_call_count <- 0 local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) list(sumstats = ss, LD_mat = ld_mat, outlier_number = 0), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(...) .test_qcresult(ss, ld_mat, outlier_number = 0), susie_rss_pipeline = function(...) { susie_rss_pipeline_call_count <<- susie_rss_pipeline_call_count + 1 fake_result @@ -1476,8 +1595,14 @@ test_that("rss: diagnostics with 1 CS triggers SER reanalysis only", { susie_rss_pipeline_call_count <- 0 local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) list(sumstats = ss, LD_mat = ld_mat, outlier_number = 0), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(...) .test_qcresult(ss, ld_mat, outlier_number = 0), susie_rss_pipeline = function(...) { susie_rss_pipeline_call_count <<- susie_rss_pipeline_call_count + 1 fake_result @@ -1542,8 +1667,14 @@ test_that("rss: diagnostics with no CS but high PIP calls extract_top_pip_info", local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) list(sumstats = ss, LD_mat = ld_mat, outlier_number = 0), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(...) .test_qcresult(ss, ld_mat, outlier_number = 0), susie_rss_pipeline = function(...) fake_result, get_susie_result = function(res) res$susie_result_trimmed, extract_top_pip_info = function(...) { @@ -1592,8 +1723,14 @@ test_that("rss: diagnostics with no CS and no high PIP => diagnostics empty", { local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) list(sumstats = ss, LD_mat = ld_mat, outlier_number = 0), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(...) .test_qcresult(ss, ld_mat, outlier_number = 0), susie_rss_pipeline = function(...) fake_result, get_susie_result = function(res) res$susie_result_trimmed, ) @@ -1633,7 +1770,13 @@ test_that("rss: finemapping_opts are forwarded to susie_rss_pipeline", { captured_R_mismatch <- NULL local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), susie_rss_pipeline = function(sumstats, LD_mat, n, var_y, L, L_greedy, analysis_method, coverage, secondary_coverage, signal_cutoff, min_abs_corr, ...) { @@ -1679,8 +1822,14 @@ test_that("rss: dentist QC method generates correct method name", { local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) list(sumstats = ss, LD_mat = ld_mat, outlier_number = 2), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(...) .test_qcresult(ss, ld_mat, outlier_number = 2), partition_LD_matrix = function(...) list(ld_matrices = list(ld_mat)), raiss = function(...) list(result_filter = ss, LD_mat = ld_mat), susie_rss_pipeline = function(...) fake_result, @@ -1741,8 +1890,14 @@ test_that("rss: diagnostics with get_susie_result returning NULL => diagnostics local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) list(sumstats = ss, LD_mat = ld_mat, outlier_number = 0), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(...) .test_qcresult(ss, ld_mat, outlier_number = 0), susie_rss_pipeline = function(...) fake_result, get_susie_result = function(res) NULL, ) @@ -1785,8 +1940,14 @@ test_that("rss: diagnostics with null/empty block_cs_metrics => no additional an susie_rss_call_count <- 0 local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) list(sumstats = ss, LD_mat = ld_mat, outlier_number = 0), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(...) .test_qcresult(ss, ld_mat, outlier_number = 0), susie_rss_pipeline = function(...) { susie_rss_call_count <<- susie_rss_call_count + 1 fake_result @@ -1868,8 +2029,14 @@ test_that("rss: diagnostics with 2 CS but low p-value and low corr => no extra a susie_rss_call_count <- 0 local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) list(sumstats = ss, LD_mat = ld_mat, outlier_number = 0), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(...) .test_qcresult(ss, ld_mat, outlier_number = 0), susie_rss_pipeline = function(...) { susie_rss_call_count <<- susie_rss_call_count + 1 fake_result @@ -1942,8 +2109,14 @@ test_that("rss: diagnostics with high max_cs_corr_study_block triggers BCR+SER", susie_rss_call_count <- 0 local_mocked_bindings( load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) list(sumstats = ss, LD_mat = ld_mat, outlier_number = 0), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(...) .test_qcresult(ss, ld_mat, outlier_number = 0), susie_rss_pipeline = function(...) { susie_rss_call_count <<- susie_rss_call_count + 1 fake_result @@ -2079,8 +2252,14 @@ test_that("rss: is_genotype=TRUE path does not precompute R and uses X for fine- stop("rss_analysis_pipeline should not precompute LD from X") }, load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) list(sumstats = ss, LD_mat = ld_mat, outlier_number = 0), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(...) .test_qcresult(ss, ld_mat, outlier_number = 0), partition_LD_matrix = function(...) list(ld_matrices = list(ld_mat)), raiss = function(...) list(result_filter = ss, LD_mat = ld_mat), susie_rss_pipeline = function(sumstats, LD_mat = NULL, X_mat = NULL, ...) { @@ -2144,8 +2323,14 @@ test_that("rss: mixture LD_data (list of X panels) preserves list shape into sus stop("rss_analysis_pipeline should not precompute LD from mixture X") }, load_rss_data = function(...) list(sumstats = ss, n = 1000, var_y = 1), - rss_basic_qc = function(...) list(sumstats = ss, LD_mat = ld_mat), - summary_stats_qc = function(...) list(sumstats = ss, LD_mat = ld_mat, outlier_number = 0), + rss_basic_qc = function(...) QCResult( + ld_data = if (nrow(ld_mat) > 0) .test_lddata_from_matrix(ld_mat) else NULL, + rss_input = list(sumstats = ss, n = NA_real_, var_y = NA_real_), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE + ), + summary_stats_qc = function(...) .test_qcresult(ss, ld_mat, outlier_number = 0), partition_LD_matrix = function(...) list(ld_matrices = list(ld_mat)), raiss = function(...) list(result_filter = ss, LD_mat = ld_mat), susie_rss_pipeline = function(sumstats, LD_mat = NULL, X_mat = NULL, ...) { diff --git a/tests/testthat/test_univariate_rss_diagnostics.R b/tests/testthat/test_univariate_rss_diagnostics.R index 156a4c7b..9025958d 100644 --- a/tests/testthat/test_univariate_rss_diagnostics.R +++ b/tests/testthat/test_univariate_rss_diagnostics.R @@ -1,5 +1,17 @@ context("univariate_rss_diagnostics") +.test_fm_result <- function(variant_names, trimmed_fit = list(), + top_loci = data.frame(variant_id = character(0), + method = character(0), + stringsAsFactors = FALSE)) { + FineMappingResult( + variant_names = variant_names, + trimmed_fit = trimmed_fit, + top_loci = top_loci, + method = "susie" + ) +} + # =========================================================================== # get_susie_result # =========================================================================== @@ -9,14 +21,17 @@ test_that("get_susie_result returns NULL for empty input", { expect_null(result) }) -test_that("get_susie_result returns NULL when susie_result_trimmed missing", { +test_that("get_susie_result returns NULL when finemapping_result missing", { result <- get_susie_result(list(some_data = 42)) expect_null(result) }) test_that("get_susie_result returns trimmed result when present", { mock_result <- list(pip = c(0.1, 0.5, 0.3), sets = list(cs = list())) - con_data <- list(susie_result_trimmed = mock_result) + con_data <- list(finemapping_result = .test_fm_result( + variant_names = character(0), + trimmed_fit = mock_result + )) result <- get_susie_result(con_data) expect_equal(result, mock_result) }) @@ -27,8 +42,10 @@ test_that("get_susie_result returns trimmed result when present", { test_that("extract_top_pip_info finds top PIP variant", { con_data <- list( - susie_result_trimmed = list(pip = c(0.1, 0.7, 0.2)), - variant_names = c("1:100:A:G", "1:200:C:T", "1:300:G:A"), + finemapping_result = .test_fm_result( + variant_names = c("1:100:A:G", "1:200:C:T", "1:300:G:A"), + trimmed_fit = list(pip = c(0.1, 0.7, 0.2)) + ), sumstats = list(z = c(1.0, 3.5, -0.5)) ) result <- extract_top_pip_info(con_data) @@ -42,8 +59,10 @@ test_that("extract_top_pip_info finds top PIP variant", { test_that("extract_top_pip_info computes p_value from z", { con_data <- list( - susie_result_trimmed = list(pip = c(0.9, 0.05, 0.05)), - variant_names = c("1:100:A:G", "1:200:C:T", "1:300:G:A"), + finemapping_result = .test_fm_result( + variant_names = c("1:100:A:G", "1:200:C:T", "1:300:G:A"), + trimmed_fit = list(pip = c(0.9, 0.05, 0.05)) + ), sumstats = list(z = c(5.0, 0.5, -0.3)) ) result <- extract_top_pip_info(con_data) @@ -53,8 +72,10 @@ test_that("extract_top_pip_info computes p_value from z", { test_that("extract_top_pip_info handles ties by taking first max", { con_data <- list( - susie_result_trimmed = list(pip = c(0.5, 0.5, 0.5)), - variant_names = c("1:100:A:G", "1:200:C:T", "1:300:G:A"), + finemapping_result = .test_fm_result( + variant_names = c("1:100:A:G", "1:200:C:T", "1:300:G:A"), + trimmed_fit = list(pip = c(0.5, 0.5, 0.5)) + ), sumstats = list(z = c(1.0, 2.0, 3.0)) ) result <- extract_top_pip_info(con_data) @@ -68,10 +89,12 @@ test_that("extract_top_pip_info handles ties by taking first max", { test_that("extract_cs_info extracts single CS correctly", { con_data <- list( - variant_names = c("1:100:A:G", "1:200:C:T", "1:300:G:A"), - susie_result_trimmed = list( - sets = list(cs = list(L_1 = c(1, 2))), - cs_corr = NULL + finemapping_result = .test_fm_result( + variant_names = c("1:100:A:G", "1:200:C:T", "1:300:G:A"), + trimmed_fit = list( + sets = list(cs = list(L_1 = c(1, 2))), + cs_corr = NULL + ) ) ) top_loci_table <- data.frame( @@ -91,12 +114,14 @@ test_that("extract_cs_info extracts single CS correctly", { test_that("extract_cs_info extracts multiple CSs with cs_corr", { con_data <- list( - variant_names = c("1:100:A:G", "1:200:C:T", "1:300:G:A", "1:400:T:C"), - susie_result_trimmed = list( - sets = list( - cs = list(L_1 = c(1, 2), L_2 = c(3, 4)) - ), - cs_corr = matrix(c(1, 0.3, 0.3, 1), nrow = 2) + finemapping_result = .test_fm_result( + variant_names = c("1:100:A:G", "1:200:C:T", "1:300:G:A", "1:400:T:C"), + trimmed_fit = list( + sets = list( + cs = list(L_1 = c(1, 2), L_2 = c(3, 4)) + ), + cs_corr = matrix(c(1, 0.3, 0.3, 1), nrow = 2) + ) ) ) top_loci_table <- data.frame( @@ -116,10 +141,12 @@ test_that("extract_cs_info extracts multiple CSs with cs_corr", { test_that("extract_cs_info computes p_value from z-score", { con_data <- list( - variant_names = c("1:100:A:G", "1:200:C:T"), - susie_result_trimmed = list( - sets = list(cs = list(L_1 = c(1, 2))), - cs_corr = NULL + finemapping_result = .test_fm_result( + variant_names = c("1:100:A:G", "1:200:C:T"), + trimmed_fit = list( + sets = list(cs = list(L_1 = c(1, 2))), + cs_corr = NULL + ) ) ) top_loci_table <- data.frame( From 618709f6a9d47ddcef9603642ebe40c2875420bb Mon Sep 17 00:00:00 2001 From: danielnachun Date: Tue, 2 Jun 2026 07:55:21 +0000 Subject: [PATCH 11/11] Update documentation --- NAMESPACE | 40 ++++++++++++++++++++++++++ man/AlleleQCResult-class.Rd | 25 ++++++++++++++++ man/AlleleQCResult.Rd | 20 +++++++++++++ man/MultivariateRegionalData-class.Rd | 35 ++++++++++++++++++++++ man/MultivariateRegionalData.Rd | 36 +++++++++++++++++++++++ man/QCResult-class.Rd | 31 ++++++++++++++++++++ man/QCResult.Rd | 36 +++++++++++++++++++++++ man/c-RegionalData-method.Rd | 24 ++++++++++++++++ man/getChrom.Rd | 23 +++++++++++++++ man/getCovariates.Rd | 21 ++++++++++++++ man/getGenotypeMatrix.Rd | 24 ++++++++++++++++ man/getGrange.Rd | 24 ++++++++++++++++ man/getHarmonizedData.Rd | 21 ++++++++++++++ man/getLDData.Rd | 20 +++++++++++++ man/getMaf.Rd | 6 +++- man/getOutlierNumber.Rd | 20 +++++++++++++ man/getPhenotypes.Rd | 21 ++++++++++++++ man/getPreprocess.Rd | 21 ++++++++++++++ man/getQCSummary.Rd | 21 ++++++++++++++ man/getRSSInput.Rd | 20 +++++++++++++ man/getSkipReason.Rd | 20 +++++++++++++ man/getVariantInfo.Rd | 3 ++ man/getXVariance.Rd | 26 +++++++++++++++++ man/getYMatrix.Rd | 21 ++++++++++++++ man/getYScalar.Rd | 21 ++++++++++++++ man/get_susie_result.Rd | 25 ++++++---------- man/isSkipped.Rd | 20 +++++++++++++ man/load_multitask_regional_data.Rd | 2 +- man/load_regional_association_data.Rd | 22 ++++---------- man/load_regional_functional_data.Rd | 8 ++++-- man/load_regional_multivariate_data.Rd | 11 +++++-- man/load_regional_regression_data.Rd | 9 ++++-- man/load_regional_univariate_data.Rd | 10 +++++-- man/load_study_LD.Rd | 5 ++-- man/match_ref_panel.Rd | 5 +++- man/region_data_to_rss_input.Rd | 18 ++++++++++++ man/rss_analysis_pipeline.Rd | 9 +++--- man/rss_basic_qc.Rd | 17 ++++++----- man/summary_stats_qc.Rd | 17 +++++------ 39 files changed, 685 insertions(+), 73 deletions(-) create mode 100644 man/AlleleQCResult-class.Rd create mode 100644 man/AlleleQCResult.Rd create mode 100644 man/MultivariateRegionalData-class.Rd create mode 100644 man/MultivariateRegionalData.Rd create mode 100644 man/QCResult-class.Rd create mode 100644 man/QCResult.Rd create mode 100644 man/c-RegionalData-method.Rd create mode 100644 man/getChrom.Rd create mode 100644 man/getCovariates.Rd create mode 100644 man/getGenotypeMatrix.Rd create mode 100644 man/getGrange.Rd create mode 100644 man/getHarmonizedData.Rd create mode 100644 man/getLDData.Rd create mode 100644 man/getOutlierNumber.Rd create mode 100644 man/getPhenotypes.Rd create mode 100644 man/getPreprocess.Rd create mode 100644 man/getQCSummary.Rd create mode 100644 man/getRSSInput.Rd create mode 100644 man/getSkipReason.Rd create mode 100644 man/getXVariance.Rd create mode 100644 man/getYMatrix.Rd create mode 100644 man/getYScalar.Rd create mode 100644 man/isSkipped.Rd create mode 100644 man/region_data_to_rss_input.Rd diff --git a/NAMESPACE b/NAMESPACE index 56eee91d..f0aae592 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,10 +6,13 @@ S3method(postprocess_finemapping_fit,susiF) S3method(postprocess_finemapping_fit,susie) S3method(postprocess_finemapping_fit,susie_inf) S3method(postprocess_finemapping_fit,susie_rss) +export(AlleleQCResult) export(AnnotationMatrix) export(FineMappingResult) export(GWASSumStats) export(LDData) +export(MultivariateRegionalData) +export(QCResult) export(RegionalData) export(TWASWeights) export(adjust_susie_weights) @@ -76,25 +79,37 @@ export(getBlockMetadata) export(getCS) export(getCVPerformance) export(getCandidates) +export(getChrom) export(getCorrelation) +export(getCovariates) export(getDataType) export(getEffects) export(getEnrichment) export(getFits) +export(getGenotypeMatrix) export(getGenotypes) +export(getGrange) +export(getHarmonizedData) export(getLBF) +export(getLDData) export(getLocal) export(getMaf) export(getMethodNames) export(getMolecularId) export(getN) +export(getOutlierNumber) export(getPIP) +export(getPhenotypes) +export(getPreprocess) +export(getQCSummary) +export(getRSSInput) export(getRefPanel) export(getResidualX) export(getResidualXScalar) export(getResidualY) export(getResidualYScalar) export(getScoreStats) +export(getSkipReason) export(getStandardized) export(getTopLoci) export(getTrimmedFit) @@ -103,6 +118,9 @@ export(getVariantIds) export(getVariantInfo) export(getVariantNames) export(getWeights) +export(getXVariance) +export(getYMatrix) +export(getYScalar) export(getZ) export(get_ctwas_meta_data) export(get_filter_lbf_index) @@ -115,6 +133,7 @@ export(harmonize_gwas) export(harmonize_twas) export(hasGenotypes) export(invert_minmax_scaling) +export(isSkipped) export(is_binary_sldsc_annot) export(l0learn_rss_weights) export(l0learn_weights) @@ -179,6 +198,7 @@ export(read_afreq) export(read_sldsc_trait) export(region_data_to_colocboost_input) export(region_data_to_ind_input) +export(region_data_to_rss_input) export(region_to_df) export(regions_overlap) export(robust_mahalanobis) @@ -223,6 +243,7 @@ export(writeSumstatsVcf) export(xgboost_imputation) export(xqtl_enrichment_wrapper) export(z_to_pvalue) +exportClasses(AlleleQCResult) exportClasses(AnnotationMatrix) exportClasses(FineMappingResult) exportClasses(GWASSumStats) @@ -233,32 +254,47 @@ exportClasses(LDData) exportClasses(LDEigen) exportClasses(LDScore) exportClasses(LDStatistic) +exportClasses(MultivariateRegionalData) +exportClasses(QCResult) exportClasses(RegionalData) exportClasses(TWASWeights) +exportMethods(c) exportMethods(computeLdScores) exportMethods(estimateH2) exportMethods(getBlockMetadata) exportMethods(getCS) exportMethods(getCVPerformance) +exportMethods(getChrom) exportMethods(getCorrelation) +exportMethods(getCovariates) exportMethods(getDataType) exportMethods(getEffects) exportMethods(getEnrichment) exportMethods(getFits) +exportMethods(getGenotypeMatrix) exportMethods(getGenotypes) +exportMethods(getGrange) +exportMethods(getHarmonizedData) exportMethods(getLBF) +exportMethods(getLDData) exportMethods(getLocal) exportMethods(getMaf) exportMethods(getMethodNames) exportMethods(getMolecularId) exportMethods(getN) +exportMethods(getOutlierNumber) exportMethods(getPIP) +exportMethods(getPhenotypes) +exportMethods(getPreprocess) +exportMethods(getQCSummary) +exportMethods(getRSSInput) exportMethods(getRefPanel) exportMethods(getResidualX) exportMethods(getResidualXScalar) exportMethods(getResidualY) exportMethods(getResidualYScalar) exportMethods(getScoreStats) +exportMethods(getSkipReason) exportMethods(getStandardized) exportMethods(getTopLoci) exportMethods(getTrimmedFit) @@ -267,8 +303,12 @@ exportMethods(getVariantIds) exportMethods(getVariantInfo) exportMethods(getVariantNames) exportMethods(getWeights) +exportMethods(getXVariance) +exportMethods(getYMatrix) +exportMethods(getYScalar) exportMethods(getZ) exportMethods(hasGenotypes) +exportMethods(isSkipped) exportMethods(nSnps) exportMethods(readAnnotations) exportMethods(readGenotypes) diff --git a/man/AlleleQCResult-class.Rd b/man/AlleleQCResult-class.Rd new file mode 100644 index 00000000..6a517593 --- /dev/null +++ b/man/AlleleQCResult-class.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllClasses.R +\docType{class} +\name{AlleleQCResult-class} +\alias{AlleleQCResult-class} +\title{Allele QC Result} +\description{ +S4 container for the output of \code{match_ref_panel} / + \code{allele_qc}. Carries the post-QC target variants alongside the full + merge / flip / strand diagnostics needed by downstream callers that + inspect what QC did. +} +\section{Slots}{ + +\describe{ +\item{\code{harmonized_data}}{A \code{data.frame} of variants retained after +allele harmonization, with reference-aligned A1/A2 and (when requested) +sign-flipped effect columns.} + +\item{\code{qc_summary}}{A \code{data.frame} carrying per-variant QC diagnostics +from the full merge: \code{variants_id_original}, \code{variants_id_qced}, +\code{exact_match}, \code{sign_flip}, \code{strand_flip}, \code{INDEL}, +\code{ID_match}, \code{keep}, etc.} +}} + diff --git a/man/AlleleQCResult.Rd b/man/AlleleQCResult.Rd new file mode 100644 index 00000000..cc7d83af --- /dev/null +++ b/man/AlleleQCResult.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllMethods.R +\name{AlleleQCResult} +\alias{AlleleQCResult} +\title{Construct an AlleleQCResult object} +\usage{ +AlleleQCResult(harmonized_data, qc_summary) +} +\arguments{ +\item{harmonized_data}{Data frame of variants retained after allele QC.} + +\item{qc_summary}{Data frame of per-variant diagnostic columns.} +} +\value{ +An \code{AlleleQCResult} object. +} +\description{ +Build an \code{AlleleQCResult} S4 object wrapping the post-QC + harmonized variants and the full per-variant QC diagnostics. +} diff --git a/man/MultivariateRegionalData-class.Rd b/man/MultivariateRegionalData-class.Rd new file mode 100644 index 00000000..ba217173 --- /dev/null +++ b/man/MultivariateRegionalData-class.Rd @@ -0,0 +1,35 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllClasses.R +\docType{class} +\name{MultivariateRegionalData-class} +\alias{MultivariateRegionalData-class} +\title{Multivariate Regional Association Data} +\description{ +S4 container for regional association data prepared for + multivariate (joint-across-conditions) modeling. Unlike + \code{RegionalData}, which carries a per-condition list of phenotype + matrices, this class assumes all conditions are jointly observed in the + same samples and packs the phenotypes into a single multivariate matrix + (samples x conditions). +} +\section{Slots}{ + +\describe{ +\item{\code{genotype_matrix}}{Numeric matrix (samples x variants), rownames are +sample IDs, colnames are variant IDs.} + +\item{\code{Y_matrix}}{Numeric matrix (samples x conditions) of residualized +phenotypes after joining conditions and (optionally) filtering rows by +minimum non-missing count.} + +\item{\code{Y_scalar}}{Numeric vector of per-condition scaling factors +(length = ncol(Y_matrix)).} + +\item{\code{dropped_samples}}{Character or list capturing sample IDs dropped +during multivariate filtering.} + +\item{\code{region}}{A \code{GRanges} (single range) or NULL.} + +\item{\code{Y_coordinates}}{A data.frame of phenotype coordinates, or NULL.} +}} + diff --git a/man/MultivariateRegionalData.Rd b/man/MultivariateRegionalData.Rd new file mode 100644 index 00000000..a479362c --- /dev/null +++ b/man/MultivariateRegionalData.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllMethods.R +\name{MultivariateRegionalData} +\alias{MultivariateRegionalData} +\title{Construct a MultivariateRegionalData object} +\usage{ +MultivariateRegionalData( + genotype_matrix, + Y_matrix, + Y_scalar, + dropped_samples = NULL, + region = NULL, + Y_coordinates = NULL +) +} +\arguments{ +\item{genotype_matrix}{Numeric matrix (samples x variants).} + +\item{Y_matrix}{Numeric matrix (samples x conditions).} + +\item{Y_scalar}{Numeric vector of per-condition scaling factors.} + +\item{dropped_samples}{Character vector or list of dropped sample IDs.} + +\item{region}{A \code{GRanges} or NULL.} + +\item{Y_coordinates}{A data.frame of phenotype coordinates, or NULL.} +} +\value{ +A \code{MultivariateRegionalData} object. +} +\description{ +Build a \code{MultivariateRegionalData} S4 object capturing + regional association data prepared for multivariate modeling (single + joint Y matrix across conditions). +} diff --git a/man/QCResult-class.Rd b/man/QCResult-class.Rd new file mode 100644 index 00000000..4243c633 --- /dev/null +++ b/man/QCResult-class.Rd @@ -0,0 +1,31 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllClasses.R +\docType{class} +\name{QCResult-class} +\alias{QCResult-class} +\title{Summary-Statistics QC Result} +\description{ +S4 container holding the output of \code{summary_stats_qc} and + \code{.summary_stats_qc_single_study}. Carries the post-QC LD reference + plus harmonized sumstats, a pre-imputation snapshot, and QC process + metadata. Replaces the legacy list-of-named-fields return shape. +} +\section{Slots}{ + +\describe{ +\item{\code{ld_data}}{An \code{LDData} S4 object containing the post-QC LD +reference (correlation and/or genotype), or NULL when QC produced no LD.} + +\item{\code{rss_input}}{List with \code{sumstats} (post-QC data.frame), \code{n}, +and \code{var_y}.} + +\item{\code{preprocess}}{List with \code{sumstats} and \code{ld_data} fields +capturing the pre-imputation snapshot for downstream re-runs.} + +\item{\code{outlier_number}}{Integer count of LD-mismatch outliers removed.} + +\item{\code{skipped}}{Single logical; TRUE when QC short-circuited.} + +\item{\code{skip_reason}}{Character string explaining a skip; empty otherwise.} +}} + diff --git a/man/QCResult.Rd b/man/QCResult.Rd new file mode 100644 index 00000000..dd8b5254 --- /dev/null +++ b/man/QCResult.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllMethods.R +\name{QCResult} +\alias{QCResult} +\title{Construct a QCResult object} +\usage{ +QCResult( + ld_data = NULL, + rss_input = list(), + preprocess = list(), + outlier_number = 0L, + skipped = FALSE, + skip_reason = "" +) +} +\arguments{ +\item{ld_data}{An \code{LDData} or NULL.} + +\item{rss_input}{List with \code{sumstats}, \code{n}, \code{var_y}.} + +\item{preprocess}{List with \code{sumstats} and \code{ld_data}.} + +\item{outlier_number}{Integer count of LD-mismatch outliers removed.} + +\item{skipped}{Single logical indicating a short-circuit.} + +\item{skip_reason}{Character explanation; defaults to empty.} +} +\value{ +A \code{QCResult} object. +} +\description{ +Build a \code{QCResult} S4 object capturing the output of + summary-statistic QC. Validates that \code{ld_data} is an \code{LDData} + or NULL. +} diff --git a/man/c-RegionalData-method.Rd b/man/c-RegionalData-method.Rd new file mode 100644 index 00000000..d41e911c --- /dev/null +++ b/man/c-RegionalData-method.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllMethods.R +\name{c,RegionalData-method} +\alias{c,RegionalData-method} +\title{Combine Two RegionalData Objects} +\usage{ +\S4method{c}{RegionalData}(x, ...) +} +\arguments{ +\item{x}{First \code{RegionalData} object.} + +\item{y}{Second \code{RegionalData} object.} +} +\value{ +A merged \code{RegionalData}. +} +\description{ +Concatenate two \code{RegionalData} objects by appending + their per-condition slots (phenotypes, covariates, maf, dropped_samples). + Used by multi-panel pipelines that load per-LD-panel data and aggregate + them. The \code{genotype_matrix} of \code{x} is retained as the + canonical genotype reference; the \code{region} is taken from \code{y} + (mirrors prior list-merge behavior). +} diff --git a/man/getChrom.Rd b/man/getChrom.Rd new file mode 100644 index 00000000..cf8d6120 --- /dev/null +++ b/man/getChrom.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getChrom} +\alias{getChrom} +\alias{getChrom,MultivariateRegionalData-method} +\alias{getChrom,RegionalData-method} +\title{Get Region Chromosome} +\usage{ +getChrom(x) + +\S4method{getChrom}{MultivariateRegionalData}(x) + +\S4method{getChrom}{RegionalData}(x) +} +\arguments{ +\item{x}{The object.} +} +\value{ +A single character string, or NULL. +} +\description{ +Extract the chromosome name from a region-bearing S4 object. +} diff --git a/man/getCovariates.Rd b/man/getCovariates.Rd new file mode 100644 index 00000000..3ad8f703 --- /dev/null +++ b/man/getCovariates.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getCovariates} +\alias{getCovariates} +\alias{getCovariates,RegionalData-method} +\title{Get Covariate List} +\usage{ +getCovariates(x) + +\S4method{getCovariates}{RegionalData}(x) +} +\arguments{ +\item{x}{A \code{RegionalData} object.} +} +\value{ +A named list of covariate matrices. +} +\description{ +Extract the per-condition covariate list from a + \code{RegionalData}. +} diff --git a/man/getGenotypeMatrix.Rd b/man/getGenotypeMatrix.Rd new file mode 100644 index 00000000..d680c728 --- /dev/null +++ b/man/getGenotypeMatrix.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getGenotypeMatrix} +\alias{getGenotypeMatrix} +\alias{getGenotypeMatrix,RegionalData-method} +\alias{getGenotypeMatrix,MultivariateRegionalData-method} +\title{Get Genotype Matrix} +\usage{ +getGenotypeMatrix(x) + +\S4method{getGenotypeMatrix}{RegionalData}(x) + +\S4method{getGenotypeMatrix}{MultivariateRegionalData}(x) +} +\arguments{ +\item{x}{The object.} +} +\value{ +A numeric matrix (samples x variants). +} +\description{ +Extract the raw genotype matrix from a + \code{RegionalData} or \code{MultivariateRegionalData}. +} diff --git a/man/getGrange.Rd b/man/getGrange.Rd new file mode 100644 index 00000000..615abe11 --- /dev/null +++ b/man/getGrange.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getGrange} +\alias{getGrange} +\alias{getGrange,MultivariateRegionalData-method} +\alias{getGrange,RegionalData-method} +\title{Get Region Range} +\usage{ +getGrange(x) + +\S4method{getGrange}{MultivariateRegionalData}(x) + +\S4method{getGrange}{RegionalData}(x) +} +\arguments{ +\item{x}{The object.} +} +\value{ +A character vector of length 2, or NULL. +} +\description{ +Extract the start/end positions from a region-bearing S4 + object as a character vector \code{c(start, end)}. +} diff --git a/man/getHarmonizedData.Rd b/man/getHarmonizedData.Rd new file mode 100644 index 00000000..8e05f7df --- /dev/null +++ b/man/getHarmonizedData.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getHarmonizedData} +\alias{getHarmonizedData} +\alias{getHarmonizedData,AlleleQCResult-method} +\title{Get Harmonized Variant Data} +\usage{ +getHarmonizedData(x) + +\S4method{getHarmonizedData}{AlleleQCResult}(x) +} +\arguments{ +\item{x}{An \code{AlleleQCResult} object.} +} +\value{ +A \code{data.frame} of harmonized variants. +} +\description{ +Extract the post-QC, reference-harmonized variants from an + \code{AlleleQCResult}. +} diff --git a/man/getLDData.Rd b/man/getLDData.Rd new file mode 100644 index 00000000..ee0757b9 --- /dev/null +++ b/man/getLDData.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getLDData} +\alias{getLDData} +\alias{getLDData,QCResult-method} +\title{Get LD Data} +\usage{ +getLDData(x) + +\S4method{getLDData}{QCResult}(x) +} +\arguments{ +\item{x}{A \code{QCResult} object.} +} +\value{ +An \code{LDData} object, or NULL when QC produced no LD reference. +} +\description{ +Extract the post-QC LDData payload from a QCResult. +} diff --git a/man/getMaf.Rd b/man/getMaf.Rd index 395812ea..76571b0a 100644 --- a/man/getMaf.Rd +++ b/man/getMaf.Rd @@ -1,12 +1,16 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/AllGenerics.R, R/gwas_sumstats.R +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R, +% R/gwas_sumstats.R \name{getMaf} \alias{getMaf} +\alias{getMaf,MultivariateRegionalData-method} \alias{getMaf,GWASSumStats-method} \title{Get Minor Allele Frequencies} \usage{ getMaf(x) +\S4method{getMaf}{MultivariateRegionalData}(x) + \S4method{getMaf}{GWASSumStats}(x) } \arguments{ diff --git a/man/getOutlierNumber.Rd b/man/getOutlierNumber.Rd new file mode 100644 index 00000000..7c3f2cb6 --- /dev/null +++ b/man/getOutlierNumber.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getOutlierNumber} +\alias{getOutlierNumber} +\alias{getOutlierNumber,QCResult-method} +\title{Get Outlier Number} +\usage{ +getOutlierNumber(x) + +\S4method{getOutlierNumber}{QCResult}(x) +} +\arguments{ +\item{x}{A \code{QCResult} object.} +} +\value{ +Integer count. +} +\description{ +Number of LD-mismatch outliers removed during QC. +} diff --git a/man/getPhenotypes.Rd b/man/getPhenotypes.Rd new file mode 100644 index 00000000..6ee3c205 --- /dev/null +++ b/man/getPhenotypes.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getPhenotypes} +\alias{getPhenotypes} +\alias{getPhenotypes,RegionalData-method} +\title{Get Phenotype List} +\usage{ +getPhenotypes(x) + +\S4method{getPhenotypes}{RegionalData}(x) +} +\arguments{ +\item{x}{A \code{RegionalData} object.} +} +\value{ +A named list of phenotype matrices. +} +\description{ +Extract the per-condition phenotype list from a + \code{RegionalData}. +} diff --git a/man/getPreprocess.Rd b/man/getPreprocess.Rd new file mode 100644 index 00000000..9dfc2a56 --- /dev/null +++ b/man/getPreprocess.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getPreprocess} +\alias{getPreprocess} +\alias{getPreprocess,QCResult-method} +\title{Get Preprocess Snapshot} +\usage{ +getPreprocess(x) + +\S4method{getPreprocess}{QCResult}(x) +} +\arguments{ +\item{x}{A \code{QCResult} object.} +} +\value{ +A list with \code{sumstats} and \code{ld_data}. +} +\description{ +Extract the pre-imputation snapshot (\code{sumstats}, + \code{ld_data}) captured before any LD-mismatch QC or RAISS imputation. +} diff --git a/man/getQCSummary.Rd b/man/getQCSummary.Rd new file mode 100644 index 00000000..b81343f9 --- /dev/null +++ b/man/getQCSummary.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getQCSummary} +\alias{getQCSummary} +\alias{getQCSummary,AlleleQCResult-method} +\title{Get Allele QC Summary} +\usage{ +getQCSummary(x) + +\S4method{getQCSummary}{AlleleQCResult}(x) +} +\arguments{ +\item{x}{An \code{AlleleQCResult} object.} +} +\value{ +A \code{data.frame} with the diagnostic columns. +} +\description{ +Extract the full per-variant merge/flip/strand diagnostics + produced by allele QC. +} diff --git a/man/getRSSInput.Rd b/man/getRSSInput.Rd new file mode 100644 index 00000000..60cf0daf --- /dev/null +++ b/man/getRSSInput.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getRSSInput} +\alias{getRSSInput} +\alias{getRSSInput,QCResult-method} +\title{Get RSS Input} +\usage{ +getRSSInput(x) + +\S4method{getRSSInput}{QCResult}(x) +} +\arguments{ +\item{x}{A \code{QCResult} object.} +} +\value{ +A list with \code{sumstats}, \code{n}, \code{var_y}. +} +\description{ +Extract the post-QC summary-statistic record (sumstats, n, var_y). +} diff --git a/man/getSkipReason.Rd b/man/getSkipReason.Rd new file mode 100644 index 00000000..cee9d92a --- /dev/null +++ b/man/getSkipReason.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getSkipReason} +\alias{getSkipReason} +\alias{getSkipReason,QCResult-method} +\title{Get Skip Reason} +\usage{ +getSkipReason(x) + +\S4method{getSkipReason}{QCResult}(x) +} +\arguments{ +\item{x}{A \code{QCResult} object.} +} +\value{ +Character scalar. +} +\description{ +Why QC short-circuited; empty string if not skipped. +} diff --git a/man/getVariantInfo.Rd b/man/getVariantInfo.Rd index c184f784..e06e6d84 100644 --- a/man/getVariantInfo.Rd +++ b/man/getVariantInfo.Rd @@ -4,6 +4,7 @@ \alias{getVariantInfo} \alias{getVariantInfo,LDData-method} \alias{getVariantInfo,RegionalData-method} +\alias{getVariantInfo,MultivariateRegionalData-method} \title{Get Variant GRanges} \usage{ getVariantInfo(x) @@ -11,6 +12,8 @@ getVariantInfo(x) \S4method{getVariantInfo}{LDData}(x) \S4method{getVariantInfo}{RegionalData}(x) + +\S4method{getVariantInfo}{MultivariateRegionalData}(x) } \arguments{ \item{x}{An object with variant metadata.} diff --git a/man/getXVariance.Rd b/man/getXVariance.Rd new file mode 100644 index 00000000..0c2a19a8 --- /dev/null +++ b/man/getXVariance.Rd @@ -0,0 +1,26 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getXVariance} +\alias{getXVariance} +\alias{getXVariance,MultivariateRegionalData-method} +\alias{getXVariance,RegionalData-method} +\title{Get Per-Variant Variance} +\usage{ +getXVariance(x, condition = 1L) + +\S4method{getXVariance}{MultivariateRegionalData}(x, condition = 1L) + +\S4method{getXVariance}{RegionalData}(x, condition = 1L) +} +\arguments{ +\item{x}{A \code{RegionalData} object.} + +\item{condition}{Integer index of the condition.} +} +\value{ +A numeric vector (length = number of variants). +} +\description{ +Per-variant variance of residualized genotypes for a + condition. +} diff --git a/man/getYMatrix.Rd b/man/getYMatrix.Rd new file mode 100644 index 00000000..ca04f236 --- /dev/null +++ b/man/getYMatrix.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getYMatrix} +\alias{getYMatrix} +\alias{getYMatrix,MultivariateRegionalData-method} +\title{Get Multivariate Y Matrix} +\usage{ +getYMatrix(x) + +\S4method{getYMatrix}{MultivariateRegionalData}(x) +} +\arguments{ +\item{x}{A \code{MultivariateRegionalData} object.} +} +\value{ +A numeric matrix (samples x conditions). +} +\description{ +Extract the multivariate phenotype matrix from a + \code{MultivariateRegionalData}. +} diff --git a/man/getYScalar.Rd b/man/getYScalar.Rd new file mode 100644 index 00000000..b92f5e52 --- /dev/null +++ b/man/getYScalar.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{getYScalar} +\alias{getYScalar} +\alias{getYScalar,MultivariateRegionalData-method} +\title{Get Y Scaling Factors} +\usage{ +getYScalar(x) + +\S4method{getYScalar}{MultivariateRegionalData}(x) +} +\arguments{ +\item{x}{A \code{MultivariateRegionalData} object.} +} +\value{ +A numeric vector (length = number of conditions). +} +\description{ +Per-condition scaling factors used for residualized + multivariate phenotypes. +} diff --git a/man/get_susie_result.Rd b/man/get_susie_result.Rd index e1895924..2d2966ac 100644 --- a/man/get_susie_result.Rd +++ b/man/get_susie_result.Rd @@ -2,29 +2,20 @@ % Please edit documentation in R/univariate_rss_diagnostics.R \name{get_susie_result} \alias{get_susie_result} -\title{Extract SuSiE Results from Finemapping Data} +\title{Extract the trimmed SuSiE fit from a finemapping pipeline result} \usage{ get_susie_result(con_data) } \arguments{ -\item{con_data}{List. The method layer data from a finemapping RDS file.} +\item{con_data}{List. The method-layer entry from a finemapping pipeline +result, expected to carry \code{$finemapping_result} as a +\code{FineMappingResult} object.} } \value{ -The trimmed SuSiE results (`$susie_result_trimmed`) if available, -otherwise NULL. +The trimmed fit (a list with \code{pip}, \code{sets}, etc.) or NULL. } \description{ -This function extracts the trimmed SuSiE results from a finemapping data object, -typically obtained from a finemapping RDS file. It's designed to work with -the method layer of these files, often named as 'method_RAISS_imputed', 'method', -or 'method_NO_QC'. This layer is right under the study layer. -} -\details{ -The function checks if the input data is empty or if the `$susie_result_trimmed` -element is missing. It returns NULL in these cases. If `$susie_result_trimmed` -exists and is not empty, it returns this element. -} -\note{ -This function is particularly useful when working with large datasets -where not all method layers may contain valid SuSiE results or method layer. +Returns the trimmed model fit underlying \code{con_data$finemapping_result} +(a \code{FineMappingResult} S4 object), or NULL if no fine-mapping result +is attached. } diff --git a/man/isSkipped.Rd b/man/isSkipped.Rd new file mode 100644 index 00000000..23507581 --- /dev/null +++ b/man/isSkipped.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/AllGenerics.R, R/AllMethods.R +\name{isSkipped} +\alias{isSkipped} +\alias{isSkipped,QCResult-method} +\title{Is Skipped} +\usage{ +isSkipped(x) + +\S4method{isSkipped}{QCResult}(x) +} +\arguments{ +\item{x}{A \code{QCResult} object.} +} +\value{ +Single logical. +} +\description{ +Whether QC short-circuited (e.g. no signals, too few variants). +} diff --git a/man/load_multitask_regional_data.Rd b/man/load_multitask_regional_data.Rd index 585c2070..8f74cb72 100644 --- a/man/load_multitask_regional_data.Rd +++ b/man/load_multitask_regional_data.Rd @@ -121,7 +121,7 @@ individual_data contains the following components if exist sumstat_data contains the following components if exist \itemize{ \item sumstats: A list of summary statistics for the matched LD_info, each sublist contains sumstats, n, var_y from \code{load_rss_data}. - \item LD_info: A list of LD information, each sublist contains LD_variants, LD_matrix, ref_panel \code{load_LD_matrix}. + \item LD_info: A list of \code{LDData} S4 objects (one per LD reference), as returned by \code{load_LD_matrix}. } } \description{ diff --git a/man/load_regional_association_data.Rd b/man/load_regional_association_data.Rd index b0b2c52d..f66906a5 100644 --- a/man/load_regional_association_data.Rd +++ b/man/load_regional_association_data.Rd @@ -61,22 +61,12 @@ load_regional_association_data( \item{tabix_header}{Logical indicating whether the tabix file has a header. Default is TRUE.} } \value{ -A list containing the following components: -\itemize{ - \item residual_Y: A list of residualized phenotype values (either a vector or a matrix). - \item residual_X: A list of residualized genotype matrices for each condition. - \item residual_Y_scalar: Scaling factor for residualized phenotype values. - \item residual_X_scalar: Scaling factor for residualized genotype values. - \item dropped_sample: A list of dropped samples for X, Y, and covariates. - \item covar: Covariate data. - \item Y: Original phenotype data. - \item X_data: Original genotype data. - \item X: Filtered genotype matrix. - \item maf: Minor allele frequency (MAF) for each variant. - \item chrom: Chromosome of the region. - \item grange: Genomic range of the region (start and end positions). - \item Y_coordinates: Phenotype coordinates if a region is specified. -} +A \code{RegionalData} S4 object. Per-condition residualized + phenotypes, residualized genotypes, and their scaling factors are + computed on demand via accessors (\code{getResidualX()}, + \code{getResidualY()}, \code{getResidualXScalar()}, + \code{getResidualYScalar()}, \code{getXVariance()}). Region metadata is + available via \code{getChrom()} and \code{getGrange()}. } \description{ This function loads genotype, phenotype, and covariate data for a specific region and performs data preprocessing. diff --git a/man/load_regional_functional_data.Rd b/man/load_regional_functional_data.Rd index bd6e033e..0c9e2a76 100644 --- a/man/load_regional_functional_data.Rd +++ b/man/load_regional_functional_data.Rd @@ -11,8 +11,12 @@ load_regional_functional_data(..., min_markers = NULL) If \code{NULL}, no marker-count filtering is applied.} } \value{ -A list +A \code{RegionalData} object. } \description{ -This function loads precomputed regional functional association data. +Loads precomputed regional functional association data. Returns a +\code{RegionalData} S4 object; derived quantities are computed lazily +via accessors. When \code{min_markers} is supplied, conditions whose +\code{Y_coordinates} have fewer than \code{min_markers} rows are +dropped from the returned \code{RegionalData}. } diff --git a/man/load_regional_multivariate_data.Rd b/man/load_regional_multivariate_data.Rd index 00880981..9e3306a7 100644 --- a/man/load_regional_multivariate_data.Rd +++ b/man/load_regional_multivariate_data.Rd @@ -7,9 +7,14 @@ load_regional_multivariate_data(matrix_y_min_complete = NULL, ...) } \value{ -A list +A \code{MultivariateRegionalData} object. } \description{ -This function loads regional association data and processes it into a multivariate format. -It optionally filters out samples based on missingness thresholds in the response matrix. +Loads regional association data and packages it for multivariate modeling. +Phenotypes across conditions are joined into a single multivariate matrix +(samples x conditions). When \code{matrix_y_min_complete} is supplied, +samples with fewer than that many non-missing condition values are dropped. +Per-variant MAF and variance are computed on the (post-filter) genotype +matrix and exposed via \code{getMAF()} / \code{getXVariance()} on the +returned object. } diff --git a/man/load_regional_regression_data.Rd b/man/load_regional_regression_data.Rd index f069f59f..8238bdda 100644 --- a/man/load_regional_regression_data.Rd +++ b/man/load_regional_regression_data.Rd @@ -7,9 +7,12 @@ load_regional_regression_data(...) } \value{ -A list +A \code{RegionalData} object. } \description{ -This function loads regional association data formatted for regression modeling. -It includes phenotype, genotype, and covariate matrices along with metadata. +Loads regional association data formatted for regression modeling. +Returns a \code{RegionalData} S4 object; the per-condition \code{X_data} +previously returned in a list is available as +\code{getResidualX(rd, i)} (residualized) or by subsetting +\code{rd@genotype_matrix} by condition rownames. } diff --git a/man/load_regional_univariate_data.Rd b/man/load_regional_univariate_data.Rd index 47651c05..aa60c160 100644 --- a/man/load_regional_univariate_data.Rd +++ b/man/load_regional_univariate_data.Rd @@ -7,9 +7,13 @@ load_regional_univariate_data(...) } \value{ -A list +A \code{RegionalData} object. } \description{ -This function loads regional association data for univariate analysis. -It includes residual matrices, original genotype data, and additional metadata. +Loads regional association data for univariate analysis. Returns a +\code{RegionalData} S4 object; derived quantities (residuals, scalars, +per-variant variance) are computed lazily via accessors +(\code{getResidualX}, \code{getResidualY}, \code{getResidualXScalar}, +\code{getResidualYScalar}, \code{getXVariance}, \code{getChrom}, +\code{getGrange}). } diff --git a/man/load_study_LD.Rd b/man/load_study_LD.Rd index 8f97f3b0..ad49824d 100644 --- a/man/load_study_LD.Rd +++ b/man/load_study_LD.Rd @@ -13,8 +13,9 @@ mixture panels (e.g., "ld_EUR.tsv,ld_AFR.tsv").} \item{region}{Region string "chr:start-end".} } \value{ -An LD_data list from load_LD_matrix. For single panels, returns as-is. - For mixture panels, LD_matrix is a list of X matrices (one per panel). +An \code{LDData} S4 object. For single panels, returns the result of + \code{load_LD_matrix()} unchanged. For mixture panels, \code{genotype_handle} + is a list of per-panel genotype handles sharing the first panel's variants. } \description{ Load LD for a study, supporting single or mixture panels. diff --git a/man/match_ref_panel.Rd b/man/match_ref_panel.Rd index 05dbe539..df3271b9 100644 --- a/man/match_ref_panel.Rd +++ b/man/match_ref_panel.Rd @@ -56,7 +56,10 @@ to be matched, otherwise stops with an error. Default is 20%.} corresponding `col_to_flip` are multiplied by -1. Default is `TRUE`.} } \value{ -A single data frame with matched variants. +An \code{AlleleQCResult} S4 object. Use + \code{getHarmonizedData()} to recover the post-QC variant + data.frame and \code{getQCSummary()} to inspect the per-variant + merge/flip/strand diagnostics. } \description{ Match by ("chrom", "A1", "A2" and "pos"), accounting for possible diff --git a/man/region_data_to_rss_input.Rd b/man/region_data_to_rss_input.Rd new file mode 100644 index 00000000..62e17229 --- /dev/null +++ b/man/region_data_to_rss_input.Rd @@ -0,0 +1,18 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/file_utils.R +\name{region_data_to_rss_input} +\alias{region_data_to_rss_input} +\title{Convert loaded regional data to RSS inputs} +\usage{ +region_data_to_rss_input(region_data) +} +\arguments{ +\item{region_data}{A list returned by \code{load_multitask_regional_data()}.} +} +\value{ +A list containing named RSS inputs, matched LD data, and source + information. +} +\description{ +Convert loaded regional data to RSS inputs +} diff --git a/man/rss_analysis_pipeline.Rd b/man/rss_analysis_pipeline.Rd index d34d351e..e1a10b9c 100644 --- a/man/rss_analysis_pipeline.Rd +++ b/man/rss_analysis_pipeline.Rd @@ -34,11 +34,10 @@ rss_analysis_pipeline( \item{column_file_path}{File path to the column mapping file.} -\item{LD_data}{A list from load_LD_matrix containing LD_matrix, LD_variants, -ref_panel, block_metadata, and is_genotype flag. When is_genotype=TRUE -(from return_genotype=TRUE), LD_matrix contains genotype X and susie_rss -uses the z+X interface. Local R is computed only for QC stages that -require a correlation matrix.} +\item{LD_data}{An \code{LDData} S4 object from \code{load_LD_matrix()}. When +\code{hasGenotypes(LD_data)} is TRUE (from \code{return_genotype=TRUE}), +susie_rss uses the z+X interface via \code{getGenotypes()}. Local R is +computed only for QC stages that require a correlation matrix.} \item{n_sample}{Sample size. If 0, retrieved from the sumstat file.} diff --git a/man/rss_basic_qc.Rd b/man/rss_basic_qc.Rd index f4068985..bb7aa591 100644 --- a/man/rss_basic_qc.Rd +++ b/man/rss_basic_qc.Rd @@ -15,21 +15,20 @@ rss_basic_qc( \arguments{ \item{sumstats}{A data frame containing summary statistics with columns "chrom", "pos", "A1", and "A2".} -\item{LD_data}{An \code{LDData} S4 object or a legacy list containing combined LD variants data, -as generated by \code{load_LD_matrix}.} +\item{LD_data}{An \code{LDData} S4 object containing combined LD variants +data, as generated by \code{load_LD_matrix}.} \item{skip_region}{A character vector specifying regions to be skipped in the analysis (optional). Each region should be in the format "chrom:start-end" (e.g., "1:1000000-2000000").} -\item{return_LD_mat}{Logical; if \code{FALSE}, return only harmonized -summary statistics and skip LD-matrix subsetting. This is useful when the -reference input is genotype-backed \code{X_ref}. Defaults to \code{TRUE} -for backwards compatibility.} +\item{return_LD_mat}{Logical; if \code{FALSE}, the returned \code{QCResult} +carries \code{NULL} in its \code{ld_data} slot (no LD subsetting is +performed). Useful when the reference input is genotype-backed.} } \value{ -A list containing the processed summary statistics and LD matrix. - - sumstats: A data frame containing the processed summary statistics. - - LD_mat: The processed LD matrix. +A \code{QCResult} S4 object. Use \code{getRSSInput()$sumstats} to + recover the harmonized sumstats and \code{getCorrelation(getLDData(qc))} + to recover the aligned LD matrix (or NULL when \code{return_LD_mat=FALSE}). } \description{ This function preprocesses summary statistics and LD data for RSS analysis. diff --git a/man/summary_stats_qc.Rd b/man/summary_stats_qc.Rd index bf267c9c..9fbc18e8 100644 --- a/man/summary_stats_qc.Rd +++ b/man/summary_stats_qc.Rd @@ -45,16 +45,13 @@ ignored by the historical LD-mismatch-only call unless \code{rss_input} or combined-QC options are supplied.} } \value{ -A list containing the quality-controlled summary statistics and - updated LD matrix for the historical call: - \itemize{ - \item sumstats: The quality-controlled summary statistics data frame. - \item LD_mat: The updated LD matrix after quality control. - \item outlier_number: The number of outlier variants removed. - } - When \code{rss_input} or combined-QC controls are supplied, returns a - cleaned RSS/LD record for one RSS record, or a named list of records for a - list of RSS records. +A \code{QCResult} S4 object for the historical LD-mismatch-only + call (use \code{getRSSInput()$sumstats}, \code{getCorrelation(getLDData())}, + and \code{getOutlierNumber()} to recover the harmonized sumstats, post-QC + LD matrix, and outlier count). When \code{rss_input} or combined-QC + controls are supplied, returns either a single \code{QCResult} (for one + RSS record) or a named list of \code{QCResult} objects (for a list of + RSS records). } \description{ This function performs quality control on the processed summary statistics