Skip to content

ofbread/Hierarchical-long-text-classfication-using-BERT-and-LSTM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Hierarchical Text Classification: LSTM vs BERT

This project implements and compares two deep learning models (LSTM and BERT) for hierarchical text classification on the WOS-11967 dataset. The task involves classifying research paper abstracts into two hierarchical levels: Level 1 (7 classes) and Level 2 (5 classes).

Quick Summary

  • Model Performance: BERT achieves superior validation performance (92.58% L1, 86.21% L2) compared to LSTM (87.26% L1, 73.11% L2)
  • Key Advantage: BERT shows better generalization with only 2.5% gap between training and validation accuracy vs 12.5% for LSTM
  • Training Efficiency: BERT converges in 5 epochs vs LSTM's 40 epochs

Dataset

The WOS-11967 dataset contains:

  • 11,967 research paper abstracts
  • Level 1 labels: 7 classes (0-6)
  • Level 2 labels: 5 classes (0-4)

Data Split

  • Training set: 7,658 samples (64%)
  • Validation set: 1,915 samples (16%)
  • Test set: 2,394 samples (20%)

Model Architecture Details

LSTM Cell Implementation

The custom LSTM cell implements the standard LSTM equations:

  • Input gate: i = σ(Wi·x + Wh·h + b)
  • Forget gate: f = σ(Wf·x + Wh·h + b)
  • Cell state: c = tanh(Wc·x + Wh·h + b)
  • Output gate: o = σ(Wo·x + Wh·h + b)
  • Final states: ct = f @ c_prev + i @ c, ht = o @ tanh(ct)

BERT Architecture

  • Uses BERT-base-uncased (12 layers, 12 attention heads, 768 hidden size)
  • Pooled output is passed through dropout and two linear layers for L1 and L2 classification

Models

1. LSTM Model

  • Architecture: Custom 2-layer LSTM with custom LSTM cell implementation
  • Embeddings: GloVe 6B.100d pre-trained embeddings
  • Hidden size: 128
  • Vocabulary size: 10,000
  • Sequence length: 256
  • Dropout: 0.3
  • Training epochs: 40
  • Learning rate: 0.001
  • Batch size: 32
  • Loss weighting: α = 0.7 (L1), 1-α = 0.3 (L2)

2. BERT Model

  • Base model: BERT-base-uncased
  • Sequence length: 256
  • Dropout: 0.3
  • Training epochs: 5
  • Learning rate: 2e-5
  • Batch size: 32
  • Loss weighting: α = 0.7 (L1), 1-α = 0.3 (L2)

Results

Final Performance Metrics

Metrics LSTM BERT
Training Loss 0.0239 0.1181
L1 Training Accuracy 0.9978 0.9851
L2 Training Accuracy 0.9832 0.9360
L1 Validation Accuracy 0.8726 0.9258
L2 Validation Accuracy 0.7311 0.8621

Training Progress

LSTM Training (40 epochs)

  • Initial Performance (Epoch 1): Loss: 1.7342, Train L1: 25.44%, Train L2: 21.55%, Val L1: 19.74%, Val L2: 21.83%
  • Mid Training (Epoch 20): Loss: 0.0961, Train L1: 99.31%, Train L2: 92.50%, Val L1: 87.05%, Val L2: 71.07%
  • Final Performance (Epoch 40): Loss: 0.0239, Train L1: 99.78%, Train L2: 98.32%, Val L1: 87.26%, Val L2: 73.11%

BERT Training (5 epochs)

  • Initial Performance (Epoch 1): Loss: 0.9968, Train L1: 75.58%, Train L2: 37.69%, Val L1: 90.39%, Val L2: 68.51%
  • Final Performance (Epoch 5): Loss: 0.1181, Train L1: 98.51%, Train L2: 93.60%, Val L1: 92.58%, Val L2: 86.21%

Plots

Training Metrics

Training Accuracy Comparison Training accuracy progression for both L1 and L2 labels across epochs for LSTM and BERT models.

Validation Accuracy Comparison Validation accuracy trends

Loss Comparison Training loss curves comparing LSTM's gradual decrease over 40 epochs vs BERT's rapid convergence in 5 epochs.

BERT Attention Analysis

The attention visualizations provide insights into what the BERT model focuses on when making predictions.

Each attention visualization includes:

  • Heatmap: Full attention matrix showing token-to-token attention weights
  • Top 10 Tokens: The most attended tokens by the [CLS] token, indicating key information used for classification

Correct Predictions:

Correct Prediction Example

Text: Medical abstract about allergic rhinitis and asthma treatment

  • True Labels: L1: 5 (Medical Science), L2: 1
  • Predicted Labels: L1: 5 (Medical Science), L2: 1

Layer 0 Attention - Correct Prediction Early layer attention patterns showing initial token interactions

Layer 8 Attention - Correct Prediction Intermediate semantic understanding

Layer 11 Attention - Correct Prediction Final layer attention before classification, showing the most refined understanding

Incorrect Predictions:

Incorrect Prediction Example

Text: Abstract about a Zn-based metal-organic framework synthesis

  • True Labels: L1: 5 (Medical Science), L2: 3
  • Predicted Labels: L1: 6 (Biochemstry), L2: 2

Layer 0 Attention - Incorrect Prediction Input layer

Layer 8 Attention - Incorrect Prediction Middle layer

Layer 11 Attention - Incorrect Prediction Scattered Final layer

Attention Analysis:

  • Early layers (Layer 0=input) focus on local token relationships, captures initial token-level interactions
  • Middle layers (Layer 8) represents intermediate semantic understanding and contextual relationships.
  • Final layers (Layer 11=output) show refined attention to key classification-relevant tokens
  • Correct predictions show more focused attention patterns compared to incorrect ones (scattered attention).
  • Different layers capture different levels of abstraction, from local patterns to global semantic understanding

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages