Skip to content

Commit c2da172

Browse files
committed
feat: implement SEP-1577 sampling with tools support
1 parent 8d09f88 commit c2da172

12 files changed

Lines changed: 1737 additions & 181 deletions

crates/rmcp/src/model.rs

Lines changed: 257 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,152 @@ pub enum Role {
12091209
Assistant,
12101210
}
12111211

1212+
/// Tool selection mode (SEP-1577).
1213+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1214+
#[serde(rename_all = "lowercase")]
1215+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1216+
pub enum ToolChoiceMode {
1217+
/// Model decides whether to use tools
1218+
Auto,
1219+
/// Model must use at least one tool
1220+
Required,
1221+
/// Model must not use tools
1222+
None,
1223+
}
1224+
1225+
impl Default for ToolChoiceMode {
1226+
fn default() -> Self {
1227+
Self::Auto
1228+
}
1229+
}
1230+
1231+
/// Tool choice configuration (SEP-1577).
1232+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
1233+
#[serde(rename_all = "camelCase")]
1234+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1235+
pub struct ToolChoice {
1236+
#[serde(skip_serializing_if = "Option::is_none")]
1237+
pub mode: Option<ToolChoiceMode>,
1238+
}
1239+
1240+
impl ToolChoice {
1241+
pub fn auto() -> Self {
1242+
Self {
1243+
mode: Some(ToolChoiceMode::Auto),
1244+
}
1245+
}
1246+
1247+
pub fn required() -> Self {
1248+
Self {
1249+
mode: Some(ToolChoiceMode::Required),
1250+
}
1251+
}
1252+
1253+
pub fn none() -> Self {
1254+
Self {
1255+
mode: Some(ToolChoiceMode::None),
1256+
}
1257+
}
1258+
}
1259+
1260+
/// Single or array content wrapper (SEP-1577).
1261+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1262+
#[serde(untagged)]
1263+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1264+
pub enum SamplingContent<T> {
1265+
Single(T),
1266+
Multiple(Vec<T>),
1267+
}
1268+
1269+
impl<T> SamplingContent<T> {
1270+
/// Convert to a Vec regardless of whether it's single or multiple
1271+
pub fn into_vec(self) -> Vec<T> {
1272+
match self {
1273+
SamplingContent::Single(item) => vec![item],
1274+
SamplingContent::Multiple(items) => items,
1275+
}
1276+
}
1277+
1278+
/// Check if the content is empty
1279+
pub fn is_empty(&self) -> bool {
1280+
match self {
1281+
SamplingContent::Single(_) => false,
1282+
SamplingContent::Multiple(items) => items.is_empty(),
1283+
}
1284+
}
1285+
1286+
/// Get the number of content items
1287+
pub fn len(&self) -> usize {
1288+
match self {
1289+
SamplingContent::Single(_) => 1,
1290+
SamplingContent::Multiple(items) => items.len(),
1291+
}
1292+
}
1293+
}
1294+
1295+
impl<T> Default for SamplingContent<T> {
1296+
fn default() -> Self {
1297+
SamplingContent::Multiple(Vec::new())
1298+
}
1299+
}
1300+
1301+
impl<T> SamplingContent<T> {
1302+
/// Get the first item if present
1303+
pub fn first(&self) -> Option<&T> {
1304+
match self {
1305+
SamplingContent::Single(item) => Some(item),
1306+
SamplingContent::Multiple(items) => items.first(),
1307+
}
1308+
}
1309+
1310+
/// Iterate over all content items
1311+
pub fn iter(&self) -> impl Iterator<Item = &T> {
1312+
let items: Vec<&T> = match self {
1313+
SamplingContent::Single(item) => vec![item],
1314+
SamplingContent::Multiple(items) => items.iter().collect(),
1315+
};
1316+
items.into_iter()
1317+
}
1318+
}
1319+
1320+
impl SamplingMessageContent {
1321+
/// Get the text content if this is a Text variant
1322+
pub fn as_text(&self) -> Option<&RawTextContent> {
1323+
match self {
1324+
SamplingMessageContent::Text(text) => Some(text),
1325+
_ => None,
1326+
}
1327+
}
1328+
1329+
/// Get the tool use content if this is a ToolUse variant
1330+
pub fn as_tool_use(&self) -> Option<&ToolUseContent> {
1331+
match self {
1332+
SamplingMessageContent::ToolUse(tool_use) => Some(tool_use),
1333+
_ => None,
1334+
}
1335+
}
1336+
1337+
/// Get the tool result content if this is a ToolResult variant
1338+
pub fn as_tool_result(&self) -> Option<&ToolResultContent> {
1339+
match self {
1340+
SamplingMessageContent::ToolResult(tool_result) => Some(tool_result),
1341+
_ => None,
1342+
}
1343+
}
1344+
}
1345+
1346+
impl<T> From<T> for SamplingContent<T> {
1347+
fn from(item: T) -> Self {
1348+
SamplingContent::Single(item)
1349+
}
1350+
}
1351+
1352+
impl<T> From<Vec<T>> for SamplingContent<T> {
1353+
fn from(items: Vec<T>) -> Self {
1354+
SamplingContent::Multiple(items)
1355+
}
1356+
}
1357+
12121358
/// A message in a sampling conversation, containing a role and content.
12131359
///
12141360
/// This represents a single message in a conversation flow, used primarily
@@ -1219,8 +1365,106 @@ pub enum Role {
12191365
pub struct SamplingMessage {
12201366
/// The role of the message sender (User or Assistant)
12211367
pub role: Role,
1222-
/// The actual content of the message (text, image, etc.)
1223-
pub content: Content,
1368+
/// The actual content of the message (text, image, audio, tool use, or tool result)
1369+
pub content: SamplingContent<SamplingMessageContent>,
1370+
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
1371+
pub meta: Option<Meta>,
1372+
}
1373+
1374+
/// Content types for sampling messages (SEP-1577).
1375+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1376+
#[serde(tag = "type", rename_all = "snake_case")]
1377+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1378+
pub enum SamplingMessageContent {
1379+
Text(RawTextContent),
1380+
Image(RawImageContent),
1381+
Audio(RawAudioContent),
1382+
/// Assistant only
1383+
ToolUse(ToolUseContent),
1384+
/// User only
1385+
ToolResult(ToolResultContent),
1386+
}
1387+
1388+
impl SamplingMessageContent {
1389+
/// Create a text content
1390+
pub fn text(text: impl Into<String>) -> Self {
1391+
Self::Text(RawTextContent {
1392+
text: text.into(),
1393+
meta: None,
1394+
})
1395+
}
1396+
1397+
pub fn tool_use(id: impl Into<String>, name: impl Into<String>, input: JsonObject) -> Self {
1398+
Self::ToolUse(ToolUseContent::new(id, name, input))
1399+
}
1400+
1401+
pub fn tool_result(tool_use_id: impl Into<String>, content: Vec<Content>) -> Self {
1402+
Self::ToolResult(ToolResultContent::new(tool_use_id, content))
1403+
}
1404+
}
1405+
1406+
impl SamplingMessage {
1407+
pub fn new(role: Role, content: impl Into<SamplingMessageContent>) -> Self {
1408+
Self {
1409+
role,
1410+
content: SamplingContent::Single(content.into()),
1411+
meta: None,
1412+
}
1413+
}
1414+
1415+
pub fn new_multiple(role: Role, contents: Vec<SamplingMessageContent>) -> Self {
1416+
Self {
1417+
role,
1418+
content: SamplingContent::Multiple(contents),
1419+
meta: None,
1420+
}
1421+
}
1422+
1423+
pub fn user_text(text: impl Into<String>) -> Self {
1424+
Self::new(Role::User, SamplingMessageContent::text(text))
1425+
}
1426+
1427+
pub fn assistant_text(text: impl Into<String>) -> Self {
1428+
Self::new(Role::Assistant, SamplingMessageContent::text(text))
1429+
}
1430+
1431+
pub fn user_tool_result(tool_use_id: impl Into<String>, content: Vec<Content>) -> Self {
1432+
Self::new(
1433+
Role::User,
1434+
SamplingMessageContent::tool_result(tool_use_id, content),
1435+
)
1436+
}
1437+
1438+
pub fn assistant_tool_use(
1439+
id: impl Into<String>,
1440+
name: impl Into<String>,
1441+
input: JsonObject,
1442+
) -> Self {
1443+
Self::new(
1444+
Role::Assistant,
1445+
SamplingMessageContent::tool_use(id, name, input),
1446+
)
1447+
}
1448+
}
1449+
1450+
// Conversion from RawTextContent to SamplingMessageContent
1451+
impl From<RawTextContent> for SamplingMessageContent {
1452+
fn from(text: RawTextContent) -> Self {
1453+
SamplingMessageContent::Text(text)
1454+
}
1455+
}
1456+
1457+
// Conversion from String to SamplingMessageContent (as text)
1458+
impl From<String> for SamplingMessageContent {
1459+
fn from(text: String) -> Self {
1460+
SamplingMessageContent::text(text)
1461+
}
1462+
}
1463+
1464+
impl From<&str> for SamplingMessageContent {
1465+
fn from(text: &str) -> Self {
1466+
SamplingMessageContent::text(text)
1467+
}
12241468
}
12251469

12261470
/// Specifies how much context should be included in sampling requests.
@@ -1281,6 +1525,12 @@ pub struct CreateMessageRequestParams {
12811525
/// Additional metadata for the request
12821526
#[serde(skip_serializing_if = "Option::is_none")]
12831527
pub metadata: Option<Value>,
1528+
/// Tools available for the model to call (SEP-1577)
1529+
#[serde(skip_serializing_if = "Option::is_none")]
1530+
pub tools: Option<Vec<Tool>>,
1531+
/// Tool selection behavior (SEP-1577)
1532+
#[serde(skip_serializing_if = "Option::is_none")]
1533+
pub tool_choice: Option<ToolChoice>,
12841534
}
12851535

12861536
impl RequestParamsMeta for CreateMessageRequestParams {
@@ -1926,6 +2176,7 @@ pub type CallToolRequestParam = CallToolRequestParams;
19262176
/// Request to call a specific tool
19272177
pub type CallToolRequest = Request<CallToolRequestMethod, CallToolRequestParams>;
19282178

2179+
/// Result of sampling/createMessage (SEP-1577).
19292180
/// The result of a sampling/createMessage request containing the generated response.
19302181
///
19312182
/// This structure contains the generated message along with metadata about
@@ -1948,6 +2199,7 @@ impl CreateMessageResult {
19482199
pub const STOP_REASON_END_TURN: &str = "endTurn";
19492200
pub const STOP_REASON_END_SEQUENCE: &str = "stopSequence";
19502201
pub const STOP_REASON_END_MAX_TOKEN: &str = "maxTokens";
2202+
pub const STOP_REASON_TOOL_USE: &str = "toolUse";
19512203
}
19522204

19532205
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
@@ -2476,7 +2728,9 @@ mod tests {
24762728
..
24772729
}) => {
24782730
assert_eq!(capabilities.roots.unwrap().list_changed, Some(true));
2479-
assert_eq!(capabilities.sampling.unwrap().len(), 0);
2731+
let sampling = capabilities.sampling.unwrap();
2732+
assert_eq!(sampling.tools, None);
2733+
assert_eq!(sampling.context, None);
24802734
assert_eq!(client_info.name, "ExampleClient");
24812735
assert_eq!(client_info.version, "1.0.0");
24822736
}

crates/rmcp/src/model/capabilities.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,19 @@ pub struct ElicitationCapability {
172172
pub schema_validation: Option<bool>,
173173
}
174174

175+
/// Sampling capability with optional sub-capabilities (SEP-1577).
176+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
177+
#[serde(rename_all = "camelCase")]
178+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
179+
pub struct SamplingCapability {
180+
/// Support for `tools` and `toolChoice` parameters
181+
#[serde(skip_serializing_if = "Option::is_none")]
182+
pub tools: Option<JsonObject>,
183+
/// Support for `includeContext` (soft-deprecated)
184+
#[serde(skip_serializing_if = "Option::is_none")]
185+
pub context: Option<JsonObject>,
186+
}
187+
175188
///
176189
/// # Builder
177190
/// ```rust
@@ -189,8 +202,9 @@ pub struct ClientCapabilities {
189202
pub experimental: Option<ExperimentalCapabilities>,
190203
#[serde(skip_serializing_if = "Option::is_none")]
191204
pub roots: Option<RootsCapabilities>,
205+
/// Capability for LLM sampling requests (SEP-1577)
192206
#[serde(skip_serializing_if = "Option::is_none")]
193-
pub sampling: Option<JsonObject>,
207+
pub sampling: Option<SamplingCapability>,
194208
/// Capability to handle elicitation requests from servers for interactive user input
195209
#[serde(skip_serializing_if = "Option::is_none")]
196210
pub elicitation: Option<ElicitationCapability>,
@@ -392,7 +406,7 @@ builder! {
392406
ClientCapabilities{
393407
experimental: ExperimentalCapabilities,
394408
roots: RootsCapabilities,
395-
sampling: JsonObject,
409+
sampling: SamplingCapability,
396410
elicitation: ElicitationCapability,
397411
tasks: TasksCapability,
398412
}
@@ -409,6 +423,26 @@ impl<const E: bool, const S: bool, const EL: bool, const TASKS: bool>
409423
}
410424
}
411425

426+
impl<const E: bool, const R: bool, const EL: bool, const TASKS: bool>
427+
ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<E, R, true, EL, TASKS>>
428+
{
429+
/// Enable tool calling in sampling requests
430+
pub fn enable_sampling_tools(mut self) -> Self {
431+
if let Some(c) = self.sampling.as_mut() {
432+
c.tools = Some(JsonObject::default());
433+
}
434+
self
435+
}
436+
437+
/// Enable context inclusion in sampling (soft-deprecated)
438+
pub fn enable_sampling_context(mut self) -> Self {
439+
if let Some(c) = self.sampling.as_mut() {
440+
c.context = Some(JsonObject::default());
441+
}
442+
self
443+
}
444+
}
445+
412446
#[cfg(feature = "elicitation")]
413447
impl<const E: bool, const R: bool, const S: bool, const TASKS: bool>
414448
ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<E, R, S, true, TASKS>>

0 commit comments

Comments
 (0)