Skip to content

Commit

Permalink
fast path for ASCII python strings (#72)
Browse files Browse the repository at this point in the history
Co-authored-by: David Hewitt <david.hewitt@pydantic.dev>
  • Loading branch information
samuelcolvin and davidhewitt committed Apr 2, 2024
1 parent cbf2e30 commit 37f9138
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 33 deletions.
2 changes: 1 addition & 1 deletion jiter-python/Cargo.toml
Expand Up @@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2021"

[dependencies]
pyo3 = { version = "0.21.0-beta.0", features = ["num-bigint", "auto-initialize"] }
pyo3 = { version = "0.21.0", features = ["num-bigint", "auto-initialize"] }
jiter = { path = "..", features = ["python"] }

[features]
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Expand Up @@ -20,6 +20,6 @@ pub use parse::Peek;
pub use value::{JsonArray, JsonObject, JsonValue};

#[cfg(feature = "python")]
pub use py_string_cache::{cache_clear, cache_usage, cached_py_string, StringCacheMode};
pub use py_string_cache::{cache_clear, cache_usage, cached_py_string, pystring_fast_new, StringCacheMode};
#[cfg(feature = "python")]
pub use python::{map_json_error, python_parse};
55 changes: 38 additions & 17 deletions src/py_string_cache.rs
Expand Up @@ -2,6 +2,7 @@ use std::cell::RefCell;

use ahash::random_state::RandomState;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::ffi;
use pyo3::prelude::*;
use pyo3::sync::{GILOnceCell, GILProtected};
use pyo3::types::{PyBool, PyString};
Expand Down Expand Up @@ -45,38 +46,38 @@ impl From<bool> for StringCacheMode {
}

pub trait StringMaybeCache {
fn get_key<'py>(py: Python<'py>, json_str: &str) -> Bound<'py, PyString>;
fn get_key<'py>(py: Python<'py>, json_str: &str, ascii_only: bool) -> Bound<'py, PyString>;

fn get_value<'py>(py: Python<'py>, json_str: &str) -> Bound<'py, PyString> {
Self::get_key(py, json_str)
fn get_value<'py>(py: Python<'py>, json_str: &str, ascii_only: bool) -> Bound<'py, PyString> {
Self::get_key(py, json_str, ascii_only)
}
}

pub struct StringCacheAll;

impl StringMaybeCache for StringCacheAll {
fn get_key<'py>(py: Python<'py>, json_str: &str) -> Bound<'py, PyString> {
cached_py_string(py, json_str)
fn get_key<'py>(py: Python<'py>, json_str: &str, ascii_only: bool) -> Bound<'py, PyString> {
cached_py_string(py, json_str, ascii_only)
}
}

pub struct StringCacheKeys;

impl StringMaybeCache for StringCacheKeys {
fn get_key<'py>(py: Python<'py>, json_str: &str) -> Bound<'py, PyString> {
cached_py_string(py, json_str)
fn get_key<'py>(py: Python<'py>, json_str: &str, ascii_only: bool) -> Bound<'py, PyString> {
cached_py_string(py, json_str, ascii_only)
}

fn get_value<'py>(py: Python<'py>, json_str: &str) -> Bound<'py, PyString> {
PyString::new_bound(py, json_str)
fn get_value<'py>(py: Python<'py>, json_str: &str, ascii_only: bool) -> Bound<'py, PyString> {
pystring_fast_new(py, json_str, ascii_only)
}
}

pub struct StringNoCache;

impl StringMaybeCache for StringNoCache {
fn get_key<'py>(py: Python<'py>, json_str: &str) -> Bound<'py, PyString> {
PyString::new_bound(py, json_str)
fn get_key<'py>(py: Python<'py>, json_str: &str, ascii_only: bool) -> Bound<'py, PyString> {
pystring_fast_new(py, json_str, ascii_only)
}
}

Expand All @@ -98,12 +99,12 @@ pub fn cache_clear(py: Python) {
get_string_cache!(py).borrow_mut().clear()
}

pub fn cached_py_string<'py>(py: Python<'py>, raw_str: &str) -> Bound<'py, PyString> {
pub fn cached_py_string<'py>(py: Python<'py>, s: &str, ascii_only: bool) -> Bound<'py, PyString> {
// from tests, 0 and 1 character strings are faster not cached
if (2..64).contains(&raw_str.len()) {
get_string_cache!(py).borrow_mut().get_or_insert(py, raw_str)
if (2..64).contains(&s.len()) {
get_string_cache!(py).borrow_mut().get_or_insert(py, s, ascii_only)
} else {
PyString::new_bound(py, raw_str)
pystring_fast_new(py, s, ascii_only)
}
}

Expand Down Expand Up @@ -135,13 +136,13 @@ impl Default for PyStringCache {
impl PyStringCache {
/// Lookup the cache for an entry with the given string. If it exists, return it.
/// If it is not set or has a different string, insert it and return it.
fn get_or_insert<'py>(&mut self, py: Python<'py>, s: &str) -> Bound<'py, PyString> {
fn get_or_insert<'py>(&mut self, py: Python<'py>, s: &str, ascii_only: bool) -> Bound<'py, PyString> {
let hash = self.hash_builder.hash_one(s);

let hash_index = hash as usize % CAPACITY;

let set_entry = |entry: &mut Entry| {
let py_str = PyString::new_bound(py, s);
let py_str = pystring_fast_new(py, s, ascii_only);
*entry = Some((hash, py_str.to_owned().unbind()));
py_str
};
Expand Down Expand Up @@ -183,3 +184,23 @@ impl PyStringCache {
self.entries.fill(None);
}
}

pub fn pystring_fast_new<'py>(py: Python<'py>, s: &str, ascii_only: bool) -> Bound<'py, PyString> {
if ascii_only {
unsafe { pystring_ascii_new(py, s) }
} else {
PyString::new_bound(py, s)
}
}

