Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bind macro #1427

Merged
merged 2 commits into from Dec 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions rusqlite-macros/Cargo.toml
Expand Up @@ -14,3 +14,4 @@ proc-macro = true
[dependencies]
sqlite3-parser = { version = "0.12", default-features = false, features = ["YYNOERRORRECOVERY"] }
fallible-iterator = "0.3"
litrs = { version = "0.4", default-features = false }
40 changes: 8 additions & 32 deletions rusqlite-macros/src/lib.rs
@@ -1,6 +1,7 @@
//! Private implementation details of `rusqlite`.

use proc_macro::{Delimiter, Group, Literal, Span, TokenStream, TokenTree};
use litrs::StringLit;
use proc_macro::{Group, Span, TokenStream, TokenTree};

use fallible_iterator::FallibleIterator;
use sqlite3_parser::ast::{ParameterInfo, ToTokens};
Expand All @@ -25,15 +26,12 @@ fn try_bind(input: TokenStream) -> Result<TokenStream> {
(stmt, literal)
};

let literal = match into_literal(&literal) {
Some(it) => it,
None => return Err("expected a plain string literal".to_string()),
let call_site = literal.span();
let string_lit = match StringLit::try_from(literal) {
Ok(string_lit) => string_lit,
Err(e) => return Ok(e.to_compile_error()),
};
let sql = literal.to_string();
if !sql.starts_with('"') {
return Err("expected a plain string literal".to_string());
}
let sql = strip_matches(&sql, "\"");
let sql = string_lit.value();

let mut parser = Parser::new(sql.as_bytes());
let ast = match parser.next() {
Expand All @@ -48,13 +46,12 @@ fn try_bind(input: TokenStream) -> Result<TokenStream> {
return Err(err.to_string());
}
if info.count == 0 {
return Ok(input);
return Ok(TokenStream::new());
}
if info.count as usize != info.names.len() {
return Err("Mixing named and numbered parameters is not supported.".to_string());
}

let call_site = literal.span();
let mut res = TokenStream::new();
for (i, name) in info.names.iter().enumerate() {
res.extend(Some(stmt.clone()));
Expand All @@ -71,27 +68,6 @@ fn try_bind(input: TokenStream) -> Result<TokenStream> {
Ok(res)
}

fn into_literal(ts: &TokenTree) -> Option<Literal> {
match ts {
TokenTree::Literal(l) => Some(l.clone()),
TokenTree::Group(g) => match g.delimiter() {
Delimiter::None => match g.stream().into_iter().collect::<Vec<_>>().as_slice() {
[TokenTree::Literal(l)] => Some(l.clone()),
_ => None,
},
Delimiter::Parenthesis | Delimiter::Brace | Delimiter::Bracket => None,
},
_ => None,
}
}

fn strip_matches<'a>(s: &'a str, pattern: &str) -> &'a str {
s.strip_prefix(pattern)
.unwrap_or(s)
.strip_suffix(pattern)
.unwrap_or(s)
}

fn respan(ts: TokenStream, span: Span) -> TokenStream {
let mut res = TokenStream::new();
for tt in ts {
Expand Down
13 changes: 10 additions & 3 deletions rusqlite-macros/tests/test.rs
Expand Up @@ -20,13 +20,20 @@ fn test_literal() -> Result {
Ok(())
}

/* FIXME
#[test]
fn test_no_placeholder() {
let _stmt = Stmt;
__bind!(_stmt "SELECT 1");
}

#[test]
fn test_raw_string() {
let stmt = ();
__bind!(stmt r#"SELECT 1"#);
let _stmt = Stmt;
__bind!(_stmt r"SELECT 1");
__bind!(_stmt r#"SELECT 1"#);
}

/* FIXME
#[test]
fn test_const() {
const SQL: &str = "SELECT 1";
Expand Down