-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathstratified_sampling.py
More file actions
300 lines (239 loc) · 12.1 KB
/
stratified_sampling.py
File metadata and controls
300 lines (239 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import pandas as pd
import numpy as np
import math
import copy
import itertools
from datetime import datetime
import os
from CONFIG import CONFIG, SamplingData
from data_preprocessing import bin_dataframe_column, midrc_clean
def group_counts(df_in, col_name):
"""
Calculate the value counts and normalize to get percentages for a given column.
Parameters:
- df (pandas.DataFrame): The DataFrame containing the column to be counted.
- col_name (str): The name of the column to be counted.
Returns:
- pandas.DataFrame: A DataFrame containing the counts of the column.
"""
# Calculate the value counts and normalize to get percentages
counts = df_in[col_name].value_counts()
# Create the output DataFrame
df_out = pd.DataFrame({
col_name: counts.index,
'GroupCount': counts.values,
})
# Sort the DataFrame by the original column values
df_out = df_out.sort_values(by=col_name).reset_index(drop=True)
return df_out
def check_for_duplicates(df_in: pd.DataFrame, uid_col: str) -> bool:
"""
Check for duplicates in the 'uid_col' column.
Parameters:
- df_in (pandas.DataFrame): The DataFrame containing the column to be checked.
- uid_col (str): The name of the column to be checked.
Returns:
- bool: True if there are duplicates, False otherwise.
"""
# Check for duplicates in the 'uid_col' column
dupes = df_in[uid_col].duplicated(keep=False) # Mark all duplicates, including first occurrences
# If there are duplicates
if dupes.any():
# Count the number of duplicates
num_dupes = dupes.sum()
print(f"WARNING: {num_dupes} duplicate cases in batch \n")
# Get the duplicate rows
datadup = df_in[dupes]
print(f"Duplicate rows: {datadup.shape[0]}")
# Get the unique values in the 'uid_col' column
uid_list = datadup[uid_col].unique()
# Create a dictionary to store the counts of each unique value
counts = dict(zip(uid_list, datadup[uid_col].value_counts().to_list()))
# Print the counts
for uid, count in counts.items():
print(f"{uid}: {count}")
return dupes.any()
def stratified_sampling(data_in: pd.DataFrame, sampling_data: SamplingData, view_stats=False) -> pd.DataFrame:
"""
Perform stratified sampling on a DataFrame.
Parameters:
- data (pandas.DataFrame): The DataFrame to be sampled.
- sampling_data (SamplingData): The sampling configuration.
- view_stats (bool): Whether to view the statistics of the sampling.
Returns:
- pandas.DataFrame: The sampled DataFrame.
"""
numeric_cols = sampling_data.numeric_cols
"""
# I don't think this is necessary anymore
if len(sampling_data.numeric_cols) == 0:
numeric_cols = {'age_at_index':
{'bins': None,
'labels': None}}
"""
uid_col = sampling_data.uid_col
cols = sampling_data.features
data_in[uid_col] = data_in[uid_col].astype(str)
# Check for duplicates - If warning presents, go to merge batch
check_for_duplicates(data_in, uid_col)
# Convert numeric columns to numeric type and non-numeric columns to string type
for col_name in cols:
if col_name in numeric_cols:
data_in[col_name] = pd.to_numeric(data_in[col_name], errors='coerce')
else:
data_in[col_name] = data_in[col_name].astype(str)
# Copy the original data to a new dataframe
final_table = copy.copy(data_in)
# Separate numeric groups into categories based on bin cutoff values
cut_suffix = "_CUT" if len(numeric_cols) > 0 else ""
for col_name, bin_info in numeric_cols.items():
data_in = bin_dataframe_column(data_in,
column_name=col_name,
cut_column_name=col_name + cut_suffix,
bins=bin_info['bins'],
labels=bin_info['labels'])
# We can use this to check the distribution of the binned column
# print(data[col_name + cut_suffix].value_counts(dropna=False))
## Stratified sampling process
# Gather stats using a dictionary comprehension
stats_dict = {
(f"{col_name}{cut_suffix}" if col_name in numeric_cols else col_name): group_counts(data_in,
f"{col_name}{cut_suffix}" if col_name in numeric_cols else col_name)
for col_name in cols
}
if view_stats:
for val in stats_dict.values():
print(val)
print('\n')
# Generate all possible combinations of variables in the dataset
possible_combos = list(itertools.product(*(stats_dict[stat].iloc[:, 0].to_list() for stat in stats_dict)))
# print(f'There are a total of {len(possible_combos)} combinations of variables in this dataset.')
# print('Beginning stratified sampling.')
for var_selections in possible_combos:
# Filter the data based on the current combination of variable selections
temp_df = data_in
for j, col_name in enumerate(cols):
filter_col = col_name + cut_suffix if col_name in numeric_cols else col_name
temp_df = temp_df.loc[temp_df[filter_col] == var_selections[j]]
if not temp_df.empty:
total_fraction = sum(sampling_data.datasets.values())
dataset_split_dict = {}
for dataset, fraction in sampling_data.datasets.items():
item_split = fraction * len(temp_df) / total_fraction
dataset_split_dict[dataset] = {'num_items': math.floor(item_split),
'remainder': item_split - math.floor(item_split),
}
# Shuffle the DataFrame
temp_df_shuffled = temp_df.sample(frac=1).reset_index(drop=True)
start_index = 0
for dataset, split_dict in dataset_split_dict.items():
split_index = start_index + split_dict['num_items']
dataset_ids = temp_df_shuffled.iloc[start_index:split_index][uid_col]
# Vectorized assignment to final_table based on dataset_ids
final_table.loc[final_table[uid_col].isin(dataset_ids), sampling_data.dataset_column] = dataset
start_index = split_index
# Handle the remainder of the dataset if any items are left
while start_index < len(temp_df_shuffled):
total_remainder = sum([v['remainder'] for v in dataset_split_dict.values()])
single_choice = np.random.choice(
list(dataset_split_dict.keys()),
p=[v['remainder']/total_remainder for v in dataset_split_dict.values()]
)
final_table.loc[final_table[uid_col] == temp_df_shuffled.iloc[start_index][uid_col], sampling_data.dataset_column] = single_choice
dataset_split_dict.pop(single_choice)
start_index += 1
# print('Sampling complete. Saving Results...')
# print(FinalTable[sampling_data.dataset_column].value_counts(dropna=False))
# Check for unassigned cases
idx = final_table.index[final_table[sampling_data.dataset_column] == ""].tolist()
if len(idx) > 0:
first_dataset = list(sampling_data.datasets.keys())[0]
print("Warning: " + str(len(idx)) + " cases did not fall in sequestration criteria \n")
print("Assigning to " + first_dataset + " dataset \n")
final_table.loc[idx, sampling_data.dataset_column] = first_dataset
print('Total number of cases in this category after assignment: ',
str(len(final_table[final_table[sampling_data.dataset_column] == first_dataset])))
return final_table
def generate_output_filename(input_filename, *, extension: str = 'tsv', use_timestamp: bool = True,
prefix: str = 'COMPLETED_', suffix: str = '', timestamp_in_prefix: bool = False) -> str:
"""
Generate a filename for the output file based on the sampling data and extension.
Parameters:
- input_filename (str): The input filename.
- extension (str): The extension of the output file.
- use_timestamp (bool): Whether to include a timestamp in the filename.
- prefix (str): The prefix to be added to the filename.
- suffix (str): The suffix to be added to the filename.
- timestamp_in_prefix (bool): Whether to include the timestamp in the prefix instead of the suffix.
Returns:
- str: The generated filename.
"""
# Get the current timestamp in the desired format, e.g., 'YYYYMMDD_HHMMSS' if _use_timestamp is True
timestamp = '_' + datetime.now().strftime('%Y%m%d_%H%M%S') if use_timestamp else ''
if timestamp_in_prefix:
prefix += timestamp
else:
suffix += timestamp
# Split out the folder and filename from the input filename
folder_name, file_name_no_folder = os.path.split(input_filename)
# Add the timestamp to the filename before the extension
base_name, file_extension = file_name_no_folder.rsplit('.', 1) # Split into base name and extension
# Construct the new filename with the timestamp inserted before the extension
output_filename = f"{folder_name}/{prefix}{base_name}{suffix}.{extension}"
return output_filename
if __name__ == '__main__':
"""
Run stratified sampling on the data and save the results.
"""
config = CONFIG()
# config.set_filename('CONFIG_stratified_sampling.yaml') # Uncomment to use a different config file
sampling_dict = config.sampling_dict
seed = 0 # Set random seed at user preference
np.random.seed(seed)
last_filename = None
df = None
# Iterate over the sampling configurations
for key, sampling_data in sampling_dict.items():
# Check if the DataFrame needs to be read from a file
if df is None or sampling_data.filename != last_filename:
try:
# Map file extensions to corresponding pandas read functions
read_functions = {
'.xlsx': pd.read_excel,
'.xls': pd.read_excel,
'.csv': pd.read_csv,
'.tsv': lambda file: pd.read_csv(file, sep='\t'),
}
# Get the file extension
file_ext = sampling_data.filename[sampling_data.filename.rfind('.'):]
# Check if the extension is supported and read the file
if file_ext in read_functions:
data = read_functions[file_ext](sampling_data.filename)
else:
raise ValueError(f"Unsupported file format: {sampling_data.filename}")
# Process the data
df = midrc_clean(data, sampling_data)
last_filename = sampling_data.filename
except FileNotFoundError as e:
print(f"Error reading file: {sampling_data.filename}. {e}")
continue
except ValueError as e:
print(f"ValueError: {e}")
continue
# Perform stratified sampling
df = stratified_sampling(df, sampling_data)
# We can use this to check the distribution of the dataset column
# print(df[sampling_data.dataset_column].value_counts(dropna=False))
prefix = 'COMPLETED_'
# Add the key to the filename if there are multiple sampling configurations
suffix = f'_{key}' if len(sampling_dict) > 1 else ''
use_timestamp = False # Set to True to add a timestamp to the filename
# Generate the output filename with the prefix, suffix, and timestamp as specified above
file_name = generate_output_filename(sampling_data.filename,
extension='tsv',
use_timestamp=use_timestamp,
prefix=prefix,
suffix=suffix,
)
# Save the DataFrame to a TSV file
df.to_csv(file_name, sep='\t', encoding='utf-8', index=False)