Skip to content

Commit c178210

Browse files
authored
Make cast method produced by implement macro unsafe (#1753)
* Remove broken cast method on types using implement macro * Simplify From impls from implement macro * Free implementation in alloc if querying fails * Add back cast but make it unsafe * Remove alloc
1 parent caecac5 commit c178210

File tree

5 files changed

+27
-27
lines changed

5 files changed

+27
-27
lines changed

crates/libs/implement/src/lib.rs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
5454
impl <#constraints> ::core::convert::From<#original_ident::<#(#generics,)*>> for #interface_ident {
5555
fn from(this: #original_ident::<#(#generics,)*>) -> Self {
5656
let this = #impl_ident::<#(#generics,)*>::new(this);
57-
let mut this = ::std::boxed::Box::new(this);
58-
let vtable_ptr = &mut this.vtables.#offset as *mut *const <#interface_ident as ::windows::core::Interface>::Vtable;
59-
let _ = ::std::boxed::Box::leak(this);
60-
unsafe { ::core::mem::transmute_copy(&vtable_ptr) }
57+
let mut this = ::core::mem::ManuallyDrop::new(::std::boxed::Box::new(this));
58+
let vtable_ptr = &this.vtables.#offset;
59+
// SAFETY: interfaces are in-memory equivalent to pointers to their vtables.
60+
unsafe { ::core::mem::transmute(vtable_ptr) }
6161
}
6262
}
6363
impl <#constraints> ::windows::core::AsImpl<#original_ident::<#(#generics,)*>> for #interface_ident {
@@ -145,12 +145,16 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
145145
}
146146
}
147147
impl <#constraints> #original_ident::<#(#generics,)*> {
148-
fn cast<ResultType: ::windows::core::Interface>(&self) -> ::windows::core::Result<ResultType> {
149-
unsafe {
150-
let boxed = (self as *const #original_ident::<#(#generics,)*> as *mut #original_ident::<#(#generics,)*> as *mut ::windows::core::RawPtr).sub(2 + #interfaces_len) as *mut #impl_ident::<#(#generics,)*>;
151-
let mut result = None;
152-
<#impl_ident::<#(#generics,)*> as ::windows::core::IUnknownImpl>::QueryInterface(&*boxed, &ResultType::IID, &mut result as *mut _ as _).and_some(result)
153-
}
148+
/// Try casting as the provided interface
149+
///
150+
/// # Safety
151+
///
152+
/// This function can only be safely called if `self` has been heap allocated and pinned using
153+
/// the mechanisms provided by `implement` macro.
154+
unsafe fn cast<I: ::windows::core::Interface>(&self) -> ::windows::core::Result<I> {
155+
let boxed = (self as *const _ as *const ::windows::core::RawPtr).sub(2 + #interfaces_len) as *mut #impl_ident::<#(#generics,)*>;
156+
let mut result = None;
157+
<#impl_ident::<#(#generics,)*> as ::windows::core::IUnknownImpl>::QueryInterface(&*boxed, &I::IID, &mut result as *mut _ as _).and_some(result)
154158
}
155159
}
156160
impl <#constraints> ::windows::core::Compose for #original_ident::<#(#generics,)*> {
@@ -163,23 +167,19 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
163167
}
164168
impl <#constraints> ::core::convert::From<#original_ident::<#(#generics,)*>> for ::windows::core::IUnknown {
165169
fn from(this: #original_ident::<#(#generics,)*>) -> Self {
170+
let this = #impl_ident::<#(#generics,)*>::new(this);
171+
let boxed = ::core::mem::ManuallyDrop::new(::std::boxed::Box::new(this));
166172
unsafe {
167-
let this = #impl_ident::<#(#generics,)*>::new(this);
168-
let ptr = ::std::boxed::Box::into_raw(::std::boxed::Box::new(this));
169-
::core::mem::transmute_copy(&::core::ptr::NonNull::new_unchecked(
170-
&mut (*ptr).identity as *mut _ as _
171-
))
173+
::core::mem::transmute(&boxed.identity)
172174
}
173175
}
174176
}
175177
impl <#constraints> ::core::convert::From<#original_ident::<#(#generics,)*>> for ::windows::core::IInspectable {
176178
fn from(this: #original_ident::<#(#generics,)*>) -> Self {
179+
let this = #impl_ident::<#(#generics,)*>::new(this);
180+
let boxed = ::core::mem::ManuallyDrop::new(::std::boxed::Box::new(this));
177181
unsafe {
178-
let this = #impl_ident::<#(#generics,)*>::new(this);
179-
let ptr = ::std::boxed::Box::into_raw(::std::boxed::Box::new(this));
180-
::core::mem::transmute_copy(&::core::ptr::NonNull::new_unchecked(
181-
&mut (*ptr).identity as *mut _ as _
182-
))
182+
::core::mem::transmute(&boxed.identity)
183183
}
184184
}
185185
}

