diff --git a/sigllm/pipelines/detector/custom_detector.json b/sigllm/pipelines/detector/custom_detector.json new file mode 100644 index 0000000..aa8128e --- /dev/null +++ b/sigllm/pipelines/detector/custom_detector.json @@ -0,0 +1,94 @@ +{ + "primitives": [ + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate", + "sklearn.impute.SimpleImputer", + "sigllm.primitives.transformation.Float2Scalar", + "sigllm.primitives.forecasting.custom.rolling_window_sequences", + "sigllm.primitives.transformation.format_as_string", + "sigllm.primitives.forecasting.custom.CustomForecast", + "sigllm.primitives.transformation.format_as_integer", + "sigllm.primitives.transformation.Scalar2Float", + "sigllm.primitives.postprocessing.aggregate_rolling_window", + "numpy.reshape", + "orion.primitives.timeseries_errors.regression_errors", + "orion.primitives.timeseries_anomalies.find_anomalies" + ], + "init_params": { + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1": { + "time_column": "timestamp", + "interval": 21600, + "method": "mean" + }, + "sigllm.primitives.transformation.Float2Scalar#1": { + "decimal": 2, + "rescale": true + }, + "sigllm.primitives.forecasting.custom.rolling_window_sequences#1": { + "target_column": 0, + "window_size": 140, + "target_size": 1 + }, + "sigllm.primitives.transformation.format_as_string#1": { + "space": true + }, + "sigllm.primitives.forecasting.custom.CustomForecast#1": { + "steps": 5 + }, + "sigllm.primitives.transformation.format_as_integer#1": { + "trunc": 1, + "errors": "coerce" + }, + "sigllm.primitives.postprocessing.aggregate_rolling_window#1": { + "agg": "median", + "remove_outliers": true + }, + "orion.primitives.timeseries_anomalies.find_anomalies#1": { + "window_size_portion": 0.3, + "window_step_size_portion": 0.1, + "fixed_threshold": true + } + }, + "input_names": { + "sigllm.primitives.transformation.Float2Scalar#1": { + "X": "y" + }, + "sigllm.primitives.transformation.format_as_integer#1": { + "X": "y_hat" + }, + "sigllm.primitives.transformation.Scalar2Float#1": { + "X": "y_hat" + }, + "sigllm.primitives.postprocessing.aggregate_rolling_window#1": { + "y": "y_hat" + }, + "numpy.reshape#1": { + "X": "y_hat" + }, + "orion.primitives.timeseries_anomalies.find_anomalies#1": { + "index": "target_index" + } + }, + "output_names": { + "sklearn.impute.SimpleImputer#1": { + "X": "y" + }, + "sigllm.primitives.forecasting.custom.CustomForecast#1": { + "y": "y_hat" + }, + "sigllm.primitives.transformation.format_as_integer#1": { + "X": "y_hat" + }, + "sigllm.primitives.transformation.Scalar2Float#1": { + "X": "y_hat" + }, + "sigllm.primitives.postprocessing.aggregate_rolling_window#1": { + "y": "y_hat" + }, + "numpy.reshape#1": { + "X": "y_hat" + }, + "orion.primitives.timeseries_anomalies.find_anomalies#1": { + "y": "anomalies" + } + } +} diff --git a/sigllm/pipelines/detector/mistral_detector.json b/sigllm/pipelines/detector/mistral_detector.json index 5200762..ecd7d3e 100644 --- a/sigllm/pipelines/detector/mistral_detector.json +++ b/sigllm/pipelines/detector/mistral_detector.json @@ -3,7 +3,7 @@ "mlstars.custom.timeseries_preprocessing.time_segments_aggregate", "sklearn.impute.SimpleImputer", "sigllm.primitives.transformation.Float2Scalar", - "mlstars.custom.timeseries_preprocessing.rolling_window_sequences", + "sigllm.primitives.forecasting.custom.rolling_window_sequences", "sigllm.primitives.transformation.format_as_string", "sigllm.primitives.forecasting.huggingface.HF", "sigllm.primitives.transformation.format_as_integer", @@ -23,7 +23,7 @@ "decimal": 2, "rescale": true }, - "mlstars.custom.timeseries_preprocessing.rolling_window_sequences#1": { + "sigllm.primitives.forecasting.custom.rolling_window_sequences#1": { "target_column": 0, "window_size": 140, "target_size": 1 diff --git a/sigllm/primitives/forecasting/custom.py b/sigllm/primitives/forecasting/custom.py new file mode 100644 index 0000000..c13764a --- /dev/null +++ b/sigllm/primitives/forecasting/custom.py @@ -0,0 +1,171 @@ +"""Custom LLM forecasting primitive. + +This module provides a forecasting primitive that works with any LLM backend +through the BaseLLMClient interface. Companies can plug in their own LLM +by implementing a simple client class. + +For OpenAI or HuggingFace, use the existing gpt.py or huggingface.py primitives. +This module is for custom/internal LLM backends. + +Example usage: + from sigllm.primitives.llm_client import BaseLLMClient + from sigllm.primitives.forecasting.custom import CustomForecast + + class CustomLLM(BaseLLMClient): + def generate(self, prompts, **kwargs): + # Your function to generate responses from the LLM here + return [[response] for response in llm_client.complete(prompts)] + + client = CustomLLM() + forecaster = CustomForecast(client=client, steps=5) + predictions = forecaster.forecast(X_strings) +""" + +import json +import os +from typing import List + +import numpy as np + +from sigllm.primitives.llm_client import BaseLLMClient + +PROMPT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'TEMPLATE.json') +PROMPTS = json.load(open(PROMPT_PATH)) + + +class CustomForecast: + """Forecast time series using a custom LLM client. + + This primitive wraps any LLM backend (via BaseLLMClient) to perform + time series forecasting. It handles prompt construction and response + parsing, while delegating the actual LLM calls to the provided client. + + For OpenAI or HuggingFace, use gpt.py or huggingface.py instead. + + Args: + client (BaseLLMClient): + Your custom LLM client instance that implements generate(). + sep (str): + Separator between values in the sequence. Default ",". + steps (int): + Number of steps ahead to forecast. Default 1. + temp (float): + Sampling temperature. Default 1.0. + samples (int): + Number of forecast samples per input. Default 1. + + Example: + class MyLLM(BaseLLMClient): + def generate(self, prompts, **kwargs): + return [[my_api.complete(p)] for p in prompts] + client = MyLLM() + forecaster = CustomForecast(client=client, steps=5) + predictions = forecaster.forecast(X) + """ + + def __init__( + self, + client: BaseLLMClient, + sep: str = ",", + steps: int = 1, + temp: float = 1.0, + samples: int = 1, + ): + if not isinstance(client, BaseLLMClient): + raise TypeError(f"client must be a BaseLLMClient instance, got {type(client)}.") + + self.client = client + self.sep = sep + self.steps = steps + self.temp = temp + self.samples = samples + + def _build_prompt(self, sequence: str) -> str: + """Build the forecasting prompt from a sequence string.""" + return " ".join([PROMPTS["user_message"], sequence, self.sep]) + + def _estimate_max_tokens(self, sequence: str) -> int: + """Estimate max tokens needed for the forecast.""" + values = sequence.split(self.sep) + if len(values) > 0: + avg_len = sum(len(v.strip()) for v in values) / len(values) + return int((avg_len + 2) * self.steps) + 10 + return 5 + + def forecast(self, X: List[str], **kwargs) -> List[List[str]]: + """Forecast future values for each input sequence. + + Args: + X (List[str] or ndarray): + Input sequences as strings. Each string is a comma-separated + sequence of numeric values. + **kwargs: + Additional arguments passed to the LLM client. + + Returns: + List[List[str]]: + For each input sequence, a list of forecast strings. + """ + prompts = [self._build_prompt(seq) for seq in X] + max_tokens = self._estimate_max_tokens(X[0]) if len(X) > 0 else 5 + + responses = self.client.generate( + prompts=prompts, + system_message=PROMPTS["system_message"], + max_tokens=max_tokens, + temperature=self.temp, + n_samples=self.samples, + **kwargs, + ) + + return responses + + +def rolling_window_sequences(X, y, index, window_size, target_size, step_size, target_column): + """Create rolling window sequences out of time series data. + + The function creates an array of input sequences and an array of target sequences by rolling + over the input sequence with a specified window. + Optionally, certain values can be dropped from the sequences. + + Args: + X (ndarray): + N-dimensional input sequence to iterate over. + y (ndarray): + N-dimensional target sequence to iterate over. + index (ndarray): + Array containing the index values of X. + window_size (int): + Length of the input sequences. + target_size (int): + Length of the target sequences. + step_size (int): + Indicating the number of steps to move the window forward each round. + target_column (int): + Indicating which column of X is the target. + + Returns: + ndarray, ndarray, ndarray, ndarray: + * input sequences. + * target sequences. + * first index value of each input sequence. + * first index value of each target sequence. + """ + out_X = list() + out_y = list() + X_index = list() + y_index = list() + target = y[:, target_column] + + start = 0 + max_start = len(X) - window_size - target_size + 1 + while start < max_start: + end = start + window_size + + out_X.append(X[start:end]) + out_y.append(target[end : end + target_size]) + X_index.append(index[start]) + y_index.append(index[end]) + start = start + step_size + + return np.asarray(out_X), np.asarray(out_y), np.asarray(X_index), np.asarray(y_index) diff --git a/sigllm/primitives/jsons/sigllm.primitives.forecasting.custom.CustomForecast.json b/sigllm/primitives/jsons/sigllm.primitives.forecasting.custom.CustomForecast.json new file mode 100644 index 0000000..0eb9b59 --- /dev/null +++ b/sigllm/primitives/jsons/sigllm.primitives.forecasting.custom.CustomForecast.json @@ -0,0 +1,52 @@ +{ + "name": "sigllm.primitives.forecasting.custom.CustomForecast", + "contributors": [ + "Allen Baranov " + ], + "description": "Forecast time series using a custom LLM client. For OpenAI/HuggingFace, use gpt.py or huggingface.py instead.", + "classifiers": { + "type": "estimator", + "subtype": "regressor" + }, + "modalities": [], + "primitive": "sigllm.primitives.forecasting.custom.CustomForecast", + "produce": { + "method": "forecast", + "args": [ + { + "name": "X", + "type": "ndarray" + } + ], + "output": [ + { + "name": "y", + "type": "ndarray" + } + ] + }, + "hyperparameters": { + "fixed": { + "client": { + "type": "object", + "description": "BaseLLMClient instance - must be provided" + }, + "sep": { + "type": "str", + "default": "," + }, + "steps": { + "type": "int", + "default": 1 + }, + "temp": { + "type": "float", + "default": 1 + }, + "samples": { + "type": "int", + "default": 1 + } + } + } +} diff --git a/sigllm/primitives/jsons/sigllm.primitives.forecasting.custom.rolling_window_sequences.json b/sigllm/primitives/jsons/sigllm.primitives.forecasting.custom.rolling_window_sequences.json new file mode 100644 index 0000000..597f14b --- /dev/null +++ b/sigllm/primitives/jsons/sigllm.primitives.forecasting.custom.rolling_window_sequences.json @@ -0,0 +1,72 @@ +{ + "name": "sigllm.primitives.forecasting.custom.rolling_window_sequences", + "contributors": [ + "Sarah Alnegheimish " + ], + "description": "Create rolling window sequences out of timeseries data.", + "classifiers": { + "type": "preprocessor", + "subtype": "feature_extractor" + }, + "modalities": [ + "timeseries" + ], + "primitive": "sigllm.primitives.forecasting.custom.rolling_window_sequences", + "produce": { + "args": [ + { + "name": "X", + "type": "ndarray" + }, + { + "name": "y", + "type": "ndarray" + }, + { + "name": "index", + "type": "ndarray" + } + ], + "output": [ + { + "name": "X", + "type": "ndarray" + }, + { + "name": "y", + "type": "ndarray" + }, + { + "name": "index", + "type": "ndarray" + }, + { + "name": "target_index", + "type": "ndarray" + } + ] + }, + "hyperparameters": { + "fixed": { + "window_size": { + "type": "int", + "default": 250 + }, + "target_size": { + "type": "int", + "default": 1 + }, + "step_size": { + "type": "int", + "default": 1 + }, + "target_column": { + "type": "str or int", + "default": 1 + } + } + } +} + + + diff --git a/sigllm/primitives/jsons/sigllm.primitives.prompting.custom.CustomDetect.json b/sigllm/primitives/jsons/sigllm.primitives.prompting.custom.CustomDetect.json new file mode 100644 index 0000000..9fb43a4 --- /dev/null +++ b/sigllm/primitives/jsons/sigllm.primitives.prompting.custom.CustomDetect.json @@ -0,0 +1,52 @@ +{ + "name": "sigllm.primitives.prompting.custom.CustomDetect", + "contributors": [ + "Allen Baranov " + ], + "description": "Detect anomalies in time series using a custom LLM client. For OpenAI/HuggingFace, use gpt.py or huggingface.py instead.", + "classifiers": { + "type": "estimator", + "subtype": "detector" + }, + "modalities": [], + "primitive": "sigllm.primitives.prompting.custom.CustomDetect", + "produce": { + "method": "detect", + "args": [ + { + "name": "X", + "type": "ndarray" + } + ], + "output": [ + { + "name": "y", + "type": "ndarray" + } + ] + }, + "hyperparameters": { + "fixed": { + "client": { + "type": "object", + "description": "BaseLLMClient instance - must be provided" + }, + "sep": { + "type": "str", + "default": "," + }, + "anomalous_percent": { + "type": "float", + "default": 0.5 + }, + "temp": { + "type": "float", + "default": 1 + }, + "samples": { + "type": "int", + "default": 10 + } + } + } +} diff --git a/sigllm/primitives/llm_client.py b/sigllm/primitives/llm_client.py new file mode 100644 index 0000000..2b1b6e8 --- /dev/null +++ b/sigllm/primitives/llm_client.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- + +"""LLM Client abstraction. + +This module provides a simple interface for custom LLM backends. Companies can +implement their own client by subclassing BaseLLMClient and implementing the +generate() method. + +For OpenAI or HuggingFace models, use the existing gpt.py or huggingface.py +primitives directly. This module is for custom/internal LLM backends. + +See tutorials/custom_llm_forecasting_pipeline.ipynb for a full example +with azure gpt. +""" + +from typing import List, Optional + + +class BaseLLMClient: + """Base class for custom LLM clients. + + Subclass this and implement generate() to integrate your own LLM backend. + + Example: + class CustomLLM(BaseLLMClient): + def __init__(self, llm_responder): + self.llm_responder = llm_responder + + def generate(self, prompts, **kwargs): + # Your function to generate responses from the LLM here + return [[response] for response in self.llm_responder.complete(prompts)] + + client = CustomLLM() + forecaster = CustomForecast(client=client, steps=5) + """ + + def generate( + self, + prompts: List[str], + system_message: Optional[str] = None, + max_tokens: int = 100, + temperature: float = 1.0, + n_samples: int = 1, + **kwargs, + ) -> List[List[str]]: + """Generate responses for a batch of prompts. + + Args: + prompts (List[str]): + List of prompt strings to send to the LLM. + system_message (str, optional): + System message to prepend (for chat models). Default None. + max_tokens (int): + Maximum tokens to generate per response. Default 100. + temperature (float): + Sampling temperature. Default 1.0. + n_samples (int): + Number of responses to generate per prompt. Default 1. + **kwargs: + Additional provider-specific arguments. + + Returns: + List[List[str]]: + For each prompt, a list of n_samples response strings. + Shape: (len(prompts), n_samples) + """ + raise NotImplementedError( + "Subclass BaseLLMClient and implement the generate() method." + ) diff --git a/sigllm/primitives/prompting/custom.py b/sigllm/primitives/prompting/custom.py new file mode 100644 index 0000000..137528f --- /dev/null +++ b/sigllm/primitives/prompting/custom.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- + +"""Custom LLM anomaly detection primitive. + +This module provides an anomaly detection primitive that works with any LLM +backend through the BaseLLMClient interface. Companies can plug in their own +LLM by implementing a simple client class. + +For OpenAI or HuggingFace, use the existing gpt.py or huggingface.py primitives. +This module is for custom/internal LLM backends. + +Example usage: + from sigllm.primitives.llm_client import BaseLLMClient + from sigllm.primitives.prompting.custom import CustomDetect + + class CustomLLM(BaseLLMClient): + def generate(self, prompts, **kwargs): + # Your function to generate responses from the LLM here + return [[response] for response in llm_client.complete(prompts)] + + client = CustomLLM() + detector = CustomDetect(client=client, samples=5) + anomalies = detector.detect(X_strings) +""" + +import json +import os +from typing import List + +from sigllm.primitives.llm_client import BaseLLMClient + + +PROMPT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'gpt_messages.json') +PROMPTS = json.load(open(PROMPT_PATH)) + + +class CustomDetect: + """Detect anomalies in time series using a custom LLM client. + + This primitive wraps any LLM backend (via BaseLLMClient) to perform + anomaly detection. It handles prompt construction and delegates the + actual LLM calls to the provided client. + + For OpenAI or HuggingFace, use gpt.py or huggingface.py instead. + + Args: + client (BaseLLMClient): + Your custom LLM client instance that implements generate(). + sep (str): + Separator between values in the sequence. Default ",". + anomalous_percent (float): + Expected percentage of time series that are anomalous. + Used to estimate response length. Default 0.5. + temp (float): + Sampling temperature. Default 1.0. + samples (int): + Number of detection samples per input. Default 10. + + Example: + class MyLLM(BaseLLMClient): + def generate(self, prompts, **kwargs): + return [[my_api.complete(p)] for p in prompts] + + client = MyLLM() + detector = CustomDetect(client=client, samples=5) + anomalies = detector.detect(X) + """ + + def __init__( + self, + client: BaseLLMClient, + sep: str = ",", + anomalous_percent: float = 0.5, + temp: float = 1.0, + samples: int = 10, + ): + if not isinstance(client, BaseLLMClient): + raise TypeError( + f"client must be a BaseLLMClient instance, got {type(client)}. " + "For OpenAI, use sigllm.primitives.prompting.gpt.GPT instead." + ) + + self.client = client + self.sep = sep + self.anomalous_percent = anomalous_percent + self.temp = temp + self.samples = samples + + def _build_prompt(self, sequence: str) -> str: + """Build the detection prompt from a sequence string.""" + return " ".join([PROMPTS["user_message"], sequence, self.sep]) + + def _estimate_max_tokens(self, sequence: str) -> int: + """Estimate max tokens needed for the detection response.""" + seq_len = len(sequence) + return int(seq_len * self.anomalous_percent) + 20 + + def detect(self, X: List[str], **kwargs) -> List[List[str]]: + """Detect anomalies in each input sequence. + + Args: + X (List[str] or ndarray): + Input sequences as strings. Each string is a comma-separated + sequence of numeric values. + **kwargs: + Additional arguments passed to the LLM client. + + Returns: + List[List[str]]: + For each input sequence, a list of detection response strings. + """ + prompts = [self._build_prompt(seq) for seq in X] + max_tokens = self._estimate_max_tokens(X[0]) if len(X) > 0 else 100 + + responses = self.client.generate( + prompts=prompts, + system_message=PROMPTS["system_message"], + max_tokens=max_tokens, + temperature=self.temp, + n_samples=self.samples, + **kwargs, + ) + + return responses diff --git a/tutorials/pipelines/custom_llm_forecasting_pipeline.ipynb b/tutorials/pipelines/custom_llm_forecasting_pipeline.ipynb new file mode 100644 index 0000000..a89f666 --- /dev/null +++ b/tutorials/pipelines/custom_llm_forecasting_pipeline.ipynb @@ -0,0 +1,319 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Integrating a Custom LLM with SigLLM\n", + "\n", + "This tutorial shows how to use your own LLM backend (like Azure OpenAI, Anthropic, or an internal API) with SigLLM's forecasting pipeline.\n", + "\n", + "## Overview\n", + "\n", + "SigLLM provides a `CustomForecast` primitive that works with any LLM. You just need to:\n", + "1. Implement a client class with a `generate()` method\n", + "2. Pass it to the `custom_detector` pipeline\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Implement Your LLM Client\n", + "\n", + "Create a class that inherits from `BaseLLMClient`. Here's an example for Azure OpenAI:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sigllm.primitives.llm_client import BaseLLMClient\n", + "from openai import AzureOpenAI\n", + "\n", + "class AzureGPTClient(BaseLLMClient):\n", + " \"\"\"Azure OpenAI client for SigLLM.\"\"\"\n", + " \n", + " def __init__(self, endpoint, api_key, deployment, api_version=\"2024-02-15-preview\"):\n", + " self.endpoint = endpoint\n", + " self.api_key = api_key\n", + " self.deployment = deployment\n", + " self.api_version = api_version\n", + " \n", + " def generate(self, prompts, system_message=None, max_tokens=100, \n", + " temperature=1.0, n_samples=1, **kwargs):\n", + " \"\"\"Generate responses for a batch of prompts.\"\"\"\n", + " all_responses = []\n", + " for prompt in prompts:\n", + " messages = []\n", + " if system_message:\n", + " messages.append({\"role\": \"system\", \"content\": system_message})\n", + " messages.append({\"role\": \"user\", \"content\": prompt})\n", + " \n", + " response = self.client.chat.completions.create(\n", + " model=self.deployment,\n", + " messages=messages,\n", + " max_tokens=max_tokens,\n", + " temperature=temperature,\n", + " n=n_samples,\n", + " )\n", + " texts = [choice.message.content for choice in response.choices]\n", + " all_responses.append(texts)\n", + " \n", + " return all_responses\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Create Test Data\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Data shape: (50, 2)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
timestampvalue
000.0
1360010.0
2720020.0
31080030.0
41440040.0
\n", + "
" + ], + "text/plain": [ + " timestamp value\n", + "0 0 0.0\n", + "1 3600 10.0\n", + "2 7200 20.0\n", + "3 10800 30.0\n", + "4 14400 40.0" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "# Simple linear test data\n", + "np.random.seed(42)\n", + "n_points = 50\n", + "data = pd.DataFrame({\n", + " \"timestamp\": np.arange(n_points) * 3600,\n", + " \"value\": np.linspace(0, n_points-1, n_points) * 10,\n", + "})\n", + "\n", + "print(\"Data shape:\", data.shape)\n", + "data.head()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Run the Pipeline\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pipeline primitives:\n", + "\t0: mlstars.custom.timeseries_preprocessing.time_segments_aggregate\n", + "\t1: sklearn.impute.SimpleImputer\n", + "\t2: sigllm.primitives.transformation.Float2Scalar\n", + "\t3: sigllm.primitives.forecasting.custom.rolling_window_sequences\n", + "\t4: sigllm.primitives.transformation.format_as_string\n", + "\t5: sigllm.primitives.forecasting.custom.CustomForecast\n", + "\t6: sigllm.primitives.transformation.format_as_integer\n", + "\t7: sigllm.primitives.transformation.Scalar2Float\n", + "\t8: sigllm.primitives.postprocessing.aggregate_rolling_window\n", + "\t9: numpy.reshape\n", + "\t10: orion.primitives.timeseries_errors.regression_errors\n", + "\t11: orion.primitives.timeseries_anomalies.find_anomalies\n" + ] + } + ], + "source": [ + "from mlblocks import MLPipeline\n", + "\n", + "client = AzureGPTClient(\n", + " endpoint=AZURE_ENDPOINT,\n", + " api_key=AZURE_API_KEY,\n", + " deployment=AZURE_DEPLOYMENT,\n", + ")\n", + "\n", + "pipeline = MLPipeline(\n", + " \"custom_detector\",\n", + " init_params={\n", + " \"sigllm.primitives.forecasting.custom.CustomForecast#1\": {\n", + " \"client\": client,\n", + " \"steps\": 1,\n", + " \"temp\": 0.3,\n", + " \"samples\": 4,\n", + " },\n", + " \"mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1\": {\n", + " \"time_column\": \"timestamp\",\n", + " \"interval\": 3600,\n", + " },\n", + " \"sigllm.primitives.forecasting.custom.rolling_window_sequences#1\": {\n", + " \"window_size\": 10,\n", + " \"target_column\": 0,\n", + " \"target_size\": 1,\n", + " },\n", + " \"sigllm.primitives.transformation.format_as_integer#1\": {\n", + " \"trunc\": 1,\n", + " \"errors\": \"coerce\"\n", + " }\n", + " }\n", + ")\n", + "\n", + "print(\"Pipeline primitives:\")\n", + "for i, p in enumerate(pipeline.primitives):\n", + " print(f\"\\t{i}: {p}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Predictions shape: (40, 4, 1)\n", + "Actuals shape: (40, 1)\n", + "Window 0: pred=100.0, actual=100.0, error=0.0\n", + "Window 1: pred=110.0, actual=110.0, error=0.0\n", + "Window 2: pred=120.0, actual=120.0, error=0.0\n", + "Window 3: pred=130.0, actual=130.0, error=0.0\n", + "Window 4: pred=140.0, actual=140.0, error=0.0\n" + ] + } + ], + "source": [ + "context = pipeline.fit(data, output_=7)\n", + "\n", + "y_hat = context.get('y_hat')\n", + "y = context.get('y')\n", + "\n", + "print(f\"Predictions shape: {y_hat.shape}\")\n", + "print(f\"Actuals shape: {y.shape}\")\n", + "\n", + "for i in range(min(5, len(y_hat))):\n", + " pred = float(np.array(y_hat[i]).flatten()[0])\n", + " actual = float(np.array(y[i]).flatten()[0])\n", + " print(f\"Window {i}: pred={pred:.1f}, actual={actual:.1f}, error={pred-actual:.1f}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The `generate()` Interface\n", + "\n", + "Your `generate()` method must accept these parameters:\n", + "\n", + "| Parameter | Type | Description |\n", + "|-----------|------|-------------|\n", + "| `prompts` | `List[str]` | List of prompt strings |\n", + "| `system_message` | `str` or `None` | Optional system context |\n", + "| `max_tokens` | `int` | Max tokens per response |\n", + "| `temperature` | `float` | Sampling temperature |\n", + "| `n_samples` | `int` | Responses per prompt |\n", + "| `**kwargs` | | Additional parameters |\n", + "\n", + "**Returns:** `List[List[str]]` — for each prompt, a list of `n_samples` response strings.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "orion310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}