-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate_stsb.py
More file actions
125 lines (103 loc) · 4.47 KB
/
evaluate_stsb.py
File metadata and controls
125 lines (103 loc) · 4.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import numpy as np
from datasets import load_dataset
from scipy.stats import spearmanr
from sklearn.metrics.pairwise import cosine_similarity
import time
import argparse
from tqdm import tqdm
from byt5 import byt5_batch_encode, get_device
def load_stsb_dataset():
"""Load the STS-B dataset (English subset)"""
print("Loading STS-B dataset...")
dataset = load_dataset("stsb_multi_mt", "en")
return dataset
def compute_embeddings(sentences, model_name, batch_size=32, device=None):
"""Compute embeddings for a list of sentences"""
if device is None:
device = get_device()
print(f"Computing embeddings using {model_name} on {device}...")
# Load model and tokenizer once
tokenizer = None
model = None
# Process in batches to avoid memory issues
all_embeddings = []
# Create a progress bar for batch processing
for i in tqdm(range(0, len(sentences), batch_size), desc="Computing embeddings", unit="batch"):
batch = sentences[i:i+batch_size]
batch_embeddings = byt5_batch_encode(
batch,
model_name=model_name,
model=model,
tokenizer=tokenizer,
device=device
)
all_embeddings.append(batch_embeddings)
# Concatenate all batch embeddings
embeddings = torch.cat(all_embeddings, dim=0)
return embeddings.cpu().numpy()
def evaluate_sts(model_name="google/byt5-small", batch_size=32, device=None):
"""Evaluate the model on STS-B dataset"""
# Load dataset
dataset = load_stsb_dataset()
# Get test split
test_data = dataset["test"]
# Extract sentence pairs and similarity scores
sentences1 = test_data["sentence1"]
sentences2 = test_data["sentence2"]
gold_scores = np.array(test_data["similarity_score"])
# Normalize gold scores to [0, 1]
gold_scores = gold_scores / 5.0
# Compute embeddings
start_time = time.time()
embeddings1 = compute_embeddings(sentences1, model_name, batch_size, device)
embeddings2 = compute_embeddings(sentences2, model_name, batch_size, device)
embedding_time = time.time() - start_time
# Compute cosine similarities
print("Computing cosine similarities...")
# Normalize embeddings for cosine similarity
print("Normalizing embeddings...")
embeddings1 = embeddings1 / np.linalg.norm(embeddings1, axis=1, keepdims=True)
embeddings2 = embeddings2 / np.linalg.norm(embeddings2, axis=1, keepdims=True)
# Compute cosine similarity
print("Calculating similarities...")
similarities = np.zeros(len(embeddings1))
for i in tqdm(range(len(embeddings1)), desc="Computing similarities", unit="pair"):
similarities[i] = np.sum(embeddings1[i] * embeddings2[i])
# Compute Spearman correlation
correlation, p_value = spearmanr(similarities, gold_scores)
# Print results
print("\nEvaluation Results:")
print(f"Model: {model_name}")
print(f"Spearman correlation: {correlation:.4f}")
print(f"p-value: {p_value:.4f}")
print(f"Time to compute embeddings: {embedding_time:.2f} seconds")
# Print some examples
print("\nExample predictions:")
for i in range(min(5, len(sentences1))):
print(f"Sentence 1: {sentences1[i]}")
print(f"Sentence 2: {sentences2[i]}")
print(f"Predicted similarity: {similarities[i]:.4f}")
print(f"Gold similarity: {gold_scores[i]:.4f}")
print("-" * 50)
return correlation
def main():
parser = argparse.ArgumentParser(description="Evaluate ByT5 model on STS-B dataset")
parser.add_argument("--model", type=str, default="google/byt5-small",
help="ByT5 model name (default: google/byt5-small)")
parser.add_argument("--batch_size", type=int, default=32,
help="Batch size for embedding computation (default: 32)")
parser.add_argument("--device", type=str, default=None,
help="Device to use (cuda, mps, cpu). If not specified, will use the best available device.")
args = parser.parse_args()
# Set device
device = args.device
if device is None:
device = get_device()
else:
device = torch.device(device)
print(f"Evaluating {args.model} on STS-B dataset using {device}...")
correlation = evaluate_sts(model_name=args.model, batch_size=args.batch_size, device=device)
print(f"\nFinal Spearman correlation: {correlation:.4f}")
if __name__ == "__main__":
main()