diff --git a/crates/esthri/src/types.rs b/crates/esthri/src/types.rs index 27a71d38..85857261 100644 --- a/crates/esthri/src/types.rs +++ b/crates/esthri/src/types.rs @@ -10,12 +10,15 @@ * WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE. */ +use std::fmt::{Display, Formatter}; +use std::str::FromStr; use std::{ path::{Path, PathBuf}, result::Result as StdResult, }; use regex::Regex; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; #[derive(Debug, Clone)] pub enum S3PathParam { @@ -43,15 +46,21 @@ impl S3PathParam { } } -impl std::str::FromStr for S3PathParam { +impl FromStr for S3PathParam { type Err = String; fn from_str(s: &str) -> StdResult { let s3_format = Regex::new(r"^s3://(?P[^/]+)/(?P.*)$").unwrap(); if let Some(captures) = s3_format.captures(s) { - let bucket = captures.name("bucket").unwrap().as_str(); - let key = captures.name("key").unwrap().as_str(); + let bucket = captures + .name("bucket") + .ok_or("s3 bucket not found".to_string())? + .as_str(); + let key = captures + .name("key") + .ok_or("s3 key not found".to_string())? + .as_str(); Ok(S3PathParam::new_bucket(bucket, key)) } else { Ok(S3PathParam::new_local(s)) @@ -59,6 +68,28 @@ impl std::str::FromStr for S3PathParam { } } +impl Display for S3PathParam { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + S3PathParam::Local { path } => write!(f, "{}", path.display()), + S3PathParam::Bucket { bucket, key } => write!(f, "s3://{bucket}/{key}"), + } + } +} + +impl Serialize for S3PathParam { + fn serialize(&self, serializer: S) -> StdResult { + serializer.collect_str(self) + } +} + +impl<'de> Deserialize<'de> for S3PathParam { + fn deserialize>(deserializer: D) -> StdResult { + let s = String::deserialize(deserializer)?; + S3PathParam::from_str(&s).map_err(serde::de::Error::custom) + } +} + #[derive(Debug, Clone)] pub enum S3ListingItem { S3Object(S3Object),