diff --git a/src/lib.rs b/src/lib.rs index 70a67c9..ba21f70 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,13 +5,22 @@ use syn::{parse_macro_input, Lit, ItemEnum, DeriveInput, Fields, Data}; #[proc_macro_derive(HttpRequest, attributes(http_get))] pub fn derive_http_get_request(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); - let name = &input.ident; + let query_name = &input.ident; + let query_name_str = query_name.to_string(); - // Extract the base URL from #[http_get(url = "...")] + // Derive response enum name by replacing "Query" suffix with "Response" + let response_name_str = if query_name_str.ends_with("Query") { + query_name_str.trim_end_matches("Query").to_string() + "Response" + } else { + panic!("HttpRequest derive macro expects the struct name to end with 'Query'"); + }; + let response_name = format_ident!("{}", response_name_str); + + // Extract base URL from #[http_get(url = "...")] let mut base_url = None; for attr in &input.attrs { - if attr.path().is_ident("http_get") { - let _ = attr.parse_nested_meta(|meta| { + if attr.path.is_ident("http_get") { + attr.parse_nested_meta(|meta| { if meta.path.is_ident("url") { let value: Lit = meta.value()?.parse()?; if let Lit::Str(litstr) = value { @@ -19,210 +28,75 @@ pub fn derive_http_get_request(input: TokenStream) -> TokenStream { } } Ok(()) - }); + }).unwrap(); } } let base_url = base_url.expect("Missing #[http_get(url = \"...\")] attribute"); - let base_url_lit = Lit::Str(syn::LitStr::new(&base_url, proc_macro2::Span::call_site())); + let base_url_lit = syn::LitStr::new(&base_url, proc_macro2::Span::call_site()); - let expanded = match &input.data { - Data::Struct(data_struct) => { - let fields = match &data_struct.fields { - Fields::Named(named) => &named.named, - _ => panic!("#[derive(HttpRequest)] only supports structs with named fields"), - }; - - let mut query_param_code = Vec::new(); - for field in fields { - let ident = field.ident.clone().unwrap(); + // Collect query parameters from fields named "lnk_p_*" + let query_param_code = if let Data::Struct(data_struct) = &input.data { + if let Fields::Named(fields_named) = &data_struct.fields { + fields_named.named.iter().filter_map(|field| { + let ident = field.ident.as_ref()?; let field_name = ident.to_string(); if field_name.starts_with("lnk_p_") { let key = &field_name["lnk_p_".len()..]; - query_param_code.push(quote! { + Some(quote! { query_params.push((#key.to_string(), self.#ident.to_string())); - }); + }) + } else { + None } - } - - quote! { - impl Queryable for #name { - fn send( - &self, - headers: Option>, - ) -> Result { - use urlencoding::encode; - use awc::Client; - - let mut query_params: Vec<(String, String)> = Vec::new(); - #(#query_param_code)* - - let mut url = #base_url_lit.to_string(); - if !query_params.is_empty() { - let query_parts: Vec = query_params.iter() - .map(|(k, v)| format!("{}={}", k, encode(v))) - .collect(); - url.push('?'); - url.push_str(&query_parts.join("&")); - } - - let client = Client::default(); - let mut request = client.get(url); - - if let Some(hdrs) = headers { - for (k, v) in hdrs { - request = request.append_header((k, v)); - } - } - - let response = rt::System::new() - .block_on(async { - request.send() - .await - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string())) - })?; - - let body_bytes = rt::System::new() - .block_on(async { - response.body() - .await - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string())) - })?; - - let body = Response::receive(body_bytes.to_vec()) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; - - Ok(body) - } - } - } + }).collect::>() + } else { + Vec::new() } - - Data::Enum(data_enum) => { - let mut variant_arms = Vec::new(); - for variant in &data_enum.variants { - let vname = &variant.ident; - match &variant.fields { - Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { - variant_arms.push(quote! { - #name::#vname(inner) => inner.send(headers.clone()).await, - }); - } - _ => panic!("#[derive(HttpRequest)] enum variants must have a single unnamed field"), - } - } - - quote! { - impl Queryable for #name { - fn send( - &self, - headers: Option>, - ) -> Result { - match self { - #(#variant_arms)* - } - } - } - } - } - - _ => panic!("#[derive(HttpRequest)] only supports structs and enums"), + } else { + Vec::new() }; - TokenStream::from(expanded) -} - -#[proc_macro_attribute] -pub fn http_response(_attr: TokenStream, item: TokenStream) -> TokenStream { - let item_for_ast = item.clone(); - let item_for_quote = item.clone(); - - let ast: DeriveInput = parse_macro_input!(item_for_ast as DeriveInput); - let name = &ast.ident; - - let impl_block = match &ast.data { - Data::Struct(_) => { - // Impl for struct: deserialize entire response body as Self - quote! { - #[async_trait::async_trait] - impl Responsable for #name - where - Self: serde::de::DeserializeOwned + Sized + Send, - { - async fn receive(mut resp: awc::ClientResponse) -> Result> { - let parsed = resp.json::().await?; - Ok(parsed) - } - } - } - } - Data::Enum(data_enum) => { - // Impl for enum with tuple variants of length 1 (newtype variants) - let variant_arms = data_enum.variants.iter().filter_map(|variant| { - let vname = &variant.ident; - match &variant.fields { - Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { - let inner_ty = &fields.unnamed.first().unwrap().ty; - Some(quote! { - if let Ok(inner) = serde_json::from_slice::<#inner_ty>(&body) { - return Ok(#name::#vname(inner)); - } - }) - } - _ => None, - } - }); - - quote! { - #[async_trait::async_trait] - impl Responsable for #name - where - Self: Sized + Send, - { - async fn receive(mut resp: awc::ClientResponse) -> Result> { - let body = resp.body().await?; // Bytes, sized! - #(#variant_arms)* - Err(Box::new(std::io::Error::new( - std::io::ErrorKind::Other, - concat!("No matching enum variant in ", stringify!(#name)) - ))) - } - } - } - } - _ => panic!("#[http_response] only supports structs and tuple-style enum variants"), - }; - - let original: syn::Item = syn::parse(item_for_quote).expect("Failed to parse item as syn::Item"); - - let output = quote! { - #original - #impl_block - }; - - output.into() -} - -#[proc_macro_derive(ResponseVec)] -pub fn derive_response_vec(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - let name = &input.ident; - + // Generate the impl let expanded = quote! { - impl Responsable for #name { - /// Deserializes all responses sequentially into a Vec. - /// Assumes `Self` implements `DeserializeOwned`. - pub async fn response_vec( - responses: Vec, - ) -> Result, std::error::Error> - where - Self: Sized + serde::de::DeserializeOwned, - { - let mut results = Vec::with_capacity(responses.len()); - for resp in responses { - let item = resp.json::().await?; - results.push(item); + #[async_trait::async_trait] + impl Queryable for #query_name { + type Response = #response_name; + + async fn send( + &self, + headers: Option>, + ) -> Result> { + use awc::Client; + use urlencoding::encode; + + let mut query_params: Vec<(String, String)> = Vec::new(); + #(#query_param_code)* + + let mut url = #base_url_lit.to_string(); + if !query_params.is_empty() { + let query_string: String = query_params.iter() + .map(|(k, v)| format!("{}={}", k, encode(v))) + .collect::>() + .join("&"); + url.push('?'); + url.push_str(&query_string); } - Ok(results) + + let client = Client::default(); + let mut request = client.get(url); + + if let Some(hdrs) = headers { + for (k, v) in hdrs { + request = request.append_header((k, v)); + } + } + + let response = request.send().await?; + let bytes = response.body().await?; + + // Deserialize into associated Response type + let parsed: Self::Response = serde_json::from_slice(&bytes)?; + Ok(parsed) } } };