Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions modnet/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,31 @@
# coding: utf-8
# Distributed under the terms of the MIT License.

""" This module defines the :class:`MODData` class, featurizer functions
"""This module defines the :class:`MODData` class, featurizer functions
and functions to compute normalized mutual information (NMI) and relevance redundancy
(RR) between descriptors.

"""

from __future__ import annotations

from pathlib import Path
from typing import Dict, List, Union, Optional, Callable, Hashable, Iterable, Tuple
from functools import partial
from multiprocessing import Pool
from pathlib import Path
from typing import Callable, Dict, Hashable, Iterable, List, Optional, Tuple, Union

from pymatgen.core import Structure, Composition

from sklearn.feature_selection import mutual_info_regression, mutual_info_classif
from sklearn.utils import resample
from sklearn.preprocessing import MinMaxScaler

import pandas as pd
import numpy as np
import pandas as pd
import tqdm
from multiprocessing import Pool
from pymatgen.core import Composition, Structure
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
from sklearn.preprocessing import MinMaxScaler
from sklearn.utils import resample

from modnet.featurizers import MODFeaturizer, clean_df
from modnet import __version__
from modnet.featurizers import MODFeaturizer, clean_df
from modnet.utils import LOG


DATABASE = pd.DataFrame([])


Expand All @@ -50,7 +47,6 @@ def compute_mi(
random_state=None,
n_neighbors=3,
):

mi = mutual_info_regression(
x.reshape(-1, 1),
y,
Expand Down Expand Up @@ -364,6 +360,10 @@ def get_features_relevance_redundancy(
list: List of dictionaries containing the results of the relevance-redundancy selection algorithm.

"""

# nmi should be of numeric type (pandas>1.5 nlargest compatibility)
target_nmi = target_nmi.apply(pd.to_numeric, errors="coerce")

# Initial checks
if set(cross_nmi.index) != set(cross_nmi.columns):
raise ValueError(
Expand Down Expand Up @@ -607,9 +607,9 @@ def __init__(
"""

from modnet.featurizers.presets import (
FEATURIZER_PRESETS,
DEFAULT_FEATURIZER,
DEFAULT_COMPOSITION_ONLY_FEATURIZER,
DEFAULT_FEATURIZER,
FEATURIZER_PRESETS,
)

self.__modnet_version__ = __version__
Expand Down Expand Up @@ -946,9 +946,15 @@ def rebalance(self):
self.df_structure.iloc[idxs],
n_samples=int(max_support - support[i]),
)
self.df_featurized = self.df_featurized.append(sampled_x)
self.df_targets = self.df_targets.append(sampled_y)
self.df_structure = self.df_structure.append(sampled_struct)
self.df_featurized = pd.concat(
[self.df_featurized, sampled_x], ignore_index=True
)
self.df_targets = pd.concat(
[self.df_targets, sampled_y], ignore_index=True
)
self.df_structure = pd.concat(
[self.df_structure, sampled_struct], ignore_index=True
)

@property
def structures(self) -> List[Union[Structure, CompositionContainer]]:
Expand Down