From f55eecadd1de59932c9c39a394fff01c53e95d5d Mon Sep 17 00:00:00 2001 From: Sierd de Vries Date: Thu, 13 Nov 2025 16:42:30 +0100 Subject: [PATCH 1/2] Gui v0.2 added (#264) (#266) * add wind plotting functionality * Refactor GUI: Complete modular architecture with all GUI tabs extracted, export functionality, and utilities (#263) * Initial plan * Phase 1: Add constants, utility functions, and improve documentation * Phase 2: Extract helper methods and reduce code duplication * Phase 3: Add variable label/title constants and improve docstrings * Final: Add comprehensive refactoring documentation and summary * Add export functionality: PNG and MP4 animations for all visualizations * Phase 4: Begin code organization - extract utils module and create gui package structure * Add comprehensive additional improvements proposal document * bugfixes related to import and animattion functionality * updated structure for further refactoring * Refactor: Extract DomainVisualizer and rename gui_app_backup.py to application.py * bugfix * bugfix on loading domain * Refactor: Extract WindVisualizer to modular architecture * Refactor: Extract Output2DVisualizer for 2D NetCDF visualization * Refactor: Extract Output1DVisualizer - Complete modular architecture achieved! * bugfixes loading files * removed netcdf check * bugfixes after refractoring * bugfixes with domain overview * Speeding up complex drawing * hold on functionality added * Tab to run code added. * Update aeolis/gui/application.py * Update aeolis/gui/application.py * Update aeolis/gui/visualizers/domain.py * Update aeolis/gui/visualizers/domain.py * Update aeolis/gui/main.py * Update aeolis/gui/visualizers/output_2d.py * Apply suggestions from code review * Rename visualizers folder to gui_tabs and update all imports * bigfixes related to refactoring * reducing code lenght by omitting some redundancies * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review --------- --------- Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: Sierd <14054272+Sierd@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- REFACTORING_SUMMARY.md | 2 +- aeolis/gui/application.py | 152 +++++++-- aeolis/gui/gui_tabs/__init__.py | 16 + aeolis/gui/gui_tabs/domain.py | 307 ++++++++++++++++++ aeolis/gui/gui_tabs/model_runner.py | 177 ++++++++++ aeolis/gui/gui_tabs/output_1d.py | 463 ++++++++++++++++++++++++++ aeolis/gui/gui_tabs/output_2d.py | 482 ++++++++++++++++++++++++++++ aeolis/gui/gui_tabs/wind.py | 313 ++++++++++++++++++ aeolis/gui/main.py | 2 +- 9 files changed, 1889 insertions(+), 25 deletions(-) create mode 100644 aeolis/gui/gui_tabs/__init__.py create mode 100644 aeolis/gui/gui_tabs/domain.py create mode 100644 aeolis/gui/gui_tabs/model_runner.py create mode 100644 aeolis/gui/gui_tabs/output_1d.py create mode 100644 aeolis/gui/gui_tabs/output_2d.py create mode 100644 aeolis/gui/gui_tabs/wind.py diff --git a/REFACTORING_SUMMARY.md b/REFACTORING_SUMMARY.md index 03e1fae4..ea845ddc 100644 --- a/REFACTORING_SUMMARY.md +++ b/REFACTORING_SUMMARY.md @@ -211,7 +211,7 @@ The refactoring focused on code quality without changing functionality. Here are 1. **Phase 4 (Suggested)**: Split into multiple modules - `gui/main.py` - Main entry point - `gui/config_manager.py` - Configuration I/O - - `gui/visualizers.py` - Plotting functions + - `gui/gui_tabs/` - Tab modules for different visualizations - `gui/utils.py` - Utility functions 2. **Phase 5 (Suggested)**: Add unit tests diff --git a/aeolis/gui/application.py b/aeolis/gui/application.py index f50f6fa1..b1840bfe 100644 --- a/aeolis/gui/application.py +++ b/aeolis/gui/application.py @@ -7,7 +7,7 @@ - Plotting wind input data and wind roses - Visualizing model output (2D and 1D transects) -This is the main application module that coordinates the GUI and visualizers. +This is the main application module that coordinates the GUI and tab modules. """ import aeolis @@ -15,32 +15,24 @@ from tkinter import ttk, filedialog, messagebox import os import numpy as np -import traceback import netCDF4 -import matplotlib.pyplot as plt from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg from matplotlib.figure import Figure from aeolis.constants import DEFAULT_CONFIG # Import utilities from gui package from aeolis.gui.utils import ( - # Constants - HILLSHADE_AZIMUTH, HILLSHADE_ALTITUDE, HILLSHADE_AMBIENT, - TIME_UNIT_THRESHOLDS, TIME_UNIT_DIVISORS, - OCEAN_DEPTH_THRESHOLD, OCEAN_DISTANCE_THRESHOLD, SUBSAMPLE_RATE_DIVISOR, - NC_COORD_VARS, VARIABLE_LABELS, VARIABLE_TITLES, - # Utility functions - resolve_file_path, make_relative_path, determine_time_unit, - extract_time_slice, apply_hillshade + VARIABLE_LABELS, VARIABLE_TITLES, + resolve_file_path, make_relative_path ) -# Import visualizers -from aeolis.gui.visualizers.domain import DomainVisualizer -from aeolis.gui.visualizers.wind import WindVisualizer -from aeolis.gui.visualizers.output_2d import Output2DVisualizer -from aeolis.gui.visualizers.output_1d import Output1DVisualizer +# Import GUI tabs +from aeolis.gui.gui_tabs.domain import DomainVisualizer +from aeolis.gui.gui_tabs.wind import WindVisualizer +from aeolis.gui.gui_tabs.output_2d import Output2DVisualizer +from aeolis.gui.gui_tabs.output_1d import Output1DVisualizer +from aeolis.gui.gui_tabs.model_runner import ModelRunner -from windrose import WindroseAxes # Initialize with default configuration configfile = "No file selected" @@ -105,6 +97,7 @@ def create_widgets(self): self.create_input_file_tab(tab_control) self.create_domain_tab(tab_control) self.create_wind_input_tab(tab_control) + self.create_run_model_tab(tab_control) self.create_plot_output_2d_tab(tab_control) self.create_plot_output_1d_tab(tab_control) # Pack the tab control to expand and fill the available space @@ -118,6 +111,8 @@ def create_widgets(self): def on_tab_changed(self, event): """Handle tab change event to auto-plot domain/wind when tab is selected""" + global configfile + # Get the currently selected tab index selected_tab = self.tab_control.index(self.tab_control.select()) @@ -158,6 +153,12 @@ def on_tab_changed(self, event): except Exception as e: # Silently fail if plotting doesn't work (e.g., file doesn't exist) pass + + # Run Model tab is at index 3 (0: Input file, 1: Domain, 2: Wind, 3: Run Model, 4: Output 2D, 5: Output 1D) + elif selected_tab == 3: + # Update config file label + if hasattr(self, 'model_runner_visualizer'): + self.model_runner_visualizer.update_config_display(configfile) def create_label_entry(self, tab, text, value, row): # Create a label and entry widget for a given tab @@ -408,7 +409,7 @@ def load_new_config(self): wind_file = self.wind_file_entry.get() if wind_file and wind_file.strip(): self.load_and_plot_wind() - except: + except Exception: pass # Silently fail if tabs not yet initialized messagebox.showinfo("Success", f"Configuration loaded from:\n{file_path}") @@ -476,8 +477,8 @@ def toggle_y_limits(self): self.ymax_entry_1d.config(state='normal') # Update plot if data is loaded - if hasattr(self, 'nc_data_cache_1d') and self.nc_data_cache_1d is not None: - self.update_1d_plot() + if hasattr(self, 'output_1d_visualizer') and self.output_1d_visualizer.nc_data_cache_1d is not None: + self.output_1d_visualizer.update_plot() def load_and_plot_wind(self): """ @@ -631,7 +632,7 @@ def create_plot_output_2d_tab(self, tab_control): # Browse button for NC file nc_browse_btn = ttk.Button(file_frame, text="Browse...", - command=lambda: self.browse_nc_file()) + command=self.browse_nc_file) nc_browse_btn.grid(row=0, column=2, sticky=W, pady=2) # Variable selection dropdown @@ -787,7 +788,7 @@ def create_plot_output_1d_tab(self, tab_control): # Browse button for NC file nc_browse_btn_1d = ttk.Button(file_frame_1d, text="Browse...", - command=lambda: self.browse_nc_file_1d()) + command=self.browse_nc_file_1d) nc_browse_btn_1d.grid(row=0, column=2, sticky=W, pady=2) # Variable selection dropdown @@ -909,6 +910,16 @@ def create_plot_output_1d_tab(self, tab_control): self.time_slider_1d.pack(side=LEFT, fill=X, expand=1, padx=5) self.time_slider_1d.set(0) + # Hold On button + self.hold_on_btn_1d = ttk.Button(slider_frame_1d, text="Hold On", + command=self.toggle_hold_on_1d) + self.hold_on_btn_1d.pack(side=LEFT, padx=5) + + # Clear Held Plots button + self.clear_held_btn_1d = ttk.Button(slider_frame_1d, text="Clear Held", + command=self.clear_held_plots_1d) + self.clear_held_btn_1d.pack(side=LEFT, padx=5) + # Initialize 1D output visualizer (after all UI components are created) self.output_1d_visualizer = Output1DVisualizer( self.output_1d_ax, self.output_1d_overview_ax, @@ -918,7 +929,8 @@ def create_plot_output_1d_tab(self, tab_control): self.variable_var_1d, self.transect_direction_var, self.nc_file_entry_1d, self.variable_dropdown_1d, self.output_1d_overview_canvas, - self.get_config_dir, self.get_variable_label, self.get_variable_title + self.get_config_dir, self.get_variable_label, self.get_variable_title, + self.auto_ylimits_var, self.ymin_entry_1d, self.ymax_entry_1d ) # Update slider commands to use visualizer @@ -993,6 +1005,21 @@ def update_1d_plot(self): """ if hasattr(self, 'output_1d_visualizer'): self.output_1d_visualizer.update_plot() + + def toggle_hold_on_1d(self): + """ + Toggle hold on for the 1D transect plot. + This allows overlaying multiple time steps on the same plot. + """ + if hasattr(self, 'output_1d_visualizer'): + self.output_1d_visualizer.toggle_hold_on() + + def clear_held_plots_1d(self): + """ + Clear all held plots from the 1D transect visualization. + """ + if hasattr(self, 'output_1d_visualizer'): + self.output_1d_visualizer.clear_held_plots() def get_variable_label(self, var_name): """ @@ -1338,6 +1365,85 @@ def enable_overlay_vegetation(self): current_time = int(self.time_slider.get()) self.update_time_step(current_time) + def create_run_model_tab(self, tab_control): + """Create the 'Run Model' tab for executing AeoLiS simulations""" + tab_run = ttk.Frame(tab_control) + tab_control.add(tab_run, text='Run Model') + + # Configure grid weights + tab_run.columnconfigure(0, weight=1) + tab_run.rowconfigure(1, weight=1) + + # Create control frame + control_frame = ttk.LabelFrame(tab_run, text="Model Control", padding=10) + control_frame.grid(row=0, column=0, padx=10, pady=10, sticky=(N, W, E)) + + # Config file display + config_label = ttk.Label(control_frame, text="Config file:") + config_label.grid(row=0, column=0, sticky=W, pady=5) + + run_config_label = ttk.Label(control_frame, text="No file selected", + foreground="gray") + run_config_label.grid(row=0, column=1, sticky=W, pady=5, padx=(10, 0)) + + # Start/Stop buttons + button_frame = ttk.Frame(control_frame) + button_frame.grid(row=1, column=0, columnspan=2, pady=10) + + start_model_btn = ttk.Button(button_frame, text="Start Model", width=15) + start_model_btn.pack(side=LEFT, padx=5) + + stop_model_btn = ttk.Button(button_frame, text="Stop Model", + width=15, state=DISABLED) + stop_model_btn.pack(side=LEFT, padx=5) + + # Progress bar + model_progress = ttk.Progressbar(control_frame, mode='indeterminate', length=400) + model_progress.grid(row=2, column=0, columnspan=2, pady=5, sticky=(W, E)) + + # Status label + model_status_label = ttk.Label(control_frame, text="Ready", foreground="blue") + model_status_label.grid(row=3, column=0, columnspan=2, sticky=W, pady=5) + + # Create output frame for logging + output_frame = ttk.LabelFrame(tab_run, text="Model Output / Logging", padding=10) + output_frame.grid(row=1, column=0, padx=10, pady=(0, 10), sticky=(N, S, E, W)) + output_frame.rowconfigure(0, weight=1) + output_frame.columnconfigure(0, weight=1) + + # Create Text widget with scrollbar for terminal output + output_scroll = ttk.Scrollbar(output_frame) + output_scroll.grid(row=0, column=1, sticky=(N, S)) + + model_output_text = Text(output_frame, wrap=WORD, + yscrollcommand=output_scroll.set, + height=20, width=80, + bg='black', fg='lime', + font=('Courier', 9)) + model_output_text.grid(row=0, column=0, sticky=(N, S, E, W)) + output_scroll.config(command=model_output_text.yview) + + # Add clear button + clear_btn = ttk.Button(output_frame, text="Clear Output", + command=lambda: model_output_text.delete(1.0, END)) + clear_btn.grid(row=1, column=0, columnspan=2, pady=(5, 0)) + + # Initialize model runner visualizer + self.model_runner_visualizer = ModelRunner( + start_model_btn, stop_model_btn, model_progress, + model_status_label, model_output_text, run_config_label, + self.root, self.get_current_config_file + ) + + # Connect button commands + start_model_btn.config(command=self.model_runner_visualizer.start_model) + stop_model_btn.config(command=self.model_runner_visualizer.stop_model) + + def get_current_config_file(self): + """Get the current config file path""" + global configfile + return configfile + def save(self): # Save the current entries to the configuration dictionary for field, entry in self.entries.items(): diff --git a/aeolis/gui/gui_tabs/__init__.py b/aeolis/gui/gui_tabs/__init__.py new file mode 100644 index 00000000..a12c4774 --- /dev/null +++ b/aeolis/gui/gui_tabs/__init__.py @@ -0,0 +1,16 @@ +""" +GUI Tabs package for AeoLiS GUI. + +This package contains specialized tab modules for different types of data: +- domain: Domain setup visualization (bed, vegetation, etc.) +- wind: Wind input visualization (time series, wind roses) +- output_2d: 2D output visualization +- output_1d: 1D transect visualization +""" + +from aeolis.gui.gui_tabs.domain import DomainVisualizer +from aeolis.gui.gui_tabs.wind import WindVisualizer +from aeolis.gui.gui_tabs.output_2d import Output2DVisualizer +from aeolis.gui.gui_tabs.output_1d import Output1DVisualizer + +__all__ = ['DomainVisualizer', 'WindVisualizer', 'Output2DVisualizer', 'Output1DVisualizer'] diff --git a/aeolis/gui/gui_tabs/domain.py b/aeolis/gui/gui_tabs/domain.py new file mode 100644 index 00000000..d6039afe --- /dev/null +++ b/aeolis/gui/gui_tabs/domain.py @@ -0,0 +1,307 @@ +""" +Domain Visualizer Module + +Handles visualization of domain setup including: +- Bed elevation +- Vegetation distribution +- Ne (erodibility) parameter +- Combined bed + vegetation views +""" + +import os +import numpy as np +import traceback +from tkinter import messagebox +from aeolis.gui.utils import resolve_file_path + + +class DomainVisualizer: + """ + Visualizer for domain setup data (bed elevation, vegetation, etc.). + + Parameters + ---------- + ax : matplotlib.axes.Axes + The matplotlib axes to plot on + canvas : FigureCanvasTkAgg + The canvas to draw on + fig : matplotlib.figure.Figure + The figure containing the axes + get_entries_func : callable + Function to get entry widgets dictionary + get_config_dir_func : callable + Function to get configuration directory + """ + + def __init__(self, ax, canvas, fig, get_entries_func, get_config_dir_func): + self.ax = ax + self.canvas = canvas + self.fig = fig + self.get_entries = get_entries_func + self.get_config_dir = get_config_dir_func + self.colorbar = None + + def _load_grid_data(self, xgrid_file, ygrid_file, config_dir): + """ + Load x and y grid data if available. + + Parameters + ---------- + xgrid_file : str + Path to x-grid file (may be relative or absolute) + ygrid_file : str + Path to y-grid file (may be relative or absolute) + config_dir : str + Base directory for resolving relative paths + + Returns + ------- + tuple + (x_data, y_data) numpy arrays or (None, None) if not available + """ + x_data = None + y_data = None + + if xgrid_file: + xgrid_file_path = resolve_file_path(xgrid_file, config_dir) + if xgrid_file_path and os.path.exists(xgrid_file_path): + x_data = np.loadtxt(xgrid_file_path) + + if ygrid_file: + ygrid_file_path = resolve_file_path(ygrid_file, config_dir) + if ygrid_file_path and os.path.exists(ygrid_file_path): + y_data = np.loadtxt(ygrid_file_path) + + return x_data, y_data + + def _get_colormap_and_label(self, file_key): + """ + Get appropriate colormap and label for a given file type. + + Parameters + ---------- + file_key : str + File type key ('bed_file', 'ne_file', 'veg_file', etc.) + + Returns + ------- + tuple + (colormap_name, label_text) + """ + colormap_config = { + 'bed_file': ('terrain', 'Elevation (m)'), + 'ne_file': ('viridis', 'Ne'), + 'veg_file': ('Greens', 'Vegetation'), + } + return colormap_config.get(file_key, ('viridis', 'Value')) + + def _update_or_create_colorbar(self, im, label): + """ + Update existing colorbar or create a new one. + + Parameters + ---------- + im : mappable + The image/mesh object returned by pcolormesh or imshow + label : str + Colorbar label + + Returns + ------- + Colorbar + The updated or newly created colorbar + """ + if self.colorbar is not None: + try: + # Update existing colorbar + self.colorbar.update_normal(im) + self.colorbar.set_label(label) + return self.colorbar + except Exception: + # If update fails, create new one + pass + + # Create new colorbar + self.colorbar = self.fig.colorbar(im, ax=self.ax, label=label) + return self.colorbar + + def plot_data(self, file_key, title): + """ + Plot data from specified file (bed_file, ne_file, or veg_file). + + Parameters + ---------- + file_key : str + Key for the file entry (e.g., 'bed_file', 'ne_file', 'veg_file') + title : str + Plot title + """ + try: + # Clear the previous plot + self.ax.clear() + + # Get the file paths from the entries + entries = self.get_entries() + xgrid_file = entries['xgrid_file'].get() + ygrid_file = entries['ygrid_file'].get() + data_file = entries[file_key].get() + + # Check if files are specified + if not data_file: + messagebox.showwarning("Warning", f"No {file_key} specified!") + return + + # Get the directory of the config file to resolve relative paths + config_dir = self.get_config_dir() + + # Load the data file + data_file_path = resolve_file_path(data_file, config_dir) + if not data_file_path or not os.path.exists(data_file_path): + messagebox.showerror("Error", f"File not found: {data_file_path}") + return + + # Load data + z_data = np.loadtxt(data_file_path) + + # Try to load x and y grid data if available + x_data, y_data = self._load_grid_data(xgrid_file, ygrid_file, config_dir) + + # Choose colormap based on data type + cmap, label = self._get_colormap_and_label(file_key) + + # Use pcolormesh for 2D grid data with coordinates + im = self.ax.pcolormesh(x_data, y_data, z_data, shading='auto', cmap=cmap) + self.ax.set_xlabel('X (m)') + self.ax.set_ylabel('Y (m)') + + self.ax.set_title(title) + + # Handle colorbar properly to avoid shrinking + self.colorbar = self._update_or_create_colorbar(im, label) + + # Enforce equal aspect ratio in domain visualization + self.ax.set_aspect('equal', adjustable='box') + + # Redraw the canvas + self.canvas.draw() + + except Exception as e: + error_msg = f"Failed to plot {file_key}: {str(e)}\n\n{traceback.format_exc()}" + messagebox.showerror("Error", error_msg) + print(error_msg) + + def plot_combined(self): + """Plot bed elevation with vegetation overlay.""" + try: + # Clear the previous plot + self.ax.clear() + + # Get the file paths from the entries + entries = self.get_entries() + xgrid_file = entries['xgrid_file'].get() + ygrid_file = entries['ygrid_file'].get() + bed_file = entries['bed_file'].get() + veg_file = entries['veg_file'].get() + + # Check if files are specified + if not bed_file: + messagebox.showwarning("Warning", "No bed_file specified!") + return + if not veg_file: + messagebox.showwarning("Warning", "No veg_file specified!") + return + + # Get the directory of the config file to resolve relative paths + config_dir = self.get_config_dir() + + # Load the bed file + bed_file_path = resolve_file_path(bed_file, config_dir) + if not bed_file_path or not os.path.exists(bed_file_path): + messagebox.showerror("Error", f"Bed file not found: {bed_file_path}") + return + + # Load the vegetation file + veg_file_path = resolve_file_path(veg_file, config_dir) + if not veg_file_path or not os.path.exists(veg_file_path): + messagebox.showerror("Error", f"Vegetation file not found: {veg_file_path}") + return + + # Load data + bed_data = np.loadtxt(bed_file_path) + veg_data = np.loadtxt(veg_file_path) + + # Try to load x and y grid data if available + x_data, y_data = self._load_grid_data(xgrid_file, ygrid_file, config_dir) + + # Use pcolormesh for 2D grid data with coordinates + im = self.ax.pcolormesh(x_data, y_data, bed_data, shading='auto', cmap='terrain') + self.ax.set_xlabel('X (m)') + self.ax.set_ylabel('Y (m)') + + # Overlay vegetation as contours where vegetation exists + veg_mask = veg_data > 0 + if np.any(veg_mask): + # Create contour lines for vegetation + self.ax.contour(x_data, y_data, veg_data, levels=[0.5], + colors='darkgreen', linewidths=2) + # Fill vegetation areas with semi-transparent green + self.ax.contourf(x_data, y_data, veg_data, levels=[0.5, veg_data.max()], + colors=['green'], alpha=0.3) + + self.ax.set_title('Bed Elevation with Vegetation') + + # Handle colorbar properly to avoid shrinking + self.colorbar = self._update_or_create_colorbar(im, 'Elevation (m)') + + # Enforce equal aspect ratio in domain visualization + self.ax.set_aspect('equal', adjustable='box') + + # Redraw the canvas + self.canvas.draw() + + except Exception as e: + error_msg = f"Failed to plot combined view: {str(e)}\n\n{traceback.format_exc()}" + messagebox.showerror("Error", error_msg) + print(error_msg) + + def export_png(self, default_filename="domain_plot.png"): + """ + Export the current domain plot as PNG. + + Parameters + ---------- + default_filename : str + Default filename for the export dialog + + Returns + ------- + str or None + Path to saved file, or None if cancelled/failed + """ + from tkinter import filedialog + + # Open file dialog for saving + file_path = filedialog.asksaveasfilename( + initialdir=self.get_config_dir(), + title="Save plot as PNG", + defaultextension=".png", + initialfile=default_filename, + filetypes=(("PNG files", "*.png"), ("All files", "*.*")) + ) + + if file_path: + try: + # Ensure canvas is drawn before saving + self.canvas.draw() + # Use tight layout to ensure everything fits + self.fig.tight_layout() + # Save the figure + self.fig.savefig(file_path, dpi=300, bbox_inches='tight') + messagebox.showinfo("Success", f"Plot exported to:\n{file_path}") + return file_path + except Exception as e: + error_msg = f"Failed to export plot: {str(e)}\n\n{traceback.format_exc()}" + messagebox.showerror("Error", error_msg) + print(error_msg) + + return None diff --git a/aeolis/gui/gui_tabs/model_runner.py b/aeolis/gui/gui_tabs/model_runner.py new file mode 100644 index 00000000..9c83705d --- /dev/null +++ b/aeolis/gui/gui_tabs/model_runner.py @@ -0,0 +1,177 @@ +""" +Model Runner Module + +Handles running AeoLiS model simulations from the GUI including: +- Model execution in separate thread +- Real-time logging output capture +- Start/stop controls +- Progress indication +""" + +import os +import threading +import logging +import traceback +from tkinter import messagebox, END, NORMAL, DISABLED + + +class ModelRunner: + """ + Model runner for executing AeoLiS simulations from GUI. + + Handles model execution in a separate thread with real-time logging + output and user controls for starting/stopping the model. + """ + + def __init__(self, start_btn, stop_btn, progress_bar, status_label, + output_text, config_label, root, get_config_func): + """Initialize the model runner.""" + self.start_btn = start_btn + self.stop_btn = stop_btn + self.progress_bar = progress_bar + self.status_label = status_label + self.output_text = output_text + self.config_label = config_label + self.root = root + self.get_config = get_config_func + + self.model_runner = None + self.model_thread = None + self.model_running = False + + def start_model(self): + """Start the AeoLiS model run in a separate thread""" + configfile = self.get_config() + + # Check if config file is selected + if not configfile or configfile == "No file selected": + messagebox.showerror("Error", "Please select a configuration file first in the 'Read/Write Inputfile' tab.") + return + + if not os.path.exists(configfile): + messagebox.showerror("Error", f"Configuration file not found:\n{configfile}") + return + + # Update UI + self.config_label.config(text=os.path.basename(configfile), foreground="black") + self.status_label.config(text="Initializing model...", foreground="orange") + self.start_btn.config(state=DISABLED) + self.stop_btn.config(state=NORMAL) + self.progress_bar.start(10) + + # Clear output text + self.output_text.delete(1.0, END) + self.append_output("="*60 + "\n") + self.append_output(f"Starting AeoLiS model\n") + self.append_output(f"Config file: {configfile}\n") + self.append_output("="*60 + "\n\n") + + # Run model in separate thread to prevent GUI freezing + self.model_running = True + self.model_thread = threading.Thread(target=self.run_model_thread, + args=(configfile,), daemon=True) + self.model_thread.start() + + def stop_model(self): + """Stop the running model""" + if self.model_running: + self.model_running = False + self.status_label.config(text="Stopping model...", foreground="red") + self.append_output("\n" + "="*60 + "\n") + self.append_output("STOP requested by user\n") + self.append_output("="*60 + "\n") + + def run_model_thread(self, configfile): + """Run the model in a separate thread""" + try: + # Import here to avoid issues if aeolis.model is not available + from aeolis.model import AeoLiSRunner + + # Create custom logging handler to capture output + class TextHandler(logging.Handler): + def __init__(self, text_widget, gui_callback): + super().__init__() + self.text_widget = text_widget + self.gui_callback = gui_callback + + def emit(self, record): + msg = self.format(record) + # Schedule GUI update from main thread + self.gui_callback(msg + "\n") + + # Update status + self.root.after(0, lambda: self.status_label.config( + text="Running model...", foreground="green")) + + # Create model runner + self.model_runner = AeoLiSRunner(configfile=configfile) + + # Set up logging to capture to GUI + logger = logging.getLogger('aeolis') + text_handler = TextHandler(self.output_text, self.append_output_threadsafe) + text_handler.setLevel(logging.INFO) + text_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', + datefmt='%H:%M:%S')) + logger.addHandler(text_handler) + + # Run the model with a callback to check for stop requests + def check_stop(model): + if not self.model_running: + raise KeyboardInterrupt("Model stopped by user") + + try: + self.model_runner.run(callback=check_stop) + + # Model completed successfully + self.root.after(0, lambda: self.status_label.config( + text="Model completed successfully!", foreground="green")) + self.append_output_threadsafe("\n" + "="*60 + "\n") + self.append_output_threadsafe("Model run completed successfully!\n") + self.append_output_threadsafe("="*60 + "\n") + + except KeyboardInterrupt: + self.root.after(0, lambda: self.status_label.config( + text="Model stopped by user", foreground="red")) + except Exception as e: + error_msg = f"Model error: {str(e)}" + self.append_output_threadsafe(f"\nERROR: {error_msg}\n") + self.append_output_threadsafe(traceback.format_exc()) + self.root.after(0, lambda: self.status_label.config( + text="Model failed - see output", foreground="red")) + finally: + # Clean up + logger.removeHandler(text_handler) + + except Exception as e: + error_msg = f"Failed to start model: {str(e)}\n{traceback.format_exc()}" + self.append_output_threadsafe(error_msg) + self.root.after(0, lambda: self.status_label.config( + text="Failed to start model", foreground="red")) + + finally: + # Reset UI + self.model_running = False + self.root.after(0, self.reset_ui) + + def append_output(self, text): + """Append text to the output widget (must be called from main thread)""" + self.output_text.insert(END, text) + self.output_text.see(END) + self.output_text.update_idletasks() + + def append_output_threadsafe(self, text): + """Thread-safe version of append_output""" + self.root.after(0, lambda: self.append_output(text)) + + def reset_ui(self): + """Reset the UI elements after model run""" + self.start_btn.config(state=NORMAL) + self.stop_btn.config(state=DISABLED) + self.progress_bar.stop() + + def update_config_display(self, configfile): + """Update the config file display label""" + if configfile and configfile != "No file selected": + self.config_label.config(text=os.path.basename(configfile), foreground="black") + else: + self.config_label.config(text="No file selected", foreground="gray") diff --git a/aeolis/gui/gui_tabs/output_1d.py b/aeolis/gui/gui_tabs/output_1d.py new file mode 100644 index 00000000..a5e54ac4 --- /dev/null +++ b/aeolis/gui/gui_tabs/output_1d.py @@ -0,0 +1,463 @@ +""" +1D Output Visualizer Module + +Handles visualization of 1D transect data from NetCDF output including: +- Cross-shore and along-shore transects +- Time evolution with slider control +- Domain overview with transect indicator +- PNG and MP4 animation export +""" + +import os +import numpy as np +import traceback +import netCDF4 +from tkinter import messagebox, filedialog, Toplevel +from tkinter import ttk + + +from aeolis.gui.utils import ( + NC_COORD_VARS, + resolve_file_path, extract_time_slice +) + + +class Output1DVisualizer: + """ + Visualizer for 1D transect data from NetCDF output. + + Handles loading, plotting, and exporting 1D transect visualizations + with support for time evolution and domain overview. + """ + + def __init__(self, transect_ax, overview_ax, transect_canvas, transect_fig, + time_slider_1d, time_label_1d, transect_slider, transect_label, + variable_var_1d, direction_var, nc_file_entry_1d, + variable_dropdown_1d, overview_canvas, get_config_dir_func, + get_variable_label_func, get_variable_title_func, + auto_ylimits_var=None, ymin_entry=None, ymax_entry=None): + """Initialize the 1D output visualizer.""" + self.transect_ax = transect_ax + self.overview_ax = overview_ax + self.transect_canvas = transect_canvas + self.transect_fig = transect_fig + self.overview_canvas = overview_canvas + self.time_slider_1d = time_slider_1d + self.time_label_1d = time_label_1d + self.transect_slider = transect_slider + self.transect_label = transect_label + self.variable_var_1d = variable_var_1d + self.direction_var = direction_var + self.nc_file_entry_1d = nc_file_entry_1d + self.variable_dropdown_1d = variable_dropdown_1d + self.get_config_dir = get_config_dir_func + self.get_variable_label = get_variable_label_func + self.get_variable_title = get_variable_title_func + self.auto_ylimits_var = auto_ylimits_var + self.ymin_entry = ymin_entry + self.ymax_entry = ymax_entry + + self.nc_data_cache_1d = None + self.held_plots = [] # List of tuples: (time_idx, transect_data, x_data) + + def load_and_plot(self): + """Load NetCDF file and plot 1D transect data.""" + try: + nc_file = self.nc_file_entry_1d.get() + if not nc_file: + messagebox.showwarning("Warning", "No NetCDF file specified!") + return + + config_dir = self.get_config_dir() + nc_file_path = resolve_file_path(nc_file, config_dir) + if not nc_file_path or not os.path.exists(nc_file_path): + messagebox.showerror("Error", f"NetCDF file not found: {nc_file_path}") + return + + # Open NetCDF file and cache data + with netCDF4.Dataset(nc_file_path, 'r') as nc: + available_vars = list(nc.variables.keys()) + + # Get coordinates + x_data = nc.variables['x'][:] if 'x' in nc.variables else None + y_data = nc.variables['y'][:] if 'y' in nc.variables else None + + # Load variables + var_data_dict = {} + n_times = 1 + + for var_name in available_vars: + if var_name in NC_COORD_VARS: + continue + + var = nc.variables[var_name] + if 'time' in var.dimensions: + var_data = var[:] + if var_data.ndim < 3: + continue + n_times = max(n_times, var_data.shape[0]) + else: + if var.ndim != 2: + continue + var_data = np.expand_dims(var[:, :], axis=0) + + var_data_dict[var_name] = var_data + + if not var_data_dict: + messagebox.showerror("Error", "No valid variables found in NetCDF file!") + return + + # Update UI + candidate_vars = list(var_data_dict.keys()) + self.variable_dropdown_1d['values'] = sorted(candidate_vars) + if candidate_vars: + self.variable_var_1d.set(candidate_vars[0]) + + # Cache data + self.nc_data_cache_1d = { + 'file_path': nc_file_path, + 'vars': var_data_dict, + 'x': x_data, + 'y': y_data, + 'n_times': n_times + } + + # Get grid dimensions + first_var = list(var_data_dict.values())[0] + n_transects = first_var.shape[1] if self.direction_var.get() == 'cross-shore' else first_var.shape[2] + + # Setup sliders + self.time_slider_1d.config(to=n_times - 1) + self.time_slider_1d.set(0) + self.time_label_1d.config(text=f"Time step: 0 / {n_times-1}") + + self.transect_slider.config(to=n_transects - 1) + self.transect_slider.set(n_transects // 2) + self.transect_label.config(text=f"Transect: {n_transects // 2} / {n_transects-1}") + + # Plot initial data + self.update_plot() + + except Exception as e: + error_msg = f"Failed to load NetCDF: {str(e)}\n\n{traceback.format_exc()}" + messagebox.showerror("Error", error_msg) + print(error_msg) + + def update_transect_position(self, value): + """Update transect position from slider.""" + if not self.nc_data_cache_1d: + return + + transect_idx = int(float(value)) + first_var = list(self.nc_data_cache_1d['vars'].values())[0] + n_transects = first_var.shape[1] if self.direction_var.get() == 'cross-shore' else first_var.shape[2] + self.transect_label.config(text=f"Transect: {transect_idx} / {n_transects-1}") + + # Clear held plots when transect changes (they're from different transect) + self.held_plots = [] + + self.update_plot() + + def update_time_step(self, value): + """Update time step from slider.""" + if not self.nc_data_cache_1d: + return + + time_idx = int(float(value)) + n_times = self.nc_data_cache_1d['n_times'] + self.time_label_1d.config(text=f"Time step: {time_idx} / {n_times-1}") + + self.update_plot() + + def update_plot(self): + """Update the 1D transect plot with current settings.""" + if not self.nc_data_cache_1d: + return + + try: + # Always clear the axis to redraw + self.transect_ax.clear() + + time_idx = int(self.time_slider_1d.get()) + transect_idx = int(self.transect_slider.get()) + var_name = self.variable_var_1d.get() + direction = self.direction_var.get() + + if var_name not in self.nc_data_cache_1d['vars']: + messagebox.showwarning("Warning", f"Variable '{var_name}' not found!") + return + + # Get data + var_data = self.nc_data_cache_1d['vars'][var_name] + z_data = extract_time_slice(var_data, time_idx) + + # Extract transect + if direction == 'cross-shore': + transect_data = z_data[transect_idx, :] + x_data = self.nc_data_cache_1d['x'][transect_idx, :] if self.nc_data_cache_1d['x'].ndim == 2 else self.nc_data_cache_1d['x'] + xlabel = 'Cross-shore distance (m)' + else: # along-shore + transect_data = z_data[:, transect_idx] + x_data = self.nc_data_cache_1d['y'][:, transect_idx] if self.nc_data_cache_1d['y'].ndim == 2 else self.nc_data_cache_1d['y'] + xlabel = 'Along-shore distance (m)' + + # Redraw held plots first (if any) + if self.held_plots: + for held_time_idx, held_data, held_x_data in self.held_plots: + if held_x_data is not None: + self.transect_ax.plot(held_x_data, held_data, '--', linewidth=1.5, + alpha=0.7, label=f'Time: {held_time_idx}') + else: + self.transect_ax.plot(held_data, '--', linewidth=1.5, + alpha=0.7, label=f'Time: {held_time_idx}') + + # Plot current transect + if x_data is not None: + self.transect_ax.plot(x_data, transect_data, 'b-', linewidth=2, + label=f'Time: {time_idx}' if self.held_plots else None) + self.transect_ax.set_xlabel(xlabel) + else: + self.transect_ax.plot(transect_data, 'b-', linewidth=2, + label=f'Time: {time_idx}' if self.held_plots else None) + self.transect_ax.set_xlabel('Grid Index') + + ylabel = self.get_variable_label(var_name) + self.transect_ax.set_ylabel(ylabel) + + title = self.get_variable_title(var_name) + if self.held_plots: + self.transect_ax.set_title(f'{title} - {direction.capitalize()} (Transect: {transect_idx}) - Multiple Time Steps') + else: + self.transect_ax.set_title(f'{title} - {direction.capitalize()} (Time: {time_idx}, Transect: {transect_idx})') + self.transect_ax.grid(True, alpha=0.3) + + # Add legend if there are held plots + if self.held_plots: + self.transect_ax.legend(loc='best') + + # Apply Y-axis limits if not auto + if self.auto_ylimits_var is not None and self.ymin_entry is not None and self.ymax_entry is not None: + if not self.auto_ylimits_var.get(): + try: + ymin_str = self.ymin_entry.get().strip() + ymax_str = self.ymax_entry.get().strip() + if ymin_str and ymax_str: + ymin = float(ymin_str) + ymax = float(ymax_str) + self.transect_ax.set_ylim([ymin, ymax]) + except ValueError: + pass # Invalid input, keep auto limits + + # Update overview + self.update_overview(transect_idx) + + self.transect_canvas.draw_idle() + + except Exception as e: + error_msg = f"Failed to update 1D plot: {str(e)}\n\n{traceback.format_exc()}" + print(error_msg) + + def update_overview(self, transect_idx): + """Update the domain overview showing transect position.""" + if not self.nc_data_cache_1d: + return + + try: + self.overview_ax.clear() + + time_idx = int(self.time_slider_1d.get()) + var_name = self.variable_var_1d.get() + direction = self.direction_var.get() + + if var_name not in self.nc_data_cache_1d['vars']: + return + + # Get data for overview + var_data = self.nc_data_cache_1d['vars'][var_name] + z_data = extract_time_slice(var_data, time_idx) + + x_data = self.nc_data_cache_1d['x'] + y_data = self.nc_data_cache_1d['y'] + + # Plot domain overview with pcolormesh + self.overview_ax.pcolormesh(x_data, y_data, z_data, shading='auto', cmap='terrain') + + # Draw transect line + if direction == 'cross-shore': + if x_data.ndim == 2: + x_line = x_data[transect_idx, :] + y_line = y_data[transect_idx, :] + else: + x_line = x_data + y_line = np.full_like(x_data, y_data[transect_idx] if y_data.ndim == 1 else y_data[transect_idx, 0]) + else: # along-shore + if y_data.ndim == 2: + x_line = x_data[:, transect_idx] + y_line = y_data[:, transect_idx] + else: + y_line = y_data + x_line = np.full_like(y_data, x_data[transect_idx] if x_data.ndim == 1 else x_data[0, transect_idx]) + + self.overview_ax.plot(x_line, y_line, 'r-', linewidth=2, label='Transect') + self.overview_ax.set_xlabel('X (m)') + self.overview_ax.set_ylabel('Y (m)') + + self.overview_ax.set_title('Domain Overview') + self.overview_ax.legend() + + # Redraw the overview canvas + self.overview_canvas.draw_idle() + + except Exception as e: + error_msg = f"Failed to update overview: {str(e)}" + print(error_msg) + + def _add_current_to_held_plots(self): + """Helper method to add the current time step to held plots.""" + if not self.nc_data_cache_1d: + return + + time_idx = int(self.time_slider_1d.get()) + transect_idx = int(self.transect_slider.get()) + var_name = self.variable_var_1d.get() + direction = self.direction_var.get() + + if var_name not in self.nc_data_cache_1d['vars']: + return + + # Check if this time step is already in held plots + for held_time, _, _ in self.held_plots: + if held_time == time_idx: + return # Already held, don't add duplicate + + var_data = self.nc_data_cache_1d['vars'][var_name] + z_data = extract_time_slice(var_data, time_idx) + + # Extract transect + if direction == 'cross-shore': + transect_data = z_data[transect_idx, :] + x_data = self.nc_data_cache_1d['x'][transect_idx, :] if self.nc_data_cache_1d['x'].ndim == 2 else self.nc_data_cache_1d['x'] + else: # along-shore + transect_data = z_data[:, transect_idx] + x_data = self.nc_data_cache_1d['y'][:, transect_idx] if self.nc_data_cache_1d['y'].ndim == 2 else self.nc_data_cache_1d['y'] + + # Add to held plots + self.held_plots.append((time_idx, transect_data.copy(), x_data.copy() if x_data is not None else None)) + + def toggle_hold_on(self): + """ + Add the current plot to the collection of held plots. + This allows overlaying multiple time steps on the same plot. + """ + if not self.nc_data_cache_1d: + messagebox.showwarning("Warning", "Please load data first!") + return + + # Add current plot to held plots + self._add_current_to_held_plots() + self.update_plot() + + def clear_held_plots(self): + """Clear all held plots.""" + self.held_plots = [] + self.update_plot() + + def export_png(self, default_filename="output_1d.png"): + """Export current 1D plot as PNG.""" + if not self.transect_fig: + messagebox.showwarning("Warning", "No plot to export.") + return None + + file_path = filedialog.asksaveasfilename( + initialdir=self.get_config_dir(), + title="Save plot as PNG", + defaultextension=".png", + initialfile=default_filename, + filetypes=(("PNG files", "*.png"), ("All files", "*.*")) + ) + + if file_path: + try: + self.transect_fig.savefig(file_path, dpi=300, bbox_inches='tight') + messagebox.showinfo("Success", f"Plot exported to:\n{file_path}") + return file_path + except Exception as e: + error_msg = f"Failed to export: {str(e)}\n\n{traceback.format_exc()}" + messagebox.showerror("Error", error_msg) + print(error_msg) + return None + + def export_animation_mp4(self, default_filename="output_1d_animation.mp4"): + """Export 1D transect animation as MP4.""" + if not self.nc_data_cache_1d or self.nc_data_cache_1d['n_times'] <= 1: + messagebox.showwarning("Warning", "Need multiple time steps for animation.") + return None + + file_path = filedialog.asksaveasfilename( + initialdir=self.get_config_dir(), + title="Save animation as MP4", + defaultextension=".mp4", + initialfile=default_filename, + filetypes=(("MP4 files", "*.mp4"), ("All files", "*.*")) + ) + + if file_path: + try: + from matplotlib.animation import FuncAnimation, FFMpegWriter + + n_times = self.nc_data_cache_1d['n_times'] + progress_window = Toplevel() + progress_window.title("Exporting Animation") + progress_window.geometry("300x100") + progress_label = ttk.Label(progress_window, text="Creating animation...\nThis may take a few minutes.") + progress_label.pack(pady=20) + progress_bar = ttk.Progressbar(progress_window, mode='determinate', maximum=n_times) + progress_bar.pack(pady=10, padx=20, fill='x') + progress_window.update() + + original_time = int(self.time_slider_1d.get()) + + def update_frame(frame_num): + self.time_slider_1d.set(frame_num) + self.update_plot() + try: + if progress_window.winfo_exists(): + progress_bar['value'] = frame_num + 1 + progress_window.update() + except: + pass # Window may have been closed + return [] + + ani = FuncAnimation(self.transect_fig, update_frame, frames=n_times, + interval=200, blit=False, repeat=False) + writer = FFMpegWriter(fps=5, bitrate=1800) + ani.save(file_path, writer=writer) + + # Stop the animation by deleting the animation object + del ani + + self.time_slider_1d.set(original_time) + self.update_plot() + + try: + if progress_window.winfo_exists(): + progress_window.destroy() + except Exception: + pass # Window already destroyed + + messagebox.showinfo("Success", f"Animation exported to:\n{file_path}") + return file_path + + except ImportError: + messagebox.showerror("Error", "Animation export requires ffmpeg.") + except Exception as e: + error_msg = f"Failed to export animation: {str(e)}\n\n{traceback.format_exc()}" + messagebox.showerror("Error", error_msg) + print(error_msg) + finally: + try: + if 'progress_window' in locals() and progress_window.winfo_exists(): + progress_window.destroy() + except Exception: + pass # Window already destroyed + return None diff --git a/aeolis/gui/gui_tabs/output_2d.py b/aeolis/gui/gui_tabs/output_2d.py new file mode 100644 index 00000000..7cdff72d --- /dev/null +++ b/aeolis/gui/gui_tabs/output_2d.py @@ -0,0 +1,482 @@ +""" +2D Output Visualizer Module + +Handles visualization of 2D NetCDF output data including: +- Variable selection and plotting +- Time slider control +- Colorbar customization +- Special renderings (hillshade, quiver plots) +- PNG and MP4 export +""" + +import os +import numpy as np +import traceback +import netCDF4 +from tkinter import messagebox, filedialog, Toplevel +from tkinter import ttk +from matplotlib.cm import ScalarMappable +from matplotlib.colors import Normalize + +from aeolis.gui.utils import ( + NC_COORD_VARS, + resolve_file_path, extract_time_slice, apply_hillshade +) + + +class Output2DVisualizer: + """ + Visualizer for 2D NetCDF output data. + + Handles loading, plotting, and exporting 2D output visualizations with + support for multiple variables, time evolution, and special renderings. + """ + + def __init__(self, output_ax, output_canvas, output_fig, + output_colorbar_ref, time_slider, time_label, + variable_var_2d, colormap_var, auto_limits_var, + vmin_entry, vmax_entry, overlay_veg_var, + nc_file_entry, variable_dropdown_2d, + get_config_dir_func, get_variable_label_func, get_variable_title_func): + """Initialize the 2D output visualizer.""" + self.output_ax = output_ax + self.output_canvas = output_canvas + self.output_fig = output_fig + self.output_colorbar_ref = output_colorbar_ref + self.time_slider = time_slider + self.time_label = time_label + self.variable_var_2d = variable_var_2d + self.colormap_var = colormap_var + self.auto_limits_var = auto_limits_var + self.vmin_entry = vmin_entry + self.vmax_entry = vmax_entry + self.overlay_veg_var = overlay_veg_var + self.nc_file_entry = nc_file_entry + self.variable_dropdown_2d = variable_dropdown_2d + self.get_config_dir = get_config_dir_func + self.get_variable_label = get_variable_label_func + self.get_variable_title = get_variable_title_func + + self.nc_data_cache = None + + def on_variable_changed(self, event=None): + """Handle variable selection change.""" + self.update_plot() + + def load_and_plot(self): + """Load NetCDF file and plot 2D data.""" + try: + nc_file = self.nc_file_entry.get() + if not nc_file: + messagebox.showwarning("Warning", "No NetCDF file specified!") + return + + config_dir = self.get_config_dir() + nc_file_path = resolve_file_path(nc_file, config_dir) + if not nc_file_path or not os.path.exists(nc_file_path): + messagebox.showerror("Error", f"NetCDF file not found: {nc_file_path}") + return + + # Open NetCDF file and cache data + with netCDF4.Dataset(nc_file_path, 'r') as nc: + available_vars = list(nc.variables.keys()) + + # Get coordinates + x_data = nc.variables['x'][:] if 'x' in nc.variables else None + y_data = nc.variables['y'][:] if 'y' in nc.variables else None + + # Load variables + var_data_dict = {} + n_times = 1 + veg_data = None + + for var_name in available_vars: + if var_name in NC_COORD_VARS: + continue + + var = nc.variables[var_name] + if 'time' in var.dimensions: + var_data = var[:] + if var_data.ndim < 3: + continue + n_times = max(n_times, var_data.shape[0]) + else: + if var.ndim != 2: + continue + var_data = np.expand_dims(var[:, :], axis=0) + + var_data_dict[var_name] = var_data + + # Load vegetation if requested + if self.overlay_veg_var.get(): + for veg_name in ['rhoveg', 'vegetated', 'hveg', 'vegfac']: + if veg_name in available_vars: + veg_var = nc.variables[veg_name] + veg_data = veg_var[:] if 'time' in veg_var.dimensions else np.expand_dims(veg_var[:, :], axis=0) + break + + if not var_data_dict: + messagebox.showerror("Error", "No valid variables found in NetCDF file!") + return + + # Add special options + candidate_vars = list(var_data_dict.keys()) + if 'zb' in var_data_dict and 'rhoveg' in var_data_dict: + candidate_vars.append('zb+rhoveg') + if 'ustarn' in var_data_dict and 'ustars' in var_data_dict: + candidate_vars.append('ustar quiver') + + # Update UI + self.variable_dropdown_2d['values'] = sorted(candidate_vars) + if candidate_vars: + self.variable_var_2d.set(candidate_vars[0]) + + # Cache data + self.nc_data_cache = { + 'file_path': nc_file_path, + 'vars': var_data_dict, + 'x': x_data, + 'y': y_data, + 'n_times': n_times, + 'veg': veg_data + } + + # Setup time slider + self.time_slider.config(to=n_times - 1) + self.time_slider.set(0) + self.time_label.config(text=f"Time step: 0 / {n_times-1}") + + # Plot initial data + self.update_plot() + + except Exception as e: + error_msg = f"Failed to load NetCDF: {str(e)}\n\n{traceback.format_exc()}" + messagebox.showerror("Error", error_msg) + print(error_msg) + + def update_plot(self): + """Update the 2D plot with current settings.""" + if not self.nc_data_cache: + return + + try: + self.output_ax.clear() + time_idx = int(self.time_slider.get()) + var_name = self.variable_var_2d.get() + + # Update time label + n_times = self.nc_data_cache.get('n_times', 1) + self.time_label.config(text=f"Time step: {time_idx} / {n_times-1}") + + # Special renderings + if var_name == 'zb+rhoveg': + self._render_zb_rhoveg_shaded(time_idx) + return + if var_name == 'ustar quiver': + self._render_ustar_quiver(time_idx) + return + + if var_name not in self.nc_data_cache['vars']: + messagebox.showwarning("Warning", f"Variable '{var_name}' not found!") + return + + # Get data + var_data = self.nc_data_cache['vars'][var_name] + z_data = extract_time_slice(var_data, time_idx) + x_data = self.nc_data_cache['x'] + y_data = self.nc_data_cache['y'] + + # Get colorbar limits + vmin, vmax = None, None + if not self.auto_limits_var.get(): + try: + vmin_str = self.vmin_entry.get().strip() + vmax_str = self.vmax_entry.get().strip() + vmin = float(vmin_str) if vmin_str else None + vmax = float(vmax_str) if vmax_str else None + except ValueError: + messagebox.showwarning( + "Invalid Input", + "Colorbar limits must be valid numbers. Using automatic limits instead." + ) + + cmap = self.colormap_var.get() + + # Plot with pcolormesh (x and y always exist in AeoLiS NetCDF files) + im = self.output_ax.pcolormesh(x_data, y_data, z_data, shading='auto', + cmap=cmap, vmin=vmin, vmax=vmax) + self.output_ax.set_xlabel('X (m)') + self.output_ax.set_ylabel('Y (m)') + + title = self.get_variable_title(var_name) + self.output_ax.set_title(f'{title} (Time step: {time_idx})') + + # Update colorbar + self._update_colorbar(im, var_name) + + # Overlay vegetation + if self.overlay_veg_var.get() and self.nc_data_cache['veg'] is not None: + veg_slice = self.nc_data_cache['veg'] + veg_data = veg_slice[time_idx, :, :] if veg_slice.ndim == 3 else veg_slice[:, :] + self.output_ax.pcolormesh(x_data, y_data, veg_data, shading='auto', + cmap='Greens', vmin=0, vmax=1, alpha=0.4) + + self.output_canvas.draw_idle() + + except Exception as e: + error_msg = f"Failed to update 2D plot: {str(e)}\n\n{traceback.format_exc()}" + print(error_msg) + + def _update_colorbar(self, im, var_name): + """Update or create colorbar.""" + cbar_label = self.get_variable_label(var_name) + if self.output_colorbar_ref[0] is not None: + try: + self.output_colorbar_ref[0].update_normal(im) + self.output_colorbar_ref[0].set_label(cbar_label) + except Exception: + self.output_colorbar_ref[0] = self.output_fig.colorbar(im, ax=self.output_ax, label=cbar_label) + else: + self.output_colorbar_ref[0] = self.output_fig.colorbar(im, ax=self.output_ax, label=cbar_label) + + def export_png(self, default_filename="output_2d.png"): + """Export current 2D plot as PNG.""" + if not self.output_fig: + messagebox.showwarning("Warning", "No plot to export.") + return None + + file_path = filedialog.asksaveasfilename( + initialdir=self.get_config_dir(), + title="Save plot as PNG", + defaultextension=".png", + initialfile=default_filename, + filetypes=(("PNG files", "*.png"), ("All files", "*.*")) + ) + + if file_path: + try: + self.output_fig.savefig(file_path, dpi=300, bbox_inches='tight') + messagebox.showinfo("Success", f"Plot exported to:\n{file_path}") + return file_path + except Exception as e: + error_msg = f"Failed to export: {str(e)}\n\n{traceback.format_exc()}" + messagebox.showerror("Error", error_msg) + print(error_msg) + return None + + def export_animation_mp4(self, default_filename="output_2d_animation.mp4"): + """Export 2D plot animation as MP4.""" + if not self.nc_data_cache or self.nc_data_cache['n_times'] <= 1: + messagebox.showwarning("Warning", "Need multiple time steps for animation.") + return None + + file_path = filedialog.asksaveasfilename( + initialdir=self.get_config_dir(), + title="Save animation as MP4", + defaultextension=".mp4", + initialfile=default_filename, + filetypes=(("MP4 files", "*.mp4"), ("All files", "*.*")) + ) + + if file_path: + try: + from matplotlib.animation import FuncAnimation, FFMpegWriter + + n_times = self.nc_data_cache['n_times'] + progress_window = Toplevel() + progress_window.title("Exporting Animation") + progress_window.geometry("300x100") + progress_label = ttk.Label(progress_window, text="Creating animation...\nThis may take a few minutes.") + progress_label.pack(pady=20) + progress_bar = ttk.Progressbar(progress_window, mode='determinate', maximum=n_times) + progress_bar.pack(pady=10, padx=20, fill='x') + progress_window.update() + + original_time = int(self.time_slider.get()) + + def update_frame(frame_num): + self.time_slider.set(frame_num) + self.update_plot() + try: + if progress_window.winfo_exists(): + progress_bar['value'] = frame_num + 1 + progress_window.update() + except: + pass # Window may have been closed + return [] + + ani = FuncAnimation(self.output_fig, update_frame, frames=n_times, + interval=200, blit=False, repeat=False) + writer = FFMpegWriter(fps=5, bitrate=1800) + ani.save(file_path, writer=writer) + + # Stop the animation by deleting the animation object + del ani + + self.time_slider.set(original_time) + self.update_plot() + + try: + if progress_window.winfo_exists(): + progress_window.destroy() + except Exception: + pass # Window already destroyed + + messagebox.showinfo("Success", f"Animation exported to:\n{file_path}") + return file_path + + except ImportError: + messagebox.showerror("Error", "Animation export requires ffmpeg.") + except Exception as e: + error_msg = f"Failed to export animation: {str(e)}\n\n{traceback.format_exc()}" + messagebox.showerror("Error", error_msg) + print(error_msg) + finally: + try: + if 'progress_window' in locals() and progress_window.winfo_exists(): + progress_window.destroy() + except Exception: + pass # Window already destroyed + return None + + def _render_zb_rhoveg_shaded(self, time_idx): + """Render combined bed + vegetation with hillshading matching Anim2D_ShadeVeg.py.""" + try: + zb_data = extract_time_slice(self.nc_data_cache['vars']['zb'], time_idx) + rhoveg_data = extract_time_slice(self.nc_data_cache['vars']['rhoveg'], time_idx) + x_data = self.nc_data_cache['x'] + y_data = self.nc_data_cache['y'] + + # Normalize vegetation to [0,1] + veg_max = np.nanmax(rhoveg_data) + veg_norm = rhoveg_data / veg_max if (veg_max is not None and veg_max > 0) else np.clip(rhoveg_data, 0.0, 1.0) + veg_norm = np.clip(veg_norm, 0.0, 1.0) + + # Apply hillshade + x1d = x_data[0, :] if x_data.ndim == 2 else x_data + y1d = y_data[:, 0] if y_data.ndim == 2 else y_data + hillshade = apply_hillshade(zb_data, x1d, y1d, az_deg=155.0, alt_deg=5.0) + + # Color definitions + sand = np.array([1.0, 239.0/255.0, 213.0/255.0]) # light sand + darkgreen = np.array([34/255, 139/255, 34/255]) + ocean = np.array([70/255, 130/255, 180/255]) # steelblue + + # Create RGB array (ny, nx, 3) + ny, nx = zb_data.shape + rgb = np.zeros((ny, nx, 3), dtype=float) + + # Base color: blend sand and vegetation + for i in range(3): # R, G, B channels + rgb[:, :, i] = sand[i] * (1.0 - veg_norm) + darkgreen[i] * veg_norm + + # Apply ocean mask: zb < -0.5 and x < 200 + if x_data is not None: + X2d = x_data if x_data.ndim == 2 else np.meshgrid(x1d, y1d)[0] + ocean_mask = (zb_data < -0.5) & (X2d < 200) + rgb[ocean_mask] = ocean + + # Apply shading to all RGB channels + rgb *= hillshade[:, :, np.newaxis] + rgb = np.clip(rgb, 0.0, 1.0) + + # Plot RGB image + extent = [x1d.min(), x1d.max(), y1d.min(), y1d.max()] + self.output_ax.imshow(rgb, origin='lower', extent=extent, + interpolation='nearest', aspect='auto') + self.output_ax.set_xlabel('X (m)') + self.output_ax.set_ylabel('Y (m)') + + self.output_ax.set_title(f'Bed + Vegetation (Time step: {time_idx})') + + # Get colorbar limits for vegetation + vmin, vmax = 0, veg_max + if not self.auto_limits_var.get(): + try: + vmin_str = self.vmin_entry.get().strip() + vmax_str = self.vmax_entry.get().strip() + vmin = float(vmin_str) if vmin_str else 0 + vmax = float(vmax_str) if vmax_str else veg_max + except ValueError: + pass # Use default limits if invalid input + + # Create a ScalarMappable for the colorbar (showing vegetation density) + norm = Normalize(vmin=vmin, vmax=vmax) + sm = ScalarMappable(cmap='Greens', norm=norm) + sm.set_array(rhoveg_data) + + # Add colorbar for vegetation density + self._update_colorbar(sm, 'rhoveg') + + self.output_canvas.draw_idle() + except Exception as e: + print(f"Failed to render zb+rhoveg: {e}") + traceback.print_exc() + + def _render_ustar_quiver(self, time_idx): + """Render quiver plot of shear velocity with magnitude background.""" + try: + ustarn = extract_time_slice(self.nc_data_cache['vars']['ustarn'], time_idx) + ustars = extract_time_slice(self.nc_data_cache['vars']['ustars'], time_idx) + x_data = self.nc_data_cache['x'] + y_data = self.nc_data_cache['y'] + + # Calculate magnitude for background coloring + ustar_mag = np.sqrt(ustarn**2 + ustars**2) + + # Subsample for quiver + step = max(1, min(ustarn.shape) // 25) + + # Get colormap and limits + cmap = self.colormap_var.get() + vmin, vmax = None, None + if not self.auto_limits_var.get(): + try: + vmin_str = self.vmin_entry.get().strip() + vmax_str = self.vmax_entry.get().strip() + vmin = float(vmin_str) if vmin_str else None + vmax = float(vmax_str) if vmax_str else None + except ValueError: + pass # Use auto limits + + # Plot background field (magnitude) + im = self.output_ax.pcolormesh(x_data, y_data, ustar_mag, + shading='auto', cmap=cmap, + vmin=vmin, vmax=vmax, alpha=0.7) + + # Calculate appropriate scaling for arrows + x1d = x_data[0, :] if x_data.ndim == 2 else x_data + y1d = y_data[:, 0] if y_data.ndim == 2 else y_data + x_range = x1d.max() - x1d.min() + y_range = y1d.max() - y1d.min() + + # Calculate typical velocity magnitude (handle masked arrays) + valid_mag = np.asarray(ustar_mag[ustar_mag > 0]) + typical_vel = np.percentile(valid_mag, 75) if valid_mag.size > 0 else 1.0 + arrow_scale = typical_vel * 20 # Scale factor to make arrows visible + + # Add quiver plot with black arrows + Q = self.output_ax.quiver(x_data[::step, ::step], y_data[::step, ::step], + ustars[::step, ::step], ustarn[::step, ::step], + scale=arrow_scale, color='black', width=0.004, + headwidth=3, headlength=4, headaxislength=3.5, + zorder=10) + + # Add quiver key (legend for arrow scale) - placed to the right, above colorbar + self.output_ax.quiverkey(Q, 1.1, 1.05, typical_vel, + f'{typical_vel:.2f} m/s', + labelpos='N', coordinates='axes', + color='black', labelcolor='black', + fontproperties={'size': 9}) + + self.output_ax.set_xlabel('X (m)') + self.output_ax.set_ylabel('Y (m)') + self.output_ax.set_title(f'Shear Velocity (Time step: {time_idx})') + + # Update colorbar for magnitude + self._update_colorbar(im, 'ustar magnitude') + + self.output_canvas.draw_idle() + except Exception as e: + print(f"Failed to render ustar quiver: {e}") + traceback.print_exc() diff --git a/aeolis/gui/gui_tabs/wind.py b/aeolis/gui/gui_tabs/wind.py new file mode 100644 index 00000000..f4b7aa0e --- /dev/null +++ b/aeolis/gui/gui_tabs/wind.py @@ -0,0 +1,313 @@ +""" +Wind Visualizer Module + +Handles visualization of wind input data including: +- Wind speed time series +- Wind direction time series +- Wind rose diagrams +- PNG export for wind plots +""" + +import os +import numpy as np +import traceback +from tkinter import messagebox, filedialog +import matplotlib.patches as mpatches +from windrose import WindroseAxes +from aeolis.gui.utils import resolve_file_path, determine_time_unit + + +class WindVisualizer: + """ + Visualizer for wind input data (time series and wind rose). + + Parameters + ---------- + wind_speed_ax : matplotlib.axes.Axes + Axes for wind speed time series + wind_dir_ax : matplotlib.axes.Axes + Axes for wind direction time series + wind_ts_canvas : FigureCanvasTkAgg + Canvas for time series plots + wind_ts_fig : matplotlib.figure.Figure + Figure containing time series + windrose_fig : matplotlib.figure.Figure + Figure for wind rose + windrose_canvas : FigureCanvasTkAgg + Canvas for wind rose + get_wind_file_func : callable + Function to get wind file entry widget + get_entries_func : callable + Function to get all entry widgets + get_config_dir_func : callable + Function to get configuration directory + get_dic_func : callable + Function to get configuration dictionary + """ + + def __init__(self, wind_speed_ax, wind_dir_ax, wind_ts_canvas, wind_ts_fig, + windrose_fig, windrose_canvas, get_wind_file_func, get_entries_func, + get_config_dir_func, get_dic_func): + self.wind_speed_ax = wind_speed_ax + self.wind_dir_ax = wind_dir_ax + self.wind_ts_canvas = wind_ts_canvas + self.wind_ts_fig = wind_ts_fig + self.windrose_fig = windrose_fig + self.windrose_canvas = windrose_canvas + self.get_wind_file = get_wind_file_func + self.get_entries = get_entries_func + self.get_config_dir = get_config_dir_func + self.get_dic = get_dic_func + self.wind_data_cache = None + + def load_and_plot(self): + """Load wind file and plot time series and wind rose.""" + try: + # Get the wind file path + wind_file = self.get_wind_file().get() + + if not wind_file: + messagebox.showwarning("Warning", "No wind file specified!") + return + + # Get the directory of the config file to resolve relative paths + config_dir = self.get_config_dir() + + # Resolve wind file path + wind_file_path = resolve_file_path(wind_file, config_dir) + if not wind_file_path or not os.path.exists(wind_file_path): + messagebox.showerror("Error", f"Wind file not found: {wind_file_path}") + return + + # Check if we already loaded this file (avoid reloading) + if self.wind_data_cache and self.wind_data_cache.get('file_path') == wind_file_path: + # Data already loaded, just return (don't reload) + return + + # Load wind data (time, speed, direction) + wind_data = np.loadtxt(wind_file_path) + + # Check data format + if wind_data.ndim != 2 or wind_data.shape[1] < 3: + messagebox.showerror("Error", "Wind file must have at least 3 columns: time, speed, direction") + return + + time = wind_data[:, 0] + speed = wind_data[:, 1] + direction = wind_data[:, 2] + + # Get wind convention from config + dic = self.get_dic() + wind_convention = dic.get('wind_convention', 'nautical') + + # Cache the wind data along with file path and convention + self.wind_data_cache = { + 'file_path': wind_file_path, + 'time': time, + 'speed': speed, + 'direction': direction, + 'convention': wind_convention + } + + # Determine appropriate time unit based on simulation time (tstart and tstop) + tstart = 0 + tstop = 0 + use_sim_limits = False + + try: + entries = self.get_entries() + tstart_entry = entries.get('tstart') + tstop_entry = entries.get('tstop') + + if tstart_entry and tstop_entry: + tstart = float(tstart_entry.get() or 0) + tstop = float(tstop_entry.get() or 0) + if tstop > tstart: + sim_duration = tstop - tstart # in seconds + use_sim_limits = True + else: + sim_duration = time[-1] - time[0] if len(time) > 0 else 0 + else: + sim_duration = time[-1] - time[0] if len(time) > 0 else 0 + except (ValueError, AttributeError, TypeError): + sim_duration = time[-1] - time[0] if len(time) > 0 else 0 + + # Choose appropriate time unit and convert using utility function + time_unit, time_divisor = determine_time_unit(sim_duration) + time_converted = time / time_divisor + + # Plot wind speed time series + self.wind_speed_ax.clear() + self.wind_speed_ax.plot(time_converted, speed, 'b-', linewidth=1.5, zorder=2, label='Wind Speed') + self.wind_speed_ax.set_xlabel(f'Time ({time_unit})') + self.wind_speed_ax.set_ylabel('Wind Speed (m/s)') + self.wind_speed_ax.set_title('Wind Speed Time Series') + self.wind_speed_ax.grid(True, alpha=0.3, zorder=1) + + # Calculate axis limits with 10% padding and add shading + if use_sim_limits: + tstart_converted = tstart / time_divisor + tstop_converted = tstop / time_divisor + axis_range = tstop_converted - tstart_converted + padding = 0.1 * axis_range + xlim_min = tstart_converted - padding + xlim_max = tstop_converted + padding + + self.wind_speed_ax.set_xlim([xlim_min, xlim_max]) + self.wind_speed_ax.axvspan(xlim_min, tstart_converted, alpha=0.15, color='gray', zorder=3) + self.wind_speed_ax.axvspan(tstop_converted, xlim_max, alpha=0.15, color='gray', zorder=3) + + shaded_patch = mpatches.Patch(color='gray', alpha=0.15, label='Outside simulation time') + self.wind_speed_ax.legend(handles=[shaded_patch], loc='upper right', fontsize=8) + + # Plot wind direction time series + self.wind_dir_ax.clear() + self.wind_dir_ax.plot(time_converted, direction, 'r-', linewidth=1.5, zorder=2, label='Wind Direction') + self.wind_dir_ax.set_xlabel(f'Time ({time_unit})') + self.wind_dir_ax.set_ylabel('Wind Direction (degrees)') + self.wind_dir_ax.set_title(f'Wind Direction Time Series ({wind_convention} convention)') + self.wind_dir_ax.set_ylim([0, 360]) + self.wind_dir_ax.grid(True, alpha=0.3, zorder=1) + + if use_sim_limits: + self.wind_dir_ax.set_xlim([xlim_min, xlim_max]) + self.wind_dir_ax.axvspan(xlim_min, tstart_converted, alpha=0.15, color='gray', zorder=3) + self.wind_dir_ax.axvspan(tstop_converted, xlim_max, alpha=0.15, color='gray', zorder=3) + + shaded_patch = mpatches.Patch(color='gray', alpha=0.15, label='Outside simulation time') + self.wind_dir_ax.legend(handles=[shaded_patch], loc='upper right', fontsize=8) + + # Redraw time series canvas + self.wind_ts_canvas.draw() + + # Plot wind rose + self.plot_windrose(speed, direction, wind_convention) + + except Exception as e: + error_msg = f"Failed to load and plot wind data: {str(e)}\n\n{traceback.format_exc()}" + messagebox.showerror("Error", error_msg) + print(error_msg) + + def force_reload(self): + """Force reload of wind data by clearing cache.""" + self.wind_data_cache = None + self.load_and_plot() + + def plot_windrose(self, speed, direction, convention='nautical'): + """ + Plot wind rose diagram. + + Parameters + ---------- + speed : array + Wind speed values + direction : array + Wind direction values in degrees + convention : str + 'nautical' or 'cartesian' + """ + try: + # Clear the windrose figure + self.windrose_fig.clear() + + # Convert direction based on convention to meteorological standard + if convention == 'cartesian': + direction_met = (270 - direction) % 360 + else: + direction_met = direction + + # Create windrose axes + ax = WindroseAxes.from_ax(fig=self.windrose_fig) + ax.bar(direction_met, speed, normed=True, opening=0.8, edgecolor='white') + ax.set_legend(title='Wind Speed (m/s)') + ax.set_title(f'Wind Rose ({convention} convention)', fontsize=14, fontweight='bold') + + # Redraw windrose canvas + self.windrose_canvas.draw() + + except Exception as e: + error_msg = f"Failed to plot wind rose: {str(e)}\n\n{traceback.format_exc()}" + print(error_msg) + # Create a simple text message instead + self.windrose_fig.clear() + ax = self.windrose_fig.add_subplot(111) + ax.text(0.5, 0.5, 'Wind rose plot failed.\nSee console for details.', + ha='center', va='center', transform=ax.transAxes) + ax.axis('off') + self.windrose_canvas.draw() + + def export_timeseries_png(self, default_filename="wind_timeseries.png"): + """ + Export the wind time series plot as PNG. + + Parameters + ---------- + default_filename : str + Default filename for the export dialog + + Returns + ------- + str or None + Path to saved file, or None if cancelled/failed + """ + if self.wind_ts_fig is None: + messagebox.showwarning("Warning", "No wind plot to export. Please load wind data first.") + return None + + file_path = filedialog.asksaveasfilename( + initialdir=self.get_config_dir(), + title="Save wind time series as PNG", + defaultextension=".png", + initialfile=default_filename, + filetypes=(("PNG files", "*.png"), ("All files", "*.*")) + ) + + if file_path: + try: + self.wind_ts_fig.savefig(file_path, dpi=300, bbox_inches='tight') + messagebox.showinfo("Success", f"Wind time series exported to:\n{file_path}") + return file_path + except Exception as e: + error_msg = f"Failed to export plot: {str(e)}\n\n{traceback.format_exc()}" + messagebox.showerror("Error", error_msg) + print(error_msg) + + return None + + def export_windrose_png(self, default_filename="wind_rose.png"): + """ + Export the wind rose plot as PNG. + + Parameters + ---------- + default_filename : str + Default filename for the export dialog + + Returns + ------- + str or None + Path to saved file, or None if cancelled/failed + """ + if self.windrose_fig is None: + messagebox.showwarning("Warning", "No wind rose plot to export. Please load wind data first.") + return None + + file_path = filedialog.asksaveasfilename( + initialdir=self.get_config_dir(), + title="Save wind rose as PNG", + defaultextension=".png", + initialfile=default_filename, + filetypes=(("PNG files", "*.png"), ("All files", "*.*")) + ) + + if file_path: + try: + self.windrose_fig.savefig(file_path, dpi=300, bbox_inches='tight') + messagebox.showinfo("Success", f"Wind rose exported to:\n{file_path}") + return file_path + except Exception as e: + error_msg = f"Failed to export plot: {str(e)}\n\n{traceback.format_exc()}" + messagebox.showerror("Error", error_msg) + print(error_msg) + + return None diff --git a/aeolis/gui/main.py b/aeolis/gui/main.py index 5b249435..10155a8b 100644 --- a/aeolis/gui/main.py +++ b/aeolis/gui/main.py @@ -23,7 +23,7 @@ def launch_gui(): root = Tk() # Create an instance of the AeolisGUI class - app = AeolisGUI(root, dic) + AeolisGUI(root, dic) # Bring window to front and give it focus root.lift() From 6a4a652631dcdd107e8326970d043119a7ef1be1 Mon Sep 17 00:00:00 2001 From: Sierd de Vries Date: Thu, 13 Nov 2025 16:46:20 +0100 Subject: [PATCH 2/2] Revert "Gui v0.2 added (#264) (#266)" This reverts commit f55eecadd1de59932c9c39a394fff01c53e95d5d. --- REFACTORING_SUMMARY.md | 2 +- aeolis/gui/application.py | 152 ++------- aeolis/gui/gui_tabs/__init__.py | 16 - aeolis/gui/gui_tabs/domain.py | 307 ------------------ aeolis/gui/gui_tabs/model_runner.py | 177 ---------- aeolis/gui/gui_tabs/output_1d.py | 463 -------------------------- aeolis/gui/gui_tabs/output_2d.py | 482 ---------------------------- aeolis/gui/gui_tabs/wind.py | 313 ------------------ aeolis/gui/main.py | 2 +- 9 files changed, 25 insertions(+), 1889 deletions(-) delete mode 100644 aeolis/gui/gui_tabs/__init__.py delete mode 100644 aeolis/gui/gui_tabs/domain.py delete mode 100644 aeolis/gui/gui_tabs/model_runner.py delete mode 100644 aeolis/gui/gui_tabs/output_1d.py delete mode 100644 aeolis/gui/gui_tabs/output_2d.py delete mode 100644 aeolis/gui/gui_tabs/wind.py diff --git a/REFACTORING_SUMMARY.md b/REFACTORING_SUMMARY.md index ea845ddc..03e1fae4 100644 --- a/REFACTORING_SUMMARY.md +++ b/REFACTORING_SUMMARY.md @@ -211,7 +211,7 @@ The refactoring focused on code quality without changing functionality. Here are 1. **Phase 4 (Suggested)**: Split into multiple modules - `gui/main.py` - Main entry point - `gui/config_manager.py` - Configuration I/O - - `gui/gui_tabs/` - Tab modules for different visualizations + - `gui/visualizers.py` - Plotting functions - `gui/utils.py` - Utility functions 2. **Phase 5 (Suggested)**: Add unit tests diff --git a/aeolis/gui/application.py b/aeolis/gui/application.py index b1840bfe..f50f6fa1 100644 --- a/aeolis/gui/application.py +++ b/aeolis/gui/application.py @@ -7,7 +7,7 @@ - Plotting wind input data and wind roses - Visualizing model output (2D and 1D transects) -This is the main application module that coordinates the GUI and tab modules. +This is the main application module that coordinates the GUI and visualizers. """ import aeolis @@ -15,24 +15,32 @@ from tkinter import ttk, filedialog, messagebox import os import numpy as np +import traceback import netCDF4 +import matplotlib.pyplot as plt from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg from matplotlib.figure import Figure from aeolis.constants import DEFAULT_CONFIG # Import utilities from gui package from aeolis.gui.utils import ( - VARIABLE_LABELS, VARIABLE_TITLES, - resolve_file_path, make_relative_path + # Constants + HILLSHADE_AZIMUTH, HILLSHADE_ALTITUDE, HILLSHADE_AMBIENT, + TIME_UNIT_THRESHOLDS, TIME_UNIT_DIVISORS, + OCEAN_DEPTH_THRESHOLD, OCEAN_DISTANCE_THRESHOLD, SUBSAMPLE_RATE_DIVISOR, + NC_COORD_VARS, VARIABLE_LABELS, VARIABLE_TITLES, + # Utility functions + resolve_file_path, make_relative_path, determine_time_unit, + extract_time_slice, apply_hillshade ) -# Import GUI tabs -from aeolis.gui.gui_tabs.domain import DomainVisualizer -from aeolis.gui.gui_tabs.wind import WindVisualizer -from aeolis.gui.gui_tabs.output_2d import Output2DVisualizer -from aeolis.gui.gui_tabs.output_1d import Output1DVisualizer -from aeolis.gui.gui_tabs.model_runner import ModelRunner +# Import visualizers +from aeolis.gui.visualizers.domain import DomainVisualizer +from aeolis.gui.visualizers.wind import WindVisualizer +from aeolis.gui.visualizers.output_2d import Output2DVisualizer +from aeolis.gui.visualizers.output_1d import Output1DVisualizer +from windrose import WindroseAxes # Initialize with default configuration configfile = "No file selected" @@ -97,7 +105,6 @@ def create_widgets(self): self.create_input_file_tab(tab_control) self.create_domain_tab(tab_control) self.create_wind_input_tab(tab_control) - self.create_run_model_tab(tab_control) self.create_plot_output_2d_tab(tab_control) self.create_plot_output_1d_tab(tab_control) # Pack the tab control to expand and fill the available space @@ -111,8 +118,6 @@ def create_widgets(self): def on_tab_changed(self, event): """Handle tab change event to auto-plot domain/wind when tab is selected""" - global configfile - # Get the currently selected tab index selected_tab = self.tab_control.index(self.tab_control.select()) @@ -153,12 +158,6 @@ def on_tab_changed(self, event): except Exception as e: # Silently fail if plotting doesn't work (e.g., file doesn't exist) pass - - # Run Model tab is at index 3 (0: Input file, 1: Domain, 2: Wind, 3: Run Model, 4: Output 2D, 5: Output 1D) - elif selected_tab == 3: - # Update config file label - if hasattr(self, 'model_runner_visualizer'): - self.model_runner_visualizer.update_config_display(configfile) def create_label_entry(self, tab, text, value, row): # Create a label and entry widget for a given tab @@ -409,7 +408,7 @@ def load_new_config(self): wind_file = self.wind_file_entry.get() if wind_file and wind_file.strip(): self.load_and_plot_wind() - except Exception: + except: pass # Silently fail if tabs not yet initialized messagebox.showinfo("Success", f"Configuration loaded from:\n{file_path}") @@ -477,8 +476,8 @@ def toggle_y_limits(self): self.ymax_entry_1d.config(state='normal') # Update plot if data is loaded - if hasattr(self, 'output_1d_visualizer') and self.output_1d_visualizer.nc_data_cache_1d is not None: - self.output_1d_visualizer.update_plot() + if hasattr(self, 'nc_data_cache_1d') and self.nc_data_cache_1d is not None: + self.update_1d_plot() def load_and_plot_wind(self): """ @@ -632,7 +631,7 @@ def create_plot_output_2d_tab(self, tab_control): # Browse button for NC file nc_browse_btn = ttk.Button(file_frame, text="Browse...", - command=self.browse_nc_file) + command=lambda: self.browse_nc_file()) nc_browse_btn.grid(row=0, column=2, sticky=W, pady=2) # Variable selection dropdown @@ -788,7 +787,7 @@ def create_plot_output_1d_tab(self, tab_control): # Browse button for NC file nc_browse_btn_1d = ttk.Button(file_frame_1d, text="Browse...", - command=self.browse_nc_file_1d) + command=lambda: self.browse_nc_file_1d()) nc_browse_btn_1d.grid(row=0, column=2, sticky=W, pady=2) # Variable selection dropdown @@ -910,16 +909,6 @@ def create_plot_output_1d_tab(self, tab_control): self.time_slider_1d.pack(side=LEFT, fill=X, expand=1, padx=5) self.time_slider_1d.set(0) - # Hold On button - self.hold_on_btn_1d = ttk.Button(slider_frame_1d, text="Hold On", - command=self.toggle_hold_on_1d) - self.hold_on_btn_1d.pack(side=LEFT, padx=5) - - # Clear Held Plots button - self.clear_held_btn_1d = ttk.Button(slider_frame_1d, text="Clear Held", - command=self.clear_held_plots_1d) - self.clear_held_btn_1d.pack(side=LEFT, padx=5) - # Initialize 1D output visualizer (after all UI components are created) self.output_1d_visualizer = Output1DVisualizer( self.output_1d_ax, self.output_1d_overview_ax, @@ -929,8 +918,7 @@ def create_plot_output_1d_tab(self, tab_control): self.variable_var_1d, self.transect_direction_var, self.nc_file_entry_1d, self.variable_dropdown_1d, self.output_1d_overview_canvas, - self.get_config_dir, self.get_variable_label, self.get_variable_title, - self.auto_ylimits_var, self.ymin_entry_1d, self.ymax_entry_1d + self.get_config_dir, self.get_variable_label, self.get_variable_title ) # Update slider commands to use visualizer @@ -1005,21 +993,6 @@ def update_1d_plot(self): """ if hasattr(self, 'output_1d_visualizer'): self.output_1d_visualizer.update_plot() - - def toggle_hold_on_1d(self): - """ - Toggle hold on for the 1D transect plot. - This allows overlaying multiple time steps on the same plot. - """ - if hasattr(self, 'output_1d_visualizer'): - self.output_1d_visualizer.toggle_hold_on() - - def clear_held_plots_1d(self): - """ - Clear all held plots from the 1D transect visualization. - """ - if hasattr(self, 'output_1d_visualizer'): - self.output_1d_visualizer.clear_held_plots() def get_variable_label(self, var_name): """ @@ -1365,85 +1338,6 @@ def enable_overlay_vegetation(self): current_time = int(self.time_slider.get()) self.update_time_step(current_time) - def create_run_model_tab(self, tab_control): - """Create the 'Run Model' tab for executing AeoLiS simulations""" - tab_run = ttk.Frame(tab_control) - tab_control.add(tab_run, text='Run Model') - - # Configure grid weights - tab_run.columnconfigure(0, weight=1) - tab_run.rowconfigure(1, weight=1) - - # Create control frame - control_frame = ttk.LabelFrame(tab_run, text="Model Control", padding=10) - control_frame.grid(row=0, column=0, padx=10, pady=10, sticky=(N, W, E)) - - # Config file display - config_label = ttk.Label(control_frame, text="Config file:") - config_label.grid(row=0, column=0, sticky=W, pady=5) - - run_config_label = ttk.Label(control_frame, text="No file selected", - foreground="gray") - run_config_label.grid(row=0, column=1, sticky=W, pady=5, padx=(10, 0)) - - # Start/Stop buttons - button_frame = ttk.Frame(control_frame) - button_frame.grid(row=1, column=0, columnspan=2, pady=10) - - start_model_btn = ttk.Button(button_frame, text="Start Model", width=15) - start_model_btn.pack(side=LEFT, padx=5) - - stop_model_btn = ttk.Button(button_frame, text="Stop Model", - width=15, state=DISABLED) - stop_model_btn.pack(side=LEFT, padx=5) - - # Progress bar - model_progress = ttk.Progressbar(control_frame, mode='indeterminate', length=400) - model_progress.grid(row=2, column=0, columnspan=2, pady=5, sticky=(W, E)) - - # Status label - model_status_label = ttk.Label(control_frame, text="Ready", foreground="blue") - model_status_label.grid(row=3, column=0, columnspan=2, sticky=W, pady=5) - - # Create output frame for logging - output_frame = ttk.LabelFrame(tab_run, text="Model Output / Logging", padding=10) - output_frame.grid(row=1, column=0, padx=10, pady=(0, 10), sticky=(N, S, E, W)) - output_frame.rowconfigure(0, weight=1) - output_frame.columnconfigure(0, weight=1) - - # Create Text widget with scrollbar for terminal output - output_scroll = ttk.Scrollbar(output_frame) - output_scroll.grid(row=0, column=1, sticky=(N, S)) - - model_output_text = Text(output_frame, wrap=WORD, - yscrollcommand=output_scroll.set, - height=20, width=80, - bg='black', fg='lime', - font=('Courier', 9)) - model_output_text.grid(row=0, column=0, sticky=(N, S, E, W)) - output_scroll.config(command=model_output_text.yview) - - # Add clear button - clear_btn = ttk.Button(output_frame, text="Clear Output", - command=lambda: model_output_text.delete(1.0, END)) - clear_btn.grid(row=1, column=0, columnspan=2, pady=(5, 0)) - - # Initialize model runner visualizer - self.model_runner_visualizer = ModelRunner( - start_model_btn, stop_model_btn, model_progress, - model_status_label, model_output_text, run_config_label, - self.root, self.get_current_config_file - ) - - # Connect button commands - start_model_btn.config(command=self.model_runner_visualizer.start_model) - stop_model_btn.config(command=self.model_runner_visualizer.stop_model) - - def get_current_config_file(self): - """Get the current config file path""" - global configfile - return configfile - def save(self): # Save the current entries to the configuration dictionary for field, entry in self.entries.items(): diff --git a/aeolis/gui/gui_tabs/__init__.py b/aeolis/gui/gui_tabs/__init__.py deleted file mode 100644 index a12c4774..00000000 --- a/aeolis/gui/gui_tabs/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -GUI Tabs package for AeoLiS GUI. - -This package contains specialized tab modules for different types of data: -- domain: Domain setup visualization (bed, vegetation, etc.) -- wind: Wind input visualization (time series, wind roses) -- output_2d: 2D output visualization -- output_1d: 1D transect visualization -""" - -from aeolis.gui.gui_tabs.domain import DomainVisualizer -from aeolis.gui.gui_tabs.wind import WindVisualizer -from aeolis.gui.gui_tabs.output_2d import Output2DVisualizer -from aeolis.gui.gui_tabs.output_1d import Output1DVisualizer - -__all__ = ['DomainVisualizer', 'WindVisualizer', 'Output2DVisualizer', 'Output1DVisualizer'] diff --git a/aeolis/gui/gui_tabs/domain.py b/aeolis/gui/gui_tabs/domain.py deleted file mode 100644 index d6039afe..00000000 --- a/aeolis/gui/gui_tabs/domain.py +++ /dev/null @@ -1,307 +0,0 @@ -""" -Domain Visualizer Module - -Handles visualization of domain setup including: -- Bed elevation -- Vegetation distribution -- Ne (erodibility) parameter -- Combined bed + vegetation views -""" - -import os -import numpy as np -import traceback -from tkinter import messagebox -from aeolis.gui.utils import resolve_file_path - - -class DomainVisualizer: - """ - Visualizer for domain setup data (bed elevation, vegetation, etc.). - - Parameters - ---------- - ax : matplotlib.axes.Axes - The matplotlib axes to plot on - canvas : FigureCanvasTkAgg - The canvas to draw on - fig : matplotlib.figure.Figure - The figure containing the axes - get_entries_func : callable - Function to get entry widgets dictionary - get_config_dir_func : callable - Function to get configuration directory - """ - - def __init__(self, ax, canvas, fig, get_entries_func, get_config_dir_func): - self.ax = ax - self.canvas = canvas - self.fig = fig - self.get_entries = get_entries_func - self.get_config_dir = get_config_dir_func - self.colorbar = None - - def _load_grid_data(self, xgrid_file, ygrid_file, config_dir): - """ - Load x and y grid data if available. - - Parameters - ---------- - xgrid_file : str - Path to x-grid file (may be relative or absolute) - ygrid_file : str - Path to y-grid file (may be relative or absolute) - config_dir : str - Base directory for resolving relative paths - - Returns - ------- - tuple - (x_data, y_data) numpy arrays or (None, None) if not available - """ - x_data = None - y_data = None - - if xgrid_file: - xgrid_file_path = resolve_file_path(xgrid_file, config_dir) - if xgrid_file_path and os.path.exists(xgrid_file_path): - x_data = np.loadtxt(xgrid_file_path) - - if ygrid_file: - ygrid_file_path = resolve_file_path(ygrid_file, config_dir) - if ygrid_file_path and os.path.exists(ygrid_file_path): - y_data = np.loadtxt(ygrid_file_path) - - return x_data, y_data - - def _get_colormap_and_label(self, file_key): - """ - Get appropriate colormap and label for a given file type. - - Parameters - ---------- - file_key : str - File type key ('bed_file', 'ne_file', 'veg_file', etc.) - - Returns - ------- - tuple - (colormap_name, label_text) - """ - colormap_config = { - 'bed_file': ('terrain', 'Elevation (m)'), - 'ne_file': ('viridis', 'Ne'), - 'veg_file': ('Greens', 'Vegetation'), - } - return colormap_config.get(file_key, ('viridis', 'Value')) - - def _update_or_create_colorbar(self, im, label): - """ - Update existing colorbar or create a new one. - - Parameters - ---------- - im : mappable - The image/mesh object returned by pcolormesh or imshow - label : str - Colorbar label - - Returns - ------- - Colorbar - The updated or newly created colorbar - """ - if self.colorbar is not None: - try: - # Update existing colorbar - self.colorbar.update_normal(im) - self.colorbar.set_label(label) - return self.colorbar - except Exception: - # If update fails, create new one - pass - - # Create new colorbar - self.colorbar = self.fig.colorbar(im, ax=self.ax, label=label) - return self.colorbar - - def plot_data(self, file_key, title): - """ - Plot data from specified file (bed_file, ne_file, or veg_file). - - Parameters - ---------- - file_key : str - Key for the file entry (e.g., 'bed_file', 'ne_file', 'veg_file') - title : str - Plot title - """ - try: - # Clear the previous plot - self.ax.clear() - - # Get the file paths from the entries - entries = self.get_entries() - xgrid_file = entries['xgrid_file'].get() - ygrid_file = entries['ygrid_file'].get() - data_file = entries[file_key].get() - - # Check if files are specified - if not data_file: - messagebox.showwarning("Warning", f"No {file_key} specified!") - return - - # Get the directory of the config file to resolve relative paths - config_dir = self.get_config_dir() - - # Load the data file - data_file_path = resolve_file_path(data_file, config_dir) - if not data_file_path or not os.path.exists(data_file_path): - messagebox.showerror("Error", f"File not found: {data_file_path}") - return - - # Load data - z_data = np.loadtxt(data_file_path) - - # Try to load x and y grid data if available - x_data, y_data = self._load_grid_data(xgrid_file, ygrid_file, config_dir) - - # Choose colormap based on data type - cmap, label = self._get_colormap_and_label(file_key) - - # Use pcolormesh for 2D grid data with coordinates - im = self.ax.pcolormesh(x_data, y_data, z_data, shading='auto', cmap=cmap) - self.ax.set_xlabel('X (m)') - self.ax.set_ylabel('Y (m)') - - self.ax.set_title(title) - - # Handle colorbar properly to avoid shrinking - self.colorbar = self._update_or_create_colorbar(im, label) - - # Enforce equal aspect ratio in domain visualization - self.ax.set_aspect('equal', adjustable='box') - - # Redraw the canvas - self.canvas.draw() - - except Exception as e: - error_msg = f"Failed to plot {file_key}: {str(e)}\n\n{traceback.format_exc()}" - messagebox.showerror("Error", error_msg) - print(error_msg) - - def plot_combined(self): - """Plot bed elevation with vegetation overlay.""" - try: - # Clear the previous plot - self.ax.clear() - - # Get the file paths from the entries - entries = self.get_entries() - xgrid_file = entries['xgrid_file'].get() - ygrid_file = entries['ygrid_file'].get() - bed_file = entries['bed_file'].get() - veg_file = entries['veg_file'].get() - - # Check if files are specified - if not bed_file: - messagebox.showwarning("Warning", "No bed_file specified!") - return - if not veg_file: - messagebox.showwarning("Warning", "No veg_file specified!") - return - - # Get the directory of the config file to resolve relative paths - config_dir = self.get_config_dir() - - # Load the bed file - bed_file_path = resolve_file_path(bed_file, config_dir) - if not bed_file_path or not os.path.exists(bed_file_path): - messagebox.showerror("Error", f"Bed file not found: {bed_file_path}") - return - - # Load the vegetation file - veg_file_path = resolve_file_path(veg_file, config_dir) - if not veg_file_path or not os.path.exists(veg_file_path): - messagebox.showerror("Error", f"Vegetation file not found: {veg_file_path}") - return - - # Load data - bed_data = np.loadtxt(bed_file_path) - veg_data = np.loadtxt(veg_file_path) - - # Try to load x and y grid data if available - x_data, y_data = self._load_grid_data(xgrid_file, ygrid_file, config_dir) - - # Use pcolormesh for 2D grid data with coordinates - im = self.ax.pcolormesh(x_data, y_data, bed_data, shading='auto', cmap='terrain') - self.ax.set_xlabel('X (m)') - self.ax.set_ylabel('Y (m)') - - # Overlay vegetation as contours where vegetation exists - veg_mask = veg_data > 0 - if np.any(veg_mask): - # Create contour lines for vegetation - self.ax.contour(x_data, y_data, veg_data, levels=[0.5], - colors='darkgreen', linewidths=2) - # Fill vegetation areas with semi-transparent green - self.ax.contourf(x_data, y_data, veg_data, levels=[0.5, veg_data.max()], - colors=['green'], alpha=0.3) - - self.ax.set_title('Bed Elevation with Vegetation') - - # Handle colorbar properly to avoid shrinking - self.colorbar = self._update_or_create_colorbar(im, 'Elevation (m)') - - # Enforce equal aspect ratio in domain visualization - self.ax.set_aspect('equal', adjustable='box') - - # Redraw the canvas - self.canvas.draw() - - except Exception as e: - error_msg = f"Failed to plot combined view: {str(e)}\n\n{traceback.format_exc()}" - messagebox.showerror("Error", error_msg) - print(error_msg) - - def export_png(self, default_filename="domain_plot.png"): - """ - Export the current domain plot as PNG. - - Parameters - ---------- - default_filename : str - Default filename for the export dialog - - Returns - ------- - str or None - Path to saved file, or None if cancelled/failed - """ - from tkinter import filedialog - - # Open file dialog for saving - file_path = filedialog.asksaveasfilename( - initialdir=self.get_config_dir(), - title="Save plot as PNG", - defaultextension=".png", - initialfile=default_filename, - filetypes=(("PNG files", "*.png"), ("All files", "*.*")) - ) - - if file_path: - try: - # Ensure canvas is drawn before saving - self.canvas.draw() - # Use tight layout to ensure everything fits - self.fig.tight_layout() - # Save the figure - self.fig.savefig(file_path, dpi=300, bbox_inches='tight') - messagebox.showinfo("Success", f"Plot exported to:\n{file_path}") - return file_path - except Exception as e: - error_msg = f"Failed to export plot: {str(e)}\n\n{traceback.format_exc()}" - messagebox.showerror("Error", error_msg) - print(error_msg) - - return None diff --git a/aeolis/gui/gui_tabs/model_runner.py b/aeolis/gui/gui_tabs/model_runner.py deleted file mode 100644 index 9c83705d..00000000 --- a/aeolis/gui/gui_tabs/model_runner.py +++ /dev/null @@ -1,177 +0,0 @@ -""" -Model Runner Module - -Handles running AeoLiS model simulations from the GUI including: -- Model execution in separate thread -- Real-time logging output capture -- Start/stop controls -- Progress indication -""" - -import os -import threading -import logging -import traceback -from tkinter import messagebox, END, NORMAL, DISABLED - - -class ModelRunner: - """ - Model runner for executing AeoLiS simulations from GUI. - - Handles model execution in a separate thread with real-time logging - output and user controls for starting/stopping the model. - """ - - def __init__(self, start_btn, stop_btn, progress_bar, status_label, - output_text, config_label, root, get_config_func): - """Initialize the model runner.""" - self.start_btn = start_btn - self.stop_btn = stop_btn - self.progress_bar = progress_bar - self.status_label = status_label - self.output_text = output_text - self.config_label = config_label - self.root = root - self.get_config = get_config_func - - self.model_runner = None - self.model_thread = None - self.model_running = False - - def start_model(self): - """Start the AeoLiS model run in a separate thread""" - configfile = self.get_config() - - # Check if config file is selected - if not configfile or configfile == "No file selected": - messagebox.showerror("Error", "Please select a configuration file first in the 'Read/Write Inputfile' tab.") - return - - if not os.path.exists(configfile): - messagebox.showerror("Error", f"Configuration file not found:\n{configfile}") - return - - # Update UI - self.config_label.config(text=os.path.basename(configfile), foreground="black") - self.status_label.config(text="Initializing model...", foreground="orange") - self.start_btn.config(state=DISABLED) - self.stop_btn.config(state=NORMAL) - self.progress_bar.start(10) - - # Clear output text - self.output_text.delete(1.0, END) - self.append_output("="*60 + "\n") - self.append_output(f"Starting AeoLiS model\n") - self.append_output(f"Config file: {configfile}\n") - self.append_output("="*60 + "\n\n") - - # Run model in separate thread to prevent GUI freezing - self.model_running = True - self.model_thread = threading.Thread(target=self.run_model_thread, - args=(configfile,), daemon=True) - self.model_thread.start() - - def stop_model(self): - """Stop the running model""" - if self.model_running: - self.model_running = False - self.status_label.config(text="Stopping model...", foreground="red") - self.append_output("\n" + "="*60 + "\n") - self.append_output("STOP requested by user\n") - self.append_output("="*60 + "\n") - - def run_model_thread(self, configfile): - """Run the model in a separate thread""" - try: - # Import here to avoid issues if aeolis.model is not available - from aeolis.model import AeoLiSRunner - - # Create custom logging handler to capture output - class TextHandler(logging.Handler): - def __init__(self, text_widget, gui_callback): - super().__init__() - self.text_widget = text_widget - self.gui_callback = gui_callback - - def emit(self, record): - msg = self.format(record) - # Schedule GUI update from main thread - self.gui_callback(msg + "\n") - - # Update status - self.root.after(0, lambda: self.status_label.config( - text="Running model...", foreground="green")) - - # Create model runner - self.model_runner = AeoLiSRunner(configfile=configfile) - - # Set up logging to capture to GUI - logger = logging.getLogger('aeolis') - text_handler = TextHandler(self.output_text, self.append_output_threadsafe) - text_handler.setLevel(logging.INFO) - text_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', - datefmt='%H:%M:%S')) - logger.addHandler(text_handler) - - # Run the model with a callback to check for stop requests - def check_stop(model): - if not self.model_running: - raise KeyboardInterrupt("Model stopped by user") - - try: - self.model_runner.run(callback=check_stop) - - # Model completed successfully - self.root.after(0, lambda: self.status_label.config( - text="Model completed successfully!", foreground="green")) - self.append_output_threadsafe("\n" + "="*60 + "\n") - self.append_output_threadsafe("Model run completed successfully!\n") - self.append_output_threadsafe("="*60 + "\n") - - except KeyboardInterrupt: - self.root.after(0, lambda: self.status_label.config( - text="Model stopped by user", foreground="red")) - except Exception as e: - error_msg = f"Model error: {str(e)}" - self.append_output_threadsafe(f"\nERROR: {error_msg}\n") - self.append_output_threadsafe(traceback.format_exc()) - self.root.after(0, lambda: self.status_label.config( - text="Model failed - see output", foreground="red")) - finally: - # Clean up - logger.removeHandler(text_handler) - - except Exception as e: - error_msg = f"Failed to start model: {str(e)}\n{traceback.format_exc()}" - self.append_output_threadsafe(error_msg) - self.root.after(0, lambda: self.status_label.config( - text="Failed to start model", foreground="red")) - - finally: - # Reset UI - self.model_running = False - self.root.after(0, self.reset_ui) - - def append_output(self, text): - """Append text to the output widget (must be called from main thread)""" - self.output_text.insert(END, text) - self.output_text.see(END) - self.output_text.update_idletasks() - - def append_output_threadsafe(self, text): - """Thread-safe version of append_output""" - self.root.after(0, lambda: self.append_output(text)) - - def reset_ui(self): - """Reset the UI elements after model run""" - self.start_btn.config(state=NORMAL) - self.stop_btn.config(state=DISABLED) - self.progress_bar.stop() - - def update_config_display(self, configfile): - """Update the config file display label""" - if configfile and configfile != "No file selected": - self.config_label.config(text=os.path.basename(configfile), foreground="black") - else: - self.config_label.config(text="No file selected", foreground="gray") diff --git a/aeolis/gui/gui_tabs/output_1d.py b/aeolis/gui/gui_tabs/output_1d.py deleted file mode 100644 index a5e54ac4..00000000 --- a/aeolis/gui/gui_tabs/output_1d.py +++ /dev/null @@ -1,463 +0,0 @@ -""" -1D Output Visualizer Module - -Handles visualization of 1D transect data from NetCDF output including: -- Cross-shore and along-shore transects -- Time evolution with slider control -- Domain overview with transect indicator -- PNG and MP4 animation export -""" - -import os -import numpy as np -import traceback -import netCDF4 -from tkinter import messagebox, filedialog, Toplevel -from tkinter import ttk - - -from aeolis.gui.utils import ( - NC_COORD_VARS, - resolve_file_path, extract_time_slice -) - - -class Output1DVisualizer: - """ - Visualizer for 1D transect data from NetCDF output. - - Handles loading, plotting, and exporting 1D transect visualizations - with support for time evolution and domain overview. - """ - - def __init__(self, transect_ax, overview_ax, transect_canvas, transect_fig, - time_slider_1d, time_label_1d, transect_slider, transect_label, - variable_var_1d, direction_var, nc_file_entry_1d, - variable_dropdown_1d, overview_canvas, get_config_dir_func, - get_variable_label_func, get_variable_title_func, - auto_ylimits_var=None, ymin_entry=None, ymax_entry=None): - """Initialize the 1D output visualizer.""" - self.transect_ax = transect_ax - self.overview_ax = overview_ax - self.transect_canvas = transect_canvas - self.transect_fig = transect_fig - self.overview_canvas = overview_canvas - self.time_slider_1d = time_slider_1d - self.time_label_1d = time_label_1d - self.transect_slider = transect_slider - self.transect_label = transect_label - self.variable_var_1d = variable_var_1d - self.direction_var = direction_var - self.nc_file_entry_1d = nc_file_entry_1d - self.variable_dropdown_1d = variable_dropdown_1d - self.get_config_dir = get_config_dir_func - self.get_variable_label = get_variable_label_func - self.get_variable_title = get_variable_title_func - self.auto_ylimits_var = auto_ylimits_var - self.ymin_entry = ymin_entry - self.ymax_entry = ymax_entry - - self.nc_data_cache_1d = None - self.held_plots = [] # List of tuples: (time_idx, transect_data, x_data) - - def load_and_plot(self): - """Load NetCDF file and plot 1D transect data.""" - try: - nc_file = self.nc_file_entry_1d.get() - if not nc_file: - messagebox.showwarning("Warning", "No NetCDF file specified!") - return - - config_dir = self.get_config_dir() - nc_file_path = resolve_file_path(nc_file, config_dir) - if not nc_file_path or not os.path.exists(nc_file_path): - messagebox.showerror("Error", f"NetCDF file not found: {nc_file_path}") - return - - # Open NetCDF file and cache data - with netCDF4.Dataset(nc_file_path, 'r') as nc: - available_vars = list(nc.variables.keys()) - - # Get coordinates - x_data = nc.variables['x'][:] if 'x' in nc.variables else None - y_data = nc.variables['y'][:] if 'y' in nc.variables else None - - # Load variables - var_data_dict = {} - n_times = 1 - - for var_name in available_vars: - if var_name in NC_COORD_VARS: - continue - - var = nc.variables[var_name] - if 'time' in var.dimensions: - var_data = var[:] - if var_data.ndim < 3: - continue - n_times = max(n_times, var_data.shape[0]) - else: - if var.ndim != 2: - continue - var_data = np.expand_dims(var[:, :], axis=0) - - var_data_dict[var_name] = var_data - - if not var_data_dict: - messagebox.showerror("Error", "No valid variables found in NetCDF file!") - return - - # Update UI - candidate_vars = list(var_data_dict.keys()) - self.variable_dropdown_1d['values'] = sorted(candidate_vars) - if candidate_vars: - self.variable_var_1d.set(candidate_vars[0]) - - # Cache data - self.nc_data_cache_1d = { - 'file_path': nc_file_path, - 'vars': var_data_dict, - 'x': x_data, - 'y': y_data, - 'n_times': n_times - } - - # Get grid dimensions - first_var = list(var_data_dict.values())[0] - n_transects = first_var.shape[1] if self.direction_var.get() == 'cross-shore' else first_var.shape[2] - - # Setup sliders - self.time_slider_1d.config(to=n_times - 1) - self.time_slider_1d.set(0) - self.time_label_1d.config(text=f"Time step: 0 / {n_times-1}") - - self.transect_slider.config(to=n_transects - 1) - self.transect_slider.set(n_transects // 2) - self.transect_label.config(text=f"Transect: {n_transects // 2} / {n_transects-1}") - - # Plot initial data - self.update_plot() - - except Exception as e: - error_msg = f"Failed to load NetCDF: {str(e)}\n\n{traceback.format_exc()}" - messagebox.showerror("Error", error_msg) - print(error_msg) - - def update_transect_position(self, value): - """Update transect position from slider.""" - if not self.nc_data_cache_1d: - return - - transect_idx = int(float(value)) - first_var = list(self.nc_data_cache_1d['vars'].values())[0] - n_transects = first_var.shape[1] if self.direction_var.get() == 'cross-shore' else first_var.shape[2] - self.transect_label.config(text=f"Transect: {transect_idx} / {n_transects-1}") - - # Clear held plots when transect changes (they're from different transect) - self.held_plots = [] - - self.update_plot() - - def update_time_step(self, value): - """Update time step from slider.""" - if not self.nc_data_cache_1d: - return - - time_idx = int(float(value)) - n_times = self.nc_data_cache_1d['n_times'] - self.time_label_1d.config(text=f"Time step: {time_idx} / {n_times-1}") - - self.update_plot() - - def update_plot(self): - """Update the 1D transect plot with current settings.""" - if not self.nc_data_cache_1d: - return - - try: - # Always clear the axis to redraw - self.transect_ax.clear() - - time_idx = int(self.time_slider_1d.get()) - transect_idx = int(self.transect_slider.get()) - var_name = self.variable_var_1d.get() - direction = self.direction_var.get() - - if var_name not in self.nc_data_cache_1d['vars']: - messagebox.showwarning("Warning", f"Variable '{var_name}' not found!") - return - - # Get data - var_data = self.nc_data_cache_1d['vars'][var_name] - z_data = extract_time_slice(var_data, time_idx) - - # Extract transect - if direction == 'cross-shore': - transect_data = z_data[transect_idx, :] - x_data = self.nc_data_cache_1d['x'][transect_idx, :] if self.nc_data_cache_1d['x'].ndim == 2 else self.nc_data_cache_1d['x'] - xlabel = 'Cross-shore distance (m)' - else: # along-shore - transect_data = z_data[:, transect_idx] - x_data = self.nc_data_cache_1d['y'][:, transect_idx] if self.nc_data_cache_1d['y'].ndim == 2 else self.nc_data_cache_1d['y'] - xlabel = 'Along-shore distance (m)' - - # Redraw held plots first (if any) - if self.held_plots: - for held_time_idx, held_data, held_x_data in self.held_plots: - if held_x_data is not None: - self.transect_ax.plot(held_x_data, held_data, '--', linewidth=1.5, - alpha=0.7, label=f'Time: {held_time_idx}') - else: - self.transect_ax.plot(held_data, '--', linewidth=1.5, - alpha=0.7, label=f'Time: {held_time_idx}') - - # Plot current transect - if x_data is not None: - self.transect_ax.plot(x_data, transect_data, 'b-', linewidth=2, - label=f'Time: {time_idx}' if self.held_plots else None) - self.transect_ax.set_xlabel(xlabel) - else: - self.transect_ax.plot(transect_data, 'b-', linewidth=2, - label=f'Time: {time_idx}' if self.held_plots else None) - self.transect_ax.set_xlabel('Grid Index') - - ylabel = self.get_variable_label(var_name) - self.transect_ax.set_ylabel(ylabel) - - title = self.get_variable_title(var_name) - if self.held_plots: - self.transect_ax.set_title(f'{title} - {direction.capitalize()} (Transect: {transect_idx}) - Multiple Time Steps') - else: - self.transect_ax.set_title(f'{title} - {direction.capitalize()} (Time: {time_idx}, Transect: {transect_idx})') - self.transect_ax.grid(True, alpha=0.3) - - # Add legend if there are held plots - if self.held_plots: - self.transect_ax.legend(loc='best') - - # Apply Y-axis limits if not auto - if self.auto_ylimits_var is not None and self.ymin_entry is not None and self.ymax_entry is not None: - if not self.auto_ylimits_var.get(): - try: - ymin_str = self.ymin_entry.get().strip() - ymax_str = self.ymax_entry.get().strip() - if ymin_str and ymax_str: - ymin = float(ymin_str) - ymax = float(ymax_str) - self.transect_ax.set_ylim([ymin, ymax]) - except ValueError: - pass # Invalid input, keep auto limits - - # Update overview - self.update_overview(transect_idx) - - self.transect_canvas.draw_idle() - - except Exception as e: - error_msg = f"Failed to update 1D plot: {str(e)}\n\n{traceback.format_exc()}" - print(error_msg) - - def update_overview(self, transect_idx): - """Update the domain overview showing transect position.""" - if not self.nc_data_cache_1d: - return - - try: - self.overview_ax.clear() - - time_idx = int(self.time_slider_1d.get()) - var_name = self.variable_var_1d.get() - direction = self.direction_var.get() - - if var_name not in self.nc_data_cache_1d['vars']: - return - - # Get data for overview - var_data = self.nc_data_cache_1d['vars'][var_name] - z_data = extract_time_slice(var_data, time_idx) - - x_data = self.nc_data_cache_1d['x'] - y_data = self.nc_data_cache_1d['y'] - - # Plot domain overview with pcolormesh - self.overview_ax.pcolormesh(x_data, y_data, z_data, shading='auto', cmap='terrain') - - # Draw transect line - if direction == 'cross-shore': - if x_data.ndim == 2: - x_line = x_data[transect_idx, :] - y_line = y_data[transect_idx, :] - else: - x_line = x_data - y_line = np.full_like(x_data, y_data[transect_idx] if y_data.ndim == 1 else y_data[transect_idx, 0]) - else: # along-shore - if y_data.ndim == 2: - x_line = x_data[:, transect_idx] - y_line = y_data[:, transect_idx] - else: - y_line = y_data - x_line = np.full_like(y_data, x_data[transect_idx] if x_data.ndim == 1 else x_data[0, transect_idx]) - - self.overview_ax.plot(x_line, y_line, 'r-', linewidth=2, label='Transect') - self.overview_ax.set_xlabel('X (m)') - self.overview_ax.set_ylabel('Y (m)') - - self.overview_ax.set_title('Domain Overview') - self.overview_ax.legend() - - # Redraw the overview canvas - self.overview_canvas.draw_idle() - - except Exception as e: - error_msg = f"Failed to update overview: {str(e)}" - print(error_msg) - - def _add_current_to_held_plots(self): - """Helper method to add the current time step to held plots.""" - if not self.nc_data_cache_1d: - return - - time_idx = int(self.time_slider_1d.get()) - transect_idx = int(self.transect_slider.get()) - var_name = self.variable_var_1d.get() - direction = self.direction_var.get() - - if var_name not in self.nc_data_cache_1d['vars']: - return - - # Check if this time step is already in held plots - for held_time, _, _ in self.held_plots: - if held_time == time_idx: - return # Already held, don't add duplicate - - var_data = self.nc_data_cache_1d['vars'][var_name] - z_data = extract_time_slice(var_data, time_idx) - - # Extract transect - if direction == 'cross-shore': - transect_data = z_data[transect_idx, :] - x_data = self.nc_data_cache_1d['x'][transect_idx, :] if self.nc_data_cache_1d['x'].ndim == 2 else self.nc_data_cache_1d['x'] - else: # along-shore - transect_data = z_data[:, transect_idx] - x_data = self.nc_data_cache_1d['y'][:, transect_idx] if self.nc_data_cache_1d['y'].ndim == 2 else self.nc_data_cache_1d['y'] - - # Add to held plots - self.held_plots.append((time_idx, transect_data.copy(), x_data.copy() if x_data is not None else None)) - - def toggle_hold_on(self): - """ - Add the current plot to the collection of held plots. - This allows overlaying multiple time steps on the same plot. - """ - if not self.nc_data_cache_1d: - messagebox.showwarning("Warning", "Please load data first!") - return - - # Add current plot to held plots - self._add_current_to_held_plots() - self.update_plot() - - def clear_held_plots(self): - """Clear all held plots.""" - self.held_plots = [] - self.update_plot() - - def export_png(self, default_filename="output_1d.png"): - """Export current 1D plot as PNG.""" - if not self.transect_fig: - messagebox.showwarning("Warning", "No plot to export.") - return None - - file_path = filedialog.asksaveasfilename( - initialdir=self.get_config_dir(), - title="Save plot as PNG", - defaultextension=".png", - initialfile=default_filename, - filetypes=(("PNG files", "*.png"), ("All files", "*.*")) - ) - - if file_path: - try: - self.transect_fig.savefig(file_path, dpi=300, bbox_inches='tight') - messagebox.showinfo("Success", f"Plot exported to:\n{file_path}") - return file_path - except Exception as e: - error_msg = f"Failed to export: {str(e)}\n\n{traceback.format_exc()}" - messagebox.showerror("Error", error_msg) - print(error_msg) - return None - - def export_animation_mp4(self, default_filename="output_1d_animation.mp4"): - """Export 1D transect animation as MP4.""" - if not self.nc_data_cache_1d or self.nc_data_cache_1d['n_times'] <= 1: - messagebox.showwarning("Warning", "Need multiple time steps for animation.") - return None - - file_path = filedialog.asksaveasfilename( - initialdir=self.get_config_dir(), - title="Save animation as MP4", - defaultextension=".mp4", - initialfile=default_filename, - filetypes=(("MP4 files", "*.mp4"), ("All files", "*.*")) - ) - - if file_path: - try: - from matplotlib.animation import FuncAnimation, FFMpegWriter - - n_times = self.nc_data_cache_1d['n_times'] - progress_window = Toplevel() - progress_window.title("Exporting Animation") - progress_window.geometry("300x100") - progress_label = ttk.Label(progress_window, text="Creating animation...\nThis may take a few minutes.") - progress_label.pack(pady=20) - progress_bar = ttk.Progressbar(progress_window, mode='determinate', maximum=n_times) - progress_bar.pack(pady=10, padx=20, fill='x') - progress_window.update() - - original_time = int(self.time_slider_1d.get()) - - def update_frame(frame_num): - self.time_slider_1d.set(frame_num) - self.update_plot() - try: - if progress_window.winfo_exists(): - progress_bar['value'] = frame_num + 1 - progress_window.update() - except: - pass # Window may have been closed - return [] - - ani = FuncAnimation(self.transect_fig, update_frame, frames=n_times, - interval=200, blit=False, repeat=False) - writer = FFMpegWriter(fps=5, bitrate=1800) - ani.save(file_path, writer=writer) - - # Stop the animation by deleting the animation object - del ani - - self.time_slider_1d.set(original_time) - self.update_plot() - - try: - if progress_window.winfo_exists(): - progress_window.destroy() - except Exception: - pass # Window already destroyed - - messagebox.showinfo("Success", f"Animation exported to:\n{file_path}") - return file_path - - except ImportError: - messagebox.showerror("Error", "Animation export requires ffmpeg.") - except Exception as e: - error_msg = f"Failed to export animation: {str(e)}\n\n{traceback.format_exc()}" - messagebox.showerror("Error", error_msg) - print(error_msg) - finally: - try: - if 'progress_window' in locals() and progress_window.winfo_exists(): - progress_window.destroy() - except Exception: - pass # Window already destroyed - return None diff --git a/aeolis/gui/gui_tabs/output_2d.py b/aeolis/gui/gui_tabs/output_2d.py deleted file mode 100644 index 7cdff72d..00000000 --- a/aeolis/gui/gui_tabs/output_2d.py +++ /dev/null @@ -1,482 +0,0 @@ -""" -2D Output Visualizer Module - -Handles visualization of 2D NetCDF output data including: -- Variable selection and plotting -- Time slider control -- Colorbar customization -- Special renderings (hillshade, quiver plots) -- PNG and MP4 export -""" - -import os -import numpy as np -import traceback -import netCDF4 -from tkinter import messagebox, filedialog, Toplevel -from tkinter import ttk -from matplotlib.cm import ScalarMappable -from matplotlib.colors import Normalize - -from aeolis.gui.utils import ( - NC_COORD_VARS, - resolve_file_path, extract_time_slice, apply_hillshade -) - - -class Output2DVisualizer: - """ - Visualizer for 2D NetCDF output data. - - Handles loading, plotting, and exporting 2D output visualizations with - support for multiple variables, time evolution, and special renderings. - """ - - def __init__(self, output_ax, output_canvas, output_fig, - output_colorbar_ref, time_slider, time_label, - variable_var_2d, colormap_var, auto_limits_var, - vmin_entry, vmax_entry, overlay_veg_var, - nc_file_entry, variable_dropdown_2d, - get_config_dir_func, get_variable_label_func, get_variable_title_func): - """Initialize the 2D output visualizer.""" - self.output_ax = output_ax - self.output_canvas = output_canvas - self.output_fig = output_fig - self.output_colorbar_ref = output_colorbar_ref - self.time_slider = time_slider - self.time_label = time_label - self.variable_var_2d = variable_var_2d - self.colormap_var = colormap_var - self.auto_limits_var = auto_limits_var - self.vmin_entry = vmin_entry - self.vmax_entry = vmax_entry - self.overlay_veg_var = overlay_veg_var - self.nc_file_entry = nc_file_entry - self.variable_dropdown_2d = variable_dropdown_2d - self.get_config_dir = get_config_dir_func - self.get_variable_label = get_variable_label_func - self.get_variable_title = get_variable_title_func - - self.nc_data_cache = None - - def on_variable_changed(self, event=None): - """Handle variable selection change.""" - self.update_plot() - - def load_and_plot(self): - """Load NetCDF file and plot 2D data.""" - try: - nc_file = self.nc_file_entry.get() - if not nc_file: - messagebox.showwarning("Warning", "No NetCDF file specified!") - return - - config_dir = self.get_config_dir() - nc_file_path = resolve_file_path(nc_file, config_dir) - if not nc_file_path or not os.path.exists(nc_file_path): - messagebox.showerror("Error", f"NetCDF file not found: {nc_file_path}") - return - - # Open NetCDF file and cache data - with netCDF4.Dataset(nc_file_path, 'r') as nc: - available_vars = list(nc.variables.keys()) - - # Get coordinates - x_data = nc.variables['x'][:] if 'x' in nc.variables else None - y_data = nc.variables['y'][:] if 'y' in nc.variables else None - - # Load variables - var_data_dict = {} - n_times = 1 - veg_data = None - - for var_name in available_vars: - if var_name in NC_COORD_VARS: - continue - - var = nc.variables[var_name] - if 'time' in var.dimensions: - var_data = var[:] - if var_data.ndim < 3: - continue - n_times = max(n_times, var_data.shape[0]) - else: - if var.ndim != 2: - continue - var_data = np.expand_dims(var[:, :], axis=0) - - var_data_dict[var_name] = var_data - - # Load vegetation if requested - if self.overlay_veg_var.get(): - for veg_name in ['rhoveg', 'vegetated', 'hveg', 'vegfac']: - if veg_name in available_vars: - veg_var = nc.variables[veg_name] - veg_data = veg_var[:] if 'time' in veg_var.dimensions else np.expand_dims(veg_var[:, :], axis=0) - break - - if not var_data_dict: - messagebox.showerror("Error", "No valid variables found in NetCDF file!") - return - - # Add special options - candidate_vars = list(var_data_dict.keys()) - if 'zb' in var_data_dict and 'rhoveg' in var_data_dict: - candidate_vars.append('zb+rhoveg') - if 'ustarn' in var_data_dict and 'ustars' in var_data_dict: - candidate_vars.append('ustar quiver') - - # Update UI - self.variable_dropdown_2d['values'] = sorted(candidate_vars) - if candidate_vars: - self.variable_var_2d.set(candidate_vars[0]) - - # Cache data - self.nc_data_cache = { - 'file_path': nc_file_path, - 'vars': var_data_dict, - 'x': x_data, - 'y': y_data, - 'n_times': n_times, - 'veg': veg_data - } - - # Setup time slider - self.time_slider.config(to=n_times - 1) - self.time_slider.set(0) - self.time_label.config(text=f"Time step: 0 / {n_times-1}") - - # Plot initial data - self.update_plot() - - except Exception as e: - error_msg = f"Failed to load NetCDF: {str(e)}\n\n{traceback.format_exc()}" - messagebox.showerror("Error", error_msg) - print(error_msg) - - def update_plot(self): - """Update the 2D plot with current settings.""" - if not self.nc_data_cache: - return - - try: - self.output_ax.clear() - time_idx = int(self.time_slider.get()) - var_name = self.variable_var_2d.get() - - # Update time label - n_times = self.nc_data_cache.get('n_times', 1) - self.time_label.config(text=f"Time step: {time_idx} / {n_times-1}") - - # Special renderings - if var_name == 'zb+rhoveg': - self._render_zb_rhoveg_shaded(time_idx) - return - if var_name == 'ustar quiver': - self._render_ustar_quiver(time_idx) - return - - if var_name not in self.nc_data_cache['vars']: - messagebox.showwarning("Warning", f"Variable '{var_name}' not found!") - return - - # Get data - var_data = self.nc_data_cache['vars'][var_name] - z_data = extract_time_slice(var_data, time_idx) - x_data = self.nc_data_cache['x'] - y_data = self.nc_data_cache['y'] - - # Get colorbar limits - vmin, vmax = None, None - if not self.auto_limits_var.get(): - try: - vmin_str = self.vmin_entry.get().strip() - vmax_str = self.vmax_entry.get().strip() - vmin = float(vmin_str) if vmin_str else None - vmax = float(vmax_str) if vmax_str else None - except ValueError: - messagebox.showwarning( - "Invalid Input", - "Colorbar limits must be valid numbers. Using automatic limits instead." - ) - - cmap = self.colormap_var.get() - - # Plot with pcolormesh (x and y always exist in AeoLiS NetCDF files) - im = self.output_ax.pcolormesh(x_data, y_data, z_data, shading='auto', - cmap=cmap, vmin=vmin, vmax=vmax) - self.output_ax.set_xlabel('X (m)') - self.output_ax.set_ylabel('Y (m)') - - title = self.get_variable_title(var_name) - self.output_ax.set_title(f'{title} (Time step: {time_idx})') - - # Update colorbar - self._update_colorbar(im, var_name) - - # Overlay vegetation - if self.overlay_veg_var.get() and self.nc_data_cache['veg'] is not None: - veg_slice = self.nc_data_cache['veg'] - veg_data = veg_slice[time_idx, :, :] if veg_slice.ndim == 3 else veg_slice[:, :] - self.output_ax.pcolormesh(x_data, y_data, veg_data, shading='auto', - cmap='Greens', vmin=0, vmax=1, alpha=0.4) - - self.output_canvas.draw_idle() - - except Exception as e: - error_msg = f"Failed to update 2D plot: {str(e)}\n\n{traceback.format_exc()}" - print(error_msg) - - def _update_colorbar(self, im, var_name): - """Update or create colorbar.""" - cbar_label = self.get_variable_label(var_name) - if self.output_colorbar_ref[0] is not None: - try: - self.output_colorbar_ref[0].update_normal(im) - self.output_colorbar_ref[0].set_label(cbar_label) - except Exception: - self.output_colorbar_ref[0] = self.output_fig.colorbar(im, ax=self.output_ax, label=cbar_label) - else: - self.output_colorbar_ref[0] = self.output_fig.colorbar(im, ax=self.output_ax, label=cbar_label) - - def export_png(self, default_filename="output_2d.png"): - """Export current 2D plot as PNG.""" - if not self.output_fig: - messagebox.showwarning("Warning", "No plot to export.") - return None - - file_path = filedialog.asksaveasfilename( - initialdir=self.get_config_dir(), - title="Save plot as PNG", - defaultextension=".png", - initialfile=default_filename, - filetypes=(("PNG files", "*.png"), ("All files", "*.*")) - ) - - if file_path: - try: - self.output_fig.savefig(file_path, dpi=300, bbox_inches='tight') - messagebox.showinfo("Success", f"Plot exported to:\n{file_path}") - return file_path - except Exception as e: - error_msg = f"Failed to export: {str(e)}\n\n{traceback.format_exc()}" - messagebox.showerror("Error", error_msg) - print(error_msg) - return None - - def export_animation_mp4(self, default_filename="output_2d_animation.mp4"): - """Export 2D plot animation as MP4.""" - if not self.nc_data_cache or self.nc_data_cache['n_times'] <= 1: - messagebox.showwarning("Warning", "Need multiple time steps for animation.") - return None - - file_path = filedialog.asksaveasfilename( - initialdir=self.get_config_dir(), - title="Save animation as MP4", - defaultextension=".mp4", - initialfile=default_filename, - filetypes=(("MP4 files", "*.mp4"), ("All files", "*.*")) - ) - - if file_path: - try: - from matplotlib.animation import FuncAnimation, FFMpegWriter - - n_times = self.nc_data_cache['n_times'] - progress_window = Toplevel() - progress_window.title("Exporting Animation") - progress_window.geometry("300x100") - progress_label = ttk.Label(progress_window, text="Creating animation...\nThis may take a few minutes.") - progress_label.pack(pady=20) - progress_bar = ttk.Progressbar(progress_window, mode='determinate', maximum=n_times) - progress_bar.pack(pady=10, padx=20, fill='x') - progress_window.update() - - original_time = int(self.time_slider.get()) - - def update_frame(frame_num): - self.time_slider.set(frame_num) - self.update_plot() - try: - if progress_window.winfo_exists(): - progress_bar['value'] = frame_num + 1 - progress_window.update() - except: - pass # Window may have been closed - return [] - - ani = FuncAnimation(self.output_fig, update_frame, frames=n_times, - interval=200, blit=False, repeat=False) - writer = FFMpegWriter(fps=5, bitrate=1800) - ani.save(file_path, writer=writer) - - # Stop the animation by deleting the animation object - del ani - - self.time_slider.set(original_time) - self.update_plot() - - try: - if progress_window.winfo_exists(): - progress_window.destroy() - except Exception: - pass # Window already destroyed - - messagebox.showinfo("Success", f"Animation exported to:\n{file_path}") - return file_path - - except ImportError: - messagebox.showerror("Error", "Animation export requires ffmpeg.") - except Exception as e: - error_msg = f"Failed to export animation: {str(e)}\n\n{traceback.format_exc()}" - messagebox.showerror("Error", error_msg) - print(error_msg) - finally: - try: - if 'progress_window' in locals() and progress_window.winfo_exists(): - progress_window.destroy() - except Exception: - pass # Window already destroyed - return None - - def _render_zb_rhoveg_shaded(self, time_idx): - """Render combined bed + vegetation with hillshading matching Anim2D_ShadeVeg.py.""" - try: - zb_data = extract_time_slice(self.nc_data_cache['vars']['zb'], time_idx) - rhoveg_data = extract_time_slice(self.nc_data_cache['vars']['rhoveg'], time_idx) - x_data = self.nc_data_cache['x'] - y_data = self.nc_data_cache['y'] - - # Normalize vegetation to [0,1] - veg_max = np.nanmax(rhoveg_data) - veg_norm = rhoveg_data / veg_max if (veg_max is not None and veg_max > 0) else np.clip(rhoveg_data, 0.0, 1.0) - veg_norm = np.clip(veg_norm, 0.0, 1.0) - - # Apply hillshade - x1d = x_data[0, :] if x_data.ndim == 2 else x_data - y1d = y_data[:, 0] if y_data.ndim == 2 else y_data - hillshade = apply_hillshade(zb_data, x1d, y1d, az_deg=155.0, alt_deg=5.0) - - # Color definitions - sand = np.array([1.0, 239.0/255.0, 213.0/255.0]) # light sand - darkgreen = np.array([34/255, 139/255, 34/255]) - ocean = np.array([70/255, 130/255, 180/255]) # steelblue - - # Create RGB array (ny, nx, 3) - ny, nx = zb_data.shape - rgb = np.zeros((ny, nx, 3), dtype=float) - - # Base color: blend sand and vegetation - for i in range(3): # R, G, B channels - rgb[:, :, i] = sand[i] * (1.0 - veg_norm) + darkgreen[i] * veg_norm - - # Apply ocean mask: zb < -0.5 and x < 200 - if x_data is not None: - X2d = x_data if x_data.ndim == 2 else np.meshgrid(x1d, y1d)[0] - ocean_mask = (zb_data < -0.5) & (X2d < 200) - rgb[ocean_mask] = ocean - - # Apply shading to all RGB channels - rgb *= hillshade[:, :, np.newaxis] - rgb = np.clip(rgb, 0.0, 1.0) - - # Plot RGB image - extent = [x1d.min(), x1d.max(), y1d.min(), y1d.max()] - self.output_ax.imshow(rgb, origin='lower', extent=extent, - interpolation='nearest', aspect='auto') - self.output_ax.set_xlabel('X (m)') - self.output_ax.set_ylabel('Y (m)') - - self.output_ax.set_title(f'Bed + Vegetation (Time step: {time_idx})') - - # Get colorbar limits for vegetation - vmin, vmax = 0, veg_max - if not self.auto_limits_var.get(): - try: - vmin_str = self.vmin_entry.get().strip() - vmax_str = self.vmax_entry.get().strip() - vmin = float(vmin_str) if vmin_str else 0 - vmax = float(vmax_str) if vmax_str else veg_max - except ValueError: - pass # Use default limits if invalid input - - # Create a ScalarMappable for the colorbar (showing vegetation density) - norm = Normalize(vmin=vmin, vmax=vmax) - sm = ScalarMappable(cmap='Greens', norm=norm) - sm.set_array(rhoveg_data) - - # Add colorbar for vegetation density - self._update_colorbar(sm, 'rhoveg') - - self.output_canvas.draw_idle() - except Exception as e: - print(f"Failed to render zb+rhoveg: {e}") - traceback.print_exc() - - def _render_ustar_quiver(self, time_idx): - """Render quiver plot of shear velocity with magnitude background.""" - try: - ustarn = extract_time_slice(self.nc_data_cache['vars']['ustarn'], time_idx) - ustars = extract_time_slice(self.nc_data_cache['vars']['ustars'], time_idx) - x_data = self.nc_data_cache['x'] - y_data = self.nc_data_cache['y'] - - # Calculate magnitude for background coloring - ustar_mag = np.sqrt(ustarn**2 + ustars**2) - - # Subsample for quiver - step = max(1, min(ustarn.shape) // 25) - - # Get colormap and limits - cmap = self.colormap_var.get() - vmin, vmax = None, None - if not self.auto_limits_var.get(): - try: - vmin_str = self.vmin_entry.get().strip() - vmax_str = self.vmax_entry.get().strip() - vmin = float(vmin_str) if vmin_str else None - vmax = float(vmax_str) if vmax_str else None - except ValueError: - pass # Use auto limits - - # Plot background field (magnitude) - im = self.output_ax.pcolormesh(x_data, y_data, ustar_mag, - shading='auto', cmap=cmap, - vmin=vmin, vmax=vmax, alpha=0.7) - - # Calculate appropriate scaling for arrows - x1d = x_data[0, :] if x_data.ndim == 2 else x_data - y1d = y_data[:, 0] if y_data.ndim == 2 else y_data - x_range = x1d.max() - x1d.min() - y_range = y1d.max() - y1d.min() - - # Calculate typical velocity magnitude (handle masked arrays) - valid_mag = np.asarray(ustar_mag[ustar_mag > 0]) - typical_vel = np.percentile(valid_mag, 75) if valid_mag.size > 0 else 1.0 - arrow_scale = typical_vel * 20 # Scale factor to make arrows visible - - # Add quiver plot with black arrows - Q = self.output_ax.quiver(x_data[::step, ::step], y_data[::step, ::step], - ustars[::step, ::step], ustarn[::step, ::step], - scale=arrow_scale, color='black', width=0.004, - headwidth=3, headlength=4, headaxislength=3.5, - zorder=10) - - # Add quiver key (legend for arrow scale) - placed to the right, above colorbar - self.output_ax.quiverkey(Q, 1.1, 1.05, typical_vel, - f'{typical_vel:.2f} m/s', - labelpos='N', coordinates='axes', - color='black', labelcolor='black', - fontproperties={'size': 9}) - - self.output_ax.set_xlabel('X (m)') - self.output_ax.set_ylabel('Y (m)') - self.output_ax.set_title(f'Shear Velocity (Time step: {time_idx})') - - # Update colorbar for magnitude - self._update_colorbar(im, 'ustar magnitude') - - self.output_canvas.draw_idle() - except Exception as e: - print(f"Failed to render ustar quiver: {e}") - traceback.print_exc() diff --git a/aeolis/gui/gui_tabs/wind.py b/aeolis/gui/gui_tabs/wind.py deleted file mode 100644 index f4b7aa0e..00000000 --- a/aeolis/gui/gui_tabs/wind.py +++ /dev/null @@ -1,313 +0,0 @@ -""" -Wind Visualizer Module - -Handles visualization of wind input data including: -- Wind speed time series -- Wind direction time series -- Wind rose diagrams -- PNG export for wind plots -""" - -import os -import numpy as np -import traceback -from tkinter import messagebox, filedialog -import matplotlib.patches as mpatches -from windrose import WindroseAxes -from aeolis.gui.utils import resolve_file_path, determine_time_unit - - -class WindVisualizer: - """ - Visualizer for wind input data (time series and wind rose). - - Parameters - ---------- - wind_speed_ax : matplotlib.axes.Axes - Axes for wind speed time series - wind_dir_ax : matplotlib.axes.Axes - Axes for wind direction time series - wind_ts_canvas : FigureCanvasTkAgg - Canvas for time series plots - wind_ts_fig : matplotlib.figure.Figure - Figure containing time series - windrose_fig : matplotlib.figure.Figure - Figure for wind rose - windrose_canvas : FigureCanvasTkAgg - Canvas for wind rose - get_wind_file_func : callable - Function to get wind file entry widget - get_entries_func : callable - Function to get all entry widgets - get_config_dir_func : callable - Function to get configuration directory - get_dic_func : callable - Function to get configuration dictionary - """ - - def __init__(self, wind_speed_ax, wind_dir_ax, wind_ts_canvas, wind_ts_fig, - windrose_fig, windrose_canvas, get_wind_file_func, get_entries_func, - get_config_dir_func, get_dic_func): - self.wind_speed_ax = wind_speed_ax - self.wind_dir_ax = wind_dir_ax - self.wind_ts_canvas = wind_ts_canvas - self.wind_ts_fig = wind_ts_fig - self.windrose_fig = windrose_fig - self.windrose_canvas = windrose_canvas - self.get_wind_file = get_wind_file_func - self.get_entries = get_entries_func - self.get_config_dir = get_config_dir_func - self.get_dic = get_dic_func - self.wind_data_cache = None - - def load_and_plot(self): - """Load wind file and plot time series and wind rose.""" - try: - # Get the wind file path - wind_file = self.get_wind_file().get() - - if not wind_file: - messagebox.showwarning("Warning", "No wind file specified!") - return - - # Get the directory of the config file to resolve relative paths - config_dir = self.get_config_dir() - - # Resolve wind file path - wind_file_path = resolve_file_path(wind_file, config_dir) - if not wind_file_path or not os.path.exists(wind_file_path): - messagebox.showerror("Error", f"Wind file not found: {wind_file_path}") - return - - # Check if we already loaded this file (avoid reloading) - if self.wind_data_cache and self.wind_data_cache.get('file_path') == wind_file_path: - # Data already loaded, just return (don't reload) - return - - # Load wind data (time, speed, direction) - wind_data = np.loadtxt(wind_file_path) - - # Check data format - if wind_data.ndim != 2 or wind_data.shape[1] < 3: - messagebox.showerror("Error", "Wind file must have at least 3 columns: time, speed, direction") - return - - time = wind_data[:, 0] - speed = wind_data[:, 1] - direction = wind_data[:, 2] - - # Get wind convention from config - dic = self.get_dic() - wind_convention = dic.get('wind_convention', 'nautical') - - # Cache the wind data along with file path and convention - self.wind_data_cache = { - 'file_path': wind_file_path, - 'time': time, - 'speed': speed, - 'direction': direction, - 'convention': wind_convention - } - - # Determine appropriate time unit based on simulation time (tstart and tstop) - tstart = 0 - tstop = 0 - use_sim_limits = False - - try: - entries = self.get_entries() - tstart_entry = entries.get('tstart') - tstop_entry = entries.get('tstop') - - if tstart_entry and tstop_entry: - tstart = float(tstart_entry.get() or 0) - tstop = float(tstop_entry.get() or 0) - if tstop > tstart: - sim_duration = tstop - tstart # in seconds - use_sim_limits = True - else: - sim_duration = time[-1] - time[0] if len(time) > 0 else 0 - else: - sim_duration = time[-1] - time[0] if len(time) > 0 else 0 - except (ValueError, AttributeError, TypeError): - sim_duration = time[-1] - time[0] if len(time) > 0 else 0 - - # Choose appropriate time unit and convert using utility function - time_unit, time_divisor = determine_time_unit(sim_duration) - time_converted = time / time_divisor - - # Plot wind speed time series - self.wind_speed_ax.clear() - self.wind_speed_ax.plot(time_converted, speed, 'b-', linewidth=1.5, zorder=2, label='Wind Speed') - self.wind_speed_ax.set_xlabel(f'Time ({time_unit})') - self.wind_speed_ax.set_ylabel('Wind Speed (m/s)') - self.wind_speed_ax.set_title('Wind Speed Time Series') - self.wind_speed_ax.grid(True, alpha=0.3, zorder=1) - - # Calculate axis limits with 10% padding and add shading - if use_sim_limits: - tstart_converted = tstart / time_divisor - tstop_converted = tstop / time_divisor - axis_range = tstop_converted - tstart_converted - padding = 0.1 * axis_range - xlim_min = tstart_converted - padding - xlim_max = tstop_converted + padding - - self.wind_speed_ax.set_xlim([xlim_min, xlim_max]) - self.wind_speed_ax.axvspan(xlim_min, tstart_converted, alpha=0.15, color='gray', zorder=3) - self.wind_speed_ax.axvspan(tstop_converted, xlim_max, alpha=0.15, color='gray', zorder=3) - - shaded_patch = mpatches.Patch(color='gray', alpha=0.15, label='Outside simulation time') - self.wind_speed_ax.legend(handles=[shaded_patch], loc='upper right', fontsize=8) - - # Plot wind direction time series - self.wind_dir_ax.clear() - self.wind_dir_ax.plot(time_converted, direction, 'r-', linewidth=1.5, zorder=2, label='Wind Direction') - self.wind_dir_ax.set_xlabel(f'Time ({time_unit})') - self.wind_dir_ax.set_ylabel('Wind Direction (degrees)') - self.wind_dir_ax.set_title(f'Wind Direction Time Series ({wind_convention} convention)') - self.wind_dir_ax.set_ylim([0, 360]) - self.wind_dir_ax.grid(True, alpha=0.3, zorder=1) - - if use_sim_limits: - self.wind_dir_ax.set_xlim([xlim_min, xlim_max]) - self.wind_dir_ax.axvspan(xlim_min, tstart_converted, alpha=0.15, color='gray', zorder=3) - self.wind_dir_ax.axvspan(tstop_converted, xlim_max, alpha=0.15, color='gray', zorder=3) - - shaded_patch = mpatches.Patch(color='gray', alpha=0.15, label='Outside simulation time') - self.wind_dir_ax.legend(handles=[shaded_patch], loc='upper right', fontsize=8) - - # Redraw time series canvas - self.wind_ts_canvas.draw() - - # Plot wind rose - self.plot_windrose(speed, direction, wind_convention) - - except Exception as e: - error_msg = f"Failed to load and plot wind data: {str(e)}\n\n{traceback.format_exc()}" - messagebox.showerror("Error", error_msg) - print(error_msg) - - def force_reload(self): - """Force reload of wind data by clearing cache.""" - self.wind_data_cache = None - self.load_and_plot() - - def plot_windrose(self, speed, direction, convention='nautical'): - """ - Plot wind rose diagram. - - Parameters - ---------- - speed : array - Wind speed values - direction : array - Wind direction values in degrees - convention : str - 'nautical' or 'cartesian' - """ - try: - # Clear the windrose figure - self.windrose_fig.clear() - - # Convert direction based on convention to meteorological standard - if convention == 'cartesian': - direction_met = (270 - direction) % 360 - else: - direction_met = direction - - # Create windrose axes - ax = WindroseAxes.from_ax(fig=self.windrose_fig) - ax.bar(direction_met, speed, normed=True, opening=0.8, edgecolor='white') - ax.set_legend(title='Wind Speed (m/s)') - ax.set_title(f'Wind Rose ({convention} convention)', fontsize=14, fontweight='bold') - - # Redraw windrose canvas - self.windrose_canvas.draw() - - except Exception as e: - error_msg = f"Failed to plot wind rose: {str(e)}\n\n{traceback.format_exc()}" - print(error_msg) - # Create a simple text message instead - self.windrose_fig.clear() - ax = self.windrose_fig.add_subplot(111) - ax.text(0.5, 0.5, 'Wind rose plot failed.\nSee console for details.', - ha='center', va='center', transform=ax.transAxes) - ax.axis('off') - self.windrose_canvas.draw() - - def export_timeseries_png(self, default_filename="wind_timeseries.png"): - """ - Export the wind time series plot as PNG. - - Parameters - ---------- - default_filename : str - Default filename for the export dialog - - Returns - ------- - str or None - Path to saved file, or None if cancelled/failed - """ - if self.wind_ts_fig is None: - messagebox.showwarning("Warning", "No wind plot to export. Please load wind data first.") - return None - - file_path = filedialog.asksaveasfilename( - initialdir=self.get_config_dir(), - title="Save wind time series as PNG", - defaultextension=".png", - initialfile=default_filename, - filetypes=(("PNG files", "*.png"), ("All files", "*.*")) - ) - - if file_path: - try: - self.wind_ts_fig.savefig(file_path, dpi=300, bbox_inches='tight') - messagebox.showinfo("Success", f"Wind time series exported to:\n{file_path}") - return file_path - except Exception as e: - error_msg = f"Failed to export plot: {str(e)}\n\n{traceback.format_exc()}" - messagebox.showerror("Error", error_msg) - print(error_msg) - - return None - - def export_windrose_png(self, default_filename="wind_rose.png"): - """ - Export the wind rose plot as PNG. - - Parameters - ---------- - default_filename : str - Default filename for the export dialog - - Returns - ------- - str or None - Path to saved file, or None if cancelled/failed - """ - if self.windrose_fig is None: - messagebox.showwarning("Warning", "No wind rose plot to export. Please load wind data first.") - return None - - file_path = filedialog.asksaveasfilename( - initialdir=self.get_config_dir(), - title="Save wind rose as PNG", - defaultextension=".png", - initialfile=default_filename, - filetypes=(("PNG files", "*.png"), ("All files", "*.*")) - ) - - if file_path: - try: - self.windrose_fig.savefig(file_path, dpi=300, bbox_inches='tight') - messagebox.showinfo("Success", f"Wind rose exported to:\n{file_path}") - return file_path - except Exception as e: - error_msg = f"Failed to export plot: {str(e)}\n\n{traceback.format_exc()}" - messagebox.showerror("Error", error_msg) - print(error_msg) - - return None diff --git a/aeolis/gui/main.py b/aeolis/gui/main.py index 10155a8b..5b249435 100644 --- a/aeolis/gui/main.py +++ b/aeolis/gui/main.py @@ -23,7 +23,7 @@ def launch_gui(): root = Tk() # Create an instance of the AeolisGUI class - AeolisGUI(root, dic) + app = AeolisGUI(root, dic) # Bring window to front and give it focus root.lift()