diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 24acda5..848ae4f 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -24,10 +24,7 @@ jobs: - "Dockerfile" - ".dockerignore" - "requirements.txt" - - "fit.py" - - "app.py" - - "templates/**" - - "static/**" + - "app/**" - "models/**" docker-build: diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 4b6c79c..2f7c9c2 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -24,13 +24,12 @@ jobs: with: filters: | src: + - ".github/workflows/docker.yml" - ".github/workflows/python-app.yml" - - "data/**" + - "requirements.txt" + - "app/**" - "models/**" - - "static/**" - - "templates/**" - - "app.py" - - "fit.py" + - "tests/**" build: needs: changes @@ -46,7 +45,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 pytest + pip install flake8 pytest httpx if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | @@ -54,9 +53,5 @@ jobs: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Run Flask app and test with Curl - run: | - nohup python app.py & - sleep 95 - curl -I http://127.0.0.1:5000 - pkill -f "python app.py" + - name: Run tests with pytest + run: pytest diff --git a/.gitignore b/.gitignore index 6a62caa..4d69ff6 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,5 @@ instance/ shit/ *.db .vscode/ -*.pkl \ No newline at end of file +*.pkl +data/test_sample.csv \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 8e1ec5f..3f2108d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -# Dockerfile (exoplanet_classifier) +# Dockerfile (TransitIQ) FROM python:3.11-slim # Avoid interactive prompts @@ -14,10 +14,9 @@ RUN pip install --upgrade pip && pip install -r requirements.txt # copy app source COPY . . -# example env and port -ENV FLASK_APP=app.py -EXPOSE 5000 +# expose port +EXPOSE 8000 -# run the flask app -CMD ["python", "-m", "gunicorn", "--bind", "0.0.0.0:5000", "app:app", "--workers", "2"] +# run the fastapi app +CMD ["uvicorn", "app.app:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/README.md b/README.md index 18ce32a..f9a9757 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ This version blends the strengths of **ensemble learning** with extensive prepro NASA’s exoplanet survey missions (Kepler, K2, and others) have generated thousands of data points using **the transit method** — tracking dips in starlight caused by orbiting planets. These datasets contain both **confirmed exoplanets** and **false positives**, and the aim of this project is to build an AI classifier capable of making preliminary predictions on new candidates. -The classifier runs inside a **Flask-powered web interface**, allowing anyone — from students to researchers — to enter transit parameters and instantly receive a prediction. +The classifier runs inside a **FastAPI-powered web interface**, allowing anyone — from students to researchers — to enter transit parameters and instantly receive a prediction. The goal is to provide a *scientifically meaningful, intuitive, and educational experience* for users interested in exoplanet research. @@ -68,8 +68,8 @@ The goal is to provide a *scientifically meaningful, intuitive, and educational - **Scikit-learn** – Pipeline, scaling, imputation, model stacking, metrics - **XGBoost** – Gradient boosting-based sub-model for ensemble - **Imbalanced-learn (SMOTE)** – Class balancing for improved fairness -- **Flask** – Backend web framework -- **HTML/CSS/JavaScript** – Frontend for the interactive web UI +- **FastAPI** – Backend web framework +- **HTML/CSS/JavaScript (Vanilla)** – Frontend for the interactive web UI - **Jupyter Notebook** – Used as a sandbox (`research.ipynb`) to experiment with different model architectures, hyperparameters, and feature engineering before finalizing `fit.py`. --- @@ -93,14 +93,14 @@ cd "TransitIQ" pip install -r requirements.txt ``` -3. **Run the Flask app** +3. **Run the FastAPI app** ```bash -python app.py +uvicorn app.app:app --host 0.0.0.0 --port 8000 ``` -4. Open your browser and go to `http://127.0.0.1:5000` to access the web interface. +4. Open your browser and go to `http://127.0.0.1:8000` to access the web interface. -5. If you want to close the server, press `Ctrl + C` in the terminal where you have run `app.py` from. +5. If you want to close the server, press `Ctrl + C` in the terminal. --- @@ -114,14 +114,14 @@ The image is built on both ARM64 and AMD64 architectures, so that it can run on 2. Open Terminal and run: ```bash docker pull bytebard101/exoplanet_classifier -docker run --rm -p 5000:5000 bytebard101/exoplanet_classifier:latest +docker run --rm -p 8000:8000 bytebard101/exoplanet_classifier:latest ``` 3. If your machine faces a port conflict, you will need to assign another port. Try to run this: ```bash -docker run --rm -p 5001:5000 bytebard101/exoplanet_classifier:latest +docker run --rm -p 8001:8000 bytebard101/exoplanet_classifier:latest ``` > If you followed Step 2 and the command ran successfully, then **DO NOT** follow this step. -4. The app will be live at localhost:5000. Open your browser and navigate to [http://127.0.0.1:5000](http://127.0.0.1:5000/) (or [http://127.0.0.1:5001](http://127.0.0.1:5000/) if you followed Step 3). +4. The app will be live at localhost:8000. Open your browser and navigate to [http://127.0.0.1:8000](http://127.0.0.1:8000/) (or [http://127.0.0.1:8001](http://127.0.0.1:8001/) if you followed Step 3). Check [Docker Documentation](https://docs.docker.com/) to learn more about Docker and it's commands. @@ -146,31 +146,28 @@ Check [Docker Documentation](https://docs.docker.com/) to learn more about Docke TransitIQ/ ├── .github/ # Folder for GitHub actions │ +├── app/ # FastAPI Application +│ ├── schema/ # Pydantic schemas +│ ├── static/ # Static assets (CSS, JS, images) +│ ├── templates/ # HTML templates (served as static) +│ └── app.py # Main FastAPI entry point +│ ├── data/ │ ├── k2_data.csv │ ├── kepler_data.csv │ └── source.txt │ ├── models/ -│ ├── column_names.pkl # Not included in the repo, run fit.py to generate +│ ├── column_names.pkl │ ├── info.txt -│ └── pipe.pkl # Not included in the repo, run fit.py to generate +│ └── pipe.pkl │ ├── screenshots/ │ -├── static/ -│ ├── materials/ -│ └── script.js -│ -├── templates/ -│ ├── about.html -│ └── index.html -│ ├── .gitignore -├── app.py ├── fit.py ├── LICENSE -├── README.md # You're reading it now +├── README.md ├── requirements.txt └── research.ipynb ``` diff --git a/app.py b/app.py deleted file mode 100644 index 2fe595e..0000000 --- a/app.py +++ /dev/null @@ -1,118 +0,0 @@ -import os -import joblib -import pandas as pd -import numpy as np -from flask import Flask, render_template, request, jsonify -from models.download_from_hf import download - -# --- Configuration --- -MODEL_DIR = "models" -PIPE_PATH = os.path.join(MODEL_DIR, "pipe.pkl") -COLUMNS_PATH = os.path.join(MODEL_DIR, "column_names.pkl") -reverse_mapping = {0: "FALSE POSITIVE", 1: "CANDIDATE", 2: "CONFIRMED"} - -# --- Self-Heal Function --- -def initialize_artifacts(): - """ - Checks if model artifacts exist. If not, runs the training script. - """ - # 1. Ensure the model directory exists - os.makedirs(MODEL_DIR, exist_ok=True) - - # 2. Check for missing files - pipe_exists = os.path.exists(PIPE_PATH) - columns_exists = os.path.exists(COLUMNS_PATH) - - if not pipe_exists or not columns_exists: - print("--- MODEL ARTIFACTS MISSING ---") - if not pipe_exists: - print(f"Missing: {PIPE_PATH}") - if not columns_exists: - print(f"Missing: {COLUMNS_PATH}") - - print("Downloading the saved models from Hugging Face... This may take a moment.") - try: - # Run the `download` function from `models/download_from_hf.py` - download() - print("Download complete. Artifacts generated successfully.") - print("---------------------------------") - except Exception as e: - print(f"\nFATAL: Error during self-heal downloading: {e}") - print("Application cannot start without model artifacts. Exitting......") - exit(1) # Exit if training fails - else: - print("Model artifacts found. Loading...") - -# --- Application Startup --- - -# Run the self-heal check *before* loading models -initialize_artifacts() - -# Load models -try: - pipe = joblib.load(PIPE_PATH) - column_names = joblib.load(COLUMNS_PATH) - print("Models loaded successfully.") -except Exception as e: - print(f"\nFATAL: Error loading model artifacts: {e}") - print("Files might be corrupt. Try deleting the 'models' directory and restarting.") - exit(1) # Exit if loading fails - -# Initialize Flask App -app = Flask(__name__) - -@app.route("/") -def home(): - return render_template("index.html") - -@app.route("/about") -def about(): - return render_template("about.html") - -@app.route("/predict", methods=["POST"]) -def predict(): - try: - # Extract features from the JSON request - raw_features = [ - request.json["orbital-period"], - request.json["transit-epoch"], - request.json["transit-depth"], - request.json["planet-radius"], - request.json["semi-major-axis"], - request.json["inclination"], - request.json["equilibrium-temp"], - request.json["insolation-flux"], - request.json["impact-parameter"], - request.json["radius-ratio"], - request.json["stellar-density"], - request.json["star-distance"], - request.json["num-transits"], - ] - - # Create DataFrame with correct column names - df = pd.DataFrame([raw_features], columns=column_names) - - # Get prediction and probabilities - pred = int(pipe.predict(df)[0]) - proba = pipe.predict_proba(df)[0] - - # Format probabilities for the response - proba_dict = { - reverse_mapping[i]: round(p, 3) for i, p in enumerate(proba) - } - - # Send response - return jsonify( - {"prediction": reverse_mapping[pred], "probabilities": proba_dict} - ) - - except KeyError as e: - print(f"Prediction Error: Missing key in request {e}") - return jsonify({"error": f"Missing feature in request: {e}"}), 400 - except Exception as e: - print(f"Prediction Error: {e}") - return jsonify({"error": str(e)}), 400 - - -if __name__ == "__main__": - app.run(debug=True) \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/app.py b/app/app.py new file mode 100644 index 0000000..35548c0 --- /dev/null +++ b/app/app.py @@ -0,0 +1,189 @@ +from fastapi import FastAPI, Depends, UploadFile +from fastapi.requests import Request +from fastapi.responses import JSONResponse, FileResponse +from fastapi.exceptions import HTTPException +from fastapi.staticfiles import StaticFiles +from .schema.validate import UserInput +from models.download_from_hf import download +from sklearn.pipeline import Pipeline +from typing import Tuple, List + +import joblib +import pandas as pd +import numpy as np +from contextlib import asynccontextmanager +import os +from pathlib import Path + +# --- Configuration --- +MODEL_DIR = "models" +PIPE_PATH = os.path.join(MODEL_DIR, "pipe.pkl") +COLUMNS_PATH = os.path.join(MODEL_DIR, "column_names.pkl") +INDEX_PATH = Path("app","templates","index.html") +ABOUT_PATH = Path("app","templates","about.html") +reverse_mapping = {0: "FALSE POSITIVE", 1: "CANDIDATE", 2: "CONFIRMED"} + +# --- Self-Heal Function --- +def initialize_artifacts() -> Tuple[Pipeline,np.ndarray]: + """ + Checks if model artifacts exist. If not, runs the training script. + """ + # 1. Ensure the model directory exists + os.makedirs(MODEL_DIR, exist_ok=True) + + # 2. Check for missing files + pipe_exists = os.path.exists(PIPE_PATH) + columns_exists = os.path.exists(COLUMNS_PATH) + + if not pipe_exists or not columns_exists: + print("--- MODEL ARTIFACTS MISSING ---") + if not pipe_exists: + print(f"Missing: {PIPE_PATH}") + if not columns_exists: + print(f"Missing: {COLUMNS_PATH}") + + print("Downloading the saved models from Hugging Face... This may take a moment.") + try: + # Run the `download` function from `models/download_from_hf.py` + download() + print("Download complete. Artifacts generated successfully.") + print("---------------------------------") + except Exception as e: + print(f"\nFATAL: Error during self-heal downloading: {e}") + print("Application cannot start without model artifacts. Exitting......") + exit(1) # Exit if training fails + else: + print("Model artifacts found. Loading...") + pipe = joblib.load(PIPE_PATH) + column_names = joblib.load(COLUMNS_PATH) + print("Model artifacts are loaded. Ready for prediction 🚀") + return pipe,column_names + +@asynccontextmanager +async def lifespan(app:FastAPI): + """ + Loads the models at start + """ + pipe,column_names = initialize_artifacts() + + app.state.pipe = pipe + app.state.column_names = column_names + + yield + +app = FastAPI(title="TransitIQ",version="3.0 (ByteBard58_Fork-FastAPI)",lifespan=lifespan) + +# Mount static files +app.mount(name="static",path="/static",app=StaticFiles( + directory=Path("app","static") +)) + +async def get_artifacts(request:Request) -> Tuple[Pipeline,np.ndarray]: + """ + Helper to serve the artifacts in a route + """ + return request.app.state.pipe, request.app.state.column_names + +def validate_csv(target:pd.DataFrame,expected_columns:List) -> pd.DataFrame: + """ + Helper for validating user-uploaded `.csv` files during batch prediction + """ + if target.columns.to_list() != expected_columns: + raise HTTPException( + status_code=422, + detail="The columns of the uploaded .csv file do not match with the expected list of columns or the order of them" + + ) + try: + target.astype(float) + except Exception: + raise HTTPException( + status_code=422, + detail = "Provided values must be numeric (float-compatible)" + ) + + return target + +@app.get("/health") +def health(): + msg = { + "title":"TransitIQ", + "version":"3.0(ByteBard58_Fork-FastAPI)", + "status":"All systems operational" + } + return JSONResponse(content=msg,status_code=200) + +@app.get("/") +def home(): + return FileResponse(INDEX_PATH) + +@app.get("/about") +def about(): + return FileResponse(ABOUT_PATH,status_code=200) + +reverse_mapping = {0: "FALSE POSITIVE", 1: "CANDIDATE", 2: "CONFIRMED"} + +@app.post("/predict") +def predict_with_manual_inputs( + payload:UserInput, + dep:Tuple[Pipeline,np.ndarray] = Depends(get_artifacts) +): + pipe, column_names = dep + column_names:List = column_names.tolist() + payload:dict = payload.model_dump(mode="json") + + sample = [] + for i,(key,val) in enumerate(payload.items()): + if column_names[i] == key: + sample.append(val) + else: + raise ValueError(f"Payload key {key} does not match expected column {column_names[i]}") + sample = np.array(sample).reshape(1,-1) + + label = int(pipe.predict(sample)[0]) + proba:List = pipe.predict_proba(sample)[0].tolist() + + label:str = reverse_mapping.get(label) + proba:dict = {cls:round(proba,3) for cls,proba in zip(reverse_mapping.values(),proba)} + msg = { + "status":"success", + "prediction":label, + "probabilities":proba + } + + return JSONResponse(status_code=201,content=msg) + +@app.post("/predict/batch") +async def predict_with_batch_input( + file:UploadFile, + dep:Tuple[Pipeline,np.ndarray] = Depends(get_artifacts) +): + pipe, column_names = dep + column_names:List = column_names.tolist() + + ext = Path(file.filename).suffix + if ext != ".csv": + raise HTTPException( + status_code=422, detail=f"Only .csv file is allowed as an upload, got {ext} instead" + ) + + try: + df = pd.read_csv(file.file) + except Exception as e: + raise HTTPException(status_code=422, detail=f"Failed to parse CSV file tracking: {str(e)}") + df:np.ndarray = validate_csv(df,column_names).to_numpy() + + sample = df + label:List[float] = pipe.predict(sample).tolist() + proba:List[List[float]] = pipe.predict_proba(sample).tolist() + + label:List[str] = [reverse_mapping.get(l) for l in label] + proba = [[round(value,3) for value in prob] for prob in proba] + proba + msg = { + "status":"batch prediction successful", + "predicted_labels":label, + "predction_probability":proba + } + + return JSONResponse(status_code=201,content=msg) \ No newline at end of file diff --git a/app/schema/__init__.py b/app/schema/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/schema/validate.py b/app/schema/validate.py new file mode 100644 index 0000000..bb0fa59 --- /dev/null +++ b/app/schema/validate.py @@ -0,0 +1,59 @@ +from pydantic import BaseModel, Field +from typing import Annotated + +class UserInput(BaseModel): + koi_period : Annotated[float, Field( + ..., gt=0, description="Orbital Period (days)", + examples= [0.837, 2.154, 9.88, 54.3, 365.2] + )] + koi_time0bk : Annotated[float, Field( + ..., gt=2_450_000, description="Transit Epoch (BJD)", + examples=[2454833.0, 2455002.5] + )] + koi_depth : Annotated[float, Field( + ..., gt = 0, le = 1_000_000, + description="Transit Depth (ppm)", + examples=[150.2, 892.5, 3400.0] + )] + koi_prad : Annotated[float, Field( + ..., description="Planet Radius (Earth radii)", + gt=0, examples=[0.84, 1.2, 2.5, 6.8, 14.3] + )] + koi_sma : Annotated[float, Field( + ..., examples=[0.021, 0.085, 0.234], + gt = 0, description="Semi-Major Axis (AU)" + )] + koi_incl : Annotated[float, Field( + ..., description="Inclination (deg)", + gt=0, le=90, examples=[74.5,20.1,86.4] + )] + koi_teq : Annotated[float, Field( + ..., description= "Equilibrium Temperature (K)", + gt = 0, examples=[312.5, 542.0, 876.3] + )] + koi_insol : Annotated[float, Field( + ..., examples=[0.32, 1.02, 4.75, 28.6, 310.0], + description= "Insolation Flux (Earth flux)", gt =0 + )] + koi_impact : Annotated[float, Field( + ..., description="Impact Parameter", + examples=[0.02, 0.18, 0.45, 0.72, 0.95], + ge = 0, lt = 1 + )] + koi_ror : Annotated[float, Field( + ..., description="Planet/Star Radius Ratio", + examples=[0.011, 0.028, 0.065, 0.112, 0.198], + gt = 0, lt = 1 + )] + koi_srho : Annotated[float, Field( + ..., description="Stellar Density (g/cm³)", + gt = 0, examples=[0.18, 0.85, 1.41, 3.72, 18.6] + )] + koi_dor : Annotated[float, Field( + ..., examples=[2.3, 8.7, 21.4, 56.8, 134.2], + gt = 1, description="Planet-Star Distance (R★)" + )] + koi_num_transits : Annotated[int, Field( + ..., description="Number of Transits", + ge=1, examples= [1, 3, 7, 15, 42] + )] \ No newline at end of file diff --git a/static/css/style.css b/app/static/css/style.css similarity index 56% rename from static/css/style.css rename to app/static/css/style.css index 8073bc7..cfe57e6 100644 --- a/static/css/style.css +++ b/app/static/css/style.css @@ -180,6 +180,12 @@ h1, h2, h3, h4, h5, h6 { max-width: 600px; } +.hero-buttons { + display: flex; + gap: 1rem; + flex-wrap: wrap; +} + .hero-image { flex: 1; display: flex; @@ -258,9 +264,58 @@ h1, h2, h3, h4, h5, h6 { color: var(--text-main); } +.btn-secondary { + background: transparent; + color: var(--text-main); + border: 2px solid var(--accent-primary); + position: relative; + overflow: hidden; + z-index: 1; +} + +.btn-secondary::before { + content: ''; + position: absolute; + top: 0; + left: 0; + width: 0%; + height: 100%; + background: var(--accent-primary); + transition: var(--transition-smooth); + z-index: -1; +} + +.btn-secondary:hover::before { + width: 100%; +} + +.btn-secondary:hover { + color: var(--text-main); +} + +.btn-secondary-small { + background: transparent; + color: var(--accent-primary); + border: 1px solid var(--accent-primary); + padding: 0.5rem 1rem; + border-radius: var(--radius-sm); + font-weight: 500; + font-size: 0.9rem; + text-decoration: none; + transition: var(--transition-smooth); + display: inline-flex; + align-items: center; + gap: 0.5rem; +} + +.btn-secondary-small:hover { + background: var(--accent-primary); + color: var(--bg-dark); +} + /* ========================================= - Form Section - ========================================= */ + Form Section + ========================================= */ .form-section { padding: 6rem 0; display: flex; @@ -432,8 +487,290 @@ h1, h2, h3, h4, h5, h6 { } /* ========================================= - About Page Styles - ========================================= */ + Tab Navigation + ========================================= */ +.tab-nav { + display: flex; + gap: 0.5rem; +} + +.tab-link { + background: transparent; + border: none; + color: var(--text-muted); + font-family: 'Poppins', sans-serif; + font-size: 0.95rem; + font-weight: 500; + padding: 0.5rem 1rem; + cursor: pointer; + position: relative; + transition: var(--transition-fast); + display: inline-flex; + align-items: center; + gap: 0.5rem; +} + +.tab-link::after { + content: ''; + position: absolute; + bottom: -2px; + left: 0; + width: 0; + height: 2px; + background: var(--accent-primary); + transition: var(--transition-smooth); +} + +.tab-link:hover { + color: var(--text-main); +} + +.tab-link.active { + color: var(--text-main); +} + +.tab-link.active::after { + width: 100%; +} + +.tab-link[href] { + color: var(--text-muted); + text-decoration: none; + font-family: 'Poppins', sans-serif; + font-size: 0.95rem; + font-weight: 500; + padding: 0.5rem 1rem; + position: relative; + display: inline-flex; + align-items: center; + gap: 0.5rem; +} + +.tab-link[href]:hover { + color: var(--text-main); +} + +/* Tab Content */ +.tab-content { + display: none; +} + +.tab-content.active { + display: block; +} + +/* ========================================= + CSV Requirements + ========================================= */ +.csv-requirements { + background: rgba(0, 188, 255, 0.05); + border: 1px solid rgba(0, 188, 255, 0.2); + border-radius: var(--radius-md); + padding: 1.5rem; + margin-bottom: 2rem; +} + +.csv-requirements h3 { + color: var(--accent-primary); + margin-bottom: 1rem; + font-size: 1.2rem; +} + +.warning-text { + color: #ffab00; + margin-bottom: 1rem; +} + +.error-text { + color: #ff4b2b; + margin-top: 1rem; + font-size: 0.9rem; +} + +.column-list { + list-style: none; + display: grid; + grid-template-columns: repeat(auto-fill, minmax(250px, 1fr)); + gap: 0.5rem; +} + +.column-list li { + color: var(--text-muted); + font-size: 0.9rem; +} + +.column-list code { + color: var(--accent-primary); + background: rgba(0, 188, 255, 0.1); + padding: 0.2rem 0.5rem; + border-radius: 4px; + font-family: 'Courier New', monospace; +} + +/* ========================================= + File Upload + ========================================= */ +.batch-form { + display: flex; + flex-direction: column; + gap: 2rem; +} + +.file-upload-wrapper { + position: relative; +} + +.file-upload-label { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + gap: 1rem; + padding: 3rem; + border: 2px dashed var(--border-color); + border-radius: var(--radius-md); + cursor: pointer; + transition: var(--transition-smooth); + background: rgba(255, 255, 255, 0.02); +} + +.file-upload-label:hover { + border-color: var(--accent-primary); + background: rgba(0, 188, 255, 0.05); +} + +.file-upload-label i { + font-size: 3rem; + color: var(--accent-primary); +} + +.file-upload-label span { + color: var(--text-muted); +} + +.file-upload-wrapper input[type="file"] { + position: absolute; + inset: 0; + opacity: 0; + cursor: pointer; +} + +.file-name { + margin-top: 1rem; + color: var(--text-main); + font-size: 0.9rem; +} + +/* ========================================= + Large Modal for Batch + ========================================= */ +.modal-large { + max-width: 800px; + max-height: 85vh; + overflow-y: auto; +} + +.batch-summary { + margin-bottom: 2rem; + text-align: center; +} + +.batch-summary h4 { + color: var(--text-muted); + margin-bottom: 1rem; + font-size: 1rem; + text-transform: uppercase; + letter-spacing: 1px; +} + +/* Pie Chart */ +.pie-chart-container { + display: flex; + justify-content: center; + align-items: center; + gap: 2rem; + flex-wrap: wrap; +} + +.pie-chart { + width: 180px; + height: 180px; + border-radius: 50%; + position: relative; +} + +.pie-legend { + display: flex; + flex-direction: column; + gap: 0.8rem; +} + +.pie-legend-item { + display: flex; + align-items: center; + gap: 0.5rem; + font-size: 0.9rem; +} + +.pie-legend-color { + width: 16px; + height: 16px; + border-radius: 4px; +} + +/* Batch Result Table */ +.batch-table-wrapper { + margin-top: 1.5rem; +} + +.batch-table-wrapper h4 { + color: var(--text-muted); + margin-bottom: 1rem; + font-size: 1rem; + text-transform: uppercase; + letter-spacing: 1px; +} + +.batch-result-table { + width: 100%; + border-collapse: collapse; + font-size: 0.9rem; +} + +.batch-result-table th, +.batch-result-table td { + padding: 0.8rem; + text-align: left; + border-bottom: 1px solid var(--border-color); +} + +.batch-result-table th { + background: rgba(0, 188, 255, 0.1); + color: var(--accent-primary); + font-weight: 600; + position: sticky; + top: 0; +} + +.batch-result-table tr:hover td { + background: rgba(255, 255, 255, 0.02); +} + +.batch-result-table .prediction-cell { + font-weight: 600; +} + +.class-confirmed { color: #00ff88; } +.class-candidate { color: #ffab00; } +.class-false { color: #ff4b2b; } + +.high-confidence { color: #00ff88; } +.medium-confidence { color: #ffab00; } +.low-confidence { color: #ff4b2b; } + +/* ========================================= + About Page Styles + ========================================= */ .about-content { background: var(--bg-card); border: 1px solid var(--glass-border); @@ -482,6 +819,72 @@ h1, h2, h3, h4, h5, h6 { background: rgba(255, 255, 255, 0.02); } +/* Notification System */ +#notification-container { + position: fixed; + top: 20px; + right: 20px; + z-index: 3000; + display: flex; + flex-direction: column; + gap: 10px; +} + +.notification { + width: 320px; + padding: 12px 16px; + border-radius: 14px; + background: rgba(30, 30, 30, 0.7); + backdrop-filter: blur(25px) saturate(180%); + -webkit-backdrop-filter: blur(25px) saturate(180%); + border: 1px solid rgba(255, 255, 255, 0.1); + color: var(--text-main); + box-shadow: 0 15px 35px rgba(0, 0, 0, 0.35), 0 5px 15px rgba(0, 0, 0, 0.5); + display: flex; + align-items: center; + gap: 12px; + transform: translateX(calc(100% + 40px)); + transition: transform 0.5s cubic-bezier(0.19, 1, 0.22, 1); + font-weight: 500; + font-size: 0.85rem; +} + +.notification.active { + transform: translateX(0); +} + +.notification.success { + border-left: 4px solid #00ff88; +} + +.notification.error { + border-left: 4px solid #ff4b2b; +} + +.notification.info { + border-left: 4px solid var(--accent-primary); +} + +.notification-icon { + font-size: 1.2rem; +} + +.notification.success .notification-icon { color: #00ff88; } +.notification.error .notification-icon { color: #ff4b2b; } +.notification.info .notification-icon { color: var(--accent-primary); } + +/* Header Link Fix */ +.site-title-link { + display: flex; + align-items: center; + gap: 1rem; + transition: var(--transition-fast); +} + +.site-title-link:hover { + opacity: 0.8; +} + /* ========================================= Responsive Design ========================================= */ diff --git a/app/static/js/main.js b/app/static/js/main.js new file mode 100644 index 0000000..0380a12 --- /dev/null +++ b/app/static/js/main.js @@ -0,0 +1,443 @@ +document.addEventListener('DOMContentLoaded', () => { + // Tab Switching + const tabLinks = document.querySelectorAll('.tab-link:not([href])'); + const tabContents = document.querySelectorAll('.tab-content'); + + tabLinks.forEach(link => { + link.addEventListener('click', () => { + const targetTab = link.dataset.tab; + + tabLinks.forEach(l => l.classList.remove('active')); + tabContents.forEach(c => c.classList.remove('active')); + + link.classList.add('active'); + document.getElementById(`${targetTab}-tab`).classList.add('active'); + }); + }); + + // Smooth Scroll to Form (from Welcome tab) + const scrollBtn = document.getElementById('scroll-to-form'); + if (scrollBtn) { + scrollBtn.addEventListener('click', (e) => { + e.preventDefault(); + // Switch to Predict tab first + tabLinks.forEach(l => l.classList.remove('active')); + tabContents.forEach(c => c.classList.remove('active')); + document.querySelector('[data-tab="home"]').classList.add('active'); + document.getElementById('home-tab').classList.add('active'); + // Then scroll to form + setTimeout(() => { + document.getElementById('form-section').scrollIntoView({ + behavior: 'smooth' + }); + }, 50); + }); + } + + // Go to Batch Tab (from Welcome tab) + const batchBtn = document.getElementById('go-to-batch'); + if (batchBtn) { + batchBtn.addEventListener('click', (e) => { + e.preventDefault(); + tabLinks.forEach(l => l.classList.remove('active')); + tabContents.forEach(c => c.classList.remove('active')); + document.querySelector('[data-tab="batch"]').classList.add('active'); + document.getElementById('batch-tab').classList.add('active'); + setTimeout(() => { + document.getElementById('batch-tab').scrollIntoView({ + behavior: 'smooth' + }); + }, 50); + }); + } + + // Input Animation & Label Handling + const inputs = document.querySelectorAll('.input-field'); + inputs.forEach(input => { + if (input.value) { + input.classList.add('has-value'); + } + + input.addEventListener('input', () => { + if (input.value.trim() !== '') { + input.classList.add('has-value'); + } else { + input.classList.remove('has-value'); + } + }); + }); + + // Field Labels for Human-Readable Errors + const fieldLabels = { + 'koi_period': 'Orbital Period', + 'koi_time0bk': 'Transit Epoch', + 'koi_depth': 'Transit Depth', + 'koi_prad': 'Planet Radius', + 'koi_sma': 'Semi-Major Axis', + 'koi_incl': 'Inclination', + 'koi_teq': 'Equilibrium Temp', + 'koi_insol': 'Insolation Flux', + 'koi_impact': 'Impact Parameter', + 'koi_ror': 'Planet/Star Radius Ratio', + 'koi_srho': 'Stellar Density', + 'koi_dor': 'Planet-Star Distance', + 'koi_num_transits': 'Number of Transits' + }; + + // Notification System + function showNotification(message, type = 'info') { + const container = document.getElementById('notification-container'); + if (!container) return; + + const notification = document.createElement('div'); + notification.className = `notification ${type}`; + + const icons = { + success: 'fa-circle-check', + error: 'fa-circle-exclamation', + info: 'fa-circle-info' + }; + + notification.innerHTML = ` + + ${message} + `; + + container.appendChild(notification); + + setTimeout(() => notification.classList.add('active'), 10); + + setTimeout(() => { + notification.classList.remove('active'); + setTimeout(() => notification.remove(), 400); + }, 5000); + } + + // Form Submission (Single Prediction) + const form = document.getElementById('predictForm'); + const modal = document.getElementById('resultModal'); + const closeModalBtn = document.getElementById('closeModal'); + const resultsContainer = document.getElementById('resultsContainer'); + + if (form) { + form.addEventListener('submit', async (e) => { + e.preventDefault(); + + const submitBtn = form.querySelector('button[type="submit"]'); + const originalBtnText = submitBtn.innerHTML; + + submitBtn.innerHTML = ' Analyzing...'; + submitBtn.disabled = true; + + const formData = {}; + inputs.forEach(input => { + formData[input.id] = parseFloat(input.value); + }); + + try { + const response = await fetch('/predict', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(formData) + }); + + const data = await response.json(); + + if (!response.ok) { + let errorMsg = 'An error occurred during prediction.'; + if (data.detail) { + if (Array.isArray(data.detail)) { + errorMsg = data.detail.map(d => { + const field = d.loc[d.loc.length - 1]; + const label = fieldLabels[field] || field; + return `${label}: ${d.msg}`; + }).join('\n'); + } else { + errorMsg = data.detail; + } + } + throw new Error(errorMsg); + } + + displayResults(data); + openModal(); + showNotification('Analysis complete! Check the results.', 'success'); + + } catch (error) { + console.error('Error:', error); + showNotification(error.message, 'error'); + } finally { + submitBtn.innerHTML = originalBtnText; + submitBtn.disabled = false; + } + }); + } + + // Modal Functions + function openModal() { + modal.classList.add('active'); + document.body.style.overflow = 'hidden'; + } + + function closeModal() { + modal.classList.remove('active'); + document.body.style.overflow = ''; + } + + if (closeModalBtn) { + closeModalBtn.addEventListener('click', closeModal); + } + + if (modal) { + modal.addEventListener('click', (e) => { + if (e.target === modal) { + closeModal(); + } + }); + } + + // Helper: Display Results + function displayResults(data) { + const predictionEl = document.getElementById('predictionResult'); + const barsContainer = document.getElementById('probabilityBars'); + + predictionEl.textContent = data.prediction; + + barsContainer.innerHTML = ''; + + const sortedProbs = Object.entries(data.probabilities) + .sort(([,a], [,b]) => b - a); + + sortedProbs.forEach(([label, prob]) => { + const percentage = (prob * 100).toFixed(1); + + const item = document.createElement('div'); + item.className = 'prob-item'; + + item.innerHTML = ` +
+ ${label} + ${percentage}% +
+
+
+
+ `; + + barsContainer.appendChild(item); + + setTimeout(() => { + item.querySelector('.prob-bar-fill').style.width = `${percentage}%`; + }, 100); + }); + } + + // ========================================= + // Batch Prediction + // ========================================= + + // File input display + const csvFileInput = document.getElementById('csvFile'); + const fileNameDisplay = document.getElementById('fileName'); + + if (csvFileInput) { + csvFileInput.addEventListener('change', (e) => { + const file = e.target.files[0]; + if (file) { + fileNameDisplay.textContent = `Selected: ${file.name}`; + fileNameDisplay.style.color = 'var(--accent-primary)'; + } else { + fileNameDisplay.textContent = ''; + } + }); + } + + // Batch Form Submission + const batchForm = document.getElementById('batchForm'); + const batchModal = document.getElementById('batchResultModal'); + const closeBatchModalBtn = document.getElementById('closeBatchModal'); + + if (batchForm) { + batchForm.addEventListener('submit', async (e) => { + e.preventDefault(); + + const submitBtn = batchForm.querySelector('button[type="submit"]'); + const originalBtnText = submitBtn.innerHTML; + + submitBtn.innerHTML = ' Processing...'; + submitBtn.disabled = true; + + const fileInput = document.getElementById('csvFile'); + const file = fileInput.files[0]; + + if (!file) { + showNotification('Please select a CSV file', 'error'); + submitBtn.innerHTML = originalBtnText; + submitBtn.disabled = false; + return; + } + + const formData = new FormData(); + formData.append('file', file); + + try { + const response = await fetch('/predict/batch', { + method: 'POST', + body: formData + }); + + const data = await response.json(); + + if (!response.ok) { + let errorMsg = 'An error occurred during batch prediction.'; + if (data.detail) { + errorMsg = data.detail; + } + throw new Error(errorMsg); + } + + displayBatchResults(data); + openBatchModal(); + showNotification('Batch prediction complete!', 'success'); + + } catch (error) { + console.error('Error:', error); + showNotification(error.message, 'error'); + } finally { + submitBtn.innerHTML = originalBtnText; + submitBtn.disabled = false; + } + }); + } + + function openBatchModal() { + batchModal.classList.add('active'); + document.body.style.overflow = 'hidden'; + } + + function closeBatchModal() { + batchModal.classList.remove('active'); + document.body.style.overflow = ''; + } + + if (closeBatchModalBtn) { + closeBatchModalBtn.addEventListener('click', closeBatchModal); + } + + if (batchModal) { + batchModal.addEventListener('click', (e) => { + if (e.target === batchModal) { + closeBatchModal(); + } + }); + } + + function displayBatchResults(data) { + const pieChartContainer = document.getElementById('pieChartContainer'); + const tableBody = document.getElementById('batchResultsTableBody'); + + pieChartContainer.innerHTML = ''; + tableBody.innerHTML = ''; + + // Transform API response to results array + // Handle both "prediction_probability" and "predction_probability" (API typo) + const probabilities = data.prediction_probability || data.predction_probability || []; + + if (!data.predicted_labels || !Array.isArray(data.predicted_labels) || probabilities.length === 0) { + console.error('Invalid API response:', data); + showNotification('Invalid response from server', 'error'); + return; + } + + const results = data.predicted_labels.map((prediction, index) => { + const probs = probabilities[index] || []; + const maxProb = probs.length > 0 ? Math.max(...probs) : 0; + return { prediction, confidence: maxProb }; + }); + + // Count predictions by class + const classCounts = {}; + results.forEach(result => { + classCounts[result.prediction] = (classCounts[result.prediction] || 0) + 1; + }); + + const total = results.length; + + // Colors for each class + const classColors = { + 'CONFIRMED': '#00ff88', + 'CANDIDATE': '#ffab00', + 'FALSE POSITIVE': '#ff4b2b' + }; + + // Calculate angles for pie chart + let currentAngle = 0; + const conicGradientParts = []; + + Object.entries(classCounts).forEach(([className, count]) => { + const percentage = (count / total) * 100; + const angle = (count / total) * 360; + const startAngle = currentAngle; + const endAngle = currentAngle + angle; + const color = classColors[className] || '#888'; + + conicGradientParts.push(`${color} ${startAngle}deg ${endAngle}deg`); + currentAngle = endAngle; + }); + + // Create pie chart + const pieChart = document.createElement('div'); + pieChart.className = 'pie-chart'; + pieChart.style.background = `conic-gradient(${conicGradientParts.join(', ')})`; + pieChartContainer.appendChild(pieChart); + + // Create legend + const legend = document.createElement('div'); + legend.className = 'pie-legend'; + + Object.entries(classCounts).forEach(([className, count]) => { + const percentage = ((count / total) * 100).toFixed(1); + const color = classColors[className] || '#888'; + + const legendItem = document.createElement('div'); + legendItem.className = 'pie-legend-item'; + legendItem.innerHTML = ` +
+ ${className}: ${count} (${percentage}%) + `; + legend.appendChild(legendItem); + }); + + pieChartContainer.appendChild(legend); + + // Populate table + results.forEach((result, index) => { + const row = document.createElement('tr'); + + const confidenceClass = getConfidenceClass(result.confidence); + const predictionClass = getPredictionClass(result.prediction); + + row.innerHTML = ` + ${index + 1} + ${result.prediction} + ${(result.confidence * 100).toFixed(1)}% + `; + + tableBody.appendChild(row); + }); + } + + function getConfidenceClass(confidence) { + if (confidence >= 0.8) return 'high-confidence'; + if (confidence >= 0.5) return 'medium-confidence'; + return 'low-confidence'; + } + + function getPredictionClass(prediction) { + if (prediction === 'CONFIRMED') return 'class-confirmed'; + if (prediction === 'CANDIDATE') return 'class-candidate'; + return 'class-false'; + } +}); diff --git a/static/materials/32247a32f95f2504404c78a8df9ed849.png b/app/static/materials/32247a32f95f2504404c78a8df9ed849.png similarity index 100% rename from static/materials/32247a32f95f2504404c78a8df9ed849.png rename to app/static/materials/32247a32f95f2504404c78a8df9ed849.png diff --git a/static/materials/d821657540b6765c2d915b547bfce9c5 (1).jpg b/app/static/materials/d821657540b6765c2d915b547bfce9c5 (1).jpg similarity index 100% rename from static/materials/d821657540b6765c2d915b547bfce9c5 (1).jpg rename to app/static/materials/d821657540b6765c2d915b547bfce9c5 (1).jpg diff --git a/static/materials/next.png b/app/static/materials/next.png similarity index 100% rename from static/materials/next.png rename to app/static/materials/next.png diff --git a/app/templates/about.html b/app/templates/about.html new file mode 100644 index 0000000..2120401 --- /dev/null +++ b/app/templates/about.html @@ -0,0 +1,187 @@ + + + + + + About - TransitIQ + + + + + + + + + + + + + + + + + +
+ + +
+
+
+

About the Project

+

+ The TransitIQ is an advanced machine learning tool developed for exoplanet classification research. + It leverages state-of-the-art ensemble algorithms to assist astronomers in identifying potential exoplanets. +

+
+ +
+

Objective

+

+ The primary goal is to automate the classification of transit signals. By analyzing light curves and orbital parameters, + the model categorizes candidates into three distinct classes: +

+
    +
  • CONFIRMED (2): Verified exoplanets.
  • +
  • CANDIDATE (1): Potential planets requiring further observation.
  • +
  • FALSE POSITIVE (0): Signals caused by other astrophysical phenomena.
  • +
+
+ +
+

Technical Architecture

+

+ Built on Scikit-learn, the system employs a Stacking Ensemble Classifier. + This combines the strengths of Random Forest and XGBoost, + orchestrated by a Logistic Regression meta-classifier. +

+ The backend is powered by FastAPI, serving a robust API that processes user inputs + and returns real-time predictions with probability confidence scores. +

+
+ +
+

Input Parameters

+

The model requires 13 specific orbital and transit features derived from NASA's Kepler mission data:

+ +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FeatureUnitDescription
Orbital PeriodDaysTime taken to complete one full orbit.
Transit EpochBJDTime of the center of the first transit.
Transit DepthppmFraction of stellar flux lost during transit.
Planet RadiusEarth RadiiEstimated radius of the planet.
Semi-Major AxisAUAverage distance from the host star.
InclinationDegreesAngle of the orbital plane.
Equilibrium TempKelvinTheoretical surface temperature.
Insolation FluxEarth FluxIncident solar radiation.
Impact Parameter-Sky-projected distance at conjunction.
Radius Ratio-Ratio of planet radius to star radius.
Stellar Densityg/cm³Density of the host star.
Planet-Star DistStellar RadiiDistance scaled by star size.
Num TransitsCountTotal number of transit events observed.
+
+
+ +
+

Disclaimer

+

+ While this model achieves high accuracy on the validation set, no ML model is infallible. + Predictions should be used as a preliminary screening tool rather than absolute confirmation. + The model is optimized for Kepler-like data distributions. +

+
+ +
+

+ + View Source on GitHub + +

+

+ © 2025 TransitIQ. Licensed under MIT. +

+ + Back to TransitIQ + +
+
+
+ + + + + \ No newline at end of file diff --git a/app/templates/index.html b/app/templates/index.html new file mode 100644 index 0000000..7693b1d --- /dev/null +++ b/app/templates/index.html @@ -0,0 +1,288 @@ + + + + + + Home - TransitIQ + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+

Welcome to the Future

+

+ Discover New
+ Worlds +

+

+ Harness the power of Machine Learning to classify exoplanets. + Input transit data and let our advanced ensemble model determine if it's a + Confirmed Planet, Candidate, or False Positive. +

+
+ + +
+
+ +
+
+ Exoplanet Art +
+
+
+
+ + +
+ +
+
+
+

Input Parameters

+

Enter the transit details below to generate a prediction.

+
+ +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ +
+
+
+
+
+ + +
+
+
+
+

Batch Prediction

+

Upload a CSV file to run predictions on multiple samples at once.

+
+ +
+

CSV File Requirements

+

+ + Important: Your CSV file must have the following columns in the exact order: +

+
    +
  • koi_period - Orbital Period (days)
  • +
  • koi_time0bk - Transit Epoch (BJD)
  • +
  • koi_depth - Transit Depth (ppm)
  • +
  • koi_prad - Planet Radius (Earth radii)
  • +
  • koi_sma - Semi-Major Axis (AU)
  • +
  • koi_incl - Inclination (deg)
  • +
  • koi_teq - Equilibrium Temp (K)
  • +
  • koi_insol - Insolation Flux (Earth flux)
  • +
  • koi_impact - Impact Parameter
  • +
  • koi_ror - Planet/Star Radius Ratio
  • +
  • koi_srho - Stellar Density (g/cm³)
  • +
  • koi_dor - Planet-Star Distance (R★)
  • +
  • koi_num_transits - Number of Transits
  • +
+

+ + If the CSV file does not have the correct columns in the exact order, predictions will fail or produce incorrect results. +

+
+ +
+
+ + +
+
+ +
+ +
+
+
+
+
+ + + + + + + + + +
+ + + + + diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/sample_generator.py b/data/sample_generator.py new file mode 100644 index 0000000..aad974c --- /dev/null +++ b/data/sample_generator.py @@ -0,0 +1,106 @@ +""" +sample_generator.py + +This script is used to generate samples directly from the main dataset. +These samples are used to test the `/predict/batch` route. +To run it, enter this in your command line: +``` +python -m data.sample_generator +``` +""" + + +import pandas as pd +import numpy as np + + +def get_window(camps, campaign_dates): + if pd.isna(camps) or not camps: + return np.nan, np.nan + + camps = str(camps).split(',') if isinstance(camps, str) else camps + starts, ends = [], [] + + for c in camps: + try: + camp_num = int(c.strip()) + if camp_num in campaign_dates: + start, end = campaign_dates[camp_num] + starts.append(start) + ends.append(end) + except (ValueError, KeyError): + continue + + return (min(starts) if starts else np.nan, max(ends) if ends else np.nan) + + +def create_test_sample(num_samples=10, random_seed=42): + np.random.seed(random_seed) + + kepler_df = pd.read_csv("data/kepler_data.csv", comment="#") + k2_df = pd.read_csv("data/k2_data.csv", comment="#") + + feature_list = [ + "koi_period", "koi_time0bk", "koi_depth", "koi_prad", + "koi_sma", "koi_incl", "koi_teq", "koi_insol", "koi_impact", + "koi_ror", "koi_srho", "koi_dor", "koi_num_transits" + ] + + kepler_subset = kepler_df[feature_list] + + campaign_dates = { + 0: (2456725.0, 2456805.0), 1: (2456808.0, 2456891.0), 2: (2456893.0, 2456975.0), + 3: (2456976.0, 2457064.0), 4: (2457065.0, 2457159.0), 5: (2457159.0, 2457246.0), + 6: (2457250.0, 2457338.0), 7: (2457339.0, 2457420.0), 8: (2457421.0, 2457530.0), + 9: (2457504.0, 2457579.0), 10: (2457577.0, 2457653.0), 11: (2457657.0, 2457732.0), + 12: (2457731.0, 2457819.0), 13: (2457820.0, 2457900.0), 14: (2457898.0, 2457942.0), + 15: (2457941.0, 2458022.0), 16: (2458020.0, 2458074.0), 17: (2458074.0, 2458176.0), + 18: (2458151.0, 2458201.0), 19: (2458232.0, 2458348.0) + } + + k2_df['campaigns'] = k2_df['k2_campaigns'] + k2_df[['obs_start_bjd', 'obs_end_bjd']] = k2_df['campaigns'].apply( + lambda x: pd.Series(get_window(x, campaign_dates)) + ) + + k2_df['n_min'] = np.ceil((k2_df['obs_start_bjd'] - k2_df['pl_tranmid']) / k2_df['pl_orbper']) + k2_df['n_max'] = np.floor((k2_df['obs_end_bjd'] - k2_df['pl_tranmid']) / k2_df['pl_orbper']) + k2_df['num_transits'] = (k2_df['n_max'] - k2_df['n_min'] + 1).clip(lower=0) + + k2_mapping = { + "pl_orbper": "koi_period", "pl_tranmid": "koi_time0bk", + "pl_trandep": "koi_depth", "pl_rade": "koi_prad", "pl_orbsmax": "koi_sma", + "pl_orbincl": "koi_incl", "pl_eqt": "koi_teq", "pl_insol": "koi_insol", + "pl_imppar": "koi_impact", "pl_ratror": "koi_ror", "pl_dens": "koi_srho", + "pl_ratdor": "koi_dor", "num_transits": "koi_num_transits" + } + + k2_subset = k2_df[list(k2_mapping.keys())].rename(columns=k2_mapping) + + combined = pd.concat([kepler_subset, k2_subset], ignore_index=True) + + combined = combined.dropna(subset=feature_list) + + sample_indices = np.random.choice(len(combined), size=min(num_samples, len(combined)), replace=False) + sample = combined.iloc[sample_indices].copy() + + noise_factor = 0.15 + for col in feature_list: + col_std = sample[col].std() + if col_std > 0: + noise = np.random.normal(0, col_std * noise_factor, size=len(sample)) + sample[col] = sample[col] + noise + sample[col] = sample[col].clip(lower=0) if col in ["koi_depth", "koi_impact", "koi_ror", "koi_num_transits"] else sample[col] + + sample = sample[feature_list] + + output_path = "data/test_sample.csv" + sample.to_csv(output_path, index=False) + print(f"Created test sample with {len(sample)} rows at {output_path}") + print(f"Columns: {sample.columns.tolist()}") + print(f"\nSample preview:") + print(sample.head()) + + +if __name__ == "__main__": + create_test_sample(num_samples=20) diff --git a/fit.py b/models/fit.py similarity index 100% rename from fit.py rename to models/fit.py diff --git a/requirements.txt b/requirements.txt index 0aaa4fe..8a8e2ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,17 @@ annotated-doc==0.0.4 +annotated-types==0.7.0 anyio==4.13.0 blinker==1.9.0 certifi==2026.2.25 click==8.3.1 contourpy==1.3.3 cycler==0.12.1 +docutils==0.22.4 dotenv==0.9.9 +fastapi==0.135.3 filelock==3.25.2 -Flask==3.1.2 fonttools==4.61.1 fsspec==2026.3.0 -gunicorn==23.0.0 h11==0.16.0 hf-xet==1.4.3 httpcore==1.0.9 @@ -19,24 +20,37 @@ huggingface_hub==1.10.1 idna==3.11 imbalanced-learn==0.14.1 itsdangerous==2.2.0 +jedi==0.19.2 Jinja2==3.1.6 joblib==1.5.3 kiwisolver==1.4.9 lightgbm==4.6.0 +loro==1.10.3 +marimo==0.23.1 +Markdown==3.10.2 markdown-it-py==4.0.0 MarkupSafe==3.0.3 matplotlib==3.10.8 mdurl==0.1.2 +msgspec==0.21.1 +narwhals==2.19.0 numpy==2.4.1 packaging==25.0 pandas==2.3.3 +parso==0.8.6 pillow==12.1.0 +psutil==7.2.2 +pydantic==2.12.5 +pydantic_core==2.41.5 Pygments==2.20.0 +pymdown-extensions==10.21.2 pyparsing==3.3.1 python-dateutil==2.9.0.post0 python-dotenv==1.2.1 +python-multipart==0.0.26 pytz==2025.2 PyYAML==6.0.3 +pyzmq==27.1.0 rich==15.0.0 scikit-learn==1.8.0 scipy==1.17.0 @@ -44,10 +58,14 @@ seaborn==0.13.2 shellingham==1.5.4 six==1.17.0 sklearn-compat==0.1.5 +starlette==1.0.0 threadpoolctl==3.6.0 +tomlkit==0.14.0 tqdm==4.67.3 typer==0.24.1 +typing-inspection==0.4.2 typing_extensions==4.15.0 tzdata==2025.3 -Werkzeug==3.1.5 +uvicorn==0.44.0 +websockets==16.0 xgboost==3.1.3 diff --git a/static/js/main.js b/static/js/main.js deleted file mode 100644 index de0f46e..0000000 --- a/static/js/main.js +++ /dev/null @@ -1,147 +0,0 @@ -document.addEventListener('DOMContentLoaded', () => { - // Smooth Scroll - const scrollBtn = document.getElementById('scroll-to-form'); - if (scrollBtn) { - scrollBtn.addEventListener('click', (e) => { - e.preventDefault(); - document.getElementById('form-section').scrollIntoView({ - behavior: 'smooth' - }); - }); - } - - // Input Animation & Label Handling - const inputs = document.querySelectorAll('.input-field'); - inputs.forEach(input => { - // Trigger label animation on load if value exists - if (input.value) { - input.classList.add('has-value'); - } - - input.addEventListener('input', () => { - if (input.value.trim() !== '') { - input.classList.add('has-value'); - } else { - input.classList.remove('has-value'); - } - }); - }); - - // Form Submission - const form = document.getElementById('predictForm'); - const modal = document.getElementById('resultModal'); - const closeModalBtn = document.getElementById('closeModal'); - const resultsContainer = document.getElementById('resultsContainer'); - - if (form) { - form.addEventListener('submit', async (e) => { - e.preventDefault(); - - const submitBtn = form.querySelector('button[type="submit"]'); - const originalBtnText = submitBtn.innerHTML; - - // Loading State - submitBtn.innerHTML = ' Analyzing...'; - submitBtn.disabled = true; - - // Collect Data - const formData = {}; - inputs.forEach(input => { - formData[input.id] = parseFloat(input.value); - }); - - try { - const response = await fetch('/predict', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify(formData) - }); - - const data = await response.json(); - - if (data.error) { - throw new Error(data.error); - } - - // Populate Results - displayResults(data); - openModal(); - - } catch (error) { - console.error('Error:', error); - alert('An error occurred during prediction. Please check your inputs.'); - } finally { - // Reset Button - submitBtn.innerHTML = originalBtnText; - submitBtn.disabled = false; - } - }); - } - - // Modal Functions - function openModal() { - modal.classList.add('active'); - document.body.style.overflow = 'hidden'; - } - - function closeModal() { - modal.classList.remove('active'); - document.body.style.overflow = ''; - } - - if (closeModalBtn) { - closeModalBtn.addEventListener('click', closeModal); - } - - // Close on click outside - if (modal) { - modal.addEventListener('click', (e) => { - if (e.target === modal) { - closeModal(); - } - }); - } - - // Helper: Display Results - function displayResults(data) { - const predictionEl = document.getElementById('predictionResult'); - const barsContainer = document.getElementById('probabilityBars'); - - // Set Prediction Text - predictionEl.textContent = data.prediction; - - // Clear previous bars - barsContainer.innerHTML = ''; - - // Sort probabilities - const sortedProbs = Object.entries(data.probabilities) - .sort(([,a], [,b]) => b - a); - - // Create Bars - sortedProbs.forEach(([label, prob]) => { - const percentage = (prob * 100).toFixed(1); - - const item = document.createElement('div'); - item.className = 'prob-item'; - - item.innerHTML = ` -
- ${label} - ${percentage}% -
-
-
-
- `; - - barsContainer.appendChild(item); - - // Animate bar after a slight delay - setTimeout(() => { - item.querySelector('.prob-bar-fill').style.width = `${percentage}%`; - }, 100); - }); - } -}); diff --git a/static/script.js b/static/script.js deleted file mode 100644 index 3bdb2b8..0000000 --- a/static/script.js +++ /dev/null @@ -1,87 +0,0 @@ -function toPascalCase(str) { - return str - .split(" ") - .map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase()) - .join(""); -} - - -document - .getElementById("predictForm") - .addEventListener("submit", async function (e) { - e.preventDefault(); - - let inputs = document.querySelectorAll("#predictForm input"); - let features = []; - inputs.forEach((input) => features.push(parseFloat(input.value))); - - let res = await fetch("/predict", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ features: features }), - }); - - let data = await res.json(); - const modal = document.getElementById("resultModal"); - const resultsBox = document.getElementById("resultsBox"); - - if (data.prediction) { - let probText = ""; - if (data.probabilities) { - probText = "

Class Probabilities:

"; - for (let key in data.probabilities) { - probText += ` -
- ${toPascalCase(key)}: ${( - data.probabilities[key] * 100 - ).toFixed(2)}% -
-
-
-
- `; - } - } - resultsBox.innerHTML = `

