Skip to content

Commit 4d519f9

Browse files
committed
feat: implement SEP-1577 sampling with tools support
1 parent 8d09f88 commit 4d519f9

12 files changed

Lines changed: 1787 additions & 141 deletions

crates/rmcp/src/model.rs

Lines changed: 301 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,18 +1209,292 @@ pub enum Role {
12091209
Assistant,
12101210
}
12111211

1212+
/// Tool selection mode for sampling requests (SEP-1577).
1213+
///
1214+
/// Controls how the model should handle tool calling during message generation.
1215+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1216+
#[serde(rename_all = "lowercase")]
1217+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1218+
pub enum ToolChoiceMode {
1219+
/// Let the model decide whether to call tools
1220+
Auto,
1221+
/// The model must call at least one tool
1222+
Required,
1223+
/// The model must not call any tools
1224+
None,
1225+
}
1226+
1227+
impl Default for ToolChoiceMode {
1228+
fn default() -> Self {
1229+
Self::Auto
1230+
}
1231+
}
1232+
1233+
/// Tool choice configuration for sampling requests (SEP-1577).
1234+
///
1235+
/// Controls how the model should select and use tools when generating responses.
1236+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
1237+
#[serde(rename_all = "camelCase")]
1238+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1239+
pub struct ToolChoice {
1240+
/// The mode for tool selection
1241+
#[serde(skip_serializing_if = "Option::is_none")]
1242+
pub mode: Option<ToolChoiceMode>,
1243+
}
1244+
1245+
impl ToolChoice {
1246+
/// Create a new ToolChoice with auto mode
1247+
pub fn auto() -> Self {
1248+
Self {
1249+
mode: Some(ToolChoiceMode::Auto),
1250+
}
1251+
}
1252+
1253+
/// Create a new ToolChoice with required mode
1254+
pub fn required() -> Self {
1255+
Self {
1256+
mode: Some(ToolChoiceMode::Required),
1257+
}
1258+
}
1259+
1260+
/// Create a new ToolChoice with none mode
1261+
pub fn none() -> Self {
1262+
Self {
1263+
mode: Some(ToolChoiceMode::None),
1264+
}
1265+
}
1266+
}
1267+
1268+
/// Content for a sampling message that can be a single item or an array (SEP-1577).
1269+
///
1270+
/// This wrapper type handles the fact that content can be either a single
1271+
/// content block or an array of content blocks, depending on the context.
1272+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1273+
#[serde(untagged)]
1274+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1275+
pub enum SamplingContent<T> {
1276+
/// A single content block
1277+
Single(T),
1278+
/// Multiple content blocks
1279+
Multiple(Vec<T>),
1280+
}
1281+
1282+
impl<T> SamplingContent<T> {
1283+
/// Convert to a Vec regardless of whether it's single or multiple
1284+
pub fn into_vec(self) -> Vec<T> {
1285+
match self {
1286+
SamplingContent::Single(item) => vec![item],
1287+
SamplingContent::Multiple(items) => items,
1288+
}
1289+
}
1290+
1291+
/// Check if the content is empty
1292+
pub fn is_empty(&self) -> bool {
1293+
match self {
1294+
SamplingContent::Single(_) => false,
1295+
SamplingContent::Multiple(items) => items.is_empty(),
1296+
}
1297+
}
1298+
1299+
/// Get the number of content items
1300+
pub fn len(&self) -> usize {
1301+
match self {
1302+
SamplingContent::Single(_) => 1,
1303+
SamplingContent::Multiple(items) => items.len(),
1304+
}
1305+
}
1306+
}
1307+
1308+
impl<T> Default for SamplingContent<T> {
1309+
fn default() -> Self {
1310+
SamplingContent::Multiple(Vec::new())
1311+
}
1312+
}
1313+
1314+
impl<T> SamplingContent<T> {
1315+
/// Get the first item if present
1316+
pub fn first(&self) -> Option<&T> {
1317+
match self {
1318+
SamplingContent::Single(item) => Some(item),
1319+
SamplingContent::Multiple(items) => items.first(),
1320+
}
1321+
}
1322+
1323+
/// Iterate over all content items
1324+
pub fn iter(&self) -> impl Iterator<Item = &T> {
1325+
let items: Vec<&T> = match self {
1326+
SamplingContent::Single(item) => vec![item],
1327+
SamplingContent::Multiple(items) => items.iter().collect(),
1328+
};
1329+
items.into_iter()
1330+
}
1331+
}
1332+
1333+
impl SamplingMessageContent {
1334+
/// Get the text content if this is a Text variant
1335+
pub fn as_text(&self) -> Option<&RawTextContent> {
1336+
match self {
1337+
SamplingMessageContent::Text(text) => Some(text),
1338+
_ => None,
1339+
}
1340+
}
1341+
1342+
/// Get the tool use content if this is a ToolUse variant
1343+
pub fn as_tool_use(&self) -> Option<&ToolUseContent> {
1344+
match self {
1345+
SamplingMessageContent::ToolUse(tool_use) => Some(tool_use),
1346+
_ => None,
1347+
}
1348+
}
1349+
1350+
/// Get the tool result content if this is a ToolResult variant
1351+
pub fn as_tool_result(&self) -> Option<&ToolResultContent> {
1352+
match self {
1353+
SamplingMessageContent::ToolResult(tool_result) => Some(tool_result),
1354+
_ => None,
1355+
}
1356+
}
1357+
}
1358+
1359+
impl<T> From<T> for SamplingContent<T> {
1360+
fn from(item: T) -> Self {
1361+
SamplingContent::Single(item)
1362+
}
1363+
}
1364+
1365+
impl<T> From<Vec<T>> for SamplingContent<T> {
1366+
fn from(items: Vec<T>) -> Self {
1367+
SamplingContent::Multiple(items)
1368+
}
1369+
}
1370+
12121371
/// A message in a sampling conversation, containing a role and content.
12131372
///
12141373
/// This represents a single message in a conversation flow, used primarily
12151374
/// in LLM sampling requests where the conversation history is important
12161375
/// for generating appropriate responses.
1376+
///
1377+
/// Per SEP-1577, the content can be a single content block or an array of
1378+
/// content blocks. For user messages, content can include tool results.
1379+
/// For assistant messages, content can include tool use requests.
12171380
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
12181381
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
12191382
pub struct SamplingMessage {
12201383
/// The role of the message sender (User or Assistant)
12211384
pub role: Role,
1222-
/// The actual content of the message (text, image, etc.)
1223-
pub content: Content,
1385+
/// The actual content of the message (text, image, tool use/result, etc.)
1386+
/// Can be a single content or an array of contents (SEP-1577)
1387+
pub content: SamplingContent<SamplingMessageContent>,
1388+
}
1389+
1390+
/// Content types for sampling messages (SEP-1577).
1391+
///
1392+
/// This enum represents all possible content types that can appear in
1393+
/// sampling messages. The appropriate content types depend on the role:
1394+
/// - User messages: text, image, audio, tool_result
1395+
/// - Assistant messages: text, image, audio, tool_use
1396+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1397+
#[serde(tag = "type", rename_all = "snake_case")]
1398+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1399+
pub enum SamplingMessageContent {
1400+
/// Plain text content
1401+
Text(RawTextContent),
1402+
/// Image content with base64-encoded data
1403+
Image(RawImageContent),
1404+
/// Audio content with base64-encoded data
1405+
Audio(RawAudioContent),
1406+
/// A request to use a tool (assistant messages only, SEP-1577)
1407+
ToolUse(ToolUseContent),
1408+
/// The result of a tool call (user messages only, SEP-1577)
1409+
ToolResult(ToolResultContent),
1410+
}
1411+
1412+
impl SamplingMessageContent {
1413+
/// Create a text content
1414+
pub fn text(text: impl Into<String>) -> Self {
1415+
Self::Text(RawTextContent {
1416+
text: text.into(),
1417+
meta: None,
1418+
})
1419+
}
1420+
1421+
/// Create a tool use content
1422+
pub fn tool_use(id: impl Into<String>, name: impl Into<String>, input: JsonObject) -> Self {
1423+
Self::ToolUse(ToolUseContent::new(id, name, input))
1424+
}
1425+
1426+
/// Create a tool result content
1427+
pub fn tool_result(tool_use_id: impl Into<String>, content: Vec<Content>) -> Self {
1428+
Self::ToolResult(ToolResultContent::new(tool_use_id, content))
1429+
}
1430+
}
1431+
1432+
impl SamplingMessage {
1433+
/// Create a new sampling message with a single content
1434+
pub fn new(role: Role, content: impl Into<SamplingMessageContent>) -> Self {
1435+
Self {
1436+
role,
1437+
content: SamplingContent::Single(content.into()),
1438+
}
1439+
}
1440+
1441+
/// Create a new sampling message with multiple contents
1442+
pub fn new_multiple(role: Role, contents: Vec<SamplingMessageContent>) -> Self {
1443+
Self {
1444+
role,
1445+
content: SamplingContent::Multiple(contents),
1446+
}
1447+
}
1448+
1449+
/// Create a user message with text content
1450+
pub fn user_text(text: impl Into<String>) -> Self {
1451+
Self::new(Role::User, SamplingMessageContent::text(text))
1452+
}
1453+
1454+
/// Create an assistant message with text content
1455+
pub fn assistant_text(text: impl Into<String>) -> Self {
1456+
Self::new(Role::Assistant, SamplingMessageContent::text(text))
1457+
}
1458+
1459+
/// Create a user message with tool result content
1460+
pub fn user_tool_result(tool_use_id: impl Into<String>, content: Vec<Content>) -> Self {
1461+
Self::new(
1462+
Role::User,
1463+
SamplingMessageContent::tool_result(tool_use_id, content),
1464+
)
1465+
}
1466+
1467+
/// Create an assistant message with tool use content
1468+
pub fn assistant_tool_use(
1469+
id: impl Into<String>,
1470+
name: impl Into<String>,
1471+
input: JsonObject,
1472+
) -> Self {
1473+
Self::new(
1474+
Role::Assistant,
1475+
SamplingMessageContent::tool_use(id, name, input),
1476+
)
1477+
}
1478+
}
1479+
1480+
// Conversion from RawTextContent to SamplingMessageContent
1481+
impl From<RawTextContent> for SamplingMessageContent {
1482+
fn from(text: RawTextContent) -> Self {
1483+
SamplingMessageContent::Text(text)
1484+
}
1485+
}
1486+
1487+
// Conversion from String to SamplingMessageContent (as text)
1488+
impl From<String> for SamplingMessageContent {
1489+
fn from(text: String) -> Self {
1490+
SamplingMessageContent::text(text)
1491+
}
1492+
}
1493+
1494+
impl From<&str> for SamplingMessageContent {
1495+
fn from(text: &str) -> Self {
1496+
SamplingMessageContent::text(text)
1497+
}
12241498
}
12251499

