Skip to content

Latest commit

 

History

History
58 lines (39 loc) · 1.95 KB

File metadata and controls

58 lines (39 loc) · 1.95 KB

Sentence-byt5

This project explores the use of ByT5 models for sentence embeddings, with a focus on evaluating their performance on semantic textual similarity tasks.

Overview

ByT5 is a byte-level transformer model that operates directly on raw bytes rather than tokens. This project provides utilities for:

  1. Generating sentence embeddings using ByT5 models
  2. Evaluating these embeddings on the STS-B (Semantic Textual Similarity Benchmark) dataset

Installation

Install the required dependencies:

pip install transformers datasets torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install sentence-transformers scikit-learn numpy pandas

Usage

Evaluating on STS-B

To evaluate a ByT5 model on the STS-B dataset, run:

python evaluate_stsb.py --model google/byt5-small --batch_size 32 --device cuda

Parameters:

  • --model: The ByT5 model to use (default: "google/byt5-small")
  • --batch_size: Batch size for embedding computation (default: 32)
  • --device: Device to use for computation (cuda, mps, cpu). If not specified, the script will automatically use the best available device (CUDA if available, then MPS if available, otherwise CPU).

Available Models

You can use any of the following ByT5 models:

  • google/byt5-small (300M parameters)
  • google/byt5-base (580M parameters)
  • google/byt5-large (1.2B parameters)
  • google/byt5-xl (3.7B parameters)
  • google/byt5-xxl (11B parameters)

Note that larger models require more memory and computation time.

How It Works

The evaluation process:

  1. Loads the STS-B dataset (English subset)
  2. Computes embeddings for each sentence using the ByT5 encoder
  3. Calculates cosine similarity between sentence pairs
  4. Computes Spearman correlation with human similarity scores

Project Structure

  • byt5.py: Core functions for generating embeddings using ByT5 models
  • evaluate_stsb.py: Script for evaluating ByT5 models on the STS-B dataset