From c4f296c937ba0549dbb10d4f964e29ffec3c384c Mon Sep 17 00:00:00 2001 From: deepnetts Date: Thu, 2 Dec 2021 13:25:29 +0100 Subject: [PATCH] Added GPL header to all files --- .../javax/visrec/ri/BufferedImageFactory.java | 123 +++++---- .../AbstractImageClassifier.java | 19 ++ .../FeedForwardNetBinaryClassifier.java | 19 ++ .../ImageClassifierNetwork.java | 123 +++++---- .../MultiClassClassifierNetwork.java | 244 ++++++++++-------- .../ml/classification/ZeroRuleClassifier.java | 19 ++ .../ml/detection/AbstractObjectDetector.java | 179 +++++++------ .../ri/ml/detection/SimpleObjectDetector.java | 177 +++++++------ .../regression/LogisticRegressionNetwork.java | 202 ++++++++------- .../SimpleLinearRegressionNetwork.java | 231 +++++++++-------- .../spi/BufferedImageClassifierFactory.java | 211 ++++++++------- .../spi/DeepNettsImplementationService.java | 61 +++-- .../ri/spi/DefaultImageFactoryService.java | 97 ++++--- .../visrec/ri/spi/DefaultServiceProvider.java | 77 +++--- .../FloatArrayBinaryClassifierFactory.java | 129 +++++---- .../java/javax/visrec/ri/util/DataSets.java | 197 +++++++------- .../spi/DefaultImageFactoryServiceTest.java | 79 +++--- .../ri/util/BuilderConfigurationTest.java | 158 +++++++----- 18 files changed, 1342 insertions(+), 1003 deletions(-) diff --git a/src/main/java/javax/visrec/ri/BufferedImageFactory.java b/src/main/java/javax/visrec/ri/BufferedImageFactory.java index ff31e4b..a2921ff 100644 --- a/src/main/java/javax/visrec/ri/BufferedImageFactory.java +++ b/src/main/java/javax/visrec/ri/BufferedImageFactory.java @@ -1,52 +1,71 @@ -package javax.visrec.ri; - -import javax.imageio.ImageIO; -import javax.visrec.ImageFactory; -import java.awt.image.BufferedImage; -import java.io.IOException; -import java.io.InputStream; -import java.net.URL; -import java.nio.file.Path; - -/** - * {@link ImageFactory} to provide {@link BufferedImage} as return object. - * - */ -public class BufferedImageFactory implements ImageFactory { - - /** - * {@inheritDoc} - */ - @Override - public BufferedImage getImage(Path path) throws IOException { - BufferedImage img = ImageIO.read(path.toFile()); - if (img == null) { - throw new IOException("Failed to transform Path into BufferedImage due to unknown image encoding"); - } - return img; - } - - /** - * {@inheritDoc} - */ - @Override - public BufferedImage getImage(URL file) throws IOException { - BufferedImage img = ImageIO.read(file); - if (img == null) { - throw new IOException("Failed to transform URL into BufferedImage due to unknown image encoding"); - } - return img; - } - - /** - * {@inheritDoc} - */ - @Override - public BufferedImage getImage(InputStream file) throws IOException { - BufferedImage img = ImageIO.read(file); - if (img == null) { - throw new IOException("Failed to transform InputStream into BufferedImage due to unknown image encoding"); - } - return img; - } -} +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +package javax.visrec.ri; + +import javax.imageio.ImageIO; +import javax.visrec.ImageFactory; +import java.awt.image.BufferedImage; +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; +import java.nio.file.Path; + +/** + * {@link ImageFactory} to provide {@link BufferedImage} as return object. + * + */ +public class BufferedImageFactory implements ImageFactory { + + /** + * {@inheritDoc} + */ + @Override + public BufferedImage getImage(Path path) throws IOException { + BufferedImage img = ImageIO.read(path.toFile()); + if (img == null) { + throw new IOException("Failed to transform Path into BufferedImage due to unknown image encoding"); + } + return img; + } + + /** + * {@inheritDoc} + */ + @Override + public BufferedImage getImage(URL file) throws IOException { + BufferedImage img = ImageIO.read(file); + if (img == null) { + throw new IOException("Failed to transform URL into BufferedImage due to unknown image encoding"); + } + return img; + } + + /** + * {@inheritDoc} + */ + @Override + public BufferedImage getImage(InputStream file) throws IOException { + BufferedImage img = ImageIO.read(file); + if (img == null) { + throw new IOException("Failed to transform InputStream into BufferedImage due to unknown image encoding"); + } + return img; + } +} diff --git a/src/main/java/javax/visrec/ri/ml/classification/AbstractImageClassifier.java b/src/main/java/javax/visrec/ri/ml/classification/AbstractImageClassifier.java index 47fd76a..197737c 100644 --- a/src/main/java/javax/visrec/ri/ml/classification/AbstractImageClassifier.java +++ b/src/main/java/javax/visrec/ri/ml/classification/AbstractImageClassifier.java @@ -1,3 +1,22 @@ +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + package javax.visrec.ri.ml.classification; import javax.visrec.ImageFactory; diff --git a/src/main/java/javax/visrec/ri/ml/classification/FeedForwardNetBinaryClassifier.java b/src/main/java/javax/visrec/ri/ml/classification/FeedForwardNetBinaryClassifier.java index 5aaec0a..30e8aeb 100644 --- a/src/main/java/javax/visrec/ri/ml/classification/FeedForwardNetBinaryClassifier.java +++ b/src/main/java/javax/visrec/ri/ml/classification/FeedForwardNetBinaryClassifier.java @@ -1,3 +1,22 @@ +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + package javax.visrec.ri.ml.classification; import deepnetts.net.FeedForwardNetwork; diff --git a/src/main/java/javax/visrec/ri/ml/classification/ImageClassifierNetwork.java b/src/main/java/javax/visrec/ri/ml/classification/ImageClassifierNetwork.java index 77ca629..86afdaa 100644 --- a/src/main/java/javax/visrec/ri/ml/classification/ImageClassifierNetwork.java +++ b/src/main/java/javax/visrec/ri/ml/classification/ImageClassifierNetwork.java @@ -1,52 +1,71 @@ -package javax.visrec.ri.ml.classification; - -import deepnetts.data.ExampleImage; -import deepnetts.net.ConvolutionalNetwork; - -import java.awt.image.BufferedImage; -import java.util.HashMap; -import java.util.Map; - -/** - * Implementation of abstract image classifier for BufferedImage-s using - * Convolutional network form Deep Netts. - */ -public class ImageClassifierNetwork extends AbstractImageClassifier { - - // it seems that these are not used at the end, onlz in builder. Do we need them exposed here__ - private int inputWidth, inputHeight; - - public ImageClassifierNetwork(ConvolutionalNetwork network) { - super(BufferedImage.class, network); - } - - @Override - public Map classify(BufferedImage inputImage) { - // create input for neural network from image - ExampleImage exImage = new ExampleImage(inputImage); - - // get underlying ML model, in this case convolutional network - ConvolutionalNetwork neuralNet = getModel(); - // set neural network input and get outputs - neuralNet.setInput(exImage.getInput()); - float[] outputs = neuralNet.getOutput(); - - // get all class labels with corresponding output larger then classification threshold - Map results = new HashMap<>(); - for (int i = 0; i < outputs.length; i++) { - if (outputs[i] >= getThreshold()) - results.put(neuralNet.getOutputLabel(i), outputs[i]); - } - - return results; - } - - public int getInputWidth() { - return inputWidth; - } - - public int getInputHeight() { - return inputHeight; - } - -} +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +package javax.visrec.ri.ml.classification; + +import deepnetts.data.ExampleImage; +import deepnetts.net.ConvolutionalNetwork; + +import java.awt.image.BufferedImage; +import java.util.HashMap; +import java.util.Map; + +/** + * Implementation of abstract image classifier for BufferedImage-s using + * Convolutional network form Deep Netts. + */ +public class ImageClassifierNetwork extends AbstractImageClassifier { + + // it seems that these are not used at the end, onlz in builder. Do we need them exposed here__ + private int inputWidth, inputHeight; + + public ImageClassifierNetwork(ConvolutionalNetwork network) { + super(BufferedImage.class, network); + } + + @Override + public Map classify(BufferedImage inputImage) { + // create input for neural network from image + ExampleImage exImage = new ExampleImage(inputImage); + + // get underlying ML model, in this case convolutional network + ConvolutionalNetwork neuralNet = getModel(); + // set neural network input and get outputs + neuralNet.setInput(exImage.getInput()); + float[] outputs = neuralNet.getOutput(); + + // get all class labels with corresponding output larger then classification threshold + Map results = new HashMap<>(); + for (int i = 0; i < outputs.length; i++) { + if (outputs[i] >= getThreshold()) + results.put(neuralNet.getOutputLabel(i), outputs[i]); + } + + return results; + } + + public int getInputWidth() { + return inputWidth; + } + + public int getInputHeight() { + return inputHeight; + } + +} diff --git a/src/main/java/javax/visrec/ri/ml/classification/MultiClassClassifierNetwork.java b/src/main/java/javax/visrec/ri/ml/classification/MultiClassClassifierNetwork.java index 3bc1d9d..7f3cab0 100644 --- a/src/main/java/javax/visrec/ri/ml/classification/MultiClassClassifierNetwork.java +++ b/src/main/java/javax/visrec/ri/ml/classification/MultiClassClassifierNetwork.java @@ -1,113 +1,131 @@ -package javax.visrec.ri.ml.classification; - -import deepnetts.data.MLDataItem; -import deepnetts.net.FeedForwardNetwork; -import deepnetts.net.layers.activation.ActivationType; -import deepnetts.net.loss.LossType; -import deepnetts.net.train.BackpropagationTrainer; -import deepnetts.util.Tensor; - -import javax.visrec.ml.classification.AbstractMultiClassClassifier; -import java.util.HashMap; -import java.util.Map; -import javax.visrec.ml.model.ModelCreationException; -import javax.visrec.ml.data.DataSet; - -public class MultiClassClassifierNetwork extends AbstractMultiClassClassifier { - - @Override - public Map classify(float[] input) { - FeedForwardNetwork model = getModel(); - model.setInput(Tensor.create(1, input.length, input)); //TODO: put array to input tensor placeholder - float[] outputs = model.getOutput(); - String[] labels = model.getOutputLabels(); - Map result = new HashMap<>(); - for(int i=0; i { - private MultiClassClassifierNetwork building = new MultiClassClassifierNetwork(); - - private float learningRate = 0.01f; - private float maxError = 0.03f; - private long maxEpochs = Long.MAX_VALUE; - private int inputsNum; - private int outputsNum; - private int[] hiddenLayers; - - private DataSet trainingSet; - - @Override - public MultiClassClassifierNetwork build() { - // Network architecture as Map/properties, json? - FeedForwardNetwork.Builder builder = FeedForwardNetwork.builder() - .addInputLayer(inputsNum); - for(int h : hiddenLayers) { - builder.addFullyConnectedLayer(h, ActivationType.TANH); - } - - builder.addOutputLayer(outputsNum, ActivationType.SOFTMAX) - .lossFunction(LossType.CROSS_ENTROPY) - .hiddenActivationFunction(ActivationType.TANH); - - FeedForwardNetwork model = builder.build(); - - // aslo can be replaced with model.getTrainer() - BackpropagationTrainer trainer = new BackpropagationTrainer(model); // model as param in constructor - trainer.setLearningRate(learningRate) - .setMaxError(maxError) - .setMaxEpochs(maxEpochs); - - if (trainingSet!=null) - trainer.train(trainingSet); // move model to constructor - - building.setModel(model); - - return building; - } - - public Builder learningRate(float learningRate) { - this.learningRate = learningRate; - return this; - } - - public Builder maxError(float maxError) { - this.maxError = maxError; - return this; - } - - public Builder maxEpochs(int maxEpochs) { - this.maxEpochs = maxEpochs; - return this; - } - - public Builder inputsNum(int inputsNum) { - this.inputsNum = inputsNum; - return this; - } - - public Builder outputsNum(int outputsNum) { - this.outputsNum = outputsNum; - return this; - } - - public Builder hiddenLayers(int... hiddenLayers) { - this.hiddenLayers = hiddenLayers; - return this; - } - - public Builder trainingSet(DataSet trainingSet) { - this.trainingSet = trainingSet; - return this; - } - - - } -} +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +package javax.visrec.ri.ml.classification; + +import deepnetts.data.MLDataItem; +import deepnetts.net.FeedForwardNetwork; +import deepnetts.net.layers.activation.ActivationType; +import deepnetts.net.loss.LossType; +import deepnetts.net.train.BackpropagationTrainer; +import deepnetts.util.Tensor; + +import javax.visrec.ml.classification.AbstractMultiClassClassifier; +import java.util.HashMap; +import java.util.Map; +import javax.visrec.ml.data.DataSet; + +public class MultiClassClassifierNetwork extends AbstractMultiClassClassifier { + + @Override + public Map classify(float[] input) { + FeedForwardNetwork model = getModel(); + model.setInput(Tensor.create(1, input.length, input)); //TODO: put array to input tensor placeholder + float[] outputs = model.getOutput(); + String[] labels = model.getOutputLabels(); + Map result = new HashMap<>(); + for(int i=0; i { + private MultiClassClassifierNetwork building = new MultiClassClassifierNetwork(); + + private float learningRate = 0.01f; + private float maxError = 0.03f; + private long maxEpochs = Long.MAX_VALUE; + private int inputsNum; + private int outputsNum; + private int[] hiddenLayers; + + private DataSet trainingSet; + + @Override + public MultiClassClassifierNetwork build() { + // Network architecture as Map/properties, json? + FeedForwardNetwork.Builder builder = FeedForwardNetwork.builder() + .addInputLayer(inputsNum); + for(int h : hiddenLayers) { + builder.addFullyConnectedLayer(h, ActivationType.TANH); + } + + builder.addOutputLayer(outputsNum, ActivationType.SOFTMAX) + .lossFunction(LossType.CROSS_ENTROPY) + .hiddenActivationFunction(ActivationType.TANH); + + FeedForwardNetwork model = builder.build(); + + // aslo can be replaced with model.getTrainer() + BackpropagationTrainer trainer = new BackpropagationTrainer(model); // model as param in constructor + trainer.setLearningRate(learningRate) + .setMaxError(maxError) + .setMaxEpochs(maxEpochs); + + if (trainingSet!=null) + trainer.train(trainingSet); // move model to constructor + + building.setModel(model); + + return building; + } + + public Builder learningRate(float learningRate) { + this.learningRate = learningRate; + return this; + } + + public Builder maxError(float maxError) { + this.maxError = maxError; + return this; + } + + public Builder maxEpochs(int maxEpochs) { + this.maxEpochs = maxEpochs; + return this; + } + + public Builder inputsNum(int inputsNum) { + this.inputsNum = inputsNum; + return this; + } + + public Builder outputsNum(int outputsNum) { + this.outputsNum = outputsNum; + return this; + } + + public Builder hiddenLayers(int... hiddenLayers) { + this.hiddenLayers = hiddenLayers; + return this; + } + + public Builder trainingSet(DataSet trainingSet) { + this.trainingSet = trainingSet; + return this; + } + + + } +} diff --git a/src/main/java/javax/visrec/ri/ml/classification/ZeroRuleClassifier.java b/src/main/java/javax/visrec/ri/ml/classification/ZeroRuleClassifier.java index ddbe930..33353a9 100644 --- a/src/main/java/javax/visrec/ri/ml/classification/ZeroRuleClassifier.java +++ b/src/main/java/javax/visrec/ri/ml/classification/ZeroRuleClassifier.java @@ -1,3 +1,22 @@ +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + package javax.visrec.ri.ml.classification; import deepnetts.data.MLDataItem; diff --git a/src/main/java/javax/visrec/ri/ml/detection/AbstractObjectDetector.java b/src/main/java/javax/visrec/ri/ml/detection/AbstractObjectDetector.java index 75a9d5b..9886773 100644 --- a/src/main/java/javax/visrec/ri/ml/detection/AbstractObjectDetector.java +++ b/src/main/java/javax/visrec/ri/ml/detection/AbstractObjectDetector.java @@ -1,80 +1,99 @@ -package javax.visrec.ri.ml.detection; - -import javax.visrec.ri.ml.classification.AbstractImageClassifier; -import javax.visrec.ml.classification.ClassificationException; -import javax.visrec.ml.detection.BoundingBox; -import javax.visrec.ml.detection.ObjectDetector; -import java.awt.image.BufferedImage; -import java.io.IOException; -import java.io.InputStream; -import java.nio.file.Path; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -/** - * Abstract object detector which implements {@link ObjectDetector} to return the positions - * of an object within the given image. - */ -public abstract class AbstractObjectDetector implements ObjectDetector { - - private AbstractImageClassifier imageClassifier; - - /** - * Creates an instance of the object detector - * - * @param imageClassifier A {@link AbstractImageClassifier} which may not be null - */ - public AbstractObjectDetector(AbstractImageClassifier imageClassifier) { - Objects.requireNonNull(imageClassifier, "A classifier is required for the object detector."); - this.imageClassifier = imageClassifier; - } - - /** - * Scan entire image and return positions where object is detected - * - * @param image {@code IMAGE_CLASS} image - * @return {@code Map} of {@link BoundingBox} of where the object - * has been detected. - */ - @Override - public abstract Map> detectObject(BufferedImage image) throws ClassificationException; - - /** - * Detect the object based on the given {@code File}. - * - * @param path Image file. - * @return {@code Map} of {@link BoundingBox} of where the object - * has been detected. - * @throws IOException if the image couldn't be retrieved from storage. - * @throws ClassificationException when the detector was unable to classify and detect the input - */ - public Map> detect(Path path) throws IOException, ClassificationException { - BufferedImage image = imageClassifier.getImageFactory().getImage(path); - return detectObject(image); - } - - /** - * Detect the object based on the given {@code InputStream}. - * - * @param inStream {@code InputStream} of the image - * @return {@code Map} of {@link BoundingBox} of where the object - * has been detected. - * @throws IOException if the image couldn't be retrieved from storage. - * @throws ClassificationException when the detector was unable to classify and detect the input - */ - public Map> detect(InputStream inStream) throws IOException, ClassificationException { - BufferedImage image = imageClassifier.getImageFactory().getImage(inStream); - return detectObject(image); - } - - /** - * Subclasses should use this method to use the underlying image classifier - * - * @return configured {@link AbstractImageClassifier} of the {@link AbstractObjectDetector} - */ - public AbstractImageClassifier getImageClassifier() { - return imageClassifier; - } - -} +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +package javax.visrec.ri.ml.detection; + +import javax.visrec.ri.ml.classification.AbstractImageClassifier; +import javax.visrec.ml.classification.ClassificationException; +import javax.visrec.ml.detection.BoundingBox; +import javax.visrec.ml.detection.ObjectDetector; +import java.awt.image.BufferedImage; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * Abstract object detector which implements {@link ObjectDetector} to return the positions + * of an object within the given image. + */ +public abstract class AbstractObjectDetector implements ObjectDetector { + + private AbstractImageClassifier imageClassifier; + + /** + * Creates an instance of the object detector + * + * @param imageClassifier A {@link AbstractImageClassifier} which may not be null + */ + public AbstractObjectDetector(AbstractImageClassifier imageClassifier) { + Objects.requireNonNull(imageClassifier, "A classifier is required for the object detector."); + this.imageClassifier = imageClassifier; + } + + /** + * Scan entire image and return positions where object is detected + * + * @param image {@code IMAGE_CLASS} image + * @return {@code Map} of {@link BoundingBox} of where the object + * has been detected. + */ + @Override + public abstract Map> detectObject(BufferedImage image) throws ClassificationException; + + /** + * Detect the object based on the given {@code File}. + * + * @param path Image file. + * @return {@code Map} of {@link BoundingBox} of where the object + * has been detected. + * @throws IOException if the image couldn't be retrieved from storage. + * @throws ClassificationException when the detector was unable to classify and detect the input + */ + public Map> detect(Path path) throws IOException, ClassificationException { + BufferedImage image = imageClassifier.getImageFactory().getImage(path); + return detectObject(image); + } + + /** + * Detect the object based on the given {@code InputStream}. + * + * @param inStream {@code InputStream} of the image + * @return {@code Map} of {@link BoundingBox} of where the object + * has been detected. + * @throws IOException if the image couldn't be retrieved from storage. + * @throws ClassificationException when the detector was unable to classify and detect the input + */ + public Map> detect(InputStream inStream) throws IOException, ClassificationException { + BufferedImage image = imageClassifier.getImageFactory().getImage(inStream); + return detectObject(image); + } + + /** + * Subclasses should use this method to use the underlying image classifier + * + * @return configured {@link AbstractImageClassifier} of the {@link AbstractObjectDetector} + */ + public AbstractImageClassifier getImageClassifier() { + return imageClassifier; + } + +} diff --git a/src/main/java/javax/visrec/ri/ml/detection/SimpleObjectDetector.java b/src/main/java/javax/visrec/ri/ml/detection/SimpleObjectDetector.java index 64d8754..e5fee7f 100644 --- a/src/main/java/javax/visrec/ri/ml/detection/SimpleObjectDetector.java +++ b/src/main/java/javax/visrec/ri/ml/detection/SimpleObjectDetector.java @@ -1,79 +1,98 @@ -package javax.visrec.ri.ml.detection; - -import javax.visrec.ri.ml.classification.AbstractImageClassifier; -import javax.visrec.ml.classification.ClassificationException; -import javax.visrec.ml.detection.BoundingBox; -import java.awt.image.BufferedImage; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -/** - * A simple object detector - * - */ -public class SimpleObjectDetector extends AbstractObjectDetector { - - private double threshold = 0.5; - - /** - * Creates an instance - * - * @param classifier A {@link AbstractImageClassifier} which may not be null - */ - public SimpleObjectDetector(AbstractImageClassifier classifier) { - super(classifier); - } - - /** - * Scan image using brute force sliding window and return positions where - * classifier returns score greater then threshold. - *

- * get width and height of the image and scan image with classifier - apply - * classifier to each position This is trivial implementation and should be - * replaced with something better - * - * @param image {@code BufferedImage} to scan - * @return A {@code Map} of {@link BoundingBox} which contain - * the positions of the detected object. - */ - @Override - public Map> detectObject(BufferedImage image) throws ClassificationException { - Map> results = new HashMap<>(); - - int boxWidth = 64, boxHeight = 64; - - for (int y = 0; y < image.getHeight() - boxHeight; y++) { - for (int x = 0; x < image.getWidth() - boxWidth; x++) { - - Map results2 = getImageClassifier().classify(image.getSubimage(x, y, boxWidth, boxHeight)); - for (Map.Entry keyValPair : results2.entrySet()) { - if (keyValPair.getValue() > threshold) { - BoundingBox bbox = new BoundingBox(keyValPair.getKey(), keyValPair.getValue(), x, y, boxWidth, boxHeight); - //results.put(keyValPair.getKey(), bboxes); add these to list - } - } - } - } - - return results; - } - - /** - * Get the threshold - * - * @return theshold as {@code double} - */ - public double getThreshold() { - return threshold; - } - - /** - * Set the threshold - * - * @param threshold as {@code double} - */ - public void setThreshold(double threshold) { - this.threshold = threshold; - } -} +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +package javax.visrec.ri.ml.detection; + +import javax.visrec.ri.ml.classification.AbstractImageClassifier; +import javax.visrec.ml.classification.ClassificationException; +import javax.visrec.ml.detection.BoundingBox; +import java.awt.image.BufferedImage; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A simple object detector. + * + */ +public class SimpleObjectDetector extends AbstractObjectDetector { + + private double threshold = 0.5; + + /** + * Creates an instance + * + * @param classifier A {@link AbstractImageClassifier} which may not be null + */ + public SimpleObjectDetector(AbstractImageClassifier classifier) { + super(classifier); + } + + /** + * Scan image using brute force sliding window and return positions where + * classifier returns score greater then threshold. + *

+ * get width and height of the image and scan image with classifier - apply + * classifier to each position This is trivial implementation and should be + * replaced with something better + * + * @param image {@code BufferedImage} to scan + * @return A {@code Map} of {@link BoundingBox} which contain + * the positions of the detected object. + */ + @Override + public Map> detectObject(BufferedImage image) throws ClassificationException { + Map> results = new HashMap<>(); + + int boxWidth = 64, boxHeight = 64; + + for (int y = 0; y < image.getHeight() - boxHeight; y++) { + for (int x = 0; x < image.getWidth() - boxWidth; x++) { + + Map results2 = getImageClassifier().classify(image.getSubimage(x, y, boxWidth, boxHeight)); + for (Map.Entry keyValPair : results2.entrySet()) { + if (keyValPair.getValue() > threshold) { + BoundingBox bbox = new BoundingBox(keyValPair.getKey(), keyValPair.getValue(), x, y, boxWidth, boxHeight); + //results.put(keyValPair.getKey(), bboxes); add these to list + } + } + } + } + + return results; + } + + /** + * Get the threshold + * + * @return theshold as {@code double} + */ + public double getThreshold() { + return threshold; + } + + /** + * Set the threshold + * + * @param threshold as {@code double} + */ + public void setThreshold(double threshold) { + this.threshold = threshold; + } +} diff --git a/src/main/java/javax/visrec/ri/ml/regression/LogisticRegressionNetwork.java b/src/main/java/javax/visrec/ri/ml/regression/LogisticRegressionNetwork.java index 4214480..c824d89 100644 --- a/src/main/java/javax/visrec/ri/ml/regression/LogisticRegressionNetwork.java +++ b/src/main/java/javax/visrec/ri/ml/regression/LogisticRegressionNetwork.java @@ -1,92 +1,110 @@ -package javax.visrec.ri.ml.regression; - -import deepnetts.data.MLDataItem; -import deepnetts.net.FeedForwardNetwork; -import deepnetts.net.layers.activation.ActivationType; -import deepnetts.net.loss.LossType; -import deepnetts.net.train.BackpropagationTrainer; -import deepnetts.util.Tensor; -import javax.visrec.ml.model.ModelCreationException; - -import javax.visrec.ml.classification.LogisticRegression; -import javax.visrec.ml.data.DataSet; - - -/** - * - */ -public class LogisticRegressionNetwork extends LogisticRegression { - - @Override - public Float classify(float[] input) { - FeedForwardNetwork model = getModel(); - model.setInput(Tensor.create(1, input.length, input)); //TODO: put array to input tensor placeholder - return model.getOutput()[0]; - } - - public static Builder builder() { - return new Builder(); - } - - - - public static class Builder implements javax.visrec.ml.model.ModelBuilder { - - private float learningRate = 0.01f; - private float maxError = 0.03f; - private int maxEpochs = 1000; - private int inputsNum; - - private DataSet trainingSet; // replace with DataSet from visrec - - public Builder inputsNum(int inputsNum) { - this.inputsNum = inputsNum; - return this; - } - - public Builder learningRate(float learningRate) { - this.learningRate = learningRate; - return this; - } - - public Builder maxError(float maxError) { - this.maxError = maxError; - return this; - } - - public Builder maxEpochs(int maxEpochs) { - this.maxEpochs = maxEpochs; - return this; - } - - public Builder trainingSet(DataSet trainingSet) { - this.trainingSet = trainingSet; - return this; - } - - // test set - // target accuracy - @Override - public LogisticRegressionNetwork build() { - FeedForwardNetwork model = FeedForwardNetwork.builder() - .addInputLayer(inputsNum) - .addOutputLayer(1, ActivationType.SIGMOID) - .lossFunction(LossType.CROSS_ENTROPY) - .build(); - - BackpropagationTrainer trainer = new BackpropagationTrainer(model); - trainer.setLearningRate(learningRate) - .setMaxEpochs(maxEpochs) - .setMaxError(maxError); - - if (trainingSet != null) { - trainer.train(trainingSet); - } - - LogisticRegressionNetwork product = new LogisticRegressionNetwork(); - product.setModel(model); - return product; - } - - } -} +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +package javax.visrec.ri.ml.regression; + +import deepnetts.data.MLDataItem; +import deepnetts.net.FeedForwardNetwork; +import deepnetts.net.layers.activation.ActivationType; +import deepnetts.net.loss.LossType; +import deepnetts.net.train.BackpropagationTrainer; +import deepnetts.util.Tensor; + +import javax.visrec.ml.classification.LogisticRegression; +import javax.visrec.ml.data.DataSet; + + +/** + * Logistic regresion algorithm implemented by neural network. + */ +public class LogisticRegressionNetwork extends LogisticRegression { + + @Override + public Float classify(float[] input) { + FeedForwardNetwork model = getModel(); + model.setInput(Tensor.create(1, input.length, input)); //TODO: put array to input tensor placeholder + return model.getOutput()[0]; + } + + public static Builder builder() { + return new Builder(); + } + + + + public static class Builder implements javax.visrec.ml.model.ModelBuilder { + + private float learningRate = 0.01f; + private float maxError = 0.03f; + private int maxEpochs = 1000; + private int inputsNum; + + private DataSet trainingSet; // replace with DataSet from visrec + + public Builder inputsNum(int inputsNum) { + this.inputsNum = inputsNum; + return this; + } + + public Builder learningRate(float learningRate) { + this.learningRate = learningRate; + return this; + } + + public Builder maxError(float maxError) { + this.maxError = maxError; + return this; + } + + public Builder maxEpochs(int maxEpochs) { + this.maxEpochs = maxEpochs; + return this; + } + + public Builder trainingSet(DataSet trainingSet) { + this.trainingSet = trainingSet; + return this; + } + + // test set + // target accuracy + @Override + public LogisticRegressionNetwork build() { + FeedForwardNetwork model = FeedForwardNetwork.builder() + .addInputLayer(inputsNum) + .addOutputLayer(1, ActivationType.SIGMOID) + .lossFunction(LossType.CROSS_ENTROPY) + .build(); + + BackpropagationTrainer trainer = new BackpropagationTrainer(model); + trainer.setLearningRate(learningRate) + .setMaxEpochs(maxEpochs) + .setMaxError(maxError); + + if (trainingSet != null) { + trainer.train(trainingSet); + } + + LogisticRegressionNetwork product = new LogisticRegressionNetwork(); + product.setModel(model); + return product; + } + + } +} diff --git a/src/main/java/javax/visrec/ri/ml/regression/SimpleLinearRegressionNetwork.java b/src/main/java/javax/visrec/ri/ml/regression/SimpleLinearRegressionNetwork.java index 086b593..a53af89 100644 --- a/src/main/java/javax/visrec/ri/ml/regression/SimpleLinearRegressionNetwork.java +++ b/src/main/java/javax/visrec/ri/ml/regression/SimpleLinearRegressionNetwork.java @@ -1,106 +1,125 @@ -package javax.visrec.ri.ml.regression; - -import deepnetts.data.MLDataItem; -import deepnetts.net.FeedForwardNetwork; -import deepnetts.net.layers.activation.ActivationType; -import deepnetts.net.loss.LossType; -import deepnetts.net.train.BackpropagationTrainer; -import deepnetts.util.Tensor; - -import javax.visrec.ml.regression.SimpleLinearRegression; -import javax.visrec.ml.data.DataSet; - -/** - * Simple linear regression implemented Feed Forward Neural Network as a back-end. - * - * @see SimpleLinearRegression - */ -public class SimpleLinearRegressionNetwork extends SimpleLinearRegression { - - private final float[] input = new float[1]; - private final Tensor inputTensor = Tensor.create(1, 1, input); - - private float slope; - private float intercept; - - @Override - public Float predict(Float inputs) { - input[0] = inputs; - FeedForwardNetwork ffn = getModel(); - ffn.setInput(inputTensor); - float[] output = ffn.getOutput(); - return output[0]; - } - - public static Builder builder() { - return new Builder(); - } - - @Override - public float getSlope() { - return slope; - } - - @Override - public float getIntercept() { - return intercept; - } - - - public static class Builder implements javax.visrec.ml.model.ModelBuilder { - private SimpleLinearRegressionNetwork buildingBlock = new SimpleLinearRegressionNetwork(); - - private float learningRate = 0.01f; - private float maxError = 0.03f; - private int maxEpochs = 1000; - - private DataSet trainingSet; // replace with DataSet from visrec - - - public Builder learningRate(float learningRate) { - this.learningRate = learningRate; - return this; - } - - public Builder maxError(float maxError) { - this.maxError = maxError; - return this; - } - - public Builder maxEpochs(int maxEpochs) { - this.maxEpochs = maxEpochs; - return this; - } - - public Builder trainingSet(DataSet trainingSet) { - this.trainingSet = trainingSet; - return this; - } - - // test set - // target accuracy - - @Override - public SimpleLinearRegressionNetwork build() { - FeedForwardNetwork model= FeedForwardNetwork.builder() - .addInputLayer(1) - .addOutputLayer(1, ActivationType.LINEAR) - .lossFunction(LossType.MEAN_SQUARED_ERROR) - .build(); - - BackpropagationTrainer trainer = new BackpropagationTrainer(model); - trainer.setLearningRate(learningRate) - .setMaxError(maxError) - .setMaxEpochs(maxEpochs); - trainer.train(trainingSet); - - buildingBlock.intercept = model.getOutputLayer().getBiases()[0]; - buildingBlock.slope = model.getOutputLayer().getWeights().get(0); - - buildingBlock.setModel(model); - return buildingBlock; - } - - - } -} +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +package javax.visrec.ri.ml.regression; + +import deepnetts.data.MLDataItem; +import deepnetts.net.FeedForwardNetwork; +import deepnetts.net.layers.activation.ActivationType; +import deepnetts.net.loss.LossType; +import deepnetts.net.train.BackpropagationTrainer; +import deepnetts.util.Tensor; + +import javax.visrec.ml.regression.SimpleLinearRegression; +import javax.visrec.ml.data.DataSet; + +/** + * Simple linear regression implemented Feed Forward Neural Network as a back-end. + * + * @see SimpleLinearRegression + */ +public class SimpleLinearRegressionNetwork extends SimpleLinearRegression { + + private final float[] input = new float[1]; + private final Tensor inputTensor = Tensor.create(1, 1, input); + + private float slope; + private float intercept; + + @Override + public Float predict(Float inputs) { + input[0] = inputs; + FeedForwardNetwork ffn = getModel(); + ffn.setInput(inputTensor); + float[] output = ffn.getOutput(); + return output[0]; + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public float getSlope() { + return slope; + } + + @Override + public float getIntercept() { + return intercept; + } + + + public static class Builder implements javax.visrec.ml.model.ModelBuilder { + private SimpleLinearRegressionNetwork buildingBlock = new SimpleLinearRegressionNetwork(); + + private float learningRate = 0.01f; + private float maxError = 0.03f; + private int maxEpochs = 1000; + + private DataSet trainingSet; // replace with DataSet from visrec + + + public Builder learningRate(float learningRate) { + this.learningRate = learningRate; + return this; + } + + public Builder maxError(float maxError) { + this.maxError = maxError; + return this; + } + + public Builder maxEpochs(int maxEpochs) { + this.maxEpochs = maxEpochs; + return this; + } + + public Builder trainingSet(DataSet trainingSet) { + this.trainingSet = trainingSet; + return this; + } + + // test set + // target accuracy + + @Override + public SimpleLinearRegressionNetwork build() { + FeedForwardNetwork model= FeedForwardNetwork.builder() + .addInputLayer(1) + .addOutputLayer(1, ActivationType.LINEAR) + .lossFunction(LossType.MEAN_SQUARED_ERROR) + .build(); + + BackpropagationTrainer trainer = new BackpropagationTrainer(model); + trainer.setLearningRate(learningRate) + .setMaxError(maxError) + .setMaxEpochs(maxEpochs); + trainer.train(trainingSet); + + buildingBlock.intercept = model.getOutputLayer().getBiases()[0]; + buildingBlock.slope = model.getOutputLayer().getWeights().get(0); + + buildingBlock.setModel(model); + return buildingBlock; + } + + + } +} diff --git a/src/main/java/javax/visrec/ri/spi/BufferedImageClassifierFactory.java b/src/main/java/javax/visrec/ri/spi/BufferedImageClassifierFactory.java index ad31b09..dcc617f 100644 --- a/src/main/java/javax/visrec/ri/spi/BufferedImageClassifierFactory.java +++ b/src/main/java/javax/visrec/ri/spi/BufferedImageClassifierFactory.java @@ -1,96 +1,115 @@ -package javax.visrec.ri.spi; - -import deepnetts.data.ImageSet; -import deepnetts.net.ConvolutionalNetwork; -import deepnetts.net.train.BackpropagationTrainer; -import deepnetts.net.train.opt.OptimizerType; -import deepnetts.util.DeepNettsException; -import deepnetts.util.FileIO; - -import javax.visrec.ml.classification.ImageClassifier; -import javax.visrec.ml.classification.NeuralNetImageClassifier; -import javax.visrec.ml.model.ModelCreationException; -import javax.visrec.ri.ml.classification.ImageClassifierNetwork; -import javax.visrec.spi.ImageClassifierFactory; -import java.awt.image.BufferedImage; -import java.io.FileInputStream; -import java.io.FileNotFoundException; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.util.logging.Logger; - -public class BufferedImageClassifierFactory implements ImageClassifierFactory { - - private static final Logger LOGGER = Logger.getLogger(BufferedImageClassifierFactory.class.getName()); - - @Override - public Class getImageClass() { - return BufferedImage.class; - } - - @Override - public ImageClassifier create(NeuralNetImageClassifier.BuildingBlock block) throws ModelCreationException { - if (block.getImportPath() != null) { - return onImport(block); - } - return onCreate(block); - } - - private ImageClassifier onImport(NeuralNetImageClassifier.BuildingBlock block) throws ModelCreationException { - try { - ObjectInputStream inputStream = new ObjectInputStream(new FileInputStream(block.getImportPath().toFile())); - ConvolutionalNetwork model = (ConvolutionalNetwork) inputStream.readObject(); - return new ImageClassifierNetwork(model); - } catch (IOException | ClassNotFoundException e) { - throw new ModelCreationException("Failed to import existing model", e); - } - } - - private ImageClassifier onCreate(NeuralNetImageClassifier.BuildingBlock block) throws ModelCreationException { - ImageSet imageSet = new ImageSet(block.getImageWidth(), block.getImageHeight()); - LOGGER.info("Loading images..."); - - imageSet.loadLabels(block.getLabelsPath().toFile()); - try { - imageSet.loadImages(block.getTrainingPath().toFile()); - imageSet.shuffle(); - } catch (DeepNettsException | FileNotFoundException ex) { - throw new ModelCreationException("Failed to load images from dataset", ex); - } - - LOGGER.info("Done!"); - LOGGER.info("Creating neural network..."); - - ConvolutionalNetwork neuralNet = null; - try { - neuralNet = (ConvolutionalNetwork) FileIO.createFromJson(block.getNetworkArchitecture().toFile()); - neuralNet.setOutputLabels(imageSet.getTargetColumnsNames()); - } catch (IOException ex) { - throw new ModelCreationException("Failed to create convolutional network from JSON file", ex); - } - - LOGGER.info("Done!"); - LOGGER.info("Training neural network"); - - // create a set of convolutional networks and do training, crossvalidation and performance evaluation - BackpropagationTrainer trainer = new BackpropagationTrainer(neuralNet) - .setLearningRate(block.getLearningRate()) - .setMomentum(0.7f) - .setMaxError(block.getMaxError()) - .setMaxEpochs(block.getMaxEpochs()) - .setBatchMode(false) - .setOptimizer(OptimizerType.SGD); - trainer.train(imageSet); - - ImageClassifierNetwork imageClassifier = new ImageClassifierNetwork(neuralNet); - try { - FileIO.writeToFile(neuralNet, block.getExportPath().toFile().getAbsolutePath()); - } catch (IOException ex) { - throw new ModelCreationException("Failed to write trained model to file", ex); - } - - imageClassifier.setThreshold(block.getThreshold()); - - return imageClassifier; - } -} +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +package javax.visrec.ri.spi; + +import deepnetts.data.ImageSet; +import deepnetts.net.ConvolutionalNetwork; +import deepnetts.net.train.BackpropagationTrainer; +import deepnetts.net.train.opt.OptimizerType; +import deepnetts.util.DeepNettsException; +import deepnetts.util.FileIO; + +import javax.visrec.ml.classification.ImageClassifier; +import javax.visrec.ml.classification.NeuralNetImageClassifier; +import javax.visrec.ml.model.ModelCreationException; +import javax.visrec.ri.ml.classification.ImageClassifierNetwork; +import javax.visrec.spi.ImageClassifierFactory; +import java.awt.image.BufferedImage; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.util.logging.Logger; + +public class BufferedImageClassifierFactory implements ImageClassifierFactory { + + private static final Logger LOGGER = Logger.getLogger(BufferedImageClassifierFactory.class.getName()); + + @Override + public Class getImageClass() { + return BufferedImage.class; + } + + @Override + public ImageClassifier create(NeuralNetImageClassifier.BuildingBlock block) throws ModelCreationException { + if (block.getImportPath() != null) { + return onImport(block); + } + return onCreate(block); + } + + private ImageClassifier onImport(NeuralNetImageClassifier.BuildingBlock block) throws ModelCreationException { + try { + ObjectInputStream inputStream = new ObjectInputStream(new FileInputStream(block.getImportPath().toFile())); + ConvolutionalNetwork model = (ConvolutionalNetwork) inputStream.readObject(); + return new ImageClassifierNetwork(model); + } catch (IOException | ClassNotFoundException e) { + throw new ModelCreationException("Failed to import existing model", e); + } + } + + private ImageClassifier onCreate(NeuralNetImageClassifier.BuildingBlock block) throws ModelCreationException { + ImageSet imageSet = new ImageSet(block.getImageWidth(), block.getImageHeight()); + LOGGER.info("Loading images..."); + + imageSet.loadLabels(block.getLabelsPath().toFile()); + try { + imageSet.loadImages(block.getTrainingPath().toFile()); + imageSet.shuffle(); + } catch (DeepNettsException | FileNotFoundException ex) { + throw new ModelCreationException("Failed to load images from dataset", ex); + } + + LOGGER.info("Done!"); + LOGGER.info("Creating neural network..."); + + ConvolutionalNetwork neuralNet = null; + try { + neuralNet = (ConvolutionalNetwork) FileIO.createFromJson(block.getNetworkArchitecture().toFile()); + neuralNet.setOutputLabels(imageSet.getTargetColumnsNames()); + } catch (IOException ex) { + throw new ModelCreationException("Failed to create convolutional network from JSON file", ex); + } + + LOGGER.info("Done!"); + LOGGER.info("Training neural network"); + + // create a set of convolutional networks and do training, crossvalidation and performance evaluation + BackpropagationTrainer trainer = new BackpropagationTrainer(neuralNet) + .setLearningRate(block.getLearningRate()) + .setMomentum(0.7f) + .setMaxError(block.getMaxError()) + .setMaxEpochs(block.getMaxEpochs()) + .setBatchMode(false) + .setOptimizer(OptimizerType.SGD); + trainer.train(imageSet); + + ImageClassifierNetwork imageClassifier = new ImageClassifierNetwork(neuralNet); + try { + FileIO.writeToFile(neuralNet, block.getExportPath().toFile().getAbsolutePath()); + } catch (IOException ex) { + throw new ModelCreationException("Failed to write trained model to file", ex); + } + + imageClassifier.setThreshold(block.getThreshold()); + + return imageClassifier; + } +} diff --git a/src/main/java/javax/visrec/ri/spi/DeepNettsImplementationService.java b/src/main/java/javax/visrec/ri/spi/DeepNettsImplementationService.java index e9ffbed..09b2eea 100644 --- a/src/main/java/javax/visrec/ri/spi/DeepNettsImplementationService.java +++ b/src/main/java/javax/visrec/ri/spi/DeepNettsImplementationService.java @@ -1,21 +1,40 @@ -package javax.visrec.ri.spi; - -import javax.visrec.spi.ImplementationService; - -/** - * DeepNetts' {@link ImplementationService} - */ -public class DeepNettsImplementationService extends ImplementationService { - - /** {@inheritDoc} */ - @Override - public String getName() { - return "DeepNetts"; - } - - /** {@inheritDoc} */ - @Override - public String getVersion() { - return "1.1"; - } -} +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +package javax.visrec.ri.spi; + +import javax.visrec.spi.ImplementationService; + +/** + * DeepNetts' {@link ImplementationService} + */ +public class DeepNettsImplementationService extends ImplementationService { + + /** {@inheritDoc} */ + @Override + public String getName() { + return "DeepNetts"; + } + + /** {@inheritDoc} */ + @Override + public String getVersion() { + return "1.1"; + } +} diff --git a/src/main/java/javax/visrec/ri/spi/DefaultImageFactoryService.java b/src/main/java/javax/visrec/ri/spi/DefaultImageFactoryService.java index c0d29f5..4e30c70 100644 --- a/src/main/java/javax/visrec/ri/spi/DefaultImageFactoryService.java +++ b/src/main/java/javax/visrec/ri/spi/DefaultImageFactoryService.java @@ -1,39 +1,58 @@ -package javax.visrec.ri.spi; - -import javax.visrec.ri.BufferedImageFactory; - -import javax.visrec.ImageFactory; -import javax.visrec.spi.ImageFactoryService; -import java.awt.image.BufferedImage; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; - -/** - * Default implementation of {@link ImageFactoryService} which serves the implementations of {@link ImageFactory}. - * - */ -public final class DefaultImageFactoryService implements ImageFactoryService { - - private static final Map, ImageFactory> imageFactories; - static { - imageFactories = new HashMap<>(); - imageFactories.put(BufferedImage.class, new BufferedImageFactory()); - } - - /** - * Get the {@link ImageFactory} by image type. - * @param imageCls image type in {@link Class} object which is able to - * be processed by the image factory implementation. - * @param image type. - * @return {@link ImageFactory} wrapped in {@link Optional}. If the {@link ImageFactory} could not be - * found then the {@link Optional} would contain null. - */ - @Override - public Optional> getByImageType(Class imageCls) { - Objects.requireNonNull(imageCls, "imageCls == null"); - ImageFactory imageFactory = imageFactories.get(imageCls); - return Optional.ofNullable((ImageFactory) imageFactory); - } -} +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +package javax.visrec.ri.spi; + +import javax.visrec.ri.BufferedImageFactory; + +import javax.visrec.ImageFactory; +import javax.visrec.spi.ImageFactoryService; +import java.awt.image.BufferedImage; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * Default implementation of {@link ImageFactoryService} which serves the implementations of {@link ImageFactory}. + * + */ +public final class DefaultImageFactoryService implements ImageFactoryService { + + private static final Map, ImageFactory> imageFactories; + static { + imageFactories = new HashMap<>(); + imageFactories.put(BufferedImage.class, new BufferedImageFactory()); + } + + /** + * Get the {@link ImageFactory} by image type. + * @param imageCls image type in {@link Class} object which is able to + * be processed by the image factory implementation. + * @param image type. + * @return {@link ImageFactory} wrapped in {@link Optional}. If the {@link ImageFactory} could not be + * found then the {@link Optional} would contain null. + */ + @Override + public Optional> getByImageType(Class imageCls) { + Objects.requireNonNull(imageCls, "imageCls == null"); + ImageFactory imageFactory = imageFactories.get(imageCls); + return Optional.ofNullable((ImageFactory) imageFactory); + } +} diff --git a/src/main/java/javax/visrec/ri/spi/DefaultServiceProvider.java b/src/main/java/javax/visrec/ri/spi/DefaultServiceProvider.java index 45c71d3..15ba41e 100644 --- a/src/main/java/javax/visrec/ri/spi/DefaultServiceProvider.java +++ b/src/main/java/javax/visrec/ri/spi/DefaultServiceProvider.java @@ -1,29 +1,48 @@ -package javax.visrec.ri.spi; - -import javax.visrec.spi.ImageFactoryService; -import javax.visrec.spi.ImplementationService; -import javax.visrec.spi.ServiceProvider; - -/** - * Default {@link ServiceProvider} of the implementation of the visual recognition API - * - */ -public final class DefaultServiceProvider extends ServiceProvider { - - /** - * {@inheritDoc} - */ - @Override - public ImageFactoryService getImageFactoryService() { - return new DefaultImageFactoryService(); - } - - /** - * {@inheritDoc} - */ - @Override - public ImplementationService getImplementationService() { - return new DeepNettsImplementationService(); - } - -} +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +package javax.visrec.ri.spi; + +import javax.visrec.spi.ImageFactoryService; +import javax.visrec.spi.ImplementationService; +import javax.visrec.spi.ServiceProvider; + +/** + * Default {@link ServiceProvider} of the implementation of the visual recognition API + * + */ +public final class DefaultServiceProvider extends ServiceProvider { + + /** + * {@inheritDoc} + */ + @Override + public ImageFactoryService getImageFactoryService() { + return new DefaultImageFactoryService(); + } + + /** + * {@inheritDoc} + */ + @Override + public ImplementationService getImplementationService() { + return new DeepNettsImplementationService(); + } + +} diff --git a/src/main/java/javax/visrec/ri/spi/FloatArrayBinaryClassifierFactory.java b/src/main/java/javax/visrec/ri/spi/FloatArrayBinaryClassifierFactory.java index 8ba9a4a..d442af0 100644 --- a/src/main/java/javax/visrec/ri/spi/FloatArrayBinaryClassifierFactory.java +++ b/src/main/java/javax/visrec/ri/spi/FloatArrayBinaryClassifierFactory.java @@ -1,55 +1,74 @@ -package javax.visrec.ri.spi; - -import deepnetts.data.MLDataItem; -import deepnetts.data.TabularDataSet; -import deepnetts.net.FeedForwardNetwork; -import deepnetts.net.layers.activation.ActivationType; -import deepnetts.net.loss.LossType; - -import javax.visrec.ml.classification.BinaryClassifier; -import javax.visrec.ml.model.ModelCreationException; -import javax.visrec.ml.classification.NeuralNetBinaryClassifier; -import javax.visrec.ri.ml.classification.FeedForwardNetBinaryClassifier; -import javax.visrec.ri.util.DataSets; -import javax.visrec.spi.BinaryClassifierFactory; -import java.io.IOException; - -public class FloatArrayBinaryClassifierFactory implements BinaryClassifierFactory { - - @Override - public Class getTargetClass() { - return float[].class; - } - - @Override - public BinaryClassifier create(NeuralNetBinaryClassifier.BuildingBlock block) throws ModelCreationException { - FeedForwardNetwork.Builder ffnBuilder = FeedForwardNetwork.builder(); - ffnBuilder.addInputLayer(block.getInputsNum()); - - for (int h : block.getHiddenLayers()) { - ffnBuilder.addFullyConnectedLayer(h); - } - - ffnBuilder.addOutputLayer(1, ActivationType.SIGMOID) - .lossFunction(LossType.CROSS_ENTROPY); - - FeedForwardNetwork ffn = ffnBuilder.build(); - ffn.getTrainer() - .setMaxEpochs(block.getMaxEpochs()) - .setMaxError(block.getMaxError()) - .setLearningRate(block.getLearningRate()); - - TabularDataSet trainingSet = null; - try { - trainingSet = DataSets.readCsv(block.getTrainingPath().toFile(), block.getInputsNum(), 1, true, ","); - //deepnetts.data.DataSets.normalizeMax(trainingSet); - } catch (IOException e) { - throw new ModelCreationException("Failed to create training set based on training file", e); - } - ffn.train(trainingSet); - FeedForwardNetBinaryClassifier ffnbc = new FeedForwardNetBinaryClassifier(ffn); - ffnbc.setThreshold(block.getThreshold()); - - return ffnbc; - } -} +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +package javax.visrec.ri.spi; + +import deepnetts.data.MLDataItem; +import deepnetts.data.TabularDataSet; +import deepnetts.net.FeedForwardNetwork; +import deepnetts.net.layers.activation.ActivationType; +import deepnetts.net.loss.LossType; + +import javax.visrec.ml.classification.BinaryClassifier; +import javax.visrec.ml.model.ModelCreationException; +import javax.visrec.ml.classification.NeuralNetBinaryClassifier; +import javax.visrec.ri.ml.classification.FeedForwardNetBinaryClassifier; +import javax.visrec.ri.util.DataSets; +import javax.visrec.spi.BinaryClassifierFactory; +import java.io.IOException; + +public class FloatArrayBinaryClassifierFactory implements BinaryClassifierFactory { + + @Override + public Class getTargetClass() { + return float[].class; + } + + @Override + public BinaryClassifier create(NeuralNetBinaryClassifier.BuildingBlock block) throws ModelCreationException { + FeedForwardNetwork.Builder ffnBuilder = FeedForwardNetwork.builder(); + ffnBuilder.addInputLayer(block.getInputsNum()); + + for (int h : block.getHiddenLayers()) { + ffnBuilder.addFullyConnectedLayer(h); + } + + ffnBuilder.addOutputLayer(1, ActivationType.SIGMOID) + .lossFunction(LossType.CROSS_ENTROPY); + + FeedForwardNetwork ffn = ffnBuilder.build(); + ffn.getTrainer() + .setMaxEpochs(block.getMaxEpochs()) + .setMaxError(block.getMaxError()) + .setLearningRate(block.getLearningRate()); + + TabularDataSet trainingSet = null; + try { + trainingSet = DataSets.readCsv(block.getTrainingPath().toFile(), block.getInputsNum(), 1, true, ","); + //deepnetts.data.DataSets.normalizeMax(trainingSet); + } catch (IOException e) { + throw new ModelCreationException("Failed to create training set based on training file", e); + } + ffn.train(trainingSet); + FeedForwardNetBinaryClassifier ffnbc = new FeedForwardNetBinaryClassifier(ffn); + ffnbc.setThreshold(block.getThreshold()); + + return ffnbc; + } +} diff --git a/src/main/java/javax/visrec/ri/util/DataSets.java b/src/main/java/javax/visrec/ri/util/DataSets.java index b461511..791c9cf 100644 --- a/src/main/java/javax/visrec/ri/util/DataSets.java +++ b/src/main/java/javax/visrec/ri/util/DataSets.java @@ -1,89 +1,108 @@ -package javax.visrec.ri.util; - -import deepnetts.data.MLDataItem; -import deepnetts.data.TabularDataSet; -import deepnetts.util.DeepNettsException; - -import java.io.*; - -public class DataSets { - - private DataSets() { - // Prevent instantiation - } - - /** - * Creates and returns data set from specified CSV file. Empty lines are - * skipped - * - * @param csvFile CSV file - * @param numInputs number of input values in a row - * @param numOutputs number of output values in a row - * @param hasColumnNames true if first row contains column names - * @param delimiter delimiter used to separate values - * @return instance of data set with values loaded from file - * - * @throws FileNotFoundException if file was not found - * @throws IOException if there was an error reading file - * - * TODO: Detect if there are labels in the first line, if there are no - * labels, set class1, class2, class3 in classifier evaluation! and detect - * type of attributes Move this method to some factory class or something? - * or as a default method in data set? - * - * TODO: should I wrap IO with DeepNetts Exception? - * Autodetetect delimiter; header and column type - * - */ - public static TabularDataSet readCsv(File csvFile, int numInputs, int numOutputs, boolean hasColumnNames, String delimiter) throws FileNotFoundException, IOException { - TabularDataSet dataSet = new TabularDataSet<>(numInputs, numOutputs); - BufferedReader br = new BufferedReader(new FileReader(csvFile)); - String line=null; - // auto detect column names - ako sadrzi slova onda ima imena. Sta ako su atributi nominalni? U ovoj fazi se pretpostavlja d anisu... - // i ako u redovima ispod takodje ima stringova u istoj koloni - detect header - if (hasColumnNames) { // get col names from the first line - line = br.readLine().trim(); - String[] colNames = line.split(delimiter); - // todo checsk number of col names - dataSet.setColumnNames(colNames); - } else { - String[] colNames = new String[numInputs+numOutputs]; - for(int i=0; i readCsv(File csvFile, int numInputs, int numOutputs, boolean hasColumnNames, String delimiter) throws FileNotFoundException, IOException { + TabularDataSet dataSet = new TabularDataSet<>(numInputs, numOutputs); + BufferedReader br = new BufferedReader(new FileReader(csvFile)); + String line=null; + // auto detect column names - ako sadrzi slova onda ima imena. Sta ako su atributi nominalni? U ovoj fazi se pretpostavlja d anisu... + // i ako u redovima ispod takodje ima stringova u istoj koloni - detect header + if (hasColumnNames) { // get col names from the first line + line = br.readLine().trim(); + String[] colNames = line.split(delimiter); + // todo checsk number of col names + dataSet.setColumnNames(colNames); + } else { + String[] colNames = new String[numInputs+numOutputs]; + for(int i=0; i> imageFactory = ServiceProvider.current().getImageFactoryService().getByImageType(BufferedImage.class); - assertTrue(imageFactory.isPresent()); - // If the casting fails, the implementation is incorrect and it will fail the test. - BufferedImageFactory.class.cast(imageFactory.get()); - } -} +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +package visrec.ri.spi; + +import org.junit.jupiter.api.Test; + +import javax.visrec.ImageFactory; +import javax.visrec.spi.ImageFactoryService; +import javax.visrec.spi.ServiceProvider; +import java.awt.image.BufferedImage; +import java.util.Optional; +import javax.visrec.ri.BufferedImageFactory; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Unit tests regarding the {@link DefaultImageFactoryService} + * @author Kevin Berendsen + */ +public class DefaultImageFactoryServiceTest { + + /** + * Test the instantiation of {@link BufferedImageFactory} through the {@link ImageFactoryService} + */ + @Test + public void testBufferedImageImageFactoryInstantiation() { + Optional> imageFactory = ServiceProvider.current().getImageFactoryService().getByImageType(BufferedImage.class); + assertTrue(imageFactory.isPresent()); + // If the casting fails, the implementation is incorrect and it will fail the test. + BufferedImageFactory.class.cast(imageFactory.get()); + } +} diff --git a/src/test/java/visrec/ri/util/BuilderConfigurationTest.java b/src/test/java/visrec/ri/util/BuilderConfigurationTest.java index ae1f383..c89eec8 100644 --- a/src/test/java/visrec/ri/util/BuilderConfigurationTest.java +++ b/src/test/java/visrec/ri/util/BuilderConfigurationTest.java @@ -1,70 +1,88 @@ -package visrec.ri.util; - -import org.junit.jupiter.api.Test; - -import javax.visrec.ml.model.ModelCreationException; -import javax.visrec.ml.model.InvalidConfigurationException; -import java.util.HashMap; -import java.util.Map; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.fail; -import javax.visrec.ml.model.ModelBuilder; - -/** - * @author Kevin Berendsen - */ -public class BuilderConfigurationTest { - - /** - * Successfully build the output of the builder. - */ - @Test - public void testReflectionInvocationBuild() throws ModelCreationException { - Map trainingSet = new HashMap<>(); - trainingSet.put("hello", "world"); - trainingSet.put("lorem", "ipsum"); - Map configuration = new HashMap<>(); - configuration.put("trainingSet", trainingSet); - - BuilderImpl builder = new BuilderImpl(); - String output = builder.build(configuration); - assertEquals("{lorem=ipsum, hello=world}", output); - } - - /** - * The trainingSet method is invoked without the valid parameter and should - * throw an exception. - */ - @Test - public void testInvalidParameterForMethod() { - String trainingSet = "invalid"; - Map configuration = new HashMap<>(); - configuration.put("trainingSet", trainingSet); - - BuilderImpl builder = new BuilderImpl(); - try { - builder.build(configuration); - fail("The configuration is invalid and should throw the InvalidBuilderConfigurationException"); - } catch (ModelCreationException e) { - /* Expected */ - } - } - - - - public static class BuilderImpl implements ModelBuilder { - - private Map trainingSet; - - public void trainingSet(Map trainingSet) { - this.trainingSet = trainingSet; - } - - @Override - public String build() { - return trainingSet.toString(); - } - } - -} +/** + * Visual Recognition API for Java, JSR381 + * Copyright (C) 2020 Zoran Sevarac, Frank Greco + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +package visrec.ri.util; + +import org.junit.jupiter.api.Test; + +import javax.visrec.ml.model.ModelCreationException; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; +import javax.visrec.ml.model.ModelBuilder; + +/** + * @author Kevin Berendsen + */ +public class BuilderConfigurationTest { + + /** + * Successfully build the output of the builder. + */ + @Test + public void testReflectionInvocationBuild() throws ModelCreationException { + Map trainingSet = new HashMap<>(); + trainingSet.put("hello", "world"); + trainingSet.put("lorem", "ipsum"); + Map configuration = new HashMap<>(); + configuration.put("trainingSet", trainingSet); + + BuilderImpl builder = new BuilderImpl(); + String output = builder.build(configuration); + assertEquals("{lorem=ipsum, hello=world}", output); + } + + /** + * The trainingSet method is invoked without the valid parameter and should + * throw an exception. + */ + @Test + public void testInvalidParameterForMethod() { + String trainingSet = "invalid"; + Map configuration = new HashMap<>(); + configuration.put("trainingSet", trainingSet); + + BuilderImpl builder = new BuilderImpl(); + try { + builder.build(configuration); + fail("The configuration is invalid and should throw the InvalidBuilderConfigurationException"); + } catch (ModelCreationException e) { + /* Expected */ + } + } + + + + public static class BuilderImpl implements ModelBuilder { + + private Map trainingSet; + + public void trainingSet(Map trainingSet) { + this.trainingSet = trainingSet; + } + + @Override + public String build() { + return trainingSet.toString(); + } + } + +}