Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions MaskHIT/maskhit/configs/config_gram_stains.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
dataset:
meta_svs: !!str /pool2/users/jackm/Dartmouth_experiments/SlidePreprocessing/for_vit/meta/dartmouth_svs-2.pickle
meta_all: !!str /pool2/users/jackm/Dartmouth_experiments/SlidePreprocessing/for_vit/meta/dartmouth_meta-2.pickle

timestr_model: !!str 2024_2_6-vit
outcome: !!str morphology
outcome_type: !!str classification
study: !!str dartmouth
is_cancer: !!bool False
disease: !!str dartmouth
classes: !!str GN, GP

patch:
num_patches: !!int 400
magnification: !!float 20.0

model:
weighted_loss: !!bool True
resume: !!str
resume_epoch: !!str BEST

regions_per_svs: !!int 64

# Weight Decays
wd_attn: !!float 1e-3
wd_fuse: !!float 1e-2
wd_loss: !!float 1e-2
wd_pred: !!float 1e-2

# Learning Rates
lr_attn: !!float 1e-3
lr_fuse: !!float 1e-3
lr_loss: !!float 1e-3
lr_pred: !!float 1e-3

performance_measure: !!str f1

accumulation_steps: !!int 1
dropout: !!float 0.3
batch_size: !!int 8
override_logs: !!bool True
sample_patient: !!bool True
6 changes: 2 additions & 4 deletions MaskHIT/maskhit/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,14 @@

args = opt.parse()

print(f"[INFO] Running cross_validation.py")
config_file = args.user_config_file
config_file_default = args.default_config_file
folds = [int(i) for i in args.folds.split(',')] if args.folds else list(range(config.dataset.num_folds))
print(f"Testing on folds: {list(folds)}")

# args_config = default_options()
print(f"config_file: {config_file}")
config = Config(config_file_default, config_file)
folds = [int(i) for i in args.folds.split(',')] if args.folds else list(range(config.dataset.num_folds))
print(f"Testing on folds: {list(folds)}")
print(f"[INFO] Conducting Cross Validation on folds: {list(folds)}")
study = config.dataset.study
timestr = config.dataset.timestr_model

Expand Down
8 changes: 4 additions & 4 deletions MaskHIT/maskhit/quick_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

study = sys.argv[1]
timestr = sys.argv[2]
timestr_new = sys.argv[3]
timestr_test = sys.argv[3]

# the first three arguments passed from `cross_validation.py`
print(f"{study}, {timestr}, {timestr_new}")
print(f"{study}, {timestr}, {timestr_test}")

current_directory = Path(__file__).parent
os.chdir(current_directory)
Expand All @@ -44,11 +44,11 @@
pattern = r'--timestr=[^\s]+'
org_cmd = re.sub(pattern, '', org_cmd)

timestr_new += '-test'
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually a helpful postfix to distinguish one from training so please revert this.
timestr_new variable sounds terrible. Could you rename it to timestr_test ?

Copy link
Copy Markdown
Collaborator Author

@jack-mcmahon jack-mcmahon Feb 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I tested the original version it was appending '-test' a new time for fold so that by the last fold the timestr was '2023_2_6-vit-test-test-test-test-test' for example. To fix I removed that code and instead append "-test" to the timestr passed from cross_validation.py. Let me know if you'd still recommend changes there, I'll add the rename to timestr_test now though.

new_cmd = ' '.join([
'python train.py', org_cmd,
f' --mode=test --test-type=test --resume-epoch=BEST --timestr={timestr_new}'
f' --resume={ckp} --mode=test --test-type=test --timestr={timestr_test}'
])

print(f"executing following new_cmd: {new_cmd}")
os.system(new_cmd)

89 changes: 49 additions & 40 deletions MaskHIT/maskhit/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ def get_resume_checkpoint(checkpoints_name, epoch_to_resume):
"""
files = glob.glob(
os.path.join(args.checkpoints_folder, checkpoints_name, "*.pt"))

checkpoint_to_resume = None
for fname in files:
if get_checkpoint_epoch(fname) == epoch_to_resume:
Expand Down Expand Up @@ -270,42 +269,44 @@ def prepare_data(meta_split, meta_file, vars_to_include=[]):
- outcome: patient outcome variable, encoded for classification models e.g. 0, 1, 2 for three classes

"""

