From e5afb14b62fabda3d6c9ac94d6703425c3436a82 Mon Sep 17 00:00:00 2001 From: Tennessee Leeuwenburg Date: Sat, 7 Jun 2025 18:42:46 +1000 Subject: [PATCH 1/2] Commit WIP on temporal climate data handling --- .../src/pyearthtools/data/indexes/indexes.py | 21 +++++-- .../src/pyearthtools/data/save/dataset.py | 2 - packages/data/src/pyearthtools/data/time.py | 10 +++- packages/models/README.md | 5 ++ packages/models/pyproject.toml | 57 +++++++++++++++++++ .../src/site_archive_nci/_CMIP5.py | 12 +++- .../pipeline/branching/branching.py | 14 +++-- .../src/pyearthtools/pipeline/controller.py | 29 ++++++++-- .../modifications/idx_modification.py | 50 ++++++++++++---- .../operations/xarray/_recode_calendar.py | 2 +- .../pipeline/operations/xarray/join.py | 18 +++++- .../operations/xarray/normalisation.py | 4 -- .../src/pyearthtools/pipeline/parallel.py | 25 ++++---- .../src/pyearthtools/pipeline/step.py | 7 ++- .../training/data/lightning/datasets.py | 8 ++- 15 files changed, 215 insertions(+), 49 deletions(-) create mode 100644 packages/models/README.md create mode 100644 packages/models/pyproject.toml diff --git a/packages/data/src/pyearthtools/data/indexes/indexes.py b/packages/data/src/pyearthtools/data/indexes/indexes.py index 264fffa9..fab3bc4c 100644 --- a/packages/data/src/pyearthtools/data/indexes/indexes.py +++ b/packages/data/src/pyearthtools/data/indexes/indexes.py @@ -37,6 +37,8 @@ from pathlib import Path from typing import Any, Callable, Iterable, Literal, Optional +import pandas as pd +import cftime import xarray as xr import pyearthtools.data @@ -482,15 +484,26 @@ def retrieve( if time_dim not in data.dims and time_dim in data.coords: data = data.expand_dims(time_dim) + time_query = str(Petdt(querytime)) + if isinstance(data.coords[time_dim].values[0], cftime.datetime): + time_query = cftime.datetime(querytime.year, + querytime.month, + querytime.day, + calendar='noleap', + has_year_zero=True) + self._round = True + round = True + # time_query = pd.to_datetime(time_query) + if select and time_dim in data: try: data = data.sel( - **{time_dim: str(Petdt(querytime))}, + **{time_dim: time_query}, method="nearest" if round else None, ) except KeyError: warnings.warn( - f"Could not find time in dataset to select on. {querytime!r}", + f"Could not find time in dataset to select on. {time_query!r}", IndexWarning, ) @@ -535,14 +548,14 @@ def series( Loaded series of data """ - interval = self._get_interval(interval) + _interval = self._get_interval(interval) tolerance = kwargs.pop("tolerance", getattr(self, "data_interval", None)) return index_routines.series( self, start, end, - interval, + _interval, transforms=transforms, tolerance=tolerance, **kwargs, diff --git a/packages/data/src/pyearthtools/data/save/dataset.py b/packages/data/src/pyearthtools/data/save/dataset.py index 7d080d00..bb67fc20 100644 --- a/packages/data/src/pyearthtools/data/save/dataset.py +++ b/packages/data/src/pyearthtools/data/save/dataset.py @@ -209,6 +209,4 @@ def to_zarr( save_kwargs.pop("append_dim", None) save_kwargs.pop("region", None) - # print(save_kwargs) - # print(dataset) dataset.to_zarr(zarr_file, **save_kwargs) diff --git a/packages/data/src/pyearthtools/data/time.py b/packages/data/src/pyearthtools/data/time.py index 17496445..a66830aa 100644 --- a/packages/data/src/pyearthtools/data/time.py +++ b/packages/data/src/pyearthtools/data/time.py @@ -459,9 +459,13 @@ def __add__(self, other: Petdt | TimeDelta | int) -> Petdt: If int, add to last level of resolution """ + # import pudb; pudb.set_trace() + resolution = TimeResolution("year") if isinstance(other, _MonthTimeDelta): - return NotImplemented + if isinstance(other, int): + raise NotImplementedError + if isinstance(other, int): if other < 0: @@ -644,6 +648,7 @@ def __init__(self, timedelta: Any, *args) -> None: 0 days 00:10:00 """ resolution = None + # import pudb; pudb.set_trace() if args: timedelta = (timedelta, *args) self._input_timedelta = timedelta @@ -658,6 +663,8 @@ def __init__(self, timedelta: Any, *args) -> None: timedelta = (1, timedelta) elif len(multisplit(timedelta, [" ", ",", "-"])) == 2: timedelta = tuple(x.strip() for x in multisplit(timedelta, [" ", ",", "-"])) + else: + raise ValueError(f"Unrecognised time interval {timedelta} not in {RESOLUTION_COMPONENTS}") if isinstance(timedelta, (list, tuple)) and len(timedelta) == 2: if isinstance(timedelta[1], str) and timedelta[1].strip().removesuffix("s") in RESOLUTION_COMPONENTS: @@ -809,6 +816,7 @@ def __init__(self, timedelta: Any, *args): _resolution = TimeResolution("year") modified_time_delta[0] = int(modified_time_delta[0]) * 12 + # import pudb; pudb.set_trace() super().__init__((int(modified_time_delta[0]) * 30, "days")) self._input_timedelta = timedelta diff --git a/packages/models/README.md b/packages/models/README.md new file mode 100644 index 00000000..aee46bd4 --- /dev/null +++ b/packages/models/README.md @@ -0,0 +1,5 @@ +# PyEarthTools Models + +This is the modles sub-package which forms a part of the [PyEarthTools package](https://github.com/ACCESS-Community-Hub/PyEarthTools). + +Documentation for the PyEarthTools package is available [here](https://pyearthtools.readthedocs.io/en/latest/). diff --git a/packages/models/pyproject.toml b/packages/models/pyproject.toml new file mode 100644 index 00000000..d9c2f17e --- /dev/null +++ b/packages/models/pyproject.toml @@ -0,0 +1,57 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + + +[project] +name = "pyearthtools-models" +description = "Blueprint model implementations which can be used in PyEarthTools" +requires-python = ">=3.9" +keywords = ["pyearthtools"] +maintainers = [ + {name = "Tennessee Leeuwenburg", email = "tennessee.leeuwenburg@bom.gov.au"} +] +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", +] +dependencies = [ + "xarray[complete]", + "geopandas", + "shapely", + "tqdm", + "pyyaml", +] +dynamic = ["version", "readme"] + +[tool.setuptools.dynamic] +readme = {file = ["README.md"], content-type = "text/markdown"} + +all = [ + "pyearthtools-models", +] + +[project.urls] +homepage = "https://pyearthtools.readthedocs.io/" +documentation = "https://pyearthtools.readthedocs.io/" +repository = "https://github.com/ACCESS-Community-Hub/PyEarthTools" + +[tool.isort] +profile = "black" + +[tool.black] +line-length = 120 + +[tool.ruff] +line-length = 120 + +[tool.mypy] +warn_return_any = true +warn_unused_configs = true + +[tool.hatch.version] +path = "src/pyearthtools/models/__init__.py" + +[tool.hatch.build.targets.wheel] +packages = ["src/pyearthtools/"] diff --git a/packages/nci_site_archive/src/site_archive_nci/_CMIP5.py b/packages/nci_site_archive/src/site_archive_nci/_CMIP5.py index 1eab770d..c2370181 100644 --- a/packages/nci_site_archive/src/site_archive_nci/_CMIP5.py +++ b/packages/nci_site_archive/src/site_archive_nci/_CMIP5.py @@ -163,8 +163,13 @@ def __init__( warnings.warn(UNDER_DEV_MSG) + base_transforms = TransformCollection() + base_transforms += pyearthtools.data.transforms.variables.variable_trim(variables) + base_transforms += pyearthtools.data.transforms.coordinates.Drop(coordinates='height') + super().__init__( - transforms=TransformCollection(), + transforms=base_transforms, + data_interval= (1, "month") ) self.record_initialisation() @@ -208,6 +213,9 @@ def quick_walk(self): self.walk_cache = walk_cache + + + def filesystem(self, query_dictionary={}): """ Given the supplied query, return all filenames which contain the data necessary to extract the data @@ -248,7 +256,7 @@ def match_path(self, path, query): if path["model"] not in self.models: match = False - if path["interval"] not in self.interval: + if path["interval"] not in self.interval[0]: match = False if path["scenario"] not in self.scenarios: diff --git a/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py b/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py index ea6a1af3..079a090a 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py +++ b/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py @@ -127,12 +127,18 @@ def __init__( self.sub_pipelines = list(map(lambda x: Pipeline(*x), __incoming_steps)) def __getitem__(self, idx: Any) -> tuple: - """Get result from each branch""" - results = [] + """ + Go through each upstream sub-pipeline, fetch their results, and return them as a tuple + """ + queries = [] + query_interface = self.parallel_interface # This is a property and might be a serial interface + for pipe in self.sub_pipelines: - results.append(self.parallel_interface.submit(pipe.__getitem__, idx)) + query = query_interface.submit(pipe.__getitem__, idx) + queries.append(query) - return tuple(self.parallel_interface.collect(results)) + samples = self.parallel_interface.collect(queries) + return tuple(samples) def apply(self, sample): """Apply each branch on the sample""" diff --git a/packages/pipeline/src/pyearthtools/pipeline/controller.py b/packages/pipeline/src/pyearthtools/pipeline/controller.py index de7dbfbb..440aed46 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/controller.py +++ b/packages/pipeline/src/pyearthtools/pipeline/controller.py @@ -460,7 +460,9 @@ def has_source(self) -> bool: def _get_initial_sample(self, idx: Any) -> tuple[Any, int]: """ - Get sample from first pipeline step or first working back index + Get a data sample from the first pipeline step, or an intermediate generator + e.g. such as a source accessor, an intermediate cache, + or a temporal retrieval modifier Returns: (tuple[Any, int]): @@ -473,25 +475,40 @@ def _get_initial_sample(self, idx: Any) -> tuple[Any, int]: for index, step in enumerate(self.steps[::-1]): if isinstance(step, PipelineIndex): LOG.debug(f"Getting initial sample from {step} at {idx}") - return step[idx], len(self.steps) - (index + 1) + sample = step[idx] + whereinthesequence = len(self.steps) - (index + 1) + return sample, whereinthesequence # Confirm that the start of the pipeline is an accessor, and then fetch from it if isinstance(self.steps[0], (_Pipeline, Index)): LOG.debug(f"Getting initial sample from {self.steps[0]} at {idx}") - return self.steps[0][idx], 0 + sample = self.steps[0][idx] + whereinthesequence = 0 + return sample, whereinthesequence + # If we haven't returned something by now, that's an error condition, raise it. raise TypeError(f"Cannot find an `Index` to get data from. Found {type(self.steps[0]).__qualname__}") def __getitem__(self, idx: Any): - """Retrieve from pipeline at `idx`""" + """ + Retrieve from pipeline at `idx` + + Called by users when accessing the pipeline + Also called by modifications (such as temporal retrieval) + """ if isinstance(idx, slice): indexes = self.iterator[idx] LOG.debug(f"Call pipeline __getitem__ for {indexes = }") return map(self.__getitem__, indexes) + # Start the pipeline with the raw/initial data + # Index here is the index of the data, not the position of the step in the list of steps + # Initial just means untransformed by the pipeline sample, step_index = self._get_initial_sample(idx) - LOG.debug(f"Call pipeline __getitem__ for {idx = }") + unmodified_sample = sample + LOG.debug(f"Call pipeline __getitem__ for {idx}") + # Apply each pipeline step to the sample for step in self.steps[step_index + 1 :]: if not isinstance(step, (Pipeline, PipelineStep, Transform, TransformCollection)): raise TypeError(f"When iterating through pipeline steps, found a {type(step)} which cannot be parsed.") @@ -504,6 +521,8 @@ def __getitem__(self, idx: Any): sample = step.apply(sample) else: sample = step(sample) # type: ignore + + # We've done all the pipeline steps, return the value return sample def __call__(self, obj): diff --git a/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py b/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py index 42ba7022..01dd0317 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py +++ b/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py @@ -189,30 +189,46 @@ def trim(s): # FIXME this is just a debugging workaround self._merge_kwargs.pop("axis") + # import pudb; pudb.set_trace() result = merge_function(sample, **self._merge_kwargs) + return result def _get_tuple(self, idx, mod: tuple[Any, ...], layer: int) -> Union[tuple[Any], Any]: """ - Collect all elements from tuple of modification + Given a sequence of modifications (e.g. temporal retrievals), go through + and unpack them, then ultimately call the parent pipeline's getitem method against + each modification request. Will descend through nested tuples. """ super_get = self.parent_pipeline().__getitem__ + query_interface = self.parallel_interface # Dynamically fetches via property - samples = [] + queries = [] + if len(mod) > 1: + import pudb; pudb.set_trace() + + # If using the parallel interface, this will prepare the queries + # If using the serial interface, the actual results will be fetched up-front + # and stored on the queries with a ResultCache for m in mod: if isinstance(m, tuple): - samples.append(self.parallel_interface.submit(self._get_tuple, idx, m, layer + 1)) + # Tuple-unpacking pathway + query = query_interface.submit(self._get_tuple, idx, m, layer + 1) else: - samples.append(self.parallel_interface.submit(super_get, idx + m)) + # Retrieve from pipeline pathway + # THIS MAY BE WHERE THE BUG IS HAPPENING + query = query_interface.submit(super_get, idx + m) + queries.append(query) - samples = tuple(self.parallel_interface.collect(samples)) + # Retrieve the results of the queries, either in parallel or in series + samples = tuple(query_interface.collect(queries)) - # def trim(s): - # if isinstance(s, tuple) and len(s) == 1: - # return s[0] - # return s + # If the whole thing just fetched one sample, return it without any merging + if isinstance(samples, tuple): + if len(samples) == 1: + return samples[0] if layer >= self._merge: return self._run_merge(samples) @@ -220,10 +236,15 @@ def _get_tuple(self, idx, mod: tuple[Any, ...], layer: int) -> Union[tuple[Any], def __getitem__(self, idx: Any): + # If we do not have any unpacking to do, get the upstream sample + # and apply our modification if not isinstance(self._modification, tuple): - return self.parent_pipeline()[idx + self._modification] + result = self.parent_pipeline()[idx + self._modification] + return result - return self._get_tuple(idx, self._modification, 0) + # Do unpacking of the tuple + result = self._get_tuple(idx, self._modification, 0) + return result class TimeIdxModifier(IdxModifier): @@ -458,15 +479,20 @@ def __init__( def map_to_tuple(mod): if isinstance(mod, tuple): return tuple(map(map_to_tuple, mod)) + return pyearthtools.data.TimeDelta((mod, delta_unit)) if delta_unit is not None: self._modification = map_to_tuple(self._modification) def __getitem__(self, idx: Any): + + # Try to convert the index to a datetime for temporal retrievals if not isinstance(idx, pyearthtools.data.Petdt): if not pyearthtools.data.Petdt.is_time(idx): raise TypeError(f"Cannot convert {idx!r} to `pyearthtools.data.Petdt`.") idx = pyearthtools.data.Petdt(idx) - return super().__getitem__(idx) + # Fetch the item using a Petdt index + result = super().__getitem__(idx) + return result diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_recode_calendar.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_recode_calendar.py index c37f83bb..20b98ffe 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_recode_calendar.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_recode_calendar.py @@ -55,7 +55,7 @@ def apply_func(self, data: xr.Dataset) -> xr.Dataset: Sorted dataset """ - recoded = data.indexes["time"].to_datetimeindex() + recoded = data.indexes["time"].to_datetimeindex(time_unit='us') data["time"] = recoded return data diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py index 6a527438..b0b82c54 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py @@ -72,6 +72,7 @@ def __init__( self.interpolation_method = interpolation_method self.time_dimension = time_dimension self._merge_kwargs = merge_kwargs + import pudb; pudb.set_trace() def _join_two_datasets(self, sample_a: xr.Dataset, sample_b: xr.Dataset) -> xr.Dataset: """ @@ -84,15 +85,28 @@ def _join_two_datasets(self, sample_a: xr.Dataset, sample_b: xr.Dataset) -> xr.D if self.time_dimension not in sample_b.coords: raise ValueError(f"Time dimension missing from {str(sample_b)}") + + # We need to make interp_like ignore the time dimension + + if sample_a is self.reference_dataset: + interped_a = sample_a + else: + interped_a = sample_a.interp_like(self.reference_dataset, method="nearest") + + if sample_b is self.reference_dataset: + interped_b = sample_b + else: + interped_b = sample_b.interp_like(self.reference_dataset, method="nearest") - interped_a = sample_a.interp_like(self.reference_dataset, method="nearest") - interped_b = sample_b.interp_like(self.reference_dataset, method="nearest") merged = xr.merge([interped_a, interped_b]) + # import pudb; pudb.set_trace() return merged def join(self, sample: tuple[Union[xr.Dataset, xr.DataArray], ...]) -> xr.Dataset: """Join sample""" + import pudb; pudb.set_trace() + # Obtain the reference dataset if self.reference_dataset is None: if self.reference_index is not None: diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/normalisation.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/normalisation.py index e7725f11..4a96489e 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/normalisation.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/normalisation.py @@ -93,9 +93,6 @@ def __init__(self, cache_dir=".", samples_needed=20): super().__init__() self.record_initialisation() self.vars = {} - # import random - # myid = random.randint(0, 20) - # print(f"Initialising {myid}") self.means_filename = os.path.join(cache_dir, "magic_means.nc") self.deviation_filename = os.path.join(cache_dir, "magic_std.nc") @@ -106,7 +103,6 @@ def __init__(self, cache_dir=".", samples_needed=20): self.deviation = None if os.path.exists(self.means_filename): - # print(f"Found file for {myid})") self.mean = xr.load_dataset(self.means_filename) self.deviation = xr.load_dataset(self.deviation_filename) self.samples_needed = 0 diff --git a/packages/pipeline/src/pyearthtools/pipeline/parallel.py b/packages/pipeline/src/pyearthtools/pipeline/parallel.py index b5fa2ced..00ec9899 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/parallel.py +++ b/packages/pipeline/src/pyearthtools/pipeline/parallel.py @@ -67,14 +67,18 @@ def __repr__(self): PARALLEL_INTERFACES = Literal["Futures", "Delayed", "Serial"] -class FutureFaker: +class ResultCache: + ''' + This is a small in-memory class, which can be used to substitue for + anything which calculates a result. + ''' + def __init__(self, obj): self._obj = obj def result(self, *args): return self._obj - class ParallelInterface: """ Interface for parallel computation. @@ -125,13 +129,14 @@ def config(self): return self._interface_kwargs.get("Serial", {}) def submit(self, func, *args, **kwargs): - return FutureFaker(func(*args, **kwargs)) + result = func(*args, **kwargs) + return ResultCache(result) def map(self, func, iterables, *iter, **kwargs) -> Future: - return tuple(map(lambda i: FutureFaker(func(i, **kwargs)), iterables, *iter)) # type: ignore + return tuple(map(lambda i: ResultCache(func(i, **kwargs)), iterables, *iter)) # type: ignore def gather(self, futures, *args, **kwargs): - if isinstance(futures, FutureFaker): + if isinstance(futures, ResultCache): return futures.result() return type(futures)(map(lambda x: x.result(), futures)) @@ -139,7 +144,7 @@ def wait(self, futures, **kwargs): return futures def collect(self, futures): - if isinstance(futures, FutureFaker): + if isinstance(futures, ResultCache): return futures.result() return type(futures)(map(lambda x: x.result(), futures)) @@ -264,13 +269,13 @@ def run_delayed(self, func, *args, **kwargs): return delayed(func, name=name, pure=pure)(*args, **kwargs) def submit(self, func, *args, **kwargs): - return FutureFaker(self.run_delayed(func, *args, **kwargs)) + return ResultCache(self.run_delayed(func, *args, **kwargs)) def map(self, func, iterables, *iter, **kwargs) -> Future: - return tuple(map(lambda i: FutureFaker(self.run_delayed(func, i, **kwargs)), iterables, *iter)) # type: ignore + return tuple(map(lambda i: ResultCache(self.run_delayed(func, i, **kwargs)), iterables, *iter)) # type: ignore def gather(self, futures): - if isinstance(futures, FutureFaker): + if isinstance(futures, ResultCache): return futures.result() return type(futures)(map(lambda x: x.result(), futures)) @@ -278,7 +283,7 @@ def wait(self, futures): return futures def collect(self, futures): - if isinstance(futures, FutureFaker): + if isinstance(futures, ResultCache): return futures.result() return type(futures)(map(lambda x: x.result(), futures)) diff --git a/packages/pipeline/src/pyearthtools/pipeline/step.py b/packages/pipeline/src/pyearthtools/pipeline/step.py index 64c0f5a8..7aa62872 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/step.py +++ b/packages/pipeline/src/pyearthtools/pipeline/step.py @@ -91,6 +91,9 @@ def _split_tuples_call( ): """ Split `sample` if it is a tuple and apply `_function` of `self` to each. + This is mainly for the purposes of allowing multiprocessing of pipelines, + which may be relevant if a merged pipeline is drawing from two sources which + could each be doing data processing in parallel for each time step """ if allow_parallel: @@ -115,7 +118,9 @@ def _split_tuples_call( ) return tuple(parallel_interface.collect(parallel_interface.map(func, sample))) - return parallel_interface.collect(parallel_interface.submit(func, sample)) + query = parallel_interface.submit(func, sample) + result = parallel_interface.collect(query) + return result def check_type( self, diff --git a/packages/training/src/pyearthtools/training/data/lightning/datasets.py b/packages/training/src/pyearthtools/training/data/lightning/datasets.py index d0ee7998..87d5c02c 100644 --- a/packages/training/src/pyearthtools/training/data/lightning/datasets.py +++ b/packages/training/src/pyearthtools/training/data/lightning/datasets.py @@ -87,4 +87,10 @@ def __len__(self): return len(self._pipeline.iteration_order) def __getitem__(self, idx): - return self._pipeline[self._pipeline.iteration_order[idx]] + + try: + return self._pipeline[self._pipeline.iteration_order[idx]] + except: + print(idx) + print(self._pipeline.iteration_order[idx]) + raise From 26e1339b79913925385901327ca9b67d0a383ecb Mon Sep 17 00:00:00 2001 From: Tennessee Leeuwenburg Date: Sat, 7 Jun 2025 21:29:35 +1000 Subject: [PATCH 2/2] Further WIP on improvement to temporal retrieval with the need for aggregation --- docs/newproject.md | 19 + .../tutorial/Working_with_Climate_Data.ipynb | 1356 ++++++++++++++++- packages/data/src/pyearthtools/data/time.py | 4 - .../modifications/idx_modification.py | 3 - .../pipeline/operations/xarray/join.py | 15 +- 5 files changed, 1303 insertions(+), 94 deletions(-) diff --git a/docs/newproject.md b/docs/newproject.md index 5ffe34d9..7aa9e69f 100644 --- a/docs/newproject.md +++ b/docs/newproject.md @@ -1,5 +1,24 @@ # New Project Guide +## A Quick Hands-On Approach + +This guide is suitable for scientists or anyone else who wants to start trying things quickly to establish their first model and make a first attempt. More detail is provided below with more detail on the nuances and alternatives for each step. + +1. Use [https://pyearthtools.readthedocs.io/en/latest/notebooks/tutorial/FourCastMini_Demo.html](https://pyearthtools.readthedocs.io/en/latest/notebooks/tutorial/FourCastMini_Demo.html) as a template for what to do. +1. Determine the parameters you want to model, such as `temperature` or `wind`. When these become part of the neural network, they will be called *channels*. +2. Determine the data source they come from, such as ERA5 or another model or re-analysis source +3. Develop a `pipeline` which includes data normalisation +4. Using a bundled model, configure that model to the size required. This may only required the adjustment of `img_size`, `in_channels` and `out_channels` to match the size of your data. The grid dimension must be a multiple of four for this model, so you may need to crop or regrid your data to match. In future, a standard approach without this limitation will be added. +5. Run some number of training steps (using the `.fit` method) and visualise the outputs. Visualising predictions from the trained model every 3000 steps or so provides useful insight into the training process as well as helping see when the model might be fully trained. *There is no definite answer to how much training will be required. If your model isn't showing any progress at all after a couple of epochs, there may be a problem. Some models will start to show progress after 3000 steps.* + +This approach should be a usable starting point for any gridded inputs and outputs. The example is based on global modelling, but could reasonably be applied to nowcasting, observational data, limited area modelling, or just anything you can represent in an xarray on a grid. You could even add a grid containins data from a weather station at each grid point and see what happens. + +Getting a neural network to perform well and make optimal predictions is very hard, with many nuances. Getting started should be reasonably simple. + +The sections below go into more detail on how to treat source data, how to develop the most suitable pipeline for your project, how to use alternative neural network architectures, how to manage the training process, and how to perform a more thorough evaluation of the outputs. + +## Metholodogical Information + This guide offers a simple, repeatable process for undertaking a machine learning project. Experts in machine learning will recognise this as a standard approach, but of course it can be adapted as required in the project. Completing a project (whether using PyEarthTools or not) comprises the following steps: 1. Identify the sources of data that you wish to work with diff --git a/notebooks/tutorial/Working_with_Climate_Data.ipynb b/notebooks/tutorial/Working_with_Climate_Data.ipynb index f448862f..f50d4a70 100644 --- a/notebooks/tutorial/Working_with_Climate_Data.ipynb +++ b/notebooks/tutorial/Working_with_Climate_Data.ipynb @@ -461,8 +461,8 @@ " parent_experiment: historical\n", " modeling_realm: atmos\n", " realization: 1\n", - " cmor_version: 2.5.6