Skip to content

Commit

Permalink
- added ONNX runtime to run neural networks on spectrograms for class…
Browse files Browse the repository at this point in the history
…ification and other analysis features!
  • Loading branch information
christoph-hart committed Nov 7, 2024
1 parent f64d5ed commit 8448dfb
Show file tree
Hide file tree
Showing 30 changed files with 26,626 additions and 5,318 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,7 @@ extras/hise_dialogs/AdditionalSourceCode/
extras/hise_dialogs/expansion_info.xml
extras/hise_dialogs/project_info.xml
extras/hise_dialogs/user_info.xml
tools/onnx_lib/Builds/
tools/onnx_lib/JuceLibraryCode/
tools/onnx_lib/Source/lib/
tools/onnx_lib/onnx_hise_library.dll
21 changes: 21 additions & 0 deletions hi_core/hi_core/MainController.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,27 @@ bool MainController::shouldUseSoftBypassRamps() const noexcept
#endif
}

ONNXLoader::Ptr MainController::getONNXLoader()
{
if(onnxLoader == nullptr)
{
#if USE_BACKEND || 1
File libraryPath(GET_HISE_SETTING(getMainSynthChain(), HiseSettings::Compiler::HisePath).toString());
libraryPath = libraryPath.getChildFile("tools/onnx_lib");
#else
auto libraryPath = FrontendHandler::getAppDataDirectory(this);
#endif
onnxLoader = new ONNXLoader(libraryPath.getFullPathName());
}

return onnxLoader;
}

MarkdownContentProcessor* MainController::getCurrentMarkdownPreview()
{
return currentPreview;
}

void callOnAllChildren(Component* c, const std::function<void(Component*)>& f)
{
f(c);
Expand Down
9 changes: 5 additions & 4 deletions hi_core/hi_core/MainController.h
Original file line number Diff line number Diff line change
Expand Up @@ -1974,10 +1974,9 @@ class MainController: public GlobalScriptCompileBroadcaster,
defaultPresetHandler = ownedHandler;
}

MarkdownContentProcessor* getCurrentMarkdownPreview()
{
return currentPreview;
}
ONNXLoader::Ptr getONNXLoader();

MarkdownContentProcessor* getCurrentMarkdownPreview();

