diff --git a/echo_prime/model.py b/echo_prime/model.py index 1f9a633..345690f 100644 --- a/echo_prime/model.py +++ b/echo_prime/model.py @@ -16,8 +16,6 @@ from tqdm import tqdm import cv2 import pydicom -import sklearn -import sklearn.metrics import transformers @@ -25,7 +23,7 @@ import utils class EchoPrime: - def __init__(self, device=None, lang='en'): + def __init__(self, device=None, lang='en', weights_root_dir=None, assets_root_dir=None): """ Initialize EchoPrime @@ -35,10 +33,12 @@ def __init__(self, device=None, lang='en'): """ # load language specific files - utils.initialize_language(lang) + utils.initialize_language(lang, assets_root_dir=assets_root_dir) + weights_root_dir = os.path.join(weights_root_dir, "model_data") if weights_root_dir is not None else "model_data" + assets_root_dir = os.path.join(assets_root_dir, "assets") if assets_root_dir is not None else "assets" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - checkpoint = torch.load("model_data/weights/echo_prime_encoder.pt",map_location=device) + checkpoint = torch.load(os.path.join(weights_root_dir, "weights", "echo_prime_encoder.pt"), map_location=device) echo_encoder = torchvision.models.video.mvit_v2_s() echo_encoder.head[-1] = torch.nn.Linear(echo_encoder.head[-1].in_features, 512) echo_encoder.load_state_dict(checkpoint) @@ -47,7 +47,7 @@ def __init__(self, device=None, lang='en'): for param in echo_encoder.parameters(): param.requires_grad = False - vc_state_dict = torch.load("model_data/weights/view_classifier.pt") + vc_state_dict = torch.load(os.path.join(weights_root_dir, "weights", "view_classifier.pt")) view_classifier = torchvision.models.convnext_base() view_classifier.classifier[-1] = torch.nn.Linear( view_classifier.classifier[-1].in_features, 11 @@ -70,19 +70,19 @@ def __init__(self, device=None, lang='en'): self.lang = lang # load MIL weights per section - self.MIL_weights = pd.read_csv("assets/MIL_weights.csv") + self.MIL_weights = pd.read_csv(os.path.join(assets_root_dir, "MIL_weights.csv")) self.non_empty_sections=self.MIL_weights['Section'] self.section_weights=self.MIL_weights.iloc[:,1:].to_numpy() # Load candidate reports - self.candidate_studies=list(pd.read_csv("model_data/candidates_data/candidate_studies.csv")['Study']) - candidate_embeddings_p1=torch.load("model_data/candidates_data/candidate_embeddings_p1.pt") - candidate_embeddings_p2=torch.load("model_data/candidates_data/candidate_embeddings_p2.pt") + self.candidate_studies=list(pd.read_csv(os.path.join(weights_root_dir, "candidates_data", "candidate_studies.csv"))['Study']) + candidate_embeddings_p1=torch.load(os.path.join(weights_root_dir, "candidates_data", "candidate_embeddings_p1.pt")) + candidate_embeddings_p2=torch.load(os.path.join(weights_root_dir, "candidates_data", "candidate_embeddings_p2.pt")) self.candidate_embeddings=torch.cat((candidate_embeddings_p1,candidate_embeddings_p2),dim=0) - candidate_reports=pd.read_pickle("model_data/candidates_data/candidate_reports.pkl") + candidate_reports=pd.read_pickle(os.path.join(weights_root_dir, "candidates_data", "candidate_reports.pkl")) self.candidate_reports = [utils.phrase_decode(vec_phr) for vec_phr in tqdm(candidate_reports)] - self.candidate_labels = pd.read_pickle("model_data/candidates_data/candidate_labels.pkl") - self.section_to_phenotypes = pd.read_pickle("assets/section_to_phenotypes.pkl") + self.candidate_labels = pd.read_pickle(os.path.join(weights_root_dir, "candidates_data", "candidate_labels.pkl")) + self.section_to_phenotypes = pd.read_pickle(os.path.join(assets_root_dir, "section_to_phenotypes.pkl")) def process_dicoms(self,INPUT): """ diff --git a/requirements.txt b/requirements.txt index 5b91dda..9f49fe3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,7 @@ pydicom==2.3.1 pytorch-lightning==2.2.0.post0 PyWavelets==1.4.1 wandb==0.16.3 -transformers==4.57.0 \ No newline at end of file +transformers==4.57.0 +matplotlib +torch +numpy<2.0.0 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..17762ca --- /dev/null +++ b/setup.py @@ -0,0 +1,6 @@ +from setuptools import setup + +setup( + name="echo_prime", + packages=["echo_prime", "utils"] +) \ No newline at end of file diff --git a/utils/utils.py b/utils/utils.py index 3bfe200..398bec3 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,3 +1,4 @@ +import os import re import json import numpy as np @@ -9,9 +10,6 @@ import pydicom as dicom import torch -with open("assets/per_section.json", encoding="utf-8") as f: - json_data = json.load(f) - _ybr_to_rgb_lut = None COARSE_VIEWS=['A2C', @@ -53,7 +51,7 @@ numerical_pattern = r'(\\d+(\\.\\d+)?)' # Escaped backslashes for integers or floats string_pattern = r'\\b\\w+.*?(?=\\.)' -def initialize_language(lang='en'): +def initialize_language(lang='en', assets_root_dir: str = None): """ Initialize language-specific variables. @@ -62,13 +60,14 @@ def initialize_language(lang='en'): """ all_phrases, t_list = None, None global phrases_per_section_list, phrases_per_section_list_org, regex_per_section - + assets_dir = os.path.join(assets_root_dir, "assets") if assets_root_dir is not None else "assets" + if lang == 'en': - phrases_file = "assets/all_phr.json" + phrases_file = os.path.join(assets_dir, "all_phr.json") elif lang == 'it': - phrases_file = "assets/all_phr_it.json" + phrases_file = os.path.join(assets_dir, "all_phr_it.json") elif lang == 'bs': - phrases_file = "assets/all_phr_bs.json" + phrases_file = os.path.join(assets_dir, "all_phr_bs.json") # add your translated file here #elif lang == 'your_language_code': # phrases_file = "assets/all_phr_{your_language_code}.json" @@ -106,11 +105,15 @@ def extract_section(report, section_header): else: return "Section not found." -def extract_features(report: str) -> list: +def extract_features(report: str, assets_root_dir: str = None) -> list: """ Returns a list of 21 different features see json_data for a list of features """ + assets_dir = os.path.join(assets_root_dir, "assets") if assets_root_dir is not None else "assets" + with open(os.path.join(assets_dir, "per_section.json"), encoding="utf-8") as f: + json_data = json.load(f) + sorted_features=['impella', 'ejection_fraction', 'pacemaker',