This is a Rust-based research project for semi-supervised plant disease classification with incremental learning, targeting edge devices.
Source/
├── plantvillage_ssl/ # Main SSL library and CLI (Burn 0.20)
│ ├── src/ # Core Rust library
│ └── gui/ # Tauri desktop app (Svelte 5 + TailwindCSS)
├── incremental_learning/ # Incremental learning workspace (Burn 0.14)
│ ├── crates/ # Library crates (plant-core, plant-dataset, etc.)
│ └── tools/ # CLI tools (train, evaluate, experiment-runner)
└── pytorch_reference/ # Python/PyTorch baseline for comparison
cd plantvillage_ssl
cargo build --release # Build with CUDA (default)
cargo build --release --features cpu # Build CPU-only
cargo test # Run all tests
cargo test test_name # Run single test
cargo test -- --nocapture # Run tests with output
cargo clippy # Lint
cargo fmt # Formatcd incremental_learning
cargo build --release # Build all workspace members
cargo test -p plant-core # Test single crate
cargo test -p plant-training -- test_trainer_update_epoch # Single test
cargo run -p train -- --help # Run train tool
cargo run -p evaluate -- --help # Run evaluate tool
cargo run -p experiment-runner -- --help # Run experiment runnercd plantvillage_ssl/gui
bun install # Install dependencies
bun run dev # Development server
bun run tauri:dev # Run Tauri app in dev mode
bun run tauri:build # Build production app
bun run check # Svelte type checkingImports: Group and order imports as follows:
- Standard library (
std::) - External crates (alphabetical)
- Internal crates/modules (
crate::,super::) - Re-exports at module level
use std::path::PathBuf;
use burn::{config::Config, module::Module};
use serde::{Deserialize, Serialize};
use tracing::info;
use crate::utils::error::{PlantVillageError, Result};Module Documentation: Every module file starts with //! doc comments:
//! Training infrastructure for plant disease classification models.
//!
//! This module provides:
//! - Training loop with epoch management
//! - Loss computation and backpropagation
Naming Conventions:
- Types/Structs:
PascalCase(e.g.,PlantClassifier,TrainingConfig) - Functions/Methods:
snake_case(e.g.,forward_softmax,load_state) - Constants:
SCREAMING_SNAKE_CASE(e.g.,NUM_CLASSES,IMAGE_SIZE) - Generic type parameters: Single uppercase letter with trait bound (e.g.,
B: Backend)
Error Handling: Use thiserror for custom error types:
#[derive(Error, Debug)]
pub enum PlantVillageError {
#[error("Dataset error: {0}")]
Dataset(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
pub type Result<T> = std::result::Result<T, PlantVillageError>;Structs with Derive Macros: Order derives consistently:
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingState { ... }
#[derive(Config, Debug)] // Burn configs
pub struct PlantClassifierConfig { ... }
#[derive(Module, Debug)] // Burn modules
pub struct PlantClassifier<B: Backend> { ... }Default Implementations: Always implement Default for config structs:
impl Default for TrainerConfig {
fn default() -> Self {
Self {
learning_rate: 0.001,
num_epochs: 100,
batch_size: 32,
}
}
}Tests: Place tests in a #[cfg(test)] module at the end of each file:
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_feature_name() {
// Test implementation
}
}Component Structure: Use Svelte 5 runes syntax:
<script lang="ts">
interface Props {
title?: string;
class?: string;
children?: import('svelte').Snippet;
}
let { title, class: className = '', children }: Props = $props();
</script>
<div class="bg-background-light rounded-xl p-6 {className}">
{#if title}
<h3 class="text-lg font-semibold text-white mb-4">{title}</h3>
{/if}
{@render children?.()}
</div>Styling: Use TailwindCSS classes inline. Custom colors defined in tailwind.config.js.
Tauri Commands: Backend commands in src-tauri/src/commands/. Use #[tauri::command] macro.
| Crate | Version | Purpose |
|---|---|---|
| burn | 0.14/0.20 | ML framework (different versions in workspaces) |
| burn-cuda | matching | CUDA backend for GPU |
| burn-ndarray | matching | CPU backend fallback |
| image | 0.25 | Image loading/processing |
| clap | 4.x | CLI argument parsing |
| serde | 1.0 | Serialization |
| thiserror | 1.0 | Error types |
| tracing | 0.1 | Logging |
| anyhow | 1.0 | Error handling in binaries |
Burn Backend Generics: Models are generic over backend:
pub struct PlantClassifier<B: Backend> {
conv1: Conv2d<B>,
// ...
}
impl<B: Backend> PlantClassifier<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> { ... }
}CLI with Clap: Use derive macros for CLI parsing:
#[derive(Parser, Debug)]
#[command(name = "plantvillage_ssl")]
struct Cli {
#[arg(short, long, default_value = "false")]
verbose: bool,
#[command(subcommand)]
command: Commands,
}Logging: Use tracing macros (info!, warn!, debug!):
use tracing::{info, warn};
info!("Training epoch {}: loss={:.4}", epoch, loss);
warn!("No improvement. Patience: {}/{}", counter, max_patience);Run tests with output visible:
cargo test -- --nocaptureRun a specific test:
cargo test test_plant_classifier_output_shape -- --nocaptureTest a specific crate in workspace:
cargo test -p plant-incremental- Two Burn versions:
plantvillage_ssluses Burn 0.20-pre,incremental_learninguses 0.14 - CUDA default: This project targets GPU. Use
--features cpufor CPU-only builds - Image size: Default is 128x128 or 256x256 depending on context
- 38 classes: PlantVillage dataset has 38 disease classes
- Workspace structure:
incremental_learning/is a Cargo workspace with multiple crates - GUI uses Svelte 5: Modern runes syntax (
$props(),$state(), etc.) - New Plant Diseases Dataset: Uses pre-balanced dataset from Kaggle (~87K images)
The project uses the New Plant Diseases Dataset from Kaggle (vipoooool/new-plant-diseases-dataset):
- ~87,000 images (pre-balanced and augmented)
- Same 38 classes as original PlantVillage
- Pre-split into
train/(~70K) andvalid/(~17K) folders - No balancing needed - dataset is already balanced (~2K images per class)
cd /home/warre/Documents/howest/Semester_5/Research_Project/Source
./download_plantvillage.shOr manually via Kaggle CLI:
kaggle datasets download -d vipoooool/new-plant-diseases-dataset
unzip new-plant-diseases-dataset.zip -d data/plantvillage/data/plantvillage/
├── train/
│ ├── Apple___Apple_scab/
│ │ ├── image1.jpg
│ │ └── ...
│ └── ... (38 classes)
└── valid/
├── Apple___Apple_scab/
└── ... (38 classes)
The loader automatically merges train/ and valid/ folders for SSL training.
When the user asks to train the model for SSL (semi-supervised learning), use this workflow:
cd plantvillage_ssl
cargo run --release --bin plantvillage_ssl -- train \
--epochs 30 \
--cuda \
--labeled-ratio 0.2Key parameters:
--labeled-ratio 0.2= 20% for CNN training, 60% reserved for SSL stream, 10% validation, 10% test--epochs 30= Sufficient for initial model (can increase if needed)--cuda= Use GPU acceleration
cd plantvillage_ssl
cargo run --release --bin plantvillage_ssl -- simulate \
--model "output/models/plant_classifier_TIMESTAMP" \
--data-dir "data/plantvillage" \
--cuda \
--days 0 \
--images-per-day 100 \
--labeled-ratio 0.2 \
--retrain-threshold 200 \
--confidence-threshold 0.9Key parameters:
--days 0= Unlimited - process ALL available SSL stream data--images-per-day 100= Batch size per "day" for streaming simulation--labeled-ratio 0.2= MUST match training! Ensures correct data split--retrain-threshold 200= Retrain after accumulating 200 pseudo-labels--confidence-threshold 0.9= Only accept predictions with >90% confidence
cp plantvillage_ssl/output/simulation/plant_classifier_ssl_TIMESTAMP.mpk plantvillage_ssl/best_model.mpk| Pool | Fraction | Purpose |
|---|---|---|
| Test | 10% | Final evaluation (never seen during training) |
| Validation | 10% | Hyperparameter tuning, early stopping |
| Labeled (CNN) | 20% | Initial supervised training |
| Stream (SSL) | 60% | Unlabeled data for pseudo-labeling |
- Initial CNN training: ~70-75% validation accuracy (with only 20% labeled data)
- After SSL pipeline: ~78-85%+ validation accuracy
- Pseudo-label precision: >95%
# Full SSL workflow (train + simulate)
cd plantvillage_ssl
cargo run --release --bin plantvillage_ssl -- train --epochs 30 --cuda --labeled-ratio 0.2
cargo run --release --bin plantvillage_ssl -- simulate \
--model "output/models/plant_classifier_LATEST" \
--data-dir "data/plantvillage" \
--cuda --days 0 --labeled-ratio 0.2
# Check available commands
cargo run --release --bin plantvillage_ssl -- --help
cargo run --release --bin plantvillage_ssl -- train --help
cargo run --release --bin plantvillage_ssl -- simulate --help