diff --git a/alibaba-rsocket-core/src/main/java/com/alibaba/rsocket/encoding/impl/ObjectEncodingHandlerHessianImpl.java b/alibaba-rsocket-core/src/main/java/com/alibaba/rsocket/encoding/impl/ObjectEncodingHandlerHessianImpl.java index 01ad44f5..159d78df 100644 --- a/alibaba-rsocket-core/src/main/java/com/alibaba/rsocket/encoding/impl/ObjectEncodingHandlerHessianImpl.java +++ b/alibaba-rsocket-core/src/main/java/com/alibaba/rsocket/encoding/impl/ObjectEncodingHandlerHessianImpl.java @@ -4,8 +4,7 @@ import com.alibaba.rsocket.encoding.ObjectEncodingHandler; import com.alibaba.rsocket.metadata.RSocketMimeType; import com.alibaba.rsocket.observability.RsocketErrorCode; -import com.caucho.hessian.io.HessianSerializerInput; -import com.caucho.hessian.io.HessianSerializerOutput; +import com.caucho.hessian.io.*; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.ByteBufOutputStream; @@ -15,7 +14,9 @@ import org.jetbrains.annotations.Nullable; import java.io.IOException; -import java.util.*; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; /** * object encoding handler hessian implementation @@ -23,15 +24,12 @@ * @author leijuan */ public class ObjectEncodingHandlerHessianImpl implements ObjectEncodingHandler { - public static final List BLACK_CLASS_PATTERNS = new ArrayList<>(); public static final Set BLACK_CLASSES = new HashSet<>(); + public static final SerializerFactory serializerFactory = new SerializerFactoryWithBlackList(); static { - BLACK_CLASS_PATTERNS.add("javax.swing."); - BLACK_CLASS_PATTERNS.add("java.awt."); - BLACK_CLASS_PATTERNS.add("javax.naming."); - BLACK_CLASS_PATTERNS.add("java.lang.System"); - BLACK_CLASS_PATTERNS.add("java.lang.Process"); + BLACK_CLASSES.add("org.springframework.context.support.ClassPathXmlApplicationContext"); + BLACK_CLASSES.add("javax.swing.UIDefaults$ProxyLazyValue"); } @NotNull @@ -50,7 +48,6 @@ public ByteBuf encodingParams(@Nullable Object[] args) throws EncodingException @Override public Object decodeParams(ByteBuf data, @Nullable Class... targetClasses) throws EncodingException { - checkDecodingClass(targetClasses[0]); if (data.readableBytes() > 0) { try { return decode(data); @@ -72,7 +69,6 @@ public ByteBuf encodingResult(@Nullable Object result) throws EncodingException @Override public Object decodeResult(ByteBuf data, @Nullable Class targetClass) throws EncodingException { - checkDecodingClass(targetClass); if (data.readableBytes() > 0) { try { return decode(data); @@ -101,21 +97,19 @@ public static Object decode(@Nullable ByteBuf byteBuf) throws IOException { if (byteBuf == null || byteBuf.readableBytes() == 0) { return null; } - return new HessianSerializerInput(new ByteBufInputStream(byteBuf)).readObject(); + final HessianSerializerInput hessianSerializerInput = new HessianSerializerInput(new ByteBufInputStream(byteBuf)); + hessianSerializerInput.setSerializerFactory(serializerFactory); + return hessianSerializerInput.readObject(); + } - protected void checkDecodingClass(Class targetClass) throws EncodingException { - if (targetClass != null) { - String classFullName = targetClass.getCanonicalName(); - if (BLACK_CLASSES.contains(classFullName)) { - throw new EncodingException(RsocketErrorCode.message("RST-700401", targetClass)); - } - for (String pattern : BLACK_CLASS_PATTERNS) { - if (classFullName.startsWith(pattern)) { - BLACK_CLASSES.add(classFullName); - throw new EncodingException(RsocketErrorCode.message("RST-700401", targetClass)); - } + public static class SerializerFactoryWithBlackList extends SerializerFactory { + @Override + public Deserializer getObjectDeserializer(String type, Class cl) throws HessianProtocolException { + if (BLACK_CLASSES.contains(type)) { + throw new HessianProtocolException(RsocketErrorCode.message("RST-700401", type)); } + return super.getObjectDeserializer(type, cl); } } }