MultiChannelAudioBuffer::XYZPool* getXYZPool()
{
Expand Down Expand Up @@ -2285,6 +2284,8 @@ class MainController: public GlobalScriptCompileBroadcaster,

DebugLogger debugLogger;

hise::ONNXLoader::Ptr onnxLoader;

#if USE_BACKEND
Component::SafePointer<ScriptWatchTable> scriptWatchTable;
Array<Component::SafePointer<ScriptComponentEditPanel>> scriptComponentEditPanels;
Expand Down
106 changes: 86 additions & 20 deletions hi_scripting/scripting/api/ScriptingApiObjects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ struct ScriptingObjects::ScriptFile::Wrapper
API_METHOD_WRAPPER_0(ScriptFile, loadAsString);
API_METHOD_WRAPPER_0(ScriptFile, loadAsObject);
API_METHOD_WRAPPER_0(ScriptFile, loadAsAudioFile);
API_METHOD_WRAPPER_0(ScriptFile, loadAsBase64String);
API_METHOD_WRAPPER_0(ScriptFile, getNonExistentSibling);
API_METHOD_WRAPPER_0(ScriptFile, deleteFileOrDirectory);
API_METHOD_WRAPPER_1(ScriptFile, loadEncryptedObject);
Expand Down Expand Up @@ -276,6 +277,7 @@ ScriptingObjects::ScriptFile::ScriptFile(ProcessorWithScriptingContent* p, const
ADD_API_METHOD_1(loadEncryptedObject);
ADD_API_METHOD_0(loadMidiMetadata);
ADD_API_METHOD_0(loadAudioMetadata);
ADD_API_METHOD_0(loadAsBase64String);
ADD_API_METHOD_1(rename);
ADD_API_METHOD_1(move);
ADD_API_METHOD_1(copy);
Expand Down Expand Up @@ -728,6 +730,13 @@ bool ScriptingObjects::ScriptFile::writeEncryptedObject(var jsonData, String key
return f.replaceWithText(out.toBase64Encoding());
}

String ScriptingObjects::ScriptFile::loadAsBase64String() const
{
MemoryBlock mb;
f.loadFileAsData(mb);
return mb.toBase64Encoding();
}

String ScriptingObjects::ScriptFile::loadAsString() const
{
return f.loadFileAsString();
Expand Down Expand Up @@ -2220,6 +2229,7 @@ ScriptingObjects::ScriptingSamplerSound::ScriptingSamplerSound(ProcessorWithScri
sampleIds.add(SampleIds::LoopEnd);
sampleIds.add(SampleIds::LoopXFade);
sampleIds.add(SampleIds::LoopEnabled);
sampleIds.add(SampleIds::ReleaseStart);
sampleIds.add(SampleIds::LowerVelocityXFade);
sampleIds.add(SampleIds::UpperVelocityXFade);
sampleIds.add(SampleIds::SampleState);
Expand Down Expand Up @@ -5009,6 +5019,8 @@ struct ScriptingObjects::ScriptNeuralNetwork::Wrapper
API_VOID_METHOD_WRAPPER_1(ScriptNeuralNetwork, loadTensorFlowModel);
API_VOID_METHOD_WRAPPER_1(ScriptNeuralNetwork, loadPytorchModel);
API_METHOD_WRAPPER_1(ScriptNeuralNetwork, createModelJSONFromTextFile);
API_METHOD_WRAPPER_2(ScriptNeuralNetwork, loadOnnxModel);
API_METHOD_WRAPPER_3(ScriptNeuralNetwork, processFFTSpectrum);
};

ScriptingObjects::ScriptNeuralNetwork::ScriptNeuralNetwork(ProcessorWithScriptingContent* p, const String& name):
Expand All @@ -5023,6 +5035,8 @@ ScriptingObjects::ScriptNeuralNetwork::ScriptNeuralNetwork(ProcessorWithScriptin
ADD_API_METHOD_1(loadTensorFlowModel);
ADD_API_METHOD_1(loadPytorchModel);
ADD_API_METHOD_0(getModelJSON);
ADD_API_METHOD_2(loadOnnxModel);
ADD_API_METHOD_3(processFFTSpectrum);

#if HISE_INCLUDE_RT_NEURAL
nn = p->getMainController_()->getNeuralNetworks().getOrCreate(Identifier(name));
Expand Down Expand Up @@ -5243,6 +5257,62 @@ void ScriptingObjects::ScriptNeuralNetwork::loadPytorchModel(const var& modelJSO
#endif
}

bool ScriptingObjects::ScriptNeuralNetwork::loadOnnxModel(const var& base64Data, int numOutputs)
{
if(onnx == nullptr)
onnx = getScriptProcessor()->getMainController_()->getONNXLoader();

onnxOutput.resize(numOutputs);

for(auto& v: onnxOutput)
v = 0.0f;

MemoryBlock mb;
mb.fromBase64Encoding(base64Data.toString());
auto ok = onnx->loadModel(mb);

if(ok.failed())
{
reportScriptError(ok.getErrorMessage());
RETURN_IF_NO_THROW(false);
}

return true;
}

var ScriptingObjects::ScriptNeuralNetwork::processFFTSpectrum(var fftObject, int numFreqPixels, int numTimePixels)
{
if(auto fft = dynamic_cast<ScriptFFT*>(fftObject.getObject()))
{
if(onnx != nullptr)
{
auto img = fft->getRescaledAndRotatedSpectrum(false, numFreqPixels, numTimePixels);

auto parameters = fft->getSpectrum2DParameters();
auto isGreyscale = (int)parameters["ColourScheme"] == 0;

onnx->run(img, onnxOutput, isGreyscale);

Array<var> outputValues;

for(auto& v: onnxOutput)
outputValues.add(v);

return var(outputValues);
}
else
{
reportScriptError("ONNX model is not loaded. use loadOnnxModel() before calling this method");
}
}
else
{
reportScriptError("fftObject is not a FFT object.");
}

RETURN_IF_NO_THROW(var(Array<var>()));
}

var ScriptingObjects::ScriptNeuralNetwork::getModelJSON()
{
#if HISE_INCLUDE_RT_NEURAL
Expand Down Expand Up @@ -7553,6 +7623,21 @@ var ScriptingObjects::ScriptFFT::getSpectrum2DParameters() const
return d;
}

Image ScriptingObjects::ScriptFFT::getRescaledAndRotatedSpectrum(bool getOutput, int numFreqPixels, int numTimePixels)
{
auto thisImg = getSpectrum(getOutput).rescaled(numFreqPixels, numTimePixels, Graphics::ResamplingQuality::highResamplingQuality);
Image rotated(Image::PixelFormat::RGB, thisImg.getHeight(), thisImg.getWidth(), false);
Image::BitmapData r(rotated, Image::BitmapData::writeOnly);

for(int y = 0; y < rotated.getHeight(); y++)
{
for(int x = 0; x < rotated.getWidth(); x++)
rotated.setPixelAt(x, y, thisImg.getPixelAt(rotated.getHeight() - y - 1, x));
}

return rotated;
}

bool ScriptingObjects::ScriptFFT::dumpSpectrum(var file, bool output, int numFreqPixels, int numTimePixels)
{
auto img = output ? outputSpectrum : spectrum;
Expand All @@ -7561,27 +7646,8 @@ bool ScriptingObjects::ScriptFFT::dumpSpectrum(var file, bool output, int numFre
{
sf->f.deleteFile();
FileOutputStream fos(sf->f);

auto rotated = getRescaledAndRotatedSpectrum(output, numFreqPixels, numTimePixels);
PNGImageFormat f;

auto thisImg = img.rescaled(numFreqPixels, numTimePixels, Graphics::ResamplingQuality::highResamplingQuality);

Image rotated(Image::PixelFormat::RGB, thisImg.getHeight(), thisImg.getWidth(), false);

Image::BitmapData r(rotated, Image::BitmapData::writeOnly);

for(int y = 0; y < rotated.getHeight(); y++)
{
auto line = r.getLinePointer(y);

for(int x = 0; x < rotated.getWidth(); x++)
{
//line[x] = thisImg.getPixelAt(rotated.getHeight() - y - 1, x);
rotated.setPixelAt(x, y, thisImg.getPixelAt(rotated.getHeight() - y - 1, x));
//line[x] = roundToInt(thisImg.getPixelAt(rotated.getHeight() - y, x).getBrightness() * 255.0f);
}
}

return f.writeImageToStream(rotated, fos);
}

Expand Down
14 changes: 14 additions & 0 deletions hi_scripting/scripting/api/ScriptingApiObjects.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ namespace ScriptingObjects
/** Loads the track (zero-based) of the MIDI file. If successful, it returns an object containing the time signature and a list of all events. */
var loadAsMidiFile(int trackIndex);

/** Loads the binary file, compresses it with zstd and returns a Base64 string. */
String loadAsBase64String() const;

/** Replaces the file content with the given text. */
bool writeString(String text);

Expand Down Expand Up @@ -759,6 +762,8 @@ namespace ScriptingObjects

Image getSpectrum(bool getOutput) const { return getOutput ? outputSpectrum : spectrum; }

Image getRescaledAndRotatedSpectrum(bool getOutput, int numFreqPixels, int numTimePixels);

private:

AudioSampleBuffer windowBuffer;
Expand Down Expand Up @@ -1589,6 +1594,12 @@ namespace ScriptingObjects
/** Loads the model layout and weights from a Pytorch model JSON. */
void loadPytorchModel(const var& modelJSON);

/** Loads the ONNX runtime model for spectral analysis. */
bool loadOnnxModel(const var& base64Data, int numOutputValues);

/** Processes the FFT spectrum and returns the output tensor as array of float numbers. */
var processFFTSpectrum(var fftObject, int numFreqPixels, int numTimePixels);

/** Returns the model JSON. */
var getModelJSON();

Expand Down Expand Up @@ -1618,6 +1629,9 @@ namespace ScriptingObjects
NeuralNetwork::Ptr nn;
#endif

ONNXLoader::Ptr onnx;
std::vector<float> onnxOutput;

JUCE_DECLARE_WEAK_REFERENCEABLE(ScriptNeuralNetwork);
};

Expand Down
Loading

0 comments on commit 8448dfb

Please sign in to comment.