Skip to content

Commit

Permalink
Experiment with rate limiting at the membrane layer
Browse files Browse the repository at this point in the history
  • Loading branch information
jerel committed Mar 7, 2024
1 parent f958b8d commit 9c7072c
Show file tree
Hide file tree
Showing 11 changed files with 143 additions and 8 deletions.
42 changes: 42 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions dart_example/test/main_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -506,4 +506,17 @@ void main() {
id: Filter(value: [Match(field: "id", value: "1")]),
withinGdpr: GDPR(value: true));
});

test('test that functions can be rate limited', () async {
final contact =
Contact(id: 1, fullName: "Alice Smith", status: Status.pending);
final accounts = AccountsApi();

assert(await accounts.rateLimitedFunction(contact: contact) ==
contact.fullName);
expect(() async => await accounts.rateLimitedFunction(contact: contact),
throwsA(isA<MembraneRateLimited>()));
expect(() async => await accounts.rateLimitedFunction(contact: contact),
throwsA(isA<MembraneRateLimited>()));
});
}
2 changes: 2 additions & 0 deletions example/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ skip-codegen = ["membrane/skip-generate"]

[dependencies]
async-stream = "0.3"
derated = {path = "../../derated"}
futures = "0.3"
membrane = {path = "../membrane"}
once_cell = "*"
serde = {version = "1.0", features = ["derive"]}
serde_bytes = "0.11"
tokio = {version = "1", features = ["full"]}
Expand Down
29 changes: 29 additions & 0 deletions example/src/application/advanced.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use data::OptionsDemo;
use membrane::emitter::{emitter, Emitter, StreamEmitter};
use membrane::{async_dart, sync_dart};
use once_cell::sync::Lazy;
use tokio_stream::Stream;

use std::collections::hash_map::DefaultHasher;
// used for background threading examples
use std::{thread, time::Duration};

Expand Down Expand Up @@ -584,3 +586,30 @@ pub async fn get_org_with_borrowed_type(
pub async fn unused_duplicate_borrows(_id: i64) -> Result<data::Organization, String> {
todo!()
}

struct MyLimit(RateLimit);

impl MyLimit {
fn per_milliseconds(milliseconds: u64, max_queued: Option<u64>) -> Self {
Self(RateLimit::per_milliseconds(milliseconds, max_queued))
}

fn hash_rate_limited_function(&self, fn_name: &str, contact: &data::Contact) -> u64 {
use std::hash::{Hash, Hasher};
let mut s = DefaultHasher::new();
(fn_name, contact.id).hash(&mut s);
s.finish()
}

async fn check(&self, key: &'static str, hash: u64) -> Result<(), derated::Dropped> {
self.0.check(key, hash).await
}
}

use derated::RateLimit;
static RATE_LIMIT: Lazy<MyLimit> = Lazy::new(|| MyLimit::per_milliseconds(100, None));

#[async_dart(namespace = "accounts", rate_limit = RATE_LIMIT)]
pub async fn rate_limited_function(contact: data::Contact) -> Result<String, String> {
Ok(contact.full_name)
}
4 changes: 2 additions & 2 deletions example/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};

#[dart_enum(namespace = "accounts")]
#[dart_enum(namespace = "orgs")]
#[derive(Debug, Clone, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize, Serialize, Hash)]
pub enum Status {
Pending,
Active,
Expand Down Expand Up @@ -42,7 +42,7 @@ pub struct Mixed {
three: Option<VecWrapper>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize, Serialize, Hash)]
pub struct Contact {
pub id: i64,
pub full_name: String,
Expand Down
4 changes: 4 additions & 0 deletions membrane/src/generators/exceptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class MembraneRustPanicException extends MembraneException {
class MembraneUnknownResponseVariantException extends MembraneException {
const MembraneUnknownResponseVariantException([String? message]) : super(message);
}
class MembraneRateLimited extends MembraneException {
const MembraneRateLimited([String? message]) : super(message);
}
"#
.to_string()
}
15 changes: 12 additions & 3 deletions membrane/src/generators/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,11 @@ impl Callable for Ffi {
_log.{fine_logger}('Deserializing data from {fn_name}');
}}
final deserializer = BincodeDeserializer(data.asTypedList(length + 8).sublist(8));
if (deserializer.deserializeUint8() == MembraneMsgKind.ok) {{
final msgCode = deserializer.deserializeUint8();
if (msgCode == MembraneMsgKind.ok) {{
return {return_de};
}} else if (msgCode == MembraneMsgKind.rateLimited) {{
throw MembraneRateLimited();
}}
throw {class_name}ApiError({error_de});
}} finally {{
Expand All @@ -362,8 +365,11 @@ impl Callable for Ffi {
_log.{fine_logger}('Deserializing data from {fn_name}');
}}
final deserializer = BincodeDeserializer(input as Uint8List);
if (deserializer.deserializeUint8() == MembraneMsgKind.ok) {{
final msgCode = deserializer.deserializeUint8();
if (msgCode == MembraneMsgKind.ok) {{
return {return_de};
}} else if (msgCode == MembraneMsgKind.rateLimited) {{
throw MembraneRateLimited();
}}
throw {class_name}ApiError({error_de});
}});
Expand Down Expand Up @@ -394,8 +400,11 @@ impl Callable for Ffi {
_log.{fine_logger}('Deserializing data from {fn_name}');
}}
final deserializer = BincodeDeserializer(await _port.first{timeout} as Uint8List);
if (deserializer.deserializeUint8() == MembraneMsgKind.ok) {{
final msgCode = deserializer.deserializeUint8();
if (msgCode == MembraneMsgKind.ok) {{
return {return_de};
}} else if (msgCode == MembraneMsgKind.rateLimited) {{
throw MembraneRateLimited();
}}
throw {class_name}ApiError({error_de});
}} finally {{
Expand Down
5 changes: 4 additions & 1 deletion membrane/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ impl<'a> Membrane {
typedef enum MembraneMsgKind {
Ok,
Error,
RateLimited,
} MembraneMsgKind;
typedef enum MembraneResponseKind {
Expand Down Expand Up @@ -826,6 +827,7 @@ enums:
'Error': 'error'
'Ok': 'ok'
'Panic': 'panic'
'RateLimited': 'rateLimited'
macros:
include:
- __none__
Expand Down Expand Up @@ -1354,10 +1356,11 @@ pub struct MembraneResponse {

#[doc(hidden)]
#[repr(u8)]
#[derive(serde::Serialize)]
#[derive(serde::Serialize, PartialEq)]
pub enum MembraneMsgKind {
Ok,
Error,
RateLimited,
}

#[doc(hidden)]
Expand Down
8 changes: 8 additions & 0 deletions membrane/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@ use crate::SourceCodeLocation;
use allo_isolate::Isolate;
use serde::ser::Serialize;

pub fn send_rate_limited(isolate: Isolate) -> bool {
if let Ok(buffer) = crate::bincode::serialize(&(crate::MembraneMsgKind::RateLimited as u8)) {
isolate.post(crate::allo_isolate::ZeroCopyBuffer(buffer))
} else {
false
}
}

pub fn send<T: Serialize, E: Serialize>(isolate: Isolate, result: Result<T, E>) -> bool {
match result {
Ok(value) => {
Expand Down
21 changes: 19 additions & 2 deletions membrane_macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ fn to_token_stream(
timeout,
os_thread,
borrow,
rate_limit,
} = options;

let mut functions = TokenStream::new();
Expand Down Expand Up @@ -204,6 +205,18 @@ fn to_token_stream(
let dart_transforms: Vec<String> = DartTransforms::try_from(&inputs)?.into();
let dart_inner_args: Vec<String> = DartArgs::from(&inputs).into();

let rate_limit_condition = if let Some(limiter) = rate_limit {
let hasher_function = Ident::new(&format!("hash_{}", &rust_fn_name), Span::call_site());
quote! {
let ::std::result::Result::Err(err) = {
let hash = #limiter.#hasher_function(#rust_fn_name, #(&#rust_inner_args),*);
#limiter.check(#rust_fn_name, hash).await
}
}
} else {
quote!(false)
};

let return_statement = match output_style {
OutputStyle::EmitterSerialized | OutputStyle::StreamEmitterSerialized if sync => {
syn::Error::new(
Expand Down Expand Up @@ -281,9 +294,13 @@ fn to_token_stream(
OutputStyle::Serialized => quote! {
let membrane_join_handle = crate::RUNTIME.get().info_spawn(
async move {
let result: ::std::result::Result<#output, #error> = #fn_name(#(#rust_inner_args),*).await;
let isolate = ::membrane::allo_isolate::Isolate::new(membrane_port);
::membrane::utils::send::<#output, #error>(isolate, result);
if #rate_limit_condition {
::membrane::utils::send_rate_limited(isolate);
} else {
let result: ::std::result::Result<#output, #error> = #fn_name(#(#rust_inner_args),*).await;
::membrane::utils::send::<#output, #error>(isolate, result);
}
},
::membrane::runtime::Info { name: #rust_fn_name }
);
Expand Down
8 changes: 8 additions & 0 deletions membrane_macro/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub(crate) struct Options {
pub timeout: Option<i32>,
pub os_thread: bool,
pub borrow: Vec<String>,
pub rate_limit: Option<syn::Path>,
}

pub(crate) fn extract_options(
Expand Down Expand Up @@ -64,6 +65,13 @@ pub(crate) fn extract_options(
options.disable_logging = val.value();
options
}
Some((ident, syn::Expr::Path(syn::ExprPath { path, .. })))
if ident == "rate_limit" && !sync =>
{
options.rate_limit = Some(path);
options
}
// TODO handle the invalid rate_limit case
Some((
ident,
Lit(ExprLit {
Expand Down

0 comments on commit 9c7072c

Please sign in to comment.