Skip to content

Commit

Permalink
Merge pull request #19062 from BvB93/ctypes-plugin
Browse files Browse the repository at this point in the history
ENH: Add a mypy plugin for inferring the precision of `np.ctypeslib.c_intp`
  • Loading branch information
charris committed May 28, 2021
2 parents ca28804 + f5a5fdb commit 336cda1
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 34 deletions.
21 changes: 21 additions & 0 deletions doc/release/upcoming_changes/19062.new_feature.rst
@@ -0,0 +1,21 @@
Assign the platform-specific ``c_intp`` precision via a mypy plugin
-------------------------------------------------------------------

The mypy_ plugin, introduced in `numpy/numpy#17843`_, has again been expanded:
the plugin now is now responsible for setting the platform-specific precision
of `numpy.ctypeslib.c_intp`, the latter being used as data type for various
`numpy.ndarray.ctypes` attributes.

Without the plugin, aforementioned type will default to `ctypes.c_int64`.

To enable the plugin, one must add it to their mypy `configuration file`_:

.. code-block:: ini
[mypy]
plugins = numpy.typing.mypy_plugin
.. _mypy: http://mypy-lang.org/
.. _configuration file: https://mypy.readthedocs.io/en/stable/config_file.html
.. _`numpy/numpy#17843`: https://github.com/numpy/numpy/pull/17843
11 changes: 3 additions & 8 deletions numpy/core/_internal.pyi
Expand Up @@ -2,6 +2,7 @@ from typing import Any, TypeVar, Type, overload, Optional, Generic
import ctypes as ct

from numpy import ndarray
from numpy.ctypeslib import c_intp

_CastT = TypeVar("_CastT", bound=ct._CanCastTo) # Copied from `ctypes.cast`
_CT = TypeVar("_CT", bound=ct._CData)
Expand All @@ -15,18 +16,12 @@ class _ctypes(Generic[_PT]):
def __new__(cls, array: ndarray[Any, Any], ptr: None = ...) -> _ctypes[None]: ...
@overload
def __new__(cls, array: ndarray[Any, Any], ptr: _PT) -> _ctypes[_PT]: ...

# NOTE: In practice `shape` and `strides` return one of the concrete
# platform dependant array-types (`c_int`, `c_long` or `c_longlong`)
# corresponding to C's `int_ptr_t`, as determined by `_getintp_ctype`
# TODO: Hook this in to the mypy plugin so that a more appropiate
# `ctypes._SimpleCData[int]` sub-type can be returned
@property
def data(self) -> _PT: ...
@property
def shape(self) -> ct.Array[ct.c_int64]: ...
def shape(self) -> ct.Array[c_intp]: ...
@property
def strides(self) -> ct.Array[ct.c_int64]: ...
def strides(self) -> ct.Array[c_intp]: ...
@property
def _as_parameter_(self) -> ct.c_void_p: ...

Expand Down
9 changes: 5 additions & 4 deletions numpy/ctypeslib.pyi
@@ -1,11 +1,12 @@
from typing import List, Type
from ctypes import _SimpleCData

# NOTE: Numpy's mypy plugin is used for importing the correct
# platform-specific `ctypes._SimpleCData[int]` sub-type
from ctypes import c_int64 as _c_intp

__all__: List[str]

# TODO: Update the `npt.mypy_plugin` such that it substitutes `c_intp` for
# a specific `_SimpleCData[int]` subclass (e.g. `ctypes.c_long`)
c_intp: Type[_SimpleCData[int]]
c_intp = _c_intp

def load_library(libname, loader_path): ...
def ndpointer(dtype=..., ndim=..., shape=..., flags=...): ...
Expand Down
11 changes: 7 additions & 4 deletions numpy/typing/__init__.py
Expand Up @@ -23,15 +23,18 @@
-----------
A mypy_ plugin is distributed in `numpy.typing` for managing a number of
platform-specific annotations. Its function can be split into to parts:
platform-specific annotations. Its functionality can be split into three
distinct parts:
* Assigning the (platform-dependent) precisions of certain `~numpy.number` subclasses,
including the likes of `~numpy.int_`, `~numpy.intp` and `~numpy.longlong`.
See the documentation on :ref:`scalar types <arrays.scalars.built-in>` for a
comprehensive overview of the affected classes. without the plugin the precision
of all relevant classes will be inferred as `~typing.Any`.
comprehensive overview of the affected classes. Without the plugin the
precision of all relevant classes will be inferred as `~typing.Any`.
* Assigning the (platform-dependent) precision of `~numpy.ctypeslib.c_intp`.
Without the plugin aforementioned type will default to `ctypes.c_int64`.
* Removing all extended-precision `~numpy.number` subclasses that are unavailable
for the platform in question. Most notable this includes the likes of
for the platform in question. Most notably, this includes the likes of
`~numpy.float128` and `~numpy.complex256`. Without the plugin *all*
extended-precision types will, as far as mypy is concerned, be available
to all platforms.
Expand Down
62 changes: 47 additions & 15 deletions numpy/typing/mypy_plugin.py
Expand Up @@ -61,13 +61,29 @@ def _get_extended_precision_list() -> t.List[str]:
return [i.__name__ for i in extended_types if i.__name__ in extended_names]


