Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions ml_peg/app/bulk_crystal/phonons/app_phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
build_serialized_scatter_content,
resolve_scatter_selection,
)
from ml_peg.app.utils.register_callbacks import register_image_download_callbacks
from ml_peg.calcs import CALCS_ROOT

DATA_PATH = APP_ROOT / "data" / "bulk_crystal" / "phonons"
Expand All @@ -54,8 +53,6 @@ class PhononApp(BaseApp):

def register_callbacks(self) -> None:
"""Register scatter/dispersion callbacks via shared helpers."""
register_image_download_callbacks()

with SCATTER_PATH.open(encoding="utf8") as handle:
interactive_data = json.load(handle)

Expand Down
36 changes: 24 additions & 12 deletions ml_peg/app/bulk_crystal/phonons/interactive_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

import matplotlib

from ml_peg.app.utils.build_components import build_image_download_controls

matplotlib.use("Agg")
from dash import dcc, html
from matplotlib import gridspec
Expand Down Expand Up @@ -317,27 +315,41 @@ def render_dispersion_component(
if not image_src:
return None

filename = f"{label}_phonon_dispersion.png" if label else "phonon_dispersion.png"
stem = f"{label}_phonon_dispersion" if label else "phonon_dispersion"
link_style = {
"display": "inline-flex",
"alignItems": "center",
"gap": "8px",
"marginTop": "4px",
"marginBottom": "0px",
}
if uris:
download_controls = build_image_download_controls(label or "dispersion", uris)
download_controls = html.Div(
[
html.A(
fmt.upper(),
href=uri,
download=f"{stem}.{fmt}",
className="download-button plot-download-button",
style={"width": "60px"},
)
for fmt, uri in uris.items()
],
style=link_style,
)
else:
download_controls = html.Div(
html.A(
"Download plot",
href=image_src,
download=filename,
download=f"{stem}.png",
className="download-button plot-download-button",
style={"width": "112px"},
),
style={
"display": "flex",
"justifyContent": "flex-end",
"marginTop": "12px",
"marginBottom": "0px",
},
style=link_style,
)
children = [
html.H4(label),
html.H4(label, style={"marginTop": "10px", "marginBottom": "4px"}),
download_controls,
html.Img(
src=image_src,
Expand Down
46 changes: 8 additions & 38 deletions ml_peg/app/utils/register_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import base64
from copy import deepcopy
from typing import Any, Literal

Expand Down Expand Up @@ -1150,55 +1149,26 @@ def register_image_download_callbacks() -> None:
"""
Register one generic image download callback once per Dash app.

Unlike the table download (which asks the browser to capture the live DOM),
this callback decodes a pre-rendered image already stored as a base64 data
URI in a ``dcc.Store``. The phonon dispersion plot is rendered server-side
via kaleido at analysis time, so the full-resolution export is available
without re-rendering in the browser.
The image payloads are already stored in the browser as data URIs. Keeping
this callback client-side avoids posting large phonon dispersion images back
to Dash, which can exceed request-size limits.
"""
app = dash.get_app()
output = Output({"type": "image-download", "index": MATCH}, "data")
if str(output) in app.callback_map:
return

@callback(
app.clientside_callback(
ClientsideFunction(
namespace="image_download",
function_name="downloadImage",
),
output,
Input({"type": "image-download-button", "index": MATCH}, "n_clicks"),
State({"type": "image-download-format", "index": MATCH}, "value"),
State({"type": "image-download-target", "index": MATCH}, "data"),
prevent_initial_call=True,
optional=True,
)
def _download_image(n_clicks, fmt, uris):
"""
Decode the stored data URI and trigger a browser file download.

Parameters
----------
n_clicks
Number of button clicks.
fmt
Selected download format (``"png"``, ``"svg"``, or ``"json"``).
uris
Mapping of format keys to base64 data URIs.

Returns
-------
dict
Dash ``dcc.send_bytes`` payload for the Download component.
"""
if not n_clicks or not uris or not fmt:
raise PreventUpdate
uri = uris.get(fmt)
if not uri:
raise PreventUpdate
data = base64.b64decode(uri.split(",")[1])
mime = {
"png": "image/png",
"svg": "image/svg+xml",
"json": "application/json",
}.get(fmt, "application/octet-stream")
return dcc.send_bytes(data, f"phonon_dispersion.{fmt}", type=mime)


def register_download_callbacks(table_id: str) -> None:
Expand Down
Loading