Skip to content

Commit

Permalink
Merge pull request Sergio0694#59 from Sergio0694/dev
Browse files Browse the repository at this point in the history
New APIs, minor improvements
  • Loading branch information
Sergio0694 authored Jan 9, 2018
2 parents b92b399 + 144a970 commit 6764628
Show file tree
Hide file tree
Showing 20 changed files with 361 additions and 421 deletions.
101 changes: 95 additions & 6 deletions NeuralNetwork.NET/APIs/DatasetLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
using JetBrains.Annotations;
using NeuralNetworkNET.APIs.Interfaces.Data;
using NeuralNetworkNET.Extensions;
using NeuralNetworkNET.Helpers;
using NeuralNetworkNET.SupervisedLearning.Data;
using NeuralNetworkNET.SupervisedLearning.Optimization.Parameters;
using NeuralNetworkNET.SupervisedLearning.Optimization.Progress;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.PixelFormats;

namespace NeuralNetworkNET.APIs
{
Expand All @@ -30,12 +33,12 @@ public static class DatasetLoader
/// <summary>
/// Creates a new <see cref="ITrainingDataset"/> instance to train a network from the input collection, with the specified batch size
/// </summary>
/// <param name="data">The source collection to use to build the training dataset</param>
/// <param name="data">The source collection to use to build the training dataset, where the samples will be extracted from the input <see cref="Func{TResult}"/> instances in parallel</param>
/// <param name="size">The desired dataset batch size</param>
[PublicAPI]
[Pure, NotNull]
[CollectionAccess(CollectionAccessType.Read)]
public static ITrainingDataset Training([NotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, int size) => BatchesCollection.From(data, size);
public static ITrainingDataset Training([NotNull, ItemNotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, int size) => BatchesCollection.From(data, size);

/// <summary>
/// Creates a new <see cref="ITrainingDataset"/> instance to train a network from the input matrices, with the specified batch size
Expand All @@ -47,6 +50,34 @@ public static class DatasetLoader
[CollectionAccess(CollectionAccessType.Read)]
public static ITrainingDataset Training((float[,] X, float[,] Y) data, int size) => BatchesCollection.From(data, size);

/// <summary>
/// Creates a new <see cref="ITrainingDataset"/> instance to train a network from the input data, where each input sample is an image in a specified format
/// </summary>
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a vector with the expected outputs</param>
/// <param name="size">The desired dataset batch size</param>
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
[PublicAPI]
[Pure, NotNull]
[CollectionAccess(CollectionAccessType.Read)]
public static ITrainingDataset Training<TPixel>([NotNull] IEnumerable<(String X, float[] Y)> data, int size, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
where TPixel : struct, IPixel<TPixel>
=> BatchesCollection.From(data.Select<(String X, float[] Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y)), size);

/// <summary>
/// Creates a new <see cref="ITrainingDataset"/> instance to train a network from the input data, where each input sample is an image in a specified format
/// </summary>
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a <see cref="Func{TResult}"/> returning a vector with the expected outputs</param>
/// <param name="size">The desired dataset batch size</param>
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
[PublicAPI]
[Pure, NotNull]
[CollectionAccess(CollectionAccessType.Read)]
public static ITrainingDataset Training<TPixel>([NotNull] IEnumerable<(String X, Func<float[]> Y)> data, int size, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
where TPixel : struct, IPixel<TPixel>
=> BatchesCollection.From(data.Select<(String X, Func<float[]> Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y())), size);

#endregion

#region Validation
Expand All @@ -66,13 +97,13 @@ public static IValidationDataset Validation([NotNull] IEnumerable<(float[] X, fl
/// <summary>
/// Creates a new <see cref="IValidationDataset"/> instance to validate a network accuracy from the input collection
/// </summary>
/// <param name="data">The source collection to use to build the validation dataset</param>
/// <param name="data">The source collection to use to build the validation dataset, where the samples will be extracted from the input <see cref="Func{TResult}"/> instances in parallel</param>
/// <param name="tolerance">The desired tolerance to test the network for convergence</param>
/// <param name="epochs">The epochs interval to consider when testing the network for convergence</param>
[PublicAPI]
[Pure, NotNull]
[CollectionAccess(CollectionAccessType.Read)]
public static IValidationDataset Validation([NotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, float tolerance = 1e-2f, int epochs = 5)
public static IValidationDataset Validation([NotNull, ItemNotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, float tolerance = 1e-2f, int epochs = 5)
=> Validation(data.AsParallel().Select(f => f()), tolerance, epochs);

/// <summary>
Expand All @@ -86,6 +117,36 @@ public static IValidationDataset Validation([NotNull] IEnumerable<Func<(float[]
[CollectionAccess(CollectionAccessType.Read)]
public static IValidationDataset Validation((float[,] X, float[,] Y) data, float tolerance = 1e-2f, int epochs = 5) => new ValidationDataset(data, tolerance, epochs);

/// <summary>
/// Creates a new <see cref="IValidationDataset"/> instance to validate a network accuracy from the input collection
/// </summary>
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a vector with the expected outputs</param>
/// <param name="tolerance">The desired tolerance to test the network for convergence</param>
/// <param name="epochs">The epochs interval to consider when testing the network for convergence</param>
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
[PublicAPI]
[Pure, NotNull]
[CollectionAccess(CollectionAccessType.Read)]
public static IValidationDataset Validation<TPixel>([NotNull] IEnumerable<(String X, float[] Y)> data, float tolerance = 1e-2f, int epochs = 5, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
where TPixel : struct, IPixel<TPixel>
=> Validation(data.Select<(String X, float[] Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y)).AsParallel(), tolerance, epochs);

/// <summary>
/// Creates a new <see cref="IValidationDataset"/> instance to validate a network accuracy from the input collection
/// </summary>
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a <see cref="Func{TResult}"/> returning a vector with the expected outputs</param>
/// <param name="tolerance">The desired tolerance to test the network for convergence</param>
/// <param name="epochs">The epochs interval to consider when testing the network for convergence</param>
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
[PublicAPI]
[Pure, NotNull]
[CollectionAccess(CollectionAccessType.Read)]
public static IValidationDataset Validation<TPixel>([NotNull] IEnumerable<(String X, Func<float[]> Y)> data, float tolerance = 1e-2f, int epochs = 5, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
where TPixel : struct, IPixel<TPixel>
=> Validation(data.Select<(String X, Func<float[]> Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y())).AsParallel(), tolerance, epochs);

#endregion

#region Test
Expand All @@ -104,12 +165,12 @@ public static ITestDataset Test([NotNull] IEnumerable<(float[] X, float[] Y)> da
/// <summary>
/// Creates a new <see cref="ITestDataset"/> instance to test a network from the input collection
/// </summary>
/// <param name="data">The source collection to use to build the test dataset</param>
/// <param name="data">The source collection to use to build the test dataset, where the samples will be extracted from the input <see cref="Func{TResult}"/> instances in parallel</param>
/// <param name="progress">The optional progress callback to use</param>
[PublicAPI]
[Pure, NotNull]
[CollectionAccess(CollectionAccessType.Read)]
public static ITestDataset Test([NotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, [CanBeNull] IProgress<TrainingProgressEventArgs> progress = null)
public static ITestDataset Test([NotNull, ItemNotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, [CanBeNull] IProgress<TrainingProgressEventArgs> progress = null)
=> Test(data.AsParallel().Select(f => f()), progress);

/// <summary>
Expand All @@ -122,6 +183,34 @@ public static ITestDataset Test([NotNull] IEnumerable<Func<(float[] X, float[] Y
[CollectionAccess(CollectionAccessType.Read)]
public static ITestDataset Test((float[,] X, float[,] Y) data, [CanBeNull] IProgress<TrainingProgressEventArgs> progress = null) => new TestDataset(data, progress);

/// <summary>
/// Creates a new <see cref="ITestDataset"/> instance to test a network from the input collection
/// </summary>
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a vector with the expected outputs</param>
/// <param name="progress">The optional progress callback to use</param>
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
[PublicAPI]
[Pure, NotNull]
[CollectionAccess(CollectionAccessType.Read)]
public static ITestDataset Test<TPixel>([NotNull] IEnumerable<(String X, float[] Y)> data, [CanBeNull] IProgress<TrainingProgressEventArgs> progress = null, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
where TPixel : struct, IPixel<TPixel>
=> Test(data.Select<(String X, float[] Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y)).AsParallel(), progress);

/// <summary>
/// Creates a new <see cref="ITestDataset"/> instance to test a network from the input collection
/// </summary>
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a <see cref="Func{TResult}"/> returning a vector with the expected outputs</param>
/// <param name="progress">The optional progress callback to use</param>
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
[PublicAPI]
[Pure, NotNull]
[CollectionAccess(CollectionAccessType.Read)]
public static ITestDataset Test<TPixel>([NotNull] IEnumerable<(String X, Func<float[]> Y)> data, [CanBeNull] IProgress<TrainingProgressEventArgs> progress = null, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
where TPixel : struct, IPixel<TPixel>
=> Test(data.Select<(String X, Func<float[]> Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y())).AsParallel(), progress);

#endregion
}
}
27 changes: 18 additions & 9 deletions NeuralNetwork.NET/APIs/Structs/TensorInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using SixLabors.ImageSharp.PixelFormats;

namespace NeuralNetworkNET.APIs.Structs
{
Expand Down Expand Up @@ -67,30 +68,38 @@ internal TensorInfo(int height, int width, int channels)
}

/// <summary>
/// Creates a new <see cref="TensorInfo"/> instance for an RGB image
/// Creates a new <see cref="TensorInfo"/> instance for a linear network layer, without keeping track of spatial info
/// </summary>
/// <param name="height">The height of the input image</param>
/// <param name="width">The width of the input image</param>
/// <param name="size">The input size</param>
[PublicAPI]
[Pure]
public static TensorInfo CreateForRgbImage(int height, int width) => new TensorInfo(height, width, 3);
public static TensorInfo Linear(int size) => new TensorInfo(1, 1, size);

/// <summary>
/// Creates a new <see cref="TensorInfo"/> instance for a grayscale image
/// Creates a new <see cref="TensorInfo"/> instance for an image with a user-defined pixel type
/// </summary>
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
/// <param name="height">The height of the input image</param>
/// <param name="width">The width of the input image</param>
[PublicAPI]
[Pure]
public static TensorInfo CreateForGrayscaleImage(int height, int width) => new TensorInfo(height, width, 1);
public static TensorInfo Image<TPixel>(int height, int width) where TPixel : struct, IPixel<TPixel>
{
if (typeof(TPixel) == typeof(Alpha8)) return new TensorInfo(height, width, 1);
if (typeof(TPixel) == typeof(Rgb24)) return new TensorInfo(height, width, 3);
if (typeof(TPixel) == typeof(Argb32)) return new TensorInfo(height, width, 4);
throw new InvalidOperationException($"The {typeof(TPixel).Name} pixel format isn't currently supported");
}

/// <summary>
/// Creates a new <see cref="TensorInfo"/> instance for a linear network layer, without keeping track of spatial info
/// Creates a new <see cref="TensorInfo"/> instance for with a custom 3D shape
/// </summary>
/// <param name="size">The input size</param>
/// <param name="height">The input volume height</param>
/// <param name="width">The input volume width</param>
/// <param name="channels">The number of channels in the input volume</param>
[PublicAPI]
[Pure]
public static TensorInfo CreateLinear(int size) => new TensorInfo(1, 1, size);
public static TensorInfo Volume(int height, int width, int channels) => new TensorInfo(height, width, channels);

#endregion

Expand Down
22 changes: 20 additions & 2 deletions NeuralNetwork.NET/Extensions/MiscExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading.Tasks;
using JetBrains.Annotations;

Expand All @@ -20,8 +21,10 @@ public static class MiscExtensions
/// <param name="item">The item to cast</param>
[Pure, NotNull]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static TOut To<TIn, TOut>([NotNull] this TIn item) where TOut : class, TIn => item as TOut
?? throw new InvalidOperationException($"The item of type {typeof(TIn)} is a {item.GetType()} instance and can't be cast to {typeof(TOut)}");
public static TOut To<TIn, TOut>([NotNull] this TIn item)
where TIn : class
where TOut : TIn
=> (TOut)item;

/// <summary>
/// Returns the maximum value between two numbers
Expand Down Expand Up @@ -138,5 +141,20 @@ public static void AssertCompleted(in this ParallelLoopResult result)
{
if (!result.IsCompleted) throw new InvalidOperationException("Error while performing the parallel loop");
}

/// <summary>
/// Removes the left spaces from the input verbatim string
/// </summary>
/// <param name="text">The string to trim</param>
[Pure, NotNull]
public static String TrimVerbatim([NotNull] this String text)
{
String[] lines = text.Split(new[] { Environment.NewLine }, StringSplitOptions.None);
return lines.Aggregate(new StringBuilder(), (b, s) =>
{
b.AppendLine(s.Trim());
return b;
}).ToString();
}
}
}
Loading

0 comments on commit 6764628

Please sign in to comment.