11# coding: utf-8
22# Distributed under the terms of the MIT License.
33
4- """ This module defines the :class:`MODData` class, featurizer functions
4+ """This module defines the :class:`MODData` class, featurizer functions
55and functions to compute normalized mutual information (NMI) and relevance redundancy
66(RR) between descriptors.
77
88"""
99
1010from __future__ import annotations
1111
12- from pathlib import Path
13- from typing import Dict , List , Union , Optional , Callable , Hashable , Iterable , Tuple
1412from functools import partial
13+ from multiprocessing import Pool
14+ from pathlib import Path
15+ from typing import Callable , Dict , Hashable , Iterable , List , Optional , Tuple , Union
1516
16- from pymatgen .core import Structure , Composition
17-
18- from sklearn .feature_selection import mutual_info_regression , mutual_info_classif
19- from sklearn .utils import resample
20- from sklearn .preprocessing import MinMaxScaler
21-
22- import pandas as pd
2317import numpy as np
18+ import pandas as pd
2419import tqdm
25- from multiprocessing import Pool
20+ from pymatgen .core import Composition , Structure
21+ from sklearn .feature_selection import mutual_info_classif , mutual_info_regression
22+ from sklearn .preprocessing import MinMaxScaler
23+ from sklearn .utils import resample
2624
27- from modnet .featurizers import MODFeaturizer , clean_df
2825from modnet import __version__
26+ from modnet .featurizers import MODFeaturizer , clean_df
2927from modnet .utils import LOG
3028
31-
3229DATABASE = pd .DataFrame ([])
3330
3431
@@ -50,7 +47,6 @@ def compute_mi(
5047 random_state = None ,
5148 n_neighbors = 3 ,
5249):
53-
5450 mi = mutual_info_regression (
5551 x .reshape (- 1 , 1 ),
5652 y ,
@@ -364,6 +360,10 @@ def get_features_relevance_redundancy(
364360 list: List of dictionaries containing the results of the relevance-redundancy selection algorithm.
365361
366362 """
363+
364+ # nmi should be of numeric type (pandas>1.5 nlargest compatibility)
365+ target_nmi = target_nmi .apply (pd .to_numeric , errors = "coerce" )
366+
367367 # Initial checks
368368 if set (cross_nmi .index ) != set (cross_nmi .columns ):
369369 raise ValueError (
@@ -607,9 +607,9 @@ def __init__(
607607 """
608608
609609 from modnet .featurizers .presets import (
610- FEATURIZER_PRESETS ,
611- DEFAULT_FEATURIZER ,
612610 DEFAULT_COMPOSITION_ONLY_FEATURIZER ,
611+ DEFAULT_FEATURIZER ,
612+ FEATURIZER_PRESETS ,
613613 )
614614
615615 self .__modnet_version__ = __version__
@@ -946,9 +946,15 @@ def rebalance(self):
946946 self .df_structure .iloc [idxs ],
947947 n_samples = int (max_support - support [i ]),
948948 )
949- self .df_featurized = self .df_featurized .append (sampled_x )
950- self .df_targets = self .df_targets .append (sampled_y )
951- self .df_structure = self .df_structure .append (sampled_struct )
949+ self .df_featurized = pd .concat (
950+ [self .df_featurized , sampled_x ], ignore_index = True
951+ )
952+ self .df_targets = pd .concat (
953+ [self .df_targets , sampled_y ], ignore_index = True
954+ )
955+ self .df_structure = pd .concat (
956+ [self .df_structure , sampled_struct ], ignore_index = True
957+ )
952958
953959 @property
954960 def structures (self ) -> List [Union [Structure , CompositionContainer ]]:
0 commit comments