From 33f22868b13192b14ea70e6b68cb52d2021afda4 Mon Sep 17 00:00:00 2001 From: Axel Huebl Date: Mon, 4 May 2026 22:44:44 -0700 Subject: [PATCH] ML Model Download Progress & Logo Simpler version of #433: - Show download progress from MLFlow - Add the Genesis AMSC Logo and make it link to MLFlow model catalogue --- dashboard/app.py | 101 ++++++++++-- dashboard/logos/AmSC_300px.png | Bin 0 -> 14569 bytes dashboard/model_manager.py | 291 ++++++++++++++++++++++++++++----- dashboard/state_manager.py | 5 + 4 files changed, 344 insertions(+), 53 deletions(-) create mode 100644 dashboard/logos/AmSC_300px.png diff --git a/dashboard/app.py b/dashboard/app.py index cb5108ff..1278cee0 100644 --- a/dashboard/app.py +++ b/dashboard/app.py @@ -1,3 +1,4 @@ +import asyncio from bson.objectid import ObjectId import os import re @@ -6,7 +7,12 @@ from trame.ui.vuetify3 import SinglePageWithDrawerLayout from trame.widgets import plotly, router, vuetify3 as vuetify, html -from model_manager import ModelManager, model_type_dict +from model_manager import ( + ModelManager, + is_model_available_on_mlflow, + load_model_from_mlflow_with_progress, + model_type_dict, +) from outputs_manager import OutputManager from optimization_manager import OptimizationManager from parameters_manager import ParametersManager @@ -52,6 +58,7 @@ def update( reset_gui_route_nersc=True, reset_gui_route_chat=True, reset_gui_layout=True, + preloaded_model_manager=None, **kwargs, ): print("Updating...") @@ -69,6 +76,9 @@ def update( # derive execution mode from execution_mode in the experiment configuration file execution_mode = config_dict.get("execution_mode") or {} state.model_training_mode = execution_mode.get("ml_training", "local") + state.model_mlflow_tracking_uri = (config_dict.get("mlflow") or {}).get( + "tracking_uri" + ) db = load_database(config_dict) exp_data, sim_data = load_data(db, state.experiment, state.experiment_date_range) # reset output @@ -79,10 +89,15 @@ def update( cal_manager = SimulationCalibrationManager(simulation_calibration) # reset model if reset_model: - mod_manager = ModelManager( - config_dict=config_dict, - model_type=model_type_dict[state.model_type_verbose], - ) + state.model_available = False + if preloaded_model_manager is None: + mod_manager = ModelManager( + config_dict=config_dict, + model_type=model_type_dict[state.model_type_verbose], + ) + else: + mod_manager = preloaded_model_manager + state.model_available = mod_manager.avail() opt_manager = OptimizationManager(mod_manager) # reset parameters if reset_parameters: @@ -113,6 +128,67 @@ def update( ctrl.figure_update(fig) +async def update_with_model_download_indicator(**update_kwargs): + """Run a dashboard update with visible download feedback for large MLflow models.""" + load_error = None + experiment = state.experiment + model_type_verbose = state.model_type_verbose + config_dict = load_config_dict(experiment) + model_type = model_type_dict[model_type_verbose] + state.model_available = False + state.model_mlflow_tracking_uri = (config_dict.get("mlflow") or {}).get( + "tracking_uri" + ) + state.model_downloading = True + state.model_download_status = "Downloading model from MLflow..." + state.model_download_progress = None + state.flush() + await asyncio.sleep(0.05) + try: + loaded_model = await asyncio.to_thread( + load_model_from_mlflow_with_progress, + config_dict, + model_type, + asyncio.get_running_loop(), + ) + except Exception as e: + loaded_model = None + load_error = e + if state.experiment != experiment or state.model_type_verbose != model_type_verbose: + return + if load_error is not None: + title = f"Unable to load model {model_type}" + msg = f"Error occurred when loading model from MLflow: {load_error}" + add_error(title, msg) + print(msg) + update_kwargs["preloaded_model_manager"] = ModelManager( + config_dict=config_dict, + model_type=model_type, + loaded_model=loaded_model, + ) + try: + update(**update_kwargs) + finally: + state.model_downloading = False + state.model_download_status = None + state.model_download_progress = None + state.flush() + + +def update_model_selection(**update_kwargs): + config_dict = load_config_dict(state.experiment) + model_type = model_type_dict[state.model_type_verbose] + if update_kwargs.get("reset_model", True) and is_model_available_on_mlflow( + config_dict, model_type + ): + asyncio.create_task(update_with_model_download_indicator(**update_kwargs)) + else: + state.model_downloading = False + state.model_download_status = None + state.model_download_progress = None + update(**update_kwargs) + + @state.change( "experiment", "experiment_date_range", @@ -128,17 +204,18 @@ def update( "use_inferred_calibration", ) def reset(**kwargs): + modified_keys = set(state.modified_keys) # skip if triggered on server ready (all state variables marked as modified) - if len(state.modified_keys) == 1: - print(f"Reacting to state change in {state.modified_keys}...") + if len(modified_keys) == 1: + print(f"Reacting to state change in {modified_keys}...") if any( - key in state.modified_keys + key in modified_keys for key in [ "experiment", "experiment_date_range", ] ): - update( + update_model_selection( reset_model=True, reset_output=True, reset_parameters=True, @@ -150,13 +227,13 @@ def reset(**kwargs): reset_gui_layout=False, ) elif any( - key in state.modified_keys + key in modified_keys for key in [ "model_type_verbose", "model_training_time", ] ): - update( + update_model_selection( reset_model=True, reset_output=False, reset_parameters=False, @@ -168,7 +245,7 @@ def reset(**kwargs): reset_gui_layout=False, ) elif any( - key in state.modified_keys + key in modified_keys for key in [ "displayed_output", "parameters", diff --git a/dashboard/logos/AmSC_300px.png b/dashboard/logos/AmSC_300px.png new file mode 100644 index 0000000000000000000000000000000000000000..a002c3abe97b4f577fc83750c4f9b0af066d2a32 GIT binary patch literal 14569 zcmeHtcT`is*6#@|^e!M>s`Su1h=}ywI{^ZO-g^lh1p%c?RjSlTF9H!nP?`dQAiX09 zNJrX>-oD;l>sxod_x*Q6Le81l^PAbf*|W3voOz_9rA&lRiw^(*k*dlAJpcglqOWss zanS!Wh`?ayV}jOLf8sP6-o{p27jr1pHWvGu;q!qJ>P22rrjtsf{BuBfP^s;KxkmC%ahg+1+5@71B~)7DR9xdl#4kg#{k-+9TB+=Q=G zLFyc9#GqmI3tN-+86~yy+or|Ee%mi^`%I0&#tNjDz)uoVP9Nw<;|_C=R~KLvZ8hhpviDhp8}5WD;erTRld^|+?X;ea?bpg>l)JGYx+HqP%{6v!|EHeD>6-D@TUNe4r zy!M+iTPsX{SLAJ!-iD}w*9cmXfm)_ZlJL<J+$ow!qGNnqNyQe z=i$Z=wfC@v@dvtjqKyv#WaI-qp>{4XUuIjFBivn9F;yy2d{a1VFp-%^GBV39tgVm-*Nv%`fs}b5k^aCYDzuuu=D#pJkLpM?;C*N=n*8?O}YvA~0b-5eacSK3h9`2|jU%gOH%Gpo6GAOz01k zy`9uU4{tXpdOG25P)C@6r@P~yir<7wDd?!mvIy}*{!yai3iWkBH;`q~fV=wz{-eSG z?grEMh5lw!P+VMGOk7w@SXfk4LPAX9A3{bjZy&TJ{>Bu9@C%FmsrfxEQfOt+#6o}D zDH`BUIa(JfMQ<3?*TdVu!^2gU<@b=7f4BU5wI+Hw*+YGy51_sgNRE({z=}$9_|qIKS}>yJj^nGnQ|4l54!)LKSh68QTi~izuf)h))oF| zDKRtuSrk%GyT7R51NDd5{~0G5>n~MyPEdD87}|UM)m;Bp4*xf%01>v8v=tGD@WJdH zB>6<35->irZQ1dO3EN3Z3Q5=tiwlYWE4q(|gKq%T8>Zlh<`K;m+C2Yo#eDbAL~;GA zbbu4=_bi|Z9S{}I3crR%?R{YMP^N5cQLuK&{YA2IMB3IEr+{=dD!%z)|x1%tquANc`!^nI`UBV-#w+6PqIrLjyBnD240qU;}N zv3$g3q{*?Ty8z+j62^gX?2u1hzol70xIs8_nNj4~%+Fv^WVophNOj4{1Mtho`{we) zBYcU>zlp<7Pof@#$rql_&JX#H9%%$<8v6N(Gvh}6$M5RMiaP;{YKG^h8EU2NLODsC zb*iK>3b#yHm6HtyGx^NHC^^6%j3h!l0Va?aA3cNiGjPl+uSIrrJJXmG_kQ<^ z#0Ui71OmzLQ3IB%Cu9g0kSnjYK50z&TfaFg#FKtToDO%`y+Mn=k5Dv+ij^IY*3 zQviP=iAl^wOyC9t$@aI^;w|MEEE! zU-23Ds1fQy<6g3nh4sCj25yTM19D|qq)P7zGub4h3@(}a}(?$?( zt*{Y}YtzCKGbwEJYwx4^?O=&yeO-P1;>pzd5*}{uytnP)+!@{%nek%&@DwToY|m6hDsHo|~DV1a=;{icb5 zY&M9$ucBh9&eX&SO4_P>sIdprkD!wa4dlY`mMw_abLQJV%Q;#KyM;Pa2rRk442+(c zw0{Nd6C{4Sgfk3i@htNoZy7oiD1#BvuJzM=0;YEw9I4iwsJb2${&th@WX^6+2T zemGaWR~$xw%2az*NyhrHg2noM3g6rDB&?CZv0z4l5! zEZS^OU$zxuQ!SKH>4TvUVPAvtJ|=rOe;M7tnW>J$Ydxu3;K1Ms!iF=ywdr1512O>< ztVMRf7!@PJwsv0~qQM$T**|$V>2rIDh7to2g|fq`dl$hX{}@f#GtZ{NKECL!Z5r+_ zo-wJ0l1^7q1V$AYwXQm5Y-A(_Sf!AgXliQWqli{}OmN1hMLy*{;7@bUE(p4FXW{+< zBPs(J(JCK*{iLO<%L>#wM@}5yS6VEv<(uj5ZyiGT>b0y##+o5;p_|I<0Wn-%FoEXZ z{16CHdJ)j>9^s662SXjRl^se2j&=G}j+6JTB)yXH0SsDG(z$iNAYRwvYZQ1fj_rP8 zW6P+?OgVYw$-+J8=;~Lw%%8kfgelJoSs(v^(%*zXXV+`PDIpVo7|N@xR9c!G84*_^ zZ&O2cev+%`=H$twG=o z?xGg?j2PY9q`ITRQ(x}W7C-~gQCi=MuHBxe6McM-%doCV4urCrs1RL~9QMk4Rj_d! z11~7p(Bv2nU==!D2c0P^DfLmv%F0@6pj3Q{7)Wd+)^30{LWys0eVvlPI_hf$FVfpJ zT{>l6p~lR_9>Td9v}z>%GxHl zA)1C{@j>?a(szItYrj$NyMK*emESHg;9s`jfqnkDvZ~5ZS-Bsp(%2lN((9W%jNvJ8 zMYOE5s+=mL`oqn*y`_$2ulQOxXS9ll4jlT(T~E#MAdQxe?iK>`6+}2#Q zlZiX0oc?o-XB5(>nCgnnqP)BQ5+hRn+)-EU*B3IAzicGgR2uHUeyYPhhCFqvvWPJR{O(<$GG&IY{b+2Fc>VEAwxN1fFJ(AISkqGTT zkCxXW+NUp)290S8^w~j=2>$#_L$S}4I)hJ+(7wDH;*Y2|@ zp7S=$kCOwrXO9Lr&nWTZtvc0UOGK;9Y#52O$o}N8dr{Biqf1ptJ8m%}wUBawoN;WH zqvS&JMc&jrNPiBu8Tn>_11oN04%>!0bDgs4a_=)ZYop<7z5Pq9Ov}2L%{Sbo-&wIH z4-lrcdS{wNK$%~{@WavEqRCHBstr6s> zbohKBC|g4-#QLy3!!Xl-hMEm|`8b?Jt5EVDJ((&;?%Pp~N8NMpj?Zb}C{KgFG1!`4s*WUol(++Knd3ZDll{z_hG2J0nVQ zdAn{1d4|P+aPT4^w~QJ40@A#7>UXfT>vZfN4>!3m-Mdulc%mzJc@<#J)Z5QnbAbRT zPIPM{>ps{82*9wzvbbI8H|j+WxmtaTu&Ab@%9Qe2xRx>1a!rdn9z{H2J&;Y!Kc%Hm zv#{|?Rkf4u8!zEI+X(C5QvvoRzeU);JKnR|$nWjCoVDl_shtwhu+gpP)u;t|9X7-y z7g=E>Vtw}^dwvDa43cgq!ANxRIRx$4$9G}vH^ZXXbEfz-Y_1A^M5KMc_sPd>kb8kd z5o^*8Z&soCWM1G%^4IfVd$Ez5s`&}CfOl*wplEex7Z*QcW8>Mnwy(f zf9VTvIen2m_;v5tZsd!~{6#zKMJFgHLSdpIw7;>X#TcCK7_O-~?D*qHY=`XVAc|1v zG>plxF+vfToeri)u->ojX1};H6uO%KxaKX8_Q?ydC~cCvbF713fa5HeO9{h>O4rQKaBYcL<$22hzCm(yJ1 zQy8s0TRcbaNuO3hmdMBTUrj_bOuLS#9foAzGXTCuf4spc%`wu~Khz2d2^l*$^m;;~ z%dlyw+hnD>xBXy!9&dR%^~%Fjx2VaA52H9Nt#^SxSu!GN&G_83XL~m5-txCur?95M z7i7(>Y?<-&SvLfWzAhd8lgNRRFp@s)&hZiY5{usFWL`IyE?cL_-A7|(*KLCbk;C|s z9=C>Mu8W^Z%&Bf)N=Aj?`5avn^BVTD)YM1Cv~=IJXass4I)Xoz8xRdn?MjoD(9V{i5)15J1qivfb6$+n`n@e}AyhaI~?@@01 zfX6fIf!*W|pzqhtZIW*_9C!f)r7+G`#Alw=i0 zzDNXD=pNtdYPRFl*!OrgZFSQd-d&eIW9V{4lxM^xTf1&Fn&2!(%@Z@-s2~Jm8MoA4alGrZUgNk;iij6f3EMRM4DW5WLclVi5WO9B>)fXA)BOj-orXIeAMH9O`y(a zhsp}Vd#3%Q{uX}6L>tNWPtqF&ouH_pv@R!;ZB{0sPYkO!%i+kFeERKFuA`{qg3-ir ztMmCGqvV-v1>>&tIVlt63ly|W+(db&LtEi^qodzNmy=54e54VbCLS6eFcNKh*O6CQN+t44 z9Az6qI*sx!Fqy)^QV4E$1}-de~H&_L7qC z(x2N{9=hyc9=K&o*1C>BHyo!gNGMF+Sk)v|ly?+)KF(8hd4po9S|yB?T!u=52C&p#-4Oim%dU*h1JIQI2~x4 zoGDm~@S?DO92C5^!*IZmP3JC-m_KIQxqcFAvo$PZlD%AcW_(Q|h_P)@rWE#gK5>Qd z!+|p^-zxzO7k#YJIe!O)k$E+0CO0N#um9U;HQ|wZs%8XU*nZPK31;IQmjwjr$iRWz z2^OX6>>t}jgoftg>_0U;opg$95?$^%yn>r(1<;oZLA$;cd%lm>#$@r1W?aQpKkjPM zSfjtSY`^=WzVlv&U54SVPb$e1r)JwnBP|EPOdRW*^p!D%=0ObB9BH4#JFJ>TmN`F= z1d$ENoX=XAq4GAVO-`@059;VuuIU)rBQ%mwW=sy;EH{_22G?OWWo_$3+pqNUcd1BB zHX84ie*eXd+_dYF);B^CO3sjQEQgCppWK1_<`#KIW@>NB{!qW~dq?_2Cw*M@rB#!@ zY-`56SIhA{r}9LOsJe}=#iv6M(YaEvsx=r1i!z}!H{nAvexIT!>zs97?Z~LAK6=Ke zQlaVqve0rk>{Y->?8=-nCz5eJ>phpb*FZ>{c_UhoQN&vq<(HK<9iScMn{qzmn4}um zSo#6mZ7|h_dpQQfC19I1L^ELlNy#*sJQ>g7YTMsy`h117s~MESKFj?GWqbTlU8ip3 zFgm20mzGK;8%s9tLib%#;VT1frLH6+oPox>1=pwpc>;8{Z{$MjUSAdK9VQXaV9A`R z;`>9Z!qCa6BaE0qfj|hCl6m)o)1Q)I>8`?Sj{u(`n%A64%2?IxUFGVlCy+i;-7{;?w?NF(g7Ddgv!uY6uu zWOaLGMy1m=Ko@uB*Cj>*dN39P*@qk&^k3bHZp_52;4Fz?EW*oUh!3R>+uD{b8aW~v3AO#OG^)31oUq5N{Q^~x2RBUeuM))vpe5U zqldjosx_tWrKqd6dRo<9aRisKsyGd4m+I}|U?tBm&DHh9#QEwn+-nEAGUVh zu>yw-exg2c`2InH#KCn|5w`%&3|`EFDiUC{4yi?eu@p#IT5-b6A7qT)_!!(5CWTn` z$b>GhMC#~bE0{caVU@zEPN{b-qq4!VMP#1`o!(@?P2i`-f*zMy8!fPRTybT4f(wbW zD{M8Jk5qtL5V=BaXqjo)d%rpA;}&7s>m=}#J5i&~^4_BtRl436%EqY&4r_Q2FAV~+or_>RJn9dW zXxW~??8WYifp6|Kg_HO*Gj<~$IG+CRil7RlF=IAl%j2Y7dAeh%@#;>gubFZ0-zXHH zi^ttAI%mV@T%5X7wqgzem4lBrgf^m>2`wj5Tk%ROn8}e0!l$nvq5G1^ z2gG7bJ?awe#eR7f7l(tr=vDQ)lQhlsnO_&q{wY>ghY<6esf9iW$+nnR!3QpSTK|>n zp65F^CeYW(_EDzSE6^twp3awE=Z+u?Wel#z3>agKjI79Y^Nm9K!>2z-bGbyVHwE;r za?Zy})x|N8bxk3hfH_8BM?y8RvDz1x8uMOS6aFme%7Jg@SVcVR^x9l&@Hv-Crf%p= z15X(!>CPb$AO?80=sCZcZ0A`D)5P)=K&bt?@c{xr*-9ASs{pN%?N(r9*r9;|xE5d; z^}ybdzKvp4lBsh^*Gqxd!h# z^DNEDq{n8=Bh6hhSs%1WcjUo|7|(uqpeed-xz7PwZIM((Cx&l zMY-QG_FW!J#W~Y-kAl}2$K=;FS>ddt-yoWgfA|c<2jUCc?O}Lwir=C*>qdlb+`~sb zDN(}jdCrPN1#begP{YuLc=B&2IS+NP7q>w@OD=$zE$xH;B6)*i*<@&kT!MOW(>w)2 z-o+(L@B4MD2zm(^x7A7~H|y`LsAusrs7EmZavG(20`0XIQV7XLk3(0n4xcC%OOi#+ zU!64@K^rRLWH`~`UhO=klappl4{!L5)ihcAk6GDtDnHaUFeV*JKkfi5a8_&zFcuxQ z&t8|jJ?nf_&i-b&*e35rR|b;;|LJ99nV!2?t}KZIHW^@2lpf5D^73dPy0}_5bE#>n zZ)y@1Qqa~;>DO3i)+*V+0Zw5J)AU#2h#`acp*Iozv*y%@29S(`-@0vfa^(3~OpIvc z_|t%%c^E@T0)w%|Z&b}Y_(9MkcY1bfmtb7JAuPqn z*!ZHXR$?yp%Q#zgWu;zbTYw|t{i=_e#yyty_I8;}Wo&T*pSD@7RYlxgd)fyP#VSZf zN4GSaNhd5sDZqhwW{h`!8YHMTddLCLsNalKy?Pat-&x!Odh_sxc*;y$EBi<6KmZ2) z*yqoJrq5H6W(3tEBO^Zn|B?vvbUp2IJ+7&RbiEslMd`GDm{5)-g?$JU+!*kr$5TQb z5p$0*@A4fObz@UZ+-~P8Wqho1LyO3&abuh#unAt@Ros(kgv?Cu>B0hU&Bh0!6S|g_ z?r05;6%(^r)t#J&pEYP-S_jk;JX)uY`4X3#9ki&6*(~nvel&HkdiUZ$n=g&2X; zLeQz4PsE1Px zCn~AsFm~czF*-daXdW5SB3hIM78JwTXNT*jce1h0Wk(wuo}%6PC>ESJcI%A{zTt}atRvuQ^%Rx8e_`z9-N-$`CuXF-GU`vroohISd>rl|td-6X{4sdXs;A7D zh*my5A&(nxoE(INPsg5@>TU7Flxp_64Ntsqx_{V4nG$3jrx^H3cG-b$dz3$Zz3VZ#xF zBM+dS`G_E^zHX<=YYAV3X*`~HcyvXCmzMm1OuFkC4o22XDEk~Sw{CAR{@b4pPsay! zE+^bJWivFhN{PZg!(z?WIzd|j>ex?ZeIn=)ip?$IQ%t@$eJkUYjdv2-&5;ulNekraYG zdBq^Gphe&yI(Gv&6JeR-7>E}SjoPd@Cwf)AvNbOZ5!M?rP?IHpJ#(lv5Pype=KrpN zpalISsm83ynGR6y88T@X(>Qw<3J7qQUaOMrE9?;9bLpye8h!VMs=bn`IeuzT_g<$U z5o!brCgqr5Phv(sMDqR;56KfU|7k42ORm`=k zAXDZ8f;ly+^-TJzV;_a@B@L7_IB1m>z3mHleOTL&JpAA%*PdOmCcr%WyC=?B*eTJd{lDV zbE%Qg0cXU0yL5y&2XLo^>s-#df3$m*(T%fTh~RR@QL}*q3wzP1U-&~W>rWtgRKe~i zmr!l2B-~Ap$R_FFx(9Pwr>z*Vy$LeF6t)JgNj>T%)FR?_?aEYDCZHfsf_fj#i`;1s zk*l#~JBqB=5-3imebLlauy#R%P^a(D*=Tp?Py3J*dAnytTfM$%@)!Op_~KXPStro- zutQ39{Fe)N#xDbx8t=AQ31CCd3LiCMwoQkA;-IbSy*$Q3cG(x_(8*qT~0qpn0mJH&oWG9YNA!Zv&S6iPBti+pbkp#=mLOFG=N8GSImAxP=SmPe9#DhEOeCgG z7Xy@o@<=pSpRArYce0(rfkl4AhO8Co1&HW*J1jH$?G53nC+J{Bo(Y}QPATVI zSOZzen{Mi*QCi&f9$y{a;+x#YN&N>qgba~=A+N=h!kwwi?EOJU5OD4y{l8` z{W^l!!*sCdGk+n@x-z$UK)%}_kcW*P;C$yzh;?v0rqfT<`#zb?ranM;H2CWSkiKWK zTHaV-YHE6JTM}_63CO5FX^rc8tpCB_x!QVrKjjhUkn0;B2bPN|ixB*aV z{ay~pcJgDg5S*aGg)T*u#}StBmlj<$qH(i;#IxI>L@s1W(O>HfZ$UcaPSX&<*eFrp zq!jPiRSep_KvU%C$#s6`WwW~7P|sf4BZ6#Gi53mdInC+4sX>#1MDp~F1CuYWJLf>= zs^BtCpQ5GF&3tssSF$a7u+cV6>f{*e-soI&e3wjCQLz^z$14#^|1tN@LWrJ{+*1aS zMM;wt#v(O3aF}wfmzPW27y||4xQtT;Xi*CE4(UcvWgM0na5qIHo%bBN^4SGFCnZ~W z_)d(bMl`jM6`i$&l78n5&=K%(bJMJp^&WoXg`~^Tb7q&-wpLYnw0T>pX|diHH;tmWlZHOh<7$l)+RiE!G{cuc+b;yTSeGE zIo?aE0M&?K0$SroR%CUyIo}wt9*k^@QA*BYm2F>$!z$9z@zm^I8BRE#hInxLr+wq` z=?Yx8Pb7_&WG{~|pEb*&cgjZ`8WL^d2tXX9J`bZx?C$BKEr5zMCrE7FZn;4L1}ml5 zP(B=mbyE{Y`oEt_+IcO){{r{pdM``iL?lEu8f-2o!CZcDxVkc3V8})t1m5D!w4Z|Z z#<90`7eW?>Zk5gkw-qQD8h#1WP^lqWsUwKa@;i;mYK_e}zPK_KY)5DRl)uOVG_#LP zZy1{Kv!yOYd_jFW#!k&^LIQkSl1aIhNfI^Qy)j{M-I50K#T@lI@!!^`slJQAqyJ|92 zURZi+(#)F5IZIrOJ6jjOYyF^XyPfG|?N{SK|K4=1&)0sw3gR-As^h&S!zn-OC!d^) zrI}+1Yueaq@W#xaPKX+}w9rjlWZk;c4>eV;^v$xWr4?-(dI6 zte=ti_6QOz%5n2Xx_U)5emyT=t#r0^m}$s`XxNmDvOwW}X5A&d7`?p!CGqNYuIwQe zUQ{&nI8^352eW`kvKJ(C=r_f5#hnt6oA){CycWw8DA0cL#I{BXF_c7ASm>OS}$)~myiOZ zPjpx7-J2t!4Js}qu&C{Yktjtq(%f|k{pNCq<8u=E$HU?hfpodM;$YmxBFQ2~+@F4; z)Ikv^MMrbwlU2mPMx5W*7!5;;X1?ajI0XzI_ra%3nlXmAH&c*tQyF?IO_K|p`YGU) zD67S7z};AFj<)gM#=+X4sA1)8w!6uZc-A^Y)cYWWJa=h0HL|E;Pdzv%{M*Xzdir9w z`A)}(&3#B{gLq}`6?wLkTf|sr$`6VP@m-wkCAtk%OxS*)XheAV@3z{}Np9;sSk+X$ zFH1Pb9)ByxY|Rpx#jM_5+caHB)hgc`wHHztT@_tb6nY#qFLzTG>?bv}`AD-mxCsNL z2)V=HfzEj2GuZVF4gH)Q9XEPQ`hFz$&tl8D{$z-tIiLPm$RylS3XX?n9GlxO+p5njF4E`SqA1~g3>=X+F{boPS Rq7N$qs!Cc9YVX@b{U1l>r&$01 literal 0 HcmV?d00001 diff --git a/dashboard/model_manager.py b/dashboard/model_manager.py index 6f8095ae..b7fc08fb 100644 --- a/dashboard/model_manager.py +++ b/dashboard/model_manager.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import contextmanager from datetime import datetime from pathlib import Path import tempfile @@ -6,22 +7,166 @@ import yaml import re import mlflow +import mlflow.store.artifact.artifact_repo as mlflow_artifact_repo +import mlflow.store.artifact.cloud_artifact_repo as mlflow_cloud_artifact_repo +import mlflow.utils.file_utils as mlflow_file_utils +from mlflow.exceptions import MlflowException +from trame.assets.local import LocalFileManager from sfapi_client import AsyncClient from sfapi_client.compute import Machine -from trame.widgets import vuetify3 as vuetify +from trame.widgets import vuetify3 as vuetify, html from utils import timer, load_config_dict, create_date_filter from calibration_manager import build_inferred_calibration from error_manager import add_error from sfapi_manager import monitor_sfapi_job from state_manager import state +LOGO_DIR = Path(__file__).parent / "logos" +AMSC_MLFLOW_URL = "https://mlflow.american-science-cloud.org" +MODEL_TYPE_GP = "Gaussian Process" +MODEL_TYPE_NN_SINGLE = "Neural Network (single)" +MODEL_TYPE_NN_ENSEMBLE = "Neural Network (ensemble)" +AMSC_LOGO_PATH = LOGO_DIR / "AmSC_300px.png" +AMSC_LOGO_URL = ( + LocalFileManager(LOGO_DIR).url("amsc_logo", AMSC_LOGO_PATH) + if AMSC_LOGO_PATH.is_file() + else None +) +MODEL_DOWNLOAD_ACTIVE_EXPR = "model_downloading" +AMSC_MLFLOW_LINK_ACTIVE_EXPR = ( + f"model_available && model_mlflow_tracking_uri === '{AMSC_MLFLOW_URL}'" +) +AMSC_MLFLOW_MODEL_URL_EXPR = ( + f"'{AMSC_MLFLOW_URL}/#/models/synapse-' + experiment + '_' + " + f"(model_type_verbose === '{MODEL_TYPE_GP}' ? 'GP' : " + f"model_type_verbose === '{MODEL_TYPE_NN_SINGLE}' ? 'NN' : " + f"model_type_verbose === '{MODEL_TYPE_NN_ENSEMBLE}' ? 'ensemble_NN' : " + "model_type_verbose)" +) + model_type_dict = { - "Gaussian Process": "GP", - "Neural Network (single)": "NN", - "Neural Network (ensemble)": "ensemble_NN", + MODEL_TYPE_GP: "GP", + MODEL_TYPE_NN_SINGLE: "NN", + MODEL_TYPE_NN_ENSEMBLE: "ensemble_NN", } +_NO_PRELOADED_MODEL = object() + + +def build_mlflow_model_name(config_dict, model_type): + """Return the registered MLflow model name for an experiment and model type.""" + return f"synapse-{config_dict['experiment']}_{model_type}" + + +def configure_mlflow_tracking(config_dict): + """Configure MLflow tracking for an experiment when MLflow is available.""" + mlflow_cfg = config_dict.get("mlflow") or {} + tracking_uri = mlflow_cfg.get("tracking_uri") + if not tracking_uri: + msg = ( + "No mlflow.tracking_uri in configuration file for " + f"{config_dict['experiment']}; cannot load model from MLflow." + ) + print(msg) + return False + + mlflow.set_tracking_uri(tracking_uri) + # When using the AmSC MLflow, inject the X-Api-Key to authenticate. + # (See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch) + if tracking_uri == AMSC_MLFLOW_URL: + enable_amsc_x_api_key(config_dict) + return True + + +def load_model_from_mlflow(config_dict, model_type): + """Load the latest registered MLflow model for an experiment configuration.""" + if not configure_mlflow_tracking(config_dict): + return None + + model_name = build_mlflow_model_name(config_dict, model_type) + return ( + mlflow.pyfunc.load_model(f"models:/{model_name}/latest") + .unwrap_python_model() + .model + ) + + +def is_model_available_on_mlflow(config_dict, model_type): + """Return whether MLflow has a registered model version to download.""" + if not configure_mlflow_tracking(config_dict): + return False + + model_name = build_mlflow_model_name(config_dict, model_type) + try: + versions = mlflow.MlflowClient().search_model_versions( + f"name='{model_name}'", + max_results=1, + ) + except MlflowException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + return False + print(f"Unable to check MLflow model availability for {model_name}: {e}") + return False + return bool(versions) + + +@contextmanager +def mlflow_artifact_progress_to_state(loop): + """Expose MLflow artifact download progress through dashboard state.""" + progress_bar_modules = [ + mlflow_file_utils, + mlflow_artifact_repo, + mlflow_cloud_artifact_repo, + ] + original_progress_bars = { + module: module.ArtifactProgressBar for module in progress_bar_modules + } + original_progress_bar = mlflow_file_utils.ArtifactProgressBar + + def set_download_progress(progress, total): + """Publish the current download completion percentage to the GUI.""" + + def update_progress_state(): + if total: + state.model_download_progress = min(100, progress / total * 100) + else: + state.model_download_progress = None + state.flush() + + loop.call_soon_threadsafe(update_progress_state) + + class TrameArtifactProgressBar(original_progress_bar): + def __init__(self, desc, total, step, **kwargs): + super().__init__(desc, total, step, **kwargs) + self.trame_progress = 0 + if desc.startswith("Downloading"): + set_download_progress(self.trame_progress, self.total) + + def update(self): + super().update() + self.trame_progress = min( + self.total, + self.trame_progress + self.step, + ) + if self.desc.startswith("Downloading"): + set_download_progress(self.trame_progress, self.total) + + for module in progress_bar_modules: + module.ArtifactProgressBar = TrameArtifactProgressBar + try: + yield + finally: + for module, progress_bar in original_progress_bars.items(): + module.ArtifactProgressBar = progress_bar + + +def load_model_from_mlflow_with_progress(config_dict, model_type, loop): + """Load an MLflow model while reporting artifact download progress.""" + with mlflow_artifact_progress_to_state(loop): + return load_model_from_mlflow(config_dict, model_type) + + def enable_amsc_x_api_key(config_dict): """ MLflow authentication helper for the AmSC MLflow server. @@ -49,7 +194,10 @@ def enable_amsc_x_api_key(config_dict): add_error(title, msg) print(msg) return - _orig = rest_utils.http_request + if getattr(rest_utils.http_request, "_synapse_amsc_api_key", None) == api_key: + return + + _orig = getattr(rest_utils, "_synapse_http_request", rest_utils.http_request) def patched(host_creds, endpoint, method, *args, **kwargs): if "headers" in kwargs and kwargs["headers"] is not None: @@ -62,40 +210,25 @@ def patched(host_creds, endpoint, method, *args, **kwargs): kwargs["extra_headers"] = h return _orig(host_creds, endpoint, method, *args, **kwargs) + patched._synapse_amsc_api_key = api_key + rest_utils._synapse_http_request = _orig rest_utils.http_request = patched class ModelManager: - def __init__(self, config_dict, model_type): + def __init__(self, config_dict, model_type, loaded_model=_NO_PRELOADED_MODEL): print("Initializing model manager...") self.__model = None self.__model_type = model_type - if "mlflow" not in config_dict or not config_dict["mlflow"].get("tracking_uri"): - print( - f"No mlflow.tracking_uri in configuration file for {config_dict['experiment']}; cannot load model from MLflow." - ) - return - - mlflow.set_tracking_uri(config_dict["mlflow"]["tracking_uri"]) - # When using the AmSC MLflow: inject the X-Api-Key into the requests to authenticate with the MLflow server - # (See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch) - if ( - config_dict["mlflow"]["tracking_uri"] - == "https://mlflow.american-science-cloud.org" - ): - enable_amsc_x_api_key(config_dict) - - experiment = config_dict["experiment"] - model_name = f"synapse-{experiment}_{model_type}" - try: - # Download model from MLflow server self.__model = ( - mlflow.pyfunc.load_model(f"models:/{model_name}/latest") - .unwrap_python_model() - .model + load_model_from_mlflow(config_dict, model_type) + if loaded_model is _NO_PRELOADED_MODEL + else loaded_model ) + if self.__model is None: + return if model_type not in ("NN", "ensemble_NN", "GP"): raise ValueError(f"Unsupported model type: {model_type}") # Populate inferred calibration in physics units for GUI @@ -353,37 +486,113 @@ def panel(self): print("Setting model card...") # list of available model types model_type_list = [ - "Gaussian Process", - "Neural Network (single)", - "Neural Network (ensemble)", + MODEL_TYPE_GP, + MODEL_TYPE_NN_SINGLE, + MODEL_TYPE_NN_ENSEMBLE, ] + model_type_cols = 8 if AMSC_LOGO_URL else 12 with vuetify.VExpansionPanels(v_model=("expand_panel_control_model", 0)): with vuetify.VExpansionPanel( title="Control: Models", style="font-size: 20px; font-weight: 500;", ): with vuetify.VExpansionPanelText(): - with vuetify.VRow(): - with vuetify.VCol(): + with vuetify.VRow(align="center"): + with vuetify.VCol(cols=model_type_cols): vuetify.VSelect( v_model=("model_type_verbose",), label="Model type", items=(model_type_list,), dense=True, ) - with vuetify.VCol(): - vuetify.VTextField( - v_model_number=("model_training_status",), - label="Training status", - readonly=True, + if AMSC_LOGO_URL: + with vuetify.VCol( + cols=4, + classes="d-flex align-center justify-end", + ): + with html.A( + v_if=(AMSC_MLFLOW_LINK_ACTIVE_EXPR,), + href=(AMSC_MLFLOW_MODEL_URL_EXPR,), + target="_blank", + rel="noopener noreferrer", + title="Open selected model in AmSC MLflow", + style=( + "display: block; width: 100%; " + "max-width: 300px; margin-left: auto; " + "cursor: pointer;" + ), + ): + vuetify.VImg( + src=AMSC_LOGO_URL, + alt="AmSC", + max_width=300, + max_height=72, + contain=True, + style="width: 100%;", + ) + vuetify.VImg( + v_if=(f"!({AMSC_MLFLOW_LINK_ACTIVE_EXPR})",), + src=AMSC_LOGO_URL, + alt="AmSC", + max_width=300, + max_height=72, + contain=True, + title=( + "Selected model is not available in AmSC MLflow" + ), + style=( + "width: 100%; max-width: 300px; " + "margin-left: auto;" + ), + ) + with vuetify.VRow( + v_if=(MODEL_DOWNLOAD_ACTIVE_EXPR,), + no_gutters=True, + align="center", + style="margin-top: -8px; margin-bottom: 8px;", + ): + with vuetify.VCol(cols=model_type_cols): + with html.Div( + classes=( + "d-flex align-center text-caption " + "text-medium-emphasis mb-1" + ) + ): + vuetify.VIcon( + "mdi-cloud-download-outline", + size=16, + classes="mr-1", + ) + html.Span(v_text=("model_download_status",)) + vuetify.VSpacer() + html.Span( + v_if=("model_download_progress !== null",), + v_text=( + "`${Math.round(model_download_progress)}%`", + ), + ) + vuetify.VProgressLinear( + indeterminate=("model_download_progress === null",), + model_value=("model_download_progress",), + color="primary", + height=4, + rounded=True, ) - with vuetify.VRow(): - with vuetify.VCol(): + with vuetify.VRow(align="center"): + with vuetify.VCol(cols="auto"): vuetify.VBtn( "Train", click=self.training_trigger, disabled=( - "model_training || (model_training_mode === 'sfapi' && perlmutter_status !== 'active')", + "model_training || " + "(model_training_mode === 'sfapi' && " + "perlmutter_status !== 'active')", ), style="text-transform: none", ) + with vuetify.VCol(cols=6, style="margin-left: auto;"): + vuetify.VTextField( + v_model_number=("model_training_status",), + label="Training status", + readonly=True, + ) diff --git a/dashboard/state_manager.py b/dashboard/state_manager.py index 8d8d31ec..10680967 100644 --- a/dashboard/state_manager.py +++ b/dashboard/state_manager.py @@ -32,6 +32,11 @@ def initialize_state(): state.model_training_mode = "local" state.model_training_status = None state.model_training_time = None + state.model_available = False + state.model_downloading = False + state.model_download_status = None + state.model_download_progress = None + state.model_mlflow_tracking_uri = None # Optimization state.optimization_type = "Maximize" state.optimization_status = None