-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
139 lines (111 loc) · 4.22 KB
/
utils.py
File metadata and controls
139 lines (111 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# --------------------------------------- utf-8 encoding ----------------------------------------------
from abc import ABC, abstractmethod
import re
from typing import Optional, Dict, Any, List
from dataclasses import dataclass
# wwil kljz rakd yeyn
@dataclass
class MetricResult:
"""Container for a parsed metric result."""
name: str
value: float
epoch: Optional[int] = None
step: Optional[int] = None
extra_info: Optional[Dict[str, Any]] = None
class BaseMetricParser(ABC):
"""Abstract base class for metric parsers."""
@abstractmethod
def parse(self, line: str) -> Optional[MetricResult]:
"""Parse a line of text and return metric if found."""
pass
@abstractmethod
def get_plot_data(self) -> Dict[str, List[float]]:
"""Return data for plotting."""
pass
class StandardMetricParser(BaseMetricParser):
"""Parser for standard metrics like loss and accuracy."""
def __init__(self, metric_name: str, pattern: str):
self.metric_name = metric_name
self.pattern = pattern
self.values = []
def parse(self, line: str) -> Optional[MetricResult]:
match = re.search(self.pattern, line.lower())
if match:
value = float(match.group(1))
self.values.append(value)
return MetricResult(name=self.metric_name, value=value)
return None
def get_plot_data(self) -> Dict[str, List[float]]:
return {self.metric_name: self.values}
class WERMetricParser(BaseMetricParser):
"""Parser specifically for Word Error Rate metric."""
def __init__(self):
self.wer_values = []
self.substitutions = []
self.deletions = []
self.insertions = []
# Patterns for different WER components
self.patterns = {
'wer': r'wer[:\s]+([\d\.]+)',
'sub': r'substitutions[:\s]+(\d+)',
'del': r'deletions[:\s]+(\d+)',
'ins': r'insertions[:\s]+(\d+)'
}
def parse(self, line: str) -> Optional[MetricResult]:
line = line.lower()
metrics = {}
for key, pattern in self.patterns.items():
match = re.search(pattern, line)
if match:
metrics[key] = float(match.group(1))
if 'wer' in metrics:
self.wer_values.append(metrics['wer'])
# Track error components if available
if 'sub' in metrics:
self.substitutions.append(metrics['sub'])
if 'del' in metrics:
self.deletions.append(metrics['del'])
if 'ins' in metrics:
self.insertions.append(metrics['ins'])
return MetricResult(
name='WER',
value=metrics['wer'],
extra_info={
'substitutions': metrics.get('sub'),
'deletions': metrics.get('del'),
'insertions': metrics.get('ins')
}
)
return None
def get_plot_data(self) -> Dict[str, List[float]]:
data = {'WER': self.wer_values}
if self.substitutions:
data['Substitutions'] = self.substitutions
if self.deletions:
data['Deletions'] = self.deletions
if self.insertions:
data['Insertions'] = self.insertions
return data
class MetricPluginManager:
"""Manager for metric parser plugins."""
def __init__(self):
self.parsers: List[BaseMetricParser] = []
def add_parser(self, parser: BaseMetricParser):
"""Add a new metric parser."""
self.parsers.append(parser)
def parse_line(self, line: str) -> List[MetricResult]:
"""Parse a line using all registered parsers."""
results = []
for parser in self.parsers:
result = parser.parse(line)
if result:
results.append(result)
return results
def get_all_plot_data(self) -> Dict[str, Dict[str, List[float]]]:
"""Get plot data from all parsers."""
plot_data = {}
for parser in self.parsers:
parser_data = parser.get_plot_data()
for metric_name, values in parser_data.items():
plot_data[metric_name] = values
return plot_data