This project implements and compares two deep learning models (LSTM and BERT) for hierarchical text classification on the WOS-11967 dataset. The task involves classifying research paper abstracts into two hierarchical levels: Level 1 (7 classes) and Level 2 (5 classes).
- Model Performance: BERT achieves superior validation performance (92.58% L1, 86.21% L2) compared to LSTM (87.26% L1, 73.11% L2)
- Key Advantage: BERT shows better generalization with only 2.5% gap between training and validation accuracy vs 12.5% for LSTM
- Training Efficiency: BERT converges in 5 epochs vs LSTM's 40 epochs
The WOS-11967 dataset contains:
- 11,967 research paper abstracts
- Level 1 labels: 7 classes (0-6)
- Level 2 labels: 5 classes (0-4)
- Training set: 7,658 samples (64%)
- Validation set: 1,915 samples (16%)
- Test set: 2,394 samples (20%)
The custom LSTM cell implements the standard LSTM equations:
- Input gate:
i = σ(Wi·x + Wh·h + b) - Forget gate:
f = σ(Wf·x + Wh·h + b) - Cell state:
c = tanh(Wc·x + Wh·h + b) - Output gate:
o = σ(Wo·x + Wh·h + b) - Final states:
ct = f @ c_prev + i @ c,ht = o @ tanh(ct)
- Uses BERT-base-uncased (12 layers, 12 attention heads, 768 hidden size)
- Pooled output is passed through dropout and two linear layers for L1 and L2 classification
- Architecture: Custom 2-layer LSTM with custom LSTM cell implementation
- Embeddings: GloVe 6B.100d pre-trained embeddings
- Hidden size: 128
- Vocabulary size: 10,000
- Sequence length: 256
- Dropout: 0.3
- Training epochs: 40
- Learning rate: 0.001
- Batch size: 32
- Loss weighting: α = 0.7 (L1), 1-α = 0.3 (L2)
- Base model: BERT-base-uncased
- Sequence length: 256
- Dropout: 0.3
- Training epochs: 5
- Learning rate: 2e-5
- Batch size: 32
- Loss weighting: α = 0.7 (L1), 1-α = 0.3 (L2)
| Metrics | LSTM | BERT |
|---|---|---|
| Training Loss | 0.0239 | 0.1181 |
| L1 Training Accuracy | 0.9978 | 0.9851 |
| L2 Training Accuracy | 0.9832 | 0.9360 |
| L1 Validation Accuracy | 0.8726 | 0.9258 |
| L2 Validation Accuracy | 0.7311 | 0.8621 |
- Initial Performance (Epoch 1): Loss: 1.7342, Train L1: 25.44%, Train L2: 21.55%, Val L1: 19.74%, Val L2: 21.83%
- Mid Training (Epoch 20): Loss: 0.0961, Train L1: 99.31%, Train L2: 92.50%, Val L1: 87.05%, Val L2: 71.07%
- Final Performance (Epoch 40): Loss: 0.0239, Train L1: 99.78%, Train L2: 98.32%, Val L1: 87.26%, Val L2: 73.11%
- Initial Performance (Epoch 1): Loss: 0.9968, Train L1: 75.58%, Train L2: 37.69%, Val L1: 90.39%, Val L2: 68.51%
- Final Performance (Epoch 5): Loss: 0.1181, Train L1: 98.51%, Train L2: 93.60%, Val L1: 92.58%, Val L2: 86.21%
Training accuracy progression for both L1 and L2 labels across epochs for LSTM and BERT models.
Training loss curves comparing LSTM's gradual decrease over 40 epochs vs BERT's rapid convergence in 5 epochs.
The attention visualizations provide insights into what the BERT model focuses on when making predictions.
Each attention visualization includes:
- Heatmap: Full attention matrix showing token-to-token attention weights
- Top 10 Tokens: The most attended tokens by the [CLS] token, indicating key information used for classification
Correct Predictions:
Text: Medical abstract about allergic rhinitis and asthma treatment
- True Labels: L1: 5 (Medical Science), L2: 1
- Predicted Labels: L1: 5 (Medical Science), L2: 1
Early layer attention patterns showing initial token interactions
Intermediate semantic understanding
Final layer attention before classification, showing the most refined understanding
Incorrect Predictions:
Text: Abstract about a Zn-based metal-organic framework synthesis
- True Labels: L1: 5 (Medical Science), L2: 3
- Predicted Labels: L1: 6 (Biochemstry), L2: 2
- Early layers (Layer 0=input) focus on local token relationships, captures initial token-level interactions
- Middle layers (Layer 8) represents intermediate semantic understanding and contextual relationships.
- Final layers (Layer 11=output) show refined attention to key classification-relevant tokens
- Correct predictions show more focused attention patterns compared to incorrect ones (scattered attention).
- Different layers capture different levels of abstraction, from local patterns to global semantic understanding



