diff --git a/pyproject.toml b/pyproject.toml index 34af193..24c44d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,7 +156,7 @@ convention = "google" requires = ["uv_build>=0.9.9,<0.10.0"] build-backend = "uv_build" -[tool.pytest] +[tool.pytest.ini_options] pythonpath = ["src"] log_cli = true log_cli_level = "INFO" diff --git a/scripts/parse_index.py b/scripts/parse_index.py deleted file mode 100644 index 1ca592e..0000000 --- a/scripts/parse_index.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Script to parse a file containing paths to genome data.""" - -import json -import sys -from pathlib import Path - - -def main() -> None: - """Extracts values associated with a specified key from a JSON file and prints them. - - Usage: - parse_json.py - - Arguments: - input_json_file: Path to the input JSON file. - key: The key whose values need to be extracted. - - Output: - Prints the values associated with the specified key, one per line. - """ - # Check if the correct number of arguments is provided - if len(sys.argv) < 3: - print("Usage: parse_json.py ") # noqa: T201 - sys.exit(1) - - input_file = sys.argv[1] # Path to the JSON file - target_key = sys.argv[2] # Key to extract values for - - try: - # Attempt to open and read the JSON file - with Path(input_file).open() as file: - data = json.load(file) - except Exception as e: - print(f"Error reading JSON file: {e}") # noqa: T201 - sys.exit(1) - - # throw error if top level data structure is incorrect - if not isinstance(data, dict): - print("Error: expected JSON file to be a dictionary") # noqa: T201 - sys.exit(1) - - # iterate over dictionary values and extract the values from dicts containing the target key - extracted_paths = [entry[target_key] for entry in data.values() if isinstance(entry, dict) and target_key in entry] - - # Print each extracted path on a new line - for path in extracted_paths: - print(path) # noqa: T201 - - -if __name__ == "__main__": - main() diff --git a/tests/parsers/refseq_importer/__init__.py b/source similarity index 100% rename from tests/parsers/refseq_importer/__init__.py rename to source diff --git a/src/cdm_data_loader_utils/parsers/annotation_parse.py b/src/cdm_data_loader_utils/parsers/annotation_parse.py new file mode 100644 index 0000000..59b30bc --- /dev/null +++ b/src/cdm_data_loader_utils/parsers/annotation_parse.py @@ -0,0 +1,386 @@ +""" + +RefSeq annotation parser for transforming NCBI Datasets API JSON into CDM-formatted Delta Lake tables. + +Usage: + python src/cdm_data_loader_utils/parsers/annotation_parse.py \ + --accession GCF_000869125.1 \ + --output-path output/refseq/GCF_000869125.1 \ + --query + +""" + +from __future__ import annotations +import argparse +import json +from pathlib import Path +from typing import Optional + +import requests +from pyspark.sql import SparkSession +from pyspark.sql.types import StructType +from delta import configure_spark_with_delta_pip + +from cdm_data_loader_utils.parsers.kbase_cdm_pyspark import schema as cdm_schemas + + +# --------------------------------------------------------------------- +# Accession-based annotation fetch +# --------------------------------------------------------------------- +def fetch_annotation_json(accession: str) -> dict: + """Fetch annotation JSON from NCBI Datasets API.""" + url = f"https://api.ncbi.nlm.nih.gov/datasets/v2/genome/accession/{accession}/annotation_report" + resp = requests.get(url, headers={"Accept": "application/json"}, timeout=60) + resp.raise_for_status() + return resp.json() + + +# --------------------------------------------------------------------- +# SPARK SESSION +# --------------------------------------------------------------------- +def build_spark_session(app_name: str = "RefSeqAnnotationToCDM") -> SparkSession: + """Configure and return Spark session with Delta support.""" + builder = ( + SparkSession.builder.appName(app_name) + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") + ) + return configure_spark_with_delta_pip(builder).getOrCreate() + + +# --------------------------------------------------------------------- +# CDM TABLE SCHEMAS +# --------------------------------------------------------------------- +# Using centralized schemas +IDENTIFIER_SCHEMA = cdm_schemas["Identifier"] +NAME_SCHEMA = cdm_schemas["Name"] +FEATURE_SCHEMA = cdm_schemas["Feature"] +CONTIG_COLLECTION_X_FEATURE_SCHEMA = cdm_schemas["ContigCollection_x_Feature"] +CONTIG_COLLECTION_X_PROTEIN_SCHEMA = cdm_schemas["ContigCollection_x_Protein"] +FEATURE_X_PROTEIN_SCHEMA = cdm_schemas["Feature_x_Protein"] +CONTIG_SCHEMA = cdm_schemas["Contig"] +CONTIG_X_CONTIG_COLLECTION_SCHEMA = cdm_schemas["Contig_x_ContigCollection"] + + +# --------------------------------------------------------------------- +# CDM PREFIX NORMALIZATION +# --------------------------------------------------------------------- +def apply_prefix(identifier: str) -> str: + """Normalize identifiers to CDM-prefixed formats.""" + if identifier.startswith("GeneID:"): + return identifier.replace("GeneID:", "ncbigene:") + if identifier.startswith(("YP_", "XP_", "WP_", "NP_", "NC_")): + return f"refseq:{identifier}" + if identifier.startswith("GCF_"): + return f"insdc.gcf:{identifier}" + return identifier + + +# --------------------------------------------------------------------- +# Safe integer conversion +# --------------------------------------------------------------------- +def to_int(val: str) -> int | None: + try: + return int(val) + except Exception: + return None + + +# --------------------------------------------------------------------- +# IDENTIFIERS +# --------------------------------------------------------------------- +def load_identifiers(input_json: Path) -> list[tuple[str, str, str, str, str | None]]: + """Extract Identifier table records.""" + data = json.loads(input_json.read_text()) + out = [] + for report in data.get("reports", []): + ann = report.get("annotation", {}) + gene_id = ann.get("gene_id") + if not gene_id: + continue + entity_id = apply_prefix(f"GeneID:{gene_id}") + out.append((entity_id, gene_id, ann.get("name"), "RefSeq", ann.get("relationship"))) + return out + + +# --------------------------------------------------------------------- +# NAME EXTRACTION +# --------------------------------------------------------------------- +def load_names(input_json: Path) -> list[tuple[str, str, str, str]]: + """Extract Name table records.""" + data = json.loads(input_json.read_text()) + out = [] + for report in data.get("reports", []): + ann = report.get("annotation", {}) + gene_id = ann.get("gene_id") + if not gene_id: + continue + entity_id = apply_prefix(f"GeneID:{gene_id}") + for label, desc in [ + ("symbol", "RefSeq gene symbol"), + ("name", "RefSeq gene name"), + ("locus_tag", "RefSeq locus tag"), + ]: + val = ann.get(label) + if val: + out.append((entity_id, val, desc, "RefSeq")) + return out + + +# --------------------------------------------------------------------- +# FEATURE LOCATIONS +# --------------------------------------------------------------------- +def load_feature_records(input_json: Path) -> list[tuple]: + """Extract Feature table records.""" + data = json.loads(input_json.read_text()) + features = [] + for report in data.get("reports", []): + ann = report.get("annotation", {}) + gene_id = ann.get("gene_id") + if not gene_id: + continue + feature_id = apply_prefix(f"GeneID:{gene_id}") + for region in ann.get("genomic_regions", []): + for r in region.get("gene_range", {}).get("range", []): + strand = { + "plus": "positive", + "minus": "negative", + "unstranded": "unstranded", + }.get(r.get("orientation"), "unknown") + features.append( + ( + feature_id, + None, + None, + None, + to_int(r.get("end")), + None, + to_int(r.get("begin")), + strand, + "RefSeq", + None, + "gene", + ) + ) + return features + + +# --------------------------------------------------------------------- +# PARSE CONTIG_COLLECTION <-> FEATURE +# --------------------------------------------------------------------- +def load_contig_collection_x_feature(input_json: Path) -> list[tuple[str, str]]: + """Parse ContigCollection ↔ Feature links.""" + data = json.loads(input_json.read_text()) + links = [] + + for report in data.get("reports", []): + ann = report.get("annotation", {}) + gene_id = ann.get("gene_id") + regions = ann.get("genomic_regions", []) + + if not gene_id or not regions: + continue + + acc = regions[0].get("gene_range", {}).get("accession_version") + if acc: + links.append((apply_prefix(acc), apply_prefix(f"GeneID:{gene_id}"))) + + return links + + +# --------------------------------------------------------------------- +# PARSE CONTIG_COLLECTION <-> PROTEIN +# --------------------------------------------------------------------- +def load_contig_collection_x_protein(input_json: Path) -> list[tuple[str, str]]: + data = json.loads(input_json.read_text()) + links = [] + + for report in data.get("reports", []): + ann = report.get("annotation", {}) + proteins = ann.get("proteins", []) + annotations = ann.get("annotations", []) + + if not proteins or not annotations: + continue + + assembly = annotations[0].get("assembly_accession") + + if not assembly: + continue + + contig_id = apply_prefix(assembly) + + for p in proteins: + pid = p.get("accession_version") + if pid: + protein_id = apply_prefix(pid) + links.append((contig_id, protein_id)) + + return links + + +# --------------------------------------------------------------------- +# PARSE FEATURE <-> PROTEIN +# --------------------------------------------------------------------- +def load_feature_x_protein(input_json: Path) -> list[tuple[str, str]]: + data = json.loads(input_json.read_text()) + links = [] + + for report in data.get("reports", []): + ann = report.get("annotation", {}) + gene_id = ann.get("gene_id") + proteins = ann.get("proteins", []) + + if not gene_id or not proteins: + continue + + feature_id = apply_prefix(f"GeneID:{gene_id}") + + for p in proteins: + pid = p.get("accession_version") + if pid: + protein_id = apply_prefix(pid) + links.append((feature_id, protein_id)) + + return links + + +# --------------------------------------------------------------------- +# PARSE CONTIGS +# --------------------------------------------------------------------- +def load_contigs(input_json: Path) -> list[tuple[str, str | None, float | None, int | None]]: + """Parse Contig table.""" + data = json.loads(input_json.read_text()) + contigs = {} + + for report in data.get("reports", []): + for region in report.get("annotation", {}).get("genomic_regions", []): + acc = region.get("gene_range", {}).get("accession_version") + if acc: + contig_id = apply_prefix(acc) + contigs.setdefault(contig_id, {"hash": None, "gc_content": None, "length": None}) + + return [(cid, meta["hash"], meta["gc_content"], meta["length"]) for cid, meta in contigs.items()] + + +# --------------------------------------------------------------------- +# PARSE CONTIG <-> CONTIG_COLLECTION +# --------------------------------------------------------------------- +def load_contig_x_contig_collection(input_json: Path) -> list[tuple[str, str]]: + data = json.loads(input_json.read_text()) + links = [] + + for report in data.get("reports", []): + ann = report.get("annotation", {}) + regions = ann.get("genomic_regions", []) + annotations = ann.get("annotations", []) + + if not regions or not annotations: + continue + + contig = regions[0].get("gene_range", {}).get("accession_version") + assembly = annotations[0].get("assembly_accession") + + if contig and assembly: + contig_id = f"refseq:{contig}" + collection_id = apply_prefix(assembly) + links.append((contig_id, collection_id)) + + return links + + +# --------------------------------------------------------------------- +# DELTA TABLE +# --------------------------------------------------------------------- +def write_to_delta( + spark: SparkSession, + records: list[tuple], + output_path: str, + schema: StructType, +) -> None: + """Write records to Delta table.""" + if not records: + return + + df = spark.createDataFrame(records, schema=schema) + df.write.format("delta").mode("overwrite").option("overwriteSchema", "true").save(output_path) + + +# --------------------------------------------------------------------- +# SQL PREVIEW +# --------------------------------------------------------------------- +def run_sql_query(spark: SparkSession, delta_path: str) -> None: + """Run SQL queries to preview Delta tables.""" + for name in [ + "cdm_identifiers", + "cdm_names", + "cdm_features", + "cdm_contig_collection_x_feature", + "cdm_contig_collection_x_protein", + "cdm_feature_x_protein", + "cdm_contigs", + "cdm_contig_x_contig_collection", + ]: + print(f"\n[SQL] {name}:") + path = str(Path(delta_path) / name) + spark.read.format("delta").load(path).createOrReplaceTempView(name) + spark.sql(f"SELECT * FROM {name} LIMIT 20").show(truncate=False) + + +# --------------------------------------------------------------------- +# CLI ENTRY +# --------------------------------------------------------------------- +def main() -> None: + """Entry point for RefSeq Annotation parser.""" + parser = argparse.ArgumentParser(description="RefSeq Annotation Parser to CDM") + parser.add_argument("--accession", required=True) + parser.add_argument("--output-path", required=True) + parser.add_argument("--query", action="store_true") + args = parser.parse_args() + + base_output = Path(args.output_path) + base_output.mkdir(parents=True, exist_ok=True) + + data = fetch_annotation_json(args.accession) + input_path = Path(f"/tmp/{args.accession}.json") + input_path.write_text(json.dumps(data, indent=2)) + + spark = build_spark_session() + + write_to_delta(spark, load_identifiers(input_path), str(base_output / "cdm_identifiers"), IDENTIFIER_SCHEMA) + write_to_delta(spark, load_names(input_path), str(base_output / "cdm_names"), NAME_SCHEMA) + write_to_delta(spark, load_feature_records(input_path), str(base_output / "cdm_features"), FEATURE_SCHEMA) + write_to_delta( + spark, + load_contig_collection_x_feature(input_path), + str(base_output / "cdm_contig_collection_x_feature"), + CONTIG_COLLECTION_X_FEATURE_SCHEMA, + ) + write_to_delta( + spark, + load_contig_collection_x_protein(input_path), + str(base_output / "cdm_contig_collection_x_protein"), + CONTIG_COLLECTION_X_PROTEIN_SCHEMA, + ) + write_to_delta( + spark, + load_feature_x_protein(input_path), + str(base_output / "cdm_feature_x_protein"), + FEATURE_X_PROTEIN_SCHEMA, + ) + write_to_delta(spark, load_contigs(input_path), str(base_output / "cdm_contigs"), CONTIG_SCHEMA) + write_to_delta( + spark, + load_contig_x_contig_collection(input_path), + str(base_output / "cdm_contig_x_contig_collection"), + CONTIG_X_CONTIG_COLLECTION_SCHEMA, + ) + + if args.query: + run_sql_query(spark, str(base_output)) + + spark.stop() + + +if __name__ == "__main__": + main() diff --git a/src/cdm_data_loader_utils/parsers/kbase_cdm_pyspark.py b/src/cdm_data_loader_utils/parsers/kbase_cdm_pyspark.py new file mode 100644 index 0000000..19be5e8 --- /dev/null +++ b/src/cdm_data_loader_utils/parsers/kbase_cdm_pyspark.py @@ -0,0 +1,610 @@ +"""Automated conversion of cdm_schema to PySpark.""" + +from pyspark.sql.types import BooleanType, DateType, FloatType, IntegerType, StringType, StructField, StructType + +schema = { + "Association": StructType( + [ + StructField("association_id", StringType(), nullable=False), + StructField("subject", StringType(), nullable=False), + StructField("object", StringType(), nullable=False), + StructField("predicate", StringType(), nullable=False), + StructField("negated", BooleanType(), nullable=True), + StructField("evidence_type", StringType(), nullable=True), + StructField("primary_knowledge_source", StringType(), nullable=True), + StructField("aggregator_knowledge_source", StringType(), nullable=True), + StructField("annotation_date", DateType(), nullable=True), + StructField("comments", StringType(), nullable=True), + ] + ), + "Association_x_SupportingObject": StructType( + [ + StructField("association_id", StringType(), nullable=False), + StructField("entity_id", StringType(), nullable=False), + ] + ), + "Cluster": StructType( + [ + StructField("cluster_id", StringType(), nullable=False), + StructField("description", StringType(), nullable=True), + StructField("name", StringType(), nullable=True), + StructField("entity_type", StringType(), nullable=False), + StructField("protocol_id", StringType(), nullable=True), + ] + ), + "ClusterMember": StructType( + [ + StructField("cluster_id", StringType(), nullable=False), + StructField("entity_id", StringType(), nullable=False), + StructField("is_representative", BooleanType(), nullable=True), + StructField("is_seed", BooleanType(), nullable=True), + StructField("score", FloatType(), nullable=True), + ] + ), + "Contig": StructType( + [ + StructField("contig_id", StringType(), nullable=False), + StructField("hash", StringType(), nullable=True), + StructField("gc_content", FloatType(), nullable=True), + StructField("length", IntegerType(), nullable=True), + ] + ), + "ContigCollection": StructType( + [ + StructField("contig_collection_id", StringType(), nullable=False), + StructField("hash", StringType(), nullable=True), + StructField("asm_score", FloatType(), nullable=True), + StructField("checkm_completeness", FloatType(), nullable=True), + StructField("checkm_contamination", FloatType(), nullable=True), + StructField("checkm_version", StringType(), nullable=True), + StructField("contig_bp", IntegerType(), nullable=True), + StructField("contig_collection_type", StringType(), nullable=True), + StructField("contig_l50", IntegerType(), nullable=True), + StructField("contig_l90", IntegerType(), nullable=True), + StructField("contig_n50", IntegerType(), nullable=True), + StructField("contig_n90", IntegerType(), nullable=True), + StructField("contig_logsum", FloatType(), nullable=True), + StructField("contig_max", IntegerType(), nullable=True), + StructField("contig_powersum", FloatType(), nullable=True), + StructField("gap_percent", FloatType(), nullable=True), + StructField("gc_average", FloatType(), nullable=True), + StructField("gc_std", FloatType(), nullable=True), + StructField("gtdb_taxon_id", StringType(), nullable=True), + StructField("n_chromosomes", IntegerType(), nullable=True), + StructField("n_contigs", IntegerType(), nullable=True), + StructField("n_scaffolds", IntegerType(), nullable=True), + StructField("ncbi_taxon_id", StringType(), nullable=True), + StructField("scaffold_l50", IntegerType(), nullable=True), + StructField("scaffold_l90", IntegerType(), nullable=True), + StructField("scaffold_n50", IntegerType(), nullable=True), + StructField("scaffold_n90", IntegerType(), nullable=True), + StructField("scaffold_bp", IntegerType(), nullable=True), + StructField("scaffold_logsum", FloatType(), nullable=True), + StructField("scaffold_maximum_length", IntegerType(), nullable=True), + StructField("scaffold_powersum", FloatType(), nullable=True), + StructField("scaffolds_n_over_50K", IntegerType(), nullable=True), + StructField("scaffolds_percent_over_50K", FloatType(), nullable=True), + StructField("scaffolds_total_length_over_50k", IntegerType(), nullable=True), + ] + ), + "ContigCollection_x_EncodedFeature": StructType( + [ + StructField("contig_collection_id", StringType(), nullable=False), + StructField("encoded_feature_id", StringType(), nullable=False), + ] + ), + "ContigCollection_x_Feature": StructType( + [ + StructField("contig_collection_id", StringType(), nullable=False), + StructField("feature_id", StringType(), nullable=False), + ] + ), + "ContigCollection_x_Protein": StructType( + [ + StructField("contig_collection_id", StringType(), nullable=False), + StructField("protein_id", StringType(), nullable=False), + ] + ), + "Contig_x_ContigCollection": StructType( + [ + StructField("contig_id", StringType(), nullable=False), + StructField("contig_collection_id", StringType(), nullable=False), + ] + ), + "Contig_x_EncodedFeature": StructType( + [ + StructField("contig_id", StringType(), nullable=False), + StructField("encoded_feature_id", StringType(), nullable=False), + ] + ), + "Contig_x_Feature": StructType( + [ + StructField("contig_id", StringType(), nullable=False), + StructField("feature_id", StringType(), nullable=False), + ] + ), + "Contig_x_Protein": StructType( + [ + StructField("contig_id", StringType(), nullable=False), + StructField("protein_id", StringType(), nullable=False), + ] + ), + "Contributor": StructType( + [ + StructField("contributor_id", StringType(), nullable=False), + StructField("contributor_type", StringType(), nullable=True), + StructField("name", StringType(), nullable=True), + StructField("given_name", StringType(), nullable=True), + StructField("family_name", StringType(), nullable=True), + ] + ), + "ContributorAffiliation": StructType( + [ + StructField("contributor_id", StringType(), nullable=False), + StructField("affiliation_id", StringType(), nullable=True), + ] + ), + "Contributor_x_DataSource": StructType( + [ + StructField("contributor_id", StringType(), nullable=False), + StructField("data_source_id", StringType(), nullable=False), + StructField("contributor_role", StringType(), nullable=True), + ] + ), + "Contributor_x_Role_x_Project": StructType( + [ + StructField("contributor_id", StringType(), nullable=False), + StructField("project_id", StringType(), nullable=False), + StructField("contributor_role", StringType(), nullable=True), + ] + ), + "ControlledTermValue": StructType( + [ + StructField("value_cv_label", StringType(), nullable=False), + StructField("raw_value", StringType(), nullable=True), + StructField("type", StringType(), nullable=True), + StructField("attribute_cv_id", StringType(), nullable=True), + StructField("attribute_cv_label", StringType(), nullable=True), + StructField("attribute_string", StringType(), nullable=True), + StructField("entity_id", StringType(), nullable=False), + ] + ), + "ControlledVocabularyTermValue": StructType( + [ + StructField("value_cv_label", StringType(), nullable=True), + StructField("value_cv_id", StringType(), nullable=True), + StructField("raw_value", StringType(), nullable=True), + StructField("type", StringType(), nullable=True), + StructField("attribute_cv_id", StringType(), nullable=True), + StructField("attribute_cv_label", StringType(), nullable=True), + StructField("attribute_string", StringType(), nullable=True), + StructField("entity_id", StringType(), nullable=False), + ] + ), + "DataSource": StructType( + [ + StructField("data_source_id", StringType(), nullable=False), + StructField("name", StringType(), nullable=True), + ] + ), + "DataSourceNew": StructType( + [ + StructField("data_source_id", StringType(), nullable=False), + StructField("name", StringType(), nullable=True), + StructField("comments", StringType(), nullable=True), + StructField("date_accessed", DateType(), nullable=False), + StructField("date_published", DateType(), nullable=True), + StructField("date_updated", DateType(), nullable=True), + StructField("license", StringType(), nullable=True), + StructField("publisher", StringType(), nullable=True), + StructField("resource_type", StringType(), nullable=False), + StructField("url", StringType(), nullable=True), + StructField("version", StringType(), nullable=True), + ] + ), + "DataSource_x_Description": StructType( + [ + StructField("data_source_id", StringType(), nullable=False), + StructField("resource_description_id", StringType(), nullable=False), + ] + ), + "DataSource_x_FundingReference": StructType( + [ + StructField("data_source_id", StringType(), nullable=False), + StructField("funding_reference_id", StringType(), nullable=False), + ] + ), + "DataSource_x_License": StructType( + [ + StructField("data_source_id", StringType(), nullable=False), + StructField("license_id", StringType(), nullable=False), + ] + ), + "DataSource_x_Title": StructType( + [ + StructField("data_source_id", StringType(), nullable=False), + StructField("resource_title_id", StringType(), nullable=False), + ] + ), + "DateTimeValue": StructType( + [ + StructField("date_time", DateType(), nullable=False), + StructField("raw_value", StringType(), nullable=True), + StructField("type", StringType(), nullable=True), + StructField("attribute_cv_id", StringType(), nullable=True), + StructField("attribute_cv_label", StringType(), nullable=True), + StructField("attribute_string", StringType(), nullable=True), + StructField("entity_id", StringType(), nullable=False), + ] + ), + "EncodedFeature": StructType( + [ + StructField("encoded_feature_id", StringType(), nullable=False), + StructField("hash", StringType(), nullable=True), + StructField("has_stop_codon", BooleanType(), nullable=True), + StructField("type", StringType(), nullable=True), + ] + ), + "EncodedFeature_x_Feature": StructType( + [ + StructField("encoded_feature_id", StringType(), nullable=False), + StructField("feature_id", StringType(), nullable=False), + ] + ), + "EncodedFeature_x_Protein": StructType( + [ + StructField("encoded_feature_id", StringType(), nullable=False), + StructField("protein_id", StringType(), nullable=False), + ] + ), + "EntailedEdge": StructType( + [ + StructField("subject", StringType(), nullable=True), + StructField("predicate", StringType(), nullable=True), + StructField("object", StringType(), nullable=True), + ] + ), + "Entity": StructType( + [ + StructField("entity_id", StringType(), nullable=False), + StructField("entity_type", StringType(), nullable=False), + StructField("data_source_id", StringType(), nullable=True), + StructField("data_source_entity_id", StringType(), nullable=True), + StructField("data_source_created", DateType(), nullable=False), + StructField("data_source_updated", DateType(), nullable=True), + StructField("created", DateType(), nullable=False), + StructField("updated", DateType(), nullable=False), + ] + ), + "Event": StructType( + [ + StructField("event_id", StringType(), nullable=False), + StructField("created_at", DateType(), nullable=True), + StructField("description", StringType(), nullable=True), + StructField("name", StringType(), nullable=True), + StructField("location", StringType(), nullable=True), + ] + ), + "Experiment": StructType( + [ + StructField("experiment_id", StringType(), nullable=False), + StructField("protocol_id", StringType(), nullable=False), + StructField("name", StringType(), nullable=True), + StructField("description", StringType(), nullable=True), + StructField("created_at", DateType(), nullable=True), + ] + ), + "ExperimentCondition": StructType( + [ + StructField("experiment_condition_id", StringType(), nullable=False), + StructField("experiment_id", StringType(), nullable=False), + StructField("variable_id", StringType(), nullable=False), + StructField("value", StringType(), nullable=True), + ] + ), + "ExperimentConditionSet": StructType( + [ + StructField("experiment_condition_set_id", StringType(), nullable=False), + StructField("experiment_condition_id", StringType(), nullable=False), + ] + ), + "Feature": StructType( + [ + StructField("feature_id", StringType(), nullable=False), + StructField("hash", StringType(), nullable=True), + StructField("cds_phase", StringType(), nullable=True), + StructField("e_value", FloatType(), nullable=True), + StructField("end", IntegerType(), nullable=True), + StructField("p_value", FloatType(), nullable=True), + StructField("start", IntegerType(), nullable=True), + StructField("strand", StringType(), nullable=True), + StructField("source_database", StringType(), nullable=True), + StructField("protocol_id", StringType(), nullable=True), + StructField("type", StringType(), nullable=True), + ] + ), + "Feature_x_Protein": StructType( + [ + StructField("feature_id", StringType(), nullable=False), + StructField("protein_id", StringType(), nullable=False), + ] + ), + "FundingReference": StructType( + [ + StructField("funding_reference_id", StringType(), nullable=False), + StructField("funder", StringType(), nullable=True), + StructField("grant_id", StringType(), nullable=True), + StructField("grant_title", StringType(), nullable=True), + StructField("grant_url", StringType(), nullable=True), + ] + ), + "Geolocation": StructType( + [ + StructField("latitude", FloatType(), nullable=False), + StructField("longitude", FloatType(), nullable=False), + StructField("raw_value", StringType(), nullable=True), + StructField("type", StringType(), nullable=True), + StructField("attribute_cv_id", StringType(), nullable=True), + StructField("attribute_cv_label", StringType(), nullable=True), + StructField("attribute_string", StringType(), nullable=True), + StructField("entity_id", StringType(), nullable=False), + ] + ), + "GoldEnvironmentalContext": StructType( + [ + StructField("gold_environmental_context_id", StringType(), nullable=False), + StructField("ecosystem", StringType(), nullable=True), + StructField("ecosystem_category", StringType(), nullable=True), + StructField("ecosystem_subtype", StringType(), nullable=True), + StructField("ecosystem_type", StringType(), nullable=True), + StructField("specific_ecosystem", StringType(), nullable=True), + ] + ), + "Identifier": StructType( + [ + StructField("entity_id", StringType(), nullable=False), + StructField("identifier", StringType(), nullable=False), + StructField("description", StringType(), nullable=True), + StructField("source", StringType(), nullable=True), + StructField("relationship", StringType(), nullable=True), + ] + ), + "License": StructType( + [ + StructField("license_id", StringType(), nullable=False), + StructField("id", StringType(), nullable=True), + StructField("name", StringType(), nullable=True), + StructField("url", StringType(), nullable=True), + ] + ), + "Measurement": StructType( + [ + StructField("measurement_id", StringType(), nullable=False), + StructField("measurement_set_id", StringType(), nullable=False), + StructField("experiment_condition_set_id", StringType(), nullable=False), + StructField("value", StringType(), nullable=True), + ] + ), + "MeasurementSet": StructType( + [ + StructField("measurement_set_id", StringType(), nullable=False), + StructField("variable_id", StringType(), nullable=False), + StructField("quality", StringType(), nullable=True), + StructField("created_at", DateType(), nullable=True), + ] + ), + "MixsEnvironmentalContext": StructType( + [ + StructField("mixs_environmental_context_id", StringType(), nullable=False), + StructField("env_broad_scale", StringType(), nullable=True), + StructField("env_local_scale", StringType(), nullable=True), + StructField("env_medium", StringType(), nullable=True), + ] + ), + "Name": StructType( + [ + StructField("entity_id", StringType(), nullable=False), + StructField("name", StringType(), nullable=False), + StructField("description", StringType(), nullable=True), + StructField("source", StringType(), nullable=True), + ] + ), + "OrderedProtocolStep": StructType( + [ + StructField("protocol_id", StringType(), nullable=False), + StructField("protocol_step_id", StringType(), nullable=False), + StructField("step_index", IntegerType(), nullable=False), + ] + ), + "Parameter": StructType( + [ + StructField("parameter_id", StringType(), nullable=False), + StructField("name", StringType(), nullable=True), + StructField("description", StringType(), nullable=True), + StructField("value_type", StringType(), nullable=True), + StructField("required", BooleanType(), nullable=True), + StructField("cardinality", StringType(), nullable=True), + StructField("default", StringType(), nullable=True), + StructField("parameter_type", StringType(), nullable=True), + ] + ), + "Prefix": StructType( + [ + StructField("prefix", StringType(), nullable=True), + StructField("base", StringType(), nullable=True), + ] + ), + "Project": StructType( + [ + StructField("project_id", StringType(), nullable=False), + StructField("description", StringType(), nullable=True), + ] + ), + "Protein": StructType( + [ + StructField("protein_id", StringType(), nullable=False), + StructField("hash", StringType(), nullable=True), + StructField("description", StringType(), nullable=True), + StructField("evidence_for_existence", StringType(), nullable=True), + StructField("length", IntegerType(), nullable=True), + StructField("sequence", StringType(), nullable=True), + ] + ), + "Protocol": StructType( + [ + StructField("protocol_id", StringType(), nullable=False), + StructField("name", StringType(), nullable=True), + StructField("description", StringType(), nullable=True), + StructField("doi", StringType(), nullable=True), + StructField("url", StringType(), nullable=True), + StructField("version", StringType(), nullable=True), + ] + ), + "ProtocolExecution": StructType( + [ + StructField("protocol_execution_id", StringType(), nullable=False), + StructField("protocol_id", StringType(), nullable=False), + StructField("name", StringType(), nullable=True), + StructField("description", StringType(), nullable=True), + StructField("created_at", DateType(), nullable=True), + ] + ), + "ProtocolInput": StructType( + [ + StructField("parameter_id", StringType(), nullable=False), + StructField("protocol_input_id", StringType(), nullable=False), + StructField("protocol_execution_id", StringType(), nullable=False), + StructField("value", StringType(), nullable=False), + ] + ), + "ProtocolInputSet": StructType( + [ + StructField("protocol_input_id", StringType(), nullable=False), + StructField("protocol_input_set_id", StringType(), nullable=False), + ] + ), + "ProtocolOutput": StructType( + [ + StructField("protocol_output_id", StringType(), nullable=False), + StructField("protocol_input_set_id", StringType(), nullable=False), + StructField("value", StringType(), nullable=False), + ] + ), + "ProtocolStep": StructType( + [ + StructField("protocol_step_id", StringType(), nullable=False), + StructField("step", StringType(), nullable=True), + ] + ), + "ProtocolVariable": StructType( + [ + StructField("protocol_id", StringType(), nullable=False), + StructField("variable_id", StringType(), nullable=False), + ] + ), + "Publication": StructType( + [ + StructField("publication_id", StringType(), nullable=False), + ] + ), + "QuantityRangeValue": StructType( + [ + StructField("maximum_numeric_value", FloatType(), nullable=False), + StructField("minimum_numeric_value", FloatType(), nullable=False), + StructField("unit_cv_id", StringType(), nullable=True), + StructField("unit_cv_label", StringType(), nullable=True), + StructField("unit_string", StringType(), nullable=True), + StructField("raw_value", StringType(), nullable=True), + StructField("type", StringType(), nullable=True), + StructField("attribute_cv_id", StringType(), nullable=True), + StructField("attribute_cv_label", StringType(), nullable=True), + StructField("attribute_string", StringType(), nullable=True), + StructField("entity_id", StringType(), nullable=False), + ] + ), + "QuantityValue": StructType( + [ + StructField("numeric_value", FloatType(), nullable=False), + StructField("unit_cv_id", StringType(), nullable=True), + StructField("unit_cv_label", StringType(), nullable=True), + StructField("unit_string", StringType(), nullable=True), + StructField("raw_value", StringType(), nullable=True), + StructField("type", StringType(), nullable=True), + StructField("attribute_cv_id", StringType(), nullable=True), + StructField("attribute_cv_label", StringType(), nullable=True), + StructField("attribute_string", StringType(), nullable=True), + StructField("entity_id", StringType(), nullable=False), + ] + ), + "ResourceDescription": StructType( + [ + StructField("resource_description_id", StringType(), nullable=False), + StructField("description_text", StringType(), nullable=False), + StructField("description_type", StringType(), nullable=True), + StructField("language", StringType(), nullable=True), + ] + ), + "ResourceTitle": StructType( + [ + StructField("resource_title_id", StringType(), nullable=False), + StructField("language", StringType(), nullable=True), + StructField("title", StringType(), nullable=False), + StructField("title_type", StringType(), nullable=True), + ] + ), + "Sample": StructType( + [ + StructField("sample_id", StringType(), nullable=False), + StructField("description", StringType(), nullable=True), + StructField("type", StringType(), nullable=True), + ] + ), + "Sequence": StructType( + [ + StructField("sequence_id", StringType(), nullable=False), + StructField("entity_id", StringType(), nullable=False), + StructField("type", StringType(), nullable=True), + StructField("length", IntegerType(), nullable=True), + StructField("checksum", StringType(), nullable=True), + ] + ), + "Statement": StructType( + [ + StructField("subject", StringType(), nullable=True), + StructField("predicate", StringType(), nullable=True), + StructField("object", StringType(), nullable=True), + StructField("value", StringType(), nullable=True), + StructField("datatype", StringType(), nullable=True), + StructField("language", StringType(), nullable=True), + ] + ), + "TextValue": StructType( + [ + StructField("text_value", StringType(), nullable=False), + StructField("language", StringType(), nullable=True), + StructField("raw_value", StringType(), nullable=True), + StructField("type", StringType(), nullable=True), + StructField("attribute_cv_id", StringType(), nullable=True), + StructField("attribute_cv_label", StringType(), nullable=True), + StructField("attribute_string", StringType(), nullable=True), + StructField("entity_id", StringType(), nullable=False), + ] + ), + "Variable": StructType( + [ + StructField("variable_id", StringType(), nullable=False), + StructField("name", StringType(), nullable=True), + StructField("description", StringType(), nullable=True), + StructField("name_cv_id", StringType(), nullable=True), + StructField("unit", StringType(), nullable=True), + StructField("value_type", StringType(), nullable=False), + ] + ), + "VariableValue": StructType( + [ + StructField("variable_value_id", StringType(), nullable=False), + StructField("variable_id", StringType(), nullable=False), + StructField("value_type", StringType(), nullable=True), + ] + ), +} diff --git a/src/cdm_data_loader_utils/parsers/uniref.py b/src/cdm_data_loader_utils/parsers/uniref.py index 6f24bb3..6e1cdf3 100644 --- a/src/cdm_data_loader_utils/parsers/uniref.py +++ b/src/cdm_data_loader_utils/parsers/uniref.py @@ -41,19 +41,18 @@ import os import uuid import xml.etree.ElementTree as ET -from datetime import datetime +from datetime import UTC, datetime +from pathlib import Path from urllib.error import URLError -from datetime import timezone from urllib.request import urlretrieve + import click from delta import configure_spark_with_delta_pip from pyspark.sql import SparkSession from pyspark.sql.types import StringType, StructField, StructType -from pathlib import Path from cdm_data_loader_utils.parsers.xml_utils import get_text, parse_properties - logger = logging.getLogger(__name__) @@ -102,7 +101,7 @@ def get_timestamps( if not uniref_id: raise ValueError("get_timestamps: uniref_id must be a non-empty string") - now_dt = now or datetime.now(timezone.utc) + now_dt = now or datetime.now(UTC) updated_time = now_dt.isoformat(timespec="seconds") created_time = existing_created.get(uniref_id) or updated_time @@ -187,7 +186,6 @@ def get_accession_and_seed(dbref: ET.Element | None, ns: dict[str, str]) -> tupl """ Extract UniProtKB accession and is_seed status from a dbReference element. """ - if dbref is None: return None, False diff --git a/tests/parsers/refseq_importer/test_cdm_builders.py b/tests/parsers/refseq_importer/test_cdm_builders.py deleted file mode 100644 index 4cb8772..0000000 --- a/tests/parsers/refseq_importer/test_cdm_builders.py +++ /dev/null @@ -1,131 +0,0 @@ -import pytest -from pyspark.sql import SparkSession - -from cdm_data_loader_utils.parsers.refseq_importer.core.cdm_builders import ( - build_cdm_contig_collection, - build_cdm_entity, - build_cdm_identifier_rows, - build_cdm_name_rows, - build_entity_id, -) - -### pytest cdm_data_loader_utils.parsers.refseq_importer/tests/test_cdm_builders.py ### - - -# ------------------------------------------------------------- -# Spark fixture (session scope — runs once for whole test suite) -# ------------------------------------------------------------- -@pytest.fixture(scope="session") -def spark(): - spark = ( - SparkSession.builder.master("local[1]") - .appName("cdm_data_loader_utils.parsers.refseq_importer_tests") - .getOrCreate() - ) - yield spark - spark.stop() - - -# ============================================================= -# TEST build_entity_id -# ============================================================= -@pytest.mark.requires_spark -@pytest.mark.parametrize("input_key", ["abc", " hello ", "", "123", "GCF_0001"]) -def test_build_entity_id_prefix(input_key) -> None: - eid = build_entity_id(input_key) - assert eid.startswith("CDM:") - assert len(eid) > 10 # UUID v5 non-empty - assert isinstance(eid, str) - - -# ============================================================= -# TEST build_cdm_entity -# ============================================================= -@pytest.mark.requires_spark -def test_build_cdm_entity_basic(spark) -> None: - df, eid = build_cdm_entity(spark, key_for_uuid="ABC123", created_date="2020-01-01") - - row = df.collect()[0] - assert row.entity_id == eid - assert row.entity_type == "contig_collection" - assert row.data_source == "RefSeq" - assert row.created == "2020-01-01" - - -# ============================================================= -# TEST build_cdm_contig_collection -# ============================================================= -@pytest.mark.requires_spark -@pytest.mark.parametrize(("taxid", "expected"), [("1234", "NCBITaxon:1234"), (None, None), ("999", "NCBITaxon:999")]) -def test_build_cdm_contig_collection_param(spark, taxid, expected) -> None: - df = build_cdm_contig_collection(spark, entity_id="CDM:xyz", taxid=taxid) - row = df.collect()[0] - assert row.collection_id == "CDM:xyz" - assert row.ncbi_taxon_id == expected - - -# ============================================================= -# TEST build_cdm_name_rows -# ============================================================= -@pytest.mark.requires_spark -def test_build_cdm_name_rows(spark) -> None: - rep = {"organism": {"name": "Escherichia coli"}, "assembly": {"display_name": "GCF_test_assembly"}} - - df = build_cdm_name_rows(spark, "CDM:abc", rep) - rows = df.collect() - - names = {r.name for r in rows} - assert "Escherichia coli" in names - assert "GCF_test_assembly" in names - - -# ============================================================= -# TEST build_cdm_identifier_rows (parametrize!) -# ============================================================= -@pytest.mark.parametrize( - ("rep", "request_taxid", "expected_identifiers"), - [ - # Case 1 – full fields - ( - {"biosample": ["BS1"], "bioproject": ["BP1"], "taxid": "123"}, - "123", - {"Biosample:BS1", "BioProject:BP1", "NCBITaxon:123"}, - ), - # Case 2 – only taxid - ( - {"biosample": [], "bioproject": [], "taxid": "999"}, - "999", - {"NCBITaxon:999"}, - ), - # Case 3 – fallback taxid used - ( - {"biosample": ["X"], "bioproject": [], "taxid": None}, - "888", - {"Biosample:X", "NCBITaxon:888"}, - ), - # Case 4 – GCF/GCA accessions - ( - { - "biosample": ["BS"], - "bioproject": [], - "taxid": "555", - "assembly": {"assembly_accession": ["GCF_0001"], "insdc_assembly_accession": ["GCA_0002"]}, - }, - "555", - {"Biosample:BS", "NCBITaxon:555", "ncbi.assembly:GCF_0001", "insdc.gca:GCA_0002"}, - ), - ], -) -def test_build_cdm_identifier_rows_param(rep, request_taxid, expected_identifiers) -> None: - # Convert mock representation into what extract_assembly_accessions expects - fake_rep = { - "biosample": rep.get("biosample", []), - "bioproject": rep.get("bioproject", []), - "taxid": rep.get("taxid"), - "assembly": rep.get("assembly", {}), - } - - rows = build_cdm_identifier_rows("CDM:123", fake_rep, request_taxid) - identifiers = {r["identifier"] for r in rows} - - assert identifiers == expected_identifiers diff --git a/tests/parsers/refseq_importer/test_extractors.py b/tests/parsers/refseq_importer/test_extractors.py deleted file mode 100644 index 317fa7e..0000000 --- a/tests/parsers/refseq_importer/test_extractors.py +++ /dev/null @@ -1,192 +0,0 @@ -import pytest - -from cdm_data_loader_utils.parsers.refseq_importer.core.extractors import ( - PAT_BIOSAMPLE, - _coalesce, - _deep_collect_regex, - _deep_find_str, - extract_assembly_accessions, - extract_assembly_name, - extract_bioproject_ids, - extract_biosample_ids, - extract_created_date, - extract_organism_name, - extract_taxid, -) - - -# --------------------------------------------- -# _coalesce -# --------------------------------------------- -@pytest.mark.parametrize( - ("vals", "expected"), - [ - (["", " ", "abc"], "abc"), - ([" x ", None, ""], "x"), - ([None, "", " "], None), - (["A", "B"], "A"), - ], -) -def test_coalesce(vals, expected) -> None: - assert _coalesce(*vals) == expected - - -# --------------------------------------------- -# _deep_find_str -# --------------------------------------------- -def test_deep_find_str() -> None: - obj = { - "level1": { - "target": "VALUE", - "other": [{"target": "SECOND"}], - } - } - res = _deep_find_str(obj, {"target"}) - assert res == "VALUE" - - -# --------------------------------------------- -# _deep_collect_regex -# --------------------------------------------- -def test_deep_collect_regex() -> None: - obj = { - "a": "SAMN123", - "b": ["xxx SAMN999 yyy", {"k": "SAMN555"}], - } - result = _deep_collect_regex(obj, PAT_BIOSAMPLE) - assert result == ["SAMN123", "SAMN555", "SAMN999"] - - -# --------------------------------------------- -# extract_created_date -# --------------------------------------------- -def test_extract_created_date_refseq() -> None: - rep = { - "assembly_info": { - "sourceDatabase": "SOURCE_DATABASE_REFSEQ", - "releaseDate": "2020-01-01", - "assemblyDate": "2019-01-01", - "submissionDate": "2018-01-01", - } - } - assert extract_created_date(rep) == "2020-01-01" - - -def test_extract_created_date_genbank() -> None: - rep = { - "assembly_info": { - "sourceDatabase": "SOURCE_DATABASE_GENBANK", - "submissionDate": "2018-01-01", - } - } - assert extract_created_date(rep, allow_genbank_date=True) == "2018-01-01" - - -# --------------------------------------------- -# extract_assembly_name -# --------------------------------------------- -@pytest.mark.parametrize( - ("rep", "expected"), - [ - ({"assemblyInfo": {"assemblyName": "ASM1"}}, "ASM1"), - ({"assembly": {"assemblyName": "ASM2"}}, "ASM2"), - ({"assembly": {"display_name": "ASM3"}}, "ASM3"), - ({"assembly": {"displayName": "ASM4"}}, "ASM4"), - ({"assembly": {"nested": {"display_name": "ASM5"}}}, "ASM5"), - ], -) -def test_extract_assembly_name(rep, expected) -> None: - assert extract_assembly_name(rep) == expected - - -# --------------------------------------------- -# extract_organism_name -# --------------------------------------------- -@pytest.mark.parametrize( - ("rep", "expected"), - [ - ({"organism": {"scientificName": "E. coli"}}, "E. coli"), - ({"organism": {"name": "Bacteria X"}}, "Bacteria X"), - ({"assembly": {"organism": {"organismName": "ABC"}}}, "ABC"), - ({"nested": {"organismName": "NNN"}}, "NNN"), - ], -) -def test_extract_organism_name(rep, expected) -> None: - assert extract_organism_name(rep) == expected - - -# --------------------------------------------- -# extract_taxid -# --------------------------------------------- -@pytest.mark.parametrize( - ("rep", "expected"), - [ - ({"organism": {"taxId": 123}}, "123"), - ({"organism": {"taxid": "456"}}, "456"), - ({"organism": {"taxID": "789"}}, "789"), - ({"nested": {"deep": {"taxid": 999}}}, "999"), - ], -) -def test_extract_taxid(rep, expected) -> None: - assert extract_taxid(rep) == expected - - -# --------------------------------------------- -# extract_biosample_ids -# --------------------------------------------- -@pytest.mark.parametrize( - ("rep", "expected"), - [ - ({"biosample": ["BS1", "BS2"]}, ["BS1", "BS2"]), - ({"assemblyInfo": {"biosample": {"accession": "BS3"}}}, ["BS3"]), - ({"biosample": [{"biosampleAccession": "BS4"}]}, ["BS4"]), - ({"text": "random SAMN111 text"}, ["SAMN111"]), - ], -) -def test_extract_biosample_ids(rep, expected) -> None: - assert extract_biosample_ids(rep) == expected - - -# --------------------------------------------- -# extract_bioproject_ids -# --------------------------------------------- -@pytest.mark.parametrize( - ("rep", "expected"), - [ - ({"bioproject": ["BP1"]}, ["BP1"]), - ({"assemblyInfo": {"bioproject": {"accession": "BP2"}}}, ["BP2"]), - ({"bioproject": [{"bioprojectAccession": "BP3"}]}, ["BP3"]), - ({"text": "abc PRJNA999 xyz"}, ["PRJNA999"]), - ], -) -def test_extract_bioproject_ids(rep, expected) -> None: - assert extract_bioproject_ids(rep) == expected - - -# --------------------------------------------- -# extract_assembly_accessions -# --------------------------------------------- -@pytest.mark.parametrize( - ("rep", "expected_gcf", "expected_gca"), - [ - ( - {"assembly": {"assembly_accession": ["GCF_0001.1"], "insdc_assembly_accession": ["GCA_0002.1"]}}, - ["GCF_0001.1"], - ["GCA_0002.1"], - ), - ( - {"accession": "GCF_1111.1"}, - ["GCF_1111.1"], - [], - ), - ( - {"assembly_info": {"paired_assembly": {"accession": "GCA_2222.1"}}}, - [], - ["GCA_2222.1"], - ), - ], -) -def test_extract_assembly_accessions(rep, expected_gcf, expected_gca) -> None: - gcf, gca = extract_assembly_accessions(rep) - assert gcf == expected_gcf - assert gca == expected_gca diff --git a/tests/parsers/refseq_importer/test_refseq_api_cli.py b/tests/parsers/refseq_importer/test_refseq_api_cli.py deleted file mode 100644 index e1cf175..0000000 --- a/tests/parsers/refseq_importer/test_refseq_api_cli.py +++ /dev/null @@ -1,86 +0,0 @@ -from unittest.mock import MagicMock, patch - -from click.testing import CliRunner - -from cdm_data_loader_utils.parsers.refseq_importer.cli.refseq_api_cli import ( - cli, - main, - parse_taxid_args, -) - -# ------------------------------------------------- -# test_parse_taxid_args -# ------------------------------------------------- - - -def test_parse_taxid_args_basic() -> None: - taxids = parse_taxid_args("123, 456x, abc789", None) - assert taxids == ["123", "456", "789"] - - -def test_parse_taxid_args_file(tmp_path) -> None: - file = tmp_path / "ids.txt" - file.write_text("111\n222x\n333\n") - - taxids = parse_taxid_args("123", str(file)) - assert taxids == ["123", "111", "222", "333"] - - -# ------------------------------------------------- -# test main() -# ------------------------------------------------- - - -@patch("cdm_data_loader_utils.parsers.refseq_importer.cli.refseq_api_cli.write_and_preview") -@patch("cdm_data_loader_utils.parsers.refseq_importer.cli.refseq_api_cli.finalize_tables") -@patch("cdm_data_loader_utils.parsers.refseq_importer.cli.refseq_api_cli.process_taxon") -@patch("cdm_data_loader_utils.parsers.refseq_importer.cli.refseq_api_cli.write_delta") -@patch("cdm_data_loader_utils.parsers.refseq_importer.cli.refseq_api_cli.build_spark") -def test_main_end_to_end( - mock_build, - mock_write_delta, - mock_process, - mock_finalize, - mock_preview, -) -> None: - mock_spark = MagicMock() - mock_build.return_value = mock_spark - mock_process.return_value = (["E"], ["C"], ["N"], ["I"]) - mock_finalize.return_value = ("EE", "CC", "NN", "II") - - main( - taxid="123", - api_key=None, - database="refseq_api", - mode="overwrite", - debug=False, - allow_genbank_date=False, - unique_per_taxon=False, - data_dir="/tmp", - ) - - mock_build.assert_called_once() - mock_process.assert_called_once() - mock_write_delta.assert_called() - mock_preview.assert_called_once() - - -# ------------------------------------------------- -# test cli() wrapper -# ------------------------------------------------- - - -def test_cli_invocation() -> None: - with patch("cdm_data_loader_utils.parsers.refseq_importer.cli.refseq_api_cli.main") as mock_main: - runner = CliRunner() - result = runner.invoke( - cli, - [ - "--taxid", - "123", - "--data-dir", - "/tmp", - ], - ) - assert result.exit_code == 0 - mock_main.assert_called_once() diff --git a/tests/parsers/refseq_importer/test_spark_delta.py b/tests/parsers/refseq_importer/test_spark_delta.py deleted file mode 100644 index b5cd9d0..0000000 --- a/tests/parsers/refseq_importer/test_spark_delta.py +++ /dev/null @@ -1,211 +0,0 @@ -import os -import shutil - -import pytest -from pyspark.sql import Row, SparkSession -from pyspark.sql.types import StringType, StructField, StructType - -from cdm_data_loader_utils.parsers.refseq_importer.core.spark_delta import ( - build_spark, - preview_or_skip, - write_delta, -) - - -# ============================================================= -# Spark fixture -# ============================================================= -@pytest.fixture(scope="session") -def spark(): - spark = ( - SparkSession.builder.master("local[1]") - .appName("spark-delta-test") - .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") - .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") - .getOrCreate() - ) - yield spark - spark.stop() - - -# ============================================================= -# build_spark -# ============================================================= - - -@pytest.mark.requires_spark -def test_build_spark_creates_database(tmp_path) -> None: - db = "testdb" - spark = build_spark(db) - dbs = [d.name for d in spark.catalog.listDatabases()] - assert db in dbs - - -# ============================================================= -# write_delta (managed table) -# ============================================================= - - -@pytest.mark.skip("See tests/utils/test_delta_spark.py") -@pytest.mark.requires_spark -def test_write_delta_managed_table(spark) -> None: - db = "writetest" - spark.sql(f"CREATE DATABASE IF NOT EXISTS {db}") - - df = spark.createDataFrame([Row(a="X", b="Y")]) - - write_delta( - spark=spark, - df=df, - database=db, - table="example", - mode="overwrite", - data_dir=None, - ) - - # Table should exist - assert spark.catalog.tableExists(f"{db}.example") - - # Data should exist - rows = spark.sql(f"SELECT a, b FROM {db}.example").collect() - assert rows[0]["a"] == "X" - assert rows[0]["b"] == "Y" - - -# ============================================================= -# write_delta with external LOCATION -# ============================================================= - - -@pytest.mark.skip("See tests/utils/test_delta_spark.py") -@pytest.mark.requires_spark -def test_write_delta_external_location(spark, tmp_path) -> None: - db = "externaldb" - spark.sql(f"CREATE DATABASE IF NOT EXISTS {db}") - - df = spark.createDataFrame([Row(id="1", value="A")]) - - write_delta( - spark=spark, - df=df, - database=db, - table="exttable", - mode="overwrite", - data_dir=str(tmp_path), - ) - - # Table should be registered - assert spark.catalog.tableExists(f"{db}.exttable") - - # Data should be readable - rows = spark.sql(f"SELECT * FROM {db}.exttable").collect() - assert rows[0]["id"] == "1" - assert rows[0]["value"] == "A" - - -# ============================================================= -# write_delta special schema: contig_collection -# ============================================================= - - -@pytest.mark.skip("See tests/utils/test_delta_spark.py") -@pytest.mark.requires_spark -def test_write_delta_contig_collection_schema(spark) -> None: - db = "cdmdb" - spark.sql(f"CREATE DATABASE IF NOT EXISTS {db}") - - schema = StructType( - [ - StructField("collection_id", StringType(), True), - StructField("contig_collection_type", StringType(), True), - StructField("ncbi_taxon_id", StringType(), True), - StructField("gtdb_taxon_id", StringType(), True), - ] - ) - - df = spark.createDataFrame( - [("C1", "isolate", "NCBITaxon:123", None)], - schema=schema, - ) - - write_delta( - spark=spark, - df=df, - database=db, - table="contig_collection", - mode="overwrite", - data_dir=None, - ) - - result = spark.sql(f"SELECT * FROM {db}.contig_collection").collect()[0] - - assert result["collection_id"] == "C1" - assert result["contig_collection_type"] == "isolate" - assert result["ncbi_taxon_id"] == "NCBITaxon:123" - assert result["gtdb_taxon_id"] is None - - -# ============================================================= -# write_delta skip when empty -# ============================================================= - - -@pytest.mark.requires_spark -def test_write_delta_empty_df(spark, capsys) -> None: - db = "emptydb" - spark.sql(f"CREATE DATABASE IF NOT EXISTS {db}") - - # Create empty df - df = spark.createDataFrame([], schema="a string") - - write_delta( - spark=spark, - df=df, - database=db, - table="emptytable", - mode="overwrite", - data_dir=None, - ) - - captured = capsys.readouterr().out - assert "No data to write" in captured - - -# ============================================================= -# preview_or_skip -# ============================================================= - - -@pytest.mark.requires_spark -def test_preview_or_skip_existing(spark, capsys) -> None: - db = "previewdb" - spark.sql(f"CREATE DATABASE IF NOT EXISTS {db}") - - # Drop table - spark.sql(f"DROP TABLE IF EXISTS {db}.t1") - - # Delete physical directory to avoid LOCATION_ALREADY_EXISTS - warehouse_dir = os.path.abspath("spark-warehouse/previewdb.db/t1") - shutil.rmtree(warehouse_dir, ignore_errors=True) - - # Create table again - spark.sql(f"CREATE TABLE {db}.t1 (x STRING)") - - # Insert sample row - spark.sql(f"INSERT INTO {db}.t1 VALUES ('hello')") - - preview_or_skip(spark, db, "t1") - - captured = capsys.readouterr().out - assert "hello" in captured - - -@pytest.mark.requires_spark -def test_preview_or_skip_missing(spark, capsys) -> None: - db = "missingdb" - spark.sql(f"CREATE DATABASE IF NOT EXISTS {db}") - - preview_or_skip(spark, db, "t9999") - - out = capsys.readouterr().out - assert "Skipping preview" in out diff --git a/tests/parsers/refseq_importer/test_tables_finalize.py b/tests/parsers/refseq_importer/test_tables_finalize.py deleted file mode 100644 index c71911c..0000000 --- a/tests/parsers/refseq_importer/test_tables_finalize.py +++ /dev/null @@ -1,119 +0,0 @@ -import pytest -from pyspark.sql import Row, SparkSession -from pyspark.sql.types import StringType, StructField, StructType - -from cdm_data_loader_utils.parsers.refseq_importer.core.tables_finalize import finalize_tables, list_of_dicts_to_spark - - -# ------------------------------------------------------------------- -# Spark fixture -# ------------------------------------------------------------------- -@pytest.fixture(scope="session") -def spark(): - spark = SparkSession.builder.master("local[*]").appName("test-tables-finalize").getOrCreate() - yield spark - spark.stop() - - -# ------------------------------------------------------------------- -# Test list_of_dicts_to_spark -# ------------------------------------------------------------------- -@pytest.mark.requires_spark -def test_list_of_dicts_to_spark(spark) -> None: - schema = StructType( - [ - StructField("a", StringType(), True), - StructField("b", StringType(), True), - ] - ) - - rows = [{"a": "1", "b": "x"}, {"a": "2", "b": "y"}] - df = list_of_dicts_to_spark(spark, rows, schema) - - out = {(r.a, r.b) for r in df.collect()} - assert out == {("1", "x"), ("2", "y")} - - -# ------------------------------------------------------------------- -# Test finalize_tables end-to-end -# ------------------------------------------------------------------- -@pytest.mark.requires_spark -def test_finalize_tables_basic(spark) -> None: - # ---------- entity ---------- - e_schema = StructType( - [ - StructField("entity_id", StringType(), True), - StructField("entity_type", StringType(), True), - StructField("data_source", StringType(), True), - StructField("created", StringType(), True), - StructField("updated", StringType(), True), - ] - ) - - e1 = spark.createDataFrame( - [Row(entity_id="E1", entity_type="genome", data_source="RefSeq", created="2020", updated="2021")], - schema=e_schema, - ) - e2 = spark.createDataFrame( - [Row(entity_id="E2", entity_type="genome", data_source="RefSeq", created="2020", updated="2021")], - schema=e_schema, - ) - - # ---------- contig_collection (schema REQUIRED due to None!) ---------- - coll_schema = StructType( - [ - StructField("collection_id", StringType(), True), - StructField("contig_collection_type", StringType(), True), - StructField("ncbi_taxon_id", StringType(), True), - StructField("gtdb_taxon_id", StringType(), True), - ] - ) - - c1 = spark.createDataFrame( - [ - Row( - collection_id="E1", - contig_collection_type="isolate", - ncbi_taxon_id="NCBITaxon:1", - gtdb_taxon_id=None, - ) - ], - schema=coll_schema, - ) - - c2 = spark.createDataFrame( - [ - Row( - collection_id="E2", - contig_collection_type="isolate", - ncbi_taxon_id="NCBITaxon:2", - gtdb_taxon_id=None, - ) - ], - schema=coll_schema, - ) - - # ---------- name ---------- - names = [ - {"entity_id": "E1", "name": "A", "description": "d1", "source": "RefSeq"}, - {"entity_id": "E2", "name": "B", "description": "d2", "source": "RefSeq"}, - ] - - # ---------- identifier ---------- - identifiers = [ - {"entity_id": "E1", "identifier": "BioSample:1", "source": "RefSeq", "description": "bs"}, - {"entity_id": "E2", "identifier": "BioSample:2", "source": "RefSeq", "description": "bs"}, - ] - - df_entity, df_coll, df_name, df_ident = finalize_tables(spark, [e1, e2], [c1, c2], names, identifiers) - - # ---------- Assertions ---------- - assert df_entity.count() == 2 - assert df_coll.count() == 2 - assert df_name.count() == 2 - assert df_ident.count() == 2 - - assert {r.entity_id for r in df_entity.collect()} == {"E1", "E2"} - assert {r.collection_id for r in df_coll.collect()} == {"E1", "E2"} - assert {r.name for r in df_name.collect()} == {"A", "B"} - assert {r.identifier for r in df_ident.collect()} == {"BioSample:1", "BioSample:2"} diff --git a/tests/parsers/test_annotation_parse.py b/tests/parsers/test_annotation_parse.py new file mode 100644 index 0000000..763b8ca --- /dev/null +++ b/tests/parsers/test_annotation_parse.py @@ -0,0 +1,713 @@ +import json + +import pytest + +from cdm_data_loader_utils.parsers.annotation_parse import ( + load_contig_collection_x_feature, + load_contig_collection_x_protein, + load_contig_x_contig_collection, + load_contigs, + load_feature_records, + load_feature_x_protein, + load_identifiers, + load_names, +) + + +@pytest.mark.parametrize( + "input_data, expected_output", + [ + ( + { + "reports": [ + { + "annotation": { + "gene_id": "1234", + "name": "hypothetical protein", + "relationship": "RefSeq gene symbol", + } + } + ] + }, + [ + ( + "ncbigene:1234", + "1234", + "hypothetical protein", + "RefSeq", + "RefSeq gene symbol", + ) + ], + ), + ( + {"reports": [{"annotation": {"gene_id": "5678", "name": "some protein"}}]}, + [("ncbigene:5678", "5678", "some protein", "RefSeq", None)], + ), + ( + { + "reports": [ + { + "annotation": { + "name": "no gene id here", + "relationship": "RefSeq locus tag", + } + } + ] + }, + [], + ), + ( + { + "reports": [ + { + "annotation": { + "gene_id": "1001", + "name": "abc", + "relationship": "RefSeq gene symbol", + } + }, + {"annotation": {"gene_id": "1002", "name": "xyz"}}, + ] + }, + [ + ("ncbigene:1001", "1001", "abc", "RefSeq", "RefSeq gene symbol"), + ("ncbigene:1002", "1002", "xyz", "RefSeq", None), + ], + ), + ], +) +def test_load_identifiers(tmp_path, input_data, expected_output): + input_file = tmp_path / "test.json" + input_file.write_text(json.dumps(input_data)) + + result = load_identifiers(input_file) + assert result == expected_output + + +@pytest.mark.parametrize( + "input_data, expected_output", + [ + # Case 1: all name fields present + ( + { + "reports": [ + { + "annotation": { + "gene_id": "1234", + "symbol": "abc", + "name": "ABC protein", + "locus_tag": "LTG_1234", + } + } + ] + }, + [ + ("ncbigene:1234", "abc", "RefSeq gene symbol", "RefSeq"), + ("ncbigene:1234", "ABC protein", "RefSeq gene name", "RefSeq"), + ("ncbigene:1234", "LTG_1234", "RefSeq locus tag", "RefSeq"), + ], + ), + # Case 2: only gene_name present + ( + {"reports": [{"annotation": {"gene_id": "5678", "name": "Hypothetical protein"}}]}, + [ + ( + "ncbigene:5678", + "Hypothetical protein", + "RefSeq gene name", + "RefSeq", + ) + ], + ), + # Case 3: no gene_id + ( + {"reports": [{"annotation": {"name": "Unnamed", "symbol": "XYZ"}}]}, + [], + ), + # Case 4: only locus_tag present + ( + {"reports": [{"annotation": {"gene_id": "8888", "locus_tag": "LTG_8888"}}]}, + [("ncbigene:8888", "LTG_8888", "RefSeq locus tag", "RefSeq")], + ), + # Case 5: multiple reports + ( + { + "reports": [ + {"annotation": {"gene_id": "1001", "symbol": "DEF"}}, + {"annotation": {"gene_id": "1002", "name": "DEF protein"}}, + ] + }, + [ + ("ncbigene:1001", "DEF", "RefSeq gene symbol", "RefSeq"), + ("ncbigene:1002", "DEF protein", "RefSeq gene name", "RefSeq"), + ], + ), + ], +) +def test_load_names(tmp_path, input_data, expected_output): + input_file = tmp_path / "test.json" + input_file.write_text(json.dumps(input_data)) + + result = load_names(input_file) + assert sorted(result) == sorted(expected_output) + + +@pytest.mark.parametrize( + "input_data, expected_output", + [ + # Case 1: basic valid input with plus strand + ( + { + "reports": [ + { + "annotation": { + "gene_id": "1234", + "genomic_regions": [ + { + "gene_range": { + "range": [ + { + "begin": "100", + "end": "200", + "orientation": "plus", + } + ] + } + } + ], + } + } + ] + }, + [ + ( + "ncbigene:1234", + None, + None, + None, + 200, + None, + 100, + "positive", + "RefSeq", + None, + "gene", + ) + ], + ), + # Case 2: multiple ranges, different strands + ( + { + "reports": [ + { + "annotation": { + "gene_id": "5678", + "genomic_regions": [ + { + "gene_range": { + "range": [ + { + "begin": "300", + "end": "500", + "orientation": "minus", + }, + { + "begin": "600", + "end": "800", + "orientation": "plus", + }, + ] + } + } + ], + } + } + ] + }, + [ + ( + "ncbigene:5678", + None, + None, + None, + 500, + None, + 300, + "negative", + "RefSeq", + None, + "gene", + ), + ( + "ncbigene:5678", + None, + None, + None, + 800, + None, + 600, + "positive", + "RefSeq", + None, + "gene", + ), + ], + ), + # Case 3: missing orientation + ( + { + "reports": [ + { + "annotation": { + "gene_id": "9999", + "genomic_regions": [{"gene_range": {"range": [{"begin": "1", "end": "2"}]}}], + } + } + ] + }, + [ + ( + "ncbigene:9999", + None, + None, + None, + 2, + None, + 1, + "unknown", + "RefSeq", + None, + "gene", + ) + ], + ), + # Case 4: no gene_id + ( + { + "reports": [ + { + "annotation": { + "genomic_regions": [ + { + "gene_range": { + "range": [ + { + "begin": "100", + "end": "200", + "orientation": "plus", + } + ] + } + } + ] + } + } + ] + }, + [], + ), + # Case 5: non-integer start/end + ( + { + "reports": [ + { + "annotation": { + "gene_id": "1111", + "genomic_regions": [ + { + "gene_range": { + "range": [ + { + "begin": "abc", + "end": "xyz", + "orientation": "plus", + } + ] + } + } + ], + } + } + ] + }, + [ + ( + "ncbigene:1111", + None, + None, + None, + None, + None, + None, + "positive", + "RefSeq", + None, + "gene", + ) + ], + ), + ], +) +def test_load_feature_records(tmp_path, input_data, expected_output): + input_file = tmp_path / "features.json" + input_file.write_text(json.dumps(input_data)) + + result = load_feature_records(input_file) + assert sorted(result) == sorted(expected_output) + + +@pytest.mark.parametrize( + "input_data, expected_output", + [ + # Case 1: valid mapping + ( + { + "reports": [ + { + "annotation": { + "gene_id": "12345", + "genomic_regions": [{"gene_range": {"accession_version": "NC_000001.11"}}], + } + } + ] + }, + [("refseq:NC_000001.11", "ncbigene:12345")], + ), + # Case 2: no gene_id + ( + {"reports": [{"annotation": {"genomic_regions": [{"gene_range": {"accession_version": "NC_000002.11"}}]}}]}, + [], + ), + # Case 3: no genomic_regions + ( + {"reports": [{"annotation": {"gene_id": "67890"}}]}, + [], + ), + # Case 4: empty genomic_regions list + ( + {"reports": [{"annotation": {"gene_id": "99999", "genomic_regions": []}}]}, + [], + ), + # Case 5: missing accession_version + ( + { + "reports": [ + { + "annotation": { + "gene_id": "13579", + "genomic_regions": [{"gene_range": {}}], + } + } + ] + }, + [], + ), + ], +) +def test_load_contig_collection_x_feature(tmp_path, input_data, expected_output): + input_file = tmp_path / "contig_feature.json" + input_file.write_text(json.dumps(input_data)) + + result = load_contig_collection_x_feature(input_file) + assert result == expected_output + + +@pytest.mark.parametrize( + "input_data, expected_output", + [ + # Case 1: Valid report with multiple proteins + ( + { + "reports": [ + { + "annotation": { + "proteins": [ + {"accession_version": "XP_123"}, + {"accession_version": "XP_456"}, + ], + "annotations": [{"assembly_accession": "GCF_000001"}], + } + } + ] + }, + [ + ("insdc.gcf:GCF_000001", "refseq:XP_123"), + ("insdc.gcf:GCF_000001", "refseq:XP_456"), + ], + ), + # Case 2: No proteins + ( + { + "reports": [ + { + "annotation": { + "proteins": [], + "annotations": [{"assembly_accession": "GCF_000002"}], + } + } + ] + }, + [], + ), + # Case 3: No annotations + ( + {"reports": [{"annotation": {"proteins": [{"accession_version": "XP_789"}]}}]}, + [], + ), + # Case 4: Missing assembly_accession + ( + { + "reports": [ + { + "annotation": { + "proteins": [{"accession_version": "XP_789"}], + "annotations": [{}], + } + } + ] + }, + [], + ), + # Case 5: Some proteins missing accession_version + ( + { + "reports": [ + { + "annotation": { + "proteins": [ + {"accession_version": "XP_111"}, + {}, + {"accession_version": "XP_222"}, + ], + "annotations": [{"assembly_accession": "GCF_000003"}], + } + } + ] + }, + [ + ("insdc.gcf:GCF_000003", "refseq:XP_111"), + ("insdc.gcf:GCF_000003", "refseq:XP_222"), + ], + ), + ], +) +def test_load_contig_collection_x_protein(tmp_path, input_data, expected_output): + input_file = tmp_path / "protein_links.json" + input_file.write_text(json.dumps(input_data)) + + result = load_contig_collection_x_protein(input_file) + assert sorted(result) == sorted(expected_output) + + +@pytest.mark.parametrize( + "input_data, expected_output", + [ + # Case 1: valid gene with multiple proteins + ( + { + "reports": [ + { + "annotation": { + "gene_id": "4156311", + "proteins": [ + {"accession_version": "XP_001"}, + {"accession_version": "XP_002"}, + ], + } + } + ] + }, + [ + ("ncbigene:4156311", "refseq:XP_001"), + ("ncbigene:4156311", "refseq:XP_002"), + ], + ), + # Case 2: no gene_id + ( + {"reports": [{"annotation": {"proteins": [{"accession_version": "XP_999"}]}}]}, + [], + ), + # Case 3: gene with no proteins + ( + {"reports": [{"annotation": {"gene_id": "4156312"}}]}, + [], + ), + # Case 4: some proteins missing accession_version + ( + { + "reports": [ + { + "annotation": { + "gene_id": "4156313", + "proteins": [ + {"accession_version": "XP_777"}, + {}, + {"accession_version": "XP_888"}, + ], + } + } + ] + }, + [ + ("ncbigene:4156313", "refseq:XP_777"), + ("ncbigene:4156313", "refseq:XP_888"), + ], + ), + # Case 5: empty report list + ({"reports": []}, []), + ], +) +def test_load_feature_x_protein(tmp_path, input_data, expected_output): + input_file = tmp_path / "feature_protein.json" + input_file.write_text(json.dumps(input_data)) + + result = load_feature_x_protein(input_file) + assert sorted(result) == sorted(expected_output) + + +@pytest.mark.parametrize( + "input_data, expected_output", + [ + # Case 1: Valid contig and assembly + ( + { + "reports": [ + { + "annotation": { + "genomic_regions": [{"gene_range": {"accession_version": "NC_000001.11"}}], + "annotations": [{"assembly_accession": "GCF_000001.1"}], + } + } + ] + }, + [("refseq:NC_000001.11", "insdc.gcf:GCF_000001.1")], + ), + # Case 2: Missing genomic_regions + ( + {"reports": [{"annotation": {"annotations": [{"assembly_accession": "GCF_000002.1"}]}}]}, + [], + ), + # Case 3: Missing annotations + ( + {"reports": [{"annotation": {"genomic_regions": [{"gene_range": {"accession_version": "NC_000003.11"}}]}}]}, + [], + ), + # Case 4: Missing accession_version in region + ( + { + "reports": [ + { + "annotation": { + "genomic_regions": [{"gene_range": {}}], + "annotations": [{"assembly_accession": "GCF_000004.1"}], + } + } + ] + }, + [], + ), + # Case 5: Missing assembly_accession in annotations + ( + { + "reports": [ + { + "annotation": { + "genomic_regions": [{"gene_range": {"accession_version": "NC_000005.11"}}], + "annotations": [{}], + } + } + ] + }, + [], + ), + # Case 6: Multiple reports, one valid + ( + { + "reports": [ + { + "annotation": { + "genomic_regions": [{"gene_range": {"accession_version": "NC_000006.11"}}], + "annotations": [{"assembly_accession": "GCF_000006.1"}], + } + }, + { + "annotation": { + "genomic_regions": [{"gene_range": {"accession_version": "NC_000007.11"}}], + "annotations": [{}], + } + }, + ] + }, + [("refseq:NC_000006.11", "insdc.gcf:GCF_000006.1")], + ), + ], +) +def test_load_contig_x_contig_collection(tmp_path, input_data, expected_output): + input_file = tmp_path / "contig_collection.json" + input_file.write_text(json.dumps(input_data)) + + result = load_contig_x_contig_collection(input_file) + assert sorted(result) == sorted(expected_output) + + +@pytest.mark.parametrize( + "input_data, expected_output", + [ + # Case 1: Valid contig with accession_version + ( + {"reports": [{"annotation": {"genomic_regions": [{"gene_range": {"accession_version": "NC_000001.11"}}]}}]}, + [("refseq:NC_000001.11", None, None, None)], + ), + # Case 2: Multiple contigs, different accession_versions + ( + { + "reports": [ + { + "annotation": { + "genomic_regions": [ + {"gene_range": {"accession_version": "NC_000001.11"}}, + {"gene_range": {"accession_version": "NC_000002.12"}}, + ] + } + } + ] + }, + [ + ("refseq:NC_000001.11", None, None, None), + ("refseq:NC_000002.12", None, None, None), + ], + ), + # Case 3: Duplicate accession versions + ( + { + "reports": [ + { + "annotation": { + "genomic_regions": [ + {"gene_range": {"accession_version": "NC_000003.13"}}, + {"gene_range": {"accession_version": "NC_000003.13"}}, + ] + } + } + ] + }, + [("refseq:NC_000003.13", None, None, None)], + ), + # Case 4: Missing accession_version + ( + {"reports": [{"annotation": {"genomic_regions": [{"gene_range": {}}]}}]}, + [], + ), + # Case 5: Empty reports + ( + {"reports": []}, + [], + ), + ], +) +def test_load_contigs(tmp_path, input_data, expected_output): + input_file = tmp_path / "contig.json" + input_file.write_text(json.dumps(input_data, indent=2)) + + result = load_contigs(input_file) + assert sorted(result) == sorted(expected_output) diff --git a/tests/parsers/test_gene_association_file.py b/tests/parsers/test_gene_association_file.py deleted file mode 100644 index 5225656..0000000 --- a/tests/parsers/test_gene_association_file.py +++ /dev/null @@ -1,206 +0,0 @@ -""" - -Unit tests for the association_update module. - -Run with: - python3 -m pytest test_association_update.py - -""" - -from pathlib import Path - -import pytest -from delta import configure_spark_with_delta_pip -from pyspark.sql import Row, SparkSession -from pyspark.sql.functions import col - -from cdm_data_loader_utils.parsers.gene_association_file import ( - AGGREGATOR, - ANNOTATION_DATE, - DB, - DB_OBJ_ID, - DB_REF, - EVIDENCE_CODE, - EVIDENCE_TYPE, - NEGATED, - PREDICATE, - PROTOCOL_ID, - PUBLICATIONS, - SUBJECT, - add_metadata, - load_annotation, - load_eco_mapping, - merge_evidence, - normalize_dates, - process_predicates, - write_output, -) - - -@pytest.fixture(scope="session") -def spark() -> SparkSession: - """Spark session fixture.""" - builder = ( - SparkSession.builder.master("local[1]") - .appName("TestAssociationUpdate") - .config("spark.sql.warehouse.dir", "/tmp/spark-warehouse") - .config("spark.ui.enabled", "false") - .config("spark.driver.bindAddress", "127.0.0.1") - .config("spark.driver.host", "127.0.0.1") - .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") - .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") - ) - return configure_spark_with_delta_pip(builder).getOrCreate() - - -@pytest.mark.requires_spark -def test_load_annotation(spark: SparkSession, tmp_path: Path) -> None: - """Test loading annotations.""" - test_csv = tmp_path / "test_input.csv" - test_csv.write_text("""DB,DB_Object_ID,Qualifier,GO_ID,DB_Reference,Evidence_Code,With_From,Date,Assigned_By -UniProtKB,P12345,enables,GO:0008150,PMID:123456,ECO:0000313,GO_REF:0000033,20240101,GO_Curator -""") - - df = load_annotation(spark, str(test_csv)) - row = df.collect()[0] - - assert row["predicate"] == "enables" - assert row["object"] == "GO:0008150" - assert row["publications"] == ["PMID:123456"] - assert row["supporting_objects"] == ["GO_REF:0000033"] - assert str(row["annotation_date"]) == "20240101" - assert row["primary_knowledge_source"] == "GO_Curator" - - -@pytest.mark.requires_spark -@pytest.mark.parametrize(("date_input", "expected"), [("20240101", "2024-01-01"), ("notadate", None)]) -def test_normalize_dates(spark: SparkSession, date_input: str, expected: str | None) -> None: - """Test normalizing dates.""" - df = spark.createDataFrame([(date_input,)], [ANNOTATION_DATE]) - result = normalize_dates(df).collect()[0][ANNOTATION_DATE] - - if expected is None: - assert result is None - else: - assert str(result) == expected - - -@pytest.mark.requires_spark -@pytest.mark.parametrize( - ("predicate_val", "expected_negated", "expected_cleaned_predicate"), - [("NOT|enables", True, "enables"), ("involved_in", False, "involved_in")], -) -def test_process_predicates( - spark: SparkSession, predicate_val: str, expected_negated: bool, expected_cleaned_predicate: str -) -> None: - """Test processing predicates.""" - df = spark.createDataFrame([(predicate_val,)], ["Qualifier"]) - df = df.withColumn(PREDICATE, col("Qualifier")) - result_df = process_predicates(df) - row = result_df.collect()[0] - - assert row[NEGATED] == expected_negated - assert row[PREDICATE] == expected_cleaned_predicate - - -@pytest.mark.requires_spark -@pytest.mark.parametrize( - ("db", "db_obj_id", "expected_subject"), - [ - ("UniProtKB", "P12345", "UniProtKB:P12345"), - ("TAIR", "AT1G01010", "TAIR:AT1G01010"), - ("MGI", "MGI:87938", "MGI:MGI:87938"), - ], -) -def test_add_metadata(spark: SparkSession, db: str, db_obj_id: str, expected_subject: str) -> None: - """Test adding metadata.""" - df = spark.createDataFrame([(db, db_obj_id)], [DB, DB_OBJ_ID]) - result_df = add_metadata(df) - row = result_df.collect()[0] - - assert row[AGGREGATOR] == "UniProt" - assert row[PROTOCOL_ID] is None - assert row[SUBJECT] == expected_subject - - -@pytest.mark.requires_spark -@pytest.mark.parametrize( - ("eco_content", "expected_rows"), - [ - ("ECO:0000313\tPMID:123456\tIEA\n", [("ECO:0000313", "PMID:123456", "IEA")]), - ( - "ECO:0000256\tPMID:789012\tEXP\nECO:0000244\tDEFAULT\tTAS\n", - [("ECO:0000256", "PMID:789012", "EXP"), ("ECO:0000244", "DEFAULT", "TAS")], - ), - ], -) -def test_load_eco_mapping_from_file( - spark: SparkSession, tmp_path: Path, eco_content: str, expected_rows: list[tuple[str, str, str]] -) -> None: - """Test loading the ECO mapping from a file.""" - eco_file = tmp_path / "gaf-eco-mapping.txt" - eco_file.write_text(eco_content) - - df = load_eco_mapping(spark, local_path=str(eco_file)) - result = [(row[EVIDENCE_CODE], row[DB_REF], row[EVIDENCE_TYPE]) for row in df.collect()] - assert result == expected_rows - - -@pytest.mark.requires_spark -@pytest.mark.parametrize( - ("annotation_rows", "eco_rows", "expected"), - [ - ( - # annotation df - [Row(evidence_code="ECO:0000313", publications=["PMID:123456"])], - # eco df - [Row(evidence_code="ECO:0000313", db_ref="PMID:123456", evidence_type="IEA")], - # expected result - [("ECO:0000313", "PMID:123456", "IEA")], - ), - ( - # Fallback case - [Row(evidence_code="ECO:0000256", publications=["PMID:999999"])], - [Row(evidence_code="ECO:0000256", db_ref="DEFAULT", evidence_type="EXP")], - [("ECO:0000256", "PMID:999999", "EXP")], - ), - ], -) -def test_merge_evidence( - spark: SparkSession, annotation_rows: Row, eco_rows: Row, expected: list[tuple[str, str, str]] -) -> None: - """Test merging the evidence mapping.""" - annotation_df = spark.createDataFrame(annotation_rows).select( - col("evidence_code").alias(EVIDENCE_CODE), col("publications").alias(PUBLICATIONS) - ) - eco_df = spark.createDataFrame(eco_rows).select( - col("evidence_code").alias(EVIDENCE_CODE), - col("db_ref").alias(DB_REF), - col("evidence_type").alias(EVIDENCE_TYPE), - ) - - result_df = merge_evidence(annotation_df, eco_df) - result = [(row[EVIDENCE_CODE], row[PUBLICATIONS], row[EVIDENCE_TYPE]) for row in result_df.collect()] - assert result == expected - - -@pytest.mark.requires_spark -def test_write_output_and_read_back(spark: SparkSession, tmp_path: Path) -> None: - """Test delta read and write.""" - # Sample test data - data = [("GO:0008150", "UniProtKB", "2024-01-01")] - columns = ["object", "db", "annotation_date"] - df = spark.createDataFrame(data, columns) - - # Write to temporary Delta location - output_path = str(tmp_path / "delta_table") - write_output(df, output_path) - - # Read back and validate - result_df = spark.read.format("delta").load(output_path) - result = result_df.collect() - - assert len(result) == 1 - assert result[0]["object"] == "GO:0008150" - assert result[0]["db"] == "UniProtKB" - assert str(result[0]["annotation_date"]) == "2024-01-01" diff --git a/tests/parsers/test_genome_paths.py b/tests/parsers/test_genome_paths.py deleted file mode 100644 index 99adedc..0000000 --- a/tests/parsers/test_genome_paths.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Tests for the genome paths file parser.""" - -import json -import re -from pathlib import Path -from typing import Any - -import pytest - -from cdm_data_loader_utils.parsers.genome_paths import get_genome_paths - -GPF_DIR = "genome_paths_file" - - -def test_get_genome_paths_empty(tmp_path: Path) -> None: - """Test that an error is thrown by a non-existent file.""" - full_path = tmp_path / "some-file" - with pytest.raises( - RuntimeError, - match=re.escape("error parsing genome_paths_file: [Errno 2] No such file or directory"), - ): - get_genome_paths(full_path) - - -def test_get_genome_paths_empty_file(tmp_path: Path) -> None: - """Test that an empty file throws an error.""" - full_path = tmp_path / "some-file" - full_path.touch() - - with pytest.raises( - RuntimeError, - match=re.escape("error parsing genome_paths_file: Expecting value: line 1 column 1 (char 0)"), - ): - get_genome_paths(full_path) - - -format_errors = { - # JSON parse errors - "empty": "error parsing genome_paths_file: Expecting value: line 1 column 1 (char 0)", - "ws": "error parsing genome_paths_file: Expecting value: line 4 column 5 (char 8)", - "str": "error parsing genome_paths_file: Expecting value: line 1 column 1 (char 0)", - "unclosed_str": "error parsing genome_paths_file: Unterminated string starting at: line 1 column 9 (char 8)", - "null_key": "error parsing genome_paths_file: Expecting property name enclosed in double quotes: line 1 column 10 (char 9)", - # wrong format (array) - "array_of_objects": "genome_paths_file is not in the correct format", - "empty_array": "genome_paths_file is not in the correct format", - # no data - "empty_object": "no valid data found in genome_paths_file", -} - -err_types = {"empty_array": TypeError, "array_of_objects": TypeError} - - -@pytest.mark.parametrize( - "params", - [ - pytest.param( - { - "err_msg": format_errors[err_id], - "input": err_id, - }, - id=err_id, - ) - for err_id in format_errors - ], -) -def test_get_genome_paths_invalid_format( - params: dict[str, str], - test_data_dir: Path, - monkeypatch: pytest.MonkeyPatch, - json_test_strings: dict[str, Any], -) -> None: - """Test that invalid JSON structures throw an error.""" - full_path = test_data_dir / GPF_DIR / "valid.json" - - def mockreturn(_) -> dict[str, Any] | list[str | Any]: - return json.loads(json_test_strings[params["input"]]) - - # patch json.load to return the data structure in params - monkeypatch.setattr(json, "load", mockreturn) - - with pytest.raises( - err_types.get(params["input"], RuntimeError), - match=re.escape(params["err_msg"]), - ): - get_genome_paths(full_path) - - -error_list = { - "no_entry": [{"": {"this": "that"}}, 'No ID specified for entry {"this": "that"}', ValueError], - "invalid_entry_format_arr": [{"id": []}, "id: invalid entry format"], - "invalid_entry_format_str": [{"id": "some string"}, "id: invalid entry format"], - "invalid_entry_format_None": [{"id": None}, "id: invalid entry format"], - "no_valid_paths": [{"id": {}}, "id: no valid file types or paths found"], - "invalid_keys": [{"id": {"pap": 1, "pip": 2, "pop": 3}}, "id: invalid keys: pap, pip, pop"], -} - - -@pytest.mark.parametrize( - "params", - [ - pytest.param( - { - "err_msg": error_list[err_id][1], - "input": error_list[err_id][0], - }, - id=err_id, - ) - for err_id in error_list - ], -) -def test_get_genome_paths_valid_input_invalid_format( - params: dict[str, str], test_data_dir: Path, monkeypatch: pytest.MonkeyPatch -) -> None: - """Test that invalid JSON structures throw an error.""" - full_path = test_data_dir / GPF_DIR / "valid.json" - - def mockreturn(_) -> dict[str, Any] | list[str | Any]: - return params["input"] - - # patch json.load to return the data structure in params - monkeypatch.setattr(json, "load", mockreturn) - - with pytest.raises( - RuntimeError, - match=f"Please ensure that the genome_paths_file is in the correct format.\n\n{params['err_msg']}", - ): - get_genome_paths(full_path) - - -def test_get_genome_paths(test_data_dir: Path) -> None: - """Test that the genome paths file can be correctly parsed.""" - full_path = test_data_dir / GPF_DIR / "valid.json" - assert get_genome_paths(full_path) == { - "FW305-3-2-15-C-TSA1.1": { - "fna": "tests/data/FW305-3-2-15-C-TSA1/FW305-3-2-15-C-TSA1_scaffolds.fna", - "gff": "tests/data/FW305-3-2-15-C-TSA1/FW305-3-2-15-C-TSA1_genes.gff", - "protein": "tests/data/FW305-3-2-15-C-TSA1/FW305-3-2-15-C-TSA1_genes.faa", - }, - "FW305-C-112.1": { - "fna": "tests/data/FW305-C-112.1/FW305-C-112.1_scaffolds.fna", - "gff": "tests/data/FW305-C-112.1/FW305-C-112.1_genes.gff", - "protein": "tests/data/FW305-C-112.1/FW305-C-112.1_genes.faa", - }, - } diff --git a/tests/parsers/test_shared_identifiers.py b/tests/parsers/test_shared_identifiers.py deleted file mode 100644 index b76e9af..0000000 --- a/tests/parsers/test_shared_identifiers.py +++ /dev/null @@ -1,34 +0,0 @@ -import xml.etree.ElementTree as ET - -from cdm_data_loader_utils.parsers.shared_identifiers import parse_identifiers_generic - - -def test_parse_identifiers_generic_basic() -> None: - # - # P12345 - # Q99999 - # - ns = {"ns": "dummy"} - entry = ET.Element("entry") - - a1 = ET.SubElement(entry, "accession") - a1.text = "P12345" - a2 = ET.SubElement(entry, "accession") - a2.text = "Q99999" - - # Add namespace prefix to match xpath - a1.tag = "{dummy}accession" - a2.tag = "{dummy}accession" - - rows = parse_identifiers_generic( - entry=entry, - xpath="ns:accession", - prefix="UniProt", - ns=ns, - ) - - assert len(rows) == 2 - assert rows[0]["identifier"] == "UniProt:P12345" - assert rows[1]["identifier"] == "UniProt:Q99999" - assert rows[0]["source"] == "UniProt" - assert rows[0]["description"] == "UniProt accession" diff --git a/tests/parsers/test_uniprot.py b/tests/parsers/test_uniprot.py deleted file mode 100644 index 86ffca3..0000000 --- a/tests/parsers/test_uniprot.py +++ /dev/null @@ -1,744 +0,0 @@ -""" - -This file uses pytest to provide parameterized and functional tests for all major -UniProt parsing utility functions, ensuring correct parsing and transformation of -UniProt XML into structured CDM records. - -Coverage: - - generate_cdm_id: Stable CDM entity ID from accession - - build_datasource_record: Datasource provenance and metadata - - parse_identifiers: UniProt accessions to identifier records - - parse_names: Protein names (top-level, recommended, alternative) - - parse_protein_info: EC numbers, existence evidence, sequences, etc. - - parse_evidence_map: Evidence element mapping and supporting objects - - parse_associations: Biological and database associations (taxonomy, PDB, Rhea, ChEBI) - - parse_publications: Supported literature references (PMID, DOI, etc.) - - parse_uniprot_entry: Full record parsing, all fields together - -How to run in the terminal: - pytest tests/uniprot_refactor/test_uniprot_parsers.py - -""" - -import datetime -import json -import xml.etree.ElementTree as ET -from pathlib import Path - -import pytest - -from cdm_data_loader_utils.parsers.uniprot import ( - build_datasource_record, - parse_associations, - parse_cross_references, - parse_evidence_map, - parse_identifiers, - parse_names, - parse_protein_info, - save_datasource_record, -) - -NS_URI = "https://uniprot.org/uniprot" - - -@pytest.fixture( - params=[ - "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.xml.gz", - "http://example.org/uniprot_test.xml.gz", - ] -) -def xml_url(request): - return request.param - - -def test_build_datasource_record(xml_url): - record = build_datasource_record(xml_url) - - # ---- basic structure ---- - assert isinstance(record, dict) - - # ---- fixed fields ---- - assert record["name"] == "UniProt import" - assert record["source"] == "UniProt" - assert record["url"] == xml_url - assert record["version"] == 115 - - # ---- accessed field ---- - accessed = record.get("accessed") - assert accessed is not None - - parsed = datetime.datetime.fromisoformat(accessed) - assert parsed.tzinfo is not None - assert parsed.tzinfo == datetime.UTC - - -def test_save_datasource_record(tmp_path: Path, xml_url): - """ - save_datasource_record should: - - create output directory if missing - - write datasource.json - - return the same content that is written to disk - """ - output_dir = tmp_path / "output" - - # ---- call function ---- - result = save_datasource_record(xml_url, str(output_dir)) - - # ---- return value sanity ---- - assert isinstance(result, dict) - assert result["url"] == xml_url - assert result["source"] == "UniProt" - assert result["name"] == "UniProt import" - assert "accessed" in result - assert "version" in result - - # ---- file existence ---- - output_file = output_dir / "datasource.json" - assert output_file.exists() - assert output_file.is_file() - - # ---- file content correctness ---- - with open(output_file, encoding="utf-8") as f: - on_disk = json.load(f) - - assert on_disk == result - - -def make_entry(names=None, protein_names=None): - entry = ET.Element(f"{{{NS_URI}}}entry") - - # - for n in names or []: - e = ET.SubElement(entry, f"{{{NS_URI}}}name") - e.text = n - - # block - if protein_names: - protein = ET.SubElement(entry, f"{{{NS_URI}}}protein") - - for tag, logical in [ - ("recommendedName", "recommended"), - ("alternativeName", "alternative"), - ]: - if logical not in protein_names: - continue - - block = ET.SubElement(protein, f"{{{NS_URI}}}{tag}") - for xml_tag in ["fullName", "shortName"]: - val = protein_names[logical].get(xml_tag.replace("Name", "")) - if val: - e = ET.SubElement(block, f"{{{NS_URI}}}{xml_tag}") - e.text = val - - return entry - - -@pytest.mark.parametrize( - "entry_kwargs, cdm_id, expected", - [ - # Only - ( - {"names": ["ProteinA"]}, - "cdm_1", - { - ("ProteinA", "UniProt entry name"), - }, - ), - # entry name + recommended full name - ( - { - "names": ["ProteinB"], - "protein_names": { - "recommended": {"full": "Rec Full B", "short": None}, - }, - }, - "cdm_2", - { - ("ProteinB", "UniProt entry name"), - ("Rec Full B", "UniProt recommended full name"), - }, - ), - # everything - ( - { - "names": ["ProteinC"], - "protein_names": { - "recommended": {"full": "Rec Full C", "short": "Rec Short C"}, - "alternative": {"full": "Alt Full C", "short": "Alt Short C"}, - }, - }, - "cdm_3", - { - ("ProteinC", "UniProt entry name"), - ("Rec Full C", "UniProt recommended full name"), - ("Rec Short C", "UniProt recommended short name"), - ("Alt Full C", "UniProt alternative full name"), - ("Alt Short C", "UniProt alternative short name"), - }, - ), - ], -) -def test_parse_names_parametrized(entry_kwargs, cdm_id, expected): - entry = make_entry(**entry_kwargs) - - rows = parse_names(entry, cdm_id) - - # ---- row count ---- - assert len(rows) == len(expected) - - # ---- content ---- - observed = {(r["name"], r["description"]) for r in rows} - assert observed == expected - - # ---- entity_id and source ---- - for r in rows: - assert r["entity_id"] == cdm_id - assert r["source"] == "UniProt" - - -@pytest.mark.parametrize( - "build_entry, cdm_id, expected", - [ - # -------------------------------------------------- - # Empty entry -> None - # -------------------------------------------------- - ( - lambda: ET.Element(f"{{{NS_URI}}}entry"), - "cdm_1", - None, - ), - # -------------------------------------------------- - # Only EC numbers - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - ET.SubElement( - ET.SubElement( - ET.SubElement(entry, f"{{{NS_URI}}}protein"), - f"{{{NS_URI}}}recommendedName", - ), - f"{{{NS_URI}}}ecNumber", - ).__setattr__("text", "1.1.1.1"), - entry, - )[1] - )(ET.Element(f"{{{NS_URI}}}entry")), - "cdm_2", - { - "ec_numbers": "1.1.1.1", - }, - ), - # -------------------------------------------------- - # Only sequence + entry modified - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - entry.set("modified", "2024-01-01"), - ET.SubElement( - entry, - f"{{{NS_URI}}}sequence", - { - "length": "100", - "mass": "12345", - "checksum": "ABC", - "version": "2", - }, - ).__setattr__("text", "MKTIIALSY"), - entry, - )[2] - )(ET.Element(f"{{{NS_URI}}}entry")), - "cdm_3", - { - "length": "100", - "mass": "12345", - "checksum": "ABC", - "sequence_version": "2", - "sequence": "MKTIIALSY", - "entry_modified": "2024-01-01", - }, - ), - # -------------------------------------------------- - # Everything - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - entry.set("modified", "2024-02-02"), - # protein + EC - ET.SubElement( - ET.SubElement( - ET.SubElement(entry, f"{{{NS_URI}}}protein"), - f"{{{NS_URI}}}recommendedName", - ), - f"{{{NS_URI}}}ecNumber", - ).__setattr__("text", "3.5.4.4"), - # proteinExistence - ET.SubElement( - entry, - f"{{{NS_URI}}}proteinExistence", - {"type": "evidence at protein level"}, - ), - # sequence - ET.SubElement( - entry, - f"{{{NS_URI}}}sequence", - { - "length": "250", - "mass": "99999", - "checksum": "XYZ", - "modified": "2023-12-01", - "version": "1", - }, - ).__setattr__("text", "MADEUPSEQUENCE"), - entry, - )[4] - )(ET.Element(f"{{{NS_URI}}}entry")), - "cdm_4", - { - "ec_numbers": "3.5.4.4", - "protein_id": "cdm_4", - "evidence_for_existence": "evidence at protein level", - "length": "250", - "mass": "99999", - "checksum": "XYZ", - "modified": "2023-12-01", - "sequence_version": "1", - "sequence": "MADEUPSEQUENCE", - "entry_modified": "2024-02-02", - }, - ), - ], -) -def test_parse_protein_info(build_entry, cdm_id, expected): - entry = build_entry() - - result = parse_protein_info(entry, cdm_id) - - if expected is None: - assert result is None - else: - assert isinstance(result, dict) - assert result == expected - - -@pytest.mark.parametrize( - "build_xml, expected", - [ - # -------------------------------------------------- - # No evidence elements - # -------------------------------------------------- - ( - lambda: ET.Element(f"{{{NS_URI}}}entry"), - {}, - ), - # -------------------------------------------------- - # Evidence without key - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - ET.SubElement(entry, f"{{{NS_URI}}}evidence", {"type": "ECO:0000269"}), - entry, - )[1] - )(ET.Element(f"{{{NS_URI}}}entry")), - {}, - ), - # -------------------------------------------------- - # Evidence with key, no source - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - ET.SubElement( - entry, - f"{{{NS_URI}}}evidence", - {"key": "1", "type": "ECO:0000313"}, - ), - entry, - )[1] - )(ET.Element(f"{{{NS_URI}}}entry")), - { - "1": { - "evidence_type": "ECO:0000313", - } - }, - ), - # -------------------------------------------------- - # Evidence with PUBMED with other refs - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - lambda ev: ( - ET.SubElement( - ET.SubElement(ev, f"{{{NS_URI}}}source"), - f"{{{NS_URI}}}dbReference", - {"type": "PubMed", "id": "12345"}, - ), - ET.SubElement( - ET.SubElement(ev, f"{{{NS_URI}}}source"), - f"{{{NS_URI}}}dbReference", - {"type": "GO", "id": "GO:0008150"}, - ), - entry, - )[2] - )( - ET.SubElement( - entry, - f"{{{NS_URI}}}evidence", - {"key": "E2", "type": "ECO:0000269"}, - ) - ) - )(ET.Element(f"{{{NS_URI}}}entry")), - { - "E2": { - "evidence_type": "ECO:0000269", - "publications": ["PMID:12345"], - } - }, - ), - ], -) -def test_parse_evidence_map_parametrized(build_xml, expected): - entry = build_xml() - result = parse_evidence_map(entry) - - assert isinstance(result, dict) - assert result == expected - - -@pytest.mark.parametrize( - "build_xml, cdm_id, evidence_map, expected", - [ - # -------------------------------------------------- - # Taxonomy association only - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - ET.SubElement( - ET.SubElement(entry, f"{{{NS_URI}}}organism"), - f"{{{NS_URI}}}dbReference", - {"type": "NCBI Taxonomy", "id": "1234"}, - ), - entry, - )[1] - )(ET.Element(f"{{{NS_URI}}}entry")), - "cdm_1", - {}, - [ - { - "subject": "cdm_1", - "object": "NCBITaxon:1234", - "predicate": "in_taxon", - } - ], - ), - # -------------------------------------------------- - # Catalytic activity with evidence - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - lambda comment: ( - lambda reaction: ( - ET.SubElement( - reaction, - f"{{{NS_URI}}}dbReference", - {"type": "Rhea", "id": "RHEA:12345"}, - ), - entry, - )[1] - )( - ET.SubElement( - comment, - f"{{{NS_URI}}}reaction", - {"evidence": "E1"}, - ) - ) - )( - ET.SubElement( - entry, - f"{{{NS_URI}}}comment", - {"type": "catalytic activity"}, - ) - ) - )(ET.Element(f"{{{NS_URI}}}entry")), - "cdm_2", - { - "E1": { - "evidence_type": "ECO:0000269", - "publications": ["PMID:12345"], - } - }, - [ - { - "subject": "cdm_2", - "predicate": "catalyzes", - "object": "Rhea:RHEA:12345", - "evidence_type": "ECO:0000269", - "publications": ["PMID:12345"], - } - ], - ), - # -------------------------------------------------- - # Cofactor association - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - lambda comment: ( - ET.SubElement( - ET.SubElement( - comment, - f"{{{NS_URI}}}cofactor", - ), - f"{{{NS_URI}}}dbReference", - {"type": "ChEBI", "id": "CHEBI:15377"}, - ), - entry, - )[1] - )( - ET.SubElement( - entry, - f"{{{NS_URI}}}comment", - {"type": "cofactor"}, - ) - ) - )(ET.Element(f"{{{NS_URI}}}entry")), - "cdm_3", - {}, - [ - { - "subject": "cdm_3", - "predicate": "requires_cofactor", - "object": "ChEBI:CHEBI:15377", - } - ], - ), - ], -) -def test_parse_associations_parametrized(build_xml, cdm_id, evidence_map, expected): - entry = build_xml() - - result = parse_associations(entry, cdm_id, evidence_map) - - assert isinstance(result, list) - assert result == expected - - -@pytest.mark.parametrize( - "build_xml, cdm_id, expected", - [ - # -------------------------------------------------- - # No dbReference - # -------------------------------------------------- - ( - lambda: ET.Element(f"{{{NS_URI}}}entry"), - "cdm_1", - [], - ), - # -------------------------------------------------- - # dbReference with CURIE id - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - ET.SubElement( - entry, - f"{{{NS_URI}}}dbReference", - {"type": "GO", "id": "GO:0008150"}, - ), - entry, - )[1] - )(ET.Element(f"{{{NS_URI}}}entry")), - "cdm_2", - [ - { - "entity_id": "cdm_2", - "xref_type": "GO", - "xref_value": "GO:0008150", - "xref": "GO:0008150", - } - ], - ), - # -------------------------------------------------- - # dbReference without CURIE (prefix) - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - ET.SubElement( - entry, - f"{{{NS_URI}}}dbReference", - {"type": "CDD", "id": "cd04253"}, - ), - entry, - )[1] - )(ET.Element(f"{{{NS_URI}}}entry")), - "cdm_3", - [ - { - "entity_id": "cdm_3", - "xref_type": "CDD", - "xref_value": "cd04253", - "xref": "CDD:cd04253", - } - ], - ), - # -------------------------------------------------- - # Mixed dbReferences - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - ET.SubElement( - entry, - f"{{{NS_URI}}}dbReference", - {"type": "GO", "id": "GO:0003674"}, - ), - ET.SubElement( - entry, - f"{{{NS_URI}}}dbReference", - {"type": "PDB", "id": "1ABC"}, - ), - entry, - )[2] - )(ET.Element(f"{{{NS_URI}}}entry")), - "cdm_4", - [ - { - "entity_id": "cdm_4", - "xref_type": "GO", - "xref_value": "GO:0003674", - "xref": "GO:0003674", - }, - { - "entity_id": "cdm_4", - "xref_type": "PDB", - "xref_value": "1ABC", - "xref": "PDB:1ABC", - }, - ], - ), - # -------------------------------------------------- - # Missing type or id - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - ET.SubElement( - entry, - f"{{{NS_URI}}}dbReference", - {"type": "GO"}, # missing id - ), - ET.SubElement( - entry, - f"{{{NS_URI}}}dbReference", - {"id": "123"}, # missing type - ), - entry, - )[2] - )(ET.Element(f"{{{NS_URI}}}entry")), - "cdm_5", - [], - ), - ], -) -def test_parse_cross_references_parametrized(build_xml, cdm_id, expected): - entry = build_xml() - - result = parse_cross_references(entry, cdm_id) - - assert isinstance(result, list) - assert result == expected - - -@pytest.mark.parametrize( - "build_xml, cdm_id, expected", - [ - # -------------------------------------------------- - # No accession - # -------------------------------------------------- - ( - lambda: ET.Element(f"{{{NS_URI}}}entry"), - "cdm_1", - [], - ), - # -------------------------------------------------- - # Single accession - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - ET.SubElement(entry, f"{{{NS_URI}}}accession").__setattr__("text", "P12345"), - entry, - )[1] - )(ET.Element(f"{{{NS_URI}}}entry")), - "cdm_2", - [ - { - "entity_id": "cdm_2", - "identifier": "UniProt:P12345", - "source": "UniProt", - "description": "UniProt accession", - } - ], - ), - # -------------------------------------------------- - # Multiple accessions - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - ET.SubElement(entry, f"{{{NS_URI}}}accession").__setattr__("text", "Q11111"), - ET.SubElement(entry, f"{{{NS_URI}}}accession").__setattr__("text", "Q22222"), - entry, - )[2] - )(ET.Element(f"{{{NS_URI}}}entry")), - "cdm_3", - [ - { - "entity_id": "cdm_3", - "identifier": "UniProt:Q11111", - "source": "UniProt", - "description": "UniProt accession", - }, - { - "entity_id": "cdm_3", - "identifier": "UniProt:Q22222", - "source": "UniProt", - "description": "UniProt accession", - }, - ], - ), - # -------------------------------------------------- - # parse_identifiers_generic already sets source/description → setdefault - # -------------------------------------------------- - ( - lambda: ( - lambda entry: ( - ET.SubElement(entry, f"{{{NS_URI}}}accession").__setattr__("text", "A0A000"), - entry, - )[1] - )(ET.Element(f"{{{NS_URI}}}entry")), - "cdm_4", - [ - { - "entity_id": "cdm_4", - "identifier": "UniProt:A0A000", - "source": "UniProt", # remains - "description": "UniProt accession", # remains - } - ], - ), - ], -) -def test_parse_identifiers_parametrized(build_xml, cdm_id, expected): - entry = build_xml() - - result = parse_identifiers(entry, cdm_id) - - assert isinstance(result, list) - assert result == expected diff --git a/tests/parsers/test_uniref.py b/tests/parsers/test_uniref.py deleted file mode 100644 index 9ca5360..0000000 --- a/tests/parsers/test_uniref.py +++ /dev/null @@ -1,318 +0,0 @@ -import os -import sys - -sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) - -import gzip -import tempfile -import xml.etree.ElementTree as ET -from datetime import datetime, timezone -import pytest - -from cdm_data_loader_utils.parsers.uniref import ( - cdm_entity_id, - get_timestamps, - extract_cluster, - get_accession_and_seed, - add_cluster_members, - extract_cross_refs, - parse_uniref_xml, -) - -NS = {"ns": "http://uniprot.org/uniref"} - - -# --------------------------------------------------------- -# cdm_entity_id -# --------------------------------------------------------- -@pytest.mark.parametrize( - "value, should_raise", - [ - ("A0A009HJL9", False), - ("UniRef100_A0A009HJL9", False), - ("", True), - (None, True), - ], -) -def test_cdm_entity_id(value, should_raise): - if should_raise: - with pytest.raises(ValueError): - cdm_entity_id(value) - else: - out = cdm_entity_id(value) - assert isinstance(out, str) - assert out.startswith("CDM:") - - -# --------------------------------------------------------- -# get_timestamps -# --------------------------------------------------------- -@pytest.mark.parametrize( - "uniref_id, existing, now, expect_created_same_as_updated", - [ - ( - "UniRef100_A", - {"UniRef100_A": "2024-01-01T00:00:00+00:00"}, - datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), - False, - ), - ( - "UniRef100_B", - {}, - datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), - True, - ), - ( - "UniRef100_C", - {}, - None, - True, - ), - ], -) -def test_get_timestamps(uniref_id, existing, now, expect_created_same_as_updated): - updated, created = get_timestamps(uniref_id, existing, now) - - assert isinstance(updated, str) - assert isinstance(created, str) - assert updated.endswith("+00:00") - - if expect_created_same_as_updated: - assert updated == created - else: - assert updated != created - - -@pytest.mark.parametrize("bad_id", ["", None]) -def test_get_timestamps_rejects_empty_uniref_id(bad_id): - with pytest.raises(ValueError): - get_timestamps(bad_id, {}, None) - - -# --------------------------------------------------------- -# add_cluster_members -# --------------------------------------------------------- -@pytest.mark.parametrize( - "repr_xml, member_xmls, expected_count", - [ - ( - """ - - - - - """, - [ - """ - - - - """, - """ - - - - """, - ], - 3, - ), - ( - None, - [ - """ - - - - """, - ], - 1, - ), - (None, [], 0), - ], -) -def test_add_cluster_members(repr_xml, member_xmls, expected_count): - cluster_id = "CDM_CLUSTER" - repr_db = ET.fromstring(repr_xml) if repr_xml else None - - entry = ET.Element("{http://uniprot.org/uniref}entry") - for m in member_xmls: - mem = ET.SubElement(entry, "{http://uniprot.org/uniref}member") - mem.append(ET.fromstring(m)) - - rows = [] - add_cluster_members(cluster_id, repr_db, entry, rows, NS) - - assert len(rows) == expected_count - for r in rows: - assert r[0] == cluster_id - assert r[1].startswith("CDM:") - assert r[4] == "1.0" - - -# --------------------------------------------------------- -# extract_cluster -# --------------------------------------------------------- -@pytest.mark.parametrize( - "xml_str, uniref_id, expected_name", - [ - ( - "Test Cluster Name", - "UniRef100_A", - "Test Cluster Name", - ), - ( - "", - "UniRef100_B", - "UNKNOWN", - ), - ], -) -def test_extract_cluster(xml_str, uniref_id, expected_name): - elem = ET.fromstring(xml_str) - - cluster_id, name = extract_cluster(elem, NS, uniref_id) - - # ---- cluster_id checks ---- - assert isinstance(cluster_id, str) - assert cluster_id.startswith("CDM:") - - # ---- name checks ---- - assert name == expected_name - - -@pytest.mark.parametrize( - "xml_str, expected_acc, expected_is_seed", - [ - # accession + isSeed=true - ( - """ - - - - - """, - "A0A009HJL9", - True, - ), - # accession only - ( - """ - - - - """, - "A0A241V597", - False, - ), - # no accession - ( - """ - - - - """, - None, - False, - ), - # dbref is None - ( - None, - None, - False, - ), - ], -) -def test_get_accession_and_seed(xml_str, expected_acc, expected_is_seed): - dbref = ET.fromstring(xml_str) if xml_str else None - - acc, is_seed = get_accession_and_seed(dbref, NS) - - assert acc == expected_acc - assert is_seed == expected_is_seed - - -# --------------------------------------------------------- -# extract_cross_refs -# --------------------------------------------------------- -@pytest.mark.parametrize( - "props, expected", - [ - ( - [ - ("UniProtKB accession", "A0A1"), - ("UniRef90 ID", "UniRef90_X"), - ("UniParc ID", "UPI0001"), - ], - { - ("UniRef90 ID", "UniRef90_X"), - ("UniParc ID", "UPI0001"), - }, - ), - ( - [ - ("UniProtKB accession", "A0A2"), - ], - set(), - ), - ], -) -def test_extract_cross_refs(props, expected): - dbref = ET.Element("{http://uniprot.org/uniref}dbReference", id="UniProtKB:A0A1") - - for k, v in props: - ET.SubElement( - dbref, - "{http://uniprot.org/uniref}property", - type=k, - value=v, - ) - - rows = [] - extract_cross_refs(dbref, rows, NS) - - got = {(t, v) for _, t, v in rows} - assert got == expected - - for entity_id, _, _ in rows: - assert entity_id is not None - assert isinstance(entity_id, str) - - -# --------------------------------------------------------- -# parse_uniref_xml -# --------------------------------------------------------- -@pytest.mark.parametrize("batch_size", [1, 2]) -def test_parse_uniref_xml_batch(batch_size): - xml = """ - - - A - - - - - - - - - B - - - - - - - - """.strip() - - with tempfile.TemporaryDirectory() as tmpdir: - gz_path = f"{tmpdir}/uniref_test.xml.gz" - with gzip.open(gz_path, "wb") as gz: - gz.write(xml.encode("utf-8")) - - result = parse_uniref_xml(gz_path, batch_size, {}) - - assert len(result["cluster_data"]) == batch_size - assert len(result["entity_data"]) == batch_size - assert len(result["cluster_member_data"]) == batch_size - assert len(result["cross_reference_data"]) in (0, batch_size) diff --git a/tests/parsers/test_xml_utils.py b/tests/parsers/test_xml_utils.py deleted file mode 100644 index fc6e3ba..0000000 --- a/tests/parsers/test_xml_utils.py +++ /dev/null @@ -1,49 +0,0 @@ -import xml.etree.ElementTree as ET - -from cdm_data_loader_utils.parsers.xml_utils import ( - clean_dict, - get_attr, - get_text, - parse_db_references, -) - - -def test_get_text_and_get_attr_basic() -> None: - elem = ET.Element("tag", attrib={"id": "123"}) - elem.text = " hello " - - assert get_text(elem) == "hello" - assert get_text(None) is None - assert get_attr(elem, "id") == "123" - assert get_attr(elem, "missing") is None - - -def test_parse_db_references_pub_and_others() -> None: - ns = {"ns": "dummy"} - source = ET.Element("source") - db1 = ET.SubElement(source, "dbReference", attrib={"type": "PubMed", "id": "12345"}) - db2 = ET.SubElement(source, "dbReference", attrib={"type": "DOI", "id": "10.1000/xyz"}) - db3 = ET.SubElement(source, "dbReference", attrib={"type": "PDB", "id": "1ABC"}) - - db1.tag = "{dummy}dbReference" - db2.tag = "{dummy}dbReference" - db3.tag = "{dummy}dbReference" - - pubs, others = parse_db_references(source, ns) - - assert "PUBMED:12345" in pubs - assert "DOI:10.1000/xyz" in pubs - assert "PDB:1ABC" in others - - -def test_clean_dict_removes_nones_and_empty() -> None: - """Test that clean_dict removes None and empty values.""" - d = { - "a": 1, - "b": None, - "c": [], - "d": {}, - "e": "ok", - } - cleaned = clean_dict(d) - assert cleaned == {"a": 1, "e": "ok"} diff --git a/tests/utils/test_spark_delta.py b/tests/utils/test_spark_delta.py deleted file mode 100644 index 86abcc8..0000000 --- a/tests/utils/test_spark_delta.py +++ /dev/null @@ -1,592 +0,0 @@ -import logging -from collections.abc import Generator -from pathlib import Path -from typing import Any - -import pytest -from pyspark.sql import DataFrame, DataFrameWriter, Row, SparkSession - -from cdm_data_loader_utils.utils import spark_delta -from cdm_data_loader_utils.utils.spark_delta import ( - APPEND, - DEFAULT_APP_NAME, - DEFAULT_NAMESPACE, - ERROR, - ERROR_IF_EXISTS, - IGNORE, - OVERWRITE, - WRITE_MODE, - get_spark, - preview_or_skip, - write_delta, -) - -original_set_up_ws_fn = spark_delta.set_up_workspace - -SAVE_DIR = "spark.sql.warehouse.dir" -DEFAULT_WRITE_MODE = ERROR -DEFAULT_SAMPLE_DATA = {"a": "A1", "b": "B1"} -TENANT_NAME = "The_Breakers" - - -@pytest.fixture -def spark(tmp_path: Path) -> Generator[SparkSession, Any]: - """Generate a spark session with spark.sql.warehouse.dir set to the pytest temporary directory.""" - spark = get_spark("test_delta_app", local=True, delta_lake=True, override={SAVE_DIR: tmp_path}) - yield spark - spark.stop() - - -def gen_ns_save_dir(current_save_dir: str, namespace: str, tenant_name: str | None) -> tuple[str, str]: - """Generate the projected namespace and save directory, given a file path, a namespace, and a tenant name.""" - db_location = f"tenant/{tenant_name}/{namespace}.db" if tenant_name else f"user/some_user/{namespace}.db" - namespace = db_location.replace("/", "__").replace(".db", "") - save_dir = f"{current_save_dir.replace('file:', '')}/{db_location}" - return (namespace, save_dir) - - -def fake_create_namespace_if_not_exists( - spark: SparkSession, - namespace: str = "default", - append_target: bool = True, - tenant_name: str | None = None, - **kwargs, -) -> str: - """Mock create_namespace_if_not_exists without external calls.""" - current_save_dir = spark.conf.get(SAVE_DIR) - if not current_save_dir: - msg = f"Error setting up fixtures: {SAVE_DIR} not set" - raise ValueError(msg) - - if append_target: - delta_ns, db_location = gen_ns_save_dir(current_save_dir, namespace, tenant_name) - spark.sql(f"CREATE DATABASE IF NOT EXISTS {delta_ns} LOCATION '{db_location}'") - print(f"Namespace {delta_ns} is ready to use at location {db_location}.") - else: - delta_ns = namespace - spark.sql(f"CREATE DATABASE IF NOT EXISTS {delta_ns}") - print(f"Namespace {delta_ns} is ready to use.") - - return delta_ns - - -@pytest.fixture -def spark_db(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Generator[tuple[SparkSession, str, str], Any]: - """Provide a Spark session with a per-test warehouse dir and patched workspace setup.""" - # patch the create_namespace_if_not_exists function - monkeypatch.setattr( - "cdm_data_loader_utils.utils.spark_delta.create_namespace_if_not_exists", - fake_create_namespace_if_not_exists, - ) - - def set_up_test_workspace(*args, **kwargs) -> tuple[SparkSession, str]: - """Local override of set_up_workspace.""" - return original_set_up_ws_fn(*args, local=True, delta_lake=True, override={SAVE_DIR: str(tmp_path)}) - - monkeypatch.setattr(spark_delta, "set_up_workspace", set_up_test_workspace) - - _, save_dir = gen_ns_save_dir(str(tmp_path), DEFAULT_NAMESPACE, TENANT_NAME) - (spark, delta_ns) = spark_delta.set_up_workspace("test_delta_app", DEFAULT_NAMESPACE, TENANT_NAME) - - yield (spark, delta_ns, save_dir) - spark.stop() - - -@pytest.mark.parametrize("app_name", [None, "", "my_fave_app"]) -def test_get_spark(app_name: str | None, monkeypatch: pytest.MonkeyPatch) -> None: - """Test of the get_spark utility's ability to fill in the app name if not provided.""" - - def fake_get_spark_session(*args: str, **kwargs) -> str: - if app_name == "my_fave_app": - assert args[0] == app_name - else: - assert args[0] == DEFAULT_APP_NAME - return "fake spark session" - - monkeypatch.setattr( - "cdm_data_loader_utils.utils.spark_delta.get_spark_session", - fake_get_spark_session, - ) - - spark = get_spark(app_name) - assert spark == "fake spark session" - - -@pytest.mark.requires_spark -@pytest.mark.parametrize("app_name", [None, "", "my_fave_app"]) -def test_get_spark_live(app_name: str | None) -> None: - """Test of the get_spark utility's ability to fill in the app name if not provided. - - Runs against the live spark engine. - """ - spark = get_spark(app_name, local=True) - assert isinstance(spark, SparkSession) - assert spark.conf.get("spark.app.name") == "my_fave_app" if app_name == "my_fave_app" else DEFAULT_APP_NAME - - -@pytest.mark.parametrize("app_name", [None, "", "my_fave_app"]) -@pytest.mark.parametrize("tenant_name", [None, "", "some_tenant"]) -@pytest.mark.parametrize("namespace", [None, "", "some_namespace"]) -@pytest.mark.parametrize("data_dir", [None, "", "path/to/ws"]) -def test_set_up_workspace_defaults( - app_name: str | None, - tenant_name: str | None, - namespace: str | None, - data_dir: str | None, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Check the default values when setting up a workspace.""" - - def fake_get_spark_session(*args: str, **kwargs) -> str: - assert args[0] == app_name if app_name else DEFAULT_APP_NAME - return "spark session" - - def fake_create_ns(*args, **kwargs) -> str: - assert args[0] == "spark session" - if namespace: - assert args[1] == namespace - else: - assert args[1] == DEFAULT_NAMESPACE - assert kwargs["tenant_name"] == tenant_name - return "delta namespace" - - monkeypatch.setattr( - "cdm_data_loader_utils.utils.spark_delta.get_spark_session", - fake_get_spark_session, - ) - - monkeypatch.setattr( - "cdm_data_loader_utils.utils.spark_delta.create_namespace_if_not_exists", - fake_create_ns, - ) - - if data_dir: - with pytest.raises(NotImplementedError, match="The data_dir parameter has not been implemented\\."): - spark_delta.set_up_workspace(app_name, namespace, tenant_name, data_dir) - return - - spark, delta_ns = spark_delta.set_up_workspace(app_name, namespace, tenant_name, data_dir) - assert spark == "spark session" - assert delta_ns == "delta namespace" - - -@pytest.mark.requires_spark -@pytest.mark.parametrize("tenant_name", [pytest.param(_, id=f"tenant_{_}") for _ in [None, "some_tenant"]]) -@pytest.mark.parametrize("namespace", [pytest.param(_, id=f"ns_{_}") for _ in [None, "some_namespace"]]) -def test_set_up_workspace_creates_database( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, - tenant_name: str | None, - namespace: str | None, -) -> None: - """Test that setting up a workspace creates the appropriate namespace. - - Mimics the functionality of BERDL's `create_namespace_if_not_exists`. - """ - app_name = "test_app" - # expected delta_ns, according to the namespace and tenant_name arguments - delta_ns = { - # namespace - None: { - # tenant - None: f"user__some_user__{DEFAULT_NAMESPACE}", - "some_tenant": f"tenant__some_tenant__{DEFAULT_NAMESPACE}", - }, - "some_namespace": { - None: "user__some_user__some_namespace", - "some_tenant": "tenant__some_tenant__some_namespace", - }, - } - - spark = get_spark(app_name, local=True) - expected = delta_ns[namespace][tenant_name] - assert not spark.catalog.databaseExists(expected) - spark.stop() - - def fake_create_namespace_if_not_exists( - spark: SparkSession, namespace: str = "default", append_target: bool = True, tenant_name: str | None = None - ) -> str: - """Mock create_namespace_if_not_exists without external calls.""" - namespace = f"tenant__{tenant_name}__{namespace}" if tenant_name else f"user__some_user__{namespace}" - assert not spark.catalog.databaseExists(namespace) - spark.sql(f"CREATE DATABASE IF NOT EXISTS {namespace}") - return namespace - - # patch the create_namespace_if_not_exists function - monkeypatch.setattr( - "cdm_data_loader_utils.utils.spark_delta.create_namespace_if_not_exists", - fake_create_namespace_if_not_exists, - ) - - def set_up_test_workspace(*args, **kwargs) -> tuple[SparkSession, str]: - """Local override of set_up_workspace.""" - return original_set_up_ws_fn(*args, local=True, delta_lake=True, override={SAVE_DIR: str(tmp_path)}) - - # patch the set_up_workspace function to add in the various extra kwargs for local use - monkeypatch.setattr(spark_delta, "set_up_workspace", set_up_test_workspace) - - # create a spark session and ensure that the appropriate database has been created - (spark, delta_ns) = spark_delta.set_up_workspace(app_name, namespace, tenant_name) - assert expected == delta_ns - assert spark.catalog.databaseExists(expected) - - -@pytest.mark.requires_spark -@pytest.mark.parametrize("dataframe", [None, [1, 2, 3], {}, True]) -def test_write_delta_no_data( - spark: SparkSession, - dataframe: DataFrame | bool | list | dict | None, # noqa: FBT001 - caplog: pytest.LogCaptureFixture, -) -> None: - """Ensure that the appropriate message is logged if there is no data to save.""" - if isinstance(dataframe, bool): - dataframe = spark.createDataFrame([], "name: string, age: int").show() - - output = write_delta(spark, dataframe, "what", "ever", DEFAULT_WRITE_MODE) # type: ignore - assert output is None - assert len(caplog.records) == 1 - for record in caplog.records: - assert record.levelno == logging.WARNING - assert record.message == "No data to write to what.ever" - - -@pytest.mark.requires_spark -@pytest.mark.parametrize("mode", ["some", "mode", 123, None, "whatever"]) -def test_write_delta_invalid_write_mode(spark: SparkSession, mode: str, caplog: pytest.LogCaptureFixture) -> None: - """Ensure that an error is logged if an invalid write mode is supplied.""" - error_msg = f"Invalid mode supplied for writing delta table: {mode}" - with pytest.raises(ValueError, match=error_msg): - write_delta(spark, {}, "what", "ever", mode) - - assert len(caplog.records) == 1 - for record in caplog.records: - assert record.levelno == logging.ERROR - assert record.message == error_msg - - -def check_query_output(spark: SparkSession, db_table: str, expected: list[dict[str, Any]]) -> None: - """Check that the query output matches the expected output.""" - # ensure that the table exists - assert spark.catalog.tableExists(db_table) - # run the query - results = spark.sql(f"SELECT * FROM {db_table}").collect() - results_as_dict = [row.asDict() for row in results] - assert len(results) == len(expected) - # TODO: make this less clunky - for row in results_as_dict: - assert row in expected - for row in expected: - assert row in results_as_dict - - -def check_logger_output_successful_write(records: list[logging.LogRecord], db_table: str, mode: str, rows: int) -> None: - """Check that the logger has emitted the appropriate messages on a successful db write.""" - first_message = records[0] - assert f"Writing table {db_table} in mode {mode} (rows={rows})" in first_message.message - assert first_message.levelno == logging.INFO - last_message = records[-1] - assert f"Saved managed table {db_table} (rows={rows})" in last_message.message - assert last_message.levelno == logging.INFO - - -def check_logger_output_successful_location_write( - records: list[logging.LogRecord], db_table: str, mode: str, rows: int -) -> None: - """Check that the logger has emitted the appropriate messages on a successful db write.""" - first_message = records[0] - assert f"Writing table {db_table} in mode {mode} (rows={rows})" in first_message.message - assert first_message.levelno == logging.INFO - last_message = records[-1] - assert f"Saved external table {db_table} (rows={rows}) to " in last_message.message - assert last_message.levelno == logging.INFO - - -def _check_saved_files(parquet_dir: Path) -> None: - # the directory where table data is expected to be stored - assert parquet_dir.is_dir() - - # use `sorted` to shortcut the iterator - parquet_files = sorted(parquet_dir.glob("*.parquet")) - assert parquet_files - - # the directory where delta table logs are expected to be stored - delta_log_dir = parquet_dir / "_delta_log" - assert delta_log_dir.is_dir() - log_files = sorted(delta_log_dir.glob("*.json")) - assert log_files - - -def check_saved_files(ns_save_dir: str | Path, table: str) -> None: - """Check that the file save operation has saved files in the expected location. - - :param ns_save_dir: save directory for a given namespace - :type ns_save_dir: str | Path - :param table: table name - :type table: str - """ - parquet_dir = Path(ns_save_dir) / table - _check_saved_files(parquet_dir) - - -def populate_db( - spark: SparkSession, caplog: pytest.LogCaptureFixture, delta_ns: str, table: str, ns_save_dir: str -) -> None: - """Populate a database, save it as a delta table, and register it with Hive.""" - db_table = f"{delta_ns}.{table}" - # save a very boring dataframe to a new db_table - write_delta( - spark=spark, - sdf=spark.createDataFrame([DEFAULT_SAMPLE_DATA]), - delta_ns=delta_ns, - table=table, - mode=DEFAULT_WRITE_MODE, - ) - assert spark.catalog.databaseExists(delta_ns) - assert spark.catalog.tableExists(db_table) - # check the db contents are as expected - check_query_output(spark, db_table, [DEFAULT_SAMPLE_DATA]) - # check there are saved files - check_saved_files(ns_save_dir, table) - # check the logger output - check_logger_output_successful_write(caplog.records, db_table, DEFAULT_WRITE_MODE, 1) - - -@pytest.mark.requires_spark -@pytest.mark.parametrize("mode", WRITE_MODE) -def test_write_delta_managed_table( - mode: str, - spark_db: tuple[SparkSession, str, str], - caplog: pytest.LogCaptureFixture, -) -> None: - """Test that a delta table is correctly written and registered in the Hive metastore. - - All valid write modes are tested. - """ - spark, delta_ns, ns_save_dir = spark_db - table = f"{mode}_example" - db_table = f"{delta_ns}.{table}" - - df = spark.createDataFrame([DEFAULT_SAMPLE_DATA]) - write_delta( - spark=spark, - sdf=df, - delta_ns=delta_ns, - table=table, - mode=mode, - ) - check_query_output(spark, db_table, [DEFAULT_SAMPLE_DATA]) - assert len(caplog.records) > 1 - check_logger_output_successful_write(caplog.records, db_table, mode, 1) - check_saved_files(ns_save_dir, table) - - -@pytest.mark.requires_spark -def test_write_delta_append_schema_merge( - spark_db: tuple[SparkSession, str, str], caplog: pytest.LogCaptureFixture -) -> None: - """Test adding data to an existing db using 'append' mode. - - Append mode should merge schemas, adding new columns without dropping existing data. - """ - mode = APPEND - spark, delta_ns, ns_save_dir = spark_db - table = f"{mode}_test" - db_table = f"{delta_ns}.{table}" - - populate_db(spark, caplog, delta_ns, table, ns_save_dir) - caplog.clear() - - # second write - two rows with three columns (SCHEMA CHANGE ALERT!) - new_rows = [{"a": "A2", "b": "B2", "c": "C2"}, {"a": "A3", "b": "B3", "c": "C3"}] - write_delta( - spark=spark, - sdf=spark.createDataFrame(new_rows), - delta_ns=delta_ns, - table=table, - mode=mode, - ) - check_saved_files(ns_save_dir, table) - check_logger_output_successful_write(caplog.records, db_table, mode, 2) - check_query_output(spark, db_table, [{**DEFAULT_SAMPLE_DATA, "c": None}, *new_rows]) - - -@pytest.mark.requires_spark -def test_write_delta_overwrite_schema( - spark_db: tuple[SparkSession, str, str], caplog: pytest.LogCaptureFixture -) -> None: - """Test adding data to an existing db using 'overwrite' mode. - - Overwrite mode should overwrite the original schema and replace any existing data. - """ - mode = OVERWRITE - spark, delta_ns, ns_save_dir = spark_db - table = f"{mode}_test" - db_table = f"{delta_ns}.{table}" - - populate_db(spark, caplog, delta_ns, table, ns_save_dir) - caplog.clear() - - # second write - two rows with three columns (SCHEMA CHANGE ALERT!) - write_delta( - spark=spark, - sdf=spark.createDataFrame([Row(x="X2", y="Y2", z="Z2"), Row(x="X3", y="Y3", z="Z3")]), - delta_ns=delta_ns, - table=table, - mode=mode, - ) - check_saved_files(ns_save_dir, table) - check_logger_output_successful_write(caplog.records, db_table, mode, 2) - check_query_output(spark, db_table, [{"x": f"X{n}", "y": f"Y{n}", "z": f"Z{n}"} for n in [2, 3]]) - - -@pytest.mark.requires_spark -@pytest.mark.parametrize("mode", [IGNORE, ERROR, ERROR_IF_EXISTS]) -def test_write_delta_ignore_error( - spark_db: tuple[SparkSession, str, str], caplog: pytest.LogCaptureFixture, mode: str -) -> None: - """Test adding data to an existing db using 'ignore' or either of the 'error' modes. - - "ignore" will write data if no table exists, but will not write anything if the table already exists. - "error" and "error_if_exists" would throw an error if the table already exists, but `write_delta` exits early. - """ - spark, delta_ns, ns_save_dir = spark_db - table = f"{mode}_test" - db_table = f"{delta_ns}.{table}" - - populate_db(spark, caplog, delta_ns, table, ns_save_dir) - caplog.clear() - - # second write - two rows with three columns (SCHEMA CHANGE ALERT!) - write_delta( - spark=spark, - sdf=spark.createDataFrame([Row(x="X2", y="Y2", z="Z2"), Row(x="X3", y="Y3", z="Z3")]), - delta_ns=delta_ns, - table=table, - mode=mode, - ) - check_saved_files(ns_save_dir, table) - last_logger_message = caplog.records[-1] - assert last_logger_message.levelno == logging.WARNING - assert ( - last_logger_message.message - == f"Database table {db_table} already exists and writer is set to {mode} mode, so no data would be written. Aborting." - ) - # check the db contents - check_query_output(spark, db_table, [DEFAULT_SAMPLE_DATA]) - - -@pytest.mark.requires_spark -def test_write_delta_raise_error( - spark_db: tuple[SparkSession, str, str], caplog: pytest.LogCaptureFixture, monkeypatch: pytest.MonkeyPatch -) -> None: - """Ensure that errors are handled gracefully if something terrible happens during saveAsFile.""" - spark, delta_ns, _ = spark_db - table = "error_handling" - db_table = f"{delta_ns}.{table}" - - def save_as_oh_crap(*args, **kwargs) -> None: - """Local override of set_up_workspace.""" - msg = "Oh crap!" - raise RuntimeError(msg) - - monkeypatch.setattr(DataFrameWriter, "saveAsTable", save_as_oh_crap) - - with pytest.raises(Exception, match="Oh crap!"): - write_delta( - spark=spark, - sdf=spark.createDataFrame([Row(x=2, y=3)]), - delta_ns=delta_ns, - table=table, - mode=DEFAULT_WRITE_MODE, - ) - last_log_record = caplog.records[-1] - assert last_log_record.levelno == logging.ERROR - assert last_log_record.message == f"Error writing managed table {db_table}" - - -@pytest.mark.requires_spark -@pytest.mark.parametrize("mode", WRITE_MODE) -def test_write_delta_uninited_namespace( - mode: str, - spark_db: tuple[SparkSession, str, str], - caplog: pytest.LogCaptureFixture, -) -> None: - """Test that a namespace that has not been registered throws an error.""" - spark, _delta_ns, _ns_save_dir = spark_db - table = f"{mode}_example" - err_msg = "Could not find an appropriate base directory for saving data." - df = spark.createDataFrame([DEFAULT_SAMPLE_DATA]) - with pytest.raises(RuntimeError, match=err_msg): - write_delta( - spark=spark, - sdf=df, - delta_ns="namespace_I_just_made_up", - table=table, - mode=mode, - ) - - assert caplog.records[-1].levelno == logging.ERROR - assert caplog.records[-1].message.startswith(err_msg) - - -@pytest.mark.skip("Not yet implemented") -@pytest.mark.requires_spark -@pytest.mark.parametrize("mode", [APPEND, OVERWRITE]) -def test_write_delta_existing_proposed_path_warning( - mode: str, spark_db: tuple[SparkSession, str, str], caplog: pytest.LogCaptureFixture, tmp_path: Path -) -> None: - """Test that a warning is emitted if there already exists data saved in another location.""" - spark, delta_ns, _ns_save_dir = spark_db - table = f"{mode}_example" - db_table = f"{delta_ns}.{table}" - err_msg = "Existing path does not match the projected base path for the table. Data written to this directory must be tracked manually." - save_dir = tmp_path / "save" / "some" / "data" / "here" - - # set up a save directory for the table - spark.sql(f"CREATE TABLE IF NOT EXISTS {db_table} USING DELTA LOCATION '{save_dir!s}'") - - df = spark.createDataFrame([DEFAULT_SAMPLE_DATA]) - write_delta( - spark=spark, - sdf=df, - delta_ns=delta_ns, - table=table, - mode=mode, - ) - assert caplog.records[0].levelno == logging.WARNING - assert caplog.records[0].message.startswith(err_msg) - - -# END write_delta tests. PHEW! - - -@pytest.mark.requires_spark -def test_preview_or_skip_existing( - spark_db: tuple[SparkSession, str, str], caplog: pytest.LogCaptureFixture, capsys: pytest.CaptureFixture -) -> None: - """Test the preview or skip function with an extant db.""" - spark, delta_ns, ns_save_dir = spark_db - table = "preview_test" - db_table = f"{delta_ns}.{table}" - populate_db(spark, caplog, delta_ns, table, ns_save_dir) - caplog.clear() - - preview_or_skip(spark, delta_ns, table) - - assert caplog.records[0].message == f"Preview for {db_table}:" - captured = capsys.readouterr().out - # N.b. this may be fragile if formatting of "show" statements changes - for k, v in DEFAULT_SAMPLE_DATA.items(): - assert f"|{k} " in captured - assert f"|{v} " in captured - - -@pytest.mark.requires_spark -def test_preview_or_skip_missing(spark: SparkSession, caplog: pytest.LogCaptureFixture) -> None: - """Test the preview or skip function with a missing db.""" - db = "missing" - table = "not_found" - preview_or_skip(spark, db, table) - - last_log_message = caplog.records[-1] - assert last_log_message.message == f"Table {db}.{table} not found. Skipping preview." diff --git a/uv b/uv new file mode 100644 index 0000000..e69de29