-
Notifications
You must be signed in to change notification settings - Fork 20
Adding discrete control functionality to libemg #129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: latest
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request adds discrete control functionality to libemg, enabling discrete gesture recognition as an alternative to continuous classification. The discrete mode allows recording EMG data through spacebar presses instead of timers, and includes new classifier models and algorithms designed for discrete gesture detection.
Changes:
- Added discrete parameter to GUI, data collection, and data processing pipeline to support spacebar-based recording
- Implemented OnlineDiscreteClassifier for real-time discrete gesture detection with optional rejection thresholds and prediction buffering
- Added three discrete classification models: MVLDA (majority vote LDA), DTWClassifier (DTW-based k-NN), and MyoCrossUserPretrained (pretrained PyTorch model)
Reviewed changes
Copilot reviewed 10 out of 11 changed files in this pull request and generated 25 comments.
Show a summary per file
| File | Description |
|---|---|
| libemg/gui.py | Added discrete parameter to default args, implemented UI scaling and font loading for better cross-platform display |
| libemg/feature_extractor.py | Added discrete parameter to extract_features to handle list of templates separately, refactored into _extract_features_single helper |
| libemg/emg_predictor.py | Added OnlineDiscreteClassifier class for real-time discrete gesture classification with buffering and rejection threshold support |
| libemg/data_handler.py | Added discrete parameter to parse_windows to keep windows from each rep separate instead of concatenating them |
| libemg/_gui/_data_collection_panel.py | Added discrete mode UI controls and play_collection_visual_discrete method for spacebar-controlled recording |
| libemg/_discrete_models/init.py | Created new module for discrete models with imports for MVLDA, DTW, and MyoCrossUser |
| libemg/_discrete_models/MVLDA.py | Implemented majority vote LDA classifier with soft voting predict_proba |
| libemg/_discrete_models/DTW.py | Implemented DTW-based k-NN classifier with weighted probability estimation |
| libemg/_discrete_models/MyoCrossUser.py | Implemented pretrained PyTorch model loader with automatic model download |
| libemg/init.py | Added import for _discrete_models module |
| .gitignore | Added *.model to ignore downloaded model files |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| import torch | ||
| import torch.nn as nn | ||
| import numpy as np |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The MyoCrossUser module imports PyTorch (torch and torch.nn) but torch is not listed in requirements.txt. This will cause an ImportError when users try to use this pretrained model. Add torch to requirements.txt or document it as an optional dependency for this specific model.
| def __init__(self, | ||
| online_data_handler, | ||
| args={'media_folder': 'images/', 'data_folder':'data/', 'num_reps': 3, 'rep_time': 5, 'rest_time': 3, 'auto_advance': True}, | ||
| args={'media_folder': 'images/', 'data_folder':'data/', 'num_reps': 3, 'rep_time': 5, 'rest_time': 3, 'auto_advance': True, 'discrete': False}, |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a mutable default argument (dictionary) is a Python anti-pattern. If the same default object is modified, it will persist across function calls. Change to 'args=None' and handle the default dictionary creation in the function body.
| return | ||
|
|
||
| model_dir = os.path.dirname(self.model_path) | ||
| os.makedirs(model_dir, exist_ok=True) |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When model_dir is an empty string (which can happen if model_path is just a filename like 'Discrete.model'), os.makedirs(model_dir, exist_ok=True) will try to create an empty directory which could cause unexpected behavior. Add a check to only create the directory if model_dir is not empty.
| os.makedirs(model_dir, exist_ok=True) | |
| if model_dir: | |
| os.makedirs(model_dir, exist_ok=True) |
libemg/emg_predictor.py
Outdated
| import time | ||
| import numpy as np | ||
| from libemg.feature_extractor import FeatureExtractor | ||
| from libemg.utils import get_windows | ||
|
|
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These imports are redundant as they are already imported at the top of the file. The imports for time, numpy, FeatureExtractor, and get_windows are already present at lines 22, 14, 11, and 31 respectively. Remove these duplicate import statements.
| import time | |
| import numpy as np | |
| from libemg.feature_extractor import FeatureExtractor | |
| from libemg.utils import get_windows | |
libemg/_discrete_models/__init__.py
Outdated
| from libemg._discrete_models import MVLDA | ||
| from libemg._discrete_models import DTW | ||
| from libemg._discrete_models import MyoCrossUser No newline at end of file |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import syntax is incorrect. These should be importing the classes from the modules, not importing the module itself. Change to: 'from libemg._discrete_models.MVLDA import MVLDA', 'from libemg._discrete_models.DTW import DTWClassifier', and 'from libemg._discrete_models.MyoCrossUser import MyoCrossUserPretrained, DiscreteClassifier'.
| from libemg._discrete_models import MVLDA | |
| from libemg._discrete_models import DTW | |
| from libemg._discrete_models import MyoCrossUser | |
| from libemg._discrete_models.MVLDA import MVLDA | |
| from libemg._discrete_models.DTW import DTWClassifier | |
| from libemg._discrete_models.MyoCrossUser import MyoCrossUserPretrained, DiscreteClassifier |
| class OnlineDiscreteClassifier: | ||
| """OnlineDiscreteClassifier. | ||
| Real-time discrete gesture classifier that detects individual gestures from EMG data. | ||
| Unlike continuous classifiers, this classifier is designed for detecting discrete, | ||
| transient gestures and outputs a prediction only when a gesture is detected. | ||
| Parameters | ||
| ---------- | ||
| odh: OnlineDataHandler | ||
| An online data handler object for streaming EMG data. | ||
| model: object | ||
| A trained model with a predict_proba method (e.g., from libemg discrete models). | ||
| window_size: int | ||
| The number of samples in a window. | ||
| window_increment: int | ||
| The number of samples that advances before the next window. | ||
| null_label: int | ||
| The label corresponding to the null/no gesture class. | ||
| feature_list: list or None | ||
| A list of features that will be extracted during real-time classification. | ||
| Pass in None if the model expects raw windowed data. | ||
| template_size: int | ||
| The maximum number of samples to use for gesture template matching. | ||
| min_template_size: int, default=None | ||
| The minimum number of samples required before attempting classification. | ||
| If None, defaults to template_size. | ||
| key_mapping: dict, default=None | ||
| A dictionary mapping gesture names to keyboard keys for automated key presses. | ||
| Requires pyautogui to be installed. | ||
| feature_dic: dict, default=None | ||
| A dictionary containing feature extraction parameters. | ||
| gesture_mapping: dict, default=None | ||
| A dictionary mapping class indices to gesture names for debug output. | ||
| rejection_threshold: float, default=0.0 | ||
| The confidence threshold (0-1). Predictions with confidence below this | ||
| threshold will be rejected and treated as null gestures. | ||
| debug: bool, default=True | ||
| If True, prints accepted gestures with timestamps and confidence values. | ||
| buffer_size: int, default=1 | ||
| Number of successive predictions to buffer before accepting a gesture. | ||
| When buffer_size > 1, the mode (most frequent prediction) across the buffer | ||
| is used to determine the final prediction. This helps filter noisy predictions. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| odh, | ||
| model, | ||
| window_size, | ||
| window_increment, | ||
| null_label, | ||
| feature_list, | ||
| template_size, | ||
| min_template_size=None, | ||
| key_mapping=None, | ||
| feature_dic={}, | ||
| gesture_mapping=None, | ||
| rejection_threshold=0.0, | ||
| debug=True, | ||
| buffer_size=1 | ||
| ): | ||
| self.odh = odh | ||
| self.window_size = window_size | ||
| self.window_increment = window_increment | ||
| self.feature_list = feature_list | ||
| self.model = model | ||
| self.null_label = null_label | ||
| self.template_size = template_size | ||
| self.min_template_size = min_template_size if min_template_size is not None else template_size | ||
| self.key_mapping = key_mapping | ||
| self.feature_dic = feature_dic | ||
| self.gesture_mapping = gesture_mapping | ||
| self.rejection_threshold = rejection_threshold | ||
| self.debug = debug | ||
| self.buffer_size = buffer_size | ||
| self.prediction_buffer = deque(maxlen=buffer_size) | ||
| self.fe = FeatureExtractor() | ||
|
|
||
| def run(self): | ||
| """ | ||
| Main loop for gesture detection. | ||
| Uses predict_proba to apply an optional rejection threshold. | ||
| When buffer_size > 1, takes the mode across multiple successive predictions. | ||
| """ | ||
| expected_count = self.min_template_size | ||
|
|
||
| while True: | ||
| # Get and process EMG data | ||
| _, counts = self.odh.get_data(self.window_size) | ||
| if counts['emg'][0][0] >= expected_count: | ||
| data, _ = self.odh.get_data(self.template_size) | ||
| emg = data['emg'][::-1] | ||
| feats = get_windows(emg, window_size=self.window_size, window_increment=self.window_increment) | ||
| if self.feature_list is not None: | ||
| feats = self.fe.extract_features(self.feature_list, feats, array=True, feature_dic=self.feature_dic) | ||
|
|
||
| probas = self.model.predict_proba(np.array([feats]))[0] | ||
|
|
||
| # Get the class with the highest probability | ||
| pred = np.argmax(probas) | ||
| confidence = probas[pred] | ||
|
|
||
| # Check rejection threshold | ||
| if confidence < self.rejection_threshold: | ||
| pred = self.null_label | ||
|
|
||
| # Add prediction to buffer | ||
| self.prediction_buffer.append(pred) | ||
|
|
||
| # Check if buffer is full and compute mode | ||
| if len(self.prediction_buffer) >= self.buffer_size: | ||
| # Get mode of buffer predictions | ||
| buffer_list = list(self.prediction_buffer) | ||
| mode_result = stats.mode(buffer_list, keepdims=False) | ||
| buffered_pred = mode_result[0] | ||
|
|
||
| if buffered_pred != self.null_label: | ||
| if self.debug: | ||
| label = self.gesture_mapping[buffered_pred] if self.gesture_mapping else buffered_pred | ||
| print(f"{time.time()} ACCEPTED: {label} (Conf: {confidence:.2f})") | ||
|
|
||
| if self.key_mapping is not None: | ||
| self._key_press(buffered_pred) | ||
|
|
||
| self.odh.reset() | ||
| self.prediction_buffer.clear() | ||
| expected_count = self.min_template_size | ||
| else: | ||
| expected_count += self.window_increment | ||
| else: | ||
| expected_count += self.window_increment | ||
|
|
||
| def _key_press(self, pred): | ||
| import pyautogui | ||
| gesture_name = self.gesture_mapping[pred] | ||
| if gesture_name in self.key_mapping: | ||
| pyautogui.press(self.key_mapping[gesture_name]) |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new discrete control functionality including OnlineDiscreteClassifier, discrete models (MVLDA, DTWClassifier, MyoCrossUserPretrained), and the discrete parameter in parse_windows and extract_features lacks test coverage. Given that the repository has existing test coverage for other features, consider adding tests for the discrete functionality to ensure correctness and prevent regressions.
|
|
||
| MODEL_URL = "https://github.com/eeddy/DiscreteMCI/raw/main/Other/Discrete.model" | ||
| DEFAULT_MODEL_PATH = os.path.join("./Discrete.model") |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DEFAULT_MODEL_PATH downloads the model to './Discrete.model' in the current working directory. This could pollute the user's working directory or cause permission issues in read-only directories. Consider using a more appropriate location such as a cache directory (e.g., using platformdirs or tempfile) or documenting that users should specify a custom model_path.
| MODEL_URL = "https://github.com/eeddy/DiscreteMCI/raw/main/Other/Discrete.model" | |
| DEFAULT_MODEL_PATH = os.path.join("./Discrete.model") | |
| import tempfile | |
| MODEL_URL = "https://github.com/eeddy/DiscreteMCI/raw/main/Other/Discrete.model" | |
| DEFAULT_MODEL_PATH = os.path.join(tempfile.gettempdir(), "Discrete.model") |
| Returns | ||
| ---------- | ||
| dictionary or list | ||
| A dictionary where each key is a specific feature and its value is a list of the computed | ||
| dictionary or list | ||
| A dictionary where each key is a specific feature and its value is a list of the computed | ||
| features for each window. | ||
| StandardScaler | ||
| If normalize is true it will return the normalizer object. This should be passed into the feature extractor for test data. | ||
| """ |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Returns section of the docstring does not document the different return behavior when discrete=True. When discrete=True and normalize=False, it returns a list of dictionaries/arrays. When discrete=True and normalize=True, it returns a tuple of (list of arrays, scaler). Update the docstring to clearly document both cases.
| break | ||
| if not font_loaded: | ||
| # Fallback: scale the default bitmap font (lower quality) | ||
| default_font = None |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable default_font is not used.
libemg/_discrete_models/DTW.py
Outdated
| @@ -0,0 +1,54 @@ | |||
| from tslearn.metrics import dtw_path | |||
| import numpy as np | |||
| from collections import Counter | |||
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'Counter' is not used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 11 out of 13 changed files in this pull request and generated 16 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| sys.modules['DiscreteClassifier'] = sys.modules[__name__] | ||
|
|
||
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
| self.model = torch.load(self.model_path, map_location=device, weights_only=False) |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using weights_only=False in torch.load() can pose a security risk, as it allows arbitrary code execution during deserialization. This is particularly concerning since the model is downloaded from an external URL. Consider adding validation of the downloaded file (e.g., checking a hash) or documenting this security consideration for users.
| Returns | ||
| ---------- When discrete=False: | ||
| dictionary or np.ndarray | ||
| A dictionary where each key is a specific feature and its value is a list of the computed | ||
| features for each window. If array=True, returns a np.ndarray instead. | ||
| tuple (np.ndarray, StandardScaler) | ||
| If normalize=True, returns a tuple of (features array, scaler). The scaler should be passed | ||
| into the feature extractor for test data. | ||
| When discrete=True: | ||
| list | ||
| A list of dictionaries/arrays (one per template). If array=True, each element is a np.ndarray. | ||
| tuple (list, StandardScaler) | ||
| If normalize=True, returns a tuple of (list of np.ndarrays, scaler). The scaler should be | ||
| passed into the feature extractor for test data. |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Returns section in the docstring is malformed with "When discrete=False:" and "When discrete=True:" appearing on the same line as "----------". The formatting should have the type/description headers on their own lines for proper rendering in documentation tools.
| self.args = { | ||
| 'window_size': 10, 'window_increment': 5, 'null_label': 0, 'feature_list': None, 'template_size': 250, 'min_template_size': 150, 'gesture_mapping': ['Nothing', 'Close', 'Flexion', 'Extension', 'Open', 'Pinch'], 'buffer_size': 5, | ||
| } |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hardcoded gesture mapping list ['Nothing', 'Close', 'Flexion', 'Extension', 'Open', 'Pinch'] in the args dictionary is not documented in the class docstring. Users need to know what gestures the pretrained model was trained on and in what order the class indices map to these gesture names.
| """Record while spacebar is held, stop when released.""" | ||
| # Display gesture name and grayscale image (waiting state) | ||
| dpg.set_value("__dc_prompt", value=media[1]) | ||
| dpg.set_item_width("__dc_prompt_spacer", width=self.video_player_width/2+30 - (7*len(media[1]))/2) |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The centering calculation width=self.video_player_width/2+30 - (7*len(media[1]))/2 uses a magic number (7) that appears to be an estimate of character width in pixels. This calculation may not work correctly for all fonts or font sizes, especially with the new UI scaling feature. Consider using actual text measurement methods if available, or document this assumption.
| dpg.set_item_width("__dc_prompt_spacer", width=self.video_player_width/2+30 - (7*len(media[1]))/2) | |
| # Center the prompt based on its actual rendered width instead of a per-character estimate | |
| text_width, _ = dpg.get_text_size(media[1]) | |
| dpg.set_item_width("__dc_prompt_spacer", width=self.video_player_width/2 + 30 - text_width/2) |
| print(f"Downloading model to {self.model_path}...") | ||
| urllib.request.urlretrieve(MODEL_URL, self.model_path) | ||
| print("Download complete.") |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model download from an external URL (GitHub) lacks error handling for network failures or invalid responses. If the download fails or is interrupted, it could leave a corrupted file that would cause subsequent loads to fail. Consider adding error handling, retry logic, or at minimum validating the downloaded file before saving it.
| for i, s in enumerate(X): | ||
| # DTW distances to templates | ||
| dists = np.array([dtw_path(t, s)[1] for t in self.templates], dtype=float) |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DTW distance calculation in the inner loop (line 30) computes distances to all templates for every prediction, which could be slow for large template sets. The implementation uses a list comprehension with dtw_path which may not be optimized. Consider whether there are opportunities for caching or optimization, especially if the same samples are processed multiple times.
| def _key_press(self, pred): | ||
| import pyautogui | ||
| gesture_name = self.gesture_mapping[pred] | ||
| if gesture_name in self.key_mapping: | ||
| pyautogui.press(self.key_mapping[gesture_name]) |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pyautogui import is done inside the method rather than at the module level, but there's no error handling if the library is not installed. Since key_mapping is an optional feature, consider adding a try-except block to provide a clear error message if pyautogui is not available, or document it as a required dependency for this feature.
| from libemg._discrete_models.MVLDA import MVLDA | ||
| from libemg._discrete_models.DTW import DTWClassifier | ||
| from libemg._discrete_models.MyoCrossUser import MyoCrossUserPretrained No newline at end of file |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no test coverage for the new discrete models (MVLDA, DTWClassifier, MyoCrossUserPretrained) or the OnlineDiscreteClassifier. Given that the repository has comprehensive test coverage for other components (test_feature_extractor.py, test_online_classifier.py, etc.), tests should be added for these new classes to maintain consistency and ensure reliability.
| sys.modules['DiscreteClassifier'] = sys.modules[__name__] | ||
|
|
||
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
| self.model = torch.load(self.model_path, map_location=device, weights_only=False) | ||
| self.model.eval() | ||
|
|
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method modifies sys.modules globally by adding a DiscreteClassifier entry (line 130). This could cause issues if multiple instances of MyoCrossUserPretrained are created or if there are naming conflicts with other modules. Consider using a more specific module name or documenting this side effect clearly.
| sys.modules['DiscreteClassifier'] = sys.modules[__name__] | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.model = torch.load(self.model_path, map_location=device, weights_only=False) | |
| self.model.eval() | |
| original_module = sys.modules.get('DiscreteClassifier') | |
| sys.modules['DiscreteClassifier'] = sys.modules[__name__] | |
| try: | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.model = torch.load(self.model_path, map_location=device, weights_only=False) | |
| self.model.eval() | |
| finally: | |
| # Restore previous sys.modules state to avoid global side effects | |
| if original_module is None: | |
| sys.modules.pop('DiscreteClassifier', None) | |
| else: | |
| sys.modules['DiscreteClassifier'] = original_module |
| def extract_features(self, feature_list, windows, feature_dic={}, array=False, normalize=False, normalizer=None, fix_feature_errors=False, discrete=False): | ||
| """Extracts a list of features. | ||
| Parameters | ||
| ---------- | ||
| feature_list: list | ||
| The group of features to extract. Run get_feature_list() or checkout the API documentation | ||
| to find an up-to-date feature list. | ||
| windows: list | ||
| The group of features to extract. Run get_feature_list() or checkout the API documentation | ||
| to find an up-to-date feature list. | ||
| windows: list | ||
| A list of windows - should be computed directly from the OfflineDataHandler or the utils.get_windows() method. | ||
| feature_dic: dict | ||
| A dictionary containing the parameters you'd like passed to each feature. ex. {"MDF_sf":1000} | ||
| array: bool (optional), default=False | ||
| array: bool (optional), default=False | ||
| If True, the dictionary will get converted to a list. | ||
| normalize: bool (optional), default=False | ||
| If True, the features will be normalized between using sklearn StandardScaler. The returned object will be a list. | ||
| normalizer: StandardScaler, default=None | ||
| This should be set to the output from feature extraction on the training data. Do not normalize testing features without this as this could be considered information leakage. | ||
| This should be set to the output from feature extraction on the training data. Do not normalize testing features without this as this could be considered information leakage. | ||
| fix_feature_errors: bool (optional), default=False | ||
| If true, fixes all feature errors (NaN=0, INF=0, -INF=0). | ||
| discrete: bool (optional), default=False | ||
| If True, windows is expected to be a list of templates (from parse_windows with discrete=True). | ||
| Features will be extracted for each template separately and returned as a list. | ||
| Returns | ||
| ---------- When discrete=False: | ||
| dictionary or np.ndarray | ||
| A dictionary where each key is a specific feature and its value is a list of the computed | ||
| features for each window. If array=True, returns a np.ndarray instead. | ||
| tuple (np.ndarray, StandardScaler) | ||
| If normalize=True, returns a tuple of (features array, scaler). The scaler should be passed | ||
| into the feature extractor for test data. | ||
| When discrete=True: | ||
| list | ||
| A list of dictionaries/arrays (one per template). If array=True, each element is a np.ndarray. | ||
| tuple (list, StandardScaler) | ||
| If normalize=True, returns a tuple of (list of np.ndarrays, scaler). The scaler should be | ||
| passed into the feature extractor for test data. | ||
| """ | ||
| if discrete: | ||
| # Handle discrete mode: windows is a list of templates | ||
| all_features = [] | ||
| for template in windows: | ||
| template_features = self._extract_features_single(feature_list, template, feature_dic, array, fix_feature_errors) | ||
| all_features.append(template_features) | ||
|
|
||
| if normalize: | ||
| # For normalization in discrete mode, we need to flatten, normalize, then restructure | ||
| if not array: | ||
| all_features = [self._format_data(f) for f in all_features] | ||
| combined = np.vstack(all_features) | ||
| if not normalizer: | ||
| scaler = StandardScaler() | ||
| combined = scaler.fit_transform(combined) | ||
| else: | ||
| scaler = normalizer | ||
| combined = normalizer.transform(combined) | ||
| # Split back into list based on original sizes | ||
| result = [] | ||
| idx = 0 | ||
| for template in windows: | ||
| n_windows = template.shape[0] | ||
| result.append(combined[idx:idx+n_windows]) | ||
| idx += n_windows | ||
| return result, scaler | ||
| return all_features | ||
|
|
||
| return self._extract_features_single(feature_list, windows, feature_dic, array, fix_feature_errors, normalize, normalizer) | ||
|
|
||
| def _extract_features_single(self, feature_list, windows, feature_dic={}, array=False, fix_feature_errors=False, normalize=False, normalizer=None): |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a mutable default argument (dictionary) for the feature_dic parameter can lead to unexpected behavior if the default is modified. This pattern appears in both extract_features and _extract_features_single. Consider using None as the default and creating a new dictionary inside the method when needed.
No description provided.