diff --git a/LICENSE b/LICENSE index 07a2ad8da4d4e..81000948cc4ef 100644 --- a/LICENSE +++ b/LICENSE @@ -360,3 +360,13 @@ Project page: https://github.com/SalesforceAIResearch/uni2ts License: https://github.com/SalesforceAIResearch/uni2ts/blob/main/LICENSE.txt -------------------------------------------------------------------------------- + +The following files include code modified from MOMENT project. + +./iotdb-core/ainode/iotdb/ainode/core/model/moment/* + +The MOMENT is open source software licensed under the MIT License +Project page: https://github.com/moment-timeseries-foundation-model/moment +License: https://github.com/moment-timeseries-foundation-model/moment/blob/main/LICENSE + +-------------------------------------------------------------------------------- diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index bf758a083d463..d69a301c06680 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -60,7 +60,9 @@ public class AINodeTestUtils { new AbstractMap.SimpleEntry<>( "moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")), new AbstractMap.SimpleEntry<>( - "toto", new FakeModelInfo("toto", "toto", "builtin", "active"))) + "toto", new FakeModelInfo("toto", "toto", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "moment", new FakeModelInfo("moment", "moment", "builtin", "active"))) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); public static final Map BUILTIN_MODEL_MAP; diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index 4a2ee543d1f8d..9d93b066c8e5d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -53,6 +53,7 @@ AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = { "sundial": 1036 * 1024**2, # 1036 MiB "timer_xl": 856 * 1024**2, # 856 MiB + "moment": 200 * 1024**2, # ~200 MiB (MOMENT-1-small) } # the memory usage of each model in bytes AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.2 # the device space allocated for inference diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index 642986c42d21f..b610c0f74b49d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -173,4 +173,17 @@ def __repr__(self): }, transformers_registered=True, ), + "moment": ModelInfo( + model_id="moment", + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + model_type="moment", + pipeline_cls="pipeline_moment.MomentPipeline", + repo_id="AutonLab/MOMENT-1-small", + auto_map={ + "AutoConfig": "configuration_moment.MomentConfig", + "AutoModelForCausalLM": "modeling_moment.MomentForPrediction", + }, + transformers_registered=True, + ), } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moment/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/moment/__init__.py new file mode 100644 index 0000000000000..2a1e720805f29 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moment/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moment/configuration_moment.py b/iotdb-core/ainode/iotdb/ainode/core/model/moment/configuration_moment.py new file mode 100644 index 0000000000000..4b72f089d9ca2 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moment/configuration_moment.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# This file contains code adapted from the MOMENT project +# (https://github.com/moment-timeseries-foundation-model/moment), +# originally licensed under the MIT License. + +from transformers import PretrainedConfig + + +class MomentConfig(PretrainedConfig): + """ + Configuration class for the MOMENT time series foundation model. + + MOMENT (A Family of Open Time-series Foundation Models) is developed by + Auton Lab, Carnegie Mellon University. It uses a T5 encoder-only backbone + with patch-based input embedding and RevIN normalization for multi-task + time series analysis including forecasting, classification, anomaly + detection and imputation. + + Reference: https://arxiv.org/abs/2402.03885 + """ + + model_type = "moment" + + def __init__( + self, + seq_len: int = 512, + patch_len: int = 8, + patch_stride_len: int = 8, + d_model: int = 1024, + transformer_backbone: str = "google/flan-t5-large", + forecast_horizon: int = 96, + revin_affine: bool = False, + **kwargs, + ): + self.seq_len = seq_len + self.patch_len = patch_len + self.patch_stride_len = patch_stride_len + self.d_model = d_model + self.transformer_backbone = transformer_backbone + self.forecast_horizon = forecast_horizon + self.revin_affine = revin_affine + + super().__init__(**kwargs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moment/modeling_moment.py b/iotdb-core/ainode/iotdb/ainode/core/model/moment/modeling_moment.py new file mode 100644 index 0000000000000..ec8caed43394f --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moment/modeling_moment.py @@ -0,0 +1,396 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# This file contains code adapted from the MOMENT project +# (https://github.com/moment-timeseries-foundation-model/moment), +# originally licensed under the MIT License. + +import json +import os +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PreTrainedModel, T5Config, T5EncoderModel + +from iotdb.ainode.core.log import Logger + +from .configuration_moment import MomentConfig + +logger = Logger() + + +@dataclass +class MomentOutput: + forecast: Optional[torch.Tensor] = None + reconstruction: Optional[torch.Tensor] = None + embeddings: Optional[torch.Tensor] = None + input_mask: Optional[torch.Tensor] = None + + +class RevIN(nn.Module): + """Reversible Instance Normalization for time series.""" + + def __init__(self, n_features: int, affine: bool = False, eps: float = 1e-5): + super().__init__() + self.n_features = n_features + self.affine = affine + self.eps = eps + if self.affine: + self.affine_weight = nn.Parameter(torch.ones(self.n_features)) + self.affine_bias = nn.Parameter(torch.zeros(self.n_features)) + + def forward(self, x: torch.Tensor, mode: str) -> torch.Tensor: + # x: [batch, n_channels, seq_len] + if mode == "norm": + self._get_statistics(x) + x = self._normalize(x) + elif mode == "denorm": + x = self._denormalize(x) + return x + + def _get_statistics(self, x: torch.Tensor): + self.mean = torch.mean(x, dim=-1, keepdim=True).detach() + self.stdev = torch.sqrt( + torch.var(x, dim=-1, keepdim=True, unbiased=False) + self.eps + ).detach() + + def _normalize(self, x: torch.Tensor) -> torch.Tensor: + x = (x - self.mean) / self.stdev + if self.affine: + x = x * self.affine_weight.unsqueeze(0).unsqueeze(-1) + x = x + self.affine_bias.unsqueeze(0).unsqueeze(-1) + return x + + def _denormalize(self, x: torch.Tensor) -> torch.Tensor: + if self.affine: + x = x - self.affine_bias.unsqueeze(0).unsqueeze(-1) + x = x / (self.affine_weight.unsqueeze(0).unsqueeze(-1) + self.eps) + x = x * self.stdev + x = x + self.mean + return x + + +class Patching(nn.Module): + """Unfold a 1-D time series into fixed-size patches.""" + + def __init__(self, patch_len: int = 8, stride: int = 8): + super().__init__() + self.patch_len = patch_len + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [batch, n_channels, seq_len] + # out: [batch, n_channels, n_patches, patch_len] + x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) + return x + + +class PatchEmbedding(nn.Module): + """Linear projection of patches to model dimension.""" + + def __init__(self, d_model: int, patch_len: int): + super().__init__() + self.proj = nn.Linear(patch_len, d_model) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [batch, n_channels, n_patches, patch_len] + # out: [batch, n_channels, n_patches, d_model] + return self.proj(x) + + +class ForecastingHead(nn.Module): + """Linear head that projects flattened patch embeddings to forecast horizon.""" + + def __init__(self, d_model: int, n_patches: int, forecast_horizon: int): + super().__init__() + self.flatten = nn.Flatten(start_dim=-2) + self.proj = nn.Linear(d_model * n_patches, forecast_horizon) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [batch, n_channels, n_patches, d_model] + x = self.flatten(x) # [batch, n_channels, n_patches * d_model] + return self.proj(x) # [batch, n_channels, forecast_horizon] + + +class MomentBackbone(nn.Module): + """ + Core MOMENT architecture. + + Architecture: + Input [batch, n_channels, seq_len] + -> RevIN normalization + -> Patching (unfold into fixed-size patches) + -> Patch embedding (linear projection to d_model) + -> T5 Encoder (self-attention layers) + -> Task-specific head (forecasting: linear projection) + -> RevIN denormalization + -> Output [batch, n_channels, forecast_horizon] + """ + + def __init__(self, config: MomentConfig, t5_config: Optional[dict] = None): + super().__init__() + self.config = config + self.seq_len = config.seq_len + self.patch_len = config.patch_len + self.patch_stride_len = config.patch_stride_len + self.d_model = config.d_model + self.forecast_horizon = config.forecast_horizon + + self.n_patches = (self.seq_len - self.patch_len) // self.patch_stride_len + 1 + + # RevIN normalization + self.revin = RevIN(n_features=1, affine=config.revin_affine) + + # Patching and embedding + self.patching = Patching( + patch_len=self.patch_len, stride=self.patch_stride_len + ) + self.patch_embedding = PatchEmbedding( + d_model=self.d_model, patch_len=self.patch_len + ) + + # Positional embedding for patches + self.position_embedding = nn.Embedding(self.n_patches, self.d_model) + + # Mask embedding (for masked reconstruction during pre-training) + self.mask_embedding = nn.Parameter(torch.zeros(self.d_model)) + + # T5 encoder backbone + if t5_config is not None: + encoder_config = T5Config(**t5_config) + else: + encoder_config = T5Config.from_pretrained(config.transformer_backbone) + encoder_config.d_model = self.d_model + self.encoder = T5EncoderModel(encoder_config) + + # Layer norm before head + self.layer_norm = nn.LayerNorm(self.d_model) + + # Forecasting head + self.head = ForecastingHead( + d_model=self.d_model, + n_patches=self.n_patches, + forecast_horizon=self.forecast_horizon, + ) + + def forward( + self, + x_enc: torch.Tensor, + input_mask: Optional[torch.Tensor] = None, + ) -> MomentOutput: + """ + Forward pass for forecasting. + + Args: + x_enc: [batch_size, n_channels, seq_len] + input_mask: [batch_size, seq_len] - 1 for observed, 0 for padding + + Returns: + MomentOutput with forecast field + """ + batch_size, n_channels, seq_len = x_enc.shape + + # Handle input_mask + if input_mask is None: + input_mask = torch.ones(batch_size, seq_len, device=x_enc.device) + + # RevIN normalization (channel-independent) + # Reshape to process each channel independently + x = x_enc.reshape(batch_size * n_channels, 1, seq_len) + x = self.revin(x, mode="norm") + x = x.reshape(batch_size, n_channels, seq_len) + + # Patching: [batch, n_channels, n_patches, patch_len] + x = self.patching(x) + + # Patch embedding: [batch, n_channels, n_patches, d_model] + x = self.patch_embedding(x) + + # Apply input mask at patch level + patch_mask = self._create_patch_mask(input_mask) + mask_embed = self.mask_embedding.unsqueeze(0).unsqueeze(0).unsqueeze(0) + x = x * patch_mask.unsqueeze(-1) + mask_embed * ( + 1.0 - patch_mask.unsqueeze(-1) + ) + + # Position embedding + positions = torch.arange(self.n_patches, device=x.device) + x = x + self.position_embedding(positions).unsqueeze(0).unsqueeze(0) + + # Flatten batch and channel dims for T5 encoder + # [batch * n_channels, n_patches, d_model] + x = x.reshape(batch_size * n_channels, self.n_patches, self.d_model) + + # T5 encoder forward + enc_output = self.encoder(inputs_embeds=x).last_hidden_state + + # Layer norm + enc_output = self.layer_norm(enc_output) + + # Restore channel dim: [batch, n_channels, n_patches, d_model] + enc_output = enc_output.reshape( + batch_size, n_channels, self.n_patches, self.d_model + ) + + # Forecasting head: [batch, n_channels, forecast_horizon] + forecast = self.head(enc_output) + + # RevIN denormalization + forecast = forecast.reshape(batch_size * n_channels, 1, self.forecast_horizon) + forecast = self.revin(forecast, mode="denorm") + forecast = forecast.reshape(batch_size, n_channels, self.forecast_horizon) + + return MomentOutput(forecast=forecast, embeddings=enc_output) + + def _create_patch_mask(self, input_mask: torch.Tensor) -> torch.Tensor: + """Convert per-timestep mask to per-patch mask via average pooling.""" + # input_mask: [batch, seq_len] + # output: [batch, 1, n_patches] with values in [0, 1] + mask = input_mask.unsqueeze(1) # [batch, 1, seq_len] + mask = mask.unfold(dimension=-1, size=self.patch_len, step=self.patch_stride_len) + # [batch, 1, n_patches, patch_len] + mask = mask.mean(dim=-1) # [batch, 1, n_patches] + mask = (mask > 0.5).float() + return mask + + +class MomentPreTrainedModel(PreTrainedModel): + """Abstract base class for all MOMENT model variants.""" + + config_class = MomentConfig + base_model_prefix = "moment" + supports_gradient_checkpointing = False + + def _init_weights(self, module): + pass + + +class MomentForPrediction(MomentPreTrainedModel): + """ + MOMENT model for time series forecasting, wrapped as a HuggingFace PreTrainedModel. + + Loads the pre-trained MOMENT backbone from safetensors and configures + the forecasting head for a given horizon. + + Reference: https://huggingface.co/AutonLab/MOMENT-1-large + """ + + def __init__(self, config: MomentConfig): + super().__init__(config) + self.moment = MomentBackbone(config) + self.post_init() + + def forward( + self, + x_enc: torch.Tensor, + input_mask: Optional[torch.Tensor] = None, + ) -> MomentOutput: + return self.moment(x_enc=x_enc, input_mask=input_mask) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + """ + Load MomentForPrediction from a local directory containing + ``config.json`` and ``model.safetensors``. + + The upstream MOMENT checkpoint uses a flat state-dict structure + (keys like ``revin.affine_weight``, ``encoder.encoder.block.0...``). + This method handles mapping those keys into our nested ``moment.*`` + structure. + """ + if not os.path.isdir(pretrained_model_name_or_path): + raise ValueError( + f"pretrained_model_name_or_path must be a local directory, " + f"got: {pretrained_model_name_or_path}" + ) + + config_file = os.path.join(pretrained_model_name_or_path, "config.json") + safetensors_file = os.path.join( + pretrained_model_name_or_path, "model.safetensors" + ) + + # Load config + config_dict = {} + if os.path.exists(config_file): + with open(config_file, "r") as f: + config_dict = json.load(f) + + # Extract t5_config if present in the upstream config + t5_config = config_dict.pop("t5_config", None) + + # Map upstream config fields to our MomentConfig fields + moment_config_kwargs = { + "seq_len": config_dict.get("seq_len", 512), + "patch_len": config_dict.get("patch_len", 8), + "patch_stride_len": config_dict.get("patch_stride_len", 8), + "transformer_backbone": config_dict.get( + "transformer_backbone", "google/flan-t5-large" + ), + "forecast_horizon": kwargs.pop("forecast_horizon", 96), + "revin_affine": config_dict.get("revin_affine", False), + } + + # Infer d_model from t5_config + if t5_config and "d_model" in t5_config: + moment_config_kwargs["d_model"] = t5_config["d_model"] + elif "d_model" in config_dict and config_dict["d_model"] is not None: + moment_config_kwargs["d_model"] = config_dict["d_model"] + + moment_config_kwargs.update(kwargs) + config = MomentConfig(**moment_config_kwargs) + + # Instantiate model (backbone uses t5_config for encoder construction) + # Override backbone init to pass t5_config + instance = cls.__new__(cls) + MomentPreTrainedModel.__init__(instance, config) + instance.moment = MomentBackbone(config, t5_config=t5_config) + instance.post_init() + + # Load weights + if not os.path.exists(safetensors_file): + raise FileNotFoundError( + f"Model checkpoint not found at: {safetensors_file}" + ) + + import safetensors.torch as safetorch + + state_dict = safetorch.load_file(safetensors_file, device="cpu") + + # Map upstream flat keys to our nested moment.* structure + mapped_state_dict = {} + for key, value in state_dict.items(): + new_key = f"moment.{key}" + mapped_state_dict[new_key] = value + + # Load with strict=False to skip mismatched head weights + model_state = instance.state_dict() + filtered = {k: v for k, v in mapped_state_dict.items() if k in model_state} + instance.load_state_dict(filtered, strict=False) + instance.eval() + + logger.info( + f"Loaded MOMENT model from {pretrained_model_name_or_path} " + f"({len(filtered)}/{len(model_state)} keys matched)" + ) + return instance + + @property + def device(self): + return next(self.parameters()).device diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moment/pipeline_moment.py b/iotdb-core/ainode/iotdb/ainode/core/model/moment/pipeline_moment.py new file mode 100644 index 0000000000000..a1343525457f0 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moment/pipeline_moment.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# This file contains code adapted from the MOMENT project +# (https://github.com/moment-timeseries-foundation-model/moment), +# originally licensed under the MIT License. + +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.model_info import ModelInfo + +logger = Logger() + +# MOMENT requires a fixed input length of 512 timesteps +MOMENT_SEQ_LEN = 512 + + +class MomentPipeline(ForecastPipeline): + """ + Inference pipeline for the MOMENT time series foundation model. + + MOMENT processes fixed-length (512) univariate patches through a T5 encoder + and produces forecasts via a single-shot linear head. Each channel/variate + is processed independently (channel-independent design). + + The pipeline handles: + - Padding/truncating inputs to the required 512 length + - Constructing input masks for padded positions + - Iterative forecasting for horizons beyond the model's native capacity + """ + + def __init__(self, model_info: ModelInfo, **model_kwargs): + super().__init__(model_info, **model_kwargs) + + def _preprocess(self, inputs, **infer_kwargs) -> dict: + """ + Preprocess input data for MOMENT. + + Converts the list of input dicts into a single tensor padded/truncated + to MOMENT's required sequence length of 512. Also constructs an + input_mask indicating which timesteps are observed vs padded. + + Parameters + ---------- + inputs : list of dict + Each dict has key ``"targets"`` with a tensor of shape + ``(target_count, input_length)``. + + Returns + ------- + dict + ``"x_enc"``: tensor of shape ``[batch, n_channels, 512]`` + ``"input_mask"``: tensor of shape ``[batch, 512]`` + """ + if inputs[0].get("past_covariates") or inputs[0].get("future_covariates"): + logger.warning( + "MomentPipeline does not support covariates; they will be ignored." + ) + + batch_tensors = [] + batch_masks = [] + + for item in inputs: + targets = item["targets"] # [target_count, input_length] + if targets.ndim == 1: + targets = targets.unsqueeze(0) + + n_channels, input_length = targets.shape + + if input_length >= MOMENT_SEQ_LEN: + # Truncate: take the last MOMENT_SEQ_LEN timesteps + x = targets[:, -MOMENT_SEQ_LEN:] + mask = torch.ones(MOMENT_SEQ_LEN, device=targets.device) + else: + # Left-pad with zeros + pad_len = MOMENT_SEQ_LEN - input_length + x = torch.nn.functional.pad(targets, (pad_len, 0), value=0.0) + mask = torch.cat( + [ + torch.zeros(pad_len, device=targets.device), + torch.ones(input_length, device=targets.device), + ] + ) + + batch_tensors.append(x) + batch_masks.append(mask) + + x_enc = torch.stack(batch_tensors, dim=0) # [batch, n_channels, 512] + input_mask = torch.stack(batch_masks, dim=0) # [batch, 512] + + return {"x_enc": x_enc, "input_mask": input_mask} + + def forecast(self, inputs: dict, **infer_kwargs) -> list[torch.Tensor]: + """ + Run MOMENT forecasting inference. + + For output_length <= model forecast_horizon, a single forward pass + suffices. For longer horizons, iterative (autoregressive) forecasting + is used: each step's predictions are appended to the context window + and fed back as input. + + Parameters + ---------- + inputs : dict + Contains ``"x_enc"`` and ``"input_mask"`` from _preprocess. + infer_kwargs : dict + ``output_length`` (int): desired forecast length, default 96. + + Returns + ------- + list of torch.Tensor + Each tensor has shape ``[n_channels, output_length]``. + """ + output_length = infer_kwargs.get("output_length", 96) + x_enc = inputs["x_enc"].to(self.model.device) + input_mask = inputs["input_mask"].to(self.model.device) + + model_horizon = self.model.config.forecast_horizon + batch_size, n_channels, seq_len = x_enc.shape + + if output_length <= model_horizon: + # Single-shot inference + with torch.no_grad(): + output = self.model(x_enc=x_enc, input_mask=input_mask) + forecasts = output.forecast[:, :, :output_length] + else: + # Iterative forecasting for long horizons + forecasts_list = [] + remaining = output_length + current_x = x_enc + current_mask = input_mask + + while remaining > 0: + with torch.no_grad(): + output = self.model(x_enc=current_x, input_mask=current_mask) + step_forecast = output.forecast[:, :, : min(model_horizon, remaining)] + forecasts_list.append(step_forecast) + remaining -= step_forecast.shape[-1] + + if remaining > 0: + # Slide context window: append forecast, drop oldest + step_len = step_forecast.shape[-1] + current_x = torch.cat( + [current_x[:, :, step_len:], step_forecast], dim=-1 + ) + current_mask = torch.ones( + batch_size, seq_len, device=x_enc.device + ) + + forecasts = torch.cat(forecasts_list, dim=-1) + + # Split batch into list of per-sample tensors + return [forecasts[i] for i in range(batch_size)] + + def _postprocess(self, outputs: list[torch.Tensor], **infer_kwargs) -> list[torch.Tensor]: + """ + Postprocess outputs. Each tensor is already [n_channels, output_length]. + """ + return outputs diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java index bc6d54c37e7fa..c3bb7cb5e971c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java @@ -2595,6 +2595,12 @@ private void analyzeSelectSingleColumn( private Analysis.GroupingSetAnalysis analyzeGroupBy( QuerySpecification node, Scope scope, List outputExpressions) { if (node.getGroupBy().isPresent()) { + + // Handle GROUP BY ALL: infer grouping columns from SELECT expressions + if (node.getGroupBy().get().isAll()) { + return analyzeGroupByAll(node, scope, outputExpressions); + } + ImmutableList.Builder>> cubes = ImmutableList.builder(); ImmutableList.Builder>> rollups = ImmutableList.builder(); ImmutableList.Builder>> sets = ImmutableList.builder(); @@ -2718,6 +2724,69 @@ private Analysis.GroupingSetAnalysis analyzeGroupBy( return result; } + private Analysis.GroupingSetAnalysis analyzeGroupByAll( + QuerySpecification node, Scope scope, List outputExpressions) { + ImmutableList.Builder>> sets = ImmutableList.builder(); + ImmutableList.Builder complexExpressions = ImmutableList.builder(); + ImmutableList.Builder groupingExpressions = ImmutableList.builder(); + FunctionCall gapFillColumn = null; + ImmutableList.Builder gapFillGroupingExpressions = ImmutableList.builder(); + + for (Expression outputExpression : outputExpressions) { + List aggregates = + extractAggregateFunctions(ImmutableList.of(outputExpression)); + List windowFunctions = + extractWindowFunctions(ImmutableList.of(outputExpression)); + if (!aggregates.isEmpty() || !windowFunctions.isEmpty()) { + continue; + } + + analyzeExpression(outputExpression, scope); + ResolvedField field = + analysis.getColumnReferenceFields().get(NodeRef.of(outputExpression)); + if (field != null) { + sets.add(ImmutableList.of(ImmutableSet.of(field.getFieldId()))); + } else { + complexExpressions.add(outputExpression); + } + + if (isDateBinGapFill(outputExpression)) { + if (gapFillColumn != null) { + throw new SemanticException("multiple date_bin_gapfill calls not allowed"); + } + gapFillColumn = (FunctionCall) outputExpression; + } else { + gapFillGroupingExpressions.add(outputExpression); + } + + groupingExpressions.add(outputExpression); + } + + List expressions = groupingExpressions.build(); + for (Expression expression : expressions) { + Type type = analysis.getType(expression); + if (!type.isComparable()) { + throw new SemanticException( + String.format( + "%s is not comparable, and therefore cannot be used in GROUP BY", type)); + } + } + + Analysis.GroupingSetAnalysis groupingSets = + new Analysis.GroupingSetAnalysis( + expressions, + ImmutableList.of(), + ImmutableList.of(), + sets.build(), + complexExpressions.build()); + analysis.setGroupingSets(node, groupingSets); + if (gapFillColumn != null) { + analysis.setGapFill(node, gapFillColumn); + analysis.setGapFillGroupingKeys(node, gapFillGroupingExpressions.build()); + } + return groupingSets; + } + private boolean isDateBinGapFill(Expression column) { return column instanceof FunctionCall && DATE_BIN diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/GroupBy.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/GroupBy.java index 00d201648ca8f..2b9e3cfaf511d 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/GroupBy.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/GroupBy.java @@ -33,11 +33,13 @@ public class GroupBy extends Node { private static final long INSTANCE_SIZE = RamUsageEstimator.shallowSizeOfInstance(GroupBy.class); private final boolean isDistinct; + private final boolean isAll; private final List groupingElements; public GroupBy(boolean isDistinct, List groupingElements) { super(null); this.isDistinct = isDistinct; + this.isAll = false; this.groupingElements = ImmutableList.copyOf(requireNonNull(groupingElements)); } @@ -45,13 +47,25 @@ public GroupBy( NodeLocation location, boolean isDistinct, List groupingElements) { super(requireNonNull(location, "location is null")); this.isDistinct = isDistinct; + this.isAll = false; this.groupingElements = ImmutableList.copyOf(requireNonNull(groupingElements)); } + public GroupBy(NodeLocation location, boolean isAll) { + super(requireNonNull(location, "location is null")); + this.isDistinct = false; + this.isAll = isAll; + this.groupingElements = ImmutableList.of(); + } + public boolean isDistinct() { return isDistinct; } + public boolean isAll() { + return isAll; + } + public List getGroupingElements() { return groupingElements; } @@ -76,18 +90,20 @@ public boolean equals(Object o) { } GroupBy groupBy = (GroupBy) o; return isDistinct == groupBy.isDistinct + && isAll == groupBy.isAll && Objects.equals(groupingElements, groupBy.groupingElements); } @Override public int hashCode() { - return Objects.hash(isDistinct, groupingElements); + return Objects.hash(isDistinct, isAll, groupingElements); } @Override public String toString() { return toStringHelper(this) .add("isDistinct", isDistinct) + .add("isAll", isAll) .add("groupingElements", groupingElements) .toString(); } @@ -98,7 +114,7 @@ public boolean shallowEquals(Node other) { return false; } - return isDistinct == ((GroupBy) other).isDistinct; + return isDistinct == ((GroupBy) other).isDistinct && isAll == ((GroupBy) other).isAll; } @Override diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java index 70dc79b6adb3a..5e08488c1bfe8 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java @@ -2419,6 +2419,9 @@ public Node visitSelectAll(RelationalSqlParser.SelectAllContext ctx) { @Override public Node visitGroupBy(RelationalSqlParser.GroupByContext ctx) { + if (ctx.ALL() != null && ctx.groupingElement().isEmpty()) { + return new GroupBy(getLocation(ctx), true); + } return new GroupBy( getLocation(ctx), isDistinct(ctx.setQuantifier()), diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/util/SqlFormatter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/util/SqlFormatter.java index 0c1d862a886cc..635847a0d93fe 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/util/SqlFormatter.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/util/SqlFormatter.java @@ -248,13 +248,18 @@ protected Void visitQuerySpecification(QuerySpecification node, Integer indent) node.getGroupBy() .ifPresent( - groupBy -> + groupBy -> { + if (groupBy.isAll()) { + append(indent, "GROUP BY ALL").append('\n'); + } else { append( indent, "GROUP BY " + (groupBy.isDistinct() ? " DISTINCT " : "") + formatGroupBy(groupBy.getGroupingElements())) - .append('\n')); + .append('\n'); + } + }); node.getHaving() .ifPresent(having -> append(indent, "HAVING " + formatExpression(having)).append('\n')); diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/GroupByAllTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/GroupByAllTest.java new file mode 100644 index 0000000000000..407c95003f04d --- /dev/null +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/GroupByAllTest.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.analyzer; + +import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan; +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.PlanTester; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationTableScanNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.GapFillNode; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.junit.Test; + +import java.util.Optional; + +import static org.apache.iotdb.db.queryengine.plan.relational.analyzer.TestUtils.assertAnalyzeSemanticException; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanAssert.assertPlan; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.aggregation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.aggregationFunction; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.output; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.singleGroupingSet; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.tableScan; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.Step.SINGLE; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Tests for the GROUP BY ALL syntax. Verifies that the analyzer correctly infers grouping columns + * from the SELECT clause and that integration with date_bin_gapfill is preserved. + */ +public class GroupByAllTest { + + // ---- basic column inference ---- + + @Test + public void groupByAllSingleColumnTest() { + // GROUP BY ALL should infer s1 as the grouping key (the only non-aggregate expression) + PlanTester planTester = new PlanTester(); + LogicalQueryPlan plan = + planTester.createPlan("SELECT s1, count(s2) FROM table1 GROUP BY ALL"); + assertPlan( + plan, + output( + aggregation( + singleGroupingSet("s1"), + ImmutableMap.of( + Optional.empty(), + aggregationFunction("count", ImmutableList.of("s2"))), + ImmutableList.of(), + Optional.empty(), + SINGLE, + tableScan( + "testdb.table1", + ImmutableList.of("s1", "s2"), + ImmutableSet.of("s1", "s2"))))); + } + + @Test + public void groupByAllMultipleColumnsTest() { + // GROUP BY ALL should infer tag1, tag2, tag3 as grouping keys. + // The optimizer pushes aggregation into the table scan, producing AggregationTableScanNode. + // Verify by walking the plan tree. + PlanTester planTester = new PlanTester(); + LogicalQueryPlan plan = + planTester.createPlan( + "SELECT tag1, tag2, tag3, count(s2) FROM table1 GROUP BY ALL"); + + PlanNode root = plan.getRootNode(); + AggregationTableScanNode aggScan = findNode(root, AggregationTableScanNode.class); + assertTrue( + "Expected AggregationTableScanNode with grouping keys [tag1, tag2, tag3]", + aggScan != null + && aggScan.getGroupingKeys().stream() + .map(s -> s.getName()) + .collect(java.util.stream.Collectors.toSet()) + .containsAll(ImmutableSet.of("tag1", "tag2", "tag3"))); + } + + @Test + public void groupByAllEquivalenceTest() { + // GROUP BY ALL and explicit GROUP BY should produce the same plan structure + PlanTester tester1 = new PlanTester(); + PlanTester tester2 = new PlanTester(); + LogicalQueryPlan allPlan = + tester1.createPlan("SELECT s1, count(s2) FROM table1 GROUP BY ALL"); + LogicalQueryPlan explicitPlan = + tester2.createPlan("SELECT s1, count(s2) FROM table1 GROUP BY s1"); + + // Both plans should match the same pattern + assertPlan( + allPlan, + output( + aggregation( + singleGroupingSet("s1"), + ImmutableMap.of( + Optional.empty(), + aggregationFunction("count", ImmutableList.of("s2"))), + ImmutableList.of(), + Optional.empty(), + SINGLE, + tableScan( + "testdb.table1", + ImmutableList.of("s1", "s2"), + ImmutableSet.of("s1", "s2"))))); + assertPlan( + explicitPlan, + output( + aggregation( + singleGroupingSet("s1"), + ImmutableMap.of( + Optional.empty(), + aggregationFunction("count", ImmutableList.of("s2"))), + ImmutableList.of(), + Optional.empty(), + SINGLE, + tableScan( + "testdb.table1", + ImmutableList.of("s1", "s2"), + ImmutableSet.of("s1", "s2"))))); + } + + // ---- complex expression in SELECT ---- + + @Test + public void groupByAllWithExpressionTest() { + // GROUP BY ALL with a non-aggregate expression (s1 + 1) should infer it as a grouping key + PlanTester planTester = new PlanTester(); + LogicalQueryPlan plan = + planTester.createPlan("SELECT s1 + 1, count(s2) FROM table1 GROUP BY ALL"); + + // Verify there is an AggregationNode with a non-empty grouping set + PlanNode root = plan.getRootNode(); + AggregationNode aggNode = findNode(root, AggregationNode.class); + assertTrue( + "Expected AggregationNode with non-empty grouping keys for expression grouping", + aggNode != null && !aggNode.getGroupingKeys().isEmpty()); + } + + // ---- global aggregation (all SELECT items are aggregates) ---- + + @Test + public void groupByAllGlobalAggregationTest() { + // When all SELECT expressions are aggregates, GROUP BY ALL is equivalent to no GROUP BY + PlanTester planTester = new PlanTester(); + LogicalQueryPlan plan = + planTester.createPlan("SELECT count(s1), sum(s2) FROM table1 GROUP BY ALL"); + + // Verify global aggregation: grouping keys should be empty + PlanNode root = plan.getRootNode(); + // May be AggregationNode or AggregationTableScanNode + AggregationNode aggNode = findNode(root, AggregationNode.class); + if (aggNode != null) { + assertTrue( + "Expected empty grouping keys for global aggregation", + aggNode.getGroupingKeys().isEmpty()); + } else { + AggregationTableScanNode aggScan = findNode(root, AggregationTableScanNode.class); + assertTrue( + "Expected AggregationTableScanNode with empty grouping keys", + aggScan != null && aggScan.getGroupingKeys().isEmpty()); + } + } + + // ---- date_bin_gapfill integration ---- + + @Test + public void groupByAllWithGapFillTest() { + // GROUP BY ALL with date_bin_gapfill should produce a GapFillNode in the plan + PlanTester planTester = new PlanTester(); + LogicalQueryPlan plan = + planTester.createPlan( + "SELECT date_bin_gapfill(1h, time), tag1, avg(s1) " + + "FROM table1 " + + "WHERE time >= 1 AND time <= 10 " + + "GROUP BY ALL"); + + // The plan should contain a GapFillNode + PlanNode root = plan.getRootNode(); + GapFillNode gapFillNode = findNode(root, GapFillNode.class); + assertTrue("Expected GapFillNode in plan for date_bin_gapfill + GROUP BY ALL", gapFillNode != null); + } + + @Test + public void groupByAllMultipleGapFillTest() { + // Two date_bin_gapfill calls should be rejected even with GROUP BY ALL + assertAnalyzeSemanticException( + "SELECT date_bin_gapfill(1h, time), date_bin_gapfill(2h, time), avg(s1) " + + "FROM table1 WHERE time >= 1 AND time <= 10 GROUP BY ALL", + "multiple date_bin_gapfill calls not allowed"); + } + + // ---- backward compatibility ---- + + @Test + public void groupByAllQuantifierBackwardCompatibilityTest() { + // GROUP BY ALL a, b (ALL as set-quantifier, not GROUP BY ALL feature) + // must still parse and plan correctly — the ALL keyword followed by explicit columns + // should be treated as the set-quantifier form, not the new GROUP BY ALL syntax. + PlanTester planTester = new PlanTester(); + LogicalQueryPlan plan = + planTester.createPlan("SELECT count(s2) FROM table1 GROUP BY ALL s1"); + PlanNode root = plan.getRootNode(); + // The plan should have an aggregation with s1 as a grouping key + AggregationNode aggNode = findNode(root, AggregationNode.class); + AggregationTableScanNode aggScan = findNode(root, AggregationTableScanNode.class); + boolean hasS1Grouping; + if (aggNode != null) { + hasS1Grouping = + aggNode.getGroupingKeys().stream().anyMatch(s -> s.getName().equals("s1")); + } else { + hasS1Grouping = + aggScan != null + && aggScan.getGroupingKeys().stream().anyMatch(s -> s.getName().equals("s1")); + } + assertTrue( + "GROUP BY ALL s1 should parse as ALL-quantifier with explicit grouping key s1", + hasS1Grouping); + } + + // ---- helper ---- + + @SuppressWarnings("unchecked") + private static T findNode(PlanNode root, Class nodeClass) { + if (nodeClass.isInstance(root)) { + return (T) root; + } + for (PlanNode child : root.getChildren()) { + T result = findNode(child, nodeClass); + if (result != null) { + return result; + } + } + return null; + } +} diff --git a/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4 b/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4 index 291aa1ea1ae2b..c7fa4524b842d 100644 --- a/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4 +++ b/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4 @@ -1012,7 +1012,8 @@ querySpecification ; groupBy - : setQuantifier? groupingElement (',' groupingElement)* + : ALL + | setQuantifier? groupingElement (',' groupingElement)* ; groupingElement diff --git a/scripts/conf/confignode-env.sh b/scripts/conf/confignode-env.sh index edf587b8c2b2b..3b2117209fcf7 100644 --- a/scripts/conf/confignode-env.sh +++ b/scripts/conf/confignode-env.sh @@ -192,7 +192,7 @@ else JAVA=java fi -if [ -z $JAVA ] ; then +if [ -z "$JAVA" ] ; then echo Unable to find java executable. Check JAVA_HOME and PATH environment variables. > /dev/stderr exit 1; fi diff --git a/scripts/conf/datanode-env.sh b/scripts/conf/datanode-env.sh index 31206e550d3a7..f0f7da3e79bb3 100755 --- a/scripts/conf/datanode-env.sh +++ b/scripts/conf/datanode-env.sh @@ -198,7 +198,7 @@ else JAVA=java fi -if [ -z $JAVA ] ; then +if [ -z "$JAVA" ] ; then echo Unable to find java executable. Check JAVA_HOME and PATH environment variables. > /dev/stderr exit 1; fi