diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletRequestHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/ServletRequestHeadersAdapter.java index 04a57b08812d..58b1f90aac47 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServletRequestHeadersAdapter.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServletRequestHeadersAdapter.java @@ -16,6 +16,7 @@ package org.springframework.http.server; +import java.nio.charset.Charset; import java.util.AbstractSet; import java.util.ArrayList; import java.util.Collection; @@ -33,10 +34,13 @@ import org.jspecify.annotations.Nullable; import org.springframework.http.HttpHeaders; +import org.springframework.http.InvalidMediaTypeException; +import org.springframework.http.MediaType; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedCaseInsensitiveMap; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; /** * {@code MultiValueMap} implementation for wrapping Servlet request headers. @@ -48,6 +52,11 @@ final class ServletRequestHeadersAdapter implements MultiValueMap get(Object key) { if (key instanceof String headerName) { + if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(headerName)) { + String contentType = getContentType(); + return (contentType != null ? Collections.singletonList(contentType) : null); + } Enumeration values = this.request.getHeaders(headerName); if (values.hasMoreElements()) { String value = values.nextElement(); @@ -178,6 +194,44 @@ public Set>> entrySet() { throw httpHeadersMapException(); } + /** + * Return the Content-Type header value, appending the charset from + * {@link HttpServletRequest#getCharacterEncoding()} if the Content-Type does not + * already include a {@code charset} parameter and the media type is not + * {@code application/json}. + *

The computed value is cached to avoid repeated string building. + */ + private @Nullable String getContentType() { + String stringContentType = this.cachedContentType; + if (stringContentType != null) { + return stringContentType; + } + + stringContentType = this.request.getContentType(); + try { + MediaType contentType = stringContentType != null ? MediaType.parseMediaType(stringContentType) : null; + if (contentType != null && contentType.getCharset() == null) { + String requestEncoding = this.request.getCharacterEncoding(); + if (StringUtils.hasLength(requestEncoding)) { + Charset charset = Charset.forName(requestEncoding); + Map params = new LinkedCaseInsensitiveMap<>(); + params.putAll(contentType.getParameters()); + if (!MediaType.APPLICATION_JSON.equals(contentType)) { + params.put("charset", charset.toString()); + } + MediaType mediaType = new MediaType(contentType.getType(), contentType.getSubtype(), params); + stringContentType = mediaType.toString(); + } + } + } + catch (InvalidMediaTypeException ex) { + // Ignore: simply not exposing an invalid content type in HttpHeaders... + } + + this.cachedContentType = stringContentType; + return stringContentType; + } + private static UnsupportedOperationException immutableRequestException() { return new UnsupportedOperationException("Request headers are immutable"); } diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java index d2a2d3c32bb9..7534cc0c6db6 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java @@ -45,10 +45,8 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; -import org.springframework.http.InvalidMediaTypeException; import org.springframework.http.MediaType; import org.springframework.util.Assert; -import org.springframework.util.LinkedCaseInsensitiveMap; import org.springframework.util.StringUtils; /** @@ -160,33 +158,6 @@ public HttpHeaders getHeaders() { // HttpServletRequest exposes some headers as properties: // we should include those if not already present - try { - MediaType contentType = this.headers.getContentType(); - if (contentType == null) { - String requestContentType = this.servletRequest.getContentType(); - if (StringUtils.hasLength(requestContentType)) { - contentType = MediaType.parseMediaType(requestContentType); - if (contentType.isConcrete()) { - this.headers.setContentType(contentType); - } - } - } - if (contentType != null && contentType.getCharset() == null) { - String requestEncoding = this.servletRequest.getCharacterEncoding(); - if (StringUtils.hasLength(requestEncoding)) { - Charset charset = Charset.forName(requestEncoding); - Map params = new LinkedCaseInsensitiveMap<>(); - params.putAll(contentType.getParameters()); - params.put("charset", charset.toString()); - MediaType mediaType = new MediaType(contentType.getType(), contentType.getSubtype(), params); - this.headers.setContentType(mediaType); - } - } - } - catch (InvalidMediaTypeException ex) { - // Ignore: simply not exposing an invalid content type in HttpHeaders... - } - if (this.headers.getContentLength() < 0) { int requestContentLength = this.servletRequest.getContentLength(); if (requestContentLength != -1) { diff --git a/spring-web/src/test/java/org/springframework/http/server/ServletRequestHeadersAdapterTests.java b/spring-web/src/test/java/org/springframework/http/server/ServletRequestHeadersAdapterTests.java index 560d5410efe9..03ef8ecb5776 100644 --- a/spring-web/src/test/java/org/springframework/http/server/ServletRequestHeadersAdapterTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/ServletRequestHeadersAdapterTests.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; +import org.springframework.http.HttpHeaders; import org.springframework.util.MultiValueMap; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; @@ -44,4 +45,38 @@ void caseSensitiveOverride() { assertThat(headersAdapter.get("foo")).containsExactly("override value"); } + @Test // gh-36426 + void contentTypeCharsetAppendedForTextType() { + request.setContentType("text/plain"); + request.setCharacterEncoding("UTF-8"); + + assertThat(headersAdapter.getFirst(HttpHeaders.CONTENT_TYPE)).isEqualTo("text/plain;charset=UTF-8"); + assertThat(headersAdapter.get(HttpHeaders.CONTENT_TYPE)).containsExactly("text/plain;charset=UTF-8"); + } + + @Test // gh-36426 + void contentTypeCharsetNotAppendedForApplicationJson() { + request.setContentType("application/json"); + request.setCharacterEncoding("UTF-8"); + + assertThat(headersAdapter.getFirst(HttpHeaders.CONTENT_TYPE)).isEqualTo("application/json"); + assertThat(headersAdapter.get(HttpHeaders.CONTENT_TYPE)).containsExactly("application/json"); + } + + @Test // gh-36426 + void contentTypeCharsetNotAppendedWhenAlreadyPresent() { + request.setContentType("text/plain;charset=ISO-8859-1"); + request.setCharacterEncoding("UTF-8"); + + assertThat(headersAdapter.getFirst(HttpHeaders.CONTENT_TYPE)).isEqualTo("text/plain;charset=ISO-8859-1"); + assertThat(headersAdapter.get(HttpHeaders.CONTENT_TYPE)).containsExactly("text/plain;charset=ISO-8859-1"); + } + + @Test // gh-36426 + void contentTypeCharsetNotAppendedWhenNoEncoding() { + request.setContentType("text/plain"); + + assertThat(headersAdapter.getFirst(HttpHeaders.CONTENT_TYPE)).isEqualTo("text/plain"); + assertThat(headersAdapter.get(HttpHeaders.CONTENT_TYPE)).containsExactly("text/plain"); + } }