Skip to content

Commit

Permalink
only import enum module once
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelhly committed Oct 12, 2023
1 parent 62c0691 commit da22d87
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
7 changes: 5 additions & 2 deletions src/input/input_python.rs
Expand Up @@ -21,7 +21,10 @@ use super::datetime::{
float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime,
EitherTime,
};
use super::shared::{decimal_as_int, float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int};
use super::shared::{
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, map_json_err, str_as_bool, str_as_float,
str_as_int,
};
use super::{
py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments,
GenericIterable, GenericIterator, GenericMapping, Input, JsonInput, PyArgs,
Expand Down Expand Up @@ -765,7 +768,7 @@ fn maybe_as_string(v: &PyAny, unicode_error: ErrorType) -> ValResult<Option<Cow<
/// Utility for extracting an enum value, if possible.
fn maybe_as_enum<'a>(v: &'a PyAny) -> Option<&'a PyAny> {
let py = v.py();
let enum_meta_object = py.import("enum").unwrap().getattr("EnumMeta").unwrap().to_object(py);
let enum_meta_object = get_enum_meta_object(py);
let meta_type = v.get_type().get_type();
if meta_type.is(&enum_meta_object) {
v.getattr(intern!(py, "value")).ok()
Expand Down
16 changes: 15 additions & 1 deletion src/input/shared.rs
@@ -1,11 +1,25 @@
use num_bigint::BigInt;
use pyo3::{intern, PyAny, Python};
use pyo3::sync::GILOnceCell;
use pyo3::{intern, Py, PyAny, Python, ToPyObject};

use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult};

use super::parse_json::{JsonArray, JsonInput};
use super::{EitherFloat, EitherInt, Input};

static ENUM_META_OBJECT: GILOnceCell<Py<PyAny>> = GILOnceCell::new();

pub fn get_enum_meta_object(py: Python) -> Py<PyAny> {
ENUM_META_OBJECT
.get_or_init(py, || {
py.import("enum")
.and_then(|decimal_module| decimal_module.getattr("EnumMeta"))
.unwrap()
.to_object(py)
})
.clone()
}

pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: serde_json::Error) -> ValError<'a> {
ValError::new(
ErrorType::JsonInvalid {
Expand Down

0 comments on commit da22d87

Please sign in to comment.