Skip to content

Commit

Permalink
fix: init hidden state for reactant (#1026)
Browse files Browse the repository at this point in the history
* fix: init hidden state for reactant

[skip tests] [skip docs] [skip ci]

* test: Reactant with recurrent layers

* fix: handle cases where similar returns a AoS

* chore: bump minimum reactant version

* test: use @jit for simplified testing code

* test: compile functions in tests correctly

* test: update and fix reactant tests
  • Loading branch information
avik-pal authored Nov 15, 2024
1 parent fc75290 commit bbf5033
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 71 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ NNlib = "0.9.24"
Optimisers = "0.3.4, 0.4"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.4"
Reactant = "0.2.6"
Reexport = "1.2.2"
ReverseDiff = "1.15"
SIMDTypes = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Optimisers = "0.3.4, 0.4"
Pkg = "1.10"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.4"
Reactant = "0.2.6"
StableRNGs = "1"
StaticArrays = "1"
WeightInitializers = "1"
Expand Down
3 changes: 2 additions & 1 deletion ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ module LuxReactantExt

using Enzyme: Enzyme, Const, Duplicated, Active
using Optimisers: Optimisers
using Reactant: Reactant, @compile, TracedRArray
using Reactant: Reactant, @compile, TracedRArray, TracedRNumber
using Setfield: @set!
using Static: False

using Lux: Lux, LuxOps, Training
using Lux.Training: TrainingBackendCache, ReactantBackend

include("patches.jl")
include("training.jl")

end
7 changes: 7 additions & 0 deletions ext/LuxReactantExt/patches.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# For some reason xlogx and xlogy with boolean inputs leads to incorrect results sometimes
# XXX: Once https://github.com/EnzymeAD/Reactant.jl/pull/278 is merged and tagged
LuxOps.xlogx(x::TracedRNumber{Bool}) = zero(x)

function LuxOps.xlogy(x::TracedRNumber, y::TracedRNumber)
return invoke(LuxOps.xlogy, Tuple{Number, Number}, float(x), float(y))
end
2 changes: 1 addition & 1 deletion lib/LuxCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ EnzymeCore = "0.8.5"
Functors = "0.5"
MLDataDevices = "1.6"
Random = "1.10"
Reactant = "0.2.4"
Reactant = "0.2.6"
ReverseDiff = "1.15"
Setfield = "1"
Tracker = "0.2.36"
Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Metal = "1"
OneHotArrays = "0.2.5"
Preferences = "1.4"
Random = "1.10"
Reactant = "0.2.4"
Reactant = "0.2.6"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Expand Down
9 changes: 5 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ using Static: Static, StaticBool, StaticInteger, StaticSymbol
using StaticArraysCore: SMatrix, SVector

using LuxCore: LuxCore, AbstractLuxLayer
using MLDataDevices: get_device
using NNlib: NNlib

const CRC = ChainRulesCore
Expand Down Expand Up @@ -162,11 +161,13 @@ add!!(x::Number, y::Number) = x + y
add!!(::Nothing, ::Nothing) = nothing

function init_rnn_hidden_state(rng::AbstractRNG, rnn, x::AbstractMatrix)
# TODO: Once we support moving `rng` to the device, we can directly initialize on the
# device
return rnn.init_state(rng, rnn.out_dims, Base.size(x, 2)) |> get_device(x)
y = similar(x, rnn.out_dims, Base.size(x, 2))
copyto!(y, rnn.init_state(rng, size(y)...))
return ArrayInterface.aos_to_soa(y)
end

@non_differentiable init_rnn_hidden_state(::Any...)

function init_trainable_rnn_hidden_state(hidden_state::AbstractVector, x::AbstractMatrix)
return repeat(hidden_state, 1, Base.size(x, 2))
end
Expand Down
65 changes: 65 additions & 0 deletions test/reactant/layer_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
@testsetup module SharedReactantLayersTestSetup

using Lux, Reactant, Enzyme, Zygote

sumabs2(model, x, ps, st) = sum(abs2, first(model(x, ps, st)))

function ∇sumabs2_zygote(model, x, ps, st)
return Zygote.gradient((x, ps) -> sumabs2(model, x, ps, st), x, ps)
end

function ∇sumabs2_enzyme(model, x, ps, st)
dx = Enzyme.make_zero(x)
dps = Enzyme.make_zero(ps)
Enzyme.autodiff(
Enzyme.Reverse, sumabs2, Active,
Const(model), Duplicated(x, dx),
Duplicated(ps, dps), Const(st)
)
return dx, dps
end

export ∇sumabs2_zygote, ∇sumabs2_enzyme

end

@testitem "Recurrent Layers" tags=[:reactant] setup=[
SharedTestSetup, SharedReactantLayersTestSetup] skip=:(Sys.iswindows()) begin
using Reactant, Lux
using LuxTestUtils: check_approx

rng = StableRNG(123)

@testset "$(mode)" for (mode, atype, dev, ongpu) in MODES
if mode == "amdgpu"
@warn "Skipping AMDGPU tests for Reactant"
continue
end

if ongpu
Reactant.set_default_backend("gpu")
else
Reactant.set_default_backend("cpu")
end

@testset for cell in (RNNCell, LSTMCell, GRUCell)
model = Recurrence(cell(4 => 4))
ps, st = Lux.setup(rng, model)
ps_ra, st_ra = (ps, st) |> Reactant.to_rarray
x = rand(Float32, 4, 16, 12)
x_ra = x |> Reactant.to_rarray

y_ra, _ = @jit model(x_ra, ps_ra, st_ra)
y, _ = model(x, ps, st)

@test y_ray atol=1e-3 rtol=1e-3

@testset "gradient" begin
∂x, ∂ps = ∇sumabs2_zygote(model, x, ps, st)
∂x_ra, ∂ps_ra = @jit ∇sumabs2_enzyme(model, x_ra, ps_ra, st_ra)
@test ∂x_ra∂x atol=1e-3 rtol=1e-3
@test check_approx(∂ps_ra, ∂ps; atol=1e-3, rtol=1e-3)
end
end
end
end
88 changes: 31 additions & 57 deletions test/reactant/loss_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,8 @@
fn1(x) = LuxOps.xlogx.(x)
fn2(x, y) = LuxOps.xlogy.(x, y)

fn1_compiled = @compile fn1(x_ra)
@test fn1(x) fn1_compiled(x_ra)

fn2_compiled = @compile fn2(x_ra, y_ra)
@test fn2(x, y) fn2_compiled(x_ra, y_ra)
@test fn1(x) @jit(fn1(x_ra))
@test fn2(x, y) @jit(fn2(x_ra, y_ra))
end

@testset "Regression Loss" begin
Expand All @@ -43,14 +40,9 @@
loss_sum = eval(Symbol(loss * "Loss"))(; agg=sum)
loss_sum2 = eval(Symbol(loss * "Loss"))(; agg=(args...) -> sum(args...))

loss_mean_compiled = @compile loss_mean(ŷ_ra, y_ra)
@test loss_mean(ŷ, y) loss_mean_compiled(ŷ_ra, y_ra)

loss_sum_compiled = @compile loss_sum(ŷ_ra, y_ra)
@test loss_sum(ŷ, y) loss_sum_compiled(ŷ_ra, y_ra)

loss_sum2_compiled = @compile loss_sum2(ŷ_ra, y_ra)
@test loss_sum2(ŷ, y) loss_sum2_compiled(ŷ_ra, y_ra)
@test loss_mean(ŷ, y) @jit(loss_mean(ŷ_ra, y_ra))
@test loss_sum(ŷ, y) @jit(loss_sum(ŷ_ra, y_ra))
@test loss_sum2(ŷ, y) @jit(loss_sum2(ŷ_ra, y_ra))
end

@testset "MSLE" begin
Expand All @@ -61,8 +53,7 @@
ŷ_ra = Reactant.to_rarray(ŷ)

loss_msle = MSLELoss()
loss_msle_compiled = @compile loss_msle(ŷ_ra, y_ra)
@test loss_msle(ŷ, y) loss_msle_compiled(ŷ_ra, y_ra)
@test loss_msle(ŷ, y) @jit(loss_msle(ŷ_ra, y_ra))
end
end

Expand All @@ -75,39 +66,35 @@

@testset "CrossEntropyLoss" begin
celoss = CrossEntropyLoss()
celoss_compiled = @compile celoss(ŷ_ra, y_ra)
@test celoss(ŷ, y) celoss_compiled(ŷ_ra, y_ra)
@test celoss(ŷ, y) @jit(celoss(ŷ_ra, y_ra))

celoss_ls = CrossEntropyLoss(; label_smoothing=0.1)
celoss_ls_compiled = @compile celoss_ls(ŷ_ra, y_ra)
@test celoss_ls(ŷ, y) celoss_ls_compiled(ŷ_ra, y_ra)
@test celoss_ls(ŷ, y) @jit(celoss_ls(ŷ_ra, y_ra))

celoss_lp = CrossEntropyLoss(; logits=Val(true))
celoss_lp_compiled = @compile celoss_lp(log.(ŷ_ra), y_ra)
@test celoss_lp(log.(ŷ), y) celoss_lp_compiled(log.(ŷ_ra), y_ra)
logit_celoss_lp = (ŷ, y) -> celoss_lp(log.(), y)
@test logit_celoss_lp(ŷ, y) @jit(logit_celoss_lp(ŷ_ra, y_ra))

celoss_lp_ls = CrossEntropyLoss(; logits=Val(true), label_smoothing=0.1)
celoss_lp_ls_compiled = @compile celoss_lp_ls(log.(ŷ_ra), y_ra)
@test celoss_lp_ls(log.(ŷ), y) celoss_lp_ls_compiled(log.(ŷ_ra), y_ra)
logit_celoss_lp_ls = (ŷ, y) -> celoss_lp_ls(log.(), y)
@test logit_celoss_lp_ls(ŷ, y) @jit(logit_celoss_lp_ls(ŷ_ra, y_ra))
end

@testset "Binary CrossEntropyLoss" begin
bceloss = BinaryCrossEntropyLoss()
bceloss_compiled = @compile bceloss(ŷ_ra, y_ra)
@test bceloss(ŷ, y) bceloss_compiled(ŷ_ra, y_ra)
@test bceloss(ŷ, y) @jit(bceloss(ŷ_ra, y_ra))

bceloss_ls = BinaryCrossEntropyLoss(; label_smoothing=0.1)
bceloss_ls_compiled = @compile bceloss_ls(ŷ_ra, y_ra)
@test bceloss_ls(ŷ, y) bceloss_ls_compiled(ŷ_ra, y_ra)
@test bceloss_ls(ŷ, y) @jit(bceloss_ls(ŷ_ra, y_ra))

bceloss_lp = BinaryCrossEntropyLoss(; logits=Val(true))
bceloss_lp_compiled = @compile bceloss_lp(log.(ŷ_ra), y_ra)
@test bceloss_lp(log.(ŷ), y) bceloss_lp_compiled(log.(ŷ_ra), y_ra)
logit_bceloss_lp = (ŷ, y) -> bceloss_lp(log.(), y)
@test logit_bceloss_lp(ŷ, y) @jit(logit_bceloss_lp(ŷ_ra, y_ra))

bceloss_lp_ls = BinaryCrossEntropyLoss(;
logits=Val(true), label_smoothing=0.1)
bceloss_lp_ls_compiled = @compile bceloss_lp_ls(log.(ŷ_ra), y_ra)
@test bceloss_lp_ls(log.(ŷ), y) bceloss_lp_ls_compiled(log.(ŷ_ra), y_ra)
logit_bceloss_lp_ls = (ŷ, y) -> bceloss_lp_ls(log.(), y)
@test logit_bceloss_lp_ls(ŷ, y) @jit(logit_bceloss_lp_ls(ŷ_ra, y_ra))
end

@testset "BinaryFocalLoss" begin
Expand All @@ -120,8 +107,7 @@
ŷ_ra = Reactant.to_rarray(ŷ)

bfl = BinaryFocalLoss()
bfl_compiled = @compile bfl(ŷ_ra, y_ra)
@test bfl(ŷ, y) bfl_compiled(ŷ_ra, y_ra)
@test bfl(ŷ, y) @jit(bfl(ŷ_ra, y_ra))
end

@testset "FocalLoss" begin
Expand All @@ -134,8 +120,7 @@
ŷ_ra = Reactant.to_rarray(ŷ)

fl = FocalLoss()
fl_compiled = @compile fl(ŷ_ra, y_ra)
@test fl(ŷ, y) fl_compiled(ŷ_ra, y_ra)
@test fl(ŷ, y) @jit(fl(ŷ_ra, y_ra))
end
end

Expand All @@ -148,8 +133,7 @@
ŷ_ra = Reactant.to_rarray(ŷ)

kldl = KLDivergenceLoss()
kldl_compiled = @compile kldl(ŷ_ra, y_ra)
@test kldl(ŷ, y) kldl_compiled(ŷ_ra, y_ra)
@test kldl(ŷ, y) @jit(kldl(ŷ_ra, y_ra))
end

@testset "HingeLoss" begin
Expand All @@ -160,12 +144,10 @@
ŷ_ra = Reactant.to_rarray(ŷ)

hl = HingeLoss()
hl_compiled = @compile hl(ŷ_ra, y_ra)
@test hl(ŷ, y) hl_compiled(ŷ_ra, y_ra)
@test hl(ŷ, y) @jit(hl(ŷ_ra, y_ra))

hl = HingeLoss(; agg=mean)
hl_compiled = @compile hl(ŷ_ra, y_ra)
@test hl(ŷ, y) hl_compiled(ŷ_ra, y_ra)
@test hl(ŷ, y) @jit(hl(ŷ_ra, y_ra))
end

@testset "SquaredHingeLoss" begin
Expand All @@ -176,12 +158,10 @@
ŷ_ra = Reactant.to_rarray(ŷ)

hl = SquaredHingeLoss()
hl_compiled = @compile hl(ŷ_ra, y_ra)
@test hl(ŷ, y) hl_compiled(ŷ_ra, y_ra)
@test hl(ŷ, y) @jit(hl(ŷ_ra, y_ra))

hl = SquaredHingeLoss(; agg=mean)
hl_compiled = @compile hl(ŷ_ra, y_ra)
@test hl(ŷ, y) hl_compiled(ŷ_ra, y_ra)
@test hl(ŷ, y) @jit(hl(ŷ_ra, y_ra))
end

@testset "PoissonLoss" begin
Expand All @@ -192,12 +172,10 @@
ŷ_ra = Reactant.to_rarray(ŷ)

pl = PoissonLoss()
pl_compiled = @compile pl(ŷ_ra, y_ra)
@test pl(ŷ, y) pl_compiled(ŷ_ra, y_ra)
@test pl(ŷ, y) @jit(pl(ŷ_ra, y_ra))

pl = PoissonLoss(; agg=mean)
pl_compiled = @compile pl(ŷ_ra, y_ra)
@test pl(ŷ, y) pl_compiled(ŷ_ra, y_ra)
@test pl(ŷ, y) @jit(pl(ŷ_ra, y_ra))
end

@testset "DiceCoeffLoss" begin
Expand All @@ -208,12 +186,10 @@
ŷ_ra = Reactant.to_rarray(ŷ)

dl = DiceCoeffLoss()
dl_compiled = @compile dl(ŷ_ra, y_ra)
@test dl(ŷ, y) dl_compiled(ŷ_ra, y_ra)
@test dl(ŷ, y) @jit(dl(ŷ_ra, y_ra))

dl = DiceCoeffLoss(; agg=mean)
dl_compiled = @compile dl(ŷ_ra, y_ra)
@test dl(ŷ, y) dl_compiled(ŷ_ra, y_ra)
@test dl(ŷ, y) @jit(dl(ŷ_ra, y_ra))
end

@testset "Siamese Contrastive Loss" begin
Expand All @@ -228,12 +204,10 @@
ŷ_ra = Reactant.to_rarray(ŷ)

sl = SiameseContrastiveLoss()
sl_compiled = @compile sl(ŷ_ra, y_ra)
@test sl(ŷ, y) sl_compiled(ŷ_ra, y_ra)
@test sl(ŷ, y) @jit(sl(ŷ_ra, y_ra))

sl = SiameseContrastiveLoss(; agg=mean)
sl_compiled = @compile sl(ŷ_ra, y_ra)
@test sl(ŷ, y) sl_compiled(ŷ_ra, y_ra)
@test sl(ŷ, y) @jit(sl(ŷ_ra, y_ra))
end
end
end
Expand Down
17 changes: 12 additions & 5 deletions test/reactant/training_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,23 @@
ps, st = Lux.setup(StableRNG(1234), model) |> xdev

x_ra = randn(Float32, 2, 32) |> xdev
y_ra = rand(Float32, 2, 32) |> xdev

inference_fn = @compile model(x_ra, ps, Lux.testmode(st))
inference_loss_fn = (xᵢ, yᵢ, mode, ps, st) -> begin
ŷᵢ, _ = model(xᵢ, ps, Lux.testmode(st))
return MSELoss()(ŷᵢ, yᵢ)
end
inference_loss_fn_compiled = @compile inference_loss_fn(
x_ra, y_ra, model, ps, st
)

x = [rand(Float32, 2, 32) for _ in 1:32]
y = [xᵢ .^ 2 for xᵢ in x]

dataloader = DeviceIterator(xdev, zip(x, y))

total_initial_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ)
ŷᵢ, _ = inference_fn(xᵢ, ps, Lux.testmode(st))
return MSELoss()(ŷᵢ, yᵢ)
inference_loss_fn_compiled(xᵢ, yᵢ, model, ps, st)
end

