Skip to content

Commit

Permalink
Support configuration of disallowed content-types (#14726)
Browse files Browse the repository at this point in the history
  • Loading branch information
oxsean authored Sep 27, 2024
1 parent 2bc04ba commit 957db22
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ public class RestConfig implements Serializable {
*/
private String jsonFramework;

/**
* The disallowed content-types.
*/
private String[] disallowedContentTypes;

/**
* The cors configuration.
*/
Expand Down Expand Up @@ -133,6 +138,14 @@ public void setJsonFramework(String jsonFramework) {
this.jsonFramework = jsonFramework;
}

public String[] getDisallowedContentTypes() {
return disallowedContentTypes;
}

public void setDisallowedContentTypes(String[] disallowedContentTypes) {
this.disallowedContentTypes = disallowedContentTypes;
}

public CorsConfig getCors() {
return cors;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,10 @@ private void handleHttp1(HttpServletRequest request, HttpServletResponse respons
channel, ServletExchanger.getUrl(), FrameworkModel.defaultModel());
channel.setGrpc(false);
context.setTimeout(resolveTimeout(request, false));
listener.onMetadata(new HttpMetadataAdapter(request));
ServletInputStream is = request.getInputStream();
response.getOutputStream().setWriteListener(new TripleWriteListener(channel));

listener.onMetadata(new HttpMetadataAdapter(request));
listener.onData(new Http1InputMessage(
is.available() == 0 ? StreamUtils.EMPTY : new ByteArrayInputStream(StreamUtils.readBytes(is))));
} catch (Throwable t) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,24 @@
package org.apache.dubbo.remoting.http12.message.codec;

import org.apache.dubbo.common.URL;
import org.apache.dubbo.common.config.Configuration;
import org.apache.dubbo.common.config.ConfigurationUtils;
import org.apache.dubbo.common.utils.Assert;
import org.apache.dubbo.common.utils.StringUtils;
import org.apache.dubbo.remoting.http12.exception.UnsupportedMediaTypeException;
import org.apache.dubbo.remoting.http12.message.HttpMessageDecoder;
import org.apache.dubbo.remoting.http12.message.HttpMessageDecoderFactory;
import org.apache.dubbo.remoting.http12.message.HttpMessageEncoder;
import org.apache.dubbo.remoting.http12.message.HttpMessageEncoderFactory;
import org.apache.dubbo.rpc.Constants;
import org.apache.dubbo.rpc.model.FrameworkModel;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

public final class CodecUtils {
Expand All @@ -37,13 +44,18 @@ public final class CodecUtils {
private final List<HttpMessageEncoderFactory> encoderFactories;
private final Map<String, Optional<HttpMessageEncoderFactory>> encoderCache = new ConcurrentHashMap<>();
private final Map<String, Optional<HttpMessageDecoderFactory>> decoderCache = new ConcurrentHashMap<>();
private Set<String> disallowedContentTypes = Collections.emptySet();

public CodecUtils(FrameworkModel frameworkModel) {
this.frameworkModel = frameworkModel;
decoderFactories = frameworkModel.getActivateExtensions(HttpMessageDecoderFactory.class);
encoderFactories = frameworkModel.getActivateExtensions(HttpMessageEncoderFactory.class);
decoderFactories.forEach(f -> decoderCache.putIfAbsent(f.mediaType().getName(), Optional.of(f)));
encoderFactories.forEach(f -> encoderCache.putIfAbsent(f.mediaType().getName(), Optional.of(f)));

Configuration configuration = ConfigurationUtils.getGlobalConfiguration(frameworkModel.defaultApplication());
String contentTypes = configuration.getString(Constants.H2_SETTINGS_DISALLOWED_CONTENT_TYPES, null);
if (contentTypes != null) {
disallowedContentTypes = new HashSet<>(StringUtils.tokenizeToList(contentTypes));
}
}

public HttpMessageDecoder determineHttpMessageDecoder(URL url, String mediaType) {
Expand All @@ -69,9 +81,10 @@ public HttpMessageEncoder determineHttpMessageEncoder(String mediaType) {
public Optional<HttpMessageDecoderFactory> determineHttpMessageDecoderFactory(String mediaType) {
Assert.notNull(mediaType, "mediaType must not be null");
return decoderCache.computeIfAbsent(mediaType, k -> {
for (HttpMessageDecoderFactory decoderFactory : decoderFactories) {
if (decoderFactory.supports(k)) {
return Optional.of(decoderFactory);
for (HttpMessageDecoderFactory factory : decoderFactories) {
if (factory.supports(k)
&& !disallowedContentTypes.contains(factory.mediaType().getName())) {
return Optional.of(factory);
}
}
return Optional.empty();
Expand All @@ -81,9 +94,10 @@ public Optional<HttpMessageDecoderFactory> determineHttpMessageDecoderFactory(St
public Optional<HttpMessageEncoderFactory> determineHttpMessageEncoderFactory(String mediaType) {
Assert.notNull(mediaType, "mediaType must not be null");
return encoderCache.computeIfAbsent(mediaType, k -> {
for (HttpMessageEncoderFactory encoderFactory : encoderFactories) {
if (encoderFactory.supports(k)) {
return Optional.of(encoderFactory);
for (HttpMessageEncoderFactory factory : encoderFactories) {
if (factory.supports(k)
&& !disallowedContentTypes.contains(factory.mediaType().getName())) {
return Optional.of(factory);
}
}
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ public interface Constants {
String H2_SETTINGS_BUILTIN_SERVICE_INIT = "dubbo.tri.builtin.service.init";

String H2_SETTINGS_JSON_FRAMEWORK_NAME = "dubbo.protocol.triple.rest.json-framework";
String H2_SETTINGS_DISALLOWED_CONTENT_TYPES = "dubbo.protocol.triple.rest.disallowed-content-types";

String H2_SETTINGS_VERBOSE_ENABLED = "dubbo.protocol.triple.verbose";
String H2_SETTINGS_SERVLET_ENABLED = "dubbo.protocol.triple.servlet.enabled";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
import org.apache.dubbo.rpc.protocol.tri.ExceptionUtils;
import org.apache.dubbo.rpc.protocol.tri.TripleProtocol;

import java.io.ByteArrayOutputStream;
import java.io.OutputStream;

import io.netty.buffer.ByteBufOutputStream;

public final class Http1UnaryServerChannelObserver extends Http1ServerChannelObserver {
Expand Down Expand Up @@ -52,7 +55,17 @@ protected void doOnError(Throwable throwable) throws Throwable {
@Override
protected void customizeHeaders(HttpHeaders headers, Throwable throwable, HttpOutputMessage message) {
super.customizeHeaders(headers, throwable, message);
int contentLength = message == null ? 0 : ((ByteBufOutputStream) message.getBody()).writtenBytes();
int contentLength = 0;
if (message != null) {
OutputStream body = message.getBody();
if (body instanceof ByteBufOutputStream) {
contentLength = ((ByteBufOutputStream) body).writtenBytes();
} else if (body instanceof ByteArrayOutputStream) {
contentLength = ((ByteArrayOutputStream) body).size();
} else {
throw new IllegalArgumentException("Unsupported body type: " + body.getClass());
}
}
headers.set(HttpHeaderNames.CONTENT_LENGTH.getName(), String.valueOf(contentLength));
}

Expand Down

0 comments on commit 957db22

Please sign in to comment.