crates/tests/nightly_implement/tests/cast_self.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ use windows::UI::Xaml::*;
55
// TODO: This is a compile-only test for now until #81 is further along and can provide composable test classes.
66

77
#[implement(IApplicationOverrides)]
8-
struct App();
8+
struct App;
99

1010
#[allow(non_snake_case)]
1111
impl IApplicationOverrides_Impl for App {
1212
fn OnLaunched(&self, _: &Option<LaunchActivatedEventArgs>) -> Result<()> {
13-
let app: Application = self.cast()?;
13+
let app: Application = unsafe { self.cast()? };
1414
assert!(app.FocusVisualKind()? == FocusVisualKind::DottedLine);
1515
Ok(())
1616
}

crates/tests/nightly_implement/tests/com.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use windows::Win32::System::WinRT::Composition::*;
77
use windows::Win32::System::WinRT::Display::*;
88

99
#[implement(windows::Foundation::IStringable, windows::Win32::System::WinRT::Composition::ISwapChainInterop, windows::Win32::System::WinRT::Display::IDisplayPathInterop)]
10-
struct Mix();
10+
struct Mix;
1111

1212
impl IStringable_Impl for Mix {
1313
fn ToString(&self) -> Result<HSTRING> {
@@ -32,13 +32,13 @@ impl IDisplayPathInterop_Impl for Mix {
3232

3333
#[test]
3434
fn mix() -> Result<()> {
35-
let a: ISwapChainInterop = Mix().into();
35+
let a: ISwapChainInterop = Mix.into();
3636
unsafe { a.SetSwapChain(None)? };
3737

3838
let b: IStringable = a.cast()?;
3939
assert!(b.ToString()? == "Mix");
4040

41-
let c: IStringable = Mix().into();
41+
let c: IStringable = Mix.into();
4242
assert!(c.ToString()? == "Mix");
4343

4444
let d: ISwapChainInterop = c.cast()?;

crates/tests/nightly_implement/tests/into_impl.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ where
5454
#[allow(non_snake_case)]
5555
impl<T: RuntimeType + 'static> IIterable_Impl<T> for Iterable<T> {
5656
fn First(&self) -> Result<IIterator<T>> {
57-
Ok(Iterator::<T>((self.cast()?, 0).into()).into())
57+
Ok(Iterator::<T>((unsafe { self.cast()? }, 0).into()).into())
5858
}
5959
}
6060

crates/tests/nightly_vector/tests/test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ impl<T: ::windows::core::RuntimeType + 'static> IVector_Impl<T> for Vector<T> {
6161
self.Size()
6262
}
6363
fn GetView(&self) -> Result<windows::Foundation::Collections::IVectorView<T>> {
64-
self.cast()
64+
unsafe { self.cast() }
6565
}
6666
fn IndexOf(&self, value: &T::DefaultType, result: &mut u32) -> Result<bool> {
6767
self.IndexOf(value, result)

0 commit comments

Comments
 (0)