Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions echo_prime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@
from tqdm import tqdm
import cv2
import pydicom
import sklearn
import sklearn.metrics
import transformers


# Local module imports
import utils

class EchoPrime:
def __init__(self, device=None, lang='en'):
def __init__(self, device=None, lang='en', weights_root_dir=None, assets_root_dir=None):
"""
Initialize EchoPrime

Expand All @@ -35,10 +33,12 @@ def __init__(self, device=None, lang='en'):
"""

# load language specific files
utils.initialize_language(lang)
utils.initialize_language(lang, assets_root_dir=assets_root_dir)
weights_root_dir = os.path.join(weights_root_dir, "model_data") if weights_root_dir is not None else "model_data"
assets_root_dir = os.path.join(assets_root_dir, "assets") if assets_root_dir is not None else "assets"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load("model_data/weights/echo_prime_encoder.pt",map_location=device)
checkpoint = torch.load(os.path.join(weights_root_dir, "weights", "echo_prime_encoder.pt"), map_location=device)
echo_encoder = torchvision.models.video.mvit_v2_s()
echo_encoder.head[-1] = torch.nn.Linear(echo_encoder.head[-1].in_features, 512)
echo_encoder.load_state_dict(checkpoint)
Expand All @@ -47,7 +47,7 @@ def __init__(self, device=None, lang='en'):
for param in echo_encoder.parameters():
param.requires_grad = False

vc_state_dict = torch.load("model_data/weights/view_classifier.pt")
vc_state_dict = torch.load(os.path.join(weights_root_dir, "weights", "view_classifier.pt"))
view_classifier = torchvision.models.convnext_base()
view_classifier.classifier[-1] = torch.nn.Linear(
view_classifier.classifier[-1].in_features, 11
Expand All @@ -70,19 +70,19 @@ def __init__(self, device=None, lang='en'):
self.lang = lang

# load MIL weights per section
self.MIL_weights = pd.read_csv("assets/MIL_weights.csv")
self.MIL_weights = pd.read_csv(os.path.join(assets_root_dir, "MIL_weights.csv"))
self.non_empty_sections=self.MIL_weights['Section']
self.section_weights=self.MIL_weights.iloc[:,1:].to_numpy()

# Load candidate reports
self.candidate_studies=list(pd.read_csv("model_data/candidates_data/candidate_studies.csv")['Study'])
candidate_embeddings_p1=torch.load("model_data/candidates_data/candidate_embeddings_p1.pt")
candidate_embeddings_p2=torch.load("model_data/candidates_data/candidate_embeddings_p2.pt")
self.candidate_studies=list(pd.read_csv(os.path.join(weights_root_dir, "candidates_data", "candidate_studies.csv"))['Study'])
candidate_embeddings_p1=torch.load(os.path.join(weights_root_dir, "candidates_data", "candidate_embeddings_p1.pt"))
candidate_embeddings_p2=torch.load(os.path.join(weights_root_dir, "candidates_data", "candidate_embeddings_p2.pt"))
self.candidate_embeddings=torch.cat((candidate_embeddings_p1,candidate_embeddings_p2),dim=0)
candidate_reports=pd.read_pickle("model_data/candidates_data/candidate_reports.pkl")
candidate_reports=pd.read_pickle(os.path.join(weights_root_dir, "candidates_data", "candidate_reports.pkl"))
self.candidate_reports = [utils.phrase_decode(vec_phr) for vec_phr in tqdm(candidate_reports)]
self.candidate_labels = pd.read_pickle("model_data/candidates_data/candidate_labels.pkl")
self.section_to_phenotypes = pd.read_pickle("assets/section_to_phenotypes.pkl")
self.candidate_labels = pd.read_pickle(os.path.join(weights_root_dir, "candidates_data", "candidate_labels.pkl"))
self.section_to_phenotypes = pd.read_pickle(os.path.join(assets_root_dir, "section_to_phenotypes.pkl"))

def process_dicoms(self,INPUT):
"""
Expand Down
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ pydicom==2.3.1
pytorch-lightning==2.2.0.post0
PyWavelets==1.4.1
wandb==0.16.3
transformers==4.57.0
transformers==4.57.0
matplotlib
torch
numpy<2.0.0
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from setuptools import setup

setup(
name="echo_prime",
packages=["echo_prime", "utils"]
)
21 changes: 12 additions & 9 deletions utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re
import json
import numpy as np
Expand All @@ -9,9 +10,6 @@
import pydicom as dicom
import torch

with open("assets/per_section.json", encoding="utf-8") as f:
json_data = json.load(f)

_ybr_to_rgb_lut = None

COARSE_VIEWS=['A2C',
Expand Down Expand Up @@ -53,7 +51,7 @@
numerical_pattern = r'(\\d+(\\.\\d+)?)' # Escaped backslashes for integers or floats
string_pattern = r'\\b\\w+.*?(?=\\.)'

def initialize_language(lang='en'):
def initialize_language(lang='en', assets_root_dir: str = None):
"""
Initialize language-specific variables.

Expand All @@ -62,13 +60,14 @@ def initialize_language(lang='en'):
"""
all_phrases, t_list = None, None
global phrases_per_section_list, phrases_per_section_list_org, regex_per_section

assets_dir = os.path.join(assets_root_dir, "assets") if assets_root_dir is not None else "assets"

if lang == 'en':
phrases_file = "assets/all_phr.json"
phrases_file = os.path.join(assets_dir, "all_phr.json")
elif lang == 'it':
phrases_file = "assets/all_phr_it.json"
phrases_file = os.path.join(assets_dir, "all_phr_it.json")
elif lang == 'bs':
phrases_file = "assets/all_phr_bs.json"
phrases_file = os.path.join(assets_dir, "all_phr_bs.json")
# add your translated file here
#elif lang == 'your_language_code':
# phrases_file = "assets/all_phr_{your_language_code}.json"
Expand Down Expand Up @@ -106,11 +105,15 @@ def extract_section(report, section_header):
else:
return "Section not found."

def extract_features(report: str) -> list:
def extract_features(report: str, assets_root_dir: str = None) -> list:
"""
Returns a list of 21 different features
see json_data for a list of features
"""
assets_dir = os.path.join(assets_root_dir, "assets") if assets_root_dir is not None else "assets"
with open(os.path.join(assets_dir, "per_section.json"), encoding="utf-8") as f:
json_data = json.load(f)

sorted_features=['impella',
'ejection_fraction',
'pacemaker',
Expand Down