|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# |
| 18 | +# This file includes code derived from DataDog/toto |
| 19 | +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. |
| 20 | +# Copyright 2025 Datadog, Inc. |
| 21 | + |
| 22 | +from functools import reduce |
| 23 | +from typing import NamedTuple |
| 24 | + |
| 25 | +import numpy as np |
| 26 | +import torch |
| 27 | +import torch.utils.data |
| 28 | +from einops import repeat |
| 29 | +from jaxtyping import Bool, Float, Int, Shaped |
| 30 | + |
| 31 | + |
| 32 | +def pad_array( |
| 33 | + values: Shaped[torch.Tensor, "*batch variates series_len"], # noqa: F722 |
| 34 | + patch_stride: int, |
| 35 | +) -> Shaped[torch.Tensor, "*batch variates padded_length"]: # noqa: F722 |
| 36 | + """ |
| 37 | + Makes sure that the series length is divisible by the patch_stride |
| 38 | + by adding left-padding. |
| 39 | + """ |
| 40 | + if isinstance(values, np.ndarray): |
| 41 | + values = torch.from_numpy(values) |
| 42 | + series_len = values.shape[-1] |
| 43 | + padded_length = int(np.ceil(series_len / patch_stride) * patch_stride) |
| 44 | + if values.ndim == 2: |
| 45 | + padded_values = torch.zeros((values.shape[0], padded_length), dtype=values.dtype, device=values.device) |
| 46 | + elif values.ndim == 3: |
| 47 | + padded_values = torch.zeros( |
| 48 | + (values.shape[0], values.shape[1], padded_length), |
| 49 | + dtype=values.dtype, |
| 50 | + device=values.device, |
| 51 | + ) |
| 52 | + else: |
| 53 | + raise ValueError(f"Unsupported number of dimensions: {values.ndim}") |
| 54 | + padded_values[..., -series_len:] = values |
| 55 | + |
| 56 | + return padded_values |
| 57 | + |
| 58 | + |
| 59 | +def pad_id_mask( |
| 60 | + id_mask: Int[torch.Tensor, "*batch variates series_len"], # noqa: F722 |
| 61 | + patch_stride: int, |
| 62 | +) -> Int[torch.Tensor, "*batch variates padded_length"]: # noqa: F722 |
| 63 | + """ |
| 64 | + Makes sure that the series length is divisible by the patch_stride |
| 65 | + by adding left-padding to the id mask. |
| 66 | + """ |
| 67 | + series_len = id_mask.shape[-1] |
| 68 | + padded_length = int(np.ceil(series_len / patch_stride) * patch_stride) |
| 69 | + padding_amount = padded_length - series_len |
| 70 | + left_edge: Int[torch.Tensor, "*batch variates"] = id_mask[..., 0] # noqa: F722 |
| 71 | + if id_mask.ndim == 2: |
| 72 | + padding = repeat( |
| 73 | + left_edge, |
| 74 | + "variates -> variates padding_amount", |
| 75 | + padding_amount=padding_amount, |
| 76 | + ) |
| 77 | + id_mask = torch.cat([padding, id_mask], dim=1) |
| 78 | + elif id_mask.ndim == 3: |
| 79 | + padding = repeat( |
| 80 | + left_edge, |
| 81 | + "batch variates -> batch variates padding_amount", |
| 82 | + padding_amount=padding_amount, |
| 83 | + ) |
| 84 | + id_mask = torch.cat([padding, id_mask], dim=2) |
| 85 | + else: |
| 86 | + raise ValueError(f"Unsupported number of dimensions: {id_mask.ndim}") |
| 87 | + |
| 88 | + return id_mask |
| 89 | + |
| 90 | + |
| 91 | +class MaskedTimeseries(NamedTuple): |
| 92 | + series: Float[torch.Tensor, "*batch variates series_len"] # noqa: F722 |
| 93 | + padding_mask: Bool[torch.Tensor, "*batch variates series_len"] # noqa: F722 |
| 94 | + id_mask: Int[torch.Tensor, "*batch variates #series_len"] # noqa: F722 |
| 95 | + timestamp_seconds: Int[torch.Tensor, "*batch variates series_len"] # noqa: F722 |
| 96 | + time_interval_seconds: Int[torch.Tensor, "*batch variates"] # noqa: F722 |
| 97 | + num_exogenous_variables: int = 0 |
| 98 | + |
| 99 | + def to(self, device: torch.device) -> "MaskedTimeseries": |
| 100 | + return MaskedTimeseries( |
| 101 | + series=self.series.to(device), |
| 102 | + padding_mask=self.padding_mask.to(device), |
| 103 | + id_mask=self.id_mask.to(device), |
| 104 | + timestamp_seconds=self.timestamp_seconds.to(device), |
| 105 | + time_interval_seconds=self.time_interval_seconds.to(device), |
| 106 | + num_exogenous_variables=self.num_exogenous_variables, |
| 107 | + ) |
| 108 | + |
| 109 | + |
| 110 | +def is_extreme_value(t: torch.Tensor) -> torch.Tensor: |
| 111 | + if torch.is_floating_point(t): |
| 112 | + max_value = torch.finfo(t.dtype).max |
| 113 | + else: |
| 114 | + max_value = torch.iinfo(t.dtype).max |
| 115 | + |
| 116 | + return reduce( |
| 117 | + torch.logical_or, |
| 118 | + ( |
| 119 | + torch.isinf(t), |
| 120 | + torch.isnan(t), |
| 121 | + t.abs() >= max_value / 2, |
| 122 | + ), |
| 123 | + ) |
| 124 | + |
| 125 | + |
| 126 | +def replace_extreme_values(t: torch.Tensor, replacement: float = 0.0) -> torch.Tensor: |
| 127 | + return torch.where(is_extreme_value(t), torch.tensor(replacement, dtype=t.dtype, device=t.device), t) |
0 commit comments