refactor: rpc build (#2244)

rewrite rpc build with quota crate
This commit is contained in:
Luna Yao
2026-05-15 08:01:56 +02:00
committed by GitHub
parent 8428a89d2d
commit 811f151155
9 changed files with 730 additions and 502 deletions
Generated
+2 -23
View File
@@ -2264,7 +2264,6 @@ dependencies = [
"derivative",
"derive_builder",
"derive_more 2.1.1",
"easytier-rpc-build",
"encoding",
"flume 0.12.0",
"forwarded-header-value",
@@ -2310,6 +2309,7 @@ dependencies = [
"pin-project-lite",
"pnet",
"prefix-trie",
"proc-macro2",
"prost",
"prost-build",
"prost-reflect",
@@ -2319,6 +2319,7 @@ dependencies = [
"prost-wkt-types",
"quinn",
"quinn-plaintext",
"quote",
"rand 0.8.5",
"rcgen",
"regex",
@@ -2358,7 +2359,6 @@ dependencies = [
"tokio-util",
"tokio-websockets",
"toml 0.8.19",
"tonic-build",
"tracing",
"tracing-subscriber",
"tun-easytier",
@@ -2437,14 +2437,6 @@ dependencies = [
"windows 0.52.0",
]
[[package]]
name = "easytier-rpc-build"
version = "0.1.0"
dependencies = [
"heck 0.5.0",
"prost-build",
]
[[package]]
name = "easytier-uptime"
version = "0.1.0"
@@ -10055,19 +10047,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "tonic-build"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "568392c5a2bd0020723e3f387891176aabafe36fd9fcd074ad309dfa0c8eb964"
dependencies = [
"prettyplease",
"proc-macro2",
"prost-build",
"quote",
"syn 2.0.117",
]
[[package]]
name = "tower"
version = "0.4.13"
-1
View File
@@ -3,7 +3,6 @@ resolver = "2"
members = [
"easytier",
"easytier-gui/src-tauri",
"easytier-rpc-build",
"easytier-web",
"easytier-contrib/easytier-ffi",
"easytier-contrib/easytier-uptime",
-20
View File
@@ -1,20 +0,0 @@
[package]
name = "easytier-rpc-build"
description = "Protobuf RPC Service Generator for EasyTier"
version = "0.1.0"
edition.workspace = true
homepage = "https://github.com/EasyTier/EasyTier"
repository = "https://github.com/EasyTier/EasyTier"
authors = ["kkrainbow"]
keywords = ["vpn", "p2p", "network", "easytier"]
categories = ["network-programming", "command-line-utilities"]
license-file = "LICENSE"
readme = "README.md"
[dependencies]
heck = "0.5"
prost-build = "0.13"
[features]
default = []
internal-namespace = []
-1
View File
@@ -1 +0,0 @@
../LICENSE
-3
View File
@@ -1,3 +0,0 @@
# Introduction
This is a protobuf rpc service stub generator for [EasyTier](https://github.com/EasyTier/EasyTier) project.
-449
View File
@@ -1,449 +0,0 @@
extern crate heck;
extern crate prost_build;
use std::fmt;
#[cfg(feature = "internal-namespace")]
const NAMESPACE: &str = "crate::proto::rpc_types";
#[cfg(not(feature = "internal-namespace"))]
const NAMESPACE: &str = "easytier::proto::rpc_types";
/// The service generator to be used with `prost-build` to generate RPC implementations for
/// `prost-simple-rpc`.
///
/// See the crate-level documentation for more info.
#[allow(missing_copy_implementations)]
#[derive(Clone, Debug, Default)]
pub struct ServiceGenerator {
_private: (),
}
impl prost_build::ServiceGenerator for ServiceGenerator {
fn generate(&mut self, service: prost_build::Service, mut buf: &mut String) {
use std::fmt::Write;
let descriptor_name = format!("{}Descriptor", service.name);
let server_name = format!("{}Server", service.name);
let client_name = format!("{}Client", service.name);
let method_descriptor_name = format!("{}MethodDescriptor", service.name);
let mut trait_methods = String::new();
let mut weak_impl_methods = String::new();
let mut enum_methods = String::new();
let mut list_enum_methods = String::new();
let mut client_methods = String::new();
let mut client_own_methods = String::new();
let mut match_name_methods = String::new();
let mut match_proto_name_methods = String::new();
let mut match_input_type_methods = String::new();
let mut match_input_proto_type_methods = String::new();
let mut match_output_type_methods = String::new();
let mut match_output_proto_type_methods = String::new();
let mut match_handle_methods = String::new();
// generate trait default method Xxx::json_call_method match branch
let mut match_trait_json_methods = String::new();
let mut match_method_try_from = String::new();
for (idx, method) in service.methods.iter().enumerate() {
assert!(
!method.client_streaming,
"Client streaming not yet supported for method {}",
method.proto_name
);
assert!(
!method.server_streaming,
"Server streaming not yet supported for method {}",
method.proto_name
);
ServiceGenerator::write_comments(&mut trait_methods, 4, &method.comments).unwrap();
writeln!(
trait_methods,
r#" async fn {name}(&self, ctrl: Self::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}>;"#,
name = method.name,
input_type = method.input_type,
output_type = method.output_type,
namespace = NAMESPACE,
)
.unwrap();
writeln!(
weak_impl_methods,
r#" async fn {method_name}(&self, ctrl: Self::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}> {{
let Some(service) = self.upgrade() else {{
return Err({namespace}::error::Error::Shutdown);
}};
service.{method_name}(ctrl, input).await
}}"#,
method_name = method.name,
input_type = method.input_type,
output_type = method.output_type,
namespace = NAMESPACE,
)
.unwrap();
ServiceGenerator::write_comments(&mut enum_methods, 4, &method.comments).unwrap();
writeln!(
enum_methods,
" {name} = {index},",
name = method.proto_name,
index = idx + 1
)
.unwrap();
writeln!(
match_method_try_from,
" {index} => Ok({service_name}MethodDescriptor::{name}),",
service_name = service.name,
name = method.proto_name,
index = idx + 1,
)
.unwrap();
writeln!(
list_enum_methods,
" {service_name}MethodDescriptor::{name},",
service_name = service.name,
name = method.proto_name
)
.unwrap();
writeln!(
client_methods,
r#" async fn {name}(&self, ctrl: H::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}> {{
{client_name}Client::{name}_inner(self.0.clone(), ctrl, input).await
}}"#,
name = method.name,
input_type = method.input_type,
output_type = method.output_type,
client_name = service.name,
namespace = NAMESPACE,
)
.unwrap();
writeln!(
client_own_methods,
r#" async fn {name}_inner(handler: H, ctrl: H::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}> {{
{namespace}::__rt::call_method(handler, ctrl, {method_descriptor_name}::{proto_name}, input).await
}}"#,
name = method.name,
method_descriptor_name = method_descriptor_name,
proto_name = method.proto_name,
input_type = method.input_type,
output_type = method.output_type,
namespace = NAMESPACE,
).unwrap();
let case = format!(
" {service_name}MethodDescriptor::{proto_name} => ",
service_name = service.name,
proto_name = method.proto_name
);
writeln!(match_name_methods, "{}{:?},", case, method.name).unwrap();
writeln!(match_proto_name_methods, "{}{:?},", case, method.proto_name).unwrap();
writeln!(
match_input_type_methods,
"{}::std::any::TypeId::of::<{}>(),",
case, method.input_type
)
.unwrap();
writeln!(
match_input_proto_type_methods,
"{}{:?},",
case, method.input_proto_type
)
.unwrap();
writeln!(
match_output_type_methods,
"{}::std::any::TypeId::of::<{}>(),",
case, method.output_type
)
.unwrap();
writeln!(
match_output_proto_type_methods,
"{}{:?},",
case, method.output_proto_type
)
.unwrap();
write!(
match_handle_methods,
r#"{} {{
let decoded: {input_type} = {namespace}::__rt::decode(input)?;
let ret = service.{name}(ctrl, decoded).await?;
{namespace}::__rt::encode(ret)
}}
"#,
case,
input_type = method.input_type,
name = method.name,
namespace = NAMESPACE,
)
.unwrap();
write!(
match_trait_json_methods,
r#" "{name}" | "{proto_name}" => {{
let req: {input_type} = ::serde_json::from_value(json).map_err(|e| {namespace}::error::Error::MalformatRpcPacket(format!("json error: {{}}", e)))?;
let resp = self.{typed_method}(ctrl, req).await?;
Ok(::serde_json::to_value(resp).map_err(|e| {namespace}::error::Error::MalformatRpcPacket(format!("json error: {{}}", e)))?)
}}
"#,
name = method.name,
proto_name = method.proto_name,
input_type = method.input_type,
typed_method = method.name,
namespace = NAMESPACE,
)
.unwrap();
}
ServiceGenerator::write_comments(&mut buf, 0, &service.comments).unwrap();
write!(
buf,
r#"
#[async_trait::async_trait]
#[auto_impl::auto_impl(&, Arc, Box)]
pub trait {name} {{
type Controller: {namespace}::controller::Controller;
{trait_methods}
async fn json_call_method(
&self,
ctrl: Self::Controller,
method_name: &str,
json: ::serde_json::Value,
) -> {namespace}::error::Result<::serde_json::Value> {{
match method_name {{
{match_trait_json_methods}
_ => Err({namespace}::error::Error::InvalidMethodIndex(0, method_name.to_string())),
}}
}}
}}
#[async_trait::async_trait]
impl<T> {name} for ::std::sync::Weak<T>
where
T: Send + Sync + 'static,
::std::sync::Arc<T>: {name},
{{
type Controller = <::std::sync::Arc<T> as {name}>::Controller;
{weak_impl_methods}
}}
/// A service descriptor for a `{name}`.
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd, Default)]
pub struct {descriptor_name};
/// Methods available on a `{name}`.
///
/// This can be used as a key when routing requests for servers/clients of a `{name}`.
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
#[repr(u8)]
pub enum {method_descriptor_name} {{
{enum_methods}
}}
impl std::convert::TryFrom<u8> for {method_descriptor_name} {{
type Error = {namespace}::error::Error;
fn try_from(value: u8) -> {namespace}::error::Result<Self> {{
match value {{
{match_method_try_from}
_ => Err({namespace}::error::Error::InvalidMethodIndex(value, "{name}".to_string())),
}}
}}
}}
/// A client for a `{name}`.
///
/// This implements the `{name}` trait by dispatching all method calls to the supplied `Handler`.
#[derive(Clone, Debug)]
pub struct {client_name}<H>(H) where H: {namespace}::handler::Handler;
impl<H> {client_name}<H> where H: {namespace}::handler::Handler<Descriptor = {descriptor_name}> {{
/// Creates a new client instance that delegates all method calls to the supplied handler.
pub fn new(handler: H) -> {client_name}<H> {{
{client_name}(handler)
}}
}}
impl<H> {client_name}<H> where H: {namespace}::handler::Handler<Descriptor = {descriptor_name}> {{
{client_own_methods}
}}
#[async_trait::async_trait]
impl<H> {name} for {client_name}<H> where H: {namespace}::handler::Handler<Descriptor = {descriptor_name}> {{
type Controller = H::Controller;
{client_methods}
}}
pub struct {client_name}Factory<C: {namespace}::controller::Controller>(std::marker::PhantomData<C>);
impl<C: {namespace}::controller::Controller> Clone for {client_name}Factory<C> {{
fn clone(&self) -> Self {{
Self(std::marker::PhantomData)
}}
}}
impl<C> {namespace}::__rt::RpcClientFactory for {client_name}Factory<C> where C: {namespace}::controller::Controller {{
type Descriptor = {descriptor_name};
type ClientImpl = Box<dyn {name}<Controller = C> + Send + Sync + 'static>;
type Controller = C;
fn new(handler: impl {namespace}::handler::Handler<Descriptor = Self::Descriptor, Controller = Self::Controller>) -> Self::ClientImpl {{
Box::new({client_name}::new(handler))
}}
}}
/// A server for a `{name}`.
///
/// This implements the `Server` trait by handling requests and dispatch them to methods on the
/// supplied `{name}`.
#[derive(Clone, Debug)]
pub struct {server_name}<A>(A) where A: {name} + Clone + Send + 'static;
impl<T> {server_name}<::std::sync::Weak<T>>
where
T: Send + Sync + 'static,
::std::sync::Arc<T>: {name},
{{
pub fn new_arc(service: ::std::sync::Arc<T>) -> {server_name}<::std::sync::Weak<T>> {{
{server_name}(::std::sync::Arc::downgrade(&service))
}}
}}
impl<A> {server_name}<A> where A: {name} + Clone + Send + 'static {{
/// Creates a new server instance that dispatches all calls to the supplied service.
pub fn new(service: A) -> {server_name}<A> {{
{server_name}(service)
}}
async fn call_inner(
service: A,
method: {method_descriptor_name},
ctrl: A::Controller,
input: ::bytes::Bytes)
-> {namespace}::error::Result<::bytes::Bytes> {{
match method {{
{match_handle_methods}
}}
}}
}}
impl {namespace}::descriptor::ServiceDescriptor for {descriptor_name} {{
type Method = {method_descriptor_name};
fn name(&self) -> &'static str {{ {name:?} }}
fn proto_name(&self) -> &'static str {{ {proto_name:?} }}
fn package(&self) -> &'static str {{ {package:?} }}
fn methods(&self) -> &'static [Self::Method] {{
&[ {list_enum_methods} ]
}}
}}
#[async_trait::async_trait]
impl<A> {namespace}::handler::Handler for {server_name}<A>
where
A: {name} + Clone + Send + Sync + 'static {{
type Descriptor = {descriptor_name};
type Controller = A::Controller;
async fn call(
&self,
ctrl: A::Controller,
method: {method_descriptor_name},
input: ::bytes::Bytes)
-> {namespace}::error::Result<::bytes::Bytes> {{
{server_name}::call_inner(self.0.clone(), method, ctrl, input).await
}}
}}
impl {namespace}::descriptor::MethodDescriptor for {method_descriptor_name} {{
fn name(&self) -> &'static str {{
match *self {{
{match_name_methods}
}}
}}
fn proto_name(&self) -> &'static str {{
match *self {{
{match_proto_name_methods}
}}
}}
fn input_type(&self) -> ::std::any::TypeId {{
match *self {{
{match_input_type_methods}
}}
}}
fn input_proto_type(&self) -> &'static str {{
match *self {{
{match_input_proto_type_methods}
}}
}}
fn output_type(&self) -> ::std::any::TypeId {{
match *self {{
{match_output_type_methods}
}}
}}
fn output_proto_type(&self) -> &'static str {{
match *self {{
{match_output_proto_type_methods}
}}
}}
fn index(&self) -> u8 {{
*self as u8
}}
}}
"#,
name = service.name,
descriptor_name = descriptor_name,
server_name = server_name,
client_name = client_name,
method_descriptor_name = method_descriptor_name,
proto_name = service.proto_name,
package = service.package,
trait_methods = trait_methods,
weak_impl_methods = weak_impl_methods,
enum_methods = enum_methods,
list_enum_methods = list_enum_methods,
client_own_methods = client_own_methods,
client_methods = client_methods,
match_name_methods = match_name_methods,
match_proto_name_methods = match_proto_name_methods,
match_input_type_methods = match_input_type_methods,
match_input_proto_type_methods = match_input_proto_type_methods,
match_output_type_methods = match_output_type_methods,
match_output_proto_type_methods = match_output_proto_type_methods,
match_handle_methods = match_handle_methods,
match_trait_json_methods = match_trait_json_methods,
namespace = NAMESPACE,
).unwrap();
}
}
impl ServiceGenerator {
fn write_comments<W>(
mut write: W,
indent: usize,
comments: &prost_build::Comments,
) -> fmt::Result
where
W: fmt::Write,
{
for comment in &comments.leading {
for line in comment.lines().filter(|s| !s.is_empty()) {
writeln!(write, "{}///{}", " ".repeat(indent), line)?;
}
}
Ok(())
}
}
+4 -4
View File
@@ -11,6 +11,7 @@ keywords = ["vpn", "p2p", "network", "easytier"]
categories = ["network-programming", "command-line-utilities"]
license-file = "LICENSE"
readme = "README.md"
build = "build/main.rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -316,15 +317,14 @@ jemalloc-sys = { package = "tikv-jemalloc-sys", version = "0.6.0", features = [
[build-dependencies]
cfg_aliases = "0.2.1"
tonic-build = "0.12"
indoc = "2.0"
globwalk = "0.8.1"
regex = "1"
prost-build = "0.13.5"
prost-wkt-build = "0.6"
easytier-rpc-build = { path = "../easytier-rpc-build", features = [
"internal-namespace",
] }
prost-reflect-build = { version = "0.14.0" }
proc-macro2 = "1"
quote = "1"
thunk-rs = { git = "https://github.com/easytier/thunk.git", default-features = false, features = [
"win7",
] }
+4 -1
View File
@@ -1,3 +1,6 @@
mod rpc;
use crate::rpc::ServiceGenerator;
use cfg_aliases::cfg_aliases;
use prost_wkt_build::{FileDescriptorSet, Message as _};
#[cfg(target_os = "windows")]
@@ -197,7 +200,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.type_attribute("acl.Rule", "#[serde(default)]")
.type_attribute("acl.GroupInfo", "#[serde(default)]")
.field_attribute(".api.manage.NetworkConfig", "#[serde(default)]")
.service_generator(Box::new(easytier_rpc_build::ServiceGenerator::default()))
.service_generator(Box::new(ServiceGenerator::default()))
.btree_map(["."])
.skip_debug([".common.Ipv4Addr", ".common.Ipv6Addr", ".common.UUID"]);
+720
View File
@@ -0,0 +1,720 @@
#![allow(non_snake_case)]
use indoc::formatdoc;
use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote};
use std::str::FromStr;
fn parse(value: &str) -> TokenStream {
TokenStream::from_str(value)
.unwrap_or_else(|err| panic!("Failed to parse tokens: {} ({})", value, err))
}
fn doc(comments: &prost_build::Comments) -> TokenStream {
let doc = comments
.leading
.iter()
.flat_map(|c| c.lines().filter(|s| !s.is_empty()));
quote! { #( #[doc = #doc] )* }
}
const NAMESPACE: &str = "crate::proto::rpc_types";
struct Method {
index: u8,
doc: TokenStream,
method: Ident,
method_inner: Ident,
method_str: String,
method_proto: Ident,
method_proto_str: String,
Input: TokenStream,
Input_proto_str: String,
Output: TokenStream,
Output_proto_str: String,
}
impl Method {
fn new(index: u8, method: prost_build::Method) -> Self {
assert!(
!method.client_streaming,
"Client streaming not yet supported for method {}",
method.proto_name
);
assert!(
!method.server_streaming,
"Server streaming not yet supported for method {}",
method.proto_name
);
Self {
index,
doc: doc(&method.comments),
method: format_ident!("{}", method.name),
method_inner: format_ident!("{}_inner", method.name),
method_str: method.name,
method_proto: format_ident!("{}", method.proto_name),
method_proto_str: method.proto_name,
Input: parse(&method.input_type),
Input_proto_str: method.input_proto_type,
Output: parse(&method.output_type),
Output_proto_str: method.output_proto_type,
}
}
}
struct Service {
namespace: TokenStream,
doc: TokenStream,
Service: Ident,
ServiceDescriptor: Ident,
ServiceServer: Ident,
ServiceClient: Ident,
ServiceClientFactory: Ident,
ServiceMethodDescriptor: Ident,
Service_str: String,
Service_proto_str: String,
Service_package_str: String,
methods: Vec<Method>,
}
impl Service {
fn new(service: prost_build::Service) -> Self {
let methods = service
.methods
.into_iter()
.enumerate()
.map(|(i, method)| Method::new((i + 1) as u8, method))
.collect();
Self {
namespace: parse(NAMESPACE),
doc: doc(&service.comments),
Service: format_ident!("{}", service.name),
ServiceDescriptor: format_ident!("{}Descriptor", service.name),
ServiceServer: format_ident!("{}Server", service.name),
ServiceClient: format_ident!("{}Client", service.name),
ServiceClientFactory: format_ident!("{}ClientFactory", service.name),
ServiceMethodDescriptor: format_ident!("{}MethodDescriptor", service.name),
Service_str: service.name,
Service_proto_str: service.proto_name,
Service_package_str: service.package,
methods,
}
}
fn trait_Service(&self) -> TokenStream {
let Self {
namespace,
doc,
Service,
methods,
..
} = self;
let match_json_call_method = methods.iter().map(
|Method {
method,
method_str,
method_proto_str,
Input,
..
}| {
quote! {
#method_str | #method_proto_str => {
let req: #Input = ::serde_json::from_value(json)
.map_err(|e| #namespace::error::Error::MalformatRpcPacket(format!("json error: {}", e)))?;
let resp = self.#method(ctrl, req).await?;
Ok(::serde_json::to_value(resp)
.map_err(|e| #namespace::error::Error::MalformatRpcPacket(format!("json error: {}", e)))?)
}
}
},
);
let methods = methods.iter().map(
|Method {
doc,
method,
Input,
Output,
..
}| {
quote! {
#doc
async fn #method(&self, ctrl: Self::Controller, input: #Input) -> #namespace::error::Result<#Output>;
}
},
);
quote! {
#doc
#[async_trait::async_trait]
#[auto_impl::auto_impl(&, Arc, Box)]
pub trait #Service {
type Controller: #namespace::controller::Controller;
#(#methods)*
async fn json_call_method(
&self,
ctrl: Self::Controller,
method: &str,
json: ::serde_json::Value,
) -> #namespace::error::Result<::serde_json::Value> {
match method {
#(#match_json_call_method)*
_ => Err(#namespace::error::Error::InvalidMethodIndex(0, method.to_string())),
}
}
}
}
}
fn impl_Service_for_Weak(&self) -> TokenStream {
let Self {
namespace,
Service,
methods,
..
} = self;
let methods = methods.iter().map(
|Method {
method,
Input,
Output,
..
}| {
quote! {
async fn #method(&self, ctrl: Self::Controller, input: #Input) -> #namespace::error::Result<#Output> {
let Some(service) = self.upgrade() else {
return Err(#namespace::error::Error::Shutdown);
};
service.#method(ctrl, input).await
}
}
},
);
quote! {
#[async_trait::async_trait]
impl<T> #Service for ::std::sync::Weak<T>
where
T: Send + Sync + 'static,
::std::sync::Arc<T>: #Service,
{
type Controller = <::std::sync::Arc<T> as #Service>::Controller;
#(#methods)*
}
}
}
fn struct_ServiceDescriptor(&self) -> TokenStream {
let Self {
namespace,
ServiceDescriptor,
ServiceMethodDescriptor,
Service_str,
Service_proto_str,
Service_package_str,
methods,
..
} = self;
let doc = format!("A service descriptor for a `{}`.", Service_str);
let methods = methods.iter().map(|Method { method_proto, .. }| {
quote! { #ServiceMethodDescriptor::#method_proto, }
});
quote! {
#[doc = #doc]
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd, Default)]
pub struct #ServiceDescriptor;
impl #namespace::descriptor::ServiceDescriptor for #ServiceDescriptor {
type Method = #ServiceMethodDescriptor;
fn name(&self) -> &'static str { #Service_str }
fn proto_name(&self) -> &'static str { #Service_proto_str }
fn package(&self) -> &'static str { #Service_package_str }
fn methods(&self) -> &'static [Self::Method] {
&[ #(#methods)* ]
}
}
}
}
fn enum_ServiceMethodDescriptor(&self) -> TokenStream {
let Self {
ServiceMethodDescriptor,
Service_str,
methods,
..
} = self;
let doc = formatdoc! {"
Methods available on a `{Service_str}`.
This can be used as a key when routing requests for servers/clients of a `{Service_str}`.
"};
let variants = methods.iter().map(
|Method {
method_proto,
index,
..
}| {
quote! { #method_proto = #index, }
},
);
let impl_MethodDescriptor = self.impl_MethodDescriptor_for_ServiceMethodDescriptor();
let impl_TryFrom = self.impl_TryFrom_for_ServiceMethodDescriptor();
quote! {
#[doc = #doc]
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
#[repr(u8)]
pub enum #ServiceMethodDescriptor {
#(#variants)*
}
#impl_MethodDescriptor
#impl_TryFrom
}
}
fn impl_MethodDescriptor_for_ServiceMethodDescriptor(&self) -> TokenStream {
let Self {
namespace,
ServiceMethodDescriptor,
methods,
..
} = self;
let name = {
let arms = methods.iter().map(
|Method {
method_proto,
method_str,
..
}| {
quote! { #ServiceMethodDescriptor::#method_proto => #method_str, }
},
);
quote! {
fn name(&self) -> &'static str {
match *self {
#(#arms)*
}
}
}
};
let proto_name = {
let arms = methods.iter().map(
|Method {
method_proto,
method_proto_str,
..
}| {
quote! { #ServiceMethodDescriptor::#method_proto => #method_proto_str, }
},
);
quote! {
fn proto_name(&self) -> &'static str {
match *self {
#(#arms)*
}
}
}
};
let input_type = {
let arms = methods.iter().map(|Method { method_proto, Input, .. }| {
quote! { #ServiceMethodDescriptor::#method_proto => ::std::any::TypeId::of::<#Input>(), }
});
quote! {
fn input_type(&self) -> ::std::any::TypeId {
match *self {
#(#arms)*
}
}
}
};
let input_proto_type = {
let arms = methods.iter().map(
|Method {
method_proto,
Input_proto_str,
..
}| {
quote! { #ServiceMethodDescriptor::#method_proto => #Input_proto_str, }
},
);
quote! {
fn input_proto_type(&self) -> &'static str {
match *self {
#(#arms)*
}
}
}
};
let output_type = {
let arms = methods.iter().map(|Method { method_proto, Output, .. }| {
quote! { #ServiceMethodDescriptor::#method_proto => ::std::any::TypeId::of::<#Output>(), }
});
quote! {
fn output_type(&self) -> ::std::any::TypeId {
match *self {
#(#arms)*
}
}
}
};
let output_proto_type = {
let arms = methods.iter().map(
|Method {
method_proto,
Output_proto_str,
..
}| {
quote! { #ServiceMethodDescriptor::#method_proto => #Output_proto_str, }
},
);
quote! {
fn output_proto_type(&self) -> &'static str {
match *self {
#(#arms)*
}
}
}
};
quote! {
impl #namespace::descriptor::MethodDescriptor for #ServiceMethodDescriptor {
#name
#proto_name
#input_type
#input_proto_type
#output_type
#output_proto_type
fn index(&self) -> u8 {
*self as u8
}
}
}
}
fn impl_TryFrom_for_ServiceMethodDescriptor(&self) -> TokenStream {
let Self {
namespace,
ServiceMethodDescriptor,
Service_str,
methods,
..
} = self;
let arms = methods.iter().map(
|Method {
method_proto,
index,
..
}| {
quote! { #index => Ok(#ServiceMethodDescriptor::#method_proto), }
},
);
quote! {
impl std::convert::TryFrom<u8> for #ServiceMethodDescriptor {
type Error = #namespace::error::Error;
fn try_from(value: u8) -> #namespace::error::Result<Self> {
match value {
#(#arms)*
_ => Err(#namespace::error::Error::InvalidMethodIndex(value, #Service_str.to_string())),
}
}
}
}
}
fn struct_ServiceClient(&self) -> TokenStream {
let Self {
namespace,
ServiceDescriptor,
ServiceClient,
Service_str,
..
} = self;
let doc = formatdoc! {"
A client for a `{Service_str}`.
This implements the `{Service_str}` trait by dispatching all method calls to the supplied `Handler`.
"};
let impl_service_client = self.impl_ServiceClient();
let impl_service_for_client = self.impl_Service_for_ServiceClient();
quote! {
#[doc = #doc]
#[derive(Clone, Debug)]
pub struct #ServiceClient<H>(H) where H: #namespace::handler::Handler;
impl<H> #ServiceClient<H> where H: #namespace::handler::Handler<Descriptor = #ServiceDescriptor> {
/// Creates a new client instance that delegates all method calls to the supplied handler.
pub fn new(handler: H) -> Self {
Self(handler)
}
}
#impl_service_client
#impl_service_for_client
}
}
fn impl_ServiceClient(&self) -> TokenStream {
let Self {
namespace,
ServiceClient,
ServiceDescriptor,
ServiceMethodDescriptor,
methods,
..
} = self;
let methods = methods.iter().map(
|Method {
method_inner,
method_proto,
Input,
Output,
..
}| {
quote! {
async fn #method_inner(handler: H, ctrl: H::Controller, input: #Input) -> #namespace::error::Result<#Output> {
#namespace::__rt::call_method(handler, ctrl, #ServiceMethodDescriptor::#method_proto, input).await
}
}
},
);
quote! {
impl<H> #ServiceClient<H> where H: #namespace::handler::Handler<Descriptor = #ServiceDescriptor> {
#(#methods)*
}
}
}
fn impl_Service_for_ServiceClient(&self) -> TokenStream {
let Self {
namespace,
Service,
ServiceClient,
ServiceDescriptor,
methods,
..
} = self;
let methods = methods.iter().map(
|Method {
method,
method_inner,
Input,
Output,
..
}| {
quote! {
async fn #method(&self, ctrl: H::Controller, input: #Input) -> #namespace::error::Result<#Output> {
#ServiceClient::#method_inner(self.0.clone(), ctrl, input).await
}
}
},
);
quote! {
#[async_trait::async_trait]
impl<H> #Service for #ServiceClient<H> where H: #namespace::handler::Handler<Descriptor = #ServiceDescriptor> {
type Controller = H::Controller;
#(#methods)*
}
}
}
fn struct_ServiceClientFactory(&self) -> TokenStream {
let Self {
namespace,
Service,
ServiceClient,
ServiceClientFactory,
ServiceDescriptor,
..
} = self;
quote! {
pub struct #ServiceClientFactory<C: #namespace::controller::Controller>(std::marker::PhantomData<C>);
impl<C: #namespace::controller::Controller> Clone for #ServiceClientFactory<C> {
fn clone(&self) -> Self {
Self(std::marker::PhantomData)
}
}
impl<C> #namespace::__rt::RpcClientFactory for #ServiceClientFactory<C> where C: #namespace::controller::Controller {
type Descriptor = #ServiceDescriptor;
type ClientImpl = Box<dyn #Service<Controller = C> + Send + Sync + 'static>;
type Controller = C;
fn new(handler: impl #namespace::handler::Handler<Descriptor = Self::Descriptor, Controller = Self::Controller>) -> Self::ClientImpl {
Box::new(#ServiceClient::new(handler))
}
}
}
}
fn struct_ServiceServer(&self) -> TokenStream {
let Self {
namespace,
Service,
ServiceDescriptor,
ServiceServer,
ServiceMethodDescriptor,
Service_str,
methods,
..
} = self;
let doc = formatdoc! {"
A server for a `{Service_str}`.
This implements the `Server` trait by handling requests and dispatch them to methods on the
supplied `{Service_str}`.
"};
let arms = methods.iter().map(
|Method {
method_proto,
method,
Input,
..
}| {
quote! {
#ServiceMethodDescriptor::#method_proto => {
let decoded: #Input = #namespace::__rt::decode(input)?;
let ret = service.#method(ctrl, decoded).await?;
#namespace::__rt::encode(ret)
}
}
},
);
quote! {
#[doc = #doc]
#[derive(Clone, Debug)]
pub struct #ServiceServer<A>(A) where A: #Service + Clone + Send + 'static;
impl<T> #ServiceServer<::std::sync::Weak<T>>
where
T: Send + Sync + 'static,
::std::sync::Arc<T>: #Service,
{
pub fn new_arc(service: ::std::sync::Arc<T>) -> #ServiceServer<::std::sync::Weak<T>> {
#ServiceServer(::std::sync::Arc::downgrade(&service))
}
}
impl<A> #ServiceServer<A> where A: #Service + Clone + Send + 'static {
/// Creates a new server instance that dispatches all calls to the supplied service.
pub fn new(service: A) -> #ServiceServer<A> {
#ServiceServer(service)
}
async fn call_inner(
service: A,
method: #ServiceMethodDescriptor,
ctrl: A::Controller,
input: ::bytes::Bytes)
-> #namespace::error::Result<::bytes::Bytes> {
match method {
#(#arms)*
}
}
}
#[async_trait::async_trait]
impl<A> #namespace::handler::Handler for #ServiceServer<A>
where
A: #Service + Clone + Send + Sync + 'static {
type Descriptor = #ServiceDescriptor;
type Controller = A::Controller;
async fn call(
&self,
ctrl: A::Controller,
method: #ServiceMethodDescriptor,
input: ::bytes::Bytes)
-> #namespace::error::Result<::bytes::Bytes> {
#ServiceServer::call_inner(self.0.clone(), method, ctrl, input).await
}
}
}
}
}
/// The service generator to be used with `prost-build` to generate RPC implementations for
/// `prost-simple-rpc`.
///
/// See the crate-level documentation for more info.
#[non_exhaustive]
#[derive(Debug, Default)]
pub struct ServiceGenerator;
impl prost_build::ServiceGenerator for ServiceGenerator {
fn generate(&mut self, service: prost_build::Service, buf: &mut String) {
let info = Service::new(service);
let trait_Service = info.trait_Service();
let impl_Service_for_Weak = info.impl_Service_for_Weak();
let struct_ServiceDescriptor = info.struct_ServiceDescriptor();
let enum_ServiceMethodDescriptor = info.enum_ServiceMethodDescriptor();
let struct_ServiceClient = info.struct_ServiceClient();
let struct_ServiceClientFactory = info.struct_ServiceClientFactory();
let struct_ServiceServer = info.struct_ServiceServer();
let tokens = quote! {
#trait_Service
#impl_Service_for_Weak
#struct_ServiceDescriptor
#enum_ServiceMethodDescriptor
#struct_ServiceClient
#struct_ServiceClientFactory
#struct_ServiceServer
};
buf.push('\n');
buf.push_str(&tokens.to_string());
buf.push('\n');
}
}