+ The TransitIQ is an advanced machine learning tool developed for exoplanet classification research.
+ It leverages state-of-the-art ensemble algorithms to assist astronomers in identifying potential exoplanets.
+
+
+
+
+
Objective
+
+ The primary goal is to automate the classification of transit signals. By analyzing light curves and orbital parameters,
+ the model categorizes candidates into three distinct classes:
+
+
+
CONFIRMED (2): Verified exoplanets.
+
CANDIDATE (1): Potential planets requiring further observation.
+
FALSE POSITIVE (0): Signals caused by other astrophysical phenomena.
+
+
+
+
+
Technical Architecture
+
+ Built on Scikit-learn, the system employs a Stacking Ensemble Classifier.
+ This combines the strengths of Random Forest and XGBoost,
+ orchestrated by a Logistic Regression meta-classifier.
+
+ The backend is powered by FastAPI, serving a robust API that processes user inputs
+ and returns real-time predictions with probability confidence scores.
+
+
+
+
+
Input Parameters
+
The model requires 13 specific orbital and transit features derived from NASA's Kepler mission data:
+
+
+
+
+
+
Feature
+
Unit
+
Description
+
+
+
+
+
Orbital Period
+
Days
+
Time taken to complete one full orbit.
+
+
+
Transit Epoch
+
BJD
+
Time of the center of the first transit.
+
+
+
Transit Depth
+
ppm
+
Fraction of stellar flux lost during transit.
+
+
+
Planet Radius
+
Earth Radii
+
Estimated radius of the planet.
+
+
+
Semi-Major Axis
+
AU
+
Average distance from the host star.
+
+
+
Inclination
+
Degrees
+
Angle of the orbital plane.
+
+
+
Equilibrium Temp
+
Kelvin
+
Theoretical surface temperature.
+
+
+
Insolation Flux
+
Earth Flux
+
Incident solar radiation.
+
+
+
Impact Parameter
+
-
+
Sky-projected distance at conjunction.
+
+
+
Radius Ratio
+
-
+
Ratio of planet radius to star radius.
+
+
+
Stellar Density
+
g/cm³
+
Density of the host star.
+
+
+
Planet-Star Dist
+
Stellar Radii
+
Distance scaled by star size.
+
+
+
Num Transits
+
Count
+
Total number of transit events observed.
+
+
+
+
+
+
+
+
Disclaimer
+
+ While this model achieves high accuracy on the validation set, no ML model is infallible.
+ Predictions should be used as a preliminary screening tool rather than absolute confirmation.
+ The model is optimized for Kepler-like data distributions.
+
+ Harness the power of Machine Learning to classify exoplanets.
+ Input transit data and let our advanced ensemble model determine if it's a
+ Confirmed Planet, Candidate, or False Positive.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Input Parameters
+
Enter the transit details below to generate a prediction.
+
+
+
+
+
+
+
+
+
+
+
+
+
Batch Prediction
+
Upload a CSV file to run predictions on multiple samples at once.
+
+
+
+
CSV File Requirements
+
+
+ Important: Your CSV file must have the following columns in the exact order:
+
+
+
koi_period - Orbital Period (days)
+
koi_time0bk - Transit Epoch (BJD)
+
koi_depth - Transit Depth (ppm)
+
koi_prad - Planet Radius (Earth radii)
+
koi_sma - Semi-Major Axis (AU)
+
koi_incl - Inclination (deg)
+
koi_teq - Equilibrium Temp (K)
+
koi_insol - Insolation Flux (Earth flux)
+
koi_impact - Impact Parameter
+
koi_ror - Planet/Star Radius Ratio
+
koi_srho - Stellar Density (g/cm³)
+
koi_dor - Planet-Star Distance (R★)
+
koi_num_transits - Number of Transits
+
+
+
+ If the CSV file does not have the correct columns in the exact order, predictions will fail or produce incorrect results.
+
- The TransitIQ is an advanced machine learning tool developed for exoplanet classification research.
- It leverages state-of-the-art ensemble algorithms to assist astronomers in identifying potential exoplanets.
-
-
-
-
-
Objective
-
- The primary goal is to automate the classification of transit signals. By analyzing light curves and orbital parameters,
- the model categorizes candidates into three distinct classes:
-
-
-
CONFIRMED (2): Verified exoplanets.
-
CANDIDATE (1): Potential planets requiring further observation.
-
FALSE POSITIVE (0): Signals caused by other astrophysical phenomena.
-
-
-
-
-
Technical Architecture
-
- Built on Scikit-learn, the system employs a Stacking Ensemble Classifier.
- This combines the strengths of Random Forest and XGBoost,
- orchestrated by a Logistic Regression meta-classifier.
-
- The backend is powered by Flask, serving a robust API that processes user inputs
- and returns real-time predictions with probability confidence scores.
-
-
-
-
-
Input Parameters
-
The model requires 13 specific orbital and transit features derived from NASA's Kepler mission data:
-
-
-
-
-
-
Feature
-
Unit
-
Description
-
-
-
-
-
Orbital Period
-
Days
-
Time taken to complete one full orbit.
-
-
-
Transit Epoch
-
BJD
-
Time of the center of the first transit.
-
-
-
Transit Depth
-
ppm
-
Fraction of stellar flux lost during transit.
-
-
-
Planet Radius
-
Earth Radii
-
Estimated radius of the planet.
-
-
-
Semi-Major Axis
-
AU
-
Average distance from the host star.
-
-
-
Inclination
-
Degrees
-
Angle of the orbital plane.
-
-
-
Equilibrium Temp
-
Kelvin
-
Theoretical surface temperature.
-
-
-
Insolation Flux
-
Earth Flux
-
Incident solar radiation.
-
-
-
Impact Parameter
-
-
-
Sky-projected distance at conjunction.
-
-
-
Radius Ratio
-
-
-
Ratio of planet radius to star radius.
-
-
-
Stellar Density
-
g/cm³
-
Density of the host star.
-
-
-
Planet-Star Dist
-
Stellar Radii
-
Distance scaled by star size.
-
-
-
Num Transits
-
Count
-
Total number of transit events observed.
-
-
-
-
-
-
-
-
Disclaimer
-
- While this model achieves high accuracy on the validation set, no ML model is infallible.
- Predictions should be used as a preliminary screening tool rather than absolute confirmation.
- The model is optimized for Kepler-like data distributions.
-
- Harness the power of Machine Learning to classify exoplanets.
- Input transit data and let our advanced ensemble model determine if it's a
- Confirmed Planet, Candidate, or False Positive.
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
Input Parameters
-
Enter the transit details below to generate a prediction.
-
-
-
-
-
-
-
-
-
-
-
-
-
Analysis Complete
-
CONFIRMED
-
-
-
-
-
-
-
-
-
-{% endblock %}
\ No newline at end of file
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..d4839a6
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1 @@
+# Tests package
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..f810f68
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,73 @@
+"""
+conftest.py – session-scoped fixtures for the TransitIQ test suite.
+
+Strategy
+--------
+The FastAPI app loads ML model artifacts (pipe.pkl, column_names.pkl) during
+its lifespan startup via ``initialize_artifacts()``. In CI those files are not
+present, and downloading them from Hugging Face would be slow and fragile.
+
+We therefore:
+1. Stub out the ``models`` package so the bare import doesn't fail.
+2. Patch ``initialize_artifacts`` to return lightweight MagicMock objects.
+3. Expose a ``TestClient`` fixture whose lifespan uses those mocks.
+"""
+
+import sys
+import io
+import pytest
+import numpy as np
+from unittest.mock import MagicMock, patch
+from fastapi.testclient import TestClient
+
+# ── 1. Stub heavy / optional dependencies before the app is imported ─────────
+# Prevents import errors in CI where model artifacts or HF credentials may
+# not be available.
+_models_stub = MagicMock()
+sys.modules.setdefault("models", _models_stub)
+sys.modules.setdefault("models.download_from_hf", _models_stub)
+
+# ── 2. Column order must exactly match app/schema/validate.py ────────────────
+COLUMN_NAMES = np.array([
+ "koi_period",
+ "koi_time0bk",
+ "koi_depth",
+ "koi_prad",
+ "koi_sma",
+ "koi_incl",
+ "koi_teq",
+ "koi_insol",
+ "koi_impact",
+ "koi_ror",
+ "koi_srho",
+ "koi_dor",
+ "koi_num_transits",
+])
+
+
+@pytest.fixture(scope="session")
+def mock_pipe():
+ """A minimal sklearn-compatible pipeline mock."""
+ pipe = MagicMock()
+ # Return a label and probability vector scaled by the number of rows so
+ # both single-row (/predict) and multi-row (/predict/batch) calls work.
+ pipe.predict.side_effect = lambda x: np.array([2] * len(x))
+ pipe.predict_proba.side_effect = (
+ lambda x: np.array([[0.05, 0.15, 0.80]] * len(x))
+ )
+ return pipe
+
+
+@pytest.fixture(scope="session")
+def client(mock_pipe):
+ """
+ Session-scoped TestClient whose lifespan uses the mock pipeline.
+ ``scope="session"`` means the app boots once for the entire test run,
+ which mirrors real usage and keeps the suite fast.
+ """
+ with patch("app.app.initialize_artifacts", return_value=(mock_pipe, COLUMN_NAMES)):
+ # Import deferred so the sys.modules stubs are already in place.
+ from app.app import app # noqa: PLC0415
+
+ with TestClient(app) as c:
+ yield c
diff --git a/tests/test_api.py b/tests/test_api.py
new file mode 100644
index 0000000..16b5a44
--- /dev/null
+++ b/tests/test_api.py
@@ -0,0 +1,283 @@
+"""
+test_api.py – integration and schema tests for the TransitIQ FastAPI app.
+
+Run from the project root:
+ pytest tests/ -v
+
+The TestClient and all mocking are set up in conftest.py.
+"""
+
+import io
+import pytest
+import pandas as pd
+from pydantic import ValidationError
+
+# ── Shared test data ──────────────────────────────────────────────────────────
+
+#: Column order matches app/schema/validate.py and conftest.COLUMN_NAMES.
+_COLUMNS = [
+ "koi_period", "koi_time0bk", "koi_depth", "koi_prad", "koi_sma",
+ "koi_incl", "koi_teq", "koi_insol", "koi_impact", "koi_ror",
+ "koi_srho", "koi_dor", "koi_num_transits",
+]
+
+#: A valid, boundary-safe payload for /predict.
+VALID_PAYLOAD: dict = {
+ "koi_period": 10.0,
+ "koi_time0bk": 2454834.0,
+ "koi_depth": 500.0,
+ "koi_prad": 1.5,
+ "koi_sma": 0.1,
+ "koi_incl": 85.0,
+ "koi_teq": 400.0,
+ "koi_insol": 2.0,
+ "koi_impact": 0.3,
+ "koi_ror": 0.05,
+ "koi_srho": 1.2,
+ "koi_dor": 10.0,
+ "koi_num_transits": 5,
+}
+
+_VALID_LABELS = {"FALSE POSITIVE", "CANDIDATE", "CONFIRMED"}
+
+
+def _make_csv(rows: list | None = None) -> bytes:
+ """Build a well-formed CSV payload from a list of row dicts."""
+ if rows is None:
+ rows = [VALID_PAYLOAD]
+ return pd.DataFrame(rows, columns=_COLUMNS).to_csv(index=False).encode()
+
+
+# ── Health check ──────────────────────────────────────────────────────────────
+
+class TestHealthRoute:
+ def test_status_code(self, client):
+ assert client.get("/health").status_code == 200
+
+ def test_body_title(self, client):
+ assert client.get("/health").json()["title"] == "TransitIQ"
+
+ def test_body_status_message(self, client):
+ assert client.get("/health").json()["status"] == "All systems operational"
+
+
+# ── Static / HTML routes ──────────────────────────────────────────────────────
+
+class TestStaticRoutes:
+ def test_home_returns_200(self, client):
+ assert client.get("/").status_code == 200
+
+ def test_home_content_type(self, client):
+ assert "text/html" in client.get("/").headers["content-type"]
+
+ def test_about_returns_200(self, client):
+ assert client.get("/about").status_code == 200
+
+ def test_about_content_type(self, client):
+ assert "text/html" in client.get("/about").headers["content-type"]
+
+
+# ── POST /predict ─────────────────────────────────────────────────────────────
+
+class TestPredictEndpoint:
+ # --- Happy path -----------------------------------------------------------
+
+ def test_valid_payload_returns_201(self, client):
+ resp = client.post("/predict", json=VALID_PAYLOAD)
+ assert resp.status_code == 201
+
+ def test_response_has_required_keys(self, client):
+ data = client.post("/predict", json=VALID_PAYLOAD).json()
+ assert {"status", "prediction", "probabilities"} <= data.keys()
+
+ def test_status_is_success(self, client):
+ data = client.post("/predict", json=VALID_PAYLOAD).json()
+ assert data["status"] == "success"
+
+ def test_prediction_is_valid_label(self, client):
+ data = client.post("/predict", json=VALID_PAYLOAD).json()
+ assert data["prediction"] in _VALID_LABELS
+
+ def test_probabilities_keys_match_labels(self, client):
+ data = client.post("/predict", json=VALID_PAYLOAD).json()
+ assert set(data["probabilities"].keys()) == _VALID_LABELS
+
+ def test_probabilities_sum_to_one(self, client):
+ data = client.post("/predict", json=VALID_PAYLOAD).json()
+ total = sum(data["probabilities"].values())
+ assert abs(total - 1.0) < 0.01
+
+ # --- Validation errors ----------------------------------------------------
+
+ def test_missing_field_returns_422(self, client):
+ payload = {k: v for k, v in VALID_PAYLOAD.items() if k != "koi_period"}
+ assert client.post("/predict", json=payload).status_code == 422
+
+ def test_koi_period_zero_returns_422(self, client):
+ """koi_period must be > 0."""
+ assert client.post(
+ "/predict", json={**VALID_PAYLOAD, "koi_period": 0}
+ ).status_code == 422
+
+ def test_koi_period_negative_returns_422(self, client):
+ assert client.post(
+ "/predict", json={**VALID_PAYLOAD, "koi_period": -5.0}
+ ).status_code == 422
+
+ def test_koi_incl_above_90_returns_422(self, client):
+ """koi_incl must be ≤ 90."""
+ assert client.post(
+ "/predict", json={**VALID_PAYLOAD, "koi_incl": 91.0}
+ ).status_code == 422
+
+ def test_koi_impact_at_one_returns_422(self, client):
+ """koi_impact must be strictly < 1."""
+ assert client.post(
+ "/predict", json={**VALID_PAYLOAD, "koi_impact": 1.0}
+ ).status_code == 422
+
+ def test_koi_impact_above_one_returns_422(self, client):
+ assert client.post(
+ "/predict", json={**VALID_PAYLOAD, "koi_impact": 1.5}
+ ).status_code == 422
+
+ def test_koi_ror_at_one_returns_422(self, client):
+ """koi_ror must be strictly < 1."""
+ assert client.post(
+ "/predict", json={**VALID_PAYLOAD, "koi_ror": 1.0}
+ ).status_code == 422
+
+ def test_koi_num_transits_zero_returns_422(self, client):
+ """koi_num_transits must be ≥ 1."""
+ assert client.post(
+ "/predict", json={**VALID_PAYLOAD, "koi_num_transits": 0}
+ ).status_code == 422
+
+ def test_string_value_for_float_field_returns_422(self, client):
+ assert client.post(
+ "/predict", json={**VALID_PAYLOAD, "koi_period": "not-a-number"}
+ ).status_code == 422
+
+ def test_empty_body_returns_422(self, client):
+ assert client.post("/predict", json={}).status_code == 422
+
+
+# ── POST /predict/batch ───────────────────────────────────────────────────────
+
+class TestBatchPredictEndpoint:
+ # --- Happy path -----------------------------------------------------------
+
+ def test_valid_csv_returns_201(self, client):
+ resp = client.post(
+ "/predict/batch",
+ files={"file": ("data.csv", io.BytesIO(_make_csv()), "text/csv")},
+ )
+ assert resp.status_code == 201
+
+ def test_response_has_required_keys(self, client):
+ data = client.post(
+ "/predict/batch",
+ files={"file": ("data.csv", io.BytesIO(_make_csv()), "text/csv")},
+ ).json()
+ assert "status" in data
+ assert "predicted_labels" in data
+ assert "predction_probability" in data # kept as-is (typo in app.py)
+
+ def test_single_row_prediction_count(self, client):
+ data = client.post(
+ "/predict/batch",
+ files={"file": ("data.csv", io.BytesIO(_make_csv()), "text/csv")},
+ ).json()
+ assert len(data["predicted_labels"]) == 1
+
+ def test_multi_row_prediction_count(self, client):
+ csv_bytes = _make_csv([VALID_PAYLOAD] * 3)
+ data = client.post(
+ "/predict/batch",
+ files={"file": ("data.csv", io.BytesIO(csv_bytes), "text/csv")},
+ ).json()
+ assert len(data["predicted_labels"]) == 3
+
+ def test_labels_are_valid(self, client):
+ data = client.post(
+ "/predict/batch",
+ files={"file": ("data.csv", io.BytesIO(_make_csv()), "text/csv")},
+ ).json()
+ for label in data["predicted_labels"]:
+ assert label in _VALID_LABELS
+
+ # --- Validation errors ----------------------------------------------------
+
+ def test_non_csv_extension_returns_422(self, client):
+ resp = client.post(
+ "/predict/batch",
+ files={"file": ("data.txt", io.BytesIO(b"some text"), "text/plain")},
+ )
+ assert resp.status_code == 422
+
+ def test_json_extension_returns_422(self, client):
+ resp = client.post(
+ "/predict/batch",
+ files={"file": ("data.json", io.BytesIO(b'{"a":1}'), "application/json")},
+ )
+ assert resp.status_code == 422
+
+ def test_wrong_column_names_returns_422(self, client):
+ bad_csv = b"col_a,col_b,col_c\n1.0,2.0,3.0\n"
+ resp = client.post(
+ "/predict/batch",
+ files={"file": ("data.csv", io.BytesIO(bad_csv), "text/csv")},
+ )
+ assert resp.status_code == 422
+
+ def test_extra_columns_returns_422(self, client):
+ """An extra column must be rejected even if valid columns are present."""
+ extra = pd.DataFrame(
+ [{**VALID_PAYLOAD, "extra_col": 99.0}],
+ columns=[*_COLUMNS, "extra_col"]
+ ).to_csv(index=False).encode()
+ resp = client.post(
+ "/predict/batch",
+ files={"file": ("data.csv", io.BytesIO(extra), "text/csv")},
+ )
+ assert resp.status_code == 422
+
+ def test_non_numeric_values_returns_422(self, client):
+ """String cell values must fail numeric validation."""
+ str_rows = [{k: "abc" for k in _COLUMNS}]
+ bad_csv = pd.DataFrame(str_rows, columns=_COLUMNS).to_csv(index=False).encode()
+ resp = client.post(
+ "/predict/batch",
+ files={"file": ("data.csv", io.BytesIO(bad_csv), "text/csv")},
+ )
+ assert resp.status_code == 422
+
+
+# ── Pydantic schema unit tests ────────────────────────────────────────────────
+
+class TestUserInputSchema:
+ """Direct unit tests for the Pydantic model – no HTTP overhead."""
+
+ def test_valid_input_creates_model(self):
+ from app.schema.validate import UserInput
+ obj = UserInput(**VALID_PAYLOAD)
+ assert obj.koi_num_transits == 5
+
+ @pytest.mark.parametrize("field,bad_value", [
+ ("koi_period", 0), # must be > 0
+ ("koi_period", -1.0),
+ ("koi_time0bk", 100.0), # must be > 2_450_000
+ ("koi_depth", 0), # must be > 0
+ ("koi_incl", 0), # must be > 0
+ ("koi_incl", 91.0), # must be ≤ 90
+ ("koi_impact", -0.1), # must be ≥ 0
+ ("koi_impact", 1.0), # must be < 1
+ ("koi_ror", 0), # must be > 0
+ ("koi_ror", 1.0), # must be < 1
+ ("koi_dor", 0.5), # must be > 1
+ ("koi_num_transits", 0), # must be ≥ 1
+ ])
+ def test_invalid_field_raises_validation_error(self, field, bad_value):
+ from app.schema.validate import UserInput
+ with pytest.raises(ValidationError):
+ UserInput(**{**VALID_PAYLOAD, field: bad_value})