diff --git a/miette-derive/src/diagnostic.rs b/miette-derive/src/diagnostic.rs index ef1bdae..1239fce 100644 --- a/miette-derive/src/diagnostic.rs +++ b/miette-derive/src/diagnostic.rs @@ -4,6 +4,7 @@ use syn::{punctuated::Punctuated, DeriveInput, Token}; use crate::code::Code; use crate::diagnostic_arg::DiagnosticArg; +use crate::diagnostic_source::DiagnosticSource; use crate::forward::{Forward, WhichFn}; use crate::help::Help; use crate::label::Labels; @@ -66,6 +67,7 @@ pub struct DiagnosticConcreteArgs { pub url: Option, pub forward: Option, pub related: Option, + pub diagnostic_source: Option, } impl DiagnosticConcreteArgs { @@ -74,6 +76,7 @@ impl DiagnosticConcreteArgs { let source_code = SourceCode::from_fields(fields)?; let related = Related::from_fields(fields)?; let help = Help::from_fields(fields)?; + let diagnostic_source = DiagnosticSource::from_fields(fields)?; Ok(DiagnosticConcreteArgs { code: None, help, @@ -83,6 +86,7 @@ impl DiagnosticConcreteArgs { url: None, forward: None, source_code, + diagnostic_source, }) } @@ -283,6 +287,8 @@ impl Diagnostic { let source_code_method = forward.gen_struct_method(WhichFn::SourceCode); let severity_method = forward.gen_struct_method(WhichFn::Severity); let related_method = forward.gen_struct_method(WhichFn::Related); + let diagnostic_source_method = + forward.gen_struct_method(WhichFn::DiagnosticSource); quote! { impl #impl_generics miette::Diagnostic for #ident #ty_generics #where_clause { @@ -293,6 +299,7 @@ impl Diagnostic { #severity_method #source_code_method #related_method + #diagnostic_source_method } } } @@ -338,6 +345,11 @@ impl Diagnostic { .as_ref() .and_then(|x| x.gen_struct(fields)) .or_else(|| forward(WhichFn::SourceCode)); + let diagnostic_source = concrete + .diagnostic_source + .as_ref() + .and_then(|x| x.gen_struct()) + .or_else(|| forward(WhichFn::DiagnosticSource)); quote! { impl #impl_generics miette::Diagnostic for #ident #ty_generics #where_clause { #code_body @@ -347,6 +359,7 @@ impl Diagnostic { #url_body #labels_body #src_body + #diagnostic_source } } } @@ -365,6 +378,7 @@ impl Diagnostic { let src_body = SourceCode::gen_enum(variants); let rel_body = Related::gen_enum(variants); let url_body = Url::gen_enum(ident, variants); + let diagnostic_source_body = DiagnosticSource::gen_enum(variants); quote! { impl #impl_generics miette::Diagnostic for #ident #ty_generics #where_clause { #code_body @@ -374,6 +388,7 @@ impl Diagnostic { #src_body #rel_body #url_body + #diagnostic_source_body } } } diff --git a/miette-derive/src/diagnostic_source.rs b/miette-derive/src/diagnostic_source.rs new file mode 100644 index 0000000..949defe --- /dev/null +++ b/miette-derive/src/diagnostic_source.rs @@ -0,0 +1,78 @@ +use proc_macro2::TokenStream; +use quote::quote; +use syn::spanned::Spanned; + +use crate::forward::WhichFn; +use crate::{ + diagnostic::{DiagnosticConcreteArgs, DiagnosticDef}, + utils::{display_pat_members, gen_all_variants_with}, +}; + +pub struct DiagnosticSource(syn::Member); + +impl DiagnosticSource { + pub(crate) fn from_fields(fields: &syn::Fields) -> syn::Result> { + match fields { + syn::Fields::Named(named) => Self::from_fields_vec(named.named.iter().collect()), + syn::Fields::Unnamed(unnamed) => { + Self::from_fields_vec(unnamed.unnamed.iter().collect()) + } + syn::Fields::Unit => Ok(None), + } + } + + fn from_fields_vec(fields: Vec<&syn::Field>) -> syn::Result> { + for (i, field) in fields.iter().enumerate() { + for attr in &field.attrs { + if attr.path.is_ident("diagnostic_source") { + let diagnostic_source = if let Some(ident) = field.ident.clone() { + syn::Member::Named(ident) + } else { + syn::Member::Unnamed(syn::Index { + index: i as u32, + span: field.span(), + }) + }; + return Ok(Some(DiagnosticSource(diagnostic_source))); + } + } + } + Ok(None) + } + + pub(crate) fn gen_enum(variants: &[DiagnosticDef]) -> Option { + gen_all_variants_with( + variants, + WhichFn::DiagnosticSource, + |ident, + fields, + DiagnosticConcreteArgs { + diagnostic_source, .. + }| { + let (display_pat, _display_members) = display_pat_members(fields); + diagnostic_source.as_ref().map(|diagnostic_source| { + let rel = match &diagnostic_source.0 { + syn::Member::Named(ident) => ident.clone(), + syn::Member::Unnamed(syn::Index { index, .. }) => { + quote::format_ident!("_{}", index) + } + }; + quote! { + Self::#ident #display_pat => { + std::option::Option::Some(#rel.as_ref()) + } + } + }) + }, + ) + } + + pub(crate) fn gen_struct(&self) -> Option { + let rel = &self.0; + Some(quote! { + fn diagnostic_source<'a>(&'a self) -> std::option::Option<&'a dyn miette::Diagnostic> { + std::option::Option::Some(&self.#rel) + } + }) + } +} diff --git a/miette-derive/src/forward.rs b/miette-derive/src/forward.rs index ca7e1b3..c8757b2 100644 --- a/miette-derive/src/forward.rs +++ b/miette-derive/src/forward.rs @@ -38,6 +38,7 @@ pub enum WhichFn { Labels, SourceCode, Related, + DiagnosticSource, } impl WhichFn { @@ -50,6 +51,7 @@ impl WhichFn { Self::Labels => quote! { labels() }, Self::SourceCode => quote! { source_code() }, Self::Related => quote! { related() }, + Self::DiagnosticSource => quote! { diagnostic_source() }, } } @@ -76,6 +78,9 @@ impl WhichFn { Self::SourceCode => quote! { fn source_code(&self) -> std::option::Option<&dyn miette::SourceCode> }, + Self::DiagnosticSource => quote! { + fn diagnostic_source(&self) -> std::option::Option<&dyn miette::Diagnostic> + }, } } diff --git a/miette-derive/src/lib.rs b/miette-derive/src/lib.rs index da8f8bb..0f7e64e 100644 --- a/miette-derive/src/lib.rs +++ b/miette-derive/src/lib.rs @@ -6,6 +6,7 @@ use diagnostic::Diagnostic; mod code; mod diagnostic; mod diagnostic_arg; +mod diagnostic_source; mod fmt; mod forward; mod help; @@ -16,7 +17,10 @@ mod source_code; mod url; mod utils; -#[proc_macro_derive(Diagnostic, attributes(diagnostic, source_code, label, related, help))] +#[proc_macro_derive( + Diagnostic, + attributes(diagnostic, source_code, label, related, help, diagnostic_source) +)] pub fn derive_diagnostic(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); let cmd = match Diagnostic::from_derive_input(input) { diff --git a/tests/test_diagnostic_source_macro.rs b/tests/test_diagnostic_source_macro.rs new file mode 100644 index 0000000..d7c5d2d --- /dev/null +++ b/tests/test_diagnostic_source_macro.rs @@ -0,0 +1,20 @@ +use miette::Diagnostic; + +#[derive(Debug, miette::Diagnostic, thiserror::Error)] +#[error("AnErr")] +struct AnErr; + +#[derive(Debug, miette::Diagnostic, thiserror::Error)] +#[error("TestError")] +struct TestError { + #[diagnostic_source] + asdf_inner_foo: AnErr, +} + +#[test] +fn test_diagnostic_source() { + let error = TestError { + asdf_inner_foo: AnErr, + }; + assert!(error.diagnostic_source().is_some()); +}