Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions arize_toolkit/cli/alert_integrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import click

from arize_toolkit.cli.client_factory import get_client
from arize_toolkit.cli.output import print_result

PROVIDER_CHOICES = ["slack", "pagerduty", "opsgenie"]


@click.group("alert-integrations")
def alert_integrations_group():
"""Manage alert integrations (Slack, PagerDuty, OpsGenie)."""
pass


@alert_integrations_group.command("list")
@click.option(
"--provider",
type=click.Choice(PROVIDER_CHOICES),
default=None,
help="Filter by provider.",
)
@click.option("--search", default=None, help="Search by integration name.")
@click.pass_context
def integrations_list(ctx, provider, search):
"""List alert integrations for the organization."""
client = get_client(ctx)
data = client.list_integrations(provider_name=provider, search=search)
print_result(
data,
columns=["id", "name", "providerName", "channelName", "alertSeverity"],
title="Alert Integrations",
json_mode=ctx.obj["json_mode"],
)
2 changes: 2 additions & 0 deletions arize_toolkit/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import click

from arize_toolkit import __version__
from arize_toolkit.cli.alert_integrations import alert_integrations_group
from arize_toolkit.cli.config_cmd import config_group, get_profile, update_profile
from arize_toolkit.cli.custom_metrics import custom_metrics_group
from arize_toolkit.cli.dashboards import dashboards_group
Expand Down Expand Up @@ -80,6 +81,7 @@ def persist_client_state(ctx, *args, **kwargs):
update_profile(profile_name, **updates)


cli.add_command(alert_integrations_group)
cli.add_command(config_group)
cli.add_command(spaces_group)
cli.add_command(orgs_group)
Expand Down
7 changes: 7 additions & 0 deletions arize_toolkit/cli/monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def _common_monitor_options(f):
f = click.option("--operator2", type=click.Choice(OPERATOR_CHOICES), default=None, help="Second operator (double mode).")(f)
f = click.option("--email", multiple=True, help="Email addresses for notifications.")(f)
f = click.option("--integration-key-id", multiple=True, help="Integration key IDs for notifications.")(f)
f = click.option("--integration-name", multiple=True, help="Integration names for notifications (resolved to IDs).")(f)
return f


Expand Down Expand Up @@ -88,6 +89,7 @@ def monitors_create_performance(
operator2,
email,
integration_key_id,
integration_name,
):
"""Create a performance monitor."""
client = get_client(ctx)
Expand All @@ -108,6 +110,7 @@ def monitors_create_performance(
operator2=operator2,
email_addresses=list(email) if email else None,
integration_key_ids=list(integration_key_id) if integration_key_id else None,
integration_names=list(integration_name) if integration_name else None,
)
print_url(url, label="Created monitor")

Expand Down Expand Up @@ -138,6 +141,7 @@ def monitors_create_drift(
operator2,
email,
integration_key_id,
integration_name,
):
"""Create a drift monitor."""
client = get_client(ctx)
Expand All @@ -158,6 +162,7 @@ def monitors_create_drift(
operator2=operator2,
email_addresses=list(email) if email else None,
integration_key_ids=list(integration_key_id) if integration_key_id else None,
integration_names=list(integration_name) if integration_name else None,
)
print_url(url, label="Created monitor")

Expand Down Expand Up @@ -190,6 +195,7 @@ def monitors_create_data_quality(
operator2,
email,
integration_key_id,
integration_name,
):
"""Create a data quality monitor."""
client = get_client(ctx)
Expand All @@ -211,6 +217,7 @@ def monitors_create_data_quality(
operator2=operator2,
email_addresses=list(email) if email else None,
integration_key_ids=list(integration_key_id) if integration_key_id else None,
integration_names=list(integration_name) if integration_name else None,
)
print_url(url, label="Created monitor")

