diff --git a/docs/api/processors.rst b/docs/api/processors.rst index 25de2fece..ccaa4bfd4 100644 --- a/docs/api/processors.rst +++ b/docs/api/processors.rst @@ -33,6 +33,7 @@ Available Processors - ``ImageProcessor``: For image data (e.g., chest X-rays) - ``TextProcessor``: For text/clinical notes data +- ``TupleTimeTextProcessor``: For text paired with temporal information (time-aware text) - ``AudioProcessor``: For audio signal data - ``SignalProcessor``: For general signal data (e.g., EEG, ECG) - ``TimeseriesProcessor``: For time-series data @@ -266,6 +267,7 @@ Common string keys for automatic processor selection: - ``"multilabel"``: For multi-label classification - ``"regression"``: For regression targets - ``"text"``: For text data +- ``"tuple_time_text"``: For text paired with temporal information - ``"image"``: For image data - ``"audio"``: For audio data - ``"signal"``: For signal data @@ -456,6 +458,7 @@ API Reference processors/pyhealth.processors.RegressionLabelProcessor processors/pyhealth.processors.ImageProcessor processors/pyhealth.processors.TextProcessor + processors/pyhealth.processors.TupleTimeTextProcessor processors/pyhealth.processors.AudioProcessor processors/pyhealth.processors.SignalProcessor processors/pyhealth.processors.TimeseriesProcessor diff --git a/docs/api/processors/pyhealth.processors.TupleTimeTextProcessor.rst b/docs/api/processors/pyhealth.processors.TupleTimeTextProcessor.rst new file mode 100644 index 000000000..b83dc0349 --- /dev/null +++ b/docs/api/processors/pyhealth.processors.TupleTimeTextProcessor.rst @@ -0,0 +1,67 @@ +pyhealth.processors.TupleTimeTextProcessor +============================================ + +Processor for tuple time-based text data with temporal information. + +.. autoclass:: pyhealth.processors.TupleTimeTextProcessor + :members: + :undoc-members: + :show-inheritance: + +**Overview** + +``TupleTimeTextProcessor`` handles clinical text paired with temporal information (time differences), enabling automatic modality routing in multimodal fusion pipelines. + +**Input/Output** + +- **Input:** ``Tuple[List[str], List[float]]`` (texts, time differences) +- **Output:** ``Tuple[List[str], torch.Tensor, str]`` (texts, 1D time tensor, modality tag) + +**Use Case** + +The ``type_tag`` parameter enables automatic modality routing without hardcoding feature names in multimodal pipelines: + +- ``type_tag="note"`` routes to text encoder +- ``type_tag="image"`` routes to vision encoder +- ``type_tag="ehr"`` routes to EHR encoder + +**Example Usage** + +.. code-block:: python + + from pyhealth.processors import TupleTimeTextProcessor + + # Initialize processor with modality tag + processor = TupleTimeTextProcessor(type_tag="clinical_note") + + # Patient notes with time differences (hours since admission) + texts = [ + "Patient admitted with chest pain.", + "Follow-up: symptoms improved.", + "Discharge: stable condition." + ] + time_diffs = [0.0, 24.0, 72.0] + + # Process tuple + processed_texts, time_tensor, modality_tag = processor.process((texts, time_diffs)) + + print(time_tensor) # tensor([0., 24., 72.]) + print(modality_tag) # "clinical_note" + +**Multimodal Fusion** + +Use different type tags for automatic routing in multimodal models: + +.. code-block:: python + + # Different modalities with different type tags + note_processor = TupleTimeTextProcessor(type_tag="note") + ehr_processor = TupleTimeTextProcessor(type_tag="ehr") + + # Process different data types + note_texts, note_times, note_tag = note_processor.process((notes, note_time_diffs)) + ehr_texts, ehr_times, ehr_tag = ehr_processor.process((events, event_time_diffs)) + + # Tags enable automatic routing to appropriate encoders + # note_tag="note" -> TextEmbedding encoder + # ehr_tag="ehr" -> EHR encoder diff --git a/examples/text_embedding_tutorial.ipynb b/examples/text_embedding_tutorial.ipynb index 75f405632..112d230ed 100644 --- a/examples/text_embedding_tutorial.ipynb +++ b/examples/text_embedding_tutorial.ipynb @@ -13,7 +13,8 @@ "- Initialize TextEmbedding with Bio_ClinicalBERT\n", "- Demonstrate 128-token chunking for long clinical notes\n", "- Show different pooling modes (none, cls, mean)\n", - "- Verify mask output and backward compatibility" + "- Verify mask output and backward compatibility\n", + "- Use TupleTimeTextProcessor for temporal text data" ] }, { @@ -309,6 +310,89 @@ "print(f\"Outputs identical in eval mode: {is_equal}\")" ] }, + { + "cell_type": "markdown", + "id": "tuple_time_processor", + "metadata": {}, + "source": [ + "## 10. TupleTimeTextProcessor for Temporal Data\n", + "\n", + "The `TupleTimeTextProcessor` handles clinical text paired with temporal information (time differences). This is useful for multimodal models that need to automatically route different modality types." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "tuple_time_basic", + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.processors import TupleTimeTextProcessor\n", + "\n", + "# Initialize processor with type tag for modality routing\n", + "processor = TupleTimeTextProcessor(type_tag=\"clinical_note\")\n", + "\n", + "# Clinical notes with time differences (e.g., hours since admission)\n", + "texts = [\n", + " \"Patient admitted with chest pain.\",\n", + " \"Follow-up: symptoms improved.\",\n", + " \"Discharge: stable condition.\"\n", + "]\n", + "time_diffs = [0.0, 24.0, 72.0] # hours\n", + "\n", + "# Process tuple\n", + "processed_texts, time_tensor, modality_tag = processor.process((texts, time_diffs))\n", + "\n", + "print(f\"Texts: {processed_texts}\")\n", + "print(f\"Time tensor: {time_tensor}\")\n", + "print(f\"Time tensor shape: {time_tensor.shape}\")\n", + "print(f\"Modality tag: '{modality_tag}'\")\n", + "print(f\"\\nProcessor repr: {repr(processor)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "tuple_time_multimodal", + "metadata": {}, + "source": [ + "### Multimodal Fusion Example\n", + "\n", + "The `type_tag` enables automatic routing in multimodal pipelines without hardcoding feature names:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "tuple_time_fusion", + "metadata": {}, + "outputs": [], + "source": [ + "# Different modality types with different processors\n", + "note_processor = TupleTimeTextProcessor(type_tag=\"note\")\n", + "ehr_processor = TupleTimeTextProcessor(type_tag=\"ehr\")\n", + "observation_processor = TupleTimeTextProcessor(type_tag=\"observation\")\n", + "\n", + "# Process different data types\n", + "notes = [\"Admission note\", \"Progress note\"]\n", + "note_times = [0.0, 24.0]\n", + "_, note_tensor, note_tag = note_processor.process((notes, note_times))\n", + "\n", + "ehr_events = [\"Lab ordered\", \"Medication given\"]\n", + "ehr_times = [2.0, 6.0]\n", + "_, ehr_tensor, ehr_tag = ehr_processor.process((ehr_events, ehr_times))\n", + "\n", + "# Tags can be used for automatic routing in models\n", + "print(f\"Note modality tag: '{note_tag}'\")\n", + "print(f\"EHR modality tag: '{ehr_tag}'\")\n", + "print(f\"\\nNote times: {note_tensor}\")\n", + "print(f\"EHR times: {ehr_tensor}\")\n", + "\n", + "# Can combine tensors for temporal modeling\n", + "combined_times = torch.cat([note_tensor, ehr_tensor])\n", + "print(f\"\\nCombined times: {combined_times}\")\n", + "print(f\"Combined shape: {combined_times.shape}\")" + ] + }, { "cell_type": "markdown", "id": "summary", @@ -316,7 +400,7 @@ "source": [ "## Summary\n", "\n", - "The `TextEmbedding` module provides:\n", + "The `TextEmbedding` module and `TupleTimeTextProcessor` provide:\n", "\n", "| Feature | Description |\n", "|---------|-------------|\n", @@ -324,7 +408,9 @@ "| **Pooling** | none/cls/mean modes for different use cases |\n", "| **Mask** | Boolean tensor compatible with TransformerLayer |\n", "| **Guardrails** | max_chunks prevents OOM on long texts |\n", - "| **Compatibility** | return_mask=False for legacy code |" + "| **Compatibility** | return_mask=False for legacy code |\n", + "| **Temporal Processing** | TupleTimeTextProcessor for time-aware text data |\n", + "| **Modality Routing** | Type tags enable automatic multimodal fusion |" ] } ], diff --git a/pyhealth/processors/__init__.py b/pyhealth/processors/__init__.py index 15512c2d7..dde39f17d 100644 --- a/pyhealth/processors/__init__.py +++ b/pyhealth/processors/__init__.py @@ -46,6 +46,7 @@ def get_processor(name: str): from .timeseries_processor import TimeseriesProcessor from .audio_processor import AudioProcessor from .ignore_processor import IgnoreProcessor +from .tuple_time_text_processor import TupleTimeTextProcessor # Expose public API __all__ = [ @@ -64,4 +65,5 @@ def get_processor(name: str): "TextProcessor", "TimeseriesProcessor", "AudioProcessor", + "TupleTimeTextProcessor", ] diff --git a/pyhealth/processors/tuple_time_text_processor.py b/pyhealth/processors/tuple_time_text_processor.py new file mode 100644 index 000000000..ecc59c030 --- /dev/null +++ b/pyhealth/processors/tuple_time_text_processor.py @@ -0,0 +1,106 @@ +"""Processor for tuple time-based text data with temporal information. + +This processor handles clinical notes or text entries paired with temporal +information (time differences), preparing them for multimodal fusion where +different modality types need to be distinguished automatically. + +Input/Output: + Input: Tuple[List[str], List[float]] + - List[str]: Clinical text entries (e.g., discharge notes, progress notes) + - List[float]: Time differences between entries (in any time unit) + + Output: Tuple[List[str], torch.Tensor, str] + - List[str]: Same text entries (unmodified) + - torch.Tensor: 1D float tensor of time differences + - str: Type tag for automatic modality routing (default: "note") + +Use Case: + This processor enables automatic modality bucketing in multimodal pipelines. + The type_tag allows downstream models to automatically route different feature + types to appropriate encoders without hardcoding feature names: + + - type_tag="note" routes to text encoder + - type_tag="image" routes to vision encoder + - type_tag="ehr" routes to EHR encoder + + This design eliminates the need to manually map task schema feature_keys to + specific model components. + +Example: + >>> from pyhealth.processors import TupleTimeTextProcessor + >>> processor = TupleTimeTextProcessor(type_tag="note") + >>> + >>> # Clinical notes with time differences + >>> texts = [ + ... "Patient admitted with chest pain.", + ... "Follow-up: symptoms improved.", + ... "Discharge: stable condition." + ... ] + >>> time_diffs = [0.0, 2.5, 5.0] # hours since admission + >>> + >>> result = processor.process((texts, time_diffs)) + >>> texts_out, time_tensor, tag = result + >>> print(f"Texts: {texts_out}") + >>> print(f"Time tensor: {time_tensor}") + >>> print(f"Type tag: {tag}") + +Args: + type_tag (str): Modality identifier for automatic routing in multimodal + models. Common values: "note", "image", "ehr", "signal". + Default: "note" +""" + +from typing import Any, List, Tuple +import torch +from .base_processor import FeatureProcessor +from . import register_processor + + +@register_processor("tuple_time_text") +class TupleTimeTextProcessor(FeatureProcessor): + """Processes (text, time_diff) tuples for multimodal temporal fusion. + + Converts paired text and temporal data into a format suitable for models + that need to distinguish between different modality types automatically. + """ + + def __init__(self, type_tag: str = "note"): + """Initialize the processor. + + Args: + type_tag: Modality identifier for automatic routing. Default: "note" + """ + super().__init__() + self.type_tag = type_tag + + def process(self, value: Tuple[List[str], List[float]]) -> Tuple[List[str], torch.Tensor, str]: + """Process a tuple of texts and time differences. + + Args: + value: Tuple containing: + - List[str]: Text entries (clinical notes, observations, etc.) + - List[float]: Time differences corresponding to each text entry + + Returns: + Tuple containing: + - List[str]: Original text entries (unmodified) + - torch.Tensor: 1D float tensor of time differences [shape: (N,)] + - str: Type tag for modality routing + + Example: + >>> processor = TupleTimeTextProcessor(type_tag="clinical_note") + >>> texts = ["Note 1", "Note 2"] + >>> times = [0.0, 24.0] # hours + >>> result = processor.process((texts, times)) + >>> print(result[1]) # tensor([0., 24.]) + """ + texts, time_diffs = value + time_tensor = torch.tensor(time_diffs, dtype=torch.float32) + return texts, time_tensor, self.type_tag + + def size(self): + """Return the size of the processor vocabulary (not applicable for this processor).""" + return None + + def __repr__(self): + return f"TupleTimeTextProcessor(type_tag='{self.type_tag}')" diff --git a/tests/test_tuple_time_text_processor.py b/tests/test_tuple_time_text_processor.py new file mode 100644 index 000000000..7a8dfd5b5 --- /dev/null +++ b/tests/test_tuple_time_text_processor.py @@ -0,0 +1,31 @@ +"""Unit test for TupleTimeTextProcessor.""" +import torch +from pyhealth.processors import TupleTimeTextProcessor + + +def test_tuple_time_text_processor(): + """Test TupleTimeTextProcessor basic functionality.""" + # Test default initialization + processor = TupleTimeTextProcessor() + assert processor.type_tag == "note" + + # Test custom type tag + processor = TupleTimeTextProcessor(type_tag="clinical_note") + assert processor.type_tag == "clinical_note" + + # Test processing + texts = ["Patient admitted", "Follow-up visit", "Discharge"] + time_diffs = [0.0, 24.0, 72.0] + result_texts, time_tensor, tag = processor.process((texts, time_diffs)) + + # Verify outputs + assert result_texts == texts + assert isinstance(time_tensor, torch.Tensor) + assert time_tensor.shape == (3,) + assert torch.equal(time_tensor, torch.tensor([0.0, 24.0, 72.0])) + assert tag == "clinical_note" + + # Test registration + from pyhealth.processors import get_processor + ProcessorClass = get_processor("tuple_time_text") + assert ProcessorClass is TupleTimeTextProcessor