Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions sdk/python/modeltrace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""
ModelTrace Python SDK
Provides one-line inference logging for ML frameworks.

Usage:
from modeltrace import ModelTraceClient

client = ModelTraceClient(
rpc_url="https://soroban-testnet.stellar.org",
secret_key="S...",
contract_id="C..."
)

# Register a model
model_id = client.register_model(
name="MyLLM-v1",
version="1.0.0",
weights_hash="sha256:abc123...",
training_data_cid="QmXzY7bN9..."
)

# Log an inference
event_id = client.log_inference(
model_id=model_id,
payload_hash="sha256:def456...",
risk_level="LOW"
)
"""

from typing import Optional, Literal

RiskLevel = Literal["LOW", "MEDIUM", "HIGH", "CRITICAL"]


class ModelRecord:
def __init__(self, model_id: int, name: str, version: str,
weights_hash: str, training_data_cid: str,
registered_at: int, risk_level: str):
self.model_id = model_id
self.name = name
self.version = version
self.weights_hash = weights_hash
self.training_data_cid = training_data_cid
self.registered_at = registered_at
self.risk_level = risk_level


class AuditEvent:
def __init__(self, event_id: int, model_id: int, event_type: str,
payload_hash: str, risk_level: str, timestamp: int):
self.event_id = event_id
self.model_id = model_id
self.event_type = event_type
self.payload_hash = payload_hash
self.risk_level = risk_level
self.timestamp = timestamp


class RegistryStats:
def __init__(self, total_models: int, total_events: int, high_risk_flags: int):
self.total_models = total_models
self.total_events = total_events
self.high_risk_flags = high_risk_flags


class ModelTraceClient:
"""
High-level client for the ModelTrace Audit Registry on Stellar Soroban.
"""

def __init__(
self,
rpc_url: str,
secret_key: str,
contract_id: str,
network_passphrase: str = "Public Global Stellar Network ; September 2015",
):
self.rpc_url = rpc_url
self.contract_id = contract_id
self.network_passphrase = network_passphrase
# stellar-sdk integration point
# In production: initialize stellar_sdk.SorobanServer and Keypair

def register_model(
self,
name: str,
version: str,
weights_hash: str,
training_data_cid: str,
) -> int:
"""
Register an AI model in the on-chain Audit Registry.

Args:
name: Human-readable model name (e.g., "GPT-Audit-v2")
version: Semantic version string (e.g., "2.1.0")
weights_hash: SHA-256 hash of the model weights file
training_data_cid: IPFS CID of the training dataset manifest

Returns:
model_id: On-chain ID for subsequent event logging
"""
raise NotImplementedError(
"Install stellar-sdk and implement Soroban transaction builder. "
"See docs/sdk-setup.md for detailed instructions."
)

def log_inference(
self,
model_id: int,
payload_hash: str,
risk_level: RiskLevel = "LOW",
) -> int:
"""
Log an inference event for an already-registered model.

Args:
model_id: The on-chain ID from register_model()
payload_hash: SHA-256 hash of the inference input+output pair
risk_level: Assessed risk level of this inference

Returns:
event_id: On-chain event ID for this inference log
"""
raise NotImplementedError(
"Install stellar-sdk and implement Soroban transaction builder."
)

def flag_bias(self, model_id: int, evidence_hash: str) -> int:
"""
Flag a bias incident for a model with cryptographic evidence.
"""
raise NotImplementedError(
"Install stellar-sdk and implement Soroban transaction builder."
)

def get_model(self, model_id: int) -> ModelRecord:
"""Query model record from on-chain registry."""
raise NotImplementedError()

def get_stats(self) -> RegistryStats:
"""Get aggregate protocol statistics."""
raise NotImplementedError()
Loading