Skip to content

Commit

Permalink
fix thread unsafety (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
adienes authored Apr 7, 2024
1 parent b384824 commit b7dd39f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"

[compat]
Distances = "0.8.1, 0.9, 0.10"
julia = "1.3"
StatsAPI = "1"
julia = "1.3"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
1 change: 1 addition & 0 deletions src/StringDistances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module StringDistances

using Distances: Distances, SemiMetric, Metric, evaluate, result_type
using StatsAPI: StatsAPI, pairwise, pairwise!

# Distances API
abstract type StringSemiMetric <: SemiMetric end
abstract type StringMetric <: Metric end
Expand Down
65 changes: 44 additions & 21 deletions src/find.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ julia> compare("martha", "marhta", Levenshtein())
"""
function compare(s1, s2, dist::Union{StringSemiMetric, StringMetric}; min_score = 0.0)
1 - Normalized(dist)(s1, s2; max_dist = 1 - min_score)
end
end

"""
findnearest(s, itr, dist::Union{StringMetric, StringSemiMetric}) -> (x, index)
Expand All @@ -35,22 +35,34 @@ julia> findnearest(s, iter, Levenshtein(); min_score = 0.9)
```
"""
function findnearest(s, itr, dist::Union{StringSemiMetric, StringMetric}; min_score = 0.0)
_citr = collect(itr)
isempty(_citr) && return (nothing, nothing)

_preprocessed_s = _preprocess(dist, s)
min_score_atomic = Threads.Atomic{Float64}(min_score)
scores = [0.0 for _ in 1:Threads.nthreads()]
is = [0 for _ in 1:Threads.nthreads()]
s = _preprocess(dist, s)
# need collect since @threads requires a length method
Threads.@threads for i in collect(eachindex(itr))
score = compare(s, _preprocess(dist, itr[i]), dist; min_score = min_score_atomic[])
score_old = Threads.atomic_max!(min_score_atomic, score)
if score >= score_old
scores[Threads.threadid()] = score
is[Threads.threadid()] = i

chunk_size = max(1, length(_citr) ÷ (2 * Threads.nthreads()))
data_chunks = Iterators.partition(_citr, chunk_size)

chunk_score_tasks = map(data_chunks) do chunk
Threads.@spawn begin
map(chunk) do x
score = compare(_preprocessed_s, _preprocess(dist, x), dist; min_score = min_score)
Threads.atomic_max!(min_score_atomic, score)
score
end
end
end
imax = is[argmax(scores)]
imax == 0 ? (nothing, nothing) : (itr[imax], imax)

# retrieve return type of `compare` for type stability in task
_self_cmp = compare(_preprocessed_s, _preprocessed_s, dist; min_score = min_score)
chunk_scores = fetch.(chunk_score_tasks)::Vector{Vector{typeof(_self_cmp)}}
scores = reduce(vcat, fetch.(chunk_scores))

imax = argmax(scores)
iszero(scores) ? (nothing, nothing) : (_citr[imax], imax)
end

_preprocess(dist::AbstractQGramDistance, ::Missing) = missing
_preprocess(dist::AbstractQGramDistance, s) = QGramSortedVector(s, dist.q)
_preprocess(dist::Union{StringSemiMetric, StringMetric}, s) = s
Expand Down Expand Up @@ -83,14 +95,25 @@ julia> findall(s, iter, Levenshtein(); min_score = 0.9)
```
"""
function Base.findall(s, itr, dist::Union{StringSemiMetric, StringMetric}; min_score = 0.8)
out = [Int[] for _ in 1:Threads.nthreads()]
s = _preprocess(dist, s)
# need collect since @threads requires a length method
Threads.@threads for i in collect(eachindex(itr))
score = compare(s, _preprocess(dist, itr[i]), dist; min_score = min_score)
if score >= min_score
push!(out[Threads.threadid()], i)
_citr = collect(itr)
_preprocessed_s = _preprocess(dist, s)

chunk_size = max(1, length(_citr) ÷ (2 * Threads.nthreads()))
data_chunks = Iterators.partition(itr, chunk_size)
isempty(data_chunks) && return empty(eachindex(_citr))

chunk_score_tasks = map(data_chunks) do chunk
Threads.@spawn begin
map(chunk) do x
compare(_preprocessed_s, _preprocess(dist, x), dist; min_score = min_score)
end
end
end
vcat(out...)

# retrieve return type of `compare` for type stability in task
_self_cmp = compare(_preprocessed_s, _preprocessed_s, dist; min_score = min_score)
chunk_scores::Vector{Vector{typeof(_self_cmp)}} = fetch.(chunk_score_tasks)

scores = reduce(vcat, fetch.(chunk_scores))
return findall(>=(min_score), scores)
end

0 comments on commit b7dd39f

Please sign in to comment.