Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Octogonapus committed Mar 12, 2024
1 parent b0ed11b commit e0c9099
Show file tree
Hide file tree
Showing 18 changed files with 1,471 additions and 684 deletions.
24 changes: 20 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,41 @@ jobs:
matrix:
version:
- "1.9"
- "1.10"
os:
- ubuntu-latest
- macOS-latest
arch:
- x64
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1

- name: Run regular tests
uses: julia-actions/julia-runtest@v1
env:
ENDPOINT: ${{ secrets.ENDPOINT }}
CERT_STRING: ${{ secrets.CERT_STRING }}
PRI_KEY_STRING: ${{ secrets.PRI_KEY_STRING }}

- name: Run parallel tests
env:
ENDPOINT: ${{ secrets.ENDPOINT }}
CERT_STRING: ${{ secrets.CERT_STRING }}
PRI_KEY_STRING: ${{ secrets.PRI_KEY_STRING }}
AWSCRT_TESTS_PARALLEL: true
run: |
NTHREADS = $(($(nproc) * 2))
NTIMES=100
julia -t $NTHREADS test/run_parallel_commands.jl $NTIMES julia -t 2 --project -e 'using Pkg; Pkg.test()'
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v2
- uses: codecov/codecov-action@v4
with:
file: ./lcov.info
fail_ci_if_error: false
Expand All @@ -46,13 +61,14 @@ jobs:
matrix:
version:
- "1.9"
- "1.10"
os:
- ubuntu-latest
- macOS-latest
arch:
- x64
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
Expand Down
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@ LibAWSCRT = "df7458b6-5204-493f-a0e7-404b4eb72fac"
AWSCRT_jll = "0.1"
CEnum = "0.4"
CountDownLatches = "2"
ForeignCallbacks = "0.1"
ForeignCallbacks = "0.1.1"
JSON = "0.21"
LibAWSCRT = "0.1"
julia = "1.9"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Documenter", "Random", "Test"]
test = ["Aqua", "Dates", "Documenter", "Random", "Test"]
177 changes: 116 additions & 61 deletions src/AWSCRT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,76 +14,44 @@ using LibAWSCRT, ForeignCallbacks, CountDownLatches, CEnum, JSON
import Base: lock, unlock
export lock, unlock

const _AWSCRT_ALLOCATOR = Ref{Union{Ptr{aws_allocator},Nothing}}(nothing)
const _GLOBAL_REFS = Vector{Ref}()
const _LIBPTR = Ref{Ptr{Cvoid}}(Ptr{Cvoid}(0))

function __init__()
_LIBPTR[] = Libc.Libdl.dlopen(LibAWSCRT.libawscrt)

_AWSCRT_ALLOCATOR[] = let level = get(ENV, "AWS_CRT_MEMORY_TRACING", "")
if !isempty(level)
level = parse(Int, strip(level))
level = aws_mem_trace_level(level)
if Symbol(level) == :UnknownMember
error(
"Invalid value for env var AWS_CRT_MEMORY_TRACING. " *
"See aws_mem_trace_level docs for valid values.",
)
end
frames_per_stack = parse(Int, strip(get(ENV, "AWS_CRT_MEMORY_TRACING_FRAMES_PER_STACK", "0")))
aws_mem_tracer_new(aws_default_allocator(), C_NULL, level, frames_per_stack)
else
aws_default_allocator()
end
end

let log_level = get(ENV, "AWS_CRT_LOG_LEVEL", "")
if !isempty(log_level)
log_level = parse(Int, strip(log_level))
log_level = aws_log_level(log_level)
if Symbol(log_level) == :UnknownMember
error("Invalid value for env var AWS_CRT_LOG_LEVEL. See aws_log_level docs for valid values.")
end

log_path = get(ENV, "AWS_CRT_LOG_PATH", "")
if isempty(log_path)
error("Env var AWS_CRT_LOG_PATH must be set to the path at which to save the log file.")
end
log_path = Ref(deepcopy(log_path))
push!(_GLOBAL_REFS, log_path)

logger = Ref(aws_logger(C_NULL, C_NULL, C_NULL))
push!(_GLOBAL_REFS, logger)

logger_options =
Ref(aws_logger_standard_options(log_level, Base.unsafe_convert(Ptr{Cchar}, log_path[]), C_NULL))
push!(_GLOBAL_REFS, logger_options)
const _C_IDS_LOCK = ReentrantLock()
const _C_IDS = IdDict{Any,Any}()

aws_logger_init_standard(logger, _AWSCRT_ALLOCATOR[], logger_options)
aws_logger_set(logger)
end
end
const _C_ON_ANY_MESSAGE_IDS_LOCK = ReentrantLock()
const _C_ON_ANY_MESSAGE_IDS = IdDict{Any,Any}()

aws_mqtt_library_init(_AWSCRT_ALLOCATOR[]) # also does io and http
# set during __init__
const _LIBPTR = Ref{Ptr{Cvoid}}(Ptr{Cvoid}(0))
const _AWSCRT_ALLOCATOR = Ref{Union{Ptr{aws_allocator},Nothing}}(nothing)

# TODO try cleanup using this approach https://github.com/JuliaLang/julia/pull/20124/files
end
# cfunctions set during __init__
const _C_ON_CONNECTION_INTERRUPTED = Ref{Ptr{Cvoid}}(C_NULL)
const _C_ON_CONNECTION_RESUMED = Ref{Ptr{Cvoid}}(C_NULL)
const _C_ON_CONNECTION_COMPLETE = Ref{Ptr{Cvoid}}(C_NULL)
const _C_ON_DISCONNECT_COMPLETE = Ref{Ptr{Cvoid}}(C_NULL)
const _C_ON_SUBSCRIBE_MESSAGE = Ref{Ptr{Cvoid}}(C_NULL)
const _C_ON_ANY_MESSAGE = Ref{Ptr{Cvoid}}(C_NULL)
const _C_ON_SUBSCRIBE_COMPLETE = Ref{Ptr{Cvoid}}(C_NULL)
const _C_ON_UNSUBSCRIBE_COMPLETE = Ref{Ptr{Cvoid}}(C_NULL)
const _C_ON_RESUBSCRIBE_COMPLETE = Ref{Ptr{Cvoid}}(C_NULL)
const _C_ON_PUBLISH_COMPLETE = Ref{Ptr{Cvoid}}(C_NULL)

function _release(; include_mem_tracer = isempty(get(ENV, "AWS_CRT_MEMORY_TRACING", "")))
aws_thread_set_managed_join_timeout_ns(5e8) # 0.5 seconds

i = findfirst(x -> x isa Ref{aws_logger}, _GLOBAL_REFS)
if i !== nothing
aws_logger_clean_up(_GLOBAL_REFS[i])
end
lock(_C_IDS_LOCK) do
logger = findfirst(x -> x isa Ref{aws_logger}, keys(_C_IDS))
if logger !== nothing
aws_logger_clean_up(logger)
end

aws_mqtt_library_clean_up() # also does io and http
empty!(_GLOBAL_REFS)
if include_mem_tracer
aws_mem_tracer_destroy(_AWSCRT_ALLOCATOR[])
aws_mqtt_library_clean_up() # also does io and http
empty!(_C_IDS)
if include_mem_tracer
aws_mem_tracer_destroy(_AWSCRT_ALLOCATOR[])
end
return nothing
end
return nothing
end

aws_err_string(code = aws_last_error()) = "AWS Error $code: " * Base.unsafe_string(aws_error_debug_str(code))
Expand Down Expand Up @@ -133,5 +101,92 @@ export ShadowDocumentPreUpdateCallback
export ShadowDocumentPostUpdateCallback
export shadow_client
export publish_current_state
export wait_until_synced

function __init__()
_LIBPTR[] = Libc.Libdl.dlopen(LibAWSCRT.libawscrt)

_C_ON_CONNECTION_INTERRUPTED[] =
@cfunction(_c_on_connection_interrupted, Cvoid, (Ptr{aws_mqtt_client_connection}, Cint, Ptr{Cvoid}))
_C_ON_CONNECTION_RESUMED[] =
@cfunction(_c_on_connection_resumed, Cvoid, (Ptr{aws_mqtt_client_connection}, Cint, Cint, Ptr{Cvoid}))
_C_ON_CONNECTION_COMPLETE[] =
@cfunction(_c_on_connection_complete, Cvoid, (Ptr{aws_mqtt_client_connection}, Cint, Cint, Cuchar, Ptr{Cvoid}))
_C_ON_DISCONNECT_COMPLETE[] =
@cfunction(_c_on_disconnect_complete, Cvoid, (Ptr{aws_mqtt_client_connection}, Ptr{Cvoid}))
_C_ON_SUBSCRIBE_MESSAGE[] = @cfunction(
_c_on_subscribe_message,
Cvoid,
(Ptr{aws_mqtt_client_connection}, Ptr{aws_byte_cursor}, Ptr{aws_byte_cursor}, Cuchar, Cint, Cuchar, Ptr{Cvoid})
)
_C_ON_ANY_MESSAGE[] = @cfunction(
_c_on_any_message,
Cvoid,
(Ptr{aws_mqtt_client_connection}, Ptr{aws_byte_cursor}, Ptr{aws_byte_cursor}, Cuchar, Cint, Cuchar, Ptr{Cvoid})
)
_C_ON_SUBSCRIBE_COMPLETE[] = @cfunction(
_c_on_subscribe_complete,
Cvoid,
(Ptr{aws_mqtt_client_connection}, Cuint, Ptr{aws_byte_cursor}, Cint, Cint, Ptr{Cvoid})
)
_C_ON_UNSUBSCRIBE_COMPLETE[] =
@cfunction(_c_on_unsubscribe_complete, Cvoid, (Ptr{aws_mqtt_client_connection}, Cuint, Cint, Ptr{Cvoid}))
_C_ON_RESUBSCRIBE_COMPLETE[] = @cfunction(
_c_on_resubscribe_complete,
Cvoid,
(Ptr{aws_mqtt_client_connection}, Cuint, Ptr{aws_array_list}, Cint, Ptr{Cvoid})
)
_C_ON_PUBLISH_COMPLETE[] =
@cfunction(_c_on_publish_complete, Cvoid, (Ptr{aws_mqtt_client_connection}, Cuint, Cint, Ptr{Cvoid}))

_AWSCRT_ALLOCATOR[] = let level = get(ENV, "AWS_CRT_MEMORY_TRACING", "")
if !isempty(level)
level = parse(Int, strip(level))
level = aws_mem_trace_level(level)
if Symbol(level) == :UnknownMember
error(
"Invalid value for env var AWS_CRT_MEMORY_TRACING. " *
"See aws_mem_trace_level docs for valid values.",
)
end
frames_per_stack = parse(Int, strip(get(ENV, "AWS_CRT_MEMORY_TRACING_FRAMES_PER_STACK", "0")))
aws_mem_tracer_new(aws_default_allocator(), C_NULL, level, frames_per_stack)
else
aws_default_allocator()
end
end

let log_level = get(ENV, "AWS_CRT_LOG_LEVEL", "")
if !isempty(log_level)
log_level = parse(Int, strip(log_level))
log_level = aws_log_level(log_level)
if Symbol(log_level) == :UnknownMember
error("Invalid value for env var AWS_CRT_LOG_LEVEL. See aws_log_level docs for valid values.")
end

log_path = get(ENV, "AWS_CRT_LOG_PATH", "")
if isempty(log_path)
error("Env var AWS_CRT_LOG_PATH must be set to the path at which to save the log file.")
end
log_path = Ref(deepcopy(log_path))
logger = Ref(aws_logger(C_NULL, C_NULL, C_NULL))
logger_options =
Ref(aws_logger_standard_options(log_level, Base.unsafe_convert(Ptr{Cchar}, log_path[]), C_NULL))

lock(_C_IDS_LOCK) do
_C_IDS[log_path] = nothing
_C_IDS[logger] = nothing
_C_IDS[logger_options] = nothing
end

aws_logger_init_standard(logger, _AWSCRT_ALLOCATOR[], logger_options)
aws_logger_set(logger)
end
end

aws_mqtt_library_init(_AWSCRT_ALLOCATOR[]) # also does io and http

# TODO try cleanup using this approach https://github.com/JuliaLang/julia/pull/20124/files
end

end
36 changes: 23 additions & 13 deletions src/AWSIO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,30 +304,42 @@ Arguments:
"""
function ClientTLSContext(options::TLSContextOptions)
tls_ctx_opt = Ref(aws_tls_ctx_options(ntuple(_ -> UInt8(0), 200)))
# tls_ctx_opt = Ref(
# aws_tls_ctx_options(
# C_NULL,
# options.min_tls_version,
# aws_tls_cipher_pref(0),
# aws_byte_buf(0, C_NULL, 0, C_NULL),
# C_NULL,
# C_NULL,
# aws_byte_buf(0, C_NULL, 0, C_NULL),
# aws_byte_buf(0, C_NULL, 0, C_NULL),
# 0,
# options.verify_peer,
# C_NULL,
# C_NULL,
# ),
# )
GC.@preserve tls_ctx_opt begin
tls_ctx_opt_ptr = Base.unsafe_convert(Ptr{aws_tls_ctx_options}, tls_ctx_opt)

# TODO pkcs11
# TODO pkcs12
# TODO windows cert store
if options.cert_data !== nothing
# mTLS with certificate and private key
cert = Ref(aws_byte_cursor_from_c_str(options.cert_data))
key = Ref(aws_byte_cursor_from_c_str(options.pk_data))
if aws_tls_ctx_options_init_client_mtls(tls_ctx_opt_ptr, _AWSCRT_ALLOCATOR[], cert, key) != AWS_OP_SUCCESS
if aws_tls_ctx_options_init_client_mtls(tls_ctx_opt, _AWSCRT_ALLOCATOR[], cert, key) != AWS_OP_SUCCESS
error("Failed to create client TLS context. $(aws_err_string())")
end
else
# no mTLS
aws_tls_ctx_options_init_default_client(tls_ctx_opt_ptr, _AWSCRT_ALLOCATOR[])
aws_tls_ctx_options_init_default_client(tls_ctx_opt, _AWSCRT_ALLOCATOR[])
end

tls_ctx_opt_ptr.minimum_tls_version = options.min_tls_version

try
if options.ca_dirpath !== nothing || options.ca_filepath !== nothing
if aws_tls_ctx_options_override_default_trust_store_from_path(
tls_ctx_opt_ptr,
tls_ctx_opt,
options.ca_dirpath === nothing ? C_NULL : options.ca_dirpath,
options.ca_filepath === nothing ? C_NULL : options.ca_filepath,
) != AWS_OP_SUCCESS
Expand All @@ -337,21 +349,19 @@ function ClientTLSContext(options::TLSContextOptions)

if options.ca_data !== nothing
ca = Ref(aws_byte_cursor_from_c_str(options.ca_data))
if aws_tls_ctx_options_override_default_trust_store(tls_ctx_opt_ptr, ca) != AWS_OP_SUCCESS
if aws_tls_ctx_options_override_default_trust_store(tls_ctx_opt, ca) != AWS_OP_SUCCESS
error("Failed to override trust store. $(aws_err_string())")
end
end

if options.alpn_list !== nothing
alpn_list_string = join(options.alpn_list, ';')
if aws_tls_ctx_options_set_alpn_list(tls_ctx_opt_ptr, alpn_list_string) != AWS_OP_SUCCESS
if aws_tls_ctx_options_set_alpn_list(tls_ctx_opt, alpn_list_string) != AWS_OP_SUCCESS
error("Failed to set ALPN list. $(aws_err_string())")
end
end

tls_ctx_opt_ptr.verify_peer = options.verify_peer

tls_ctx = aws_tls_client_ctx_new(_AWSCRT_ALLOCATOR[], tls_ctx_opt_ptr)
tls_ctx = aws_tls_client_ctx_new(_AWSCRT_ALLOCATOR[], tls_ctx_opt)
if tls_ctx == C_NULL
error("Failed to create TLS context. $(aws_err_string())")
end
Expand All @@ -361,7 +371,7 @@ function ClientTLSContext(options::TLSContextOptions)
aws_tls_ctx_release(x.ptr)
end
catch
aws_tls_ctx_options_clean_up(tls_ctx_opt_ptr)
aws_tls_ctx_options_clean_up(tls_ctx_opt)
rethrow()
end
end
Expand Down
Loading

0 comments on commit e0c9099

Please sign in to comment.