-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
67 lines (52 loc) · 2.01 KB
/
main.py
File metadata and controls
67 lines (52 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import pandas as pd
import numpy as np
from src.classifier import VariableStarClassifier
from src.stellar_properties import (
compute_stellar_properties,
print_summary_statistics,
save_pulsators_data
)
from src.visualization import create_all_plots
def main():
# Load data
print("Loading data from ASAS-SN catalog...")
df = pd.read_csv("data/asassn_variables.csv")
print(f"Loaded {len(df)} variable stars\n")
# ==================== Classification ====================
print("="*80)
print("STEP 1: Classifying Variable Stars")
print("="*80)
classifier = VariableStarClassifier()
# Prepare data
df, X, y = classifier.prepare_data(df)
# Train classifier
print("\nTraining Random Forest classifier...")
X_test_scaled, y_test, y_pred = classifier.train(X, y)
# Evaluate
classifier.evaluate(y_test, y_pred)
# Get feature importances
importances, feature_names, indices = classifier.get_feature_importances()
# ==================== Stellar Properties ====================
print("\n" + "="*80)
print("STEP 2: Computing Stellar Properties for Pulsating Stars")
print("="*80)
df_pulsators = compute_stellar_properties(df)
print(f"\nFound {len(df_pulsators)} pulsating stars with good parallax data")
# Print statistics
print_summary_statistics(df_pulsators)
# Save results
save_pulsators_data(df_pulsators, "outputs/pulsators_with_observables.csv")
# ==================== Visualization ====================
print("\n" + "="*80)
print("STEP 3: Creating Visualizations")
print("="*80)
print("\nGenerating plots...")
# Pass test/pred arrays and class labels into plotting utilities
classes = np.unique(y_test)
create_all_plots(df_pulsators, importances, feature_names, indices,
y_test=y_test, y_pred=y_pred, classes=classes)
print("\n" + "="*80)
print("Analysis Complete!")
print("="*80)
if __name__ == "__main__":
main()