Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow windows-result to work on non-Windows platforms #3082

Merged
merged 2 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 85 additions & 49 deletions crates/libs/result/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use core::ffi::c_void;
#[derive(Clone, PartialEq, Eq)]
pub struct Error {
code: HRESULT,
#[cfg(windows)]
info: Option<ComPtr>,
}

Expand All @@ -13,35 +14,54 @@ impl Error {
pub const fn empty() -> Self {
Self {
code: HRESULT(0),
#[cfg(windows)]
info: None,
}
}

/// Creates a new error object, capturing the stack and other information about the
/// point of failure.
pub fn new<T: AsRef<str>>(code: HRESULT, message: T) -> Self {
sivadeilra marked this conversation as resolved.
Show resolved Hide resolved
let message: Vec<_> = message.as_ref().encode_utf16().collect();

if message.is_empty() {
Self::from_hresult(code)
} else {
unsafe {
RoOriginateErrorW(code.0, message.len() as u32, message.as_ptr());
#[cfg(windows)]
{
let message: Vec<_> = message.as_ref().encode_utf16().collect();
if message.is_empty() {
Self::from_hresult(code)
} else {
unsafe {
RoOriginateErrorW(code.0, message.len() as u32, message.as_ptr());
}
code.into()
}
code.into()
}
#[cfg(not(windows))]
{
let _ = message;
Self::from_hresult(code)
}
}

/// Creates a new error object with an error code, but without additional error information.
pub fn from_hresult(code: HRESULT) -> Self {
Self { code, info: None }
Self {
code,
#[cfg(windows)]
info: None,
}
}

/// Creates a new `Error` from the Win32 error code returned by `GetLastError()`.
pub fn from_win32() -> Self {
kennykerr marked this conversation as resolved.
Show resolved Hide resolved
Self {
code: HRESULT::from_win32(unsafe { GetLastError() }),
info: None,
#[cfg(windows)]
{
Self {
code: HRESULT::from_win32(unsafe { GetLastError() }),
info: None,
}
}
#[cfg(not(windows))]
{
unimplemented!()
}
}

Expand All @@ -52,49 +72,53 @@ impl Error {

/// The error message describing the error.
pub fn message(&self) -> String {
if let Some(info) = &self.info {
let mut message = BasicString::default();

// First attempt to retrieve the restricted error information.
if let Some(info) = info.cast(&IID_IRestrictedErrorInfo) {
let mut fallback = BasicString::default();
let mut code = 0;

unsafe {
com_call!(
IRestrictedErrorInfo_Vtbl,
info.GetErrorDetails(
&mut fallback as *mut _ as _,
&mut code,
&mut message as *mut _ as _,
&mut BasicString::default() as *mut _ as _
)
);
#[cfg(windows)]
{
if let Some(info) = &self.info {
let mut message = BasicString::default();

// First attempt to retrieve the restricted error information.
if let Some(info) = info.cast(&IID_IRestrictedErrorInfo) {
let mut fallback = BasicString::default();
let mut code = 0;

unsafe {
com_call!(
IRestrictedErrorInfo_Vtbl,
info.GetErrorDetails(
&mut fallback as *mut _ as _,
&mut code,
&mut message as *mut _ as _,
&mut BasicString::default() as *mut _ as _
)
);
}

if message.is_empty() {
message = fallback
};
}

// Next attempt to retrieve the regular error information.
if message.is_empty() {
message = fallback
};
}

// Next attempt to retrieve the regular error information.
if message.is_empty() {
unsafe {
com_call!(
IErrorInfo_Vtbl,
info.GetDescription(&mut message as *mut _ as _)
);
unsafe {
com_call!(
IErrorInfo_Vtbl,
info.GetDescription(&mut message as *mut _ as _)
);
}
}
}

return String::from_utf16_lossy(wide_trim_end(message.as_wide()));
return String::from_utf16_lossy(wide_trim_end(message.as_wide()));
}
}

// Otherwise fallback to a generic error code description.
self.code.message()
}

/// The error object describing the error.
#[cfg(windows)]
pub fn as_ptr(&self) -> *mut c_void {
self.info
.as_ref()
Expand All @@ -109,9 +133,12 @@ unsafe impl Sync for Error {}

impl From<Error> for HRESULT {
fn from(error: Error) -> Self {
if let Some(info) = error.info {
unsafe {
SetErrorInfo(0, info.as_raw());
#[cfg(windows)]
{
if let Some(info) = error.info {
unsafe {
SetErrorInfo(0, info.as_raw());
}
}
}
error.code
Expand All @@ -120,9 +147,15 @@ impl From<Error> for HRESULT {

impl From<HRESULT> for Error {
fn from(code: HRESULT) -> Self {
let mut info = None;
unsafe { GetErrorInfo(0, &mut info as *mut _ as _) };
Self { code, info }
Self {
code,
#[cfg(windows)]
info: {
let mut info = None;
unsafe { GetErrorInfo(0, &mut info as *mut _ as _) };
info
},
}
}
}

Expand All @@ -147,6 +180,7 @@ impl From<alloc::string::FromUtf16Error> for Error {
fn from(_: alloc::string::FromUtf16Error) -> Self {
Self {
code: HRESULT::from_win32(ERROR_NO_UNICODE_TRANSLATION),
#[cfg(windows)]
info: None,
}
}
Expand All @@ -156,6 +190,7 @@ impl From<alloc::string::FromUtf8Error> for Error {
fn from(_: alloc::string::FromUtf8Error) -> Self {
Self {
code: HRESULT::from_win32(ERROR_NO_UNICODE_TRANSLATION),
#[cfg(windows)]
info: None,
}
}
Expand All @@ -165,6 +200,7 @@ impl From<core::num::TryFromIntError> for Error {
fn from(_: core::num::TryFromIntError) -> Self {
Self {
code: HRESULT::from_win32(ERROR_INVALID_DATA),
#[cfg(windows)]
info: None,
}
}
Expand Down
77 changes: 44 additions & 33 deletions crates/libs/result/src/hresult.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,41 +64,52 @@ impl HRESULT {

/// The error message describing the error.
pub fn message(&self) -> String {
let mut message = HeapString::default();
let mut code = self.0;
let mut module = 0;

let mut flags = FORMAT_MESSAGE_ALLOCATE_BUFFER
| FORMAT_MESSAGE_FROM_SYSTEM
| FORMAT_MESSAGE_IGNORE_INSERTS;

unsafe {
if self.0 & 0x1000_0000 == 0x1000_0000 {
code ^= 0x1000_0000;
flags |= FORMAT_MESSAGE_FROM_HMODULE;

module =
LoadLibraryExA(b"ntdll.dll\0".as_ptr(), 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS);
#[cfg(windows)]
{
let mut message = HeapString::default();
let mut code = self.0;
let mut module = 0;

let mut flags = FORMAT_MESSAGE_ALLOCATE_BUFFER
| FORMAT_MESSAGE_FROM_SYSTEM
| FORMAT_MESSAGE_IGNORE_INSERTS;

unsafe {
if self.0 & 0x1000_0000 == 0x1000_0000 {
code ^= 0x1000_0000;
flags |= FORMAT_MESSAGE_FROM_HMODULE;

module = LoadLibraryExA(
b"ntdll.dll\0".as_ptr(),
0,
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS,
);
}

let size = FormatMessageW(
flags,
module as _,
code as _,
0,
&mut message.0 as *mut _ as *mut _,
0,
core::ptr::null(),
);

if !message.0.is_null() && size > 0 {
String::from_utf16_lossy(wide_trim_end(core::slice::from_raw_parts(
message.0,
size as usize,
)))
} else {
String::default()
}
}
}

let size = FormatMessageW(
flags,
module as _,
code as _,
0,
&mut message.0 as *mut _ as *mut _,
0,
core::ptr::null(),
);

if !message.0.is_null() && size > 0 {
String::from_utf16_lossy(wide_trim_end(core::slice::from_raw_parts(
message.0,
size as usize,
)))
} else {
String::default()
}
#[cfg(not(windows))]
{
return format!("0x{:08x}", self.0 as u32);
}
}

Expand Down
5 changes: 5 additions & 0 deletions crates/libs/result/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Learn more about Rust for Windows here: <https://github.com/microsoft/windows-rs
debugger_visualizer(natvis_file = "../.natvis")
)]
#![cfg_attr(all(not(test), not(feature = "std")), no_std)]
#![cfg_attr(not(windows), allow(unused_imports))]

extern crate alloc;

Expand All @@ -16,10 +17,14 @@ use alloc::vec::Vec;
mod bindings;
use bindings::*;

#[cfg(windows)]
mod com;
#[cfg(windows)]
use com::*;

#[cfg(windows)]
mod strings;
#[cfg(windows)]
use strings::*;

mod error;
Expand Down
34 changes: 34 additions & 0 deletions crates/tests/linux/tests/hresult.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// This tests code paths in `windows-result` that are different on non-Windows platforms.
#![cfg(not(windows))]

use windows::core::Error;
use windows::Win32::Foundation::{E_FAIL, S_OK};

#[test]
fn basic_hresult() {
assert!(E_FAIL.is_err());
assert!(S_OK.is_ok());

let ok_message = S_OK.message();
assert_eq!(ok_message, "0x00000000");
}

#[test]
fn error_message_is_not_supported() {
let e = Error::new(S_OK, "this gets ignored");
let message = e.message();
assert_eq!(message, "0x00000000");
}

#[test]
#[should_panic]
fn from_win32_panics() {
// from_win32() is not implemented on non-Windows platforms.
let _e = Error::from_win32();
}

#[test]
fn error_from_hresult() {
let e = Error::from(E_FAIL);
assert_eq!(e.code(), E_FAIL);
}