From d7b5640a25c4260b274eaf892b0d6e9d3e886bec Mon Sep 17 00:00:00 2001 From: Mikolaj Wielgus Date: Sat, 16 Sep 2023 23:08:20 +0200 Subject: [PATCH] contracts: Improve abstract type detection (i.e. containing impl) Now it also checks for if there's an impl inside the type, not only on the outside. --- .../contracts/src/implementation/codegen.rs | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/vendor/contracts/src/implementation/codegen.rs b/vendor/contracts/src/implementation/codegen.rs index 7655d16..c3a59f9 100644 --- a/vendor/contracts/src/implementation/codegen.rs +++ b/vendor/contracts/src/implementation/codegen.rs @@ -6,7 +6,7 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::ToTokens; use syn::{ spanned::Spanned, visit_mut as visitor, Attribute, Expr, ExprCall, - ReturnType, Type, + ReturnType, TypeImplTrait, }; use crate::implementation::{ @@ -321,12 +321,18 @@ pub(crate) fn generate( let body = 'blk: { let mut block = func.function.block.clone(); - syn::visit_mut::visit_block_mut(&mut ReturnReplacer {}, &mut block); + syn::visit_mut::visit_block_mut(&mut ReturnReplacer, &mut block); - if let ReturnType::Type(.., ref return_type) = func.function.sig.output - { - if let Type::ImplTrait(..) = **return_type { - } else { + let mut impl_detector = ImplDetector { found_impl: false }; + syn::visit::visit_return_type( + &mut impl_detector, + &func.function.sig.output, + ); + + if !impl_detector.found_impl { + if let ReturnType::Type(.., ref return_type) = + func.function.sig.output + { break 'blk quote::quote! { let ret: #return_type = 'run: #block; }; } } @@ -365,7 +371,7 @@ pub(crate) fn generate( func.function.into_token_stream() } -struct ReturnReplacer {} +struct ReturnReplacer; impl syn::visit_mut::VisitMut for ReturnReplacer { fn visit_expr_mut(&mut self, node: &mut Expr) { @@ -375,3 +381,13 @@ impl syn::visit_mut::VisitMut for ReturnReplacer { } } } + +struct ImplDetector { + found_impl: bool, +} + +impl<'a> syn::visit::Visit<'a> for ImplDetector { + fn visit_type_impl_trait(&mut self, _node: &'a TypeImplTrait) { + self.found_impl = true; + } +}