Skip to content

Commit

Permalink
Implement server side named parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
luc65r committed Sep 2, 2021
1 parent 17883c0 commit 6ae633c
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 106 deletions.
5 changes: 1 addition & 4 deletions derive/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,7 @@ impl DeriveOptions {
options.enable_client = true;
options.enable_server = true;
}
if options.enable_server && options.params_style == ParamStyle::Named {
// This is not allowed at this time
panic!("Server code generation only supports `params = \"positional\"` (default) or `params = \"raw\" at this time.")
}

Ok(options)
}
}
14 changes: 0 additions & 14 deletions derive/src/rpc_trait.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::options::DeriveOptions;
use crate::params_style::ParamStyle;
use crate::rpc_attr::{AttributeKind, PubSubMethodKind, RpcMethodAttribute};
use crate::to_client::generate_client_module;
use crate::to_delegate::{generate_trait_item_method, MethodRegistration, RpcMethod};
Expand All @@ -22,10 +21,6 @@ const MISSING_UNSUBSCRIBE_METHOD_ERR: &str =
"Can't find unsubscribe method, expected a method annotated with `unsubscribe` \
e.g. `#[pubsub(subscription = \"hello\", unsubscribe, name = \"hello_unsubscribe\")]`";

pub const USING_NAMED_PARAMS_WITH_SERVER_ERR: &str =
"`params = \"named\"` can only be used to generate a client (on a trait annotated with #[rpc(client)]). \
At this time the server does not support named parameters.";

const RPC_MOD_NAME_PREFIX: &str = "rpc_impl_";

