diff --git a/.gitignore b/.gitignore index 6e0a7cce..56b311f0 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ __pycache__ venv *-checkpoint.ipynb requirements.txt +python/circuitsvis/.vscode # Node node_modules/ diff --git a/python/circuitsvis/activations.py b/python/circuitsvis/activations.py index 1f634c21..8abd4277 100644 --- a/python/circuitsvis/activations.py +++ b/python/circuitsvis/activations.py @@ -13,6 +13,9 @@ def text_neuron_activations( second_dimension_name: Optional[str] = "Neuron", first_dimension_labels: Optional[List[str]] = None, second_dimension_labels: Optional[List[str]] = None, + first_dimension_default: Optional[int] = 0, + second_dimension_default: Optional[int] = 0, + show_selectors: Optional[bool] = True, ) -> RenderedHTML: """Show activations (colored by intensity) for each token in a text or set of texts. @@ -54,4 +57,7 @@ def text_neuron_activations( secondDimensionName=second_dimension_name, firstDimensionLabels=first_dimension_labels, secondDimensionLabels=second_dimension_labels, + firstDimensionDefault=first_dimension_default, + secondDimensionDefault=second_dimension_default, + showSelectors=show_selectors, ) diff --git a/python/circuitsvis/tests/snapshots/snap_test_activations.py b/python/circuitsvis/tests/snapshots/snap_test_activations.py index aaaf8501..8f9fad15 100644 --- a/python/circuitsvis/tests/snapshots/snap_test_activations.py +++ b/python/circuitsvis/tests/snapshots/snap_test_activations.py @@ -7,22 +7,10 @@ snapshots = Snapshot() -snapshots['TestTextNeuronActivations.test_multi_matches_snapshot 1'] = '''
- ''' +snapshots[ + "TestTextNeuronActivations.test_multi_matches_snapshot 1" +] = """\n """ -snapshots['TestTextNeuronActivations.test_single_matches_snapshot 1'] = ''' - ''' +snapshots[ + "TestTextNeuronActivations.test_single_matches_snapshot 1" +] = """\n """ diff --git a/python/circuitsvis/topk_samples.py b/python/circuitsvis/topk_samples.py index 92d772ff..45058671 100644 --- a/python/circuitsvis/topk_samples.py +++ b/python/circuitsvis/topk_samples.py @@ -1,12 +1,13 @@ """Activations visualizations""" -from typing import List, Optional +import torch +from typing import List, Optional, Union from circuitsvis.utils.render import RenderedHTML, render def topk_samples( tokens: List[List[List[List[str]]]], - activations: List[List[List[List[float]]]], + activations: Union[List[List[List[List[float]]]], torch.Tensor], zeroth_dimension_name: Optional[str] = "Layer", first_dimension_name: Optional[str] = "Neuron", zeroth_dimension_labels: Optional[List[str]] = None, diff --git a/react/src/activations/TextNeuronActivations.tsx b/react/src/activations/TextNeuronActivations.tsx index 7540b0f0..e645462f 100644 --- a/react/src/activations/TextNeuronActivations.tsx +++ b/react/src/activations/TextNeuronActivations.tsx @@ -36,7 +36,10 @@ export function TextNeuronActivations({ firstDimensionName = "Layer", secondDimensionName = "Neuron", firstDimensionLabels, - secondDimensionLabels + secondDimensionLabels, + firstDimensionDefault = 0, + secondDimensionDefault = 0, + showSelectors = true }: TextNeuronActivationsProps) { // If there is only one sample (i.e. if tokens is an array of strings), cast tokens and activations to an array with // a single element @@ -68,8 +71,10 @@ export function TextNeuronActivations({ const [sampleNumbers, setSampleNumbers] = useState