Skip to content

YalaLab/strata

Repository files navigation

Strata: human-level information extraction from clinical reports with fine-tuned language models

Longchao Liu*, Long Lian*, Yiyan Hao, Aidan Pace, Elaine Kim, Nour Homsi, Yash Pershad, Liheng Lai, Thomas Gracie, Ashwin Kishtagari, Peter R Carroll, Alexander G Bick, Anobel Y Odisho, Maggie Chung, Adam Yala

Strata Usage

Strata is an open-source, low-code library designed to streamline the fine-tuning, evaluation, and deployment of large language models (LLMs) for extracting structured data from free-text clinical reports. It enables researchers to easily customize LLMs for clinical information extraction, reducing the need for extensive manual annotation. Strata supports local model hosting, ensuring privacy, reproducibility, and control over model versions. By simplifying the customization process and requiring minimal computational resources, Strata accelerates the development of AI tools for clinical research, making it more accessible to users with limited technical expertise. Strata leverages the Hugging Face Transformers library and unsloth library for implementations of LLMs and efficient fine-tuning.

Table of contents

Installation

Docker

Please ensure your docker has GPU support (Docker installation guide). We recommend using rootless docker for security.

docker build --build-arg http_proxy="${http_proxy}" --build-arg https_proxy="${https_proxy}" --build-arg no_proxy="${no_proxy}" -t strata:0.1 .
docker run -it --rm -e http_proxy="${http_proxy}" -e https_proxy="${https_proxy}" -e no_proxy="${no_proxy}" -v .:/root/strata --gpus all strata:0.1
# Then you can run train/inference commands

If you want to use docker for development, it is advised to mount a clone of this repo into the docker so that updates outside can be reflected inside of the docker container (through -v .:/root/strata added above).

Conda

You need to install Miniconda first and then run the command.

# Ensure you are in the clone of this repo
source ./setup_conda.sh
# Every time you log in, please switch to strata env
conda activate strata
# Then you can run train/inference commands

Development

Install the package as an editable package and install pre-commit to develop.

pip install -e . -v
pip install pre-commit
# Initialize pre-commit
pre-commit install

Problem Definition

Data Preprocessing

Strata supports data in the CSV and JSON formats. We've created an illustrative example dataset, with versions in both formats. Please note that all data is synthetic, for the sole purpose of demonstrating how to use Strata.

CSV

Please preprocess your data into CSV format and include only the following columns: "Accession Number", "Report Text", and ground truth values (inference will run without ground truth labels, but they are required for training and evaluation).

Accession Number Report Text Source_RB Source_LB Cancer_RB Cancer_LB
985440 ... 1 0 1 0
503958 ... 1 1 0 1
894772 ... 0 0 0 0

In addition to the accession numbers and report texts, we have ground truth labels for two tasks: tissue source and cancer diagnosis. The tissue source can be any combination of left breast and right breast. In the columns Source_LB and Source_RB, the value is 1 if the tissue is examined and 0 if it is absent. In the columns Cancer_LB and Cancer_RB, the value is 1 if a tissue source is examined and cancer is deterministically diagnosed and 0 otherwise.

You can reference additional synthetic example data at train_0.8_0.1_0.1_0.csv.

JSON

Similarly to CSV, the JSON should contain "Accession Number", "Report Text", and ground truth values. Here is the equivalent of the above example in JSON, where if a tissue source is not investigated we exclude it from the cancer diagnosis:

[
    {
        "Accession Number": 985440,
        "Report Text": ...,
        "Cancer_Diagnosis": "{\"Right Breast\": 1}"
    },
    {
        "Accession Number": 503958,
        "Report Text": ...,
        "Cancer_Diagnosis": "{\"Right Breast\": 0, \"Left Breast\": 1}"
    },
    {
        "Accession Number": 894772,
        "Report Text": ...,
        "Cancer_Diagnosis": "{}"
    }
]

You can reference additional synthetic example data at train_0.8_0.1_0.1_0.json

Splits

The following script may be used to split data into training, validation, and test sets: split_data.py, with bash command run_split.sh.

Prompt Creation

In the dataset above, we care about two tasks: the tissue source and the cancer diagnosis associated with each report text. Each task should be defined by i) a question to the model, ii) the corresponding ground truth labels, which can be in any number of columns, and iii) parse functions, which convert the ground truth labels to the intended LLM response and vice versa. The template file should be formatted as follows:

name # each question should have a unique identifier 
  question: # the text of the question itself 
  gt_columns: # the column names 
  parse_response: # the path to the parse function
  question_name: # question name used for formatting
  response_start: # optional, defaults to "My answer is:"
  preamble: # optional, defaults to "You are an experienced pathologist. Answer the question using the pathology report below. Base the answer on the report only. Do not add any additional information."

