From 41f85afa6cdb4afd20a205ba11f56bc9c2040ff3 Mon Sep 17 00:00:00 2001 From: buckn Date: Fri, 18 Jul 2025 13:34:27 -0400 Subject: [PATCH] ud, fixed errors --- src/lib.rs | 138 ++++++++++++++++++++++------------------------------- 1 file changed, 58 insertions(+), 80 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a456c44..c44c7c9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -147,85 +147,79 @@ pub fn derive_http_get_request(input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn alpaca_cli(_attr: TokenStream, item: TokenStream) -> TokenStream { let input_enum = parse_macro_input!(item as ItemEnum); - let top_enum_ident = &input_enum.ident; // e.g., "Cmd" + let top_enum_ident = &input_enum.ident; let top_variants = &input_enum.variants; - // For each variant like Alpaca(AlpacaCmd) - let outer_match_arms = top_variants.iter().map(|v| { - let variant_ident = &v.ident; // e.g., "Alpaca" + // Build outer match arms + let match_arms: Vec<_> = top_variants.iter().map(|variant| { + let variant_ident = &variant.ident; - // Extract the inner sub-enum type like AlpacaCmd - if let syn::Fields::Unnamed(fields) = &v.fields { - if let Some(field) = fields.unnamed.first() { - if let syn::Type::Path(inner_type) = &field.ty { - let inner_type_ident = &inner_type.path.segments.last().unwrap().ident; + // Expecting tuple variants like Alpaca(AlpacaCmd) + let inner_type = match &variant.fields { + Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { + match &fields.unnamed.first().unwrap().ty { + syn::Type::Path(p) => p.path.segments.last().unwrap().ident.clone(), + _ => panic!("Expected tuple variant with a type path"), + } + } + _ => panic!("Each variant must be a tuple variant like `Alpaca(AlpacaCmd)`"), + }; - // Match arms inside the nested enum (AlpacaCmd) - let inner_match_arm = quote! { - match #inner_type_ident::parse() { - #inner_type_ident::Bulk { input } => { - // Bulk: read and parse Vec - let mut reader: Box = match input { - Some(path) => Box::new(std::fs::File::open(path)?), - None => Box::new(std::io::stdin()), - }; - let mut buf = String::new(); - reader.read_to_string(&mut buf)?; - let queries: Vec<#inner_type_ident> = serde_json::from_str(&buf)?; + quote! { + #top_enum_ident::#variant_ident(inner) => { + match inner { + #inner_type::Bulk { input } => { + let mut reader: Box = match input { + Some(path) => Box::new(std::fs::File::open(path)?), + None => Box::new(std::io::stdin()), + }; - use std::sync::Arc; - let client = Arc::new(awc::Client::default()); - let api_keys = Arc::new(crate::load_api_keys()?); + let mut buf = String::new(); + reader.read_to_string(&mut buf)?; + let queries: Vec<#inner_type> = serde_json::from_str(&buf)?; - const THREADS: usize = 4; - let total = queries.len(); - let per_thread = total / THREADS; - let shared_queries = Arc::new(queries); + use std::sync::Arc; + let client = Arc::new(awc::Client::default()); + let keys = Arc::new(crate::load_api_keys()?); - let mut handles = Vec::new(); - for i in 0..THREADS { - let queries = Arc::clone(&shared_queries); - let client = Arc::clone(&client); - let keys = Arc::clone(&api_keys); - let start = i * per_thread; - let end = if i == THREADS - 1 { total } else { start + per_thread }; + const THREADS: usize = 4; + let total = queries.len(); + let per_thread = std::cmp::max(1, total / THREADS); + let shared_queries = Arc::new(queries); - let handle = std::thread::spawn(move || { - let rt = tokio::runtime::Runtime::new().unwrap(); - for q in &queries[start..end] { - rt.block_on(q.send_all(&client, &keys)); - } - }); + let mut handles = Vec::new(); + for i in 0..THREADS { + let queries = Arc::clone(&shared_queries); + let client = Arc::clone(&client); + let keys = Arc::clone(&keys); + let start = i * per_thread; + let end = if i == THREADS - 1 { total } else { start + per_thread }; - handles.push(handle); + let handle = std::thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + for q in &queries[start..end] { + rt.block_on(q.send_all(&client, &keys)).unwrap(); } + }); - for h in handles { - h.join().expect("Thread panicked"); - } - } - other => { - let client = awc::Client::default(); - let keys = crate::load_api_keys()?; - other.send_all(&client, &keys).await?; - } + handles.push(handle); } - }; - // Wrap the outer enum match (Cmd::Alpaca(inner)) - return quote! { - #top_enum_ident::#variant_ident(inner) => { - #inner_match_arm + for h in handles { + h.join().expect("Thread panicked"); } - }; + } + other => { + let client = awc::Client::default(); + let keys = crate::load_api_keys()?; + other.send_all(&client, &keys).await?; + } } } } + }).collect(); - panic!("Each outer enum variant must be a tuple variant like `Alpaca(AlpacaCmd)`"); - }); - - // Generate the final program + // Generate the final code let expanded = quote! { use clap::Parser; use std::io::Read; @@ -237,35 +231,19 @@ pub fn alpaca_cli(_attr: TokenStream, item: TokenStream) -> TokenStream { async fn main() -> Result<(), Box> { let cmd = #top_enum_ident::parse(); match cmd { - #(#outer_match_arms),* + #(#match_arms),* } Ok(()) } - // Helper trait to unify async calls on sub-commands - trait ApiDispatch { + // Trait for dispatching API calls + pub trait ApiDispatch { fn send_all( &self, client: &awc::Client, keys: &std::collections::HashMap, ) -> std::pin::Pin>> + Send>>; } - - // Implement ApiDispatch for every subcommand variant - #(impl ApiDispatch for #top_enum_ident { - fn send_all( - &self, - client: &awc::Client, - keys: &std::collections::HashMap, - ) -> std::pin::Pin>> + Send>> { - Box::pin(async move { - match self { - #(#outer_match_arms),* - } - Ok(()) - }) - } - })* }; TokenStream::from(expanded)