Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/newproject.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
# New Project Guide

## A Quick Hands-On Approach

This guide is suitable for scientists or anyone else who wants to start trying things quickly to establish their first model and make a first attempt. More detail is provided below with more detail on the nuances and alternatives for each step.

1. Use [https://pyearthtools.readthedocs.io/en/latest/notebooks/tutorial/FourCastMini_Demo.html](https://pyearthtools.readthedocs.io/en/latest/notebooks/tutorial/FourCastMini_Demo.html) as a template for what to do.
1. Determine the parameters you want to model, such as `temperature` or `wind`. When these become part of the neural network, they will be called *channels*.
2. Determine the data source they come from, such as ERA5 or another model or re-analysis source
3. Develop a `pipeline` which includes data normalisation
4. Using a bundled model, configure that model to the size required. This may only required the adjustment of `img_size`, `in_channels` and `out_channels` to match the size of your data. The grid dimension must be a multiple of four for this model, so you may need to crop or regrid your data to match. In future, a standard approach without this limitation will be added.
5. Run some number of training steps (using the `.fit` method) and visualise the outputs. Visualising predictions from the trained model every 3000 steps or so provides useful insight into the training process as well as helping see when the model might be fully trained. *There is no definite answer to how much training will be required. If your model isn't showing any progress at all after a couple of epochs, there may be a problem. Some models will start to show progress after 3000 steps.*

This approach should be a usable starting point for any gridded inputs and outputs. The example is based on global modelling, but could reasonably be applied to nowcasting, observational data, limited area modelling, or just anything you can represent in an xarray on a grid. You could even add a grid containins data from a weather station at each grid point and see what happens.

Getting a neural network to perform well and make optimal predictions is very hard, with many nuances. Getting started should be reasonably simple.

The sections below go into more detail on how to treat source data, how to develop the most suitable pipeline for your project, how to use alternative neural network architectures, how to manage the training process, and how to perform a more thorough evaluation of the outputs.

## Metholodogical Information

This guide offers a simple, repeatable process for undertaking a machine learning project. Experts in machine learning will recognise this as a standard approach, but of course it can be adapted as required in the project. Completing a project (whether using PyEarthTools or not) comprises the following steps:

1. Identify the sources of data that you wish to work with
Expand Down
1,356 changes: 1,277 additions & 79 deletions notebooks/tutorial/Working_with_Climate_Data.ipynb

Large diffs are not rendered by default.

21 changes: 17 additions & 4 deletions packages/data/src/pyearthtools/data/indexes/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from pathlib import Path
from typing import Any, Callable, Iterable, Literal, Optional

import pandas as pd
import cftime
import xarray as xr

import pyearthtools.data
Expand Down Expand Up @@ -482,15 +484,26 @@ def retrieve(
if time_dim not in data.dims and time_dim in data.coords:
data = data.expand_dims(time_dim)

time_query = str(Petdt(querytime))
if isinstance(data.coords[time_dim].values[0], cftime.datetime):
time_query = cftime.datetime(querytime.year,
querytime.month,
querytime.day,
calendar='noleap',
has_year_zero=True)
self._round = True
round = True
# time_query = pd.to_datetime(time_query)

if select and time_dim in data:
try:
data = data.sel(
**{time_dim: str(Petdt(querytime))},
**{time_dim: time_query},
method="nearest" if round else None,
)
except KeyError:
warnings.warn(
f"Could not find time in dataset to select on. {querytime!r}",
f"Could not find time in dataset to select on. {time_query!r}",
IndexWarning,
)

Expand Down Expand Up @@ -535,14 +548,14 @@ def series(
Loaded series of data
"""

interval = self._get_interval(interval)
_interval = self._get_interval(interval)
tolerance = kwargs.pop("tolerance", getattr(self, "data_interval", None))

return index_routines.series(
self,
start,
end,
interval,
_interval,
transforms=transforms,
tolerance=tolerance,
**kwargs,
Expand Down
2 changes: 0 additions & 2 deletions packages/data/src/pyearthtools/data/save/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,4 @@ def to_zarr(
save_kwargs.pop("append_dim", None)
save_kwargs.pop("region", None)

# print(save_kwargs)
# print(dataset)
dataset.to_zarr(zarr_file, **save_kwargs)
6 changes: 5 additions & 1 deletion packages/data/src/pyearthtools/data/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,9 @@ def __add__(self, other: Petdt | TimeDelta | int) -> Petdt:

resolution = TimeResolution("year")
if isinstance(other, _MonthTimeDelta):
return NotImplemented
if isinstance(other, int):
raise NotImplementedError


if isinstance(other, int):
if other < 0:
Expand Down Expand Up @@ -658,6 +660,8 @@ def __init__(self, timedelta: Any, *args) -> None:
timedelta = (1, timedelta)
elif len(multisplit(timedelta, [" ", ",", "-"])) == 2:
timedelta = tuple(x.strip() for x in multisplit(timedelta, [" ", ",", "-"]))
else:
raise ValueError(f"Unrecognised time interval {timedelta} not in {RESOLUTION_COMPONENTS}")

if isinstance(timedelta, (list, tuple)) and len(timedelta) == 2:
if isinstance(timedelta[1], str) and timedelta[1].strip().removesuffix("s") in RESOLUTION_COMPONENTS:
Expand Down
5 changes: 5 additions & 0 deletions packages/models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# PyEarthTools Models

This is the modles sub-package which forms a part of the [PyEarthTools package](https://github.com/ACCESS-Community-Hub/PyEarthTools).

Documentation for the PyEarthTools package is available [here](https://pyearthtools.readthedocs.io/en/latest/).
57 changes: 57 additions & 0 deletions packages/models/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"


[project]
name = "pyearthtools-models"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't relevant to this PR and was an accidental include

description = "Blueprint model implementations which can be used in PyEarthTools"
requires-python = ">=3.9"
keywords = ["pyearthtools"]
maintainers = [
{name = "Tennessee Leeuwenburg", email = "tennessee.leeuwenburg@bom.gov.au"}
]
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
]
dependencies = [
"xarray[complete]",
"geopandas",
"shapely",
"tqdm",
"pyyaml",
]
dynamic = ["version", "readme"]

[tool.setuptools.dynamic]
readme = {file = ["README.md"], content-type = "text/markdown"}

all = [
"pyearthtools-models",
]

[project.urls]
homepage = "https://pyearthtools.readthedocs.io/"
documentation = "https://pyearthtools.readthedocs.io/"
repository = "https://github.com/ACCESS-Community-Hub/PyEarthTools"

[tool.isort]
profile = "black"

[tool.black]
line-length = 120

[tool.ruff]
line-length = 120

[tool.mypy]
warn_return_any = true
warn_unused_configs = true

[tool.hatch.version]
path = "src/pyearthtools/models/__init__.py"

[tool.hatch.build.targets.wheel]
packages = ["src/pyearthtools/"]
12 changes: 10 additions & 2 deletions packages/nci_site_archive/src/site_archive_nci/_CMIP5.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,13 @@ def __init__(

warnings.warn(UNDER_DEV_MSG)

base_transforms = TransformCollection()
base_transforms += pyearthtools.data.transforms.variables.variable_trim(variables)
base_transforms += pyearthtools.data.transforms.coordinates.Drop(coordinates='height')

super().__init__(
transforms=TransformCollection(),
transforms=base_transforms,
data_interval= (1, "month")
)

self.record_initialisation()
Expand Down Expand Up @@ -208,6 +213,9 @@ def quick_walk(self):

self.walk_cache = walk_cache




def filesystem(self, query_dictionary={}):
"""
Given the supplied query, return all filenames which contain the data necessary to extract the data
Expand Down Expand Up @@ -248,7 +256,7 @@ def match_path(self, path, query):
if path["model"] not in self.models:
match = False

if path["interval"] not in self.interval:
if path["interval"] not in self.interval[0]:
match = False

if path["scenario"] not in self.scenarios:
Expand Down
14 changes: 10 additions & 4 deletions packages/pipeline/src/pyearthtools/pipeline/branching/branching.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,18 @@ def __init__(
self.sub_pipelines = list(map(lambda x: Pipeline(*x), __incoming_steps))

def __getitem__(self, idx: Any) -> tuple:
"""Get result from each branch"""
results = []
"""
Go through each upstream sub-pipeline, fetch their results, and return them as a tuple
"""
queries = []
query_interface = self.parallel_interface # This is a property and might be a serial interface

for pipe in self.sub_pipelines:
results.append(self.parallel_interface.submit(pipe.__getitem__, idx))
query = query_interface.submit(pipe.__getitem__, idx)
queries.append(query)

return tuple(self.parallel_interface.collect(results))
samples = self.parallel_interface.collect(queries)
return tuple(samples)

def apply(self, sample):
"""Apply each branch on the sample"""
Expand Down
29 changes: 24 additions & 5 deletions packages/pipeline/src/pyearthtools/pipeline/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,9 @@ def has_source(self) -> bool:

def _get_initial_sample(self, idx: Any) -> tuple[Any, int]:
"""
Get sample from first pipeline step or first working back index
Get a data sample from the first pipeline step, or an intermediate generator
e.g. such as a source accessor, an intermediate cache,
or a temporal retrieval modifier

Returns:
(tuple[Any, int]):
Expand All @@ -473,25 +475,40 @@ def _get_initial_sample(self, idx: Any) -> tuple[Any, int]:
for index, step in enumerate(self.steps[::-1]):
if isinstance(step, PipelineIndex):
LOG.debug(f"Getting initial sample from {step} at {idx}")
return step[idx], len(self.steps) - (index + 1)
sample = step[idx]
whereinthesequence = len(self.steps) - (index + 1)
return sample, whereinthesequence

# Confirm that the start of the pipeline is an accessor, and then fetch from it
if isinstance(self.steps[0], (_Pipeline, Index)):
LOG.debug(f"Getting initial sample from {self.steps[0]} at {idx}")
return self.steps[0][idx], 0
sample = self.steps[0][idx]
whereinthesequence = 0
return sample, whereinthesequence

# If we haven't returned something by now, that's an error condition, raise it.
raise TypeError(f"Cannot find an `Index` to get data from. Found {type(self.steps[0]).__qualname__}")

def __getitem__(self, idx: Any):
"""Retrieve from pipeline at `idx`"""
"""
Retrieve from pipeline at `idx`

Called by users when accessing the pipeline
Also called by modifications (such as temporal retrieval)
"""
if isinstance(idx, slice):
indexes = self.iterator[idx]
LOG.debug(f"Call pipeline __getitem__ for {indexes = }")
return map(self.__getitem__, indexes)

# Start the pipeline with the raw/initial data
# Index here is the index of the data, not the position of the step in the list of steps
# Initial just means untransformed by the pipeline
sample, step_index = self._get_initial_sample(idx)
LOG.debug(f"Call pipeline __getitem__ for {idx = }")
unmodified_sample = sample
LOG.debug(f"Call pipeline __getitem__ for {idx}")

# Apply each pipeline step to the sample
for step in self.steps[step_index + 1 :]:
if not isinstance(step, (Pipeline, PipelineStep, Transform, TransformCollection)):
raise TypeError(f"When iterating through pipeline steps, found a {type(step)} which cannot be parsed.")
Expand All @@ -504,6 +521,8 @@ def __getitem__(self, idx: Any):
sample = step.apply(sample)
else:
sample = step(sample) # type: ignore

# We've done all the pipeline steps, return the value
return sample

def __call__(self, obj):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,40 +190,58 @@ def trim(s):
self._merge_kwargs.pop("axis")

result = merge_function(sample, **self._merge_kwargs)

return result

def _get_tuple(self, idx, mod: tuple[Any, ...], layer: int) -> Union[tuple[Any], Any]:
"""
Collect all elements from tuple of modification
Given a sequence of modifications (e.g. temporal retrievals), go through
and unpack them, then ultimately call the parent pipeline's getitem method against
each modification request.

Will descend through nested tuples.
"""
super_get = self.parent_pipeline().__getitem__
query_interface = self.parallel_interface # Dynamically fetches via property

samples = []
queries = []

# If using the parallel interface, this will prepare the queries
# If using the serial interface, the actual results will be fetched up-front
# and stored on the queries with a ResultCache
for m in mod:
if isinstance(m, tuple):
samples.append(self.parallel_interface.submit(self._get_tuple, idx, m, layer + 1))
# Tuple-unpacking pathway
query = query_interface.submit(self._get_tuple, idx, m, layer + 1)
else:
samples.append(self.parallel_interface.submit(super_get, idx + m))
# Retrieve from pipeline pathway
# THIS MAY BE WHERE THE BUG IS HAPPENING
query = query_interface.submit(super_get, idx + m)
queries.append(query)

samples = tuple(self.parallel_interface.collect(samples))
# Retrieve the results of the queries, either in parallel or in series
samples = tuple(query_interface.collect(queries))

# def trim(s):
# if isinstance(s, tuple) and len(s) == 1:
# return s[0]
# return s
# If the whole thing just fetched one sample, return it without any merging
if isinstance(samples, tuple):
if len(samples) == 1:
return samples[0]

if layer >= self._merge:
return self._run_merge(samples)
return samples

def __getitem__(self, idx: Any):

# If we do not have any unpacking to do, get the upstream sample
# and apply our modification
if not isinstance(self._modification, tuple):
return self.parent_pipeline()[idx + self._modification]
result = self.parent_pipeline()[idx + self._modification]
return result

return self._get_tuple(idx, self._modification, 0)
# Do unpacking of the tuple
result = self._get_tuple(idx, self._modification, 0)
return result


class TimeIdxModifier(IdxModifier):
Expand Down Expand Up @@ -458,15 +476,20 @@ def __init__(
def map_to_tuple(mod):
if isinstance(mod, tuple):
return tuple(map(map_to_tuple, mod))

return pyearthtools.data.TimeDelta((mod, delta_unit))

if delta_unit is not None:
self._modification = map_to_tuple(self._modification)

def __getitem__(self, idx: Any):

# Try to convert the index to a datetime for temporal retrievals
if not isinstance(idx, pyearthtools.data.Petdt):
if not pyearthtools.data.Petdt.is_time(idx):
raise TypeError(f"Cannot convert {idx!r} to `pyearthtools.data.Petdt`.")
idx = pyearthtools.data.Petdt(idx)

return super().__getitem__(idx)
# Fetch the item using a Petdt index
result = super().__getitem__(idx)
return result
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def apply_func(self, data: xr.Dataset) -> xr.Dataset:
Sorted dataset
"""

recoded = data.indexes["time"].to_datetimeindex()
recoded = data.indexes["time"].to_datetimeindex(time_unit='us')
data["time"] = recoded

return data
Loading
Loading