From fe46179d0b908602343ae978468aa724b4263fc4 Mon Sep 17 00:00:00 2001 From: Jessica Nash Date: Fri, 20 Sep 2024 12:58:33 -0400 Subject: [PATCH 1/2] add api endpoint for retrieving single molecule by smiles --- backend/app/app/api/v2/endpoints/molecule.py | 156 ++++++++++++++----- 1 file changed, 119 insertions(+), 37 deletions(-) diff --git a/backend/app/app/api/v2/endpoints/molecule.py b/backend/app/app/api/v2/endpoints/molecule.py index f7e658e..87f2fda 100644 --- a/backend/app/app/api/v2/endpoints/molecule.py +++ b/backend/app/app/api/v2/endpoints/molecule.py @@ -50,6 +50,101 @@ def _pandas_to_buffer(df): return buffer +def _query_data(molecule_id, data_type, db): + + # Check for valid data type. + if data_type.lower() not in ["ml", "dft", "xtb", "xtb_ni"]: + raise HTTPException(status_code=400, detail="Invalid data type.") + + table_name = f"{data_type}_data" + query = text(f""" + SELECT t.*, m.SMILES + FROM {table_name} t + JOIN molecule m ON t.molecule_id = m.molecule_id + WHERE t.molecule_id = :molecule_id + """) + + stmt = query.bindparams(molecule_id=molecule_id) + + results = db.execute(stmt).fetchall() + + # Hacky way to get each row as a dictionary. + # do this to generalize for different data sets - column names may vary. + list_of_dicts = [row._asdict() for row in results] + + return list_of_dicts + +def _get_molecule_data_types(molecule_id, db): + """Get the data types (e.g. ML, DFT, xTB) available for a molecule. + + Parameters + ---------- + molecule_id : int + The molecule ID. + db : Session + The database session. + + Returns + ------- + dict + A dictionary with the data types available in the database. `True` or `False` values indicate whether the data type is available for the molecule. + + """ + + # First, fetch all table names that have 'data' in their name + query_tables = text("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name LIKE '%data%';") + tables = db.execute(query_tables).fetchall() + + # Prepare a dictionary to store the results + results = {} + + # Loop through each table to check for the molecule_id + for table in tables: + table_name = table[0] + # Assuming the column storing molecule IDs is named 'molecule_id' in all tables. + # You might need to adjust this according to your database schema. + query_check = text(f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE molecule_id = :molecule_id) AS exists;") + exists = db.execute(query_check, {"molecule_id": molecule_id}).fetchone()[0] + results[table_name] = exists + + return results + +def _build_multiple_molecule_query(molecule_ids_list, data_type): + + # Use pandas.read_sql_query to get the data. + table_name = f"{data_type}_data" + + # Generating a safe query with placeholders + placeholders = ', '.join([':id' + str(i) for i in range(len(molecule_ids_list))]) + query_parameters = {'id' + str(i): mid for i, mid in enumerate(molecule_ids_list)} + + query = text(f""" + SELECT t.*, m.SMILES + FROM {table_name} t + JOIN molecule m ON t.molecule_id = m.molecule_id + WHERE t.molecule_id IN ({placeholders}) + """) + + return query, query_parameters + +def _get_id_from_smiles(smiles, db): + + canonical_smiles = valid_smiles(smiles) + + query = text("""SELECT molecule_id FROM molecule WHERE canonical_smiles = :smiles""") + + stmt = query.bindparams(smiles=canonical_smiles) + + molecule_id = db.execute(stmt).fetchone() + + try: + molecule_id = molecule_id[0] + except TypeError: + molecule_id = None + + return molecule_id + + def valid_smiles(smiles): """Check to see if a smile string is valid to represent a molecule. @@ -99,21 +194,7 @@ async def get_data_types(db: Session = Depends(deps.get_db)): @router.get("/{molecule_id}/data_types", response_model=Any) async def get_molecule_data_types(molecule_id: int | str, db: Session = Depends(deps.get_db)): - # First, fetch all table names that have 'data' in their name - query_tables = text("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name LIKE '%data%';") - tables = db.execute(query_tables).fetchall() - - # Prepare a dictionary to store the results - results = {} - - # Loop through each table to check for the molecule_id - for table in tables: - table_name = table[0] - # Assuming the column storing molecule IDs is named 'molecule_id' in all tables. - # You might need to adjust this according to your database schema. - query_check = text(f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE molecule_id = :molecule_id) AS exists;") - exists = db.execute(query_check, {"molecule_id": molecule_id}).fetchone()[0] - results[table_name] = exists + results = _get_molecule_data_types(molecule_id, db) return results @@ -126,23 +207,28 @@ async def get_molecule_data(molecule_id: int | str, if data_type.lower() not in ["ml", "dft", "xtb", "xtb_ni"]: raise HTTPException(status_code=400, detail="Invalid data type.") - table_name = f"{data_type}_data" - query = text(f""" - SELECT t.*, m.SMILES - FROM {table_name} t - JOIN molecule m ON t.molecule_id = m.molecule_id - WHERE t.molecule_id = :molecule_id - """) - - stmt = query.bindparams(molecule_id=molecule_id) + result = _query_data(molecule_id, data_type, db) + + return result - results = db.execute(stmt).fetchall() +@router.get("/data", response_model=Any) +async def get_data(smiles: str, data_type: str="ml", db: Session = Depends(deps.get_db)): + """Get the data for a molecule by its SMILES string.""" - # Hacky way to get each row as a dictionary. - # do this to generalize for different data sets - column names may vary. - list_of_dicts = [row._asdict() for row in results] + molecule_id = _get_id_from_smiles(smiles, db) - return list_of_dicts + if molecule_id is None: + raise HTTPException(status_code=404, detail="Molecule not found") + + # Check that the molecule has the requested data type. + valid_datatypes = _get_molecule_data_types(molecule_id, db) + + if not valid_datatypes.get(f"{data_type}_data"): + raise HTTPException(status_code=404, detail="Data type not found for molecule") + + result = _query_data(molecule_id, data_type, db) + + return result @router.get("/data/export/batch") async def get_molecules_data(molecule_ids: str, @@ -172,14 +258,8 @@ async def get_molecules_data(molecule_ids: str, # Generating a safe query with placeholders placeholders = ', '.join([':id' + str(i) for i in range(len(molecule_ids_list))]) - query_parameters = {'id' + str(i): mid for i, mid in enumerate(molecule_ids_list)} - query = text(f""" - SELECT t.*, m.SMILES - FROM {table_name} t - JOIN molecule m ON t.molecule_id = m.molecule_id - WHERE t.molecule_id IN ({placeholders}) - """) + query, query_parameters = _build_multiple_molecule_query(molecule_ids_list, data_type) df = pd.read_sql_query(query, db.bind, params=query_parameters) @@ -201,12 +281,14 @@ async def get_molecules_data(molecule_ids: str, return response + + @router.get("/data/export/{molecule_id}") async def export_molecule_data(molecule_id: int | str, data_type: str="ml", db: Session = Depends(deps.get_db)): + """Export data for a single molecule as a CSV.""" - # Check for valid data type. if data_type.lower() not in ["ml", "dft", "xtb", "xtb_ni"]: raise HTTPException(status_code=400, detail="Invalid data type.") From 690cb8dc7ff040edc698c4a34a42142a654379f2 Mon Sep 17 00:00:00 2001 From: Jessica Nash Date: Fri, 20 Sep 2024 14:12:32 -0400 Subject: [PATCH 2/2] modify batch download to allow SMILES identifiers --- backend/app/app/api/v2/endpoints/molecule.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/backend/app/app/api/v2/endpoints/molecule.py b/backend/app/app/api/v2/endpoints/molecule.py index 87f2fda..7e5b0b9 100644 --- a/backend/app/app/api/v2/endpoints/molecule.py +++ b/backend/app/app/api/v2/endpoints/molecule.py @@ -231,14 +231,26 @@ async def get_data(smiles: str, data_type: str="ml", db: Session = Depends(deps. return result @router.get("/data/export/batch") -async def get_molecules_data(molecule_ids: str, +async def get_molecules_data(molecule_ids: Optional[str]=None, + molecule_smiles: Optional[str]=None, data_type: str="ml", return_type: str="csv", context: Optional[str]=None, db: Session = Depends(deps.get_db)): + if molecule_ids is None and molecule_smiles is None: + raise HTTPException(status_code=400, detail="No molecule IDs or SMILES provided.") + + if molecule_ids is not None and molecule_smiles is not None: + raise HTTPException(status_code=400, detail="Provide either molecule IDs or SMILES, not both.") + + if molecule_smiles: + molecule_smiles_list = [ x.strip() for x in molecule_smiles.split(",")] + molecule_ids_list = [_get_id_from_smiles(smiles, db) for smiles in molecule_smiles_list] + + if molecule_ids: + molecule_ids_list = [ x.strip() for x in molecule_ids.split(",")] - molecule_ids_list = [ x.strip() for x in molecule_ids.split(",")] first_molecule_id = molecule_ids_list[0] num_molecules = len(molecule_ids_list)