diff --git a/src/cnlpt/modeling/models/projection_model.py b/src/cnlpt/modeling/models/projection_model.py index 5a134040..8838ca43 100644 --- a/src/cnlpt/modeling/models/projection_model.py +++ b/src/cnlpt/modeling/models/projection_model.py @@ -293,7 +293,7 @@ def forward( outputs = self.encoder(input_ids, **kwargs) - batch_size, seq_len = input_ids.shape + batch_size, seq_len, _ = outputs.last_hidden_state.shape logits = []