Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

give primitive types a faster key&value compare, copy, and zero methods in TypedDict #9520

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions docs/upcoming_changes/9520.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
support primitive types faster compare, copy, and zero methods in dict
----------------------------------------------------------------------

Support primitive types to have faster key and value compare, copy, and zero
methods in TypedDict. This is a performance optimization for TypedDict.
60 changes: 38 additions & 22 deletions numba/cext/dictobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,13 @@ set_index(NB_DictKeys *dk, Py_ssize_t i, Py_ssize_t ix)

/* INV_USABLE_FRACTION gives the inverse of USABLE_FRACTION.
* Used for sizing a new dictionary to a specified number of keys.
*
*
* NOTE: If the denominator of the USABLE_FRACTION ratio is not a power
* of 2, must add 1 to the result of the inverse for correct sizing.
*
*
* For example, when USABLE_FRACTION ratio = 5/8 (8 is a power of 2):
* #define INV_USABLE_FRACTION(n) (((n) << 3)/5) // inv_ratio: 8/5
*
*
* When USABLE_FRACTION ratio = 5/7 (7 is not a power of 2):
* #define INV_USABLE_FRACTION(n) ((7*(n))/5 + 1) // inv_ratio: 7/5
*/
Expand Down Expand Up @@ -420,23 +420,39 @@ get_entry(NB_DictKeys *dk, Py_ssize_t idx) {
}

