Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 0 additions & 13 deletions bindings/node/__tests__/cancellation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,4 @@ describe('JsCancellationToken', () => {
expect(token.isImmediate).toBe(false)
})

it('requestGraceful accepts optional reason string', () => {
const token = new JsCancellationToken()
token.requestGraceful('user requested stop')
expect(token.isCancelled).toBe(true)
expect(token.isGraceful).toBe(true)
})

it('requestImmediate accepts optional reason string', () => {
const token = new JsCancellationToken()
token.requestImmediate('timeout exceeded')
expect(token.isCancelled).toBe(true)
expect(token.isImmediate).toBe(true)
})
})
11 changes: 0 additions & 11 deletions bindings/node/__tests__/types.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import {
ContextInjectionRole,
ApprovalDefault,
UserMessageLevel,
Role,
} from '../index.js'

describe('enum types', () => {
Expand Down Expand Up @@ -51,14 +50,4 @@ describe('enum types', () => {
})
})

describe('Role', () => {
it('has all expected variants with correct string values', () => {
expect(Role.System).toBe('System')
expect(Role.Developer).toBe('Developer')
expect(Role.User).toBe('User')
expect(Role.Assistant).toBe('Assistant')
expect(Role.Function).toBe('Function')
expect(Role.Tool).toBe('Tool')
})
})
})
25 changes: 2 additions & 23 deletions bindings/node/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,6 @@ export const enum UserMessageLevel {
Warning = 'Warning',
Error = 'Error'
}
export const enum Role {
System = 'System',
Developer = 'Developer',
User = 'User',
Assistant = 'Assistant',
Function = 'Function',
Tool = 'Tool'
}
export interface JsToolResult {
success: boolean
output?: string
error?: string
}
export interface JsToolSpec {
name: string
description?: string
parametersJson: string
}
export interface JsHookResult {
action: HookAction
reason?: string
Expand All @@ -63,9 +45,6 @@ export interface JsHookResult {
approvalTimeout?: number
approvalDefault?: ApprovalDefault
}
export interface JsSessionConfig {
configJson: string
}
/** Structured error object returned to JS with a typed `code` property. */
export interface JsAmplifierError {
code: string
Expand Down Expand Up @@ -134,8 +113,8 @@ export declare class JsCancellationToken {
get isCancelled(): boolean
get isGraceful(): boolean
get isImmediate(): boolean
requestGraceful(reason?: string | undefined | null): void
requestImmediate(reason?: string | undefined | null): void
requestGraceful(): void
requestImmediate(): void
reset(): void
}
/**
Expand Down
3 changes: 1 addition & 2 deletions bindings/node/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -310,15 +310,14 @@ if (!nativeBinding) {
throw new Error(`Failed to load native binding`)
}

const { hello, HookAction, SessionState, ContextInjectionRole, ApprovalDefault, UserMessageLevel, Role, JsCancellationToken, JsHookRegistry, JsCoordinator, JsAmplifierSession, JsToolBridge, amplifierErrorToJs, resolveModule, loadWasmFromPath } = nativeBinding
const { hello, HookAction, SessionState, ContextInjectionRole, ApprovalDefault, UserMessageLevel, JsCancellationToken, JsHookRegistry, JsCoordinator, JsAmplifierSession, JsToolBridge, amplifierErrorToJs, resolveModule, loadWasmFromPath } = nativeBinding

module.exports.hello = hello
module.exports.HookAction = HookAction
module.exports.SessionState = SessionState
module.exports.ContextInjectionRole = ContextInjectionRole
module.exports.ApprovalDefault = ApprovalDefault
module.exports.UserMessageLevel = UserMessageLevel
module.exports.Role = Role
module.exports.JsCancellationToken = JsCancellationToken
module.exports.JsHookRegistry = JsHookRegistry
module.exports.JsCoordinator = JsCoordinator
Expand Down
36 changes: 5 additions & 31 deletions bindings/node/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,6 @@ pub enum UserMessageLevel {
Error,
}

#[napi(string_enum)]
pub enum Role {
System,
Developer,
User,
Assistant,
Function,
Tool,
}

// ---------------------------------------------------------------------------
// Bidirectional From conversions: HookAction <-> amplifier_core::models::HookAction
// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -215,20 +205,6 @@ impl From<ApprovalDefault> for core_models::ApprovalDefault {
// Structs — exported as TypeScript interfaces via #[napi(object)]
// ---------------------------------------------------------------------------

#[napi(object)]
pub struct JsToolResult {
pub success: bool,
pub output: Option<String>,
pub error: Option<String>,
}

#[napi(object)]
pub struct JsToolSpec {
pub name: String,
pub description: Option<String>,
pub parameters_json: String,
}

#[napi(object)]
pub struct JsHookResult {
pub action: HookAction,
Expand All @@ -245,11 +221,6 @@ pub struct JsHookResult {
pub approval_default: Option<ApprovalDefault>,
}

#[napi(object)]
pub struct JsSessionConfig {
pub config_json: String,
}

// ---------------------------------------------------------------------------
// Classes — exported as TypeScript classes via #[napi]
// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -298,12 +269,12 @@ impl JsCancellationToken {
}

#[napi]
pub fn request_graceful(&self, _reason: Option<String>) {
pub fn request_graceful(&self) {
self.inner.request_graceful();
}

#[napi]
pub fn request_immediate(&self, _reason: Option<String>) {
pub fn request_immediate(&self) {
self.inner.request_immediate();
}

Expand Down Expand Up @@ -937,6 +908,9 @@ pub fn resolve_module(path: String) -> Result<JsModuleManifest> {
};

let (artifact_type, artifact_path, endpoint, package_name) = match &manifest.artifact {
amplifier_core::module_resolver::ModuleArtifact::WasmPath(path) => {
("wasm", Some(path.to_string_lossy().to_string()), None, None)
}
amplifier_core::module_resolver::ModuleArtifact::WasmBytes { path, .. } => {
("wasm", Some(path.to_string_lossy().to_string()), None, None)
}
Expand Down
33 changes: 4 additions & 29 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2644,6 +2644,10 @@ fn resolve_module(py: Python<'_>, path: String) -> PyResult<Py<PyDict>> {
dict.set_item("module_type", type_str)?;

match &manifest.artifact {
amplifier_core::module_resolver::ModuleArtifact::WasmPath(path) => {
dict.set_item("artifact_type", "wasm")?;
dict.set_item("artifact_path", path.to_string_lossy().as_ref())?;
}
amplifier_core::module_resolver::ModuleArtifact::WasmBytes { path, .. } => {
dict.set_item("artifact_type", "wasm")?;
dict.set_item("artifact_path", path.to_string_lossy().as_ref())?;
Expand All @@ -2661,31 +2665,6 @@ fn resolve_module(py: Python<'_>, path: String) -> PyResult<Py<PyDict>> {
Ok(dict.unbind())
}

/// Ensure WASM bytes are loaded from disk when `amplifier.toml` deferred loading.
///
/// `parse_amplifier_toml` stores `bytes: Vec::new()` with a path, deferring
/// the actual file read to the transport layer. This helper fills in the
/// bytes before handing the manifest to `load_module`.
fn ensure_wasm_bytes_loaded(
manifest: &mut amplifier_core::module_resolver::ModuleManifest,
) -> PyResult<()> {
if let amplifier_core::module_resolver::ModuleArtifact::WasmBytes {
ref mut bytes,
ref path,
} = manifest.artifact
{
if bytes.is_empty() && path.is_file() {
*bytes = std::fs::read(path).map_err(|e| {
PyErr::new::<PyRuntimeError, _>(format!(
"Failed to read WASM bytes from {}: {e}",
path.display()
))
})?;
}
}
Ok(())
}

/// Load a WASM module from a resolved manifest path.
///
/// Returns a dict with "status" = "loaded" and "module_type" on success.
Expand All @@ -2703,8 +2682,6 @@ fn load_wasm_from_path(py: Python<'_>, path: String) -> PyResult<Py<PyDict>> {
)));
}

ensure_wasm_bytes_loaded(&mut manifest)?;

let engine = amplifier_core::wasm_engine::WasmEngine::new().map_err(|e| {
PyErr::new::<PyRuntimeError, _>(format!("WASM engine creation failed: {e}"))
})?;
Expand Down Expand Up @@ -2856,8 +2833,6 @@ fn load_and_mount_wasm(
)));
}

ensure_wasm_bytes_loaded(&mut manifest)?;

let engine = amplifier_core::wasm_engine::WasmEngine::new().map_err(|e| {
PyErr::new::<PyRuntimeError, _>(format!("WASM engine creation failed: {e}"))
})?;
Expand Down
3 changes: 2 additions & 1 deletion crates/amplifier-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ tonic = "0.12"
tokio-stream = { version = "0.1", features = ["net"] }
wasmtime = { version = "42", optional = true, features = ["component-model"] }
wasmtime-wasi = { version = "42", optional = true }
sha2 = { version = "0.10", optional = true }

[features]
default = []
wasm = ["wasmtime", "wasmtime-wasi"]
wasm = ["wasmtime", "wasmtime-wasi", "sha2"]

[dev-dependencies]
tempfile = "3"
Expand Down
41 changes: 41 additions & 0 deletions crates/amplifier-core/src/bridges/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,44 @@ pub fn create_wasm_engine() -> Result<Arc<Engine>, Box<dyn std::error::Error + S

Ok(engine)
}

// ── Shared WASM function lookup ───────────────────────────────────────────────

/// Look up a typed function export from a component instance.
///
/// Component Model exports may be at the root level or nested inside an
/// exported interface instance. This helper tries:
/// 1. Direct root-level export by `func_name`
/// 2. Nested inside the `interface_name` exported instance
///
/// Shared by all WASM bridge modules to avoid duplicating identical lookup
/// logic across `wasm_tool`, `wasm_provider`, `wasm_context`, `wasm_hook`,
/// `wasm_orchestrator`, and `wasm_approval`.
#[cfg(feature = "wasm")]
pub(crate) fn get_typed_func<Params, Results>(
instance: &wasmtime::component::Instance,
store: &mut wasmtime::Store<wasm_tool::WasmState>,
func_name: &str,
interface_name: &str,
) -> Result<wasmtime::component::TypedFunc<Params, Results>, Box<dyn std::error::Error + Send + Sync>>
where
Params: wasmtime::component::Lower + wasmtime::component::ComponentNamedList,
Results: wasmtime::component::Lift + wasmtime::component::ComponentNamedList,
{
// Try direct root-level export first.
if let Ok(f) = instance.get_typed_func::<Params, Results>(&mut *store, func_name) {
return Ok(f);
}
// Try nested inside interface-exported instance.
let iface_idx = instance
.get_export_index(&mut *store, None, interface_name)
.ok_or_else(|| format!("export instance '{interface_name}' not found"))?;
let func_idx = instance
.get_export_index(&mut *store, Some(&iface_idx), func_name)
.ok_or_else(|| format!("export function '{func_name}' not found in '{interface_name}'"))?;
instance
.get_typed_func::<Params, Results>(&mut *store, &func_idx)
.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> {
format!("typed func lookup failed for '{func_name}': {e}").into()
})
}
45 changes: 5 additions & 40 deletions crates/amplifier-core/src/bridges/wasm_approval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,54 +13,17 @@ use std::pin::Pin;
use std::sync::Arc;

use wasmtime::component::Component;
use wasmtime::{Engine, Store};
use wasmtime::Engine;

use crate::errors::{AmplifierError, SessionError};
use crate::models::{ApprovalRequest, ApprovalResponse};
use crate::traits::ApprovalProvider;

use super::wasm_tool::{create_linker_and_store, WasmState};
use super::wasm_tool::create_linker_and_store;

/// The WIT interface name used by `cargo component` for approval provider exports.
const INTERFACE_NAME: &str = "amplifier:modules/approval-provider@1.0.0";

/// Convenience alias for the wasmtime typed function handle: takes (bytes) → result(bytes, string).
type RequestApprovalFunc = wasmtime::component::TypedFunc<(Vec<u8>,), (Result<Vec<u8>, String>,)>;

/// Shorthand for the fallible return type used by helper functions.
type WasmResult<T> = Result<T, Box<dyn std::error::Error + Send + Sync>>;

/// Look up the `request-approval` typed function export from a component instance.
///
/// Tries:
/// 1. Direct root-level export `"request-approval"`
/// 2. Nested inside the [`INTERFACE_NAME`] exported instance
fn get_request_approval_func(
instance: &wasmtime::component::Instance,
store: &mut Store<WasmState>,
) -> WasmResult<RequestApprovalFunc> {
// Try direct root-level export first.
if let Ok(f) = instance
.get_typed_func::<(Vec<u8>,), (Result<Vec<u8>, String>,)>(&mut *store, "request-approval")
{
return Ok(f);
}

// Try nested inside the interface-exported instance.
let iface_idx = instance
.get_export_index(&mut *store, None, INTERFACE_NAME)
.ok_or_else(|| format!("export instance '{INTERFACE_NAME}' not found"))?;
let func_idx = instance
.get_export_index(&mut *store, Some(&iface_idx), "request-approval")
.ok_or_else(|| {
format!("export function 'request-approval' not found in '{INTERFACE_NAME}'")
})?;
let func = instance
.get_typed_func::<(Vec<u8>,), (Result<Vec<u8>, String>,)>(&mut *store, &func_idx)
.map_err(|e| format!("typed func lookup failed for 'request-approval': {e}"))?;
Ok(func)
}

/// Helper: call the `request-approval` export on a fresh component instance.
///
/// The request bytes must be a JSON-serialized `ApprovalRequest`.
Expand All @@ -72,7 +35,9 @@ fn call_request_approval(
let (linker, mut store) = create_linker_and_store(engine, &super::WasmLimits::default())?;
let instance = linker.instantiate(&mut store, component)?;

let func = get_request_approval_func(&instance, &mut store)?;
let func = super::get_typed_func::<(Vec<u8>,), (Result<Vec<u8>, String>,)>(
&instance, &mut store, "request-approval", INTERFACE_NAME,
)?;
let (result,) = func.call(&mut store, (request_bytes,))?;
match result {
Ok(bytes) => Ok(bytes),
Expand Down
Loading
Loading