Skip to content

Commit

Permalink
Merge pull request #26 from gcmurphy/fix/switch-async-lib
Browse files Browse the repository at this point in the history
fix update client to use reqwest and tokio
  • Loading branch information
gcmurphy authored Mar 2, 2024
2 parents a306eb4 + 0557365 commit b6b75db
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 37 deletions.
10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ edition = "2021"
authors = ["Grant Murphy <[email protected]>"]
repository = "https://github.com/gcmurphy/osv"
documentation = "https://docs.rs/osv"
description = "Rust client library for the osv API"
description = "Rust library for parsing the OSV schema and client API"
readme = "README.md"
license = "Apache-2.0"
keywords = ["vulnerabilities", "security", "osv"]
Expand All @@ -14,23 +14,23 @@ keywords = ["vulnerabilities", "security", "osv"]
all-features = true

[dependencies]
async-std = { version = "1.12.0", features = ["attributes"], optional = true }
chrono = { version = "0.4", features = ["serde"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
surf = { version = "2.3.2", optional = true }
reqwest = { version = "0.11", features = ["json"], optional = true }
tokio = { version = "1", features = ["full"], optional = true }
thiserror = { version = "1.0", optional = true }
url = { version = "2.3.1", optional = true }

[dev-dependencies]
comfy-table = "5.0.1"
textwrap = { version = "0.15.0", features = ["default", "terminal_size"] }
tokio-test = "0.4.3"

[features]
default = ["schema"]
schema = []
client = ["dep:async-std", "dep:surf", "dep:url", "dep:thiserror", "schema"]
search = []
client = ["dep:tokio", "dep:reqwest", "dep:url", "dep:thiserror", "schema"]

[[example]]
name = "commit"
Expand Down
2 changes: 1 addition & 1 deletion examples/commit.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#[async_std::main]
#[tokio::main]
async fn main() {
let commit = "6879efc2c1596d11a6a6ad296f80063b558d5e0f";
let res = osv::client::query_commit(commit).await.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion examples/package.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use osv::client;
use osv::schema::Ecosystem::PyPI;
use textwrap::termwidth;

#[async_std::main]
#[tokio::main]
async fn main() -> Result<(), client::ApiError> {
if let Some(vulns) = client::query_package("jinja2", "2.4.1", PyPI).await? {
let default = String::from("-");
Expand Down
2 changes: 1 addition & 1 deletion examples/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::process;

use osv::client;

#[async_std::main]
#[tokio::main]
async fn main() -> Result<(), client::ApiError> {
let args: Vec<String> = env::args().skip(1).collect();
if args.len() <= 0 {
Expand Down
2 changes: 1 addition & 1 deletion examples/vulnerability.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use osv::client;

#[async_std::main]
#[tokio::main]
async fn main() -> Result<(), client::ApiError> {
let vuln = client::vulnerability("GHSA-jfh8-c2jp-5v3q").await?;
println!("{:#?}", vuln);
Expand Down
59 changes: 31 additions & 28 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
//! use osv::schema::Ecosystem::PyPI;
//! use textwrap::termwidth;
//!
//! #[async_std::main]
//! #[tokio::main]
//! async fn main() -> Result<(), osv::client::ApiError> {
//!
//! if let Some(vulns) = osv::client::query_package("jinja2", "2.4.1", PyPI).await? {
Expand All @@ -45,7 +45,7 @@
use super::schema::*;
use serde::{Deserialize, Serialize};
use surf::http::StatusCode;
use reqwest::StatusCode;
use thiserror::Error;
use url::Url;

Expand Down Expand Up @@ -86,14 +86,14 @@ pub enum ApiError {
SerializationError(#[from] serde_json::Error),

#[error("request to osv endpoint failed: {0:?}")]
RequestFailed(surf::Error),
RequestFailed(reqwest::Error),

#[error("unexpected error has occurred")]
Unexpected,
}

impl From<surf::Error> for ApiError {
fn from(err: surf::Error) -> Self {
impl From<reqwest::Error> for ApiError {
fn from(err: reqwest::Error) -> Self {
ApiError::RequestFailed(err)
}
}
Expand All @@ -113,8 +113,8 @@ impl From<surf::Error> for ApiError {
/// # Examples
///
/// ```
/// # use async_std::task;
/// # task::block_on(async {
/// # use tokio_test;
/// # tokio_test::block_on(async {
/// let ver = osv::schema::Version::from("2.4.1");
/// let pkg = "jinja2".to_string();
/// let req = osv::client::Request::PackageQuery {
Expand All @@ -133,12 +133,14 @@ impl From<surf::Error> for ApiError {
///
///
pub async fn query(q: &Request) -> Result<Option<Vec<Vulnerability>>, ApiError> {
let mut res = surf::post("https://api.osv.dev/v1/query")
.body_json(q)?
let client = reqwest::Client::new();
let res = client.post("https://api.osv.dev/v1/query")
.json(q)
.send()
.await?;

match res.status() {
StatusCode::NotFound => {
StatusCode::NOT_FOUND => {
let err = match q {
Request::PackageQuery {
version: _,
Expand All @@ -153,7 +155,7 @@ pub async fn query(q: &Request) -> Result<Option<Vec<Vulnerability>>, ApiError>
Err(ApiError::NotFound(err))
}
_ => {
let vulns: Response = res.body_json().await?;
let vulns: Response = res.json().await?;
match vulns {
Response::Vulnerabilities { vulns: vs } => Ok(Some(vs)),
_ => Ok(None),
Expand All @@ -175,8 +177,8 @@ pub async fn query(q: &Request) -> Result<Option<Vec<Vulnerability>>, ApiError>
/// ```
/// use osv::client::query_package;
/// use osv::schema::Ecosystem::PyPI;
/// # use async_std::task;
/// # task::block_on(async {
/// # use tokio_test;
/// # tokio_test::block_on(async {
/// let pkg = "jinja2";
/// let ver = "2.4.1";
/// if let Some(vulns) = query_package(pkg, ver, PyPI).await.unwrap() {
Expand Down Expand Up @@ -218,9 +220,9 @@ pub async fn query_package(
/// # Examples
///
/// ```
/// # use async_std::task;
/// # use osv::client::query_commit;
/// # task::block_on(async {
/// # use tokio_test;
/// # tokio_test::block_on(async {
/// let vulnerable = query_commit("6879efc2c1596d11a6a6ad296f80063b558d5e0f")
/// .await
/// .expect("api error");
Expand All @@ -245,9 +247,10 @@ pub async fn query_commit(commit: &str) -> Result<Option<Vec<Vulnerability>>, Ap
/// # Examples
///
/// ```
/// # use async_std::task;
/// # use tokio::task;
/// use osv::client::vulnerability;
/// # task::block_on(async {
/// # use tokio_test;
/// # tokio_test::block_on(async {
/// let vuln = vulnerability("OSV-2020-484").await.unwrap();
/// assert!(vuln.id.eq("OSV-2020-484"));
///
Expand All @@ -256,11 +259,11 @@ pub async fn query_commit(commit: &str) -> Result<Option<Vec<Vulnerability>>, Ap
pub async fn vulnerability(vuln_id: &str) -> Result<Vulnerability, ApiError> {
let base = Url::parse("https://api.osv.dev/v1/vulns/")?;
let req = base.join(vuln_id)?;
let mut res = surf::get(req.as_str()).await?;
if res.status() == StatusCode::NotFound {
let res = reqwest::get(req.as_str()).await?;
if res.status() == StatusCode::NOT_FOUND {
Err(ApiError::NotFound(vuln_id.to_string()))
} else {
let vuln: Vulnerability = res.body_json().await?;
let vuln: Vulnerability = res.json().await?;
Ok(vuln)
}
}
Expand All @@ -269,7 +272,7 @@ pub async fn vulnerability(vuln_id: &str) -> Result<Vulnerability, ApiError> {
mod tests {
use super::*;

#[async_std::test]
#[tokio::test]
async fn test_package_query() {
let req = Request::PackageQuery {
version: Version::from("2.4.1"),
Expand All @@ -283,15 +286,15 @@ mod tests {
assert!(res.is_some());
}

#[async_std::test]
#[tokio::test]
async fn test_package_query_wrapper() {
let res = query_package("jinja2", "2.4.1", Ecosystem::PyPI)
.await
.unwrap();
assert!(res.is_some());
}

#[async_std::test]
#[tokio::test]
async fn test_invalid_packagename() {
let res = query_package(
"asdfasdlfkjlksdjfklsdjfklsdjfklds",
Expand All @@ -303,7 +306,7 @@ mod tests {
assert!(res.is_none());
}

#[async_std::test]
#[tokio::test]
async fn test_commit_query() {
let req = Request::CommitQuery {
commit: Commit::from("6879efc2c1596d11a6a6ad296f80063b558d5e0f"),
Expand All @@ -312,27 +315,27 @@ mod tests {
assert!(res.is_some());
}

#[async_std::test]
#[tokio::test]
async fn test_commit_query_wrapper() {
let res = query_commit("6879efc2c1596d11a6a6ad296f80063b558d5e0f")
.await
.unwrap();
assert!(res.is_some());
}

#[async_std::test]
#[tokio::test]
async fn test_invalid_commit() {
let res = query_commit("zzzz").await.unwrap();
assert!(res.is_none());
}

#[async_std::test]
#[tokio::test]
async fn test_vulnerability() {
let res = vulnerability("OSV-2020-484").await;
assert!(res.is_ok());
}

#[async_std::test]
#[tokio::test]
async fn test_get_missing_cve() {
let res = vulnerability("CVE-2014-0160").await;
assert!(res.is_err());
Expand Down

0 comments on commit b6b75db

Please sign in to comment.