From 3dddc383e7800edd1db0390bfb219fe04bdfac6f Mon Sep 17 00:00:00 2001 From: zamfofex Date: Sat, 5 Oct 2024 20:27:47 -0300 Subject: [PATCH] add simple JS backend --- cross-files/wasm32-emscripten | 26 +++++ js/.gitignore | 5 + js/README.md | 19 ++++ js/build.sh | 29 +++++ js/example/.gitignore | 5 + js/example/README.md | 16 +++ js/example/index.html | 70 ++++++++++++ js/example/package.json | 11 ++ js/example/vite.config.js | 11 ++ js/main.js | 94 ++++++++++++++++ js/package.json | 10 ++ js/worker.js | 171 ++++++++++++++++++++++++++++++ meson.build | 5 + src/chess/uciloop.cc | 31 +++++- src/mcts/node.h | 2 +- src/mcts/search.cc | 6 ++ src/neural/backends/network_js.cc | 165 ++++++++++++++++++++++++++++ src/neural/onnx/builder.cc | 9 ++ src/neural/onnx/builder.h | 1 + src/neural/onnx/converter.cc | 10 +- src/neural/onnx/converter.h | 1 + 21 files changed, 694 insertions(+), 3 deletions(-) create mode 100644 cross-files/wasm32-emscripten create mode 100644 js/.gitignore create mode 100644 js/README.md create mode 100755 js/build.sh create mode 100644 js/example/.gitignore create mode 100644 js/example/README.md create mode 100644 js/example/index.html create mode 100644 js/example/package.json create mode 100644 js/example/vite.config.js create mode 100644 js/main.js create mode 100644 js/package.json create mode 100644 js/worker.js create mode 100644 src/neural/backends/network_js.cc diff --git a/cross-files/wasm32-emscripten b/cross-files/wasm32-emscripten new file mode 100644 index 0000000000..b25b004733 --- /dev/null +++ b/cross-files/wasm32-emscripten @@ -0,0 +1,26 @@ +[host_machine] +system = 'emscripten' +cpu_family = 'wasm32' +cpu = 'wasm32' +endian = 'little' + +[binaries] +c = 'emcc' +cpp = 'em++' +ar = 'emar' +strip = 'emstrip' + +[built-in options] +cpp_args = ['--use-port=zlib', '-fexceptions', '-msse', '-msse2', '-msse3', '-msimd128'] +cpp_link_args = [ + '--use-port=zlib', + '-fexceptions', + '-sASYNCIFY', '-sASYNCIFY_STACK_SIZE=65536', + '-STACK_SIZE=1048576', + '-sMODULARIZE', '-sEXPORT_ES6', + '-sDEFAULT_LIBRARY_FUNCS_TO_INCLUDE=$stringToNewUTF8', + '-sALLOW_MEMORY_GROWTH', + '-sWASM_BIGINT', + '-sENVIRONMENT=worker', + '-sEXTRA_EXPORTED_RUNTIME_METHODS=["FS"]', + ] diff --git a/js/.gitignore b/js/.gitignore new file mode 100644 index 0000000000..0cc0c39d6f --- /dev/null +++ b/js/.gitignore @@ -0,0 +1,5 @@ +node_modules +package-lock.json +dist +build +lc0.tar diff --git a/js/README.md b/js/README.md new file mode 100644 index 0000000000..68a378dd0f --- /dev/null +++ b/js/README.md @@ -0,0 +1,19 @@ + + +Leela Chess Zero (Wasm) +=== + +Lc0 compiled to WebAssembly (running on WebGPU when supported). + +compiling +--- + +- Install [Emscripten]. +- Install [npm]. +- Install [Meson]. +- Run `npm install` +- Run `npm run build` (or `./build.sh`) + +[Emscripten]: +[npm]: +[Meson]: diff --git a/js/build.sh b/js/build.sh new file mode 100755 index 0000000000..065fa7c10e --- /dev/null +++ b/js/build.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env sh +set -ex +meson setup --buildtype=release -Ddefault_library=static --prefer-static --cross-file=../cross-files/wasm32-emscripten -Dblas=false build .. || : +meson compile -C build lc0 +esbuild --minify --outdir=dist --format=esm main.js worker.js build/lc0.js build/lc0.worker.mjs +mv dist/build/lc0.worker.js dist/build/lc0.worker.mjs +cp build/lc0.wasm dist/build +cat > dist/package.json << END +{ + "name": "lc0", + "description": "Leela Chess Zero", + "version": "0.0.0.1", + "license": "GPL", + "homepage": "https://lczero.org", + "repository": { + "type": "git", + "url": "https://github.com/LeelaChessZero/lc0" + }, + "main": "./main.js", + "exports": { + ".": { + "import": "./main.js" + } + }, + "dependencies": { + "onnxruntime-web": "1.20.1" + } +} +END diff --git a/js/example/.gitignore b/js/example/.gitignore new file mode 100644 index 0000000000..53e5f0de14 --- /dev/null +++ b/js/example/.gitignore @@ -0,0 +1,5 @@ +node_modules +package-lock.json +dist +net.pb.gz +.parcel-cache diff --git a/js/example/README.md b/js/example/README.md new file mode 100644 index 0000000000..011ffa9be2 --- /dev/null +++ b/js/example/README.md @@ -0,0 +1,16 @@ + + +Lc0 Web Example (Wasm) +=== + +First, download an Lc0 network and name the file `net.pb.gz` and place it on this directory. Then, run the following commands: + +- Install [Emscripten]. +- Install [npm]. +- Install [Meson]. +- Run `npm install` +- Run `npm run dev` + +[Emscripten]: +[npm]: +[Meson]: diff --git a/js/example/index.html b/js/example/index.html new file mode 100644 index 0000000000..0f210c9e69 --- /dev/null +++ b/js/example/index.html @@ -0,0 +1,70 @@ + + + + + + Lc0 Web Example + + + + + +
diff --git a/js/example/package.json b/js/example/package.json new file mode 100644 index 0000000000..6f2a62e087 --- /dev/null +++ b/js/example/package.json @@ -0,0 +1,11 @@ +{ + "scripts": { + "pack": "cd .. && npm run pack", + "dev": "npm run pack && npm install && vite" + }, + "dependencies": { + "@xterm/xterm": "5.5.0", + "lc0": "../lc0.tar", + "vite": "6.0.4" + } +} diff --git a/js/example/vite.config.js b/js/example/vite.config.js new file mode 100644 index 0000000000..113fc7a208 --- /dev/null +++ b/js/example/vite.config.js @@ -0,0 +1,11 @@ +export default { + server: { + headers: { + "Cross-Origin-Embedder-Policy": "require-corp", + "Cross-Origin-Opener-Policy": "same-origin", + } + }, + optimizeDeps: { + exclude: ["lc0"], + }, +} diff --git a/js/main.js b/js/main.js new file mode 100644 index 0000000000..37cdff64e9 --- /dev/null +++ b/js/main.js @@ -0,0 +1,94 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +function Stream(worker, type, finished) +{ + let gotLine + const lines = [] + + worker.addEventListener("message", ({data}) => + { + if (data.type !== type) return + if (!gotLine) { + lines.push(data.text) + return + } + gotLine(data.text) + gotLine = undefined + }) + + function next() + { + if (lines.length !== 0) return lines.shift() + else return new Promise(resolve => gotLine = resolve) + } + + const it = {next: async () => finished() && lines.length === 0 ? {done: true} : {done: false, value: await next()}} + Object.freeze(it) + + const peek = () => lines[0] + return {next, peek, [Symbol.asyncIterator]: () => it} +} + +export function Lc0(network) +{ + const worker = new Worker(new URL("worker.js", import.meta.url), {type: "module"}) + + let commands = [] + let post0 = command => commands.push(command) + + worker.addEventListener("message", () => + { + worker.postMessage({network}, [network]) + for (const command of commands) worker.postMessage(command) + commands = undefined + post0 = command => worker.postMessage(command) + }, {once: true}) + + const post = command => + { + if (finished) throw new Error("Cannot post command to finished Lc0") + post0(String(command)) + } + + let finished = false + + // todo: this should send a message to the worker instead + // so that it can end its pthread workers too + function finish() + { + finished = true + worker.terminate() + } + + const stdout = Stream(worker, "stdout", () => finished) + const stderr = Stream(worker, "stderr", () => finished) + + const lc0 = {post, finish, ...stdout, stderr, get finished() { return finished }} + Object.freeze(lc0) + return lc0 +} diff --git a/js/package.json b/js/package.json new file mode 100644 index 0000000000..e4b08a4fc9 --- /dev/null +++ b/js/package.json @@ -0,0 +1,10 @@ +{ + "scripts": { + "build": "npm install && ./build.sh", + "pack": "npm run build && tar cf lc0.tar dist" + }, + "dependencies": { + "esbuild": "0.24.2" + } +} + diff --git a/js/worker.js b/js/worker.js new file mode 100644 index 0000000000..049bb8b4a1 --- /dev/null +++ b/js/worker.js @@ -0,0 +1,171 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +import Module from "./build/lc0.js" +import {InferenceSession, Tensor} from "onnxruntime-web/all" + +let gotLine +const lines = [] + +addEventListener("message", ({data}) => +{ + if (typeof data !== "string") return + if (!gotLine) { + lines.push(data) + return + } + gotLine(data) + gotLine = undefined +}) + +const {data: {network}} = await new Promise(resolve => +{ + postMessage({type: "ready"}) + addEventListener("message", resolve, {once: true}) +}) + +let bytes = new Uint8Array(await new Response(network).arrayBuffer()) + +let id = 0 +const map = new Map() + +function lc0web_get_line() +{ + if (lines.length !== 0) return lines.shift() + return new Promise(resolve => gotLine = resolve) +} + +function lc0web_is_cpu(_id) +{ + // TODO + return true +} + +function lc0web_computation(id2) +{ + const i = id++ + map.set(i, {input: [], session: map.get(id2)}) + return i +} + +function lc0web_batch_size(id) +{ + return map.get(id).input.length +} + +function lc0web_remove(id) +{ + return map.delete(id) +} + +function lc0web_q_val(id, _sample) +{ + const [w, _d, l] = map.get(id).output["/output/wdl"].cpuData + return w - l +} + +function lc0web_d_val(id, _sample) +{ + const [_w, d] = map.get(id).output["/output/wdl"].cpuData + return d +} + +function lc0web_p_val(id, sample, moveID) +{ + return map.get(id).output["/output/policy"].cpuData[sample * 1858 + moveID] +} + +function lc0web_m_val(id, sample) +{ + return map.get(id).output["/output/mlh"].cpuData[sample] +} + +function lc0web_add_input(id) +{ + return map.get(id).input.push([]) +} + +function lc0web_add_plane(id, index, mask, value) +{ + const array = map.get(id).input[index] + for (let i = 0 ; i < 64 ; i++) { + if (mask & 1n) array.push(value) + else array.push(0) + mask >>= 1n + } +} + +async function lc0web_compute(id) +{ + const value = map.get(id) + const array = new Float32Array(value.input.flat(Infinity)) + const tensor = new Tensor("float32", array, [value.input.length, 112, 8, 8]) + value.output = await value.session.run({"/input/planes": tensor}) +} + +async function lc0web_network(data, length) +{ + const i = id++ + const buffer = module.HEAPU8.subarray(data, data + length) + const session = await InferenceSession.create(buffer, {executionProviders: ["webgpu", "wasm"]}) + map.set(i, session) + return i +} + +Object.assign(globalThis, { + lc0web_get_line, + lc0web_is_cpu, + lc0web_computation, + lc0web_batch_size, + lc0web_remove, + lc0web_q_val, + lc0web_d_val, + lc0web_p_val, + lc0web_m_val, + lc0web_add_input, + lc0web_add_plane, + lc0web_compute, + lc0web_network, +}) + +let module +Module({ + preRun: m => + { + module = m + const file = module.FS.open("net.pb.gz", "w") + module.FS.write(file, bytes, 0, bytes.length) + module.FS.close(file) + // free the buffer + if (bytes.buffer.transfer) bytes.buffer.transfer(0) + // let it be garbage-collected + bytes = undefined + }, + arguments: ["--preload", "-w", "net.pb.gz"], + print: text => postMessage({type: "stdout", text}), + printErr: text => postMessage({type: "stderr", text}), +}) diff --git a/meson.build b/meson.build index 587ea74c92..4ed68b582b 100644 --- a/meson.build +++ b/meson.build @@ -632,6 +632,11 @@ if get_option('build_backends') has_backends = true endif + if host_machine.system() == 'emscripten' + files += 'src/neural/backends/network_js.cc' + has_backends = true + endif + endif # if get_option('build_backends') if not has_backends and get_option('lc0') and get_option('build_backends') diff --git a/src/chess/uciloop.cc b/src/chess/uciloop.cc index 4135895965..f24da87733 100644 --- a/src/chess/uciloop.cc +++ b/src/chess/uciloop.cc @@ -37,6 +37,10 @@ #include #include +#ifdef __EMSCRIPTEN__ +#include +#endif + #include "utils/exception.h" #include "utils/logging.h" #include "utils/string.h" @@ -127,12 +131,37 @@ bool ContainsKey(const std::unordered_map& params, const std::string& key) { return params.find(key) != params.end(); } + +#ifdef __EMSCRIPTEN__ + +extern "C" { + +EM_ASYNC_JS(char*, lc0web_get_line, (), { + return stringToNewUTF8(String(await globalThis.lc0web_get_line())) +}); + +} + +bool GetLine(std::string& line) { + char *cline = lc0web_get_line(); + line = cline; + free(cline); + return true; +} + +#else + +bool GetLine(std::string& line) { + return static_cast(std::getline(std::cin, line)); +} + +#endif } // namespace void UciLoop::RunLoop() { std::cout.setf(std::ios::unitbuf); std::string line; - while (std::getline(std::cin, line)) { + while (GetLine(line)) { LOGFILE << ">> " << line; try { auto command = ParseCommand(line); diff --git a/src/mcts/node.h b/src/mcts/node.h index 2982de24da..7abf596084 100644 --- a/src/mcts/node.h +++ b/src/mcts/node.h @@ -346,7 +346,7 @@ class Node { #endif // A basic sanity check. This must be adjusted when Node members are adjusted. -#if defined(__i386__) || (defined(__arm__) && !defined(__aarch64__)) +#if defined(__i386__) || (defined(__arm__) && !defined(__aarch64__)) || defined(__EMSCRIPTEN__) static_assert(sizeof(Node) == 48, "Unexpected size of Node for 32bit compile"); #else static_assert(sizeof(Node) == 64, "Unexpected size of Node"); diff --git a/src/mcts/search.cc b/src/mcts/search.cc index b3326a7661..b81f38d47b 100644 --- a/src/mcts/search.cc +++ b/src/mcts/search.cc @@ -899,6 +899,11 @@ EdgeAndNode Search::GetBestRootChildWithTemperature(float temperature) const { } void Search::StartThreads(size_t how_many) { +#ifdef __EMSCRIPTEN__ + (void) how_many; + SearchWorker worker(this, params_, 0); + worker.RunBlocking(); +#else Mutex::Lock lock(threads_mutex_); if (how_many == 0 && threads_.size() == 0) { how_many = network_->GetThreads() + !network_->IsCpu(); @@ -920,6 +925,7 @@ void Search::StartThreads(size_t how_many) { std::chrono::steady_clock::now() - start_time_) .count() << "ms already passed."; +#endif } void Search::RunBlocking(size_t threads) { diff --git a/src/neural/backends/network_js.cc b/src/neural/backends/network_js.cc new file mode 100644 index 0000000000..874ebe685a --- /dev/null +++ b/src/neural/backends/network_js.cc @@ -0,0 +1,165 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +#include +#include + +#include "neural/factory.h" +#include "neural/loader.h" +#include "neural/network.h" +#include "neural/onnx/converter.h" + +namespace lczero { +namespace { + +extern "C" { + +EM_JS(int, lc0web_is_cpu, (int id), { return globalThis.lc0web_is_cpu(id) }); +EM_JS(int, lc0web_computation, (int id), { return globalThis.lc0web_computation(id) }); +EM_JS(int, lc0web_q_val, (int id, int sample), { return globalThis.lc0web_q_val(id, sample) }); +EM_JS(int, lc0web_d_val, (int id, int sample), { return globalThis.lc0web_d_val(id, sample) }); +EM_JS(int, lc0web_p_val, (int id, int sample, int move_id), { return globalThis.lc0web_p_val(id, sample, move_id) }); +EM_JS(int, lc0web_m_val, (int id, int sample), { return globalThis.lc0web_m_val(id, sample) }); +EM_JS(int, lc0web_remove, (int id), { return globalThis.lc0web_remove(id) }); +EM_JS(void, lc0web_add_input, (int id), { return globalThis.lc0web_add_input(id) }); +EM_JS(void, lc0web_add_plane, (int id, int index, uint64_t mask, float value), { return globalThis.lc0web_add_plane(id, index, mask, value) }); +EM_JS(int, lc0web_batch_size, (int id), { return globalThis.lc0web_batch_size(id) }); +EM_ASYNC_JS(void, lc0web_compute, (int id), { return globalThis.lc0web_compute(id) }); +EM_ASYNC_JS(int, lc0web_network, (const char *data, size_t length), { return globalThis.lc0web_network(data, length) }); + +} + +class JSComputation : public NetworkComputation { + public: + JSComputation(int id); + ~JSComputation() override; + void AddInput(InputPlanes&& input) override; + int GetBatchSize() const override; + void ComputeBlocking() override; + float GetQVal(int sample) const override; + float GetDVal(int sample) const override; + float GetPVal(int sample, int move_id) const override; + float GetMVal(int sample) const override; + private: + int id; +}; + +class JSNetwork : public Network { + public: + JSNetwork(std::string_view bytes); + ~JSNetwork() override; + const NetworkCapabilities& GetCapabilities() const override { + return capabilities; + }; + std::unique_ptr NewComputation() override; + bool IsCpu() const override; + private: + int id; + const NetworkCapabilities capabilities = { + pblczero::NetworkFormat_InputFormat_INPUT_CLASSICAL_112_PLANE, + pblczero::NetworkFormat_OutputFormat_OUTPUT_WDL, + pblczero::NetworkFormat_MovesLeftFormat_MOVES_LEFT_V1, + }; +}; + +std::unique_ptr MakeJSNetwork( + const std::optional& w, + const OptionsDict& opts) { + (void) opts; + if (!w) { + throw Exception("The JS backend requires a network file."); + } + auto weights = *w; + if (!weights.has_onnx_model()) { + WeightsToOnnxConverterOptions onnx_options; + onnx_options.alt_mish = true; + onnx_options.alt_selu = true; + weights = ConvertWeightsToOnnx(weights, onnx_options); + } + const auto& onnx = weights.onnx_model(); + return std::make_unique(onnx.model()); +} + +bool JSNetwork::IsCpu() const { + return lc0web_is_cpu(id); +} + +std::unique_ptr JSNetwork::NewComputation() { + return std::make_unique(id); +} + +float JSComputation::GetQVal(int sample) const { + return lc0web_q_val(id, sample); +} + +float JSComputation::GetDVal(int sample) const { + return lc0web_d_val(id, sample); +} + +float JSComputation::GetPVal(int sample, int move_id) const { + return lc0web_p_val(id, sample, move_id); +} + +float JSComputation::GetMVal(int sample) const { + return lc0web_m_val(id, sample); +} + +void JSComputation::AddInput(InputPlanes&& input) { + int i = GetBatchSize(); + lc0web_add_input(id); + for (auto& plane : input) + lc0web_add_plane(id, i, plane.mask, plane.value); +} + +int JSComputation::GetBatchSize() const { + return lc0web_batch_size(id); +} + +void JSComputation::ComputeBlocking() { + lc0web_compute(id); +} + +JSComputation::JSComputation(int id2) { + id = lc0web_computation(id2); +} + +JSComputation::~JSComputation() { + lc0web_remove(id); +} + +JSNetwork::JSNetwork(std::string_view bytes) { + id = lc0web_network(bytes.data(), bytes.length()); +} + +JSNetwork::~JSNetwork() { + lc0web_remove(id); +} + +REGISTER_NETWORK("js", MakeJSNetwork, 1000) + +} +} diff --git a/src/neural/onnx/builder.cc b/src/neural/onnx/builder.cc index fe09d5cb1c..7fe724adaa 100644 --- a/src/neural/onnx/builder.cc +++ b/src/neural/onnx/builder.cc @@ -143,6 +143,7 @@ std::string OnnxBuilder::Conv(const std::string& name, node->add_input(AddInitializer(name + "/w/bias", bias_weights)); AddIntsAttribute(node, "pads", {pads, pads, pads, pads}); AddIntsAttribute(node, "kernel_shape", {shape, shape}); + AddIntsAttribute(node, "dilations", {1, 1}); return out; } @@ -286,6 +287,14 @@ std::string OnnxBuilder::Identity(const std::string& name, return PopulateStdNodeFields(node, name, input, "Identity"); } +std::string OnnxBuilder::Elu(const std::string& name, + const std::string& input, float alpha) { + auto* node = model_.mutable_graph()->add_node(); + auto out = PopulateStdNodeFields(node, name, input, "Elu"); + AddFloatAttribute(node, "alpha", alpha); + return out; +} + std::string OnnxBuilder::Selu(const std::string& name, const std::string& input) { auto* node = model_.mutable_graph()->add_node(); diff --git a/src/neural/onnx/builder.h b/src/neural/onnx/builder.h index 4ada3c37f7..873b734103 100644 --- a/src/neural/onnx/builder.h +++ b/src/neural/onnx/builder.h @@ -97,6 +97,7 @@ class OnnxBuilder { std::initializer_list perm = {}); std::string Pad(const std::string& name, const std::string& input, std::initializer_list pads); + std::string Elu(const std::string& name, const std::string& input, float alpha); std::string Selu(const std::string& name, const std::string& input); std::string Slice(const std::string& name, const std::string& input, std::initializer_list starts, diff --git a/src/neural/onnx/converter.cc b/src/neural/onnx/converter.cc index 07986d4ef6..247a4181d4 100644 --- a/src/neural/onnx/converter.cc +++ b/src/neural/onnx/converter.cc @@ -308,7 +308,15 @@ std::string Converter::MakeActivation(OnnxBuilder* builder, case ACTIVATION_SELU: { auto flow = input; flow = StartOptionalBf16Fix(builder, flow, name); - flow = builder->Selu(name + "/selu", flow); + if (!options_.alt_selu) { + flow = builder->Selu(name + "/selu", flow); + } else { + // For the JS backend (while ORT-web doesn't implement SELU). + auto& alpha = + static_cast(FloatOnnxConst({1.67326f}, {1})); + flow = builder->Elu(name + "/selu/elu", flow, 1.0507f); + flow = builder->Mul(name + "/selu/mul", flow, alpha); + } return EndOptionalBf16Fix(builder, flow, name); } case ACTIVATION_SWISH: diff --git a/src/neural/onnx/converter.h b/src/neural/onnx/converter.h index 632f65c94b..a9203aa38a 100644 --- a/src/neural/onnx/converter.h +++ b/src/neural/onnx/converter.h @@ -47,6 +47,7 @@ struct WeightsToOnnxConverterOptions { int opset = 17; bool alt_mish = false; // Use "Mish" approximation (fp32 only). bool alt_layernorm = false; // Discrete "LayerNormalization" implementation. + bool alt_selu = false; // Implement "SELU" using "ELU". bool no_shape = false; // Avoid use of "Shape" operator. std::string policy_head = "vanilla"; std::string value_head = "winner";