diff --git a/Cargo.lock b/Cargo.lock index 6db0df14..88755425 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2264,7 +2264,6 @@ dependencies = [ "derivative", "derive_builder", "derive_more 2.1.1", - "easytier-rpc-build", "encoding", "flume 0.12.0", "forwarded-header-value", @@ -2310,6 +2309,7 @@ dependencies = [ "pin-project-lite", "pnet", "prefix-trie", + "proc-macro2", "prost", "prost-build", "prost-reflect", @@ -2319,6 +2319,7 @@ dependencies = [ "prost-wkt-types", "quinn", "quinn-plaintext", + "quote", "rand 0.8.5", "rcgen", "regex", @@ -2358,7 +2359,6 @@ dependencies = [ "tokio-util", "tokio-websockets", "toml 0.8.19", - "tonic-build", "tracing", "tracing-subscriber", "tun-easytier", @@ -2437,14 +2437,6 @@ dependencies = [ "windows 0.52.0", ] -[[package]] -name = "easytier-rpc-build" -version = "0.1.0" -dependencies = [ - "heck 0.5.0", - "prost-build", -] - [[package]] name = "easytier-uptime" version = "0.1.0" @@ -10055,19 +10047,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "tonic-build" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "568392c5a2bd0020723e3f387891176aabafe36fd9fcd074ad309dfa0c8eb964" -dependencies = [ - "prettyplease", - "proc-macro2", - "prost-build", - "quote", - "syn 2.0.117", -] - [[package]] name = "tower" version = "0.4.13" diff --git a/Cargo.toml b/Cargo.toml index fecdef70..2a2eb144 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,6 @@ resolver = "2" members = [ "easytier", "easytier-gui/src-tauri", - "easytier-rpc-build", "easytier-web", "easytier-contrib/easytier-ffi", "easytier-contrib/easytier-uptime", diff --git a/easytier-rpc-build/Cargo.toml b/easytier-rpc-build/Cargo.toml deleted file mode 100644 index 510d8381..00000000 --- a/easytier-rpc-build/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[package] -name = "easytier-rpc-build" -description = "Protobuf RPC Service Generator for EasyTier" -version = "0.1.0" -edition.workspace = true -homepage = "https://github.com/EasyTier/EasyTier" -repository = "https://github.com/EasyTier/EasyTier" -authors = ["kkrainbow"] -keywords = ["vpn", "p2p", "network", "easytier"] -categories = ["network-programming", "command-line-utilities"] -license-file = "LICENSE" -readme = "README.md" - -[dependencies] -heck = "0.5" -prost-build = "0.13" - -[features] -default = [] -internal-namespace = [] diff --git a/easytier-rpc-build/LICENSE b/easytier-rpc-build/LICENSE deleted file mode 120000 index ea5b6064..00000000 --- a/easytier-rpc-build/LICENSE +++ /dev/null @@ -1 +0,0 @@ -../LICENSE \ No newline at end of file diff --git a/easytier-rpc-build/README.md b/easytier-rpc-build/README.md deleted file mode 100644 index 66b1cae0..00000000 --- a/easytier-rpc-build/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Introduction - -This is a protobuf rpc service stub generator for [EasyTier](https://github.com/EasyTier/EasyTier) project. diff --git a/easytier-rpc-build/src/lib.rs b/easytier-rpc-build/src/lib.rs deleted file mode 100644 index f41f1e18..00000000 --- a/easytier-rpc-build/src/lib.rs +++ /dev/null @@ -1,449 +0,0 @@ -extern crate heck; -extern crate prost_build; - -use std::fmt; - -#[cfg(feature = "internal-namespace")] -const NAMESPACE: &str = "crate::proto::rpc_types"; - -#[cfg(not(feature = "internal-namespace"))] -const NAMESPACE: &str = "easytier::proto::rpc_types"; - -/// The service generator to be used with `prost-build` to generate RPC implementations for -/// `prost-simple-rpc`. -/// -/// See the crate-level documentation for more info. -#[allow(missing_copy_implementations)] -#[derive(Clone, Debug, Default)] -pub struct ServiceGenerator { - _private: (), -} - -impl prost_build::ServiceGenerator for ServiceGenerator { - fn generate(&mut self, service: prost_build::Service, mut buf: &mut String) { - use std::fmt::Write; - - let descriptor_name = format!("{}Descriptor", service.name); - let server_name = format!("{}Server", service.name); - let client_name = format!("{}Client", service.name); - let method_descriptor_name = format!("{}MethodDescriptor", service.name); - - let mut trait_methods = String::new(); - let mut weak_impl_methods = String::new(); - let mut enum_methods = String::new(); - let mut list_enum_methods = String::new(); - let mut client_methods = String::new(); - let mut client_own_methods = String::new(); - let mut match_name_methods = String::new(); - let mut match_proto_name_methods = String::new(); - let mut match_input_type_methods = String::new(); - let mut match_input_proto_type_methods = String::new(); - let mut match_output_type_methods = String::new(); - let mut match_output_proto_type_methods = String::new(); - let mut match_handle_methods = String::new(); - // generate trait default method Xxx::json_call_method match branch - let mut match_trait_json_methods = String::new(); - - let mut match_method_try_from = String::new(); - - for (idx, method) in service.methods.iter().enumerate() { - assert!( - !method.client_streaming, - "Client streaming not yet supported for method {}", - method.proto_name - ); - assert!( - !method.server_streaming, - "Server streaming not yet supported for method {}", - method.proto_name - ); - - ServiceGenerator::write_comments(&mut trait_methods, 4, &method.comments).unwrap(); - writeln!( - trait_methods, - r#" async fn {name}(&self, ctrl: Self::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}>;"#, - name = method.name, - input_type = method.input_type, - output_type = method.output_type, - namespace = NAMESPACE, - ) - .unwrap(); - - writeln!( - weak_impl_methods, - r#" async fn {method_name}(&self, ctrl: Self::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}> {{ - let Some(service) = self.upgrade() else {{ - return Err({namespace}::error::Error::Shutdown); - }}; - service.{method_name}(ctrl, input).await - }}"#, - method_name = method.name, - input_type = method.input_type, - output_type = method.output_type, - namespace = NAMESPACE, - ) - .unwrap(); - - ServiceGenerator::write_comments(&mut enum_methods, 4, &method.comments).unwrap(); - writeln!( - enum_methods, - " {name} = {index},", - name = method.proto_name, - index = idx + 1 - ) - .unwrap(); - - writeln!( - match_method_try_from, - " {index} => Ok({service_name}MethodDescriptor::{name}),", - service_name = service.name, - name = method.proto_name, - index = idx + 1, - ) - .unwrap(); - - writeln!( - list_enum_methods, - " {service_name}MethodDescriptor::{name},", - service_name = service.name, - name = method.proto_name - ) - .unwrap(); - - writeln!( - client_methods, - r#" async fn {name}(&self, ctrl: H::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}> {{ - {client_name}Client::{name}_inner(self.0.clone(), ctrl, input).await - }}"#, - name = method.name, - input_type = method.input_type, - output_type = method.output_type, - client_name = service.name, - namespace = NAMESPACE, - ) - .unwrap(); - - writeln!( - client_own_methods, - r#" async fn {name}_inner(handler: H, ctrl: H::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}> {{ - {namespace}::__rt::call_method(handler, ctrl, {method_descriptor_name}::{proto_name}, input).await - }}"#, - name = method.name, - method_descriptor_name = method_descriptor_name, - proto_name = method.proto_name, - input_type = method.input_type, - output_type = method.output_type, - namespace = NAMESPACE, - ).unwrap(); - - let case = format!( - " {service_name}MethodDescriptor::{proto_name} => ", - service_name = service.name, - proto_name = method.proto_name - ); - - writeln!(match_name_methods, "{}{:?},", case, method.name).unwrap(); - writeln!(match_proto_name_methods, "{}{:?},", case, method.proto_name).unwrap(); - writeln!( - match_input_type_methods, - "{}::std::any::TypeId::of::<{}>(),", - case, method.input_type - ) - .unwrap(); - writeln!( - match_input_proto_type_methods, - "{}{:?},", - case, method.input_proto_type - ) - .unwrap(); - writeln!( - match_output_type_methods, - "{}::std::any::TypeId::of::<{}>(),", - case, method.output_type - ) - .unwrap(); - writeln!( - match_output_proto_type_methods, - "{}{:?},", - case, method.output_proto_type - ) - .unwrap(); - write!( - match_handle_methods, - r#"{} {{ - let decoded: {input_type} = {namespace}::__rt::decode(input)?; - let ret = service.{name}(ctrl, decoded).await?; - {namespace}::__rt::encode(ret) - }} -"#, - case, - input_type = method.input_type, - name = method.name, - namespace = NAMESPACE, - ) - .unwrap(); - - write!( - match_trait_json_methods, - r#" "{name}" | "{proto_name}" => {{ - let req: {input_type} = ::serde_json::from_value(json).map_err(|e| {namespace}::error::Error::MalformatRpcPacket(format!("json error: {{}}", e)))?; - let resp = self.{typed_method}(ctrl, req).await?; - Ok(::serde_json::to_value(resp).map_err(|e| {namespace}::error::Error::MalformatRpcPacket(format!("json error: {{}}", e)))?) - }} -"#, - name = method.name, - proto_name = method.proto_name, - input_type = method.input_type, - typed_method = method.name, - namespace = NAMESPACE, - ) - .unwrap(); - } - - ServiceGenerator::write_comments(&mut buf, 0, &service.comments).unwrap(); - write!( - buf, - r#" -#[async_trait::async_trait] -#[auto_impl::auto_impl(&, Arc, Box)] -pub trait {name} {{ - type Controller: {namespace}::controller::Controller; - - {trait_methods} - - async fn json_call_method( - &self, - ctrl: Self::Controller, - method_name: &str, - json: ::serde_json::Value, - ) -> {namespace}::error::Result<::serde_json::Value> {{ - match method_name {{ -{match_trait_json_methods} - _ => Err({namespace}::error::Error::InvalidMethodIndex(0, method_name.to_string())), - }} - }} -}} - -#[async_trait::async_trait] -impl {name} for ::std::sync::Weak -where - T: Send + Sync + 'static, - ::std::sync::Arc: {name}, -{{ - type Controller = <::std::sync::Arc as {name}>::Controller; - - {weak_impl_methods} -}} - -/// A service descriptor for a `{name}`. -#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd, Default)] -pub struct {descriptor_name}; - -/// Methods available on a `{name}`. -/// -/// This can be used as a key when routing requests for servers/clients of a `{name}`. -#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] -#[repr(u8)] -pub enum {method_descriptor_name} {{ - {enum_methods} -}} - -impl std::convert::TryFrom for {method_descriptor_name} {{ - type Error = {namespace}::error::Error; - fn try_from(value: u8) -> {namespace}::error::Result {{ - match value {{ - {match_method_try_from} - _ => Err({namespace}::error::Error::InvalidMethodIndex(value, "{name}".to_string())), - }} - }} -}} - -/// A client for a `{name}`. -/// -/// This implements the `{name}` trait by dispatching all method calls to the supplied `Handler`. -#[derive(Clone, Debug)] -pub struct {client_name}(H) where H: {namespace}::handler::Handler; - -impl {client_name} where H: {namespace}::handler::Handler {{ - /// Creates a new client instance that delegates all method calls to the supplied handler. - pub fn new(handler: H) -> {client_name} {{ - {client_name}(handler) - }} -}} - -impl {client_name} where H: {namespace}::handler::Handler {{ - {client_own_methods} -}} - -#[async_trait::async_trait] -impl {name} for {client_name} where H: {namespace}::handler::Handler {{ - type Controller = H::Controller; - - {client_methods} -}} - -pub struct {client_name}Factory(std::marker::PhantomData); - -impl Clone for {client_name}Factory {{ - fn clone(&self) -> Self {{ - Self(std::marker::PhantomData) - }} -}} - -impl {namespace}::__rt::RpcClientFactory for {client_name}Factory where C: {namespace}::controller::Controller {{ - type Descriptor = {descriptor_name}; - type ClientImpl = Box + Send + Sync + 'static>; - type Controller = C; - - fn new(handler: impl {namespace}::handler::Handler) -> Self::ClientImpl {{ - Box::new({client_name}::new(handler)) - }} -}} - -/// A server for a `{name}`. -/// -/// This implements the `Server` trait by handling requests and dispatch them to methods on the -/// supplied `{name}`. -#[derive(Clone, Debug)] -pub struct {server_name}(A) where A: {name} + Clone + Send + 'static; - -impl {server_name}<::std::sync::Weak> -where - T: Send + Sync + 'static, - ::std::sync::Arc: {name}, -{{ - pub fn new_arc(service: ::std::sync::Arc) -> {server_name}<::std::sync::Weak> {{ - {server_name}(::std::sync::Arc::downgrade(&service)) - }} -}} - -impl {server_name} where A: {name} + Clone + Send + 'static {{ - /// Creates a new server instance that dispatches all calls to the supplied service. - pub fn new(service: A) -> {server_name} {{ - {server_name}(service) - }} - - async fn call_inner( - service: A, - method: {method_descriptor_name}, - ctrl: A::Controller, - input: ::bytes::Bytes) - -> {namespace}::error::Result<::bytes::Bytes> {{ - match method {{ - {match_handle_methods} - }} - }} -}} - -impl {namespace}::descriptor::ServiceDescriptor for {descriptor_name} {{ - type Method = {method_descriptor_name}; - fn name(&self) -> &'static str {{ {name:?} }} - fn proto_name(&self) -> &'static str {{ {proto_name:?} }} - fn package(&self) -> &'static str {{ {package:?} }} - fn methods(&self) -> &'static [Self::Method] {{ - &[ {list_enum_methods} ] - }} -}} - -#[async_trait::async_trait] -impl {namespace}::handler::Handler for {server_name} -where - A: {name} + Clone + Send + Sync + 'static {{ - type Descriptor = {descriptor_name}; - type Controller = A::Controller; - - async fn call( - &self, - ctrl: A::Controller, - method: {method_descriptor_name}, - input: ::bytes::Bytes) - -> {namespace}::error::Result<::bytes::Bytes> {{ - {server_name}::call_inner(self.0.clone(), method, ctrl, input).await - }} -}} - -impl {namespace}::descriptor::MethodDescriptor for {method_descriptor_name} {{ - fn name(&self) -> &'static str {{ - match *self {{ - {match_name_methods} - }} - }} - - fn proto_name(&self) -> &'static str {{ - match *self {{ - {match_proto_name_methods} - }} - }} - - fn input_type(&self) -> ::std::any::TypeId {{ - match *self {{ - {match_input_type_methods} - }} - }} - - fn input_proto_type(&self) -> &'static str {{ - match *self {{ - {match_input_proto_type_methods} - }} - }} - - fn output_type(&self) -> ::std::any::TypeId {{ - match *self {{ - {match_output_type_methods} - }} - }} - - fn output_proto_type(&self) -> &'static str {{ - match *self {{ - {match_output_proto_type_methods} - }} - }} - - fn index(&self) -> u8 {{ - *self as u8 - }} -}} -"#, - name = service.name, - descriptor_name = descriptor_name, - server_name = server_name, - client_name = client_name, - method_descriptor_name = method_descriptor_name, - proto_name = service.proto_name, - package = service.package, - trait_methods = trait_methods, - weak_impl_methods = weak_impl_methods, - enum_methods = enum_methods, - list_enum_methods = list_enum_methods, - client_own_methods = client_own_methods, - client_methods = client_methods, - match_name_methods = match_name_methods, - match_proto_name_methods = match_proto_name_methods, - match_input_type_methods = match_input_type_methods, - match_input_proto_type_methods = match_input_proto_type_methods, - match_output_type_methods = match_output_type_methods, - match_output_proto_type_methods = match_output_proto_type_methods, - match_handle_methods = match_handle_methods, - match_trait_json_methods = match_trait_json_methods, - namespace = NAMESPACE, - ).unwrap(); - } -} - -impl ServiceGenerator { - fn write_comments( - mut write: W, - indent: usize, - comments: &prost_build::Comments, - ) -> fmt::Result - where - W: fmt::Write, - { - for comment in &comments.leading { - for line in comment.lines().filter(|s| !s.is_empty()) { - writeln!(write, "{}///{}", " ".repeat(indent), line)?; - } - } - Ok(()) - } -} diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 4b645546..0748909e 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -11,6 +11,7 @@ keywords = ["vpn", "p2p", "network", "easytier"] categories = ["network-programming", "command-line-utilities"] license-file = "LICENSE" readme = "README.md" +build = "build/main.rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -316,15 +317,14 @@ jemalloc-sys = { package = "tikv-jemalloc-sys", version = "0.6.0", features = [ [build-dependencies] cfg_aliases = "0.2.1" -tonic-build = "0.12" +indoc = "2.0" globwalk = "0.8.1" regex = "1" prost-build = "0.13.5" prost-wkt-build = "0.6" -easytier-rpc-build = { path = "../easytier-rpc-build", features = [ - "internal-namespace", -] } prost-reflect-build = { version = "0.14.0" } +proc-macro2 = "1" +quote = "1" thunk-rs = { git = "https://github.com/easytier/thunk.git", default-features = false, features = [ "win7", ] } diff --git a/easytier/build.rs b/easytier/build/main.rs similarity index 98% rename from easytier/build.rs rename to easytier/build/main.rs index 6b2ea81f..8d111d80 100644 --- a/easytier/build.rs +++ b/easytier/build/main.rs @@ -1,3 +1,6 @@ +mod rpc; + +use crate::rpc::ServiceGenerator; use cfg_aliases::cfg_aliases; use prost_wkt_build::{FileDescriptorSet, Message as _}; #[cfg(target_os = "windows")] @@ -197,7 +200,7 @@ fn main() -> Result<(), Box> { .type_attribute("acl.Rule", "#[serde(default)]") .type_attribute("acl.GroupInfo", "#[serde(default)]") .field_attribute(".api.manage.NetworkConfig", "#[serde(default)]") - .service_generator(Box::new(easytier_rpc_build::ServiceGenerator::default())) + .service_generator(Box::new(ServiceGenerator::default())) .btree_map(["."]) .skip_debug([".common.Ipv4Addr", ".common.Ipv6Addr", ".common.UUID"]); diff --git a/easytier/build/rpc.rs b/easytier/build/rpc.rs new file mode 100644 index 00000000..e8d4472d --- /dev/null +++ b/easytier/build/rpc.rs @@ -0,0 +1,720 @@ +#![allow(non_snake_case)] + +use indoc::formatdoc; +use proc_macro2::{Ident, TokenStream}; +use quote::{format_ident, quote}; +use std::str::FromStr; + +fn parse(value: &str) -> TokenStream { + TokenStream::from_str(value) + .unwrap_or_else(|err| panic!("Failed to parse tokens: {} ({})", value, err)) +} + +fn doc(comments: &prost_build::Comments) -> TokenStream { + let doc = comments + .leading + .iter() + .flat_map(|c| c.lines().filter(|s| !s.is_empty())); + quote! { #( #[doc = #doc] )* } +} + +const NAMESPACE: &str = "crate::proto::rpc_types"; + +struct Method { + index: u8, + doc: TokenStream, + method: Ident, + method_inner: Ident, + method_str: String, + method_proto: Ident, + method_proto_str: String, + Input: TokenStream, + Input_proto_str: String, + Output: TokenStream, + Output_proto_str: String, +} + +impl Method { + fn new(index: u8, method: prost_build::Method) -> Self { + assert!( + !method.client_streaming, + "Client streaming not yet supported for method {}", + method.proto_name + ); + assert!( + !method.server_streaming, + "Server streaming not yet supported for method {}", + method.proto_name + ); + Self { + index, + doc: doc(&method.comments), + method: format_ident!("{}", method.name), + method_inner: format_ident!("{}_inner", method.name), + method_str: method.name, + method_proto: format_ident!("{}", method.proto_name), + method_proto_str: method.proto_name, + Input: parse(&method.input_type), + Input_proto_str: method.input_proto_type, + Output: parse(&method.output_type), + Output_proto_str: method.output_proto_type, + } + } +} + +struct Service { + namespace: TokenStream, + doc: TokenStream, + Service: Ident, + ServiceDescriptor: Ident, + ServiceServer: Ident, + ServiceClient: Ident, + ServiceClientFactory: Ident, + ServiceMethodDescriptor: Ident, + Service_str: String, + Service_proto_str: String, + Service_package_str: String, + methods: Vec, +} + +impl Service { + fn new(service: prost_build::Service) -> Self { + let methods = service + .methods + .into_iter() + .enumerate() + .map(|(i, method)| Method::new((i + 1) as u8, method)) + .collect(); + + Self { + namespace: parse(NAMESPACE), + doc: doc(&service.comments), + Service: format_ident!("{}", service.name), + ServiceDescriptor: format_ident!("{}Descriptor", service.name), + ServiceServer: format_ident!("{}Server", service.name), + ServiceClient: format_ident!("{}Client", service.name), + ServiceClientFactory: format_ident!("{}ClientFactory", service.name), + ServiceMethodDescriptor: format_ident!("{}MethodDescriptor", service.name), + Service_str: service.name, + Service_proto_str: service.proto_name, + Service_package_str: service.package, + methods, + } + } + + fn trait_Service(&self) -> TokenStream { + let Self { + namespace, + doc, + Service, + methods, + .. + } = self; + + let match_json_call_method = methods.iter().map( + |Method { + method, + method_str, + method_proto_str, + Input, + .. + }| { + quote! { + #method_str | #method_proto_str => { + let req: #Input = ::serde_json::from_value(json) + .map_err(|e| #namespace::error::Error::MalformatRpcPacket(format!("json error: {}", e)))?; + let resp = self.#method(ctrl, req).await?; + Ok(::serde_json::to_value(resp) + .map_err(|e| #namespace::error::Error::MalformatRpcPacket(format!("json error: {}", e)))?) + } + } + }, + ); + + let methods = methods.iter().map( + |Method { + doc, + method, + Input, + Output, + .. + }| { + quote! { + #doc + async fn #method(&self, ctrl: Self::Controller, input: #Input) -> #namespace::error::Result<#Output>; + } + }, + ); + + quote! { + #doc + #[async_trait::async_trait] + #[auto_impl::auto_impl(&, Arc, Box)] + pub trait #Service { + type Controller: #namespace::controller::Controller; + + #(#methods)* + + async fn json_call_method( + &self, + ctrl: Self::Controller, + method: &str, + json: ::serde_json::Value, + ) -> #namespace::error::Result<::serde_json::Value> { + match method { + #(#match_json_call_method)* + _ => Err(#namespace::error::Error::InvalidMethodIndex(0, method.to_string())), + } + } + } + } + } + + fn impl_Service_for_Weak(&self) -> TokenStream { + let Self { + namespace, + Service, + methods, + .. + } = self; + let methods = methods.iter().map( + |Method { + method, + Input, + Output, + .. + }| { + quote! { + async fn #method(&self, ctrl: Self::Controller, input: #Input) -> #namespace::error::Result<#Output> { + let Some(service) = self.upgrade() else { + return Err(#namespace::error::Error::Shutdown); + }; + service.#method(ctrl, input).await + } + } + }, + ); + + quote! { + #[async_trait::async_trait] + impl #Service for ::std::sync::Weak + where + T: Send + Sync + 'static, + ::std::sync::Arc: #Service, + { + type Controller = <::std::sync::Arc as #Service>::Controller; + + #(#methods)* + } + } + } + + fn struct_ServiceDescriptor(&self) -> TokenStream { + let Self { + namespace, + ServiceDescriptor, + ServiceMethodDescriptor, + Service_str, + Service_proto_str, + Service_package_str, + methods, + .. + } = self; + + let doc = format!("A service descriptor for a `{}`.", Service_str); + + let methods = methods.iter().map(|Method { method_proto, .. }| { + quote! { #ServiceMethodDescriptor::#method_proto, } + }); + + quote! { + #[doc = #doc] + #[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd, Default)] + pub struct #ServiceDescriptor; + + impl #namespace::descriptor::ServiceDescriptor for #ServiceDescriptor { + type Method = #ServiceMethodDescriptor; + fn name(&self) -> &'static str { #Service_str } + fn proto_name(&self) -> &'static str { #Service_proto_str } + fn package(&self) -> &'static str { #Service_package_str } + fn methods(&self) -> &'static [Self::Method] { + &[ #(#methods)* ] + } + } + } + } + + fn enum_ServiceMethodDescriptor(&self) -> TokenStream { + let Self { + ServiceMethodDescriptor, + Service_str, + methods, + .. + } = self; + + let doc = formatdoc! {" + Methods available on a `{Service_str}`. + + This can be used as a key when routing requests for servers/clients of a `{Service_str}`. + "}; + + let variants = methods.iter().map( + |Method { + method_proto, + index, + .. + }| { + quote! { #method_proto = #index, } + }, + ); + + let impl_MethodDescriptor = self.impl_MethodDescriptor_for_ServiceMethodDescriptor(); + let impl_TryFrom = self.impl_TryFrom_for_ServiceMethodDescriptor(); + quote! { + #[doc = #doc] + #[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] + #[repr(u8)] + pub enum #ServiceMethodDescriptor { + #(#variants)* + } + + #impl_MethodDescriptor + + #impl_TryFrom + } + } + + fn impl_MethodDescriptor_for_ServiceMethodDescriptor(&self) -> TokenStream { + let Self { + namespace, + ServiceMethodDescriptor, + methods, + .. + } = self; + + let name = { + let arms = methods.iter().map( + |Method { + method_proto, + method_str, + .. + }| { + quote! { #ServiceMethodDescriptor::#method_proto => #method_str, } + }, + ); + + quote! { + fn name(&self) -> &'static str { + match *self { + #(#arms)* + } + } + } + }; + + let proto_name = { + let arms = methods.iter().map( + |Method { + method_proto, + method_proto_str, + .. + }| { + quote! { #ServiceMethodDescriptor::#method_proto => #method_proto_str, } + }, + ); + + quote! { + fn proto_name(&self) -> &'static str { + match *self { + #(#arms)* + } + } + } + }; + + let input_type = { + let arms = methods.iter().map(|Method { method_proto, Input, .. }| { + quote! { #ServiceMethodDescriptor::#method_proto => ::std::any::TypeId::of::<#Input>(), } + }); + + quote! { + fn input_type(&self) -> ::std::any::TypeId { + match *self { + #(#arms)* + } + } + } + }; + + let input_proto_type = { + let arms = methods.iter().map( + |Method { + method_proto, + Input_proto_str, + .. + }| { + quote! { #ServiceMethodDescriptor::#method_proto => #Input_proto_str, } + }, + ); + + quote! { + fn input_proto_type(&self) -> &'static str { + match *self { + #(#arms)* + } + } + } + }; + + let output_type = { + let arms = methods.iter().map(|Method { method_proto, Output, .. }| { + quote! { #ServiceMethodDescriptor::#method_proto => ::std::any::TypeId::of::<#Output>(), } + }); + + quote! { + fn output_type(&self) -> ::std::any::TypeId { + match *self { + #(#arms)* + } + } + } + }; + + let output_proto_type = { + let arms = methods.iter().map( + |Method { + method_proto, + Output_proto_str, + .. + }| { + quote! { #ServiceMethodDescriptor::#method_proto => #Output_proto_str, } + }, + ); + + quote! { + fn output_proto_type(&self) -> &'static str { + match *self { + #(#arms)* + } + } + } + }; + + quote! { + impl #namespace::descriptor::MethodDescriptor for #ServiceMethodDescriptor { + #name + + #proto_name + + #input_type + + #input_proto_type + + #output_type + + #output_proto_type + + fn index(&self) -> u8 { + *self as u8 + } + } + } + } + + fn impl_TryFrom_for_ServiceMethodDescriptor(&self) -> TokenStream { + let Self { + namespace, + ServiceMethodDescriptor, + Service_str, + methods, + .. + } = self; + + let arms = methods.iter().map( + |Method { + method_proto, + index, + .. + }| { + quote! { #index => Ok(#ServiceMethodDescriptor::#method_proto), } + }, + ); + + quote! { + impl std::convert::TryFrom for #ServiceMethodDescriptor { + type Error = #namespace::error::Error; + fn try_from(value: u8) -> #namespace::error::Result { + match value { + #(#arms)* + _ => Err(#namespace::error::Error::InvalidMethodIndex(value, #Service_str.to_string())), + } + } + } + } + } + + fn struct_ServiceClient(&self) -> TokenStream { + let Self { + namespace, + ServiceDescriptor, + ServiceClient, + Service_str, + .. + } = self; + + let doc = formatdoc! {" + A client for a `{Service_str}`. + + This implements the `{Service_str}` trait by dispatching all method calls to the supplied `Handler`. + "}; + + let impl_service_client = self.impl_ServiceClient(); + let impl_service_for_client = self.impl_Service_for_ServiceClient(); + quote! { + #[doc = #doc] + #[derive(Clone, Debug)] + pub struct #ServiceClient(H) where H: #namespace::handler::Handler; + + impl #ServiceClient where H: #namespace::handler::Handler { + /// Creates a new client instance that delegates all method calls to the supplied handler. + pub fn new(handler: H) -> Self { + Self(handler) + } + } + + #impl_service_client + + #impl_service_for_client + } + } + + fn impl_ServiceClient(&self) -> TokenStream { + let Self { + namespace, + ServiceClient, + ServiceDescriptor, + ServiceMethodDescriptor, + methods, + .. + } = self; + + let methods = methods.iter().map( + |Method { + method_inner, + method_proto, + Input, + Output, + .. + }| { + quote! { + async fn #method_inner(handler: H, ctrl: H::Controller, input: #Input) -> #namespace::error::Result<#Output> { + #namespace::__rt::call_method(handler, ctrl, #ServiceMethodDescriptor::#method_proto, input).await + } + } + }, + ); + + quote! { + impl #ServiceClient where H: #namespace::handler::Handler { + #(#methods)* + } + } + } + + fn impl_Service_for_ServiceClient(&self) -> TokenStream { + let Self { + namespace, + Service, + ServiceClient, + ServiceDescriptor, + methods, + .. + } = self; + + let methods = methods.iter().map( + |Method { + method, + method_inner, + Input, + Output, + .. + }| { + quote! { + async fn #method(&self, ctrl: H::Controller, input: #Input) -> #namespace::error::Result<#Output> { + #ServiceClient::#method_inner(self.0.clone(), ctrl, input).await + } + } + }, + ); + + quote! { + #[async_trait::async_trait] + impl #Service for #ServiceClient where H: #namespace::handler::Handler { + type Controller = H::Controller; + + #(#methods)* + } + } + } + + fn struct_ServiceClientFactory(&self) -> TokenStream { + let Self { + namespace, + Service, + ServiceClient, + ServiceClientFactory, + ServiceDescriptor, + .. + } = self; + + quote! { + pub struct #ServiceClientFactory(std::marker::PhantomData); + + impl Clone for #ServiceClientFactory { + fn clone(&self) -> Self { + Self(std::marker::PhantomData) + } + } + + impl #namespace::__rt::RpcClientFactory for #ServiceClientFactory where C: #namespace::controller::Controller { + type Descriptor = #ServiceDescriptor; + type ClientImpl = Box + Send + Sync + 'static>; + type Controller = C; + + fn new(handler: impl #namespace::handler::Handler) -> Self::ClientImpl { + Box::new(#ServiceClient::new(handler)) + } + } + } + } + + fn struct_ServiceServer(&self) -> TokenStream { + let Self { + namespace, + Service, + ServiceDescriptor, + ServiceServer, + ServiceMethodDescriptor, + Service_str, + methods, + .. + } = self; + + let doc = formatdoc! {" + A server for a `{Service_str}`. + + This implements the `Server` trait by handling requests and dispatch them to methods on the + supplied `{Service_str}`. + "}; + + let arms = methods.iter().map( + |Method { + method_proto, + method, + Input, + .. + }| { + quote! { + #ServiceMethodDescriptor::#method_proto => { + let decoded: #Input = #namespace::__rt::decode(input)?; + let ret = service.#method(ctrl, decoded).await?; + #namespace::__rt::encode(ret) + } + } + }, + ); + + quote! { + #[doc = #doc] + #[derive(Clone, Debug)] + pub struct #ServiceServer(A) where A: #Service + Clone + Send + 'static; + + impl #ServiceServer<::std::sync::Weak> + where + T: Send + Sync + 'static, + ::std::sync::Arc: #Service, + { + pub fn new_arc(service: ::std::sync::Arc) -> #ServiceServer<::std::sync::Weak> { + #ServiceServer(::std::sync::Arc::downgrade(&service)) + } + } + + impl #ServiceServer where A: #Service + Clone + Send + 'static { + /// Creates a new server instance that dispatches all calls to the supplied service. + pub fn new(service: A) -> #ServiceServer { + #ServiceServer(service) + } + + async fn call_inner( + service: A, + method: #ServiceMethodDescriptor, + ctrl: A::Controller, + input: ::bytes::Bytes) + -> #namespace::error::Result<::bytes::Bytes> { + match method { + #(#arms)* + } + } + } + + #[async_trait::async_trait] + impl #namespace::handler::Handler for #ServiceServer + where + A: #Service + Clone + Send + Sync + 'static { + type Descriptor = #ServiceDescriptor; + type Controller = A::Controller; + + async fn call( + &self, + ctrl: A::Controller, + method: #ServiceMethodDescriptor, + input: ::bytes::Bytes) + -> #namespace::error::Result<::bytes::Bytes> { + #ServiceServer::call_inner(self.0.clone(), method, ctrl, input).await + } + } + } + } +} + +/// The service generator to be used with `prost-build` to generate RPC implementations for +/// `prost-simple-rpc`. +/// +/// See the crate-level documentation for more info. +#[non_exhaustive] +#[derive(Debug, Default)] +pub struct ServiceGenerator; + +impl prost_build::ServiceGenerator for ServiceGenerator { + fn generate(&mut self, service: prost_build::Service, buf: &mut String) { + let info = Service::new(service); + + let trait_Service = info.trait_Service(); + let impl_Service_for_Weak = info.impl_Service_for_Weak(); + let struct_ServiceDescriptor = info.struct_ServiceDescriptor(); + let enum_ServiceMethodDescriptor = info.enum_ServiceMethodDescriptor(); + let struct_ServiceClient = info.struct_ServiceClient(); + let struct_ServiceClientFactory = info.struct_ServiceClientFactory(); + let struct_ServiceServer = info.struct_ServiceServer(); + + let tokens = quote! { + #trait_Service + + #impl_Service_for_Weak + + #struct_ServiceDescriptor + + #enum_ServiceMethodDescriptor + + #struct_ServiceClient + + #struct_ServiceClientFactory + + #struct_ServiceServer + }; + + buf.push('\n'); + buf.push_str(&tokens.to_string()); + buf.push('\n'); + } +}