Skip to content

Commit bb81dc2

Browse files
committed
Add baseline essential-gene classifier with main.py, requirements, and README
1 parent 00a6c93 commit bb81dc2

File tree

3 files changed

+266
-0
lines changed

3 files changed

+266
-0
lines changed
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Essential Gene Classification from DNA Sequences
2+
3+
This project implements a baseline machine learning pipeline to classify bacterial genes as essential or non-essential using DNA sequence information from the **macwiatrak/bacbench-essential-genes-dna** dataset (Hugging Face Datasets).
4+
5+
## Project Overview
6+
7+
The notebook:
8+
- Loads the BacBench essential genes dataset (train/validation/test splits).
9+
- Cleans and simplifies the dataset by removing unused metadata columns.
10+
- Encodes DNA sequences into integer representations using a custom nucleotide mapping.
11+
- Extracts non-overlapping 4-mer (length-4 subsequence) count features.
12+
- Trains a Logistic Regression classifier on the resulting feature vectors.
13+
- Evaluates model performance using accuracy and F1 score on validation and test splits.
14+
15+
This serves as a simple, fast baseline for essential-gene prediction from raw DNA sequences.
16+
17+
## Dataset
18+
19+
The project uses the `macwiatrak/bacbench-essential-genes-dna` dataset loaded via `datasets.load_dataset`.
20+
Each split (train, validation, test) originally contains, among others, the following fields:
21+
- `dna_seq`: DNA sequence of the gene.
22+
- `essential`: Label indicating whether a gene is essential (`"Yes"` or `"No"`).
23+
- Several metadata columns (e.g., `genome_name`, `start`, `end`, `protein_id`, `strand`, `product`, `__index_level_0__`).
24+
25+
In this notebook, the unnecessary metadata columns are dropped, and only `dna_seq` and `essential` are retained for modeling.
26+
27+
## Preprocessing
28+
29+
Key preprocessing steps:
30+
31+
- **Label encoding**
32+
The `essential` field is converted from string to integer:
33+
- `"Yes"``1`
34+
- `"No"``0`
35+
36+
- **DNA character mapping**
37+
Each base in `dna_seq` is mapped to an integer to prioritize efficiency:
38+
- `A → 0`, `T → 1`, `C → 2`, `G → 3`
39+
- Ambiguous bases: `N → 4`, `K → 5`, `R → 6`, `S → 7`, `Y → 8`, `M → 9`, `W → 10`
40+
41+
- **Sequence encoding**
42+
A helper function converts each DNA string into a list of integers using the mapping above, discarding characters not present in the map.
43+
44+
## Feature Extraction
45+
46+
The feature representation is based on **non-overlapping 4-mers**:
47+
48+
- The number of possible symbols is `NUM_BASES = 11`.
49+
- The total number of distinct 4-mers is `NUM_4MERS = 11^4 = 14641`.
50+
- For each encoded sequence, the notebook:
51+
- Iterates with step size `STEP = 4` to form non-overlapping 4-mers.
52+
- Maps each 4-mer to a unique integer index using positional encoding:
53+
\[
54+
\text{kmer\_int} = b_0 \cdot 11^3 + b_1 \cdot 11^2 + b_2 \cdot 11 + b_3
55+
\]
56+
- Increments the corresponding position in a length-14641 count vector.
57+
58+
The resulting dense feature matrix is then converted to a SciPy CSR sparse matrix for memory efficiency.
59+
60+
## Model
61+
62+
The classification model is a **Logistic Regression** from `sklearn.linear_model` with:
63+
64+
- `solver='saga'`
65+
- `max_iter=2000`
66+
- `n_jobs=-1` (parallel training where possible)
67+
68+
Training is performed on the 4-mer count features of the train split.
69+
70+
## Evaluation
71+
72+
Model performance is evaluated on both validation and test splits using:
73+
74+
- **Accuracy** (`sklearn.metrics.accuracy_score`)
75+
- **F1 Score** (`sklearn.metrics.f1_score`)
76+
77+
The notebook prints:
78+
79+
- Validation Accuracy
80+
- Validation F1 Score
81+
- Test Accuracy
82+
- Test F1 Score
83+
84+
These metrics provide an initial benchmark for this simple 4-mer + Logistic Regression approach.
85+
86+
## Requirements
87+
88+
Main Python dependencies:
89+
90+
- `pandas`
91+
- `numpy`
92+
- `scipy`
93+
- `datasets` (Hugging Face Datasets)
94+
- `scikit-learn`
95+
96+
Example installation (if running locally):
97+
`pip install pandas numpy scipy datasets scikit-learn`
98+
99+
## How to Run
100+
101+
1. Open the notebook in Google Colab or your preferred environment.
102+
2. Ensure all required packages are installed.
103+
3. Run the cells in order:
104+
- Dataset loading and column filtering
105+
- Label encoding
106+
- DNA mapping and sequence encoding
107+
- 4-mer feature extraction
108+
- Model training
109+
- Evaluation on validation and test splits
110+
111+
## Possible Extensions
112+
113+
- Use overlapping k-mers or different k-mer sizes to capture more sequence context.
114+
- Try more expressive models (e.g., tree-based methods, neural networks).
115+
- Explore alternative encodings (e.g., one-hot, embeddings, or biologically informed encodings).
116+
- Add cross-validation and hyperparameter tuning for more robust performance estimates.
117+
# Issues with Current Gene Classifier
118+
119+
1. **Class Imbalance**
120+
- Essential genes (`1`) are much rarer than non-essential genes (`0`).
121+
- Logistic Regression tends to predict the majority class, lowering F1 score on validation.
122+
123+
2. **Simple Features**
124+
- Using **non-overlapping 4-mer counts** loses many sequence patterns.
125+
- Linear combinations of k-mer counts may not capture complex dependencies between nucleotides.
126+
127+
3. **Non-Overlapping k-mers**
128+
- Step size of 4 skips many overlapping patterns in the DNA sequence.
129+
- Important motifs or codon patterns might be missed.
130+
131+
4. **Normalization**
132+
- Raw 4-mer counts vary with sequence length.
133+
- Longer sequences dominate the feature vectors, potentially biasing the classifier.
134+
135+
5. **Linear Model Limitations**
136+
- Logistic Regression is a linear classifier.
137+
- Cannot capture non-linear interactions between k-mers that may be biologically relevant.
138+
139+
6. **Potential Data Leakage**
140+
- Some sequences in train/test splits may be very similar or overlapping.
141+
- This can inflate test accuracy artificially, as seen in the high test F1 compared to validation.
142+
143+
7. **Limited Biological Context**
144+
- Only nucleotide sequences are considered.
145+
- Other biological features (gene location, GC content, protein info) are ignored, which may be predictive of essentiality.
146+
147+
8. **Sparse Signal**
148+
- Many 4-mer combinations may never appear, making feature vectors sparse.
149+
- Sparse linear models may struggle to generalize with limited data for certain patterns.
150+
9. **Mapping**
151+
- I did not take into account whether W which is mapped to 10 will be treated as 10 or 1 and 0 which would essentialy derail the classification
152+
## Model Evaluation
153+
154+
The baseline Logistic Regression classifier was evaluated on the validation and test splits using **accuracy** and **F1 score**:
155+
156+
| Split | Accuracy | F1 Score |
157+
|------------|---------|----------|
158+
| Validation | 0.45 | 0.25 |
159+
| Test | 0.90 | 0.80 |
160+
161+
> ⚠️ Note:
162+
> - Validation F1 is low due to class imbalance and simple linear model.
163+
> - The high test metrics may be artificially inflated if some sequences are very similar across splits.
164+
> - This baseline serves as a starting point for further improvements.
165+
166+
167+
## Credits
168+
169+
- **Dataset:** [BacBench Essential Genes DNA Dataset](https://huggingface.co/macwiatrak/bacbench-essential-genes-dna) by Mac Wiatrak et al., hosted on HuggingFace.
170+
- **Libraries & Tools:**
171+
- [HuggingFace `datasets`](https://huggingface.co/docs/datasets) for data loading and preprocessing
172+
- [NumPy](https://numpy.org/) for numerical operations
173+
- [SciPy](https://www.scipy.org/) for scientific computing
174+
- [scikit-learn](https://scikit-learn.org/) for machine learning models and evaluation metrics
175+
- **Inspired by:** Standard bioinformatics workflows for DNA k-mer feature extraction and baseline classification.
176+
-**Workflow & Model Implementation:** Done by Sharat Doddihal
177+
### Note
178+
This was my first attempt at creating a Ml model by myself without too much use from AI.AI has been used here but only for helping with the debugging process.
179+
Overall I am happy with how this turned as this was a great learning experience.There are many fundamental errors that mess with the accuracy.
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import pandas as pd
2+
import numpy as np
3+
from scipy import stats
4+
from datasets import load_dataset
5+
from scipy.sparse import csr_matrix
6+
from sklearn.linear_model import LogisticRegression
7+
from sklearn.metrics import accuracy_score, f1_score
8+
# Load datasets
9+
ds = load_dataset("macwiatrak/bacbench-essential-genes-dna", split="validation")
10+
ds = ds.remove_columns(['genome_name', 'start', 'end', 'protein_id',
11+
'strand', 'product', '__index_level_0__'])
12+
ds1 = load_dataset("macwiatrak/bacbench-essential-genes-dna", split="train")
13+
ds1 = ds1.remove_columns(['genome_name', 'start', 'end', 'protein_id',
14+
'strand', 'product', '__index_level_0__'])
15+
ds2 = load_dataset("macwiatrak/bacbench-essential-genes-dna", split="test")
16+
ds2 = ds2.remove_columns(['genome_name', 'start', 'end', 'protein_id',
17+
'strand', 'product', '__index_level_0__'])
18+
# Convert Yes/No to 1/0
19+
def convert(example):
20+
example["essential"] = [1 if x == "Yes" else 0 for x in example["essential"]]
21+
return example
22+
ds = ds.map(convert)
23+
ds1 = ds1.map(convert)
24+
ds2 = ds2.map(convert)
25+
# DNA base mapping
26+
dna_map = {
27+
"A": 0, "T": 1, "C": 2, "G": 3,
28+
"N": 4, "K": 5, "R": 6, "S": 7,
29+
"Y": 8, "M": 9, "W": 10
30+
}
31+
# encode sequences in each split
32+
for i in range(len(ds1)):
33+
ds1[i]["dna_seq"] = [dna_map[base] for base in ds1[i]["dna_seq"]]
34+
for i in range(len(ds)):
35+
ds[i]["dna_seq"] = [dna_map[base] for base in ds[i]["dna_seq"]]
36+
for i in range(len(ds2)):
37+
ds2[i]["dna_seq"] = [dna_map[base] for base in ds2[i]["dna_seq"]]
38+
# 4-mer encoding utilities
39+
NUM_BASES = len(dna_map)
40+
NUM_4MERS = NUM_BASES ** 4
41+
STEP = 4
42+
def encode_sequence(seq, mapping=dna_map):
43+
return [mapping[base] for base in seq if base in mapping]
44+
def sequence_to_4mer_counts(seq, step=STEP):
45+
counts = np.zeros(NUM_4MERS, dtype=int)
46+
for i in range(0, len(seq) - (step - 1), step):
47+
kmer = seq[i:i + step]
48+
if len(kmer) < step:
49+
continue
50+
kmer_int = (
51+
kmer[0] * NUM_BASES ** 3 +
52+
kmer[1] * NUM_BASES ** 2 +
53+
kmer[2] * NUM_BASES +
54+
kmer[3]
55+
)
56+
counts[kmer_int] += 1
57+
return counts
58+
# Prepare dataset for ML
59+
def prepare_dataset(ds_split):
60+
def _map_dna_sequence_to_integers(batch):
61+
return {"dna_seq": [encode_sequence(seq_str) for seq_str in batch["dna_seq"]]}
62+
ds_processed = ds_split.map(_map_dna_sequence_to_integers, batched=True)
63+
X_dense = np.array([sequence_to_4mer_counts(item["dna_seq"]) for item in ds_processed])
64+
y = np.array([item["essential"][0] for item in ds_processed])
65+
return X_dense, y
66+
X_train_dense, y_train = prepare_dataset(ds1)
67+
X_val_dense, y_val = prepare_dataset(ds)
68+
X_test_dense, y_test = prepare_dataset(ds2)
69+
# sparse conversion
70+
X_train = csr_matrix(X_train_dense)
71+
X_val = csr_matrix(X_val_dense)
72+
X_test = csr_matrix(X_test_dense)
73+
# Train classifier
74+
clf = LogisticRegression(max_iter=2000, solver='saga', n_jobs=-1)
75+
clf.fit(X_train, y_train)
76+
# Evaluation
77+
y_pred_val = clf.predict(X_val)
78+
print("Validation Accuracy:", accuracy_score(y_val, y_pred_val))
79+
print("Validation F1 Score:", f1_score(y_val, y_pred_val))
80+
y_pred_test = clf.predict(X_test)
81+
print("Test Accuracy:", accuracy_score(y_test, y_pred_test))
82+
print("Test F1 Score:", f1_score(y_test, y_pred_test))
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
numpy==1.25.0
2+
pandas==2.0.1
3+
scipy==1.11.0
4+
scikit-learn==1.3.0
5+
datasets==2.16.0

0 commit comments

Comments
 (0)