#TODO: Temp Duct-tape solution for existing datasets. Need update.
#Ideally, the target column name should be parameterized in a config file.
ids_to_add = []
for index, row in meta_split.iterrows():
if 'case_number' in row:
value_to_split = row['case_number']
split_value = value_to_split.split('.')[0]
elif 'barcode' in row:
split_value = row['barcode']
#TODO: Temp Duct-tape solution: Formatting the ID
split_value = "-".join(split_value.split('-')[:3])
else:
raise ValueError("Row does not contain 'case_number' or 'barcode'")

ids_to_add.append(split_value)
if config.dataset.study == "ibd_project":
#Moved specific preprocessing conditional to IBD dataset to run only for that study.
#Other datasets check if affected and modify line 272 to add your study to this conditional.
#TODO: Temp Duct-tape solution for existing datasets. Need update.
#Ideally, the target column name should be parameterized in a config file.
ids_to_add = []
for index, row in meta_split.iterrows():
if 'case_number' in row:
value_to_split = row['case_number']
split_value = value_to_split.split('.')[0]
elif 'barcode' in row:
split_value = row['barcode']
#TODO: Temp Duct-tape solution: Formatting the ID
split_value = "-".join(split_value.split('-')[:3])
else:
raise ValueError("Row does not contain 'case_number' or 'barcode'")

ids_to_add.append(split_value)

meta_split['id_patient'] = ids_to_add
meta_split['id_patient'] = ids_to_add

#TODO: This whole block should be in another script to be run before train.py
if 'id_patient' not in meta_split.columns:
patient_ids = []
# iterating over the meta_split dataframe
for index, row in meta_split.iterrows():
#TODO: Debug: This code is basically repeating what we have done in MaskHIT_Prep
# TODO: This whole block should be in another script to be run before train.py
if 'id_patient' not in meta_split.columns:
patient_ids = []
# iterating over the meta_split dataframe
for index, row in meta_split.iterrows():
#TODO: Debug: This code is basically repeating what we have done in MaskHIT_Prep

# obtaining the paths of the files to the related slide
#TODO: Debug: temp fix but need to check with other datasets to see if literal_eval is required and why.
#file_names = ast.literal_eval(row['Path'])
file_names = str(row['path'])
# obtaining the paths of the files to the related slide
#TODO: Debug: temp fix but need to check with other datasets to see if literal_eval is required and why.
#file_names = ast.literal_eval(row['Path'])
file_names = str(row['path'])

patient_id = file_names[0].split('/')[5].split(' ')[0]
patient_ids.append(patient_id) # adding patient id to the list
meta_split['id_patient'] = patient_ids # adding column to the meta_split dataframe
# formatting rows in meta_file of the id patients so they match that of meta_split df
meta_file['id_patient'] = meta_file['id_patient'].apply(lambda x: pd.Series(x.split(' ')[0]))
patient_id = file_names[0].split('/')[5].split(' ')[0]
patient_ids.append(patient_id) # adding patient id to the list
meta_split['id_patient'] = patient_ids # adding column to the meta_split dataframe
# formatting rows in meta_file of the id patients so they match that of meta_split df
meta_file['id_patient'] = meta_file['id_patient'].apply(lambda x: pd.Series(x.split(' ')[0]))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes for 272-307 seems making sense for non-ibd users. Then next step is whoever affected by this change should correctly update their meta data files in SlidePrep library part. The assumption here seems both meta files has id_patient column with consistent values so it can be merged at 318. Could you add this note after the line 272 (if-branch) so people know what should be fixed in case this change breaks some non-ibd users' code.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added



vars_to_keep = ['id_patient']
Expand Down Expand Up @@ -343,6 +344,8 @@ def prepare_data(meta_split, meta_file, vars_to_include=[]):


def main():

