Skip to content

Commit 5095862

Browse files
authored
[AINode]: Integrate toto as a builtin forecasting model (#17322)
1 parent 83addf9 commit 5095862

26 files changed

Lines changed: 3066 additions & 4 deletions

NOTICE

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ grant the users the right to the use of patent under the requirement of Apache 2
1717

1818
============================================================================
1919

20+
This product includes source code derived from the DataDog/toto project:
21+
22+
Toto – Timeseries-Optimized Transformer for Observability
23+
Copyright 2025 Datadog, Inc.
24+
Licensed under the Apache License, Version 2.0
25+
https://github.com/DataDog/toto
26+
27+
============================================================================
28+
2029
Apache Commons Collections
2130
Copyright 2001-2019 The Apache Software Foundation
2231

integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ public class AINodeTestUtils {
5858
new AbstractMap.SimpleEntry<>(
5959
"chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")),
6060
new AbstractMap.SimpleEntry<>(
61-
"moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")))
61+
"moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")),
62+
new AbstractMap.SimpleEntry<>(
63+
"toto", new FakeModelInfo("toto", "toto", "builtin", "active")))
6264
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
6365

6466
public static final Map<String, FakeModelInfo> BUILTIN_MODEL_MAP;

iotdb-core/ainode/build_binary.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,26 +423,28 @@ def verify_poetry_env():
423423
[str(poetry_exe), "lock"],
424424
cwd=str(script_dir),
425425
env=venv_env,
426-
check=True,
426+
check=False,
427427
capture_output=True,
428428
text=True,
429429
)
430430
if result.stdout:
431431
print(result.stdout)
432432
if result.stderr:
433433
print(result.stderr)
434+
if result.returncode != 0:
435+
print(f"ERROR: poetry lock failed with exit code {result.returncode}")
436+
sys.exit(1)
434437
verify_poetry_env() # Verify after lock
435438

436439
accelerator = detect_accelerator()
437440
print(f"Selected accelerator: {accelerator}")
438441

439442
print("Running poetry install...")
440443
subprocess.run(
441-
[str(poetry_exe), "lock"],
444+
[str(poetry_exe), "install", "--no-root"],
442445
cwd=str(script_dir),
443446
env=venv_env,
444447
check=True,
445-
capture_output=True,
446448
text=True,
447449
)
448450
verify_poetry_env() # Verify before install