12261500
/// Specifies how much context should be included in sampling requests.
@@ -1267,7 +1541,9 @@ pub struct CreateMessageRequestParams {
12671541
/// System prompt to guide the model's behavior
12681542
#[serde(skip_serializing_if = "Option::is_none")]
12691543
pub system_prompt: Option<String>,
1270-
/// How much context to include from MCP servers
1544+
/// How much context to include from MCP servers.
1545+
/// Note: Values other than "none" are soft-deprecated per SEP-1577
1546+
/// and require `clientCapabilities.sampling.context` support.
12711547
#[serde(skip_serializing_if = "Option::is_none")]
12721548
pub include_context: Option<ContextInclusion>,
12731549
/// Temperature for controlling randomness (0.0 to 1.0)
@@ -1281,6 +1557,14 @@ pub struct CreateMessageRequestParams {
12811557
/// Additional metadata for the request
12821558
#[serde(skip_serializing_if = "Option::is_none")]
12831559
pub metadata: Option<Value>,
1560+
/// Tools available for the model to call (SEP-1577).
1561+
/// Requires `clientCapabilities.sampling.tools` support.
1562+
#[serde(skip_serializing_if = "Option::is_none")]
1563+
pub tools: Option<Vec<Tool>>,
1564+
/// Configuration for tool selection behavior (SEP-1577).
1565+
/// Requires `clientCapabilities.sampling.tools` support.
1566+
#[serde(skip_serializing_if = "Option::is_none")]
1567+
pub tool_choice: Option<ToolChoice>,
12841568
}
12851569

12861570
impl RequestParamsMeta for CreateMessageRequestParams {
@@ -1930,13 +2214,17 @@ pub type CallToolRequest = Request<CallToolRequestMethod, CallToolRequestParams>
19302214
///
19312215
/// This structure contains the generated message along with metadata about
19322216
/// how the generation was performed and why it stopped.
2217+
///
2218+
/// Per SEP-1577, the content can be a single content block or an array of
2219+
/// content blocks when the model uses tools.
19332220
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
19342221
#[serde(rename_all = "camelCase")]
19352222
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
19362223
pub struct CreateMessageResult {
19372224
/// The identifier of the model that generated the response
19382225
pub model: String,
1939-
/// The reason why generation stopped (e.g., "endTurn", "maxTokens")
2226+
/// The reason why generation stopped.
2227+
/// Common values: "endTurn", "stopSequence", "maxTokens", "toolUse"
19402228
#[serde(skip_serializing_if = "Option::is_none")]
19412229
pub stop_reason: Option<String>,
19422230
/// The generated message with role and content
@@ -1945,9 +2233,14 @@ pub struct CreateMessageResult {
19452233
}
19462234

19472235
impl CreateMessageResult {
2236+
/// Stop reason: The model naturally ended its turn
19482237
pub const STOP_REASON_END_TURN: &str = "endTurn";
2238+
/// Stop reason: A stop sequence was encountered
19492239
pub const STOP_REASON_END_SEQUENCE: &str = "stopSequence";
2240+
/// Stop reason: Maximum token limit reached
19502241
pub const STOP_REASON_END_MAX_TOKEN: &str = "maxTokens";
2242+
/// Stop reason: The model wants to use a tool (SEP-1577)
2243+
pub const STOP_REASON_TOOL_USE: &str = "toolUse";
19512244
}
19522245

19532246
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
@@ -2476,7 +2769,10 @@ mod tests {
24762769
..
24772770
}) => {
24782771
assert_eq!(capabilities.roots.unwrap().list_changed, Some(true));
2479-
assert_eq!(capabilities.sampling.unwrap().len(), 0);
2772+
// Empty sampling capability (no tools or context sub-capabilities)
2773+
let sampling = capabilities.sampling.unwrap();
2774+
assert_eq!(sampling.tools, None);
2775+
assert_eq!(sampling.context, None);
24802776
assert_eq!(client_info.name, "ExampleClient");
24812777
assert_eq!(client_info.version, "1.0.0");
24822778
}

0 commit comments

Comments
 (0)