errors.rs90.00%
1
// Copyright 2024 Google LLC2
//3
// Licensed under the Apache License, Version 2.0 (the "License");4
// you may not use this file except in compliance with the License.5
// You may obtain a copy of the License at6
//7
// https://www.apache.org/licenses/LICENSE-2.08
//9
// Unless required by applicable law or agreed to in writing, software10
// distributed under the License is distributed on an "AS IS" BASIS,11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12
// See the License for the specific language governing permissions and13
// limitations under the License.14
15
use proc_macro::TokenStream;16
use quote::quote;17
use syn::parse::{Parse, Parser};18
use syn::{DeriveInput, GenericArgument, PathArguments, Type, parse_macro_input, parse_quote};19
20
fn extract_type_from_box(ty: &Type) -> Option<&Type> {512x21
let Type::Path(type_path) = ty else {512x22
return None;23
};24
if type_path.path.segments.first()?.ident != "Box" {512x25
return None;328x26
}184x27
let arguments = &type_path.path.segments.first()?.arguments;184x28
let PathArguments::AngleBracketed(angle_bracketed) = arguments else {184x29
return None;30
};31
let generic_arg = angle_bracketed.args.first()?;184x32
let GenericArgument::Type(ty) = generic_arg else {184x33
return None;34
};35
if matches!(ty, Type::TraitObject(_)) {184x36
None24x37
} else {38
Some(ty)160x39
}40
}512x41
42
pub fn trace_error(_attr: TokenStream, item: TokenStream) -> TokenStream {91x43
let mut input = parse_macro_input!(item as DeriveInput);91x44
let syn::Data::Enum(enum_data) = &mut input.data else {91x45
panic!("not an enum")46
};47
for variant in enum_data.variants.iter_mut() {922x48
if matches!(variant.fields, syn::Fields::Unit) {922x49
variant.fields =136x50
syn::Fields::Named(syn::FieldsNamed::parse.parse2(quote! {{}}).unwrap());136x51
}786x52
let syn::Fields::Named(field) = &mut variant.fields else {922x53
panic!("not a named field ")54
};55
field.named.push(922x56
syn::Field::parse_named57
.parse2(quote! {#[snafu(implicit)] _location: ::snafu::Location})922x58
.unwrap(),922x59
);60
if let Some(source) = field.named.iter_mut().find(|f| {1418x61
let name = f.ident.as_ref().unwrap();1418x62
name == "source" || name == "error"1418x63
}) {1418x64
if let Some(inner_type) = extract_type_from_box(&source.ty) {512x65
source160x66
.attrs160x67
.push(parse_quote! {#[snafu(source(from(#inner_type, Box::new)))]})160x68
} else {69
source.attrs.push(parse_quote! {#[snafu(source)]})352x70
}71
}410x72
}73
74
quote! { #input }.into()91x75
}91x76
77
pub fn derive_debug_trace(input: TokenStream) -> TokenStream {91x78
let mut input = parse_macro_input!(input as DeriveInput);91x79
let name = &input.ident;91x80
let syn::Data::Enum(enum_data) = &mut input.data else {91x81
panic!("not an enum")82
};83
let mut debug_trace_arms = vec![];91x84
for variant in enum_data.variants.iter_mut() {889x85
let syn::Fields::Named(field) = &mut variant.fields else {889x86
panic!("not a named field ")87
};88
let mut cfg_attrs = vec![];889x89
for attr in &variant.attrs {927x90
if attr.path().is_ident("cfg") {927x91
cfg_attrs.push(attr);38x92
}889x93
}94
let is_source = |f: &syn::Field| f.ident.as_ref().unwrap() == "source";1658x95
let has_source = field.named.iter().any(is_source);889x96
let is_error = |f: &syn::Field| f.ident.as_ref().unwrap() == "error";1571x97
let has_error = field.named.iter().any(is_error);889x98
99
let variant_name = &variant.ident;889x100
let debug_trace_arm = if has_source {889x101
quote! {204x102
#(#cfg_attrs)*103
#name::#variant_name {_location, source, ..} => {104
let level = source.debug_trace(f)?;105
writeln!(f, "{level}: {self}, at {_location}")?;106
Ok(level + 1)107
}108
}109
} else if has_error {685x110
quote! {291x111
#(#cfg_attrs)*112
#name::#variant_name {_location, error, ..} => {113
writeln!(f, "0: {error}")?;114
writeln!(f, "1: {self}, at {_location}")?;115
Ok(2)116
}117
}118
} else {119
quote! {394x120
#(#cfg_attrs)*121
#name::#variant_name {_location, .. } => {122
writeln!(f, "0: {self}, at {_location}")?;123
Ok(1)124
}125
}126
};127
debug_trace_arms.push(debug_trace_arm);889x128
}129
130
quote! {91x131
impl DebugTrace for #name {132
#[inline(never)]133
fn debug_trace(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::result::Result<u32, ::std::fmt::Error> {134
match self {135
#(#debug_trace_arms)*136
}137
}138
}139
140
impl ::std::fmt::Debug for #name {141
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {142
writeln!(f, "{self}")?;143
DebugTrace::debug_trace(self, f)?;144
Ok(())145
}146
}147
}148
.into()91x149
}91x150