Expand Down
94 changes: 94 additions & 0 deletions arize_toolkit/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
GetEvaluatorQuery,
GetEvaluatorsQuery,
)
from arize_toolkit.queries.integration_queries import GetIntegrationKeysQuery
from arize_toolkit.queries.llm_utils_queries import (
CreateAnnotationMutation,
CreatePromptMutation,
Expand Down Expand Up @@ -1810,6 +1811,66 @@ def _resolve_llm_integration_id(self, llm_integration_name: Optional[str] = None
available = [r.name for r in results]
raise ValueError(f"No LLM integration found with name '{llm_integration_name}'. " f"Available integrations: {available}")

def list_integrations(
self,
provider_name: Optional[str] = None,
search: Optional[str] = None,
) -> List[dict]:
"""List alert integration keys (slack, pagerduty, opsgenie) for the organization.

Args:
provider_name (Optional[str]): Filter by provider name (e.g. "slack", "pagerduty", "opsgenie")
search (Optional[str]): Search by integration name

Returns:
List[dict]: List of integration keys with id, name, providerName, channelName, alertSeverity, etc.

Example:
```python
# List all integrations
integrations = client.list_integrations()

# List only slack integrations
slack_integrations = client.list_integrations(provider_name="slack")
```
"""
results = GetIntegrationKeysQuery.run_graphql_query_to_list(
self._graphql_client,
organization_id=self.org_id,
providerName=provider_name,
search=search,
)
return [r.to_dict() for r in results]

def _resolve_integration_key_ids(
self,
integration_names: Union[str, List[str]],
) -> List[str]:
"""Resolve alert integration names to their IDs.

Args:
integration_names: Name(s) of alert integrations (slack, pagerduty, opsgenie) to look up.

Returns:
List[str]: The IDs of the matching integrations.

Raises:
ValueError: If any name does not match an existing integration.
"""
if isinstance(integration_names, str):
integration_names = [integration_names]
results = GetIntegrationKeysQuery.run_graphql_query_to_list(
self._graphql_client,
organization_id=self.org_id,
)
available = {r.name: r.id for r in results}
resolved_ids = []
for name in integration_names:
if name not in available:
raise ValueError(f"No alert integration found with name '{name}'. " f"Available integrations: {list(available.keys())}")
resolved_ids.append(available[name])
return resolved_ids

def get_evaluators(
self,
search: Optional[str] = None,
Expand Down Expand Up @@ -2864,6 +2925,7 @@ def create_performance_monitor(
std_dev_multiplier2: Optional[float] = None,
email_addresses: Optional[Union[str, List[str]]] = None,
integration_key_ids: Optional[Union[str, List[str]]] = None,
integration_names: Optional[Union[str, List[str]]] = None,
filters: Optional[Union[List[Dict], List[DimensionFilterInput]]] = None,
) -> str:
"""Creates a new performance metric monitor for a model.
Expand Down Expand Up @@ -2894,6 +2956,8 @@ def create_performance_monitor(
std_dev_multiplier2 (Optional[float]): Standard deviation multiplier for the second threshold (only used if threshold_mode is "double")
email_addresses (Optional[Union[str, List[str]]]): Email address(es) to notify when the monitor is triggered
integration_key_ids (Optional[Union[str, List[str]]]): ID(s) of integration key(s) to notify when the monitor is triggered
integration_names (Optional[Union[str, List[str]]]): Name(s) of alert integrations (slack, pagerduty, opsgenie) to notify.
These are resolved to integration key IDs automatically. Use this as an alternative to integration_key_ids.
filters (Optional[Union[List[Dict], List[DimensionFilterInput]]]): Filters to apply to the monitor
- filterType (FilterRowType): Type of filter to apply (featureLabel, tagLabel, actuals, predictionScore, etc)
- operator (ComparisonOperator): Comparison operator to apply (equals, notEquals, greaterThan, lessThan, greaterThanOrEqual, lessThanOrEqual)
Expand All @@ -2908,6 +2972,14 @@ def create_performance_monitor(
"""
if performance_metric is None and custom_metric_id is None:
raise ValueError("Either performance_metric or custom_metric_id must be provided")
if integration_names:
resolved_ids = self._resolve_integration_key_ids(integration_names)
if integration_key_ids:
if isinstance(integration_key_ids, str):
integration_key_ids = [integration_key_ids]
integration_key_ids = list(integration_key_ids) + resolved_ids
else:
integration_key_ids = resolved_ids
contacts = []
if email_addresses:
if isinstance(email_addresses, str):
Expand Down Expand Up @@ -2985,6 +3057,7 @@ def create_drift_monitor(
std_dev_multiplier2: Optional[float] = 2.0,
email_addresses: Optional[Union[str, List[str]]] = None,
integration_key_ids: Optional[Union[str, List[str]]] = None,
integration_names: Optional[Union[str, List[str]]] = None,
filters: Optional[Union[List[Dict], List[DimensionFilterInput]]] = None,
) -> str:
"""Creates a new drift monitor for a model.
Expand Down Expand Up @@ -3013,6 +3086,8 @@ def create_drift_monitor(
std_dev_multiplier2 (Optional[float]): Standard deviation multiplier for the second threshold (default is 2.0 if threshold_mode is "double" and a threshold2 is not provided)
email_addresses (Optional[List[str]]): Email addresses to notify when the monitor is triggered
integration_key_ids (Optional[List[str]]): IDs of integration keys to notify when the monitor is triggered
integration_names (Optional[Union[str, List[str]]]): Name(s) of alert integrations (slack, pagerduty, opsgenie) to notify.
These are resolved to integration key IDs automatically. Use this as an alternative to integration_key_ids.
filters (Optional[Union[List[Dict], List[DimensionFilterInput]]]): Filters to apply to the monitor
- filterType (FilterRowType): Type of filter to apply (featureLabel, tagLabel, actuals, predictionScore, etc)
- operator (ComparisonOperator): Comparison operator to apply (equals, notEquals, greaterThan, lessThan, greaterThanOrEqual, lessThanOrEqual)
Expand All @@ -3026,6 +3101,14 @@ def create_drift_monitor(
ArizeAPIException: If monitor creation fails or there is an API error

"""
if integration_names:
resolved_ids = self._resolve_integration_key_ids(integration_names)
if integration_key_ids:
if isinstance(integration_key_ids, str):
integration_key_ids = [integration_key_ids]
integration_key_ids = list(integration_key_ids) + resolved_ids
else:
integration_key_ids = resolved_ids
contacts = []
if email_addresses:
if isinstance(email_addresses, str):
Expand Down Expand Up @@ -3102,6 +3185,7 @@ def create_data_quality_monitor(
std_dev_multiplier2: Optional[float] = 2.0,
email_addresses: Optional[Union[str, List[str]]] = None,
integration_key_ids: Optional[Union[str, List[str]]] = None,
integration_names: Optional[Union[str, List[str]]] = None,
filters: Optional[Union[List[Dict], List[DimensionFilterInput]]] = None,
) -> str:
"""Creates a new data quality monitor for a model.
Expand Down Expand Up @@ -3131,6 +3215,8 @@ def create_data_quality_monitor(
std_dev_multiplier2 (Optional[float]): Standard deviation multiplier for the second threshold (default is 2.0 if threshold_mode is "double" and a threshold2 is not provided)
email_addresses (Optional[Union[str, List[str]]]): Email address(es) to notify when the monitor is triggered
integration_key_ids (Optional[Union[str, List[str]]]): ID(s) of integration key(s) to notify when the monitor is triggered
integration_names (Optional[Union[str, List[str]]]): Name(s) of alert integrations (slack, pagerduty, opsgenie) to notify.
These are resolved to integration key IDs automatically. Use this as an alternative to integration_key_ids.
filters (Optional[Union[List[Dict], List[DimensionFilterInput]]]): Filters to apply to the monitor
- filterType (FilterRowType): Type of filter to apply (featureLabel, tagLabel, actuals, predictionScore, etc)
- operator (ComparisonOperator): Comparison operator to apply (equals, notEquals, greaterThan, lessThan, greaterThanOrEqual, lessThanOrEqual)
Expand All @@ -3144,6 +3230,14 @@ def create_data_quality_monitor(
ArizeAPIException: If monitor creation fails or there is an API error

"""
if integration_names:
resolved_ids = self._resolve_integration_key_ids(integration_names)
if integration_key_ids:
if isinstance(integration_key_ids, str):
integration_key_ids = [integration_key_ids]
integration_key_ids = list(integration_key_ids) + resolved_ids
else:
integration_key_ids = resolved_ids
contacts = []
if email_addresses:
if isinstance(email_addresses, str):
Expand Down
17 changes: 14 additions & 3 deletions arize_toolkit/models/monitor_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,28 @@

from arize_toolkit.models.base_models import BaseNode, Dimension, DimensionFilterInput, DimensionValue, User
from arize_toolkit.models.custom_metrics_models import CustomMetric
from arize_toolkit.types import ComparisonOperator, DataQualityMetric, DimensionCategory, DriftMetric, FilterRowType, ModelEnvironment, MonitorCategory, PerformanceMetric
from arize_toolkit.types import (
ComparisonOperator,
DataQualityMetric,
DimensionCategory,
DriftMetric,
FilterRowType,
IntegrationAlertSeverity,
IntegrationProvider,
ModelEnvironment,
MonitorCategory,
PerformanceMetric,
)
from arize_toolkit.utils import GraphQLModel

## Monitor GraphQL Models ##


class IntegrationKey(BaseNode):
providerName: Literal["slack", "pagerduty", "opsgenie"]
providerName: IntegrationProvider
createdAt: Optional[datetime] = Field(default=None)
channelName: Optional[str] = Field(default=None)
alertSeverity: Optional[str] = Field(default=None)
alertSeverity: Optional[IntegrationAlertSeverity] = Field(default=None)


class MonitorContact(GraphQLModel):
Expand Down
42 changes: 42 additions & 0 deletions arize_toolkit/queries/integration_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import List, Optional, Tuple

from arize_toolkit.models.monitor_models import IntegrationKey
from arize_toolkit.queries.basequery import ArizeAPIException, BaseQuery, BaseResponse, BaseVariables


class GetIntegrationKeysQuery(BaseQuery):
graphql_query = (
"""
query getIntegrationKeys($organization_id: ID!, $providerName: IntegrationProvider, $search: String) {
node(id: $organization_id) {
... on AccountOrganization {
integrations(providerName: $providerName, search: $search) { """
+ IntegrationKey.to_graphql_fields()
+ """
}
}
}
}
"""
)
query_description = "Get all alert integration keys for an organization"

class Variables(BaseVariables):
organization_id: str
providerName: Optional[str] = None
search: Optional[str] = None

class QueryException(ArizeAPIException):
message: str = "Error getting integration keys"

class QueryResponse(IntegrationKey):
pass

@classmethod
def _parse_graphql_result(cls, result: dict) -> Tuple[List[BaseResponse], bool, Optional[str]]:
integrations = result.get("node", {}).get("integrations", [])
return (
[cls.QueryResponse(**integration) for integration in integrations],
False,
None,
)
22 changes: 22 additions & 0 deletions arize_toolkit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,28 @@ class PromptVersionInputVariableFormatEnum(InputValidationEnum):
MUSTACHE = "MUSTACHE", "Mustache", "mustache", "{{}}"


class IntegrationProvider(InputValidationEnum):
"""Alert integration providers for monitor notifications (slack, pagerduty, opsgenie)"""

slack = "slack", "Slack", "SLACK"
pagerduty = "pagerduty", "PagerDuty", "Pagerduty", "PAGERDUTY"
opsgenie = "opsgenie", "OpsGenie", "Opsgenie", "OPSGENIE"


class IntegrationAlertSeverity(InputValidationEnum):
"""Alert severity levels for PagerDuty and OpsGenie integrations"""

opsgenieP1 = "opsgenieP1", "OpsGenie P1", "P1"
opsgenieP2 = "opsgenieP2", "OpsGenie P2", "P2"
opsgenieP3 = "opsgenieP3", "OpsGenie P3", "P3"
opsgenieP4 = "opsgenieP4", "OpsGenie P4", "P4"
opsgenieP5 = "opsgenieP5", "OpsGenie P5", "P5"
pagerdutycritical = "pagerdutycritical", "PagerDuty Critical", "critical"
pagerdutyerror = "pagerdutyerror", "PagerDuty Error", "error"
pagerdutywarning = "pagerdutywarning", "PagerDuty Warning", "warning"
pagerdutyinfo = "pagerdutyinfo", "PagerDuty Info", "info"


class LLMIntegrationProvider(InputValidationEnum):
"""The LLM provider used for execution with the prompt"""

Expand Down
Loading
Loading