33import tempfile
44from io import BytesIO
55from pathlib import Path
6- from typing import AsyncIterator , Literal
6+ from typing import AsyncIterator , Callable , Literal
77
88from adrf .views import sync_to_async
99from pydicom import Dataset
2121logger = logging .getLogger ("__name__" )
2222
2323
24+ class WadoFetcher :
25+ """Helper class for WADO fetch operations to avoid code duplication."""
26+
27+ def __init__ (self , source_server : DicomServer , query : dict [str , str ]):
28+ self .source_server = source_server
29+ self .query_ds = QueryDataset .from_dict (query )
30+ self .operator = DicomOperator (source_server )
31+
32+ async def fetch (
33+ self , level : Literal ["STUDY" , "SERIES" , "IMAGE" ], callback : Callable [[Dataset ], None ]
34+ ) -> asyncio .Task :
35+ """Create and return a fetch task based on the level."""
36+ if level == "STUDY" :
37+ return asyncio .create_task (
38+ sync_to_async (self .operator .fetch_study , thread_sensitive = False )(
39+ patient_id = self .query_ds .PatientID ,
40+ study_uid = self .query_ds .StudyInstanceUID ,
41+ callback = callback ,
42+ )
43+ )
44+ elif level == "SERIES" :
45+ return asyncio .create_task (
46+ sync_to_async (self .operator .fetch_series , thread_sensitive = False )(
47+ patient_id = self .query_ds .PatientID ,
48+ study_uid = self .query_ds .StudyInstanceUID ,
49+ series_uid = self .query_ds .SeriesInstanceUID ,
50+ callback = callback ,
51+ )
52+ )
53+ elif level == "IMAGE" :
54+ assert self .query_ds .has ("SeriesInstanceUID" )
55+ return asyncio .create_task (
56+ sync_to_async (self .operator .fetch_image , thread_sensitive = False )(
57+ patient_id = self .query_ds .PatientID ,
58+ study_uid = self .query_ds .StudyInstanceUID ,
59+ series_uid = self .query_ds .SeriesInstanceUID ,
60+ image_uid = self .query_ds .SOPInstanceUID ,
61+ callback = callback ,
62+ )
63+ )
64+ else :
65+ raise ValueError (f"Invalid WADO-RS level: { level } ." )
66+
67+ @staticmethod
68+ async def execute_fetch (fetch_task : asyncio .Task ) -> None :
69+ """Execute the fetch task and handle common errors."""
70+ try :
71+ await asyncio .wait ([fetch_task ])
72+ except RetriableDicomError as err :
73+ raise ServiceUnavailableApiError (str (err ))
74+ except DicomError as err :
75+ raise BadGatewayApiError (str (err ))
76+
77+
2478async def wado_retrieve (
2579 source_server : DicomServer ,
2680 query : dict [str , str ],
@@ -33,50 +87,17 @@ async def wado_retrieve(
3387
3488 Mainly converts a sync callback (by the operator) to an async iterator.
3589 """
36- operator = DicomOperator (source_server )
37- query_ds = QueryDataset .from_dict (query )
38-
3990 loop = asyncio .get_running_loop ()
4091 queue = asyncio .Queue [Dataset ]()
41-
4292 dicom_manipulator = DicomManipulator ()
93+ wado_fetcher = WadoFetcher (source_server , query )
4394
4495 def callback (ds : Dataset ) -> None :
4596 dicom_manipulator .manipulate (ds , pseudonym , trial_protocol_id , trial_protocol_name )
46-
4797 loop .call_soon_threadsafe (queue .put_nowait , ds )
4898
4999 try :
50- if level == "STUDY" :
51- fetch_task = asyncio .create_task (
52- sync_to_async (operator .fetch_study , thread_sensitive = False )(
53- patient_id = query_ds .PatientID ,
54- study_uid = query_ds .StudyInstanceUID ,
55- callback = callback ,
56- )
57- )
58- elif level == "SERIES" :
59- fetch_task = asyncio .create_task (
60- sync_to_async (operator .fetch_series , thread_sensitive = False )(
61- patient_id = query_ds .PatientID ,
62- study_uid = query_ds .StudyInstanceUID ,
63- series_uid = query_ds .SeriesInstanceUID ,
64- callback = callback ,
65- )
66- )
67- elif level == "IMAGE" :
68- assert query_ds .has ("SeriesInstanceUID" )
69- fetch_task = asyncio .create_task (
70- sync_to_async (operator .fetch_image , thread_sensitive = False )(
71- patient_id = query_ds .PatientID ,
72- study_uid = query_ds .StudyInstanceUID ,
73- series_uid = query_ds .SeriesInstanceUID ,
74- image_uid = query_ds .SOPInstanceUID ,
75- callback = callback ,
76- )
77- )
78- else :
79- raise ValueError (f"Invalid WADO-RS level: { level } ." )
100+ fetch_task = await wado_fetcher .fetch (level , callback )
80101
81102 while True :
82103 queue_get_task = asyncio .create_task (queue .get ())
@@ -112,46 +133,15 @@ async def wado_retrieve_nifti(
112133 Returns the generated files (NIfTI and JSON) as tuples in the format
113134 (filename, file content).
114135 """
115- operator = DicomOperator (source_server )
116- query_ds = QueryDataset .from_dict (query )
117136 dicom_images : list [Dataset ] = []
137+ wado_fetcher = WadoFetcher (source_server , query )
118138
119139 def callback (ds : Dataset ) -> None :
120140 dicom_images .append (ds )
121141
122142 try :
123- if level == "SERIES" :
124- fetch_task = asyncio .create_task (
125- sync_to_async (operator .fetch_series , thread_sensitive = False )(
126- patient_id = query_ds .PatientID ,
127- study_uid = query_ds .StudyInstanceUID ,
128- series_uid = query_ds .SeriesInstanceUID ,
129- callback = callback ,
130- )
131- )
132- elif level == "STUDY" :
133- fetch_task = asyncio .create_task (
134- sync_to_async (operator .fetch_study , thread_sensitive = False )(
135- patient_id = query_ds .PatientID ,
136- study_uid = query_ds .StudyInstanceUID ,
137- callback = callback ,
138- )
139- )
140- elif level == "IMAGE" :
141- assert query_ds .has ("SeriesInstanceUID" )
142- fetch_task = asyncio .create_task (
143- sync_to_async (operator .fetch_image , thread_sensitive = False )(
144- patient_id = query_ds .PatientID ,
145- study_uid = query_ds .StudyInstanceUID ,
146- series_uid = query_ds .SeriesInstanceUID ,
147- image_uid = query_ds .SOPInstanceUID ,
148- callback = callback ,
149- )
150- )
151- else :
152- raise ValueError (f"Invalid NIFTI-WADO-RS level: { level } ." )
153-
154- await asyncio .wait ([fetch_task ])
143+ fetch_task = await wado_fetcher .fetch (level , callback )
144+ await WadoFetcher .execute_fetch (fetch_task )
155145
156146 with tempfile .TemporaryDirectory () as temp_dir :
157147 temp_path = Path (temp_dir )
0 commit comments