diff --git a/src/view.py b/src/view.py index bb9e9a04d..c714583b7 100644 --- a/src/view.py +++ b/src/view.py @@ -7,6 +7,7 @@ import pyopenms as poms from src.common.common import show_fig, display_large_dataframe from typing import Union +from plotly.subplots import make_subplots def get_df(file: Union[str, Path]) -> pd.DataFrame: @@ -166,41 +167,112 @@ def plot_ms_spectrum(df, title, bin_peaks, num_x_bins): ) return fig - @st.fragment def view_peak_map(): df = st.session_state.view_ms1 + if "view_peak_map_selection" in st.session_state: box = st.session_state.view_peak_map_selection.selection.box if box: df = st.session_state.view_ms1.copy() - df = df[df["RT"] > box[0]["x"][0]] + df = df[df["RT"] > max(box[0]["x"][0], 0)] df = df[df["mz"] > box[0]["y"][1]] df = df[df["mz"] < box[0]["y"][0]] - df = df[df["RT"] < box[0]["x"][1]] + df = df[df["RT"] < min(box[0]["x"][1], 1500)] + peak_map = df.plot( kind="peakmap", x="RT", y="mz", z="inty", - title=st.session_state.view_selected_file, + xlabel="Retention Time (s)", + ylabel="m/z", grid=False, show_plot=False, bin_peaks=True, backend="ms_plotly", aggregate_duplicates=True, ) - peak_map.update_layout(template="simple_white", dragmode="select") + + df_tic = df.groupby("RT").sum().reset_index() + + tic_fig = df_tic.plot( + kind="chromatogram", + x="RT", + y="inty", + xlabel="Retention Time (s)", + ylabel="TIC", + grid=False, + show_plot=False, + backend="ms_plotly", + ) + + + tic_fig.update_layout( + height=200, + margin=dict(l=0, r=0, t=0, b=0), + plot_bgcolor="rgb(255,255,255)", + xaxis=dict( + title="Retention Time (s)", + rangeslider=dict(visible=False), + showgrid=False, + fixedrange=True, + minallowed="0", + maxallowed="1000" + ), + yaxis=dict( + title="TIC", + autorange="reversed", + fixedrange=True + ), + dragmode=False + ) + + combined_fig = make_subplots( + rows=2, + cols=1, + shared_xaxes=True, + row_heights=[0.8, 0.2], + vertical_spacing=0.05 + ) + + for trace in peak_map.data: + combined_fig.add_trace(trace, row=1, col=1) + + for trace in tic_fig.data: + combined_fig.add_trace(trace, row=2, col=1) + + combined_fig.update_layout( + template="simple_white", + dragmode="zoom", + xaxis=dict(title="Retention Time (s)", showgrid=False, minallowed="0", maxallowed="1000"), + yaxis=dict(title="m/z", fixedrange=True), + yaxis2=dict(title="TIC", autorange="reversed", fixedrange=True), + height=850, + margin=dict(t=100, b=100), + title=dict( + text=st.session_state.view_selected_file, + x=0.5, + y=0.99, + xanchor="center", + yanchor="top", + font=dict(size=18, family="Arial, sans-serif") + ) + ) + c1, c2 = st.columns(2) + with c1: st.info( - "💡 Zoom in via rectangular selection for more details and 3D plot. Double click plot to zoom back out." + "💡 Select ranges on the TIC plot below for more details. " + "Double click to reset the selection." ) show_fig( - peak_map, + combined_fig, f"peak_map_{st.session_state.view_selected_file}", selection_session_state_key="view_peak_map_selection", ) + with c2: if df.shape[0] < 2500: peak_map_3D = df.plot( @@ -211,7 +283,8 @@ def view_peak_map(): y="mz", z="inty", zlabel="Intensity", - title="", + xlabel="Retention Time (s)", + ylabel="m/z", show_plot=False, grid=False, bin_peaks=st.session_state.spectrum_bin_peaks, @@ -220,9 +293,17 @@ def view_peak_map(): width=900, aggregate_duplicates=True, ) - st.plotly_chart(peak_map_3D, use_container_width=True) + peak_map_3D.update_layout( + scene=dict( + xaxis=dict(title="Retention Time (s)", minallowed="0", maxallowed="1000"), + yaxis=dict(title="m/z", fixedrange=True), + zaxis=dict(title="Intensity"), + dragmode="orbit" + ) + ) + st.plotly_chart(peak_map_3D, use_container_width=True) @st.fragment def view_spectrum(): cols = st.columns([0.34, 0.66]) @@ -316,4 +397,27 @@ def view_bpc_tic(): key="view_eic_ppm", ) fig = plot_bpc_tic() - show_fig(fig, f"BPC-TIC-{st.session_state.view_selected_file}") + + fig.update_layout( + xaxis=dict( + rangeslider=dict(visible=True), # Range slider for x-axis zoom + rangeselector=dict( # Zoom buttons + buttons=list([ + dict(count=1, label="1s", step="second", stepmode="backward"), + dict(count=10, label="10s", step="second", stepmode="backward"), + dict(count=1, label="1m", step="minute", stepmode="backward"), + dict(step="all") + ]) + ) + ), + margin=dict(l=0, r=0, t=0, b=0), # Remove margins for full-width display + height=700, # Increase height for better visualization + hovermode="x unified", # Unified hover tooltip + modebar=dict( + orientation='h', # Horizontal toolbar + bgcolor='rgba(255,255,255,0.7)', # Toolbar background color + ) + ) + + # Display full-width chromatogram + st.plotly_chart(fig, use_container_width=True)