Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v6.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
# The promptsource templates spuriously fail without this
args: ["--unsafe"]
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 23.7.0
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 26.3.1
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.0.278'
rev: 'v0.15.11'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/codespell-project/codespell
rev: v2.2.5
rev: v2.4.2
hooks:
- id: codespell
# The promptsource templates spuriously get flagged without this
Expand Down
1 change: 1 addition & 0 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Functions for extracting the hidden states of a model."""

import os
from collections import defaultdict
from dataclasses import InitVar, dataclass, replace
Expand Down Expand Up @@ -202,7 +203,7 @@
a_id = tokenizer.encode(" " + choice, add_special_tokens=False)

# the Llama tokenizer splits off leading spaces
if tokenizer.decode(a_id[0]).strip() == "":

Check failure on line 206 in elk/extraction/extraction.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, macos-latest)

Cannot access member "strip" for type "list[str]"   Member "strip" is unknown (reportGeneralTypeIssues)

Check failure on line 206 in elk/extraction/extraction.py

View workflow job for this annotation

GitHub Actions / run-tests (3.11, macos-latest)

Cannot access member "strip" for type "list[str]"   Member "strip" is unknown (reportGeneralTypeIssues)
a_id_without_space = tokenizer.encode(
choice, add_special_tokens=False
)
Expand Down
2 changes: 1 addition & 1 deletion elk/metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def evaluate_preds(
Returns:
dict: A dictionary containing the accuracy, AUROC, and ECE.
"""
(n, v) = y_logits.shape
n, v = y_logits.shape
assert y_true.shape == (n,)

if ensembling == "full":
Expand Down
6 changes: 3 additions & 3 deletions elk/plotting/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def render(
y=dataset_data["auroc_estimate"],
mode="lines",
name=ensemble,
showlegend=False
if dataset_name != unique_datasets[0]
else True,
showlegend=(
False if dataset_name != unique_datasets[0] else True
),
line=dict(color=color_map[ensemble]),
),
row=row,
Expand Down
8 changes: 5 additions & 3 deletions elk/promptsource/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,11 @@ def _escape_pipe(cls, example):
# Replaces any occurrences of the "|||" separator in the example, which
# which will be replaced back after splitting
protected_example = {
key: value.replace("|||", cls.pipe_protector)
if isinstance(value, str)
else value
key: (
value.replace("|||", cls.pipe_protector)
if isinstance(value, str)
else value
)
for key, value in example.items()
}
return protected_example
Expand Down
3 changes: 1 addition & 2 deletions elk/training/platt_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ class PlattMixin(ABC):
scale: nn.Parameter

@abstractmethod
def __call__(self, *args: Any, **kwds: Any) -> Any:
...
def __call__(self, *args: Any, **kwds: Any) -> Any: ...

def platt_scale(self, labels: Tensor, hiddens: Tensor, max_iter: int = 100):
"""Fit the scale and bias terms to data with LBFGS.
Expand Down
2 changes: 1 addition & 1 deletion elk/training/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def train_supervised(
leace = None

for train_data in data.values():
(n, v, d) = train_data.hiddens.shape
n, v, d = train_data.hiddens.shape
train_h = rearrange(train_data.hiddens, "n v d -> (n v) d")

if erase_paraphrases and v > 1:
Expand Down
2 changes: 1 addition & 1 deletion elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def apply_to_layer(
val_dict = self.prepare_data(device, layer, "val")

first_train_data, *rest = train_dict.values()
(_, v, d) = first_train_data.hiddens.shape
_, v, d = first_train_data.hiddens.shape
if not all(other_data.hiddens.shape[-1] == d for other_data in rest):
raise ValueError("All datasets must have the same hidden state size")

Expand Down
Loading