iotdb-core/ainode/iotdb/ainode/core/model/model_info.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,17 @@ def __repr__(self):
160160
},
161161
transformers_registered=True,
162162
),
163+
"toto": ModelInfo(
164+
model_id="toto",
165+
category=ModelCategory.BUILTIN,
166+
state=ModelStates.INACTIVE,
167+
model_type="toto",
168+
pipeline_cls="pipeline_toto.TotoPipeline",
169+
repo_id="Datadog/Toto-Open-Base-1.0",
170+
auto_map={
171+
"AutoConfig": "configuration_toto.TotoConfig",
172+
"AutoModelForCausalLM": "modeling_toto.TotoForPrediction",
173+
},
174+
transformers_registered=True,
175+
),
163176
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
19+
from typing import List, Optional
20+
21+
from transformers import PretrainedConfig
22+
23+
24+
class TotoConfig(PretrainedConfig):
25+
"""
26+
Configuration class for the Toto time series forecasting model.
27+
28+
Toto (Time Series Optimized Transformer for Observability) is a foundation model
29+
for multivariate time series forecasting developed by Datadog. It uses a decoder-only
30+
architecture with per-variate patch-based causal scaling, proportional time-variate
31+
factorized attention, and a Student-T mixture prediction head.
32+
33+
Reference: https://github.com/DataDog/toto
34+
"""
35+
36+
model_type = "toto"
37+
38+
def __init__(
39+
self,
40+
patch_size: int = 32,
41+
stride: int = 32,
42+
embed_dim: int = 1024,
43+
num_layers: int = 18,
44+
num_heads: int = 16,
45+
mlp_hidden_dim: int = 2816,
46+
dropout: float = 0.0,
47+
spacewise_every_n_layers: int = 3,
48+
scaler_cls: str = "per_variate_causal",
49+
output_distribution_classes: Optional[List[str]] = None,
50+
output_distribution_kwargs: Optional[dict] = None,
51+
spacewise_first: bool = True,
52+
use_memory_efficient_attention: bool = True,
53+
stabilize_with_global: bool = True,
54+
scale_factor_exponent: float = 10.0,
55+
**kwargs,
56+
):
57+
self.patch_size = patch_size
58+
self.stride = stride
59+
self.embed_dim = embed_dim
60+
self.num_layers = num_layers
61+
self.num_heads = num_heads
62+
self.mlp_hidden_dim = mlp_hidden_dim
63+
self.dropout = dropout
64+
self.spacewise_every_n_layers = spacewise_every_n_layers
65+
self.scaler_cls = scaler_cls
66+
self.output_distribution_classes = output_distribution_classes or [
67+
"student_t_mixture"
68+
]
69+
# k_components=5 is the default used by Datadog/Toto-Open-Base-1.0
70+
self.output_distribution_kwargs = output_distribution_kwargs or {
71+
"k_components": 5
72+
}
73+
self.spacewise_first = spacewise_first
74+
self.use_memory_efficient_attention = use_memory_efficient_attention
75+
self.stabilize_with_global = stabilize_with_global
76+
self.scale_factor_exponent = scale_factor_exponent
77+
78+
super().__init__(**kwargs)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
# This file includes code derived from DataDog/toto
19+
# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
20+
# Copyright 2025 Datadog, Inc.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
# This file includes code derived from DataDog/toto
19+
# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
20+
# Copyright 2025 Datadog, Inc.
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
# This file includes code derived from DataDog/toto
19+
# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
20+
# Copyright 2025 Datadog, Inc.
21+
22+
from functools import reduce
23+
from typing import NamedTuple
24+
25+
import numpy as np
26+
import torch
27+
import torch.utils.data
28+
from einops import repeat
29+
from jaxtyping import Bool, Float, Int, Shaped
30+
31+
32+
def pad_array(
33+
values: Shaped[torch.Tensor, "*batch variates series_len"], # noqa: F722
34+
patch_stride: int,
35+
) -> Shaped[torch.Tensor, "*batch variates padded_length"]: # noqa: F722
36+
"""
37+
Makes sure that the series length is divisible by the patch_stride
38+
by adding left-padding.
39+
"""
40+
if isinstance(values, np.ndarray):
41+
values = torch.from_numpy(values)
42+
series_len = values.shape[-1]
43+
padded_length = int(np.ceil(series_len / patch_stride) * patch_stride)
44+
if values.ndim == 2:
45+
padded_values = torch.zeros((values.shape[0], padded_length), dtype=values.dtype, device=values.device)
46+
elif values.ndim == 3:
47+
padded_values = torch.zeros(
48+
(values.shape[0], values.shape[1], padded_length),
49+
dtype=values.dtype,
50+
device=values.device,
51+
)
52+
else:
53+
raise ValueError(f"Unsupported number of dimensions: {values.ndim}")
54+
padded_values[..., -series_len:] = values
55+
56+
return padded_values
57+
58+
59+
def pad_id_mask(
60+
id_mask: Int[torch.Tensor, "*batch variates series_len"], # noqa: F722
61+
patch_stride: int,
62+
) -> Int[torch.Tensor, "*batch variates padded_length"]: # noqa: F722
63+
"""
64+
Makes sure that the series length is divisible by the patch_stride
65+
by adding left-padding to the id mask.
66+
"""
67+
series_len = id_mask.shape[-1]
68+
padded_length = int(np.ceil(series_len / patch_stride) * patch_stride)
69+
padding_amount = padded_length - series_len
70+
left_edge: Int[torch.Tensor, "*batch variates"] = id_mask[..., 0] # noqa: F722
71+
if id_mask.ndim == 2:
72+
padding = repeat(
73+
left_edge,
74+
"variates -> variates padding_amount",
75+
padding_amount=padding_amount,
76+
)
77+
id_mask = torch.cat([padding, id_mask], dim=1)
78+
elif id_mask.ndim == 3:
79+
padding = repeat(
80+
left_edge,
81+
"batch variates -> batch variates padding_amount",
82+
padding_amount=padding_amount,
83+
)
84+
id_mask = torch.cat([padding, id_mask], dim=2)
85+
else:
86+
raise ValueError(f"Unsupported number of dimensions: {id_mask.ndim}")
87+
88+
return id_mask
89+
90+
91+
class MaskedTimeseries(NamedTuple):
92+
series: Float[torch.Tensor, "*batch variates series_len"] # noqa: F722
93+
padding_mask: Bool[torch.Tensor, "*batch variates series_len"] # noqa: F722
94+
id_mask: Int[torch.Tensor, "*batch variates #series_len"] # noqa: F722
95+
timestamp_seconds: Int[torch.Tensor, "*batch variates series_len"] # noqa: F722
96+
time_interval_seconds: Int[torch.Tensor, "*batch variates"] # noqa: F722
97+
num_exogenous_variables: int = 0
98+
99+
def to(self, device: torch.device) -> "MaskedTimeseries":
100+
return MaskedTimeseries(
101+
series=self.series.to(device),
102+
padding_mask=self.padding_mask.to(device),
103+
id_mask=self.id_mask.to(device),
104+
timestamp_seconds=self.timestamp_seconds.to(device),
105+
time_interval_seconds=self.time_interval_seconds.to(device),
106+
num_exogenous_variables=self.num_exogenous_variables,
107+
)
108+
109+
110+
def is_extreme_value(t: torch.Tensor) -> torch.Tensor:
111+
if torch.is_floating_point(t):
112+
max_value = torch.finfo(t.dtype).max
113+
else:
114+
max_value = torch.iinfo(t.dtype).max
115+
116+
return reduce(
117+
torch.logical_or,
118+
(
119+
torch.isinf(t),
120+
torch.isnan(t),
121+
t.abs() >= max_value / 2,
122+
),
123+
)
124+
125+
126+
def replace_extreme_values(t: torch.Tensor, replacement: float = 0.0) -> torch.Tensor:
127+
return torch.where(is_extreme_value(t), torch.tensor(replacement, dtype=t.dtype, device=t.device), t)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
# This file includes code derived from DataDog/toto
19+
# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
20+
# Copyright 2025 Datadog, Inc.

0 commit comments

Comments
 (0)