Skip to content

Commit

Permalink
Add Script::load function (#603)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiburt committed Jun 9, 2022
1 parent a188202 commit bb329dd
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 30 deletions.
81 changes: 51 additions & 30 deletions src/script.rs
Expand Up @@ -4,6 +4,7 @@ use sha1_smol::Sha1;
use crate::cmd::cmd;
use crate::connection::ConnectionLike;
use crate::types::{ErrorKind, FromRedisValue, RedisResult, ToRedisArgs};
use crate::Cmd;

/// Represents a lua script.
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -124,26 +125,15 @@ impl<'a> ScriptInvocation<'a> {
/// Invokes the script and returns the result.
#[inline]
pub fn invoke<T: FromRedisValue>(&self, con: &mut dyn ConnectionLike) -> RedisResult<T> {
loop {
match cmd("EVALSHA")
.arg(self.script.hash.as_bytes())
.arg(self.keys.len())
.arg(&*self.keys)
.arg(&*self.args)
.query(con)
{
Ok(val) => {
return Ok(val);
}
Err(err) => {
if err.kind() == ErrorKind::NoScriptError {
cmd("SCRIPT")
.arg("LOAD")
.arg(self.script.code.as_bytes())
.query(con)?;
} else {
fail!(err);
}
let eval_cmd = self.eval_cmd();
match eval_cmd.query(con) {
Ok(val) => Ok(val),
Err(err) => {
if err.kind() == ErrorKind::NoScriptError {
self.load_cmd().query(con)?;
eval_cmd.query(con)
} else {
Err(err)
}
}
}
Expand All @@ -157,15 +147,7 @@ impl<'a> ScriptInvocation<'a> {
C: crate::aio::ConnectionLike,
T: FromRedisValue,
{
let mut eval_cmd = cmd("EVALSHA");
eval_cmd
.arg(self.script.hash.as_bytes())
.arg(self.keys.len())
.arg(&*self.keys)
.arg(&*self.args);

let mut load_cmd = cmd("SCRIPT");
load_cmd.arg("LOAD").arg(self.script.code.as_bytes());
let eval_cmd = self.eval_cmd();
match eval_cmd.query_async(con).await {
Ok(val) => {
// Return the value from the script evaluation
Expand All @@ -174,12 +156,51 @@ impl<'a> ScriptInvocation<'a> {
Err(err) => {
// Load the script into Redis if the script hash wasn't there already
if err.kind() == ErrorKind::NoScriptError {
load_cmd.query_async(con).await?;
self.load_cmd().query_async(con).await?;
eval_cmd.query_async(con).await
} else {
Err(err)
}
}
}
}

/// Loads the script and returns the SHA1 of it.
#[inline]
pub fn load(&self, con: &mut dyn ConnectionLike) -> RedisResult<String> {
let hash: String = self.load_cmd().query(con)?;

debug_assert_eq!(hash, self.script.hash);

Ok(hash)
}

/// Asynchronously loads the script and returns the SHA1 of it.
#[inline]
#[cfg(feature = "aio")]
pub async fn load_async<C>(&self, con: &mut C) -> RedisResult<String>
where
C: crate::aio::ConnectionLike,
{
let hash: String = self.load_cmd().query_async(con).await?;

debug_assert_eq!(hash, self.script.hash);

Ok(hash)
}

fn load_cmd(&self) -> Cmd {
let mut cmd = cmd("SCRIPT");
cmd.arg("LOAD").arg(self.script.code.as_bytes());
cmd
}

fn eval_cmd(&self) -> Cmd {
let mut cmd = cmd("EVALSHA");
cmd.arg(self.script.hash.as_bytes())
.arg(self.keys.len())
.arg(&*self.keys)
.arg(&*self.args);
cmd
}
}
14 changes: 14 additions & 0 deletions tests/test_async.rs
Expand Up @@ -350,6 +350,20 @@ fn test_script() {
.unwrap();
}

#[test]
#[cfg(feature = "script")]
fn test_script_load() {
let ctx = TestContext::new();
let script = redis::Script::new("return 'Hello World'");

block_on_all(async move {
let mut con = ctx.multiplexed_async_connection().await.unwrap();

let hash = script.prepare_invoke().load_async(&mut con).await.unwrap();
assert_eq!(hash, script.get_hash().to_string());
});
}

#[test]
#[cfg(feature = "script")]
fn test_script_returning_complex_type() {
Expand Down
16 changes: 16 additions & 0 deletions tests/test_async_async_std.rs
Expand Up @@ -286,6 +286,22 @@ fn test_script() {
.unwrap();
}

#[test]
#[cfg(feature = "script")]
fn test_script_load() {
let ctx = TestContext::new();
let mut con = ctx.connection();

let script = redis::Script::new("return 'Hello World'");

block_on_all(async move {
let mut con = ctx.multiplexed_async_connection_async_std().await.unwrap();

let hash = script.prepare_invoke().load_async(&mut con).await.unwrap();
assert_eq!(hash, script.get_hash().to_string());
});
}

#[test]
#[cfg(feature = "script")]
fn test_script_returning_complex_type() {
Expand Down
13 changes: 13 additions & 0 deletions tests/test_basic.rs
Expand Up @@ -723,6 +723,19 @@ fn test_script() {
assert_eq!(response, Ok(("foo".to_string(), 42)));
}

#[test]
#[cfg(feature = "script")]
fn test_script_load() {
let ctx = TestContext::new();
let mut con = ctx.connection();

let script = redis::Script::new("return 'Hello World'");

let hash = script.prepare_invoke().load(&mut con);

assert_eq!(hash, Ok(script.get_hash().to_string()));
}

#[test]
fn test_tuple_args() {
let ctx = TestContext::new();
Expand Down

0 comments on commit bb329dd

Please sign in to comment.