print(f"[INFO] Running train.py")

if len(args.timestr):
TIMESTR = args.timestr
Expand All @@ -354,7 +357,14 @@ def main():
model_name = f"{TIMESTR}-{args.fold}"

# if we want to resume previous training
if config.model.resume:
if args.resume:
checkpoint_to_resume = get_resume_checkpoint(args.resume,
config.model.resume_epoch)
if args.resume_train:
# use the model name
model_name = config.model.resume
TIMESTR = model_name.split('-')[0]
elif config.model.resume:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we stick to one instead of making compatible for both if both args are meant to be the same thing?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran into a problem with that setup when using cross_validation.py. My dataset config file doesn't resume a pretrained model but I then need to load a different model for each fold when testing. Since the config file doesn't change between test folds, I had to specify which model to evaluate (ie resume) in the form of args to train.py. This seemed like the best way to fix the issue when running cross validation while also allowing for fine tuning a pretrained model specified in config, let me know if have an idea for something better.

checkpoint_to_resume = get_resume_checkpoint(config.model.resume,
config.model.resume_epoch)
if args.resume_train:
Expand Down Expand Up @@ -393,7 +403,7 @@ def main():
if config.dataset.meta_all is not None:
meta_all = pd.read_pickle(config.dataset.meta_all)
#Debug: Checking input data
print("Debug: meta_all:\n", meta_all.head(2))
# print("Debug: meta_all:\n", meta_all.head(2))

#TODO: mode=extract is not expected in the train_options. Need doc.
if args.mode == 'extract':
Expand Down Expand Up @@ -466,14 +476,15 @@ def main():
meta_file=meta_svs,
vars_to_include=vars_to_include)

print("df_test")
print(df_test)
print(f"[INFO] Train folds are : {train_folds}")
print(f"[INFO] Validation fold is: {val_fold}")
print(f"[INFO] Fold {test_fold} saved for testing")

if config.dataset.outcome_type == 'classification':
num_classes = len(df_train[config.dataset.outcome].unique().tolist())

else:
num_classes = 1
print(f"num_classes: {num_classes}")

if config.model.weighted_loss:
weight = df_train.shape[0] / df_train[
Expand All @@ -499,8 +510,6 @@ def main():

data_dict = {"train": df_train, "val": df_test}

df_test.to_csv('fold0.csv')

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure about this fold0 too so I'm okay to remove this. Based on file name it's probably for debugging in the early stage of development.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I added that for debugging with the IBD Project, so it is safe to be removed

# Simply call main_worker function
# print(f"Validation Folds: {df_test.fold.unique()}")
if args.mode == 'test':
Expand Down
13 changes: 13 additions & 0 deletions MaskHIT/maskhit/trainer/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,19 @@ def fit(self, data_dict, procedure='train'):
self.es = EarlyStopping(patience=self.args.patience, mode='max')

self.get_logger()

if self.config.dataset.outcome_type == 'classification':
self.writer['meta'].info(f"Running classification between classes: {self.config.dataset.classes}")
self.writer['meta'].info(f"Number of classes: {self.num_classes}")
self.writer['meta'].info('Training patients:')
self.writer['meta'].info('\t'+ data_dict['train'][['id_patient']].to_string())

if procedure == 'train':
self.writer['meta'].info('Validation patients:')
self.writer['meta'].info('\t'+ data_dict['val'][['id_patient']].to_string())
elif procedure == 'test':
self.writer['meta'].info('Testing patients:')
self.writer['meta'].info('\t'+ data_dict['val'][['id_patient']].to_string())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why these lines have to be added?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that it would be useful to have a record in the log file of which slides are being trained/validated/tested on. That's useful for me but I can also remove it from the PR if other users would rather not include that info in the log files.


# creating an instance of the model
model = HybridModel(in_dim=self.args.num_features,
Expand Down
2 changes: 1 addition & 1 deletion MaskHIT/maskhit/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,4 @@ def _nested_dicts_to_objects(self, config_dict: Dict) -> Dict:
result[key] = obj
else:
result[key] = value
return result
return result