Each row in the data will be mapped into a prompt framework which consists of two roles, a user and an assistant. The question field and the value of "Report Text" will be inserted into the prompt framework as part of what the user says:

{preamble}

{report}

{question}   

The assistant responds to the user’s question. This always begins with the response_start. During fine-tuning, the response includes the ground truth. During inference, the model is prompted to complete the response with label it predicts.

Example CSV and JSON templates are provided in this repo. You can download it and modify for your application.

Template and Parsing Script

The parse_response file should contain two methods: the template, gt_to_response, and the parsing script, parse_response. During training, the output of the template is added to the assistant's response after response_start. During inference, the generated prediction is passed into the parsing script. You can reference an example here, which is used for the source question in the CSV data: parse_functions/example/source.py.

In the example, the template gt_to_response({"Source_RB": 1, "Source_LB": 0}) will output "right breast".

def gt_to_response(gt):
  """
  Args:
    gt (dict): a dictionary with (gt column, gt value) pairs from one row in the data  
  Returns:
    str: the intended response for the LLM
  """

In the example, the parsing script parse_response("My answer is: right breast", "My answer is:") will output {"Source_RB": 1, "Source_LB": 0}.

def parse_response(response, response_start):
  """
  Args:
    response (str): the raw LLM response in its entirety
    response_start (str): the prefix to the response, to be removed from response
  Outputs 
    dict: a dictionary with (gt column, gt value) pairs from one row in the data
  """

For JSON data, the default parse function can be used to to predict any field directly.

Experiments

Unsloth library supports fine-tuning numerous LLMs, which we also support. We have tested the codebase on the following models:

An example config shows how to set the data paths, the model architecture, and the training and inference settings.

Zero-Shot

In test mode, the library predicts the labels and evaluates the predictions against ground truth values at args.data.val_set_data_path, logging the overall accuracy and the question accuracies. args.save_path defaults to outputs/ and args.exp_name defaults to example/llama3.1-8b, so the outputs will be saved at outputs/example/llama3.1-8b.

python -m strata.main test configs/example/llama3.1-8b.yaml --inference-mode zero-shot 

Four files will be saved at outputs/example/llama3.1-8b: inference.csv contains the predicted responses of the model, eval.csv contains the comparisons with the ground truth labels, and scores.csv and scores_per_column.csv contain the evaluation metrics per question and per column respectively.

If the ground truth labels are not known, inference mode should be used. The library will only generate the predictions for the report texts at args.data.val_set_data_path.

python -m strata.main inference configs/example/llama3.1-8b.yaml --inference-mode zero-shot 

Please check the inference.csv for inference outputs at outputs/example/llama3.1-8b.

Fine-tuning

To fine-tune the model on the dataset at args.data.train_set_data_path, run the following command. You will be prompted to log-in to WandB, which will track the training loss. The model checkpoint will be saved at outputs/example/llama3.1-8b-ft.

python -m strata.main train configs/example/llama3.1-8b.yaml --opts exp_name example/llama3.1-8b-ft

Questions to be included in fine-tuning can be selected in args.data.questions.

Using the same config, the library will load the fine-tuned model checkpoint and evaluate the accuracy on args.data.val_set_data_path.

python -m strata.main test configs/example/llama3.1-8b.yaml --inference-mode fine-tuned --load-fine-tuned-model outputs/example/llama3.1-8b-ft --opts exp_name example/llama3.1-8b-ft

As in zero-shot, four files will be saved at outputs/example/llama3.1-8b-ft: inference.csv contains the predicted responses of the model, eval.csv contains the comparisons with the ground truth labels, and scores.csv and scores_per_column.csv contain the evaluation metrics per question and per column respectively.

Sweeps

You can loop through hyperparameters using a bash loop. Here is an example that sweeps through the learning rate:

for lr in 0.001 0.0001 0.00001
do
python -m strata.main train configs/example/llama3.1-8b.yaml --opts trainer.learning_rate $lr exp_name example/llama3.1-8b-ft-lr$lr
done

Deployment

Using test set instead of validation set

Set use_test_set in the config to True. The library will generate the predictions for the report texts at args.data.test_set_data_path, and the saved csvs will have suffix _test_set.

Evaluation

By default Strata will calculate the exact match, F1, precision, and recall scores. The function should be defined according to the following docstring:

def metric(gt_column, pred_column, inference_results_df):
    """
    Args:
        gt_column (str): the column containing the ground truth labels 
        pred_column (str): the column containing the predicted labels
        inference_results_df (pandas.DataFrame): the DataFrame at inference.csv, which contains the predicted labels and the ground truth labels 
    Outputs 
        float: the value of the metric for gt_column and pred_column
    """

Here is an example of a function that calculates mean squared error:

from sklearn.metrics import mean_squared_error