train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
Expand All @@ -52,8 +58,9 @@
end

total_final_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ)
ŷᵢ, _ = inference_fn(xᵢ, train_state.parameters, Lux.testmode(st))
return MSELoss()(ŷᵢ, yᵢ)
inference_loss_fn_compiled(
xᵢ, yᵢ, model, train_state.parameters, train_state.states
)
end

@test total_final_loss < 100 * total_initial_loss
Expand Down

10 comments on commit bbf5033

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=lib/MLDataDevices

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/119457

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a MLDataDevices-v1.6.0 -m "<description of version>" bbf503374b42432324654d4701d284fa5bac74f3
git push origin MLDataDevices-v1.6.0

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=lib/LuxCore

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=lib/LuxTestUtils

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/119458

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a LuxCore-v1.2.0 -m "<description of version>" bbf503374b42432324654d4701d284fa5bac74f3
git push origin LuxCore-v1.2.0

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/119459

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a LuxTestUtils-v1.6.0 -m "<description of version>" bbf503374b42432324654d4701d284fa5bac74f3
git push origin LuxTestUtils-v1.6.0

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=lib/LuxLib

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/119461

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a LuxLib-v1.3.8 -m "<description of version>" bbf503374b42432324654d4701d284fa5bac74f3
git push origin LuxLib-v1.3.8

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/119463

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.3.0 -m "<description of version>" bbf503374b42432324654d4701d284fa5bac74f3
git push origin v1.3.0

Please sign in to comment.