diff --git a/src/crc32.rs b/src/crc32.rs index ebace898d..d2f974180 100644 --- a/src/crc32.rs +++ b/src/crc32.rs @@ -12,7 +12,7 @@ pub struct Crc32Reader { check: u32, /// Signals if `inner` stores aes encrypted data. /// AE-2 encrypted data doesn't use crc and sets the value to 0. - ae2_encrypted: bool, + enabled: bool, } impl Crc32Reader { @@ -23,7 +23,7 @@ impl Crc32Reader { inner, hasher: Hasher::new(), check: checksum, - ae2_encrypted, + enabled: !ae2_encrypted, } } @@ -36,20 +36,51 @@ impl Crc32Reader { } } +#[cold] +fn invalid_checksum() -> io::Error { + io::Error::new(io::ErrorKind::InvalidData, "Invalid checksum") +} + impl Read for Crc32Reader { fn read(&mut self, buf: &mut [u8]) -> io::Result { - let invalid_check = !buf.is_empty() && !self.check_matches() && !self.ae2_encrypted; + let count = self.inner.read(buf)?; - let count = match self.inner.read(buf) { - Ok(0) if invalid_check => { - return Err(io::Error::new(io::ErrorKind::Other, "Invalid checksum")) + if self.enabled { + if count == 0 && !buf.is_empty() && !self.check_matches() { + return Err(invalid_checksum()); } - Ok(n) => n, - Err(e) => return Err(e), - }; - self.hasher.update(&buf[0..count]); + self.hasher.update(&buf[..count]); + } Ok(count) } + + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + let start = buf.len(); + let n = self.inner.read_to_end(buf)?; + + if self.enabled { + self.hasher.update(&buf[start..]); + if !self.check_matches() { + return Err(invalid_checksum()); + } + } + + Ok(n) + } + + fn read_to_string(&mut self, buf: &mut String) -> io::Result { + let start = buf.len(); + let n = self.inner.read_to_string(buf)?; + + if self.enabled { + self.hasher.update(&buf.as_bytes()[start..]); + if !self.check_matches() { + return Err(invalid_checksum()); + } + } + + Ok(n) + } } #[cfg(test)] diff --git a/src/read.rs b/src/read.rs index b702b4f21..e4add3bf4 100644 --- a/src/read.rs +++ b/src/read.rs @@ -91,6 +91,24 @@ impl<'a> Read for CryptoReader<'a> { CryptoReader::Aes { reader: r, .. } => r.read(buf), } } + + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + match self { + CryptoReader::Plaintext(r) => r.read_to_end(buf), + CryptoReader::ZipCrypto(r) => r.read_to_end(buf), + #[cfg(feature = "aes-crypto")] + CryptoReader::Aes { reader: r, .. } => r.read_to_end(buf), + } + } + + fn read_to_string(&mut self, buf: &mut String) -> io::Result { + match self { + CryptoReader::Plaintext(r) => r.read_to_string(buf), + CryptoReader::ZipCrypto(r) => r.read_to_string(buf), + #[cfg(feature = "aes-crypto")] + CryptoReader::Aes { reader: r, .. } => r.read_to_string(buf), + } + } } impl<'a> CryptoReader<'a> { @@ -153,6 +171,60 @@ impl<'a> Read for ZipFileReader<'a> { ZipFileReader::Zstd(r) => r.read(buf), } } + + fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + match self { + ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), + ZipFileReader::Raw(r) => r.read_exact(buf), + ZipFileReader::Stored(r) => r.read_exact(buf), + #[cfg(any( + feature = "deflate", + feature = "deflate-miniz", + feature = "deflate-zlib" + ))] + ZipFileReader::Deflated(r) => r.read_exact(buf), + #[cfg(feature = "bzip2")] + ZipFileReader::Bzip2(r) => r.read_exact(buf), + #[cfg(feature = "zstd")] + ZipFileReader::Zstd(r) => r.read_exact(buf), + } + } + + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + match self { + ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), + ZipFileReader::Raw(r) => r.read_to_end(buf), + ZipFileReader::Stored(r) => r.read_to_end(buf), + #[cfg(any( + feature = "deflate", + feature = "deflate-miniz", + feature = "deflate-zlib" + ))] + ZipFileReader::Deflated(r) => r.read_to_end(buf), + #[cfg(feature = "bzip2")] + ZipFileReader::Bzip2(r) => r.read_to_end(buf), + #[cfg(feature = "zstd")] + ZipFileReader::Zstd(r) => r.read_to_end(buf), + } + } + + fn read_to_string(&mut self, buf: &mut String) -> io::Result { + match self { + ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), + ZipFileReader::Raw(r) => r.read_to_string(buf), + ZipFileReader::Stored(r) => r.read_to_string(buf), + #[cfg(any( + feature = "deflate", + feature = "deflate-miniz", + feature = "deflate-zlib" + ))] + ZipFileReader::Deflated(r) => r.read_to_string(buf), + #[cfg(feature = "bzip2")] + ZipFileReader::Bzip2(r) => r.read_to_string(buf), + #[cfg(feature = "zstd")] + ZipFileReader::Zstd(r) => r.read_to_string(buf), + } + } } impl<'a> ZipFileReader<'a> { @@ -979,6 +1051,18 @@ impl<'a> Read for ZipFile<'a> { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.get_reader().read(buf) } + + fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.get_reader().read_exact(buf) + } + + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + self.get_reader().read_to_end(buf) + } + + fn read_to_string(&mut self, buf: &mut String) -> io::Result { + self.get_reader().read_to_string(buf) + } } impl<'a> Drop for ZipFile<'a> {