def MSE(gt_column, pred_column, inference_results_df):
    # Identify rows with ground truth labels
    keep_if_labeled = ~inference_results_df[gt_column].isna()
    # Filter for rows with labels
    gt = inference_results_df[gt_column][keep_if_labeled]
    pred = inference_results_df[pred_column][keep_if_labeled]
    # Calculate mean squared error using the scikit-learn library
    return mean_squared_error(gt, pred)

Custom metrics can be added to metrics.py, which also contains the default metrics for reference. The metrics specified in the config will be calculated for the predictions.

Inference

To predict the labels for the test set when the ground truth is unknown, inference mode and use_test_set should be used. The library will generate the predictions for the report texts at args.data.test_set_data_path, saving them under args.save_path.

python -m strata.main inference configs/example/llama3.1-8b.yaml --inference-mode fine-tuned --load-fine-tuned-model outputs/example/llama3.1-8b-ft --opts exp_name example/llama3.1-8b-ft use_test_set True 

Examples

CSV

Let's fine-tune on toy data using the CSV example config and template to improve the performance of the model.

First, benchmark zero shot inference:

python -m strata.main test configs/example/llama3.1-8b.yaml --inference-mode zero-shot

This should yield the following results:

  Question  All Correct     F1  Precision  Recall
0   source        100.0  100.0      100.0   100.0
1   cancer         30.0   41.7       75.0    29.2
2  overall         30.0   70.8       87.5    64.6
Note: for 'All Correct', the 'overall' metric is exact match over all questions. For other columns, the metric is averaged over questions.

    Question Category  All Correct     F1  Precision  Recall
0  Source_LB   source        100.0  100.0      100.0   100.0
1  Source_RB   source        100.0  100.0      100.0   100.0
2  Cancer_LB   cancer         50.0   33.3       50.0    25.0
3  Cancer_RB   cancer         80.0   50.0      100.0    33.3

The output for the first question is perfect, but the output for the second question is not perfect because the model, without fine-tuning on our medical dataset, cannot capture some nuances in the data.

Next, try fine-tuning the model:

python -m strata.main train configs/example/llama3.1-8b.yaml --opts exp_name example/llama3.1-8b-ft

Finally, try the new checkpoint on the val set. It should take a few minutes to run.

python -m strata.main test configs/example/llama3.1-8b.yaml --inference-mode fine-tuned --load-fine-tuned-model outputs/example/llama3.1-8b-ft --opts exp_name example/llama3.1-8b-ft

The results should now be:

  Question  All Correct     F1  Precision  Recall
0   source        100.0  100.0      100.0   100.0
1   cancer        100.0  100.0      100.0   100.0
2  overall        100.0  100.0      100.0   100.0
Note: for 'All Correct', the 'overall' metric is exact match over all questions. For other columns, the metric is averaged over questions.

    Question Category  All Correct     F1  Precision  Recall
0  Source_LB   source        100.0  100.0      100.0   100.0
1  Source_RB   source        100.0  100.0      100.0   100.0
2  Cancer_LB   cancer        100.0  100.0      100.0   100.0
3  Cancer_RB   cancer        100.0  100.0      100.0   100.0

The model now has perfect accuracy on the val set, indicating that the fine-tuning process greatly improved the model's performance.

JSON

Let's fine-tune on toy data using the JSON example config and template to improve the perfomance of the model.

First, benchmark zero-shot:

python -m strata.main test configs/example/llama3.1-8b_json.yaml --inference-mode zero-shot

This should give you 0 percent match. Taking a look at outputs/example/llama3.1-8b-json/inference.csv, it seems like the model always outputs both lateralities. You could consider developing a more fine-grained evaluation metric according to Evaluation, or you could fine-tune to help the model understand the prompt better.

Here's the command for fine-tuning:

python -m strata.main train configs/example/llama3.1-8b_json.yaml --opts exp_name example/llama3.1-8b-ft_json

Finally, test out the fine-tuned model:

python -m strata.main test configs/example/llama3.1-8b_json.yaml --inference-mode fine-tuned --load-fine-tuned-model outputs/example/llama3.1-8b-ft_json --opts exp_name example/llama3.1-8b-ft_json

The model now has perfect accuracy on the val set, indicating that the fine-tuning process greatly improved the model's performance:

  Question  All Correct     F1  Precision  Recall
0   cancer        100.0  100.0      100.0   100.0
1  overall        100.0  100.0      100.0   100.0
Note: for 'All Correct', the 'overall' metric is exact match over all questions. For other columns, the metric is averaged over questions.

           Question Category  All Correct     F1  Precision  Recall
0  Cancer_Diagnosis   cancer        100.0  100.0      100.0   100.0

About

Human-Level Information Extraction from Clinical Reports with Fine-Tuned Language Models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published