Add required bounds to derived impl

This is implemented using a new TypeBoundsStore struct which tracks
usage of types during parsing and stores required bounds for generics,
trying to use some heuristics to remove unneeded bounds.

Signed-off-by: Justus Fluegel <justusfluegel@gmail.com>
Signed-off-by: Justus Flügel <justusfluegel@gmail.com>
This commit is contained in:
Justus Fluegel 2025-02-02 21:00:52 +01:00
parent df7bcfa17d
commit 1a3a4c224a
No known key found for this signature in database
GPG Key ID: DD4B1903FEACCC4D
9 changed files with 484 additions and 35 deletions

View File

@ -13,4 +13,4 @@ proc-macro = true
[dependencies]
proc-macro2 = "1.0.83"
quote = "1.0.35"
syn = "2.0.87"
syn = { version = "2.0.87", features = ["extra-traits"] }

View File

@ -11,6 +11,7 @@ use crate::label::Labels;
use crate::related::Related;
use crate::severity::Severity;
use crate::source_code::SourceCode;
use crate::trait_bounds::TraitBoundStore;
use crate::url::Url;
pub enum Diagnostic {
@ -19,11 +20,13 @@ pub enum Diagnostic {
ident: syn::Ident,
fields: syn::Fields,
args: DiagnosticDefArgs,
bound_store: TraitBoundStore,
},
Enum {
ident: syn::Ident,
generics: syn::Generics,
variants: Vec<DiagnosticDef>,
bound_store: TraitBoundStore,
},
}
@ -71,12 +74,15 @@ pub struct DiagnosticConcreteArgs {
}
impl DiagnosticConcreteArgs {
fn for_fields(fields: &syn::Fields) -> Result<Self, syn::Error> {
let labels = Labels::from_fields(fields)?;
let source_code = SourceCode::from_fields(fields)?;
let related = Related::from_fields(fields)?;
fn for_fields(
fields: &syn::Fields,
bounds_store: &mut TraitBoundStore,
) -> Result<Self, syn::Error> {
let labels = Labels::from_fields(fields, bounds_store)?;
let source_code = SourceCode::from_fields(fields, bounds_store)?;
let related = Related::from_fields(fields, bounds_store)?;
let help = Help::from_fields(fields)?;
let diagnostic_source = DiagnosticSource::from_fields(fields)?;
let diagnostic_source = DiagnosticSource::from_fields(fields, bounds_store)?;
Ok(DiagnosticConcreteArgs {
code: None,
help,
@ -156,6 +162,7 @@ impl DiagnosticDefArgs {
_ident: &syn::Ident,
fields: &syn::Fields,
attrs: &[&syn::Attribute],
bounds_store: &mut TraitBoundStore,
allow_transparent: bool,
) -> syn::Result<Self> {
let mut errors = Vec::new();
@ -166,7 +173,7 @@ impl DiagnosticDefArgs {
attrs[0].parse_args_with(Punctuated::<DiagnosticArg, Token![,]>::parse_terminated)
{
if matches!(args.first(), Some(DiagnosticArg::Transparent)) {
let forward = Forward::for_transparent_field(fields)?;
let forward = Forward::for_transparent_field(fields, bounds_store)?;
return Ok(Self::Transparent(forward));
}
}
@ -182,7 +189,7 @@ impl DiagnosticDefArgs {
matches!(d, DiagnosticArg::Transparent)
}
let mut concrete = DiagnosticConcreteArgs::for_fields(fields)?;
let mut concrete = DiagnosticConcreteArgs::for_fields(fields, bounds_store)?;
for attr in attrs {
let args =
attr.parse_args_with(Punctuated::<DiagnosticArg, Token![,]>::parse_terminated);
@ -226,10 +233,13 @@ impl Diagnostic {
.collect::<Vec<&syn::Attribute>>();
Ok(match input.data {
syn::Data::Struct(data_struct) => {
let mut bounds_store = TraitBoundStore::new(&input.generics);
let args = DiagnosticDefArgs::parse(
&input.ident,
&data_struct.fields,
&input_attrs,
&mut bounds_store,
true,
)?;
@ -238,16 +248,23 @@ impl Diagnostic {
ident: input.ident,
generics: input.generics,
args,
bound_store: bounds_store,
}
}
syn::Data::Enum(syn::DataEnum { variants, .. }) => {
let mut vars = Vec::new();
let mut bound_store = TraitBoundStore::new(&input.generics);
for var in variants {
let mut variant_attrs = input_attrs.clone();
variant_attrs
.extend(var.attrs.iter().filter(|x| x.path().is_ident("diagnostic")));
let args =
DiagnosticDefArgs::parse(&var.ident, &var.fields, &variant_attrs, true)?;
let args = DiagnosticDefArgs::parse(
&var.ident,
&var.fields,
&variant_attrs,
&mut bound_store,
true,
)?;
vars.push(DiagnosticDef {
ident: var.ident,
fields: var.fields,
@ -258,6 +275,7 @@ impl Diagnostic {
ident: input.ident,
generics: input.generics,
variants: vars,
bound_store,
}
}
syn::Data::Union(_) => {
@ -276,8 +294,11 @@ impl Diagnostic {
fields,
generics,
args,
bound_store,
} => {
let (impl_generics, ty_generics, where_clause) = &generics.split_for_impl();
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let where_clause = bound_store.merge_with(where_clause);
match args {
DiagnosticDefArgs::Transparent(forward) => {
let code_method = forward.gen_struct_method(WhichFn::Code);
@ -369,8 +390,11 @@ impl Diagnostic {
ident,
generics,
variants,
bound_store,
} => {
let (impl_generics, ty_generics, where_clause) = &generics.split_for_impl();
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let where_clause = bound_store.merge_with(where_clause);
let code_body = Code::gen_enum(variants);
let help_body = Help::gen_enum(variants);
let sev_body = Severity::gen_enum(variants);

View File

@ -3,6 +3,7 @@ use quote::quote;
use syn::spanned::Spanned;
use crate::forward::WhichFn;
use crate::trait_bounds::TraitBoundStore;
use crate::{
diagnostic::{DiagnosticConcreteArgs, DiagnosticDef},
utils::{display_pat_members, gen_all_variants_with},
@ -11,17 +12,25 @@ use crate::{
pub struct DiagnosticSource(syn::Member);
impl DiagnosticSource {
pub(crate) fn from_fields(fields: &syn::Fields) -> syn::Result<Option<Self>> {
pub(crate) fn from_fields(
fields: &syn::Fields,
bounds_store: &mut TraitBoundStore,
) -> syn::Result<Option<Self>> {
match fields {
syn::Fields::Named(named) => Self::from_fields_vec(named.named.iter().collect()),
syn::Fields::Named(named) => {
Self::from_fields_vec(named.named.iter().collect(), bounds_store)
}
syn::Fields::Unnamed(unnamed) => {
Self::from_fields_vec(unnamed.unnamed.iter().collect())
Self::from_fields_vec(unnamed.unnamed.iter().collect(), bounds_store)
}
syn::Fields::Unit => Ok(None),
}
}
fn from_fields_vec(fields: Vec<&syn::Field>) -> syn::Result<Option<Self>> {
fn from_fields_vec(
fields: Vec<&syn::Field>,
bounds_store: &mut TraitBoundStore,
) -> syn::Result<Option<Self>> {
for (i, field) in fields.iter().enumerate() {
for attr in &field.attrs {
if attr.path().is_ident("diagnostic_source") {
@ -33,6 +42,9 @@ impl DiagnosticSource {
span: field.span(),
})
};
bounds_store.register_source_usage(&field.ty);
return Ok(Some(DiagnosticSource(diagnostic_source)));
}
}

View File

@ -6,6 +6,8 @@ use syn::{
spanned::Spanned,
};
use crate::trait_bounds::TraitBoundStore;
pub enum Forward {
Unnamed(usize),
Named(syn::Ident),
@ -90,7 +92,10 @@ impl WhichFn {
}
impl Forward {
pub fn for_transparent_field(fields: &syn::Fields) -> syn::Result<Self> {
pub fn for_transparent_field(
fields: &syn::Fields,
bounds_store: &mut TraitBoundStore,
) -> syn::Result<Self> {
let make_err = || {
syn::Error::new(
fields.span(),
@ -108,12 +113,18 @@ impl Forward {
.ident
.clone()
.unwrap_or_else(|| format_ident!("unnamed"));
bounds_store.register_transparent_usage(&field.ty);
Ok(Self::Named(field_name))
}
syn::Fields::Unnamed(unnamed) => {
if unnamed.unnamed.iter().len() != 1 {
let mut iter = unnamed.unnamed.iter();
let field = iter.next().ok_or_else(make_err)?;
if iter.next().is_some() {
return Err(make_err());
}
bounds_store.register_transparent_usage(&field.ty);
Ok(Self::Unnamed(0))
}
_ => Err(syn::Error::new(

View File

@ -11,6 +11,7 @@ use crate::{
diagnostic::{DiagnosticConcreteArgs, DiagnosticDef},
fmt::{self, Display},
forward::WhichFn,
trait_bounds::TraitBoundStore,
utils::{display_pat_members, gen_all_variants_with},
};
@ -101,22 +102,31 @@ impl Parse for LabelAttr {
} else {
(LabelType::Default, None)
};
Ok(LabelAttr { label, lbl_ty })
}
}
impl Labels {
pub fn from_fields(fields: &syn::Fields) -> syn::Result<Option<Self>> {
pub fn from_fields(
fields: &syn::Fields,
bounds_store: &mut TraitBoundStore,
) -> syn::Result<Option<Self>> {
match fields {
syn::Fields::Named(named) => Self::from_fields_vec(named.named.iter().collect()),
syn::Fields::Named(named) => {
Self::from_fields_vec(named.named.iter().collect(), bounds_store)
}
syn::Fields::Unnamed(unnamed) => {
Self::from_fields_vec(unnamed.unnamed.iter().collect())
Self::from_fields_vec(unnamed.unnamed.iter().collect(), bounds_store)
}
syn::Fields::Unit => Ok(None),
}
}
fn from_fields_vec(fields: Vec<&syn::Field>) -> syn::Result<Option<Self>> {
fn from_fields_vec(
fields: Vec<&syn::Field>,
bounds_store: &mut TraitBoundStore,
) -> syn::Result<Option<Self>> {
let mut labels = Vec::new();
for (i, field) in fields.iter().enumerate() {
for attr in &field.attrs {
@ -144,6 +154,16 @@ impl Labels {
));
}
match lbl_ty {
LabelType::Default | LabelType::Primary => {
bounds_store.register_source_span_usage(&field.ty);
}
LabelType::Collection => {
bounds_store.register_source_span_collection_usage(&field.ty);
}
}
labels.push(Label {
label,
span,

View File

@ -14,6 +14,7 @@ mod label;
mod related;
mod severity;
mod source_code;
mod trait_bounds;
mod url;
mod utils;

View File

@ -5,23 +5,32 @@ use syn::spanned::Spanned;
use crate::{
diagnostic::{DiagnosticConcreteArgs, DiagnosticDef},
forward::WhichFn,
trait_bounds::TraitBoundStore,
utils::{display_pat_members, gen_all_variants_with},
};
pub struct Related(syn::Member);
impl Related {
pub(crate) fn from_fields(fields: &syn::Fields) -> syn::Result<Option<Self>> {
pub(crate) fn from_fields(
fields: &syn::Fields,
bounds_store: &mut TraitBoundStore,
) -> syn::Result<Option<Self>> {
match fields {
syn::Fields::Named(named) => Self::from_fields_vec(named.named.iter().collect()),
syn::Fields::Named(named) => {
Self::from_fields_vec(named.named.iter().collect(), bounds_store)
}
syn::Fields::Unnamed(unnamed) => {
Self::from_fields_vec(unnamed.unnamed.iter().collect())
Self::from_fields_vec(unnamed.unnamed.iter().collect(), bounds_store)
}
syn::Fields::Unit => Ok(None),
}
}
fn from_fields_vec(fields: Vec<&syn::Field>) -> syn::Result<Option<Self>> {
fn from_fields_vec(
fields: Vec<&syn::Field>,
bounds_store: &mut TraitBoundStore,
) -> syn::Result<Option<Self>> {
for (i, field) in fields.iter().enumerate() {
for attr in &field.attrs {
if attr.path().is_ident("related") {
@ -33,6 +42,7 @@ impl Related {
span: field.span(),
})
};
bounds_store.register_related_usage(&field.ty);
return Ok(Some(Related(related)));
}
}

View File

@ -1,10 +1,11 @@
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::spanned::Spanned;
use syn::{spanned::Spanned, AngleBracketedGenericArguments, GenericArgument, PathArguments};
use crate::{
diagnostic::{DiagnosticConcreteArgs, DiagnosticDef},
forward::WhichFn,
trait_bounds::TraitBoundStore,
utils::{display_pat_members, gen_all_variants_with},
};
@ -14,17 +15,25 @@ pub struct SourceCode {
}
impl SourceCode {
pub fn from_fields(fields: &syn::Fields) -> syn::Result<Option<Self>> {
pub fn from_fields(
fields: &syn::Fields,
bounds_store: &mut TraitBoundStore,
) -> syn::Result<Option<Self>> {
match fields {
syn::Fields::Named(named) => Self::from_fields_vec(named.named.iter().collect()),
syn::Fields::Named(named) => {
Self::from_fields_vec(named.named.iter().collect(), bounds_store)
}
syn::Fields::Unnamed(unnamed) => {
Self::from_fields_vec(unnamed.unnamed.iter().collect())
Self::from_fields_vec(unnamed.unnamed.iter().collect(), bounds_store)
}
syn::Fields::Unit => Ok(None),
}
}
fn from_fields_vec(fields: Vec<&syn::Field>) -> syn::Result<Option<Self>> {
fn from_fields_vec(
fields: Vec<&syn::Field>,
bounds_store: &mut TraitBoundStore,
) -> syn::Result<Option<Self>> {
for (i, field) in fields.iter().enumerate() {
for attr in &field.attrs {
if attr.path().is_ident("source_code") {
@ -35,12 +44,33 @@ impl SourceCode {
{
segments
.last()
.map(|seg| seg.ident == "Option")
.unwrap_or(false)
.filter(|seg| seg.ident == "Option")
.and_then(|seg| match &seg.arguments {
PathArguments::AngleBracketed(AngleBracketedGenericArguments {
args,
..
}) => {
let mut iter = args.iter();
let ty = iter.next();
iter.next().xor(ty)
}
_ => None,
})
.and_then(|arg| match arg {
GenericArgument::Type(ty) => Some(ty),
_ => None,
})
} else {
false
None
};
if let Some(option_ty) = is_option {
bounds_store.register_source_code_usage(option_ty);
} else {
bounds_store.register_source_code_usage(&field.ty);
}
let source_code = if let Some(ident) = field.ident.clone() {
syn::Member::Named(ident)
} else {
@ -51,7 +81,7 @@ impl SourceCode {
};
return Ok(Some(SourceCode {
source_code,
is_option,
is_option: is_option.is_some(),
}));
}
}

View File

@ -0,0 +1,341 @@
use std::{
collections::{HashMap, VecDeque},
iter::{once, FromIterator},
};
use proc_macro2::Span;
use syn::{
punctuated::Punctuated, token::Plus, AngleBracketedGenericArguments, AssocType,
GenericArgument, GenericParam, Generics, ParenthesizedGenericArguments, Path, PathArguments,
PathSegment, PredicateType, ReturnType, Token, Type, TypeArray, TypeGroup, TypeParamBound,
TypeParen, TypePath, TypePtr, TypeReference, TypeSlice, TypeTuple, WhereClause, WherePredicate,
};
#[derive(Default)]
pub struct RequiredTraitBound {
r#static: bool,
std_error: bool,
miette_diagnostic: bool,
source_code: bool,
into_source_span: bool,
std_into_iter: bool,
}
impl RequiredTraitBound {
fn to_bounds(&self) -> Punctuated<TypeParamBound, Plus> {
let mut bounds = Punctuated::new();
if self.std_error && !self.miette_diagnostic {
bounds.push(TypeParamBound::Trait(syn::parse_quote!(
::std::error::Error
)));
}
if self.miette_diagnostic {
bounds.push(TypeParamBound::Trait(syn::parse_quote!(
::miette::Diagnostic
)))
}
if self.source_code {
bounds.push(TypeParamBound::Trait(syn::parse_quote!(
::miette::SourceCode
)))
}
if self.into_source_span {
bounds.push(TypeParamBound::Trait(syn::parse_quote!(
::std::convert::Into<::miette::SourceSpan>
)))
}
if self.std_into_iter {
bounds.push(TypeParamBound::Trait(syn::parse_quote!(
::std::iter::IntoIterator
)))
}
if self.r#static {
bounds.push(TypeParamBound::Lifetime(syn::parse_quote!('static)))
}
bounds
}
fn register_transparent_usage(&mut self) {
self.r#static = true;
self.miette_diagnostic = true;
}
fn register_source_code_usage(&mut self) {
self.source_code = true;
}
fn register_source_span_usage(&mut self) {
self.into_source_span = true;
}
fn register_collection_usage(&mut self) {
self.std_into_iter = true;
}
fn register_related_item_usage(&mut self) {
self.miette_diagnostic = true;
self.r#static = true;
}
fn register_source_usage(&mut self) {
self.miette_diagnostic = true;
self.r#static = true;
}
}
pub struct TraitBoundStore(HashMap<Type, RequiredTraitBound>);
impl TraitBoundStore {
pub fn new(generics: &Generics) -> Self {
let hash_map = generics
.params
.iter()
.filter_map(|param| {
if let GenericParam::Type(ty) = param {
Some(ty)
} else {
None
}
})
.map(|param| {
Type::Path(TypePath {
qself: None,
path: Path {
leading_colon: None,
segments: Punctuated::from_iter(once(PathSegment {
ident: param.ident.clone(),
arguments: PathArguments::None,
})),
},
})
})
.map(|v| (v, RequiredTraitBound::default()))
.collect::<HashMap<_, _>>();
Self(hash_map)
}
fn check_generic_usage<'ty>(&self, mut r#type: &'ty Type) -> Option<&'ty Type> {
// in theory we could skip all this logic and just allow trivial bounds but that would add redundant trait bounds
// to the derived impl - would be another choice to make. I choose to filter as much as possible so that we don't
// introduce unneccessary bounds.
// this reduces the type down as much as possible to remove unneeded groups.
let original_type = loop {
match r#type {
Type::Paren(TypeParen { elem, .. }) => r#type = &**elem,
Type::Group(TypeGroup { elem, .. }) => r#type = &**elem,
x => break x,
}
};
let mut depends_on_generic = false;
// max depth to check, after which we'll just add the (maybe redundant) bound anyways.
// this is a tradeoff between filtering speed and compiler speed so I'll keep it
// reasonably low for now, since I assume the compiler is better optimized for more complex
// checks.
let max_depth = 8;
let mut to_check_queue: VecDeque<(&Type, usize)> = VecDeque::new();
to_check_queue.push_back((original_type, 0));
while !depends_on_generic {
// this needs to be like this cuz if-let-chains aren't supported yet
let Some((elem, current_depth)) = to_check_queue.pop_front() else {
break;
};
// if we exceed the max depth we just assume it depends on the generic and let the compiler check it
if current_depth > max_depth {
depends_on_generic = true;
break;
}
// the map contains types that we know depend on generics so we can just short circuit
//
// this is also the "bottom" check since we add the generics themselves to the map when
// constructing self
if self.0.contains_key(elem) {
depends_on_generic = true;
break;
}
// basically go through the type and add all referenced types inside it to the check queue
match elem {
Type::Group(_) => unreachable!("This is unwrapped above"),
Type::Paren(_) => unreachable!("This is unwrapped above"),
// function pointer's can never implement the required trait bounds anyways so we just accept the errors
Type::BareFn(_) => return None,
// impl trait types aren't allowed from struct/enum definitions anyways so we can just ignore them
Type::ImplTrait(_) => return None,
// infered types aren't allowed either
Type::Infer(_) => return None,
// macros are opaque to us and i don't really know how to properly implement this.
// we could in theory I think introduce a type alias and use that instead but honestly
// type macros are such a niche usecase especially in combination with a generic,
// I would say we should just recommend to implement
// the trait manually, as such we just accept the error if any occurs (this still allows using macros when they
// return concrete types which don't depend on any generic or when the generic doesn't affect the
// required trait implementation)
Type::Macro(_) => return None,
// trait objects which depend on a generic inside them seem like very much a hassle to implement so i'll ignore
// them for now, if the need arises we could support that in a future pr maybe?
//
// this again doesn't restrict the usage of trait objects which implement the required traits regardless of the generics.
Type::TraitObject(_) => return None,
// Well never is never and never never.
Type::Never(_) => return None,
Type::Array(TypeArray { elem, .. })
| Type::Ptr(TypePtr { elem, .. })
| Type::Reference(TypeReference { elem, .. })
| Type::Slice(TypeSlice { elem, .. }) => {
to_check_queue.push_back((&**elem, current_depth + 1));
}
Type::Path(TypePath { qself, path }) => {
if let Some(qself) = qself {
to_check_queue.push_back((&qself.ty, current_depth + 1));
}
for segment in &path.segments {
match &segment.arguments {
PathArguments::None => {}
PathArguments::AngleBracketed(AngleBracketedGenericArguments {
args,
..
}) => {
for argument in args {
match argument {
GenericArgument::Type(ty)
| GenericArgument::AssocType(AssocType { ty, .. }) => {
to_check_queue.push_back((ty, current_depth + 1));
}
_ => {}
}
}
}
PathArguments::Parenthesized(ParenthesizedGenericArguments {
inputs,
output,
..
}) => {
for inp in inputs {
to_check_queue.push_back((inp, current_depth + 1));
}
if let ReturnType::Type(_, ty) = output {
to_check_queue.push_back((ty, current_depth + 1));
}
}
}
}
}
Type::Tuple(TypeTuple { elems, .. }) => {
for elem in elems {
to_check_queue.push_back((elem, current_depth + 1));
}
}
// we can't really handle verbatim so we just assume it depends on the generics
Type::Verbatim(_) => depends_on_generic = true,
_ => depends_on_generic = true,
}
}
depends_on_generic.then_some(original_type)
}
pub fn merge_with(&self, where_clause: Option<&WhereClause>) -> Option<WhereClause> {
if self.0.is_empty() {
return where_clause.cloned();
}
let mut where_clause = where_clause.cloned().unwrap_or_else(|| WhereClause {
where_token: Token![where](Span::mixed_site()),
predicates: Punctuated::new(),
});
where_clause
.predicates
.extend(self.0.iter().map(|(ty, bounds)| {
WherePredicate::Type(PredicateType {
lifetimes: None,
bounded_ty: ty.clone(),
colon_token: Token![:](Span::mixed_site()),
bounds: bounds.to_bounds(),
})
}));
Some(where_clause)
}
pub fn register_transparent_usage(&mut self, r#type: &Type) {
let Some(r#type) = self.check_generic_usage(r#type) else {
return;
};
let type_opts = self.0.entry(r#type.clone()).or_default();
type_opts.register_transparent_usage()
}
pub fn register_source_code_usage(&mut self, r#type: &Type) {
let Some(r#type) = self.check_generic_usage(r#type) else {
return;
};
let type_opts = self.0.entry(r#type.clone()).or_default();
type_opts.register_source_code_usage()
}
pub fn register_source_span_usage(&mut self, r#type: &Type) {
let Some(r#type) = self.check_generic_usage(r#type) else {
return;
};
let type_opts = self.0.entry(r#type.clone()).or_default();
type_opts.register_source_span_usage()
}
pub fn register_source_span_collection_usage(&mut self, r#type: &Type) {
let Some(ty) = self.check_generic_usage(r#type) else {
return;
};
let type_opts = self.0.entry(ty.clone()).or_default();
type_opts.register_collection_usage();
let type_opts_item = self
.0
.entry(syn::parse_quote!(<#ty as ::std::iter::IntoIterator>::Item))
.or_default();
type_opts_item.register_source_span_usage();
}
pub fn register_related_usage(&mut self, r#type: &Type) {
let Some(ty) = self.check_generic_usage(r#type) else {
return;
};
let type_opts = self.0.entry(ty.clone()).or_default();
type_opts.register_collection_usage();
let type_opts_item = self
.0
.entry(syn::parse_quote!(<#ty as ::std::iter::IntoIterator>::Item))
.or_default();
type_opts_item.register_related_item_usage();
}
pub(crate) fn register_source_usage(&mut self, r#type: &Type) {
let Some(ty) = self.check_generic_usage(r#type) else {
return;
};
let type_opts = self.0.entry(ty.clone()).or_default();
type_opts.register_source_usage();
}
}