-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcli.py
More file actions
70 lines (54 loc) · 2.33 KB
/
cli.py
File metadata and controls
70 lines (54 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import click
import yaml
import sys
from chebifier.ensemble.base_ensemble import BaseEnsemble
from chebifier.ensemble.weighted_majority_ensemble import WMVwithPPVNPVEnsemble, WMVwithF1Ensemble
@click.group()
def cli():
"""Command line interface for Chebifier."""
pass
ENSEMBLES = {
"mv": BaseEnsemble,
"wmv-ppvnpv": WMVwithPPVNPVEnsemble,
"wmv-f1": WMVwithF1Ensemble
}
@cli.command()
@click.argument('config_file', type=click.Path(exists=True))
@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict')
@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)')
@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)')
@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)')
def predict(config_file, smiles, smiles_file, output, ensemble_type):
"""Predict ChEBI classes for SMILES strings using an ensemble model.
CONFIG_FILE is the path to a YAML configuration file for the ensemble model.
"""
# Load configuration from YAML file
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
# Instantiate ensemble model
ensemble = ENSEMBLES[ensemble_type](config)
# Collect SMILES strings from arguments and/or file
smiles_list = list(smiles)
if smiles_file:
with open(smiles_file, 'r') as f:
smiles_list.extend([line.strip() for line in f if line.strip()])
if not smiles_list:
click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.")
return
# Make predictions
predictions = ensemble.predict_smiles_list(smiles_list)
if output:
# save as json
import json
with open(output, 'w') as f:
json.dump({smiles: pred for smiles, pred in zip(smiles_list, predictions)}, f, indent=2)
else:
# Print results
for i, (smiles, prediction) in enumerate(zip(smiles_list, predictions)):
click.echo(f"Result for: {smiles}")
if prediction:
click.echo(f" Predicted classes: {', '.join(map(str, prediction))}")
else:
click.echo(" No predictions")
if __name__ == '__main__':
cli()