struct RpcTrait {
Expand Down Expand Up @@ -222,12 +217,6 @@ fn rpc_wrapper_mod_name(rpc_trait: &syn::ItemTrait) -> syn::Ident {
syn::Ident::new(&mod_name, proc_macro2::Span::call_site())
}

fn has_named_params(methods: &[RpcMethod]) -> bool {
methods
.iter()
.any(|method| method.attr.params_style == Some(ParamStyle::Named))
}

pub fn crate_name(name: &str) -> Result<Ident> {
proc_macro_crate::crate_name(name)
.map(|name| Ident::new(&name, Span::call_site()))
Expand Down Expand Up @@ -264,9 +253,6 @@ pub fn rpc_impl(input: syn::Item, options: &DeriveOptions) -> Result<proc_macro2
});
}
if options.enable_server {
if has_named_params(&methods) {
return Err(syn::Error::new_spanned(rpc_trait, USING_NAMED_PARAMS_WITH_SERVER_ERR));
}
let rpc_server_module = generate_server_module(&method_registrations, &rpc_trait, &methods)?;
submodules.push(rpc_server_module);
exports.push(quote! {
Expand Down
125 changes: 76 additions & 49 deletions derive/src/to_delegate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,21 +206,37 @@ impl RpcMethod {
}

fn generate_delegate_closure(&self, is_subscribe: bool) -> Result<proc_macro2::TokenStream> {
let mut param_types: Vec<_> = self
let args = self
.trait_item
.sig
.inputs
.iter()
.cloned()
.filter_map(|arg| match arg {
syn::FnArg::Typed(ty) => Some(*ty.ty),
syn::FnArg::Typed(pat_type) => Some(pat_type),
_ => None,
})
.collect();
.enumerate();

let (special_args, fn_args) = {
// special args are those which are not passed directly via rpc params: metadata, subscriber
let mut special_args = vec![];
let mut fn_args = vec![];

for (i, arg) in args {
if let Some(sarg) = Self::special_arg(i, arg.clone()) {
special_args.push(sarg);
} else {
fn_args.push(arg);
}
}

(special_args, fn_args)
};

let param_types: Vec<_> = fn_args.iter().map(|arg| *arg.ty.clone()).collect();
let arg_names: Vec<_> = fn_args.iter().map(|arg| *arg.pat.clone()).collect();

// special args are those which are not passed directly via rpc params: metadata, subscriber
let special_args = Self::special_args(&param_types);
param_types.retain(|ty| !special_args.iter().any(|(_, sty)| sty == ty));
if param_types.len() > TUPLE_FIELD_NAMES.len() {
return Err(syn::Error::new_spanned(
&self.trait_item,
Expand All @@ -232,28 +248,38 @@ impl RpcMethod {
.take(param_types.len())
.map(|name| ident(name))
.collect());
let param_types = &param_types;
let parse_params = {
// last arguments that are `Option`-s are optional 'trailing' arguments
let trailing_args_num = param_types.iter().rev().take_while(|t| is_option_type(t)).count();

if trailing_args_num != 0 {
self.params_with_trailing(trailing_args_num, param_types, tuple_fields)
} else if param_types.is_empty() {
quote! { let params = params.expect_no_params(); }
} else if self.attr.params_style == Some(ParamStyle::Raw) {
quote! { let params: _jsonrpc_core::Result<_> = Ok((params,)); }
} else if self.attr.params_style == Some(ParamStyle::Positional) {
quote! { let params = params.parse::<(#(#param_types, )*)>(); }
} else {
unimplemented!("Server side named parameters are not implemented");
let parse_params = if param_types.is_empty() {
quote! { let params = params.expect_no_params(); }
} else {
match self.attr.params_style.as_ref().unwrap() {
ParamStyle::Raw => quote! { let params: _jsonrpc_core::Result<_> = Ok((params,)); },
ParamStyle::Positional => {
// last arguments that are `Option`-s are optional 'trailing' arguments
let trailing_args_num = param_types.iter().rev().take_while(|t| is_option_type(t)).count();
if trailing_args_num != 0 {
self.params_with_trailing(trailing_args_num, &param_types, tuple_fields)
} else {
quote! { let params = params.parse::<(#(#param_types, )*)>(); }
}
}
ParamStyle::Named => quote! {
#[derive(serde::Deserialize)]
#[allow(non_camel_case_types)]
struct __Params {
#(
#fn_args,
)*
}
let params = params.parse::<__Params>()
.map(|__Params { #(#arg_names, )* }| (#(#arg_names, )*));
},
}
};

let method_ident = self.ident();
let result = &self.trait_item.sig.output;
let extra_closure_args: &Vec<_> = &special_args.iter().cloned().map(|arg| arg.0).collect();
let extra_method_types: &Vec<_> = &special_args.iter().cloned().map(|arg| arg.1).collect();
let extra_closure_args: Vec<_> = special_args.iter().map(|arg| *arg.pat.clone()).collect();
let extra_method_types: Vec<_> = special_args.iter().map(|arg| *arg.ty.clone()).collect();

let closure_args = quote! { base, params, #(#extra_closure_args), * };
let method_sig = quote! { fn(&Self, #(#extra_method_types, ) * #(#param_types), *) #result };
Expand Down Expand Up @@ -301,34 +327,35 @@ impl RpcMethod {
})
}

fn special_args(param_types: &[syn::Type]) -> Vec<(syn::Ident, syn::Type)> {
let meta_arg = param_types.first().and_then(|ty| {
if *ty == parse_quote!(Self::Metadata) {
Some(ty.clone())
} else {
None
}
});
let subscriber_arg = param_types.get(1).and_then(|ty| {
if let syn::Type::Path(path) = ty {
if path.path.segments.iter().any(|s| s.ident == SUBSCRIBER_TYPE_IDENT) {
Some(ty.clone())
} else {
None
fn special_arg(index: usize, arg: syn::PatType) -> Option<syn::PatType> {
match index {
0 if arg.ty == parse_quote!(Self::Metadata) => Some(syn::PatType {
pat: Box::new(syn::Pat::Ident(syn::PatIdent {
attrs: vec![],
by_ref: None,
mutability: None,
ident: ident(METADATA_CLOSURE_ARG),
subpat: None,
})),
..arg
}),
1 => match *arg.ty {
syn::Type::Path(ref path) if path.path.segments.iter().any(|s| s.ident == SUBSCRIBER_TYPE_IDENT) => {
Some(syn::PatType {
pat: Box::new(syn::Pat::Ident(syn::PatIdent {
attrs: vec![],
by_ref: None,
mutability: None,
ident: ident(SUBSCRIBER_CLOSURE_ARG),
subpat: None,
})),
..arg
})
}
} else {
None
}
});

let mut special_args = Vec::new();
if let Some(meta) = meta_arg {
special_args.push((ident(METADATA_CLOSURE_ARG), meta));
}
if let Some(subscriber) = subscriber_arg {
special_args.push((ident(SUBSCRIBER_CLOSURE_ARG), subscriber));
_ => None,
},
_ => None,
}
special_args
}

fn params_with_trailing(
Expand Down
101 changes: 95 additions & 6 deletions derive/tests/macros.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use jsonrpc_core::types::params::Params;
use jsonrpc_core::{IoHandler, Response};
use jsonrpc_derive::rpc;
use serde;
use serde_json;

pub enum MyError {}
Expand All @@ -14,6 +15,8 @@ type Result<T> = ::std::result::Result<T, MyError>;

#[rpc]
pub trait Rpc {
type Metadata;

/// Returns a protocol version.
#[rpc(name = "protocolVersion")]
fn protocol_version(&self) -> Result<String>;
Expand All @@ -30,6 +33,18 @@ pub trait Rpc {
#[rpc(name = "raw", params = "raw")]
fn raw(&self, params: Params) -> Result<String>;

/// Adds two numbers and returns a result.
#[rpc(name = "named_add", params = "named")]
fn named_add(&self, a: u64, b: u64) -> Result<u64>;

/// Adds one or two numbers and returns a result.
#[rpc(name = "option_named_add", params = "named")]
fn option_named_add(&self, a: u64, b: Option<u64>) -> Result<u64>;

/// Adds two numbers and returns a result.
#[rpc(meta, name = "meta_named_add", params = "named")]
fn meta_named_add(&self, meta: Self::Metadata, a: u64, b: u64) -> Result<u64>;

/// Handles a notification.
#[rpc(name = "notify")]
fn notify(&self, a: u64);
Expand All @@ -39,6 +54,8 @@ pub trait Rpc {
struct RpcImpl;

impl Rpc for RpcImpl {
type Metadata = Metadata;

fn protocol_version(&self) -> Result<String> {
Ok("version1".into())
}
Expand All @@ -55,14 +72,30 @@ impl Rpc for RpcImpl {
Ok("OK".into())
}

fn named_add(&self, a: u64, b: u64) -> Result<u64> {
Ok(a + b)
}

fn option_named_add(&self, a: u64, b: Option<u64>) -> Result<u64> {
Ok(a + b.unwrap_or_default())
}

fn meta_named_add(&self, _meta: Self::Metadata, a: u64, b: u64) -> Result<u64> {
Ok(a + b)
}

fn notify(&self, a: u64) {
println!("Received `notify` with value: {}", a);
}
}

#[derive(Clone, Default)]
struct Metadata;
impl jsonrpc_core::Metadata for Metadata {}

#[test]
fn should_accept_empty_array_as_no_params() {
let mut io = IoHandler::new();
let mut io = IoHandler::default();
let rpc = RpcImpl::default();
io.extend_with(rpc.to_delegate());

Expand Down Expand Up @@ -94,7 +127,7 @@ fn should_accept_empty_array_as_no_params() {

#[test]
fn should_accept_single_param() {
let mut io = IoHandler::new();
let mut io = IoHandler::default();
let rpc = RpcImpl::default();
io.extend_with(rpc.to_delegate());

Expand All @@ -120,7 +153,7 @@ fn should_accept_single_param() {

#[test]
fn should_accept_multiple_params() {
let mut io = IoHandler::new();
let mut io = IoHandler::default();
let rpc = RpcImpl::default();
io.extend_with(rpc.to_delegate());

Expand All @@ -146,7 +179,7 @@ fn should_accept_multiple_params() {

#[test]
fn should_use_method_name_aliases() {
let mut io = IoHandler::new();
let mut io = IoHandler::default();
let rpc = RpcImpl::default();
io.extend_with(rpc.to_delegate());

Expand Down Expand Up @@ -187,7 +220,7 @@ fn should_use_method_name_aliases() {

#[test]
fn should_accept_any_raw_params() {
let mut io = IoHandler::new();
let mut io = IoHandler::default();
let rpc = RpcImpl::default();
io.extend_with(rpc.to_delegate());

Expand Down Expand Up @@ -222,9 +255,65 @@ fn should_accept_any_raw_params() {
assert_eq!(expected, result4);
}

#[test]
fn should_accept_named_params() {
let mut io = IoHandler::default();
let rpc = RpcImpl::default();
io.extend_with(rpc.to_delegate());

// when
let req1 = r#"{"jsonrpc":"2.0","id":1,"method":"named_add","params":{"a":1,"b":2}}"#;
let req2 = r#"{"jsonrpc":"2.0","id":1,"method":"named_add","params":{"b":2,"a":1}}"#;

let res1 = io.handle_request_sync(req1);
let res2 = io.handle_request_sync(req2);

let expected = r#"{
"jsonrpc": "2.0",
"result": 3,
"id": 1
}"#;
let expected: Response = serde_json::from_str(expected).unwrap();

// then
let result1: Response = serde_json::from_str(&res1.unwrap()).unwrap();
assert_eq!(expected, result1);

let result2: Response = serde_json::from_str(&res2.unwrap()).unwrap();
assert_eq!(expected, result2);
}

#[test]
fn should_accept_option_named_params() {
let mut io = IoHandler::default();
let rpc = RpcImpl::default();
io.extend_with(rpc.to_delegate());

// when
let req1 = r#"{"jsonrpc":"2.0","id":1,"method":"option_named_add","params":{"a":1,"b":2}}"#;
let req2 = r#"{"jsonrpc":"2.0","id":1,"method":"option_named_add","params":{"a":3}}"#;

let res1 = io.handle_request_sync(req1);
let res2 = io.handle_request_sync(req2);

let expected = r#"{
"jsonrpc": "2.0",
"result": 3,
"id": 1
}"#;
let expected: Response = serde_json::from_str(expected).unwrap();

// then
let result1: Response = serde_json::from_str(&res1.unwrap()).unwrap();
assert_eq!(expected, result1);

let result2: Response = serde_json::from_str(&res2.unwrap()).unwrap();
assert_eq!(expected, result2);
}

#[test]
fn should_accept_only_notifications() {
let mut io = IoHandler::new();
let mut io = IoHandler::default();
let rpc = RpcImpl::default();
io.extend_with(rpc.to_delegate());

Expand Down

0 comments on commit 6ae633c

Please sign in to comment.