Skip to content

Commit

Permalink
Added '--pmml-schema' (aka '--schema') command-line option
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Aug 5, 2024
1 parent d8ba364 commit dc6f3f5
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 2 deletions.
Binary file modified sklearn2pmml/resources/sklearn2pmml-1.0-SNAPSHOT.jar
Binary file not shown.
64 changes: 62 additions & 2 deletions src/main/java/com/sklearn2pmml/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,33 @@
package com.sklearn2pmml;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.OutputStream;

import javax.xml.transform.TransformerFactory;
import javax.xml.transform.sax.SAXTransformerFactory;
import javax.xml.transform.sax.TransformerHandler;
import javax.xml.transform.stream.StreamResult;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Version;
import org.jpmml.converter.Application;
import org.jpmml.model.SAXUtil;
import org.jpmml.model.filters.ExportFilter;
import org.jpmml.model.metro.MetroJAXBUtil;
import org.jpmml.model.visitors.VersionInspector;
import org.jpmml.python.PickleUtil;
import org.jpmml.python.Storage;
import org.jpmml.python.StorageUtil;
import org.jpmml.sklearn.Encodable;
import org.jpmml.sklearn.EncodableUtil;
import org.jpmml.sklearn.SkLearnException;
import org.jpmml.sklearn.SkLearnUtil;
import org.xml.sax.InputSource;

public class Main extends Application {

Expand All @@ -48,6 +61,12 @@ public class Main extends Application {
)
private File outputFile = null;

@Parameter (
names = {"--pmml-schema", "--schema"},
converter = VersionConverter.class
)
private Version version = null;


static
public void main(String... args) throws Exception {
Expand Down Expand Up @@ -79,6 +98,21 @@ private void run() throws Exception {

PMML pmml = encodable.encodePMML();

if(this.version != null && this.version.compareTo(Version.PMML_4_4) < 0){
VersionInspector versionInspector = new VersionInspector();
versionInspector.applyTo(pmml);

Version minVersion = versionInspector.getMinimum();
if(minVersion.compareTo(this.version) > 0){
throw new SkLearnException("The generated markup requires PMML schema version " + minVersion.getVersion() + " or newer");
}

Version maxVersion = versionInspector.getMaximum();
if(maxVersion.compareTo(this.version) < 0){
throw new SkLearnException("The generated markup requires PMML schema version " + maxVersion.getVersion() + " or older");
}
} // End if

if(!this.outputFile.exists()){
File absoluteOutputFile = this.outputFile.getAbsoluteFile();

Expand All @@ -88,8 +122,34 @@ private void run() throws Exception {
}
}

try(OutputStream os = new FileOutputStream(this.outputFile)){
MetroJAXBUtil.marshalPMML(pmml, os);
if(this.version != null && this.version.compareTo(Version.PMML_4_4) < 0){
File tempFile = File.createTempFile("sklearn2pmml-", ".pmml");

try(OutputStream os = new FileOutputStream(tempFile)){
MetroJAXBUtil.marshalPMML(pmml, os);
}

SAXTransformerFactory transformerFactory = (SAXTransformerFactory)TransformerFactory.newInstance();

try(OutputStream os = new FileOutputStream(this.outputFile)){
TransformerHandler transformerHandler = transformerFactory.newTransformerHandler();
transformerHandler.setResult(new StreamResult(os));

ExportFilter exportFilter = new ExportFilter(SAXUtil.createXMLReader(), this.version);
exportFilter.setContentHandler(transformerHandler);

try(InputStream is = new FileInputStream(tempFile)){
exportFilter.parse(new InputSource(is));
}
}

tempFile.delete();
} else

{
try(OutputStream os = new FileOutputStream(this.outputFile)){
MetroJAXBUtil.marshalPMML(pmml, os);
}
}
}

Expand Down
45 changes: 45 additions & 0 deletions src/main/java/com/sklearn2pmml/VersionConverter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) 2024 Villu Ruusmann
*
* This file is part of SkLearn2PMML
*
* SkLearn2PMML is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* SkLearn2PMML is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with SkLearn2PMML. If not, see <http://www.gnu.org/licenses/>.
*/
package com.sklearn2pmml;

import java.util.Objects;

import com.beust.jcommander.IStringConverter;
import org.dmg.pmml.Version;

public class VersionConverter implements IStringConverter<Version> {

@Override
public Version convert(String string){
Version[] versions = Version.values();

for(Version version : versions){

if(!version.isStandard()){
continue;
} // End if

if(Objects.equals(version.getNamespaceURI(), string) || Objects.equals(version.getVersion(), string)){
return version;
}
}

throw new IllegalArgumentException(string);
}
}

0 comments on commit dc6f3f5

Please sign in to comment.