From 63107213a76f079ab76cc6bda2d3e7fdb85365d2 Mon Sep 17 00:00:00 2001 From: Marek Grzesiuk Date: Fri, 8 May 2026 13:33:33 +0100 Subject: [PATCH] make EchoPrime usable in other python projects Allow the user to specify where model weights and assets are stored and add setup.py to allow installation via pip, this helps users to use EchoPrime as part of other projects more easily. --- echo_prime/model.py | 26 +++++++++++++------------- requirements.txt | 5 ++++- setup.py | 6 ++++++ utils/utils.py | 21 ++++++++++++--------- 4 files changed, 35 insertions(+), 23 deletions(-) create mode 100644 setup.py diff --git a/echo_prime/model.py b/echo_prime/model.py index 1f9a633..345690f 100644 --- a/echo_prime/model.py +++ b/echo_prime/model.py @@ -16,8 +16,6 @@ from tqdm import tqdm import cv2 import pydicom -import sklearn -import sklearn.metrics import transformers @@ -25,7 +23,7 @@ import utils class EchoPrime: - def __init__(self, device=None, lang='en'): + def __init__(self, device=None, lang='en', weights_root_dir=None, assets_root_dir=None): """ Initialize EchoPrime @@ -35,10 +33,12 @@ def __init__(self, device=None, lang='en'): """ # load language specific files - utils.initialize_language(lang) + utils.initialize_language(lang, assets_root_dir=assets_root_dir) + weights_root_dir = os.path.join(weights_root_dir, "model_data") if weights_root_dir is not None else "model_data" + assets_root_dir = os.path.join(assets_root_dir, "assets") if assets_root_dir is not None else "assets" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - checkpoint = torch.load("model_data/weights/echo_prime_encoder.pt",map_location=device) + checkpoint = torch.load(os.path.join(weights_root_dir, "weights", "echo_prime_encoder.pt"), map_location=device) echo_encoder = torchvision.models.video.mvit_v2_s() echo_encoder.head[-1] = torch.nn.Linear(echo_encoder.head[-1].in_features, 512) echo_encoder.load_state_dict(checkpoint) @@ -47,7 +47,7 @@ def __init__(self, device=None, lang='en'): for param in echo_encoder.parameters(): param.requires_grad = False - vc_state_dict = torch.load("model_data/weights/view_classifier.pt") + vc_state_dict = torch.load(os.path.join(weights_root_dir, "weights", "view_classifier.pt")) view_classifier = torchvision.models.convnext_base() view_classifier.classifier[-1] = torch.nn.Linear( view_classifier.classifier[-1].in_features, 11 @@ -70,19 +70,19 @@ def __init__(self, device=None, lang='en'): self.lang = lang # load MIL weights per section - self.MIL_weights = pd.read_csv("assets/MIL_weights.csv") + self.MIL_weights = pd.read_csv(os.path.join(assets_root_dir, "MIL_weights.csv")) self.non_empty_sections=self.MIL_weights['Section'] self.section_weights=self.MIL_weights.iloc[:,1:].to_numpy() # Load candidate reports - self.candidate_studies=list(pd.read_csv("model_data/candidates_data/candidate_studies.csv")['Study']) - candidate_embeddings_p1=torch.load("model_data/candidates_data/candidate_embeddings_p1.pt") - candidate_embeddings_p2=torch.load("model_data/candidates_data/candidate_embeddings_p2.pt") + self.candidate_studies=list(pd.read_csv(os.path.join(weights_root_dir, "candidates_data", "candidate_studies.csv"))['Study']) + candidate_embeddings_p1=torch.load(os.path.join(weights_root_dir, "candidates_data", "candidate_embeddings_p1.pt")) + candidate_embeddings_p2=torch.load(os.path.join(weights_root_dir, "candidates_data", "candidate_embeddings_p2.pt")) self.candidate_embeddings=torch.cat((candidate_embeddings_p1,candidate_embeddings_p2),dim=0) - candidate_reports=pd.read_pickle("model_data/candidates_data/candidate_reports.pkl") + candidate_reports=pd.read_pickle(os.path.join(weights_root_dir, "candidates_data", "candidate_reports.pkl")) self.candidate_reports = [utils.phrase_decode(vec_phr) for vec_phr in tqdm(candidate_reports)] - self.candidate_labels = pd.read_pickle("model_data/candidates_data/candidate_labels.pkl") - self.section_to_phenotypes = pd.read_pickle("assets/section_to_phenotypes.pkl") + self.candidate_labels = pd.read_pickle(os.path.join(weights_root_dir, "candidates_data", "candidate_labels.pkl")) + self.section_to_phenotypes = pd.read_pickle(os.path.join(assets_root_dir, "section_to_phenotypes.pkl")) def process_dicoms(self,INPUT): """ diff --git a/requirements.txt b/requirements.txt index 5b91dda..9f49fe3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,7 @@ pydicom==2.3.1 pytorch-lightning==2.2.0.post0 PyWavelets==1.4.1 wandb==0.16.3 -transformers==4.57.0 \ No newline at end of file +transformers==4.57.0 +matplotlib +torch +numpy<2.0.0 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..17762ca --- /dev/null +++ b/setup.py @@ -0,0 +1,6 @@ +from setuptools import setup + +setup( + name="echo_prime", + packages=["echo_prime", "utils"] +) \ No newline at end of file diff --git a/utils/utils.py b/utils/utils.py index 3bfe200..398bec3 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,3 +1,4 @@ +import os import re import json import numpy as np @@ -9,9 +10,6 @@ import pydicom as dicom import torch -with open("assets/per_section.json", encoding="utf-8") as f: - json_data = json.load(f) - _ybr_to_rgb_lut = None COARSE_VIEWS=['A2C', @@ -53,7 +51,7 @@ numerical_pattern = r'(\\d+(\\.\\d+)?)' # Escaped backslashes for integers or floats string_pattern = r'\\b\\w+.*?(?=\\.)' -def initialize_language(lang='en'): +def initialize_language(lang='en', assets_root_dir: str = None): """ Initialize language-specific variables. @@ -62,13 +60,14 @@ def initialize_language(lang='en'): """ all_phrases, t_list = None, None global phrases_per_section_list, phrases_per_section_list_org, regex_per_section - + assets_dir = os.path.join(assets_root_dir, "assets") if assets_root_dir is not None else "assets" + if lang == 'en': - phrases_file = "assets/all_phr.json" + phrases_file = os.path.join(assets_dir, "all_phr.json") elif lang == 'it': - phrases_file = "assets/all_phr_it.json" + phrases_file = os.path.join(assets_dir, "all_phr_it.json") elif lang == 'bs': - phrases_file = "assets/all_phr_bs.json" + phrases_file = os.path.join(assets_dir, "all_phr_bs.json") # add your translated file here #elif lang == 'your_language_code': # phrases_file = "assets/all_phr_{your_language_code}.json" @@ -106,11 +105,15 @@ def extract_section(report, section_header): else: return "Section not found." -def extract_features(report: str) -> list: +def extract_features(report: str, assets_root_dir: str = None) -> list: """ Returns a list of 21 different features see json_data for a list of features """ + assets_dir = os.path.join(assets_root_dir, "assets") if assets_root_dir is not None else "assets" + with open(os.path.join(assets_dir, "per_section.json"), encoding="utf-8") as f: + json_data = json.load(f) + sorted_features=['impella', 'ejection_fraction', 'pacemaker',