diff --git a/.gitignore b/.gitignore index 6e0a7cce..56b311f0 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ __pycache__ venv *-checkpoint.ipynb requirements.txt +python/circuitsvis/.vscode # Node node_modules/ diff --git a/python/circuitsvis/activations.py b/python/circuitsvis/activations.py index 1f634c21..8abd4277 100644 --- a/python/circuitsvis/activations.py +++ b/python/circuitsvis/activations.py @@ -13,6 +13,9 @@ def text_neuron_activations( second_dimension_name: Optional[str] = "Neuron", first_dimension_labels: Optional[List[str]] = None, second_dimension_labels: Optional[List[str]] = None, + first_dimension_default: Optional[int] = 0, + second_dimension_default: Optional[int] = 0, + show_selectors: Optional[bool] = True, ) -> RenderedHTML: """Show activations (colored by intensity) for each token in a text or set of texts. @@ -54,4 +57,7 @@ def text_neuron_activations( secondDimensionName=second_dimension_name, firstDimensionLabels=first_dimension_labels, secondDimensionLabels=second_dimension_labels, + firstDimensionDefault=first_dimension_default, + secondDimensionDefault=second_dimension_default, + showSelectors=show_selectors, ) diff --git a/python/circuitsvis/tests/snapshots/snap_test_activations.py b/python/circuitsvis/tests/snapshots/snap_test_activations.py index aaaf8501..8f9fad15 100644 --- a/python/circuitsvis/tests/snapshots/snap_test_activations.py +++ b/python/circuitsvis/tests/snapshots/snap_test_activations.py @@ -7,22 +7,10 @@ snapshots = Snapshot() -snapshots['TestTextNeuronActivations.test_multi_matches_snapshot 1'] = '''
- ''' +snapshots[ + "TestTextNeuronActivations.test_multi_matches_snapshot 1" +] = """
\n """ -snapshots['TestTextNeuronActivations.test_single_matches_snapshot 1'] = '''
- ''' +snapshots[ + "TestTextNeuronActivations.test_single_matches_snapshot 1" +] = """
\n """ diff --git a/python/circuitsvis/topk_samples.py b/python/circuitsvis/topk_samples.py index 92d772ff..45058671 100644 --- a/python/circuitsvis/topk_samples.py +++ b/python/circuitsvis/topk_samples.py @@ -1,12 +1,13 @@ """Activations visualizations""" -from typing import List, Optional +import torch +from typing import List, Optional, Union from circuitsvis.utils.render import RenderedHTML, render def topk_samples( tokens: List[List[List[List[str]]]], - activations: List[List[List[List[float]]]], + activations: Union[List[List[List[List[float]]]], torch.Tensor], zeroth_dimension_name: Optional[str] = "Layer", first_dimension_name: Optional[str] = "Neuron", zeroth_dimension_labels: Optional[List[str]] = None, diff --git a/react/src/activations/TextNeuronActivations.tsx b/react/src/activations/TextNeuronActivations.tsx index 7540b0f0..e645462f 100644 --- a/react/src/activations/TextNeuronActivations.tsx +++ b/react/src/activations/TextNeuronActivations.tsx @@ -36,7 +36,10 @@ export function TextNeuronActivations({ firstDimensionName = "Layer", secondDimensionName = "Neuron", firstDimensionLabels, - secondDimensionLabels + secondDimensionLabels, + firstDimensionDefault = 0, + secondDimensionDefault = 0, + showSelectors = true }: TextNeuronActivationsProps) { // If there is only one sample (i.e. if tokens is an array of strings), cast tokens and activations to an array with // a single element @@ -68,8 +71,10 @@ export function TextNeuronActivations({ const [sampleNumbers, setSampleNumbers] = useState([ ...Array(samplesPerPage).keys() ]); - const [layerNumber, setLayerNumber] = useState(0); - const [neuronNumber, setNeuronNumber] = useState(0); + const [layerNumber, setLayerNumber] = useState(firstDimensionDefault); + const [neuronNumber, setNeuronNumber] = useState( + secondDimensionDefault + ); useEffect(() => { // When the user changes the samplesPerPage, update the sampleNumbers @@ -96,77 +101,79 @@ export function TextNeuronActivations({ return ( - - - - - - - - - - - - - - - {/* Only show the sample selector if there is more than one sample */} - {numberOfSamples > 1 && ( + {showSelectors && ( + + - - )} - - - {/* Only show the sample per page selector if there is more than one sample */} - {numberOfSamples > 1 && ( - - )} - - + {/* Only show the sample selector if there is more than one sample */} + {numberOfSamples > 1 && ( + + + + + + + )} + + + {/* Only show the sample per page selector if there is more than one sample */} + {numberOfSamples > 1 && ( + + + + + + + )} + + + )}