static void
zero_key(NB_DictKeys *dk, char *data){
memset(data, 0, dk->key_size);
key_zero(NB_DictKeys *dk, char *data){
if ( dk->methods.key_zero ) {
return dk->methods.key_zero(data);
} else {
memset(data, 0, dk->key_size);
}
}

static void
zero_val(NB_DictKeys *dk, char *data){
memset(data, 0, dk->val_size);
value_zero(NB_DictKeys *dk, char *data){
if ( dk->methods.value_zero ) {
return dk->methods.value_zero(data);
} else {
memset(data, 0, dk->val_size);
}
}

static void
copy_key(NB_DictKeys *dk, char *dst, const char *src){
memcpy(dst, src, dk->key_size);
key_copy(NB_DictKeys *dk, char *dst, const char *src){
if ( dk->methods.key_copy ) {
dk->methods.key_copy(dst, src);
} else {
memcpy(dst, src, dk->key_size);
}
}

static void
copy_val(NB_DictKeys *dk, char *dst, const char *src){
memcpy(dst, src, dk->val_size);
value_copy(NB_DictKeys *dk, char *dst, const char *src){
if ( dk->methods.value_copy ) {
dk->methods.value_copy(dst, src);
} else {
memcpy(dst, src, dk->val_size);
}
}

/* Returns -1 for error; 0 for not equal; 1 for equal */
Expand Down Expand Up @@ -641,7 +657,7 @@ numba_dict_lookup(NB_Dict *d, const char *key_bytes, Py_hash_t hash, char *oldva
for (;;) {
Py_ssize_t ix = get_index(dk, i);
if (ix == DKIX_EMPTY) {
zero_val(dk, oldval_bytes);
value_zero(dk, oldval_bytes);
return ix;
}
if (ix >= 0) {
Expand All @@ -659,7 +675,7 @@ numba_dict_lookup(NB_Dict *d, const char *key_bytes, Py_hash_t hash, char *oldva
}
if (cmp > 0) {
// key is equal; retrieve the value.
copy_val(dk, oldval_bytes, entry_get_val(dk, ep));
value_copy(dk, oldval_bytes, entry_get_val(dk, ep));
return ix;
}
}
Expand Down Expand Up @@ -734,10 +750,10 @@ numba_dict_insert(
hashpos = find_empty_slot(dk, hash);
ep = get_entry(dk, dk->nentries);
set_index(dk, hashpos, dk->nentries);
copy_key(dk, entry_get_key(dk, ep), key_bytes);
key_copy(dk, entry_get_key(dk, ep), key_bytes);
assert ( hash != -1 );
ep->hash = hash;
copy_val(dk, entry_get_val(dk, ep), val_bytes);
value_copy(dk, entry_get_val(dk, ep), val_bytes);

/* incref */
dk_incref_key(dk, key_bytes);
Expand All @@ -753,7 +769,7 @@ numba_dict_insert(
/* decref old value */
dk_decref_val(dk, oldval_bytes);
// Replace the previous value
copy_val(dk, entry_get_val(dk, get_entry(dk, ix)), val_bytes);
value_copy(dk, entry_get_val(dk, get_entry(dk, ix)), val_bytes);

/* incref */
dk_incref_val(dk, val_bytes);
Expand Down Expand Up @@ -893,8 +909,8 @@ numba_dict_delitem(NB_Dict *d, Py_hash_t hash, Py_ssize_t ix)
dk_decref_val(dk, entry_get_val(dk, ep));

/* zero the entries */
zero_key(dk, entry_get_key(dk, ep));
zero_val(dk, entry_get_val(dk, ep));
key_zero(dk, entry_get_key(dk, ep));
value_zero(dk, entry_get_val(dk, ep));
ep->hash = DKIX_EMPTY; // to mark it as empty;

return OK;
Expand Down Expand Up @@ -931,11 +947,11 @@ numba_dict_popitem(NB_Dict *d, char *key_bytes, char *val_bytes)
key_ptr = entry_get_key(d->keys, ep);
val_ptr = entry_get_val(d->keys, ep);

copy_key(d->keys, key_bytes, key_ptr);
copy_val(d->keys, val_bytes, val_ptr);
key_copy(d->keys, key_bytes, key_ptr);
value_copy(d->keys, val_bytes, val_ptr);

zero_key(d->keys, key_ptr);
zero_val(d->keys, val_ptr);
key_zero(d->keys, key_ptr);
value_zero(d->keys, val_ptr);

/* We can't dk_usable++ since there is DKIX_DUMMY in indices */
d->keys->nentries = i;
Expand Down
9 changes: 9 additions & 0 deletions numba/cext/dictobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,24 @@ typedef struct {


typedef int (*dict_key_comparator_t)(const char *lhs, const char *rhs);
typedef void (*dict_copy_op_t)(char *dst, const char *src);
typedef void (*dict_zero_op_t)(char *data);
typedef void (*dict_refcount_op_t)(const void*);


typedef struct {
/* these five funcs are for all types */
dict_key_comparator_t key_equal;
dict_copy_op_t key_copy;
dict_zero_op_t key_zero;
dict_copy_op_t value_copy;
dict_zero_op_t value_zero;
/* if key or value is a container type, then need inc/dec ref */
dict_refcount_op_t key_incref;
dict_refcount_op_t key_decref;
dict_refcount_op_t value_incref;
dict_refcount_op_t value_decref;

} type_based_methods_table;


Expand Down
2 changes: 1 addition & 1 deletion numba/core/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def from_argument(self, builder, val):
@register_default(types.Boolean)
@register_default(types.BooleanLiteral)
class BooleanModel(DataModel):
_bit_type = ir.IntType(1)
be_type = _bit_type = ir.IntType(1)
dlee992 marked this conversation as resolved.
Show resolved Hide resolved
_byte_type = ir.IntType(8)

def get_value_type(self):
Expand Down
115 changes: 89 additions & 26 deletions numba/typed/dictobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
from numba.typed.typedobjectutils import (_as_bytes, _cast, _nonoptional,
_sentry_safe_cast_default,
_get_incref_decref,
_get_equal, _container_get_data,)
_get_container_equal,
_container_get_data,
_get_primitive_equal,
_get_copy, _get_zero,)

ll_dict_type = cgutils.voidptr_t
ll_dictiter_type = cgutils.voidptr_t
Expand Down Expand Up @@ -262,46 +265,57 @@ def codegen(context, builder, sig, args):

@intrinsic
def _dict_set_method_table(typingctx, dp, keyty, valty):
"""Wrap numba_dict_set_method_table
"""
"""Wrap numba_dict_set_method_table"""
resty = types.void
sig = resty(dp, keyty, valty)

def codegen(context, builder, sig, args):
vtablety = ir.LiteralStructType([
ll_voidptr_type, # equal
ll_voidptr_type, # key incref
ll_voidptr_type, # key decref
ll_voidptr_type, # val incref
ll_voidptr_type, # val decref
])
# correspond to type_based_methods_table in dictobject.h
vtablety = ir.LiteralStructType(
[
ll_voidptr_type, # key equal
ll_voidptr_type, # key copy
ll_voidptr_type, # key zero
ll_voidptr_type, # value copy
ll_voidptr_type, # value zero
ll_voidptr_type, # key incref
ll_voidptr_type, # key decref
ll_voidptr_type, # value incref
ll_voidptr_type, # value decref
]
)
setmethod_fnty = ir.FunctionType(
ir.VoidType(),
[ll_dict_type, vtablety.as_pointer()]
ir.VoidType(), [ll_dict_type, vtablety.as_pointer()]
)
setmethod_fn = ir.Function(
builder.module,
setmethod_fnty,
name='numba_dict_set_method_table',
name="numba_dict_set_method_table",
)
dp = args[0]
vtable = cgutils.alloca_once(builder, vtablety, zfill=True)

# install key incref/decref
# install type_based_methods_table
key_equal_ptr = cgutils.gep_inbounds(builder, vtable, 0, 0)
key_incref_ptr = cgutils.gep_inbounds(builder, vtable, 0, 1)
key_decref_ptr = cgutils.gep_inbounds(builder, vtable, 0, 2)
val_incref_ptr = cgutils.gep_inbounds(builder, vtable, 0, 3)
val_decref_ptr = cgutils.gep_inbounds(builder, vtable, 0, 4)
key_copy_ptr = cgutils.gep_inbounds(builder, vtable, 0, 1)
key_zero_ptr = cgutils.gep_inbounds(builder, vtable, 0, 2)
value_copy_ptr = cgutils.gep_inbounds(builder, vtable, 0, 3)
value_zero_ptr = cgutils.gep_inbounds(builder, vtable, 0, 4)
# install inc/dec refs
key_incref_ptr = cgutils.gep_inbounds(builder, vtable, 0, 5)
key_decref_ptr = cgutils.gep_inbounds(builder, vtable, 0, 6)
value_incref_ptr = cgutils.gep_inbounds(builder, vtable, 0, 7)
value_decref_ptr = cgutils.gep_inbounds(builder, vtable, 0, 8)

dm_key = context.data_model_manager[keyty.instance_type]
if dm_key.contains_nrt_meminfo():
equal = _get_equal(context, builder.module, dm_key, 'dict_key')
key_equal = _get_container_equal(
context, builder.module, dm_key)
key_incref, key_decref = _get_incref_decref(
context, builder.module, dm_key, 'dict_key'
context, builder.module, dm_key, "dict_key"
)
builder.store(
builder.bitcast(equal, key_equal_ptr.type.pointee),
builder.bitcast(key_equal, key_equal_ptr.type.pointee),
key_equal_ptr,
)
builder.store(
Expand All @@ -313,18 +327,67 @@ def codegen(context, builder, sig, args):
key_decref_ptr,
)

if isinstance(dm_key, models.PrimitiveModel):
# dm_key is a primitive type, e.g., int64
# generate key_equal, key_copy, key_zero
key_equal = _get_primitive_equal(
context,
builder.module,
dm_key,
)
key_copy = _get_copy(
context,
builder.module,
dm_key,
"dict_key",
)
key_zero = _get_zero(
context,
builder.module,
dm_key,
"dict_key",
)
builder.store(
builder.bitcast(key_equal, key_equal_ptr.type.pointee),
key_equal_ptr,
)
builder.store(
builder.bitcast(key_copy, key_copy_ptr.type.pointee),
key_copy_ptr,
)
builder.store(
builder.bitcast(key_zero, key_zero_ptr.type.pointee),
key_zero_ptr,
)

dm_val = context.data_model_manager[valty.instance_type]
if dm_val.contains_nrt_meminfo():
val_incref, val_decref = _get_incref_decref(
context, builder.module, dm_val, 'dict_value'
context, builder.module, dm_val, "dict_value"
)
builder.store(
builder.bitcast(val_incref, val_incref_ptr.type.pointee),
val_incref_ptr,
builder.bitcast(val_incref, value_incref_ptr.type.pointee),
value_incref_ptr,
)
builder.store(
builder.bitcast(val_decref, value_decref_ptr.type.pointee),
value_decref_ptr,
)

if isinstance(dm_val, models.PrimitiveModel):
# dm_val doesn't contain meminfo, is a primitive type, e.g., int64
# generate value_copy, value_zero
value_copy = _get_copy(
context, builder.module, dm_val, "dict_value")
value_zero = _get_zero(
context, builder.module, dm_val, "dict_value")
builder.store(
builder.bitcast(value_copy, value_copy_ptr.type.pointee),
value_copy_ptr,
)
builder.store(
builder.bitcast(val_decref, val_decref_ptr.type.pointee),
val_decref_ptr,
builder.bitcast(value_zero, value_zero_ptr.type.pointee),
value_zero_ptr,
)

builder.call(setmethod_fn, [dp, vtable])
Expand Down