Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,18 @@
* @see ServerTransportSecurityValidator
* @see ServerTransportSecurityException
*/
public class DefaultServerTransportSecurityValidator implements ServerTransportSecurityValidator {
public final class DefaultServerTransportSecurityValidator implements ServerTransportSecurityValidator {

private static final String ORIGIN_HEADER = "Origin";

private static final ServerTransportSecurityException INVALID_ORIGIN = new ServerTransportSecurityException(403,
"Invalid Origin header");

private final List<String> allowedOrigins;

/**
* Creates a new validator with the specified allowed origins.
* @param allowedOrigins List of allowed origin patterns. Supports exact matches
* (e.g., "http://example.com:8080") and wildcard ports (e.g., "http://example.com:*")
*/
public DefaultServerTransportSecurityValidator(List<String> allowedOrigins) {
private DefaultServerTransportSecurityValidator(List<String> allowedOrigins) {
Assert.notNull(allowedOrigins, "allowedOrigins must not be null");
this.allowedOrigins = allowedOrigins;
}
Expand Down Expand Up @@ -79,7 +76,7 @@ else if (allowed.endsWith(":*")) {

}

throw INVALID_ORIGIN;
throw new ServerTransportSecurityException(403, "Invalid Origin header");
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2026-2026 the original author or authors.
*/

package io.modelcontextprotocol.server.transport;

import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import jakarta.servlet.http.HttpServletRequest;

/**
* Utility methods for working with {@link HttpServletRequest}. For internal use only.
*
* @author Daniel Garnier-Moiroux
*/
final class HttpServletRequestUtils {

private HttpServletRequestUtils() {
}

/**
* Extracts all headers from the HTTP request into a map.
* @param request The HTTP servlet request
* @return A map of header names to their values
*/
static Map<String, List<String>> extractHeaders(HttpServletRequest request) {
Map<String, List<String>> headers = new HashMap<>();
Enumeration<String> names = request.getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
headers.put(name, Collections.list(request.getHeaders(name)));
}
return headers;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
import java.io.IOException;
import java.io.PrintWriter;
import java.time.Duration;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
Expand Down Expand Up @@ -258,7 +255,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
}

try {
Map<String, List<String>> headers = extractHeaders(request);
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
}
catch (ServerTransportSecurityException e) {
Expand Down Expand Up @@ -332,7 +329,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
}

try {
Map<String, List<String>> headers = extractHeaders(request);
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
}
catch (ServerTransportSecurityException e) {
Expand Down Expand Up @@ -440,21 +437,6 @@ private void sendEvent(PrintWriter writer, String eventType, String data) throws
}
}

/**
* Extracts all headers from the HTTP servlet request into a map.
* @param request The HTTP servlet request
* @return A map of header names to their values
*/
private Map<String, List<String>> extractHeaders(HttpServletRequest request) {
Map<String, List<String>> headers = new HashMap<>();
Enumeration<String> names = request.getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
headers.put(name, Collections.list(request.getHeaders(name)));
}
return headers;
}

/**
* Cleans up resources when the servlet is being destroyed.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -137,7 +134,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
}

try {
Map<String, List<String>> headers = extractHeaders(request);
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
}
catch (ServerTransportSecurityException e) {
Expand Down Expand Up @@ -232,21 +229,6 @@ private void responseError(HttpServletResponse response, int httpCode, McpError
writer.flush();
}

/**
* Extracts all headers from the HTTP servlet request into a map.
* @param request The HTTP servlet request
* @return A map of header names to their values
*/
private Map<String, List<String>> extractHeaders(HttpServletRequest request) {
Map<String, List<String>> headers = new HashMap<>();
Enumeration<String> names = request.getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
headers.put(name, Collections.list(request.getHeaders(name)));
}
return headers;
}

/**
* Cleans up resources when the servlet is being destroyed.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
import java.io.PrintWriter;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -262,7 +259,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
}

try {
Map<String, List<String>> headers = extractHeaders(request);
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
}
catch (ServerTransportSecurityException e) {
Expand Down Expand Up @@ -398,7 +395,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
}

try {
Map<String, List<String>> headers = extractHeaders(request);
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
}
catch (ServerTransportSecurityException e) {
Expand Down Expand Up @@ -570,7 +567,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
}

try {
Map<String, List<String>> headers = extractHeaders(request);
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
}
catch (ServerTransportSecurityException e) {
Expand Down Expand Up @@ -628,21 +625,6 @@ public void responseError(HttpServletResponse response, int httpCode, McpError m
return;
}

/**
* Extracts all headers from the HTTP servlet request into a map.
* @param request The HTTP servlet request
* @return A map of header names to their values
*/
private Map<String, List<String>> extractHeaders(HttpServletRequest request) {
Map<String, List<String>> headers = new HashMap<>();
Enumeration<String> names = request.getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
headers.put(name, Collections.list(request.getHeaders(name)));
}
return headers;
}

/**
* Sends an SSE event to a client with a specific ID.
* @param writer The writer to send the event through
Expand Down