def _get_c_intp_name() -> str:
# Adapted from `np.core._internal._getintp_ctype`
char = np.dtype('p').char
if char == 'i':
return "c_int"
elif char == 'l':
return "c_long"
elif char == 'q':
return "c_longlong"
else:
return "c_long"


#: A dictionary mapping type-aliases in `numpy.typing._nbit` to
#: concrete `numpy.typing.NBitBase` subclasses.
_PRECISION_DICT: t.Final = _get_precision_dict()

#: A list with the names of all extended precision `np.number` subclasses.
_EXTENDED_PRECISION_LIST: t.Final = _get_extended_precision_list()

#: The name of the ctypes quivalent of `np.intp`
_C_INTP: t.Final = _get_c_intp_name()


def _hook(ctx: AnalyzeTypeContext) -> Type:
"""Replace a type-alias with a concrete ``NBitBase`` subclass."""
Expand All @@ -87,8 +103,23 @@ def _index(iterable: t.Iterable[Statement], id: str) -> int:
raise ValueError("Failed to identify a `ImportFrom` instance "
f"with the following id: {id!r}")

def _override_imports(
file: MypyFile,
module: str,
imports: t.List[t.Tuple[str, t.Optional[str]]],
) -> None:
"""Override the first `module`-based import with new `imports`."""
# Construct a new `from module import y` statement
import_obj = ImportFrom(module, 0, names=imports)
import_obj.is_top_level = True

# Replace the first `module`-based import statement with `import_obj`
for lst in [file.defs, file.imports]: # type: t.List[Statement]
i = _index(lst, module)
lst[i] = import_obj

class _NumpyPlugin(Plugin):
"""A plugin for assigning platform-specific `numpy.number` precisions."""
"""A mypy plugin for handling versus numpy-specific typing tasks."""

def get_type_analyze_hook(self, fullname: str) -> None | _HookFunc:
"""Set the precision of platform-specific `numpy.number` subclasses.
Expand All @@ -100,25 +131,26 @@ def get_type_analyze_hook(self, fullname: str) -> None | _HookFunc:
return None

def get_additional_deps(self, file: MypyFile) -> t.List[t.Tuple[int, str, int]]:
"""Import platform-specific extended-precision `numpy.number` subclasses.
"""Handle all import-based overrides.
* Import platform-specific extended-precision `numpy.number`
subclasses (*e.g.* `numpy.float96`, `numpy.float128` and
`numpy.complex256`).
* Import the appropriate `ctypes` equivalent to `numpy.intp`.
For example: `numpy.float96`, `numpy.float128` and `numpy.complex256`.
"""
ret = [(PRI_MED, file.fullname, -1)]

if file.fullname == "numpy":
# Import ONLY the extended precision types available to the
# platform in question
imports = ImportFrom(
"numpy.typing._extended_precision", 0,
names=[(v, v) for v in _EXTENDED_PRECISION_LIST],
_override_imports(
file, "numpy.typing._extended_precision",
imports=[(v, v) for v in _EXTENDED_PRECISION_LIST],
)
elif file.fullname == "numpy.ctypeslib":
_override_imports(
file, "ctypes",
imports=[(_C_INTP, "_c_intp")],
)
imports.is_top_level = True

# Replace the much broader extended-precision import
# (defined in `numpy/__init__.pyi`) with a more specific one
for lst in [file.defs, file.imports]: # type: t.List[Statement]
i = _index(lst, "numpy.typing._extended_precision")
lst[i] = imports
return ret

def plugin(version: str) -> t.Type[_NumpyPlugin]:
Expand Down
3 changes: 3 additions & 0 deletions numpy/typing/tests/data/reveal/ctypeslib.py
@@ -0,0 +1,3 @@
import numpy as np

reveal_type(np.ctypeslib.c_intp()) # E: {c_intp}
4 changes: 2 additions & 2 deletions numpy/typing/tests/data/reveal/ndarray_misc.py
Expand Up @@ -23,8 +23,8 @@ class SubClass(np.ndarray): ...
ctypes_obj = AR_f8.ctypes

reveal_type(ctypes_obj.data) # E: int
reveal_type(ctypes_obj.shape) # E: ctypes.Array[ctypes.c_int64]
reveal_type(ctypes_obj.strides) # E: ctypes.Array[ctypes.c_int64]
reveal_type(ctypes_obj.shape) # E: ctypes.Array[{c_intp}]
reveal_type(ctypes_obj.strides) # E: ctypes.Array[{c_intp}]
reveal_type(ctypes_obj._as_parameter_) # E: ctypes.c_void_p

reveal_type(ctypes_obj.data_as(ct.c_void_p)) # E: ctypes.c_void_p
Expand Down
9 changes: 8 additions & 1 deletion numpy/typing/tests/test_typing.py
Expand Up @@ -8,7 +8,11 @@

import pytest
import numpy as np
from numpy.typing.mypy_plugin import _PRECISION_DICT, _EXTENDED_PRECISION_LIST
from numpy.typing.mypy_plugin import (
_PRECISION_DICT,
_EXTENDED_PRECISION_LIST,
_C_INTP,
)

try:
from mypy import api
Expand Down Expand Up @@ -219,6 +223,9 @@ def _construct_format_dict():

# numpy.typing
"_NBitInt": dct['_NBitInt'],

# numpy.ctypeslib
"c_intp": f"ctypes.{_C_INTP}"
}


Expand Down

0 comments on commit 336cda1

Please sign in to comment.