Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.io.Serializable;

import com.alibaba.com.caucho.hessian.io.Deserializer;
import com.alibaba.com.caucho.hessian.io.HessianProtocolException;
import com.alibaba.com.caucho.hessian.io.InputStreamDeserializer;
import com.alibaba.com.caucho.hessian.io.JavaDeserializer;
import com.alibaba.com.caucho.hessian.io.JavaSerializer;
Expand All @@ -47,18 +48,21 @@ public Class<?> loadSerializedClass(String className) throws ClassNotFoundExcept
return defaultSerializeClassChecker.loadClass(getClassLoader(), className);
}

@Override
public Serializer getSerializer(Class cl) throws HessianProtocolException {
// SerializerFactory handles writeReplace before getDefaultSerializer(), so enforce Dubbo's
// Serializable requirement here without expanding strict allow-list checks to JDK replacements.
if (JavaSerializer.getWriteReplace(cl) != null) {
checkSerializable(cl);
}
return super.getSerializer(cl);
}

@Override
protected Serializer getDefaultSerializer(Class cl) {
if (_defaultSerializer != null) return _defaultSerializer;

try {
// pre-check if class is allow
defaultSerializeClassChecker.loadClass(getClassLoader(), cl.getName());
} catch (ClassNotFoundException e) {
// ignore
}

checkSerializable(cl);
checkClass(cl);

if (isEnableUnsafeSerializer() && JavaSerializer.getWriteReplace(cl) == null) {
return UnsafeSerializer.create(cl);
Expand All @@ -71,14 +75,7 @@ protected Deserializer getDefaultDeserializer(Class cl) {
return InputStreamDeserializer.DESER;
}

try {
// pre-check if class is allow
defaultSerializeClassChecker.loadClass(getClassLoader(), cl.getName());
} catch (ClassNotFoundException e) {
// ignore
}

checkSerializable(cl);
checkClass(cl);

if (RecordUtil.isRecord(cl)) {
return new RecordDeserializer(cl, getFieldDeserializerFactory());
Expand All @@ -89,6 +86,17 @@ protected Deserializer getDefaultDeserializer(Class cl) {
}
}

private void checkClass(Class<?> cl) {
try {
// pre-check if class is allow
defaultSerializeClassChecker.loadClass(getClassLoader(), cl.getName());
} catch (ClassNotFoundException e) {
// ignore
}

checkSerializable(cl);
}

private void checkSerializable(Class<?> cl) {
// If class is Serializable => ok
// If class has not implement Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.util.Date;
import java.util.HashMap;
import java.util.LinkedList;
Expand All @@ -37,6 +39,7 @@

import com.example.test.TestPojo;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.Test;

import static org.apache.dubbo.common.constants.CommonConstants.DubboProperty.DUBBO_HESSIAN_ALLOW_NON_SERIALIZABLE;
Expand Down Expand Up @@ -654,4 +657,177 @@ void testLimit5() throws IOException, ClassNotFoundException {
frameworkModel.destroy();
}
}

@Test
void testWriteReplaceSelfStillRequiresSerializable() throws IOException {
FrameworkModel frameworkModel = new FrameworkModel();
try {
Serialization serialization =
frameworkModel.getExtensionLoader(Serialization.class).getExtension("hessian2");
URL url = URL.valueOf("").setScopeModel(frameworkModel);
ObjectOutput objectOutput = serialization.serialize(url, new ByteArrayOutputStream());

Assertions.assertThrows(
IOException.class, () -> objectOutput.writeObject(new NonSerializableWriteReplace("self")));
} finally {
frameworkModel.destroy();
}
}

@Test
void testWriteReplaceReplacementStillRequiresSerializable() throws IOException {
FrameworkModel frameworkModel = new FrameworkModel();
try {
Serialization serialization =
frameworkModel.getExtensionLoader(Serialization.class).getExtension("hessian2");
URL url = URL.valueOf("").setScopeModel(frameworkModel);
ObjectOutput objectOutput = serialization.serialize(url, new ByteArrayOutputStream());

Assertions.assertThrows(
IOException.class,
() -> objectOutput.writeObject(new SerializableWriteReplaceToNonSerializableReplacement("target")));
} finally {
frameworkModel.destroy();
}
}

@Test
void testWriteReplaceSerializableReplacementStillWorks() throws IOException, ClassNotFoundException {
FrameworkModel frameworkModel = new FrameworkModel();
try {
Serialization serialization =
frameworkModel.getExtensionLoader(Serialization.class).getExtension("hessian2");
URL url = URL.valueOf("").setScopeModel(frameworkModel);

ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
ObjectOutput objectOutput = serialization.serialize(url, outputStream);
objectOutput.writeObject(new SerializableWriteReplaceToSerializableReplacement("allowed"));
objectOutput.flushBuffer();

ByteArrayInputStream inputStream = new ByteArrayInputStream(outputStream.toByteArray());
ObjectInput objectInput = serialization.deserialize(url, inputStream);
Assertions.assertEquals(new SerializableReplacement("allowed"), objectInput.readObject());
} finally {
frameworkModel.destroy();
}
}

@Test
void testJdkImmutableListWithWriteReplaceCanSerializeInStrictMode() throws Exception {
List<String> immutableList = newJdkImmutableList("one", "two");
Assumptions.assumeTrue(immutableList != null, "JDK immutable collections are available since Java 9");

FrameworkModel frameworkModel = new FrameworkModel();
try {
Serialization serialization =
frameworkModel.getExtensionLoader(Serialization.class).getExtension("hessian2");
frameworkModel
.getBeanFactory()
.getBean(SerializeSecurityManager.class)
.setCheckStatus(SerializeCheckStatus.STRICT);
URL url = URL.valueOf("").setScopeModel(frameworkModel);

ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
ObjectOutput objectOutput = serialization.serialize(url, outputStream);
Assertions.assertDoesNotThrow(() -> {
objectOutput.writeObject(immutableList);
objectOutput.flushBuffer();
});
} finally {
frameworkModel.destroy();
}
}

@SuppressWarnings("unchecked")
private static List<String> newJdkImmutableList(String first, String second) throws Exception {
try {
Method of = List.class.getMethod("of", Object.class, Object.class);
return (List<String>) of.invoke(null, first, second);
} catch (NoSuchMethodException e) {
return null;
}
}

public static class NonSerializableWriteReplace {
private String value;

public NonSerializableWriteReplace() {}

public NonSerializableWriteReplace(String value) {
this.value = value;
}

private Object writeReplace() {
return this;
}
}

public static class SerializableWriteReplaceToNonSerializableReplacement implements Serializable {
private String value;

public SerializableWriteReplaceToNonSerializableReplacement() {}

public SerializableWriteReplaceToNonSerializableReplacement(String value) {
this.value = value;
}

private Object writeReplace() {
return new NonSerializableReplacement(value);
}
}

public static class NonSerializableReplacement {
private String value;

public NonSerializableReplacement() {}

public NonSerializableReplacement(String value) {
this.value = value;
}

private Object writeReplace() {
return this;
}
}

public static class SerializableWriteReplaceToSerializableReplacement implements Serializable {
private String value;

public SerializableWriteReplaceToSerializableReplacement() {}

public SerializableWriteReplaceToSerializableReplacement(String value) {
this.value = value;
}

private Object writeReplace() {
return new SerializableReplacement(value);
}
}

public static class SerializableReplacement implements Serializable {
private String value;

public SerializableReplacement() {}

public SerializableReplacement(String value) {
this.value = value;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof SerializableReplacement)) {
return false;
}
SerializableReplacement that = (SerializableReplacement) o;
return value.equals(that.value);
}

@Override
public int hashCode() {
return value.hashCode();
}
}
}
Loading