Prediction: ${data.prediction - .toLowerCase() - .split(" ") - .map((word) => word.charAt(0).toUpperCase() + word.slice(1)) - .join("")}

${probText}`; - } else { - resultsBox.innerHTML = `

Error: ${data.error}

`; - } - modal.classList.remove("hidden"); - }); -document.getElementById("closeModal").addEventListener("click", () => { - document.getElementById("resultModal").classList.add("hidden"); -}); - -resultsBox.classList.remove("show"); -void resultsBox.offsetWidth; -resultsBox.classList.add("show"); -const formLabel = document.querySelector(".form-label"); -const formInputs = document.querySelectorAll("#predictForm input"); - -formInputs.forEach((input) => { - input.addEventListener("input", () => { - const anyFilled = Array.from(formInputs).some((i) => i.value.trim() !== ""); - formLabel.textContent = anyFilled ? "Remove details" : "Enter the details"; - formLabel.style.cursor = anyFilled ? "pointer" : "default"; - }); -}); - -formLabel.addEventListener("click", () => { - if (formLabel.textContent === "Remove details") { - formInputs.forEach((input) => (input.value = "")); - formLabel.textContent = "Enter the details"; - } -}); - -document.querySelector("#animated-btn").addEventListener("click", (e) => { - e.preventDefault(); // prevent default if it's a link - document.querySelector("#form-section").scrollIntoView({ - behavior: "smooth", - }); -}); diff --git a/templates/about.html b/templates/about.html deleted file mode 100644 index 840f524..0000000 --- a/templates/about.html +++ /dev/null @@ -1,147 +0,0 @@ -{% extends "base.html" %} - -{% block title %}About - TransitIQ{% endblock %} - -{% block content %} -
-
-

About the Project

-

- The TransitIQ is an advanced machine learning tool developed for exoplanet classification research. - It leverages state-of-the-art ensemble algorithms to assist astronomers in identifying potential exoplanets. -

-
- -
-

Objective

-

- The primary goal is to automate the classification of transit signals. By analyzing light curves and orbital parameters, - the model categorizes candidates into three distinct classes: -

- -
- -
-

Technical Architecture

-

- Built on Scikit-learn, the system employs a Stacking Ensemble Classifier. - This combines the strengths of Random Forest and XGBoost, - orchestrated by a Logistic Regression meta-classifier. -

- The backend is powered by Flask, serving a robust API that processes user inputs - and returns real-time predictions with probability confidence scores. -

-
- -
-

Input Parameters

-

The model requires 13 specific orbital and transit features derived from NASA's Kepler mission data:

- -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FeatureUnitDescription
Orbital PeriodDaysTime taken to complete one full orbit.
Transit EpochBJDTime of the center of the first transit.
Transit DepthppmFraction of stellar flux lost during transit.
Planet RadiusEarth RadiiEstimated radius of the planet.
Semi-Major AxisAUAverage distance from the host star.
InclinationDegreesAngle of the orbital plane.
Equilibrium TempKelvinTheoretical surface temperature.
Insolation FluxEarth FluxIncident solar radiation.
Impact Parameter-Sky-projected distance at conjunction.
Radius Ratio-Ratio of planet radius to star radius.
Stellar Densityg/cm³Density of the host star.
Planet-Star DistStellar RadiiDistance scaled by star size.
Num TransitsCountTotal number of transit events observed.
-
-
- -
-

Disclaimer

-

- While this model achieves high accuracy on the validation set, no ML model is infallible. - Predictions should be used as a preliminary screening tool rather than absolute confirmation. - The model is optimized for Kepler-like data distributions. -

-
- -
-

- - View Source on GitHub - -

-

- © 2025 TransitIQ. Licensed under MIT. -

- - Back to TransitIQ - -
-
-{% endblock %} \ No newline at end of file diff --git a/templates/base.html b/templates/base.html deleted file mode 100644 index b51aed6..0000000 --- a/templates/base.html +++ /dev/null @@ -1,43 +0,0 @@ - - - - - - {% block title %}TransitIQ{% endblock %} - - - - - - - - - - - - - - - - - - -
- {% block content %}{% endblock %} -
- - - - - diff --git a/templates/index.html b/templates/index.html deleted file mode 100644 index 41a4c27..0000000 --- a/templates/index.html +++ /dev/null @@ -1,132 +0,0 @@ -{% extends "base.html" %} - -{% block title %}Home - TransitIQ{% endblock %} - -{% block content %} - -
-
-

Welcome to the Future

-

- Discover New
- Worlds -

-

- Harness the power of Machine Learning to classify exoplanets. - Input transit data and let our advanced ensemble model determine if it's a - Confirmed Planet, Candidate, or False Positive. -

- -
- -
-
- Exoplanet Art -
-
-
- - -
-
-
-

Input Parameters

-

Enter the transit details below to generate a prediction.

-
- -
- -
- - -
- -
- - -
- -
- - -
- -
- - -
- -
- - -
- -
- - -
- -
- - -
- -
- - -
- -
- - -
- -
- - -
- -
- - -
- -
- - -
- -
- - -
- -
- -
-
-
-
- - - -{% endblock %} \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..d4839a6 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Tests package diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..f810f68 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,73 @@ +""" +conftest.py – session-scoped fixtures for the TransitIQ test suite. + +Strategy +-------- +The FastAPI app loads ML model artifacts (pipe.pkl, column_names.pkl) during +its lifespan startup via ``initialize_artifacts()``. In CI those files are not +present, and downloading them from Hugging Face would be slow and fragile. + +We therefore: +1. Stub out the ``models`` package so the bare import doesn't fail. +2. Patch ``initialize_artifacts`` to return lightweight MagicMock objects. +3. Expose a ``TestClient`` fixture whose lifespan uses those mocks. +""" + +import sys +import io +import pytest +import numpy as np +from unittest.mock import MagicMock, patch +from fastapi.testclient import TestClient + +# ── 1. Stub heavy / optional dependencies before the app is imported ───────── +# Prevents import errors in CI where model artifacts or HF credentials may +# not be available. +_models_stub = MagicMock() +sys.modules.setdefault("models", _models_stub) +sys.modules.setdefault("models.download_from_hf", _models_stub) + +# ── 2. Column order must exactly match app/schema/validate.py ──────────────── +COLUMN_NAMES = np.array([ + "koi_period", + "koi_time0bk", + "koi_depth", + "koi_prad", + "koi_sma", + "koi_incl", + "koi_teq", + "koi_insol", + "koi_impact", + "koi_ror", + "koi_srho", + "koi_dor", + "koi_num_transits", +]) + + +@pytest.fixture(scope="session") +def mock_pipe(): + """A minimal sklearn-compatible pipeline mock.""" + pipe = MagicMock() + # Return a label and probability vector scaled by the number of rows so + # both single-row (/predict) and multi-row (/predict/batch) calls work. + pipe.predict.side_effect = lambda x: np.array([2] * len(x)) + pipe.predict_proba.side_effect = ( + lambda x: np.array([[0.05, 0.15, 0.80]] * len(x)) + ) + return pipe + + +@pytest.fixture(scope="session") +def client(mock_pipe): + """ + Session-scoped TestClient whose lifespan uses the mock pipeline. + ``scope="session"`` means the app boots once for the entire test run, + which mirrors real usage and keeps the suite fast. + """ + with patch("app.app.initialize_artifacts", return_value=(mock_pipe, COLUMN_NAMES)): + # Import deferred so the sys.modules stubs are already in place. + from app.app import app # noqa: PLC0415 + + with TestClient(app) as c: + yield c diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..16b5a44 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,283 @@ +""" +test_api.py – integration and schema tests for the TransitIQ FastAPI app. + +Run from the project root: + pytest tests/ -v + +The TestClient and all mocking are set up in conftest.py. +""" + +import io +import pytest +import pandas as pd +from pydantic import ValidationError + +# ── Shared test data ────────────────────────────────────────────────────────── + +#: Column order matches app/schema/validate.py and conftest.COLUMN_NAMES. +_COLUMNS = [ + "koi_period", "koi_time0bk", "koi_depth", "koi_prad", "koi_sma", + "koi_incl", "koi_teq", "koi_insol", "koi_impact", "koi_ror", + "koi_srho", "koi_dor", "koi_num_transits", +] + +#: A valid, boundary-safe payload for /predict. +VALID_PAYLOAD: dict = { + "koi_period": 10.0, + "koi_time0bk": 2454834.0, + "koi_depth": 500.0, + "koi_prad": 1.5, + "koi_sma": 0.1, + "koi_incl": 85.0, + "koi_teq": 400.0, + "koi_insol": 2.0, + "koi_impact": 0.3, + "koi_ror": 0.05, + "koi_srho": 1.2, + "koi_dor": 10.0, + "koi_num_transits": 5, +} + +_VALID_LABELS = {"FALSE POSITIVE", "CANDIDATE", "CONFIRMED"} + + +def _make_csv(rows: list | None = None) -> bytes: + """Build a well-formed CSV payload from a list of row dicts.""" + if rows is None: + rows = [VALID_PAYLOAD] + return pd.DataFrame(rows, columns=_COLUMNS).to_csv(index=False).encode() + + +# ── Health check ────────────────────────────────────────────────────────────── + +class TestHealthRoute: + def test_status_code(self, client): + assert client.get("/health").status_code == 200 + + def test_body_title(self, client): + assert client.get("/health").json()["title"] == "TransitIQ" + + def test_body_status_message(self, client): + assert client.get("/health").json()["status"] == "All systems operational" + + +# ── Static / HTML routes ────────────────────────────────────────────────────── + +class TestStaticRoutes: + def test_home_returns_200(self, client): + assert client.get("/").status_code == 200 + + def test_home_content_type(self, client): + assert "text/html" in client.get("/").headers["content-type"] + + def test_about_returns_200(self, client): + assert client.get("/about").status_code == 200 + + def test_about_content_type(self, client): + assert "text/html" in client.get("/about").headers["content-type"] + + +# ── POST /predict ───────────────────────────────────────────────────────────── + +class TestPredictEndpoint: + # --- Happy path ----------------------------------------------------------- + + def test_valid_payload_returns_201(self, client): + resp = client.post("/predict", json=VALID_PAYLOAD) + assert resp.status_code == 201 + + def test_response_has_required_keys(self, client): + data = client.post("/predict", json=VALID_PAYLOAD).json() + assert {"status", "prediction", "probabilities"} <= data.keys() + + def test_status_is_success(self, client): + data = client.post("/predict", json=VALID_PAYLOAD).json() + assert data["status"] == "success" + + def test_prediction_is_valid_label(self, client): + data = client.post("/predict", json=VALID_PAYLOAD).json() + assert data["prediction"] in _VALID_LABELS + + def test_probabilities_keys_match_labels(self, client): + data = client.post("/predict", json=VALID_PAYLOAD).json() + assert set(data["probabilities"].keys()) == _VALID_LABELS + + def test_probabilities_sum_to_one(self, client): + data = client.post("/predict", json=VALID_PAYLOAD).json() + total = sum(data["probabilities"].values()) + assert abs(total - 1.0) < 0.01 + + # --- Validation errors ---------------------------------------------------- + + def test_missing_field_returns_422(self, client): + payload = {k: v for k, v in VALID_PAYLOAD.items() if k != "koi_period"} + assert client.post("/predict", json=payload).status_code == 422 + + def test_koi_period_zero_returns_422(self, client): + """koi_period must be > 0.""" + assert client.post( + "/predict", json={**VALID_PAYLOAD, "koi_period": 0} + ).status_code == 422 + + def test_koi_period_negative_returns_422(self, client): + assert client.post( + "/predict", json={**VALID_PAYLOAD, "koi_period": -5.0} + ).status_code == 422 + + def test_koi_incl_above_90_returns_422(self, client): + """koi_incl must be ≤ 90.""" + assert client.post( + "/predict", json={**VALID_PAYLOAD, "koi_incl": 91.0} + ).status_code == 422 + + def test_koi_impact_at_one_returns_422(self, client): + """koi_impact must be strictly < 1.""" + assert client.post( + "/predict", json={**VALID_PAYLOAD, "koi_impact": 1.0} + ).status_code == 422 + + def test_koi_impact_above_one_returns_422(self, client): + assert client.post( + "/predict", json={**VALID_PAYLOAD, "koi_impact": 1.5} + ).status_code == 422 + + def test_koi_ror_at_one_returns_422(self, client): + """koi_ror must be strictly < 1.""" + assert client.post( + "/predict", json={**VALID_PAYLOAD, "koi_ror": 1.0} + ).status_code == 422 + + def test_koi_num_transits_zero_returns_422(self, client): + """koi_num_transits must be ≥ 1.""" + assert client.post( + "/predict", json={**VALID_PAYLOAD, "koi_num_transits": 0} + ).status_code == 422 + + def test_string_value_for_float_field_returns_422(self, client): + assert client.post( + "/predict", json={**VALID_PAYLOAD, "koi_period": "not-a-number"} + ).status_code == 422 + + def test_empty_body_returns_422(self, client): + assert client.post("/predict", json={}).status_code == 422 + + +# ── POST /predict/batch ─────────────────────────────────────────────────────── + +class TestBatchPredictEndpoint: + # --- Happy path ----------------------------------------------------------- + + def test_valid_csv_returns_201(self, client): + resp = client.post( + "/predict/batch", + files={"file": ("data.csv", io.BytesIO(_make_csv()), "text/csv")}, + ) + assert resp.status_code == 201 + + def test_response_has_required_keys(self, client): + data = client.post( + "/predict/batch", + files={"file": ("data.csv", io.BytesIO(_make_csv()), "text/csv")}, + ).json() + assert "status" in data + assert "predicted_labels" in data + assert "predction_probability" in data # kept as-is (typo in app.py) + + def test_single_row_prediction_count(self, client): + data = client.post( + "/predict/batch", + files={"file": ("data.csv", io.BytesIO(_make_csv()), "text/csv")}, + ).json() + assert len(data["predicted_labels"]) == 1 + + def test_multi_row_prediction_count(self, client): + csv_bytes = _make_csv([VALID_PAYLOAD] * 3) + data = client.post( + "/predict/batch", + files={"file": ("data.csv", io.BytesIO(csv_bytes), "text/csv")}, + ).json() + assert len(data["predicted_labels"]) == 3 + + def test_labels_are_valid(self, client): + data = client.post( + "/predict/batch", + files={"file": ("data.csv", io.BytesIO(_make_csv()), "text/csv")}, + ).json() + for label in data["predicted_labels"]: + assert label in _VALID_LABELS + + # --- Validation errors ---------------------------------------------------- + + def test_non_csv_extension_returns_422(self, client): + resp = client.post( + "/predict/batch", + files={"file": ("data.txt", io.BytesIO(b"some text"), "text/plain")}, + ) + assert resp.status_code == 422 + + def test_json_extension_returns_422(self, client): + resp = client.post( + "/predict/batch", + files={"file": ("data.json", io.BytesIO(b'{"a":1}'), "application/json")}, + ) + assert resp.status_code == 422 + + def test_wrong_column_names_returns_422(self, client): + bad_csv = b"col_a,col_b,col_c\n1.0,2.0,3.0\n" + resp = client.post( + "/predict/batch", + files={"file": ("data.csv", io.BytesIO(bad_csv), "text/csv")}, + ) + assert resp.status_code == 422 + + def test_extra_columns_returns_422(self, client): + """An extra column must be rejected even if valid columns are present.""" + extra = pd.DataFrame( + [{**VALID_PAYLOAD, "extra_col": 99.0}], + columns=[*_COLUMNS, "extra_col"] + ).to_csv(index=False).encode() + resp = client.post( + "/predict/batch", + files={"file": ("data.csv", io.BytesIO(extra), "text/csv")}, + ) + assert resp.status_code == 422 + + def test_non_numeric_values_returns_422(self, client): + """String cell values must fail numeric validation.""" + str_rows = [{k: "abc" for k in _COLUMNS}] + bad_csv = pd.DataFrame(str_rows, columns=_COLUMNS).to_csv(index=False).encode() + resp = client.post( + "/predict/batch", + files={"file": ("data.csv", io.BytesIO(bad_csv), "text/csv")}, + ) + assert resp.status_code == 422 + + +# ── Pydantic schema unit tests ──────────────────────────────────────────────── + +class TestUserInputSchema: + """Direct unit tests for the Pydantic model – no HTTP overhead.""" + + def test_valid_input_creates_model(self): + from app.schema.validate import UserInput + obj = UserInput(**VALID_PAYLOAD) + assert obj.koi_num_transits == 5 + + @pytest.mark.parametrize("field,bad_value", [ + ("koi_period", 0), # must be > 0 + ("koi_period", -1.0), + ("koi_time0bk", 100.0), # must be > 2_450_000 + ("koi_depth", 0), # must be > 0 + ("koi_incl", 0), # must be > 0 + ("koi_incl", 91.0), # must be ≤ 90 + ("koi_impact", -0.1), # must be ≥ 0 + ("koi_impact", 1.0), # must be < 1 + ("koi_ror", 0), # must be > 0 + ("koi_ror", 1.0), # must be < 1 + ("koi_dor", 0.5), # must be > 1 + ("koi_num_transits", 0), # must be ≥ 1 + ]) + def test_invalid_field_raises_validation_error(self, field, bad_value): + from app.schema.validate import UserInput + with pytest.raises(ValidationError): + UserInput(**{**VALID_PAYLOAD, field: bad_value})