Conversation
|
Check out this pull request on Review Jupyter notebook visual diffs & provide feedback on notebooks. Powered by ReviewNB |
machine-learning/jax-haiku-dask-dataframe-distributed-example.ipynb
Outdated
Show resolved
Hide resolved
…ipynb Co-authored-by: Matthew Rocklin <mrocklin@gmail.com>
| " df_one_partition = ddf_one_partition.compute()\n", | ||
| " scaled_x = jnp.array(df_one_partition[[\"scaled_x\"]].values)\n", | ||
| " y = jnp.array(df_one_partition[[\"y\"]].values)\n", | ||
| " params, opt_state = update(params, opt_state, scaled_x, y)" |
There was a problem hiding this comment.
It might be worth taking a look at some of the functionality in dask-ml, which might do some of these things for you already if you're interested.
cc'ing @stsievert and @TomAugspurger
| " futures = []\n", | ||
| " for ddf_one_partition in ddf_train.partitions:\n", | ||
| " # Compute the gradients in parallel\n", | ||
| " futures.append(client.submit(dask_compute_grads_one_partition_wrapper, ddf_one_partition, params))\n", |
There was a problem hiding this comment.
I recommend instead ...
from dask.distributed import futures_of
futures = futures_of(df.map_partitions(func, **params).persist())There was a problem hiding this comment.
Thanks, I've tried this but .map_partitions() requires you to return either a Dask.DataFrame or Dask.Series (I think?). My function returns a set of gradients, grads, which is a Python dictionary (with more python dicts inside, i.e. a tree-like structure), so I don't think this will work in this case (please correct me if I am mistaken).
There was a problem hiding this comment.
You can probably work around that with to_delayed() instead of map_partitions. I can take a closer look later.
machine-learning/jax-haiku-dask-dataframe-distributed-example.ipynb
Outdated
Show resolved
Hide resolved
| " # Bring the gradients back to the client, and update the model with the optimizer on the client\n", | ||
| " grads = future.result()\n", | ||
| " updates, opt_state = optimizer.update(grads, opt_state)\n", | ||
| " params = optix.apply_updates(params, updates)" |
There was a problem hiding this comment.
This is also the kind of thing for which Actors is probably a decent fit.
There was a problem hiding this comment.
Yes, I've been trying to think how to perform training with shared parameters (and optimizer state) among workers via Actors. Haven't quite got my head around how this might work yet.
There was a problem hiding this comment.
This might be a start: https://docs.dask.org/en/latest/futures.html#example-parameter-server
There was a problem hiding this comment.
That example doesn't run, maybe a bad merge. I've put in a PR to correct that: dask/dask#6449
This notebook example is a learning exercise during the Scipy2020 Dask sprint to establish how dask might be used to parallelize jax/dm-haiku deep learning model training and prediction.
I've committed my notebook that is working end-to-end, and demonstrates a neural network for learning the sine function.