Skip to content

Commit

Permalink
Use SafeConnection instead of Connection
Browse files Browse the repository at this point in the history
  • Loading branch information
crsib committed Feb 19, 2024
1 parent d6cace7 commit 0892c6c
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 174 deletions.
49 changes: 24 additions & 25 deletions libraries/lib-cloud-audiocom/sync/CloudProjectsDatabase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,30 +73,31 @@ CloudProjectsDatabase& CloudProjectsDatabase::Get()

bool CloudProjectsDatabase::IsOpen() const
{
return mConnection.IsOpen();
return !!mConnection;
}
sqlite::Connection& CloudProjectsDatabase::GetConnection()
sqlite::SafeConnection::Lock CloudProjectsDatabase::GetConnection()
{
if (!mConnection)
OpenConnection();

return mConnection;
// It is safe to call lock on a null connection
return sqlite::SafeConnection::Lock { mConnection };
}

const sqlite::Connection& CloudProjectsDatabase::GetConnection() const
const sqlite::SafeConnection::Lock CloudProjectsDatabase::GetConnection() const
{
return const_cast<CloudProjectsDatabase*>(this)->GetConnection();
}

std::optional<DBProjectData>
CloudProjectsDatabase::GetProjectData(const std::string_view& projectId) const
{
auto& connection = GetConnection();
auto connection = GetConnection();

if (!connection)
return {};

auto statement = connection.CreateStatement(
auto statement = connection->CreateStatement(
"SELECT project_id, snapshot_id, saves_count, last_audio_preview_save, local_path, last_modified, last_read, sync_status FROM projects WHERE project_id = ? LIMIT 1");

if (!statement)
Expand All @@ -108,12 +109,12 @@ CloudProjectsDatabase::GetProjectData(const std::string_view& projectId) const
std::optional<DBProjectData> CloudProjectsDatabase::GetProjectDataForPath(
const std::string& projectFilePath) const
{
auto& connection = GetConnection();
auto connection = GetConnection();

if (!connection)
return {};

auto statement = connection.CreateStatement(
auto statement = connection->CreateStatement(
"SELECT project_id, snapshot_id, saves_count, last_audio_preview_save, local_path, last_modified, last_read, sync_status FROM projects WHERE local_path = ? LIMIT 1");

if (!statement)
Expand All @@ -125,12 +126,12 @@ std::optional<DBProjectData> CloudProjectsDatabase::GetProjectDataForPath(
bool CloudProjectsDatabase::MarkProjectAsSynced(
const std::string_view& projectId, const std::string_view& snapshotId)
{
auto& connection = GetConnection();
auto connection = GetConnection();

if (!connection)
return false;

auto statement = connection.CreateStatement ("UPDATE projects SET sync_status = ? WHERE project_id = ? AND snapshot_id = ?");
auto statement = connection->CreateStatement ("UPDATE projects SET sync_status = ? WHERE project_id = ? AND snapshot_id = ?");

if (!statement)
return false;
Expand All @@ -146,16 +147,16 @@ bool CloudProjectsDatabase::MarkProjectAsSynced(
void CloudProjectsDatabase::UpdateProjectBlockList(
const std::string_view& projectId, const SampleBlockIDSet& blockSet)
{
auto& connection = GetConnection();
auto connection = GetConnection();

if (!connection)
return;

auto inProjectSet = connection.CreateScalarFunction(
auto inProjectSet = connection->CreateScalarFunction(
"inProjectSet", [&blockSet](int64_t blockIndex)
{ return blockSet.find(blockIndex) != blockSet.end(); });

auto statement = connection.CreateStatement(
auto statement = connection->CreateStatement(
"DELETE FROM block_hashes WHERE project_id = ? AND NOT inProjectSet(block_id)");

auto result = statement->Prepare(projectId).Run();
Expand All @@ -169,12 +170,12 @@ void CloudProjectsDatabase::UpdateProjectBlockList(
std::optional<std::string> CloudProjectsDatabase::GetBlockHash(
const std::string_view& projectId, int64_t blockId) const
{
auto& connection = GetConnection();
auto connection = GetConnection();

if (!connection)
return {};

auto statement = connection.CreateStatement ("SELECT hash FROM block_hashes WHERE project_id = ? AND block_id = ? LIMIT 1");
auto statement = connection->CreateStatement ("SELECT hash FROM block_hashes WHERE project_id = ? AND block_id = ? LIMIT 1");

if (!statement)
return {};
Expand All @@ -198,17 +199,17 @@ void CloudProjectsDatabase::UpdateBlockHashes(
const std::string_view& projectId,
const std::vector<std::pair<int64_t, std::string>>& hashes)
{
auto& connection = GetConnection();
auto connection = GetConnection();

if (!connection)
return;

const int localVar {};
auto transaction = connection.BeginTransaction(
auto transaction = connection->BeginTransaction(
std::string("UpdateBlockHashes_") +
std::to_string(reinterpret_cast<size_t>(&localVar)));

auto statement = connection.CreateStatement (
auto statement = connection->CreateStatement (
"INSERT OR REPLACE INTO block_hashes (project_id, block_id, hash) VALUES (?, ?, ?)");

for (const auto& [blockId, hash] : hashes)
Expand All @@ -220,12 +221,12 @@ void CloudProjectsDatabase::UpdateBlockHashes(
bool CloudProjectsDatabase::UpdateProjectData(
const DBProjectData& projectData)
{
auto& connection = GetConnection();
auto connection = GetConnection();

if (!connection)
return false;

auto statement = connection.CreateStatement (
auto statement = connection->CreateStatement (
"INSERT OR REPLACE INTO projects (project_id, snapshot_id, saves_count, last_audio_preview_save, local_path, last_modified, last_read, sync_status) VALUES (?, ?, ?, ?, ?, ?, ?, ?)");

if (!statement)
Expand Down Expand Up @@ -292,14 +293,12 @@ bool CloudProjectsDatabase::OpenConnection()
const auto configDir = FileNames::ConfigDir();
const auto configPath = configDir + "/audiocom_sync.db";

auto db = sqlite::Connection::Open(audacity::ToUTF8(configPath));
mConnection = sqlite::SafeConnection::Open(audacity::ToUTF8(configPath));

if (!db)
if (!mConnection)
return false;

mConnection = std::move(*db);

auto result = mConnection.Execute(createTableQuery);
auto result = mConnection->Acquire()->Execute(createTableQuery);

if (!result)
{
Expand Down
8 changes: 4 additions & 4 deletions libraries/lib-cloud-audiocom/sync/CloudProjectsDatabase.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <optional>
#include <string>

#include "sqlite/Connection.h"
#include "sqlite/SafeConnection.h"

#include "CloudSyncUtils.h"

Expand Down Expand Up @@ -47,8 +47,8 @@ class CloudProjectsDatabase final

bool IsOpen() const;

sqlite::Connection& GetConnection();
const sqlite::Connection& GetConnection() const;
sqlite::SafeConnection::Lock GetConnection();
const sqlite::SafeConnection::Lock GetConnection() const;

std::optional<DBProjectData> GetProjectData(const std::string_view& projectId) const;
std::optional<DBProjectData> GetProjectDataForPath(const std::string& projectPath) const;
Expand All @@ -67,7 +67,7 @@ class CloudProjectsDatabase final
private:
std::optional<DBProjectData> DoGetProjectData(sqlite::RunResult result) const;
bool OpenConnection();
sqlite::Connection mConnection;
std::shared_ptr<sqlite::SafeConnection> mConnection;

};
} // namespace cloud::audiocom::sync
Loading

0 comments on commit 0892c6c

Please sign in to comment.