Skip to content

Commit 82905dd

Browse files
committed
Fix: formatting issues.
1 parent 5ef944c commit 82905dd

1 file changed

Lines changed: 32 additions & 69 deletions

File tree

src/llm/local_llm.py

Lines changed: 32 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
python extract_metrics.py path/to/text_file.txt --model llama3.1:8b
66
python extract_metrics.py path/to/text_file.txt --output-dir results/
77
8-
This script uses Ollama to extract structured data from preprocessed predator diet
8+
This script uses Ollama to extract structured data from preprocessed predator diet
99
surveys, including species name, study date, location, and stomach content data.
1010
"""
1111

@@ -21,40 +21,22 @@
2121

2222
class PredatorDietMetrics(BaseModel):
2323
"""Structured schema for extracted predator diet survey metrics."""
24-
25-
species_name: Optional[str] = Field(
26-
None,
27-
description="Scientific name of the predator species studied"
28-
)
29-
study_location: Optional[str] = Field(
30-
None,
31-
description="Geographic location where the study was conducted"
32-
)
33-
study_date: Optional[str] = Field(
34-
None,
35-
description="Year or date range when the study was conducted"
36-
)
37-
num_empty_stomachs: Optional[int] = Field(
38-
None,
39-
description="Number of predators with empty stomachs"
40-
)
41-
num_nonempty_stomachs: Optional[int] = Field(
42-
None,
43-
description="Number of predators with non-empty stomachs"
44-
)
45-
sample_size: Optional[int] = Field(
46-
None,
47-
description="Total number of predators surveyed"
48-
)
24+
25+
species_name: Optional[str] = Field(None, description="Scientific name of the predator species studied")
26+
study_location: Optional[str] = Field(None, description="Geographic location where the study was conducted")
27+
study_date: Optional[str] = Field(None, description="Year or date range when the study was conducted")
28+
num_empty_stomachs: Optional[int] = Field(None, description="Number of predators with empty stomachs")
29+
num_nonempty_stomachs: Optional[int] = Field(None, description="Number of predators with non-empty stomachs")
30+
sample_size: Optional[int] = Field(None, description="Total number of predators surveyed")
4931

5032

5133
def extract_metrics_from_text(text: str, model: str = "llama3.1:8b") -> PredatorDietMetrics:
5234
"""Extract structured metrics from text using Ollama.
53-
35+
5436
Args:
5537
text: Preprocessed text content from a scientific publication
5638
model: Name of the Ollama model to use
57-
39+
5840
Returns:
5941
PredatorDietMetrics object with extracted data
6042
"""
@@ -94,24 +76,24 @@ def extract_metrics_from_text(text: str, model: str = "llama3.1:8b") -> Predator
9476
model=model,
9577
format=PredatorDietMetrics.model_json_schema(),
9678
)
97-
79+
9880
metrics = PredatorDietMetrics.model_validate_json(response.message.content)
9981
return metrics
10082

10183

10284
def validate_and_calculate(metrics: dict) -> dict:
10385
"""Validate extracted metrics and calculate derived values.
104-
86+
10587
Args:
10688
metrics: Dictionary of extracted metrics
107-
89+
10890
Returns:
10991
Dictionary with validated metrics and calculated fraction_feeding
11092
"""
11193
empty = metrics.get("num_empty_stomachs")
11294
nonempty = metrics.get("num_nonempty_stomachs")
11395
sample = metrics.get("sample_size")
114-
96+
11597
# Validate and fix sample size if needed
11698
if empty is not None and nonempty is not None:
11799
calculated_sample = empty + nonempty
@@ -122,83 +104,64 @@ def validate_and_calculate(metrics: dict) -> dict:
122104
# LLM made an error, use calculated value
123105
metrics["sample_size"] = calculated_sample
124106
sample = calculated_sample
125-
107+
126108
# Calculate fraction of feeding predators
127109
fraction_feeding = None
128110
if nonempty is not None and sample is not None and sample > 0:
129111
fraction_feeding = round(nonempty / sample, 4)
130-
112+
131113
metrics["fraction_feeding"] = fraction_feeding
132-
114+
133115
return metrics
134116

135117

136118
def main():
137-
parser = argparse.ArgumentParser(
138-
description="Extract predator diet metrics from preprocessed text using LLM"
139-
)
140-
parser.add_argument(
141-
"text_file",
142-
type=str,
143-
help="Path to the preprocessed text file"
144-
)
145-
parser.add_argument(
146-
"--model",
147-
type=str,
148-
default="llama3.1:8b",
149-
help="Ollama model to use (default: llama3.1:8b)"
150-
)
151-
parser.add_argument(
152-
"--output-dir",
153-
type=str,
154-
default="data/results",
155-
help="Output directory for JSON results (default: data/results)"
156-
)
157-
119+
parser = argparse.ArgumentParser(description="Extract predator diet metrics from preprocessed text using LLM")
120+
parser.add_argument("text_file", type=str, help="Path to the preprocessed text file")
121+
parser.add_argument("--model", type=str, default="llama3.1:8b", help="Ollama model to use (default: llama3.1:8b)")
122+
parser.add_argument("--output-dir", type=str, default="data/results", help="Output directory for JSON results (default: data/results)")
123+
158124
args = parser.parse_args()
159-
125+
160126
# Load text file
161127
text_path = Path(args.text_file)
162128
if not text_path.exists():
163129
print(f"[ERROR] File not found: {text_path}", file=sys.stderr)
164130
sys.exit(1)
165-
131+
166132
try:
167133
with open(text_path, "r", encoding="utf-8") as f:
168134
text = f.read()
169135
except Exception as e:
170136
print(f"[ERROR] Failed to read file: {e}", file=sys.stderr)
171137
sys.exit(1)
172-
138+
173139
# Extract metrics
174140
print(f"Extracting metrics from {text_path.name}...", file=sys.stderr)
175141
try:
176142
metrics = extract_metrics_from_text(text, model=args.model)
177143
except Exception as e:
178144
print(f"[ERROR] Extraction failed: {e}", file=sys.stderr)
179145
sys.exit(1)
180-
146+
181147
# Validate and calculate derived metrics
182148
metrics_dict = metrics.model_dump()
183149
metrics_dict = validate_and_calculate(metrics_dict)
184-
150+
185151
# Prepare output
186-
result = {
187-
"source_file": text_path.name,
188-
"metrics": metrics_dict
189-
}
190-
152+
result = {"source_file": text_path.name, "metrics": metrics_dict}
153+
191154
# Generate output filename: input_name_results.json
192155
output_filename = text_path.stem + "_results.json"
193156
output_path = Path(args.output_dir) / output_filename
194-
157+
195158
# Save results
196159
output_path.parent.mkdir(parents=True, exist_ok=True)
197160
with open(output_path, "w", encoding="utf-8") as f:
198161
json.dump(result, f, indent=2)
199-
162+
200163
print(f"Results saved to {output_path}", file=sys.stderr)
201164

202165

203166
if __name__ == "__main__":
204-
main()
167+
main()

0 commit comments

Comments
 (0)