Skip to content
43 changes: 29 additions & 14 deletions openfecli/commands/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,11 @@ def _generate_bad_legs_error_message(set_vals, ligpair):
def _parse_raw_units(results: dict) -> list[tuple]:
# grab individual unit results from master results dict
# returns list of (estimate, uncertainty) tuples
list_of_pur = list(results['protocol_result']['data'].values())[0]
list_of_pur = list(results['protocol_result']['data'].values())

return [(pu['outputs']['unit_estimate'],
pu['outputs']['unit_estimate_error'])
# could add to each tuple pu[0]["source_key"] for ID
return [(pu[0]['outputs']['unit_estimate'],
pu[0]['outputs']['unit_estimate_error'])
for pu in list_of_pur]


Expand Down Expand Up @@ -178,10 +179,10 @@ def _get_ddgs(legs, error_on_missing=True):
return DDGs


def _write_ddg(legs, writer, allow_partial):
def _write_ddg(legs, writer, allow_partial): # unc
DDGs = _get_ddgs(legs, error_on_missing=not allow_partial)
writer.writerow(["ligand_i", "ligand_j", "DDG(i->j) (kcal/mol)",
"uncertainty (kcal/mol)"])
"uncertainty (kcal/mol)"]) # unc])
for ligA, ligB, DDGbind, bind_unc, DDGhyd, hyd_unc in DDGs:
if DDGbind is not None:
DDGbind, bind_unc = format_estimate_uncertainty(DDGbind, bind_unc)
Expand All @@ -191,19 +192,19 @@ def _write_ddg(legs, writer, allow_partial):
writer.writerow([ligA, ligB, DDGhyd, hyd_unc])


def _write_raw(legs, writer, allow_partial=True):
writer.writerow(["leg", "ligand_i", "ligand_j", "DG(i->j) (kcal/mol)",
"MBAR uncertainty (kcal/mol)"])
def _write_raw(legs, writer, allow_partial=True): # *args?
writer.writerow(["leg", "repeat", "ligand_i", "ligand_j",
"DG(i->j) (kcal/mol)", "MBAR uncertainty (kcal/mol)"])

for ligpair, vals in sorted(legs.items()):
for simtype, repeats in sorted(vals.items()):
for m, u in repeats:
for rep, (m, u) in enumerate(repeats, 1):
if m is None:
m, u = 'NaN', 'NaN'
else:
m, u = format_estimate_uncertainty(m.m, u.m)

writer.writerow([simtype, *ligpair, m, u])
writer.writerow([simtype, rep, *ligpair, m, u])


def _write_dg_raw(legs, writer, allow_partial): # pragma: no-cover
Expand All @@ -218,7 +219,7 @@ def _write_dg_raw(legs, writer, allow_partial): # pragma: no-cover
writer.writerow([simtype, *ligpair, m, u])


def _write_dg_mle(legs, writer, allow_partial):
def _write_dg_mle(legs, writer, allow_partial): # unc
import networkx as nx
import numpy as np
from cinnabar.stats import mle
Expand Down Expand Up @@ -264,7 +265,7 @@ def _write_dg_mle(legs, writer, allow_partial):
MLEs.append((ligname, f, df))

writer.writerow(["ligand", "DG(MLE) (kcal/mol)",
"uncertainty (kcal/mol)"])
"uncertainty (kcal/mol)"]) # unc])
for ligA, DG, unc_DG in MLEs:
DG, unc_DG = format_estimate_uncertainty(DG, unc_DG)
writer.writerow([ligA, DG, unc_DG])
Expand Down Expand Up @@ -336,6 +337,9 @@ def gather(rootdir, output, report, allow_partial):
# 3) pair legs of simulations together into dict of dicts
legs = defaultdict(dict)

######## CHECK IF ALL RESULTS HAVE SAME # OF PROTOCOLUNITS?
# MBAR_errors = True

for result_fn in result_fns:
result = load_results(result_fn)
if result is None:
Expand All @@ -344,6 +348,8 @@ def gather(rootdir, output, report, allow_partial):
click.echo(f"WARNING: Calculations for {result_fn} did not finish successfully!",
err=True)



try:
names = get_names(result)
except KeyError:
Expand All @@ -353,8 +359,15 @@ def gather(rootdir, output, report, allow_partial):
except KeyError:
simtype = legacy_get_type(result_fn)

raw_units = _parse_raw_units(result)
######## CHECK IF ALL RESULTS HAVE SAME # OF PROTOCOLUNITS?
# if MBAR_errors and len(raw_units) > 1:
# MBAR_errors = False

if report.lower() == 'raw':
legs[names][simtype] = _parse_raw_units(result)
legs[names][simtype] = raw_units
elif len(raw_units) == 1:
legs[names][simtype] = raw_units[0]
else:
legs[names][simtype] = result['estimate'], result['uncertainty']

Expand All @@ -364,6 +377,8 @@ def gather(rootdir, output, report, allow_partial):
lineterminator="\n", # to exactly reproduce previous, prefer "\r\n"
)

# unc = "MBAR uncertainty (kcal/mol)" if MBAR_errors else "uncertainty (kcal/mol)"

# 5a) write out MLE values
# 5b) write out DDG values
# 5c) write out each leg
Expand All @@ -373,7 +388,7 @@ def gather(rootdir, output, report, allow_partial):
# 'dg-raw': _write_dg_raw,
'raw': _write_raw,
}[report.lower()]
writing_func(legs, writer, allow_partial)
writing_func(legs, writer, allow_partial) # , unc)


PLUGIN = OFECommandPlugin(
Expand Down