/// Faster creation of PyString from an ASCII string, inspired by
/// https://github.com/ijl/orjson/blob/3.10.0/src/str/create.rs#L41
unsafe fn pystring_ascii_new<'py>(py: Python<'py>, s: &str) -> Bound<'py, PyString> {
let ptr = ffi::PyUnicode_New(s.len() as isize, 127);
// see https://github.com/pydantic/jiter/pull/72#discussion_r1545485907
debug_assert_eq!(ffi::PyUnicode_KIND(ptr), ffi::PyUnicode_1BYTE_KIND);
let data_ptr = ffi::PyUnicode_DATA(ptr).cast();
core::ptr::copy_nonoverlapping(s.as_ptr(), data_ptr, s.len());
core::ptr::write(data_ptr.add(s.len()), 0);
Bound::from_owned_ptr(py, ptr).downcast_into_unchecked()
}
6 changes: 3 additions & 3 deletions src/python.rs
Expand Up @@ -85,7 +85,7 @@ impl<'j> PythonParser<'j> {
}
Peek::String => {
let s = self.parser.consume_string::<StringDecoder>(&mut self.tape)?;
Ok(StringCache::get_value(py, s.as_str()).into_any())
Ok(StringCache::get_value(py, s.as_str(), s.ascii_only()).into_any())
}
Peek::Array => {
let peek_first = match self.parser.array_first() {
Expand Down Expand Up @@ -162,12 +162,12 @@ impl<'j> PythonParser<'j> {
}
};
if let Some(first_key) = self.parser.object_first::<StringDecoder>(&mut self.tape)? {
let first_key = StringCache::get_key(py, first_key.as_str());
let first_key = StringCache::get_key(py, first_key.as_str(), first_key.ascii_only());
let peek = self.parser.peek()?;
let first_value = self._check_take_value::<StringCache>(py, peek)?;
set_item(first_key, first_value);
while let Some(key) = self.parser.object_step::<StringDecoder>(&mut self.tape)? {
let key = StringCache::get_key(py, key.as_str());
let key = StringCache::get_key(py, key.as_str(), key.ascii_only());
let peek = self.parser.peek()?;
let value = self._check_take_value::<StringCache>(py, peek)?;
set_item(key, value);
Expand Down
27 changes: 17 additions & 10 deletions src/string_decoder.rs
Expand Up @@ -25,33 +25,40 @@ pub enum StringOutput<'t, 'j>
where
'j: 't,
{
Tape(&'t str),
Data(&'j str),
Tape(&'t str, bool),
Data(&'j str, bool),
}

impl From<StringOutput<'_, '_>> for String {
fn from(val: StringOutput) -> Self {
match val {
StringOutput::Tape(s) => s.to_owned(),
StringOutput::Data(s) => s.to_owned(),
StringOutput::Tape(s, _) => s.to_owned(),
StringOutput::Data(s, _) => s.to_owned(),
}
}
}

impl<'t, 'j> From<StringOutput<'t, 'j>> for Cow<'j, str> {
fn from(val: StringOutput<'t, 'j>) -> Self {
match val {
StringOutput::Tape(s) => s.to_owned().into(),
StringOutput::Data(s) => s.into(),
StringOutput::Tape(s, _) => s.to_owned().into(),
StringOutput::Data(s, _) => s.into(),
}
}
}

impl<'t, 'j> StringOutput<'t, 'j> {
pub fn as_str(&self) -> &'t str {
match self {
Self::Tape(s) => s,
Self::Data(s) => s,
Self::Tape(s, _) => s,
Self::Data(s, _) => s,
}
}

pub fn ascii_only(&self) -> bool {
match self {
Self::Tape(_, ascii_only) => *ascii_only,
Self::Data(_, ascii_only) => *ascii_only,
}
}
}
Expand Down Expand Up @@ -143,7 +150,7 @@ where
CharType::Quote => {
let s = to_str(&data[start..index], ascii_only, start)?;
index += 1;
return Ok((StringOutput::Data(s), index));
return Ok((StringOutput::Data(s, ascii_only), index));
}
CharType::Backslash => return decode_to_tape(data, index, tape, start, ascii_only),
CharType::ControlChar => return json_err!(ControlCharacterWhileParsingString, index),
Expand Down Expand Up @@ -204,7 +211,7 @@ fn decode_to_tape<'t, 'j>(
tape.extend_from_slice(&data[last_escape..index]);
index += 1;
let s = to_str(tape, ascii_only, start)?;
return Ok((StringOutput::Tape(s), index));
return Ok((StringOutput::Tape(s, ascii_only), index));
}
CharType::Backslash => on_backslash!(),
CharType::ControlChar => return json_err!(ControlCharacterWhileParsingString, index),
Expand Down
50 changes: 49 additions & 1 deletion tests/python.rs
Expand Up @@ -2,7 +2,7 @@ use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString};
use pyo3::ToPyObject;

use jiter::{cache_clear, cache_usage, map_json_error, python_parse, JsonValue, StringCacheMode};
use jiter::{cache_clear, cache_usage, map_json_error, pystring_fast_new, python_parse, JsonValue, StringCacheMode};

#[test]
fn test_to_py_object_numeric() {
Expand Down Expand Up @@ -269,3 +269,51 @@ fn test_cache_into() {
);
})
}

#[test]
fn test_use_tape() {
let json = r#" "foo\nbar" "#;
Python::with_gil(|py| {
cache_clear(py);
let obj = python_parse(py, json.as_bytes(), false, StringCacheMode::None, false).unwrap();
assert_eq!(obj.to_string(), "foo\nbar");
})
}

#[test]
fn test_unicode() {
let json = r#"{"💩": "£"}"#;
Python::with_gil(|py| {
cache_clear(py);
let obj = python_parse(py, json.as_bytes(), false, StringCacheMode::None, false).unwrap();
assert_eq!(obj.to_string(), "{'💩': '£'}");
})
}

#[test]
fn test_unicode_cache() {
let json = r#"{"💩": "£"}"#;
Python::with_gil(|py| {
cache_clear(py);
let obj = python_parse(py, json.as_bytes(), false, StringCacheMode::All, false).unwrap();
assert_eq!(obj.to_string(), "{'💩': '£'}");
})
}

#[test]
fn test_pystring_fast_new_non_ascii() {
let json = "£100 💩";
Python::with_gil(|py| {
let s = pystring_fast_new(py, json, false);
assert_eq!(s.to_string(), "£100 💩");
})
}

#[test]
fn test_pystring_fast_new_ascii() {
let json = "100abc";
Python::with_gil(|py| {
let s = pystring_fast_new(py, json, true);
assert_eq!(s.to_string(), "100abc");
})
}

0 comments on commit 37f9138

Please sign in to comment.