Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: conversion of ROOT histograms to PlottableHistogram #27

Merged
merged 5 commits into from Jun 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 28 additions & 3 deletions .github/workflows/ci.yml
Expand Up @@ -26,6 +26,7 @@ jobs:
checks:
name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }}
runs-on: ${{ matrix.runs-on }}
needs: [pre-commit]
strategy:
fail-fast: false
matrix:
Expand All @@ -36,7 +37,6 @@ jobs:
- python-version: pypy-3.7
runs-on: ubuntu-latest


steps:
- uses: actions/checkout@v2

Expand All @@ -47,16 +47,41 @@ jobs:
- name: Install package
run: python -m pip install .[test]

- name: Test package
- name: Test
run: python -m pytest -ra

root:
name: ROOT test
runs-on: ubuntu-latest
needs: [pre-commit]
defaults:
run:
shell: bash -l {0}

steps:
- uses: actions/checkout@v2

- uses: conda-incubator/setup-miniconda@v2
with:
miniforge-variant: Mambaforge
use-mamba: true
environment-file: environment.yml

- name: Install package
run: pip install .

- name: Test root
run: pytest -ra tests/test_root.py



dist:
name: Distribution build
runs-on: ubuntu-latest
needs: [pre-commit]

steps:
- uses: actions/checkout@v1
- uses: actions/checkout@v2

- name: Build sdist and wheel
run: pipx run build
Expand Down
9 changes: 9 additions & 0 deletions environment.yml
@@ -0,0 +1,9 @@
name: uhi
channels:
- conda-forge
dependencies:
- pip
- pytest
- root
- pytest
- boost-histogram
15 changes: 13 additions & 2 deletions noxfile.py
Expand Up @@ -17,7 +17,7 @@ def lint(session):
Run the linter.
"""
session.install("pre-commit")
session.run("pre-commit", "run", "--all-files")
session.run("pre-commit", "run", "--all-files", *session.posargs)


@nox.session(python=ALL_PYTHONS)
Expand All @@ -26,7 +26,7 @@ def tests(session):
Run the unit and regular tests.
"""
session.install(".[test]")
session.run("pytest")
session.run("pytest", *session.posargs)


@nox.session
Expand Down Expand Up @@ -63,3 +63,14 @@ def build(session):

session.install("build")
session.run("python", "-m", "build")


@nox.session(venv_backend="conda")
def root_tests(session):
"""
Test against ROOT.
"""

session.conda_install("--channel=conda-forge", "ROOT", "pytest", "boost-histogram")
session.install(".")
session.run("pytest", "tests/test_root.py")
175 changes: 175 additions & 0 deletions src/uhi/numpy_plottable.py
Expand Up @@ -160,6 +160,172 @@ def variances(self) -> Optional[np.ndarray]:
_: PlottableHistogram = cast(NumPyPlottableHistogram, None)


def _roottarray_asnumpy(
tarr: Any, shape: Optional[Tuple[int, ...]] = None
) -> np.ndarray:
llv = tarr.GetArray()
arr: np.ndarray = np.frombuffer(llv, dtype=llv.typecode, count=tarr.GetSize())
if shape is not None:
return np.reshape(arr, shape, order="F")
else:
return arr


class ROOTAxis:
def __init__(self, tAxis: Any) -> None:
self.tAx = tAxis

def __len__(self) -> int:
return self.tAx.GetNbins() # type: ignore

def __getitem__(self, index: int) -> Any:
pass

def __eq__(self, other: Any) -> bool:
if not isinstance(other, ROOTAxis):
return NotImplemented
return len(self) == len(other) and all(
aEdges == bEdges for aEdges, bEdges in zip(self, other)
)

def __iter__(self) -> Union[Iterator[Tuple[float, float]], Iterator[str]]:
pass

@staticmethod
def create(tAx: Any) -> Union["DiscreteROOTAxis", "ContinuousROOTAxis"]:
if all(tAx.GetBinLabel(i + 1) for i in range(tAx.GetNbins())):
return DiscreteROOTAxis(tAx)
else:
return ContinuousROOTAxis(tAx)


class ContinuousROOTAxis(ROOTAxis):
@property
def traits(self) -> PlottableTraits:
return Traits(circular=False, discrete=False)

def __getitem__(self, index: int) -> Tuple[float, float]:
return (self.tAx.GetBinLowEdge(index + 1), self.tAx.GetBinUpEdge(index + 1))

def __iter__(self) -> Iterator[Tuple[float, float]]:
for i in range(len(self)):
yield self[i]


class DiscreteROOTAxis(ROOTAxis):
@property
def traits(self) -> PlottableTraits:
return Traits(circular=False, discrete=True)

def __getitem__(self, index: int) -> str:
return self.tAx.GetBinLabel(index + 1) # type: ignore

def __iter__(self) -> Iterator[str]:
for i in range(len(self)):
yield self[i]


class ROOTPlottableHistBase:
"""Common base for ROOT histograms and TProfile"""

def __init__(self, thist: Any) -> None:
self.thist: Any = thist
nDim = thist.GetDimension()
self._shape: Tuple[int, ...] = tuple(
getattr(thist, f"GetNbins{ax}")() + 2 for ax in "XYZ"[:nDim]
)
self.axes: Tuple[Union[ContinuousROOTAxis, DiscreteROOTAxis], ...] = tuple(
ROOTAxis.create(getattr(thist, f"Get{ax}axis")()) for ax in "XYZ"[:nDim]
)

@property
def name(self) -> str:
return self.thist.GetName() # type: ignore


class ROOTPlottableHistogram(ROOTPlottableHistBase):
def __init__(self, thist: Any) -> None:
super().__init__(thist)

@property
def hasWeights(self) -> bool:
return bool(self.thist.GetSumw2() and self.thist.GetSumw2N())

@property
def kind(self) -> str:
return Kind.COUNT

def values(self) -> np.ndarray:
return _roottarray_asnumpy(self.thist, shape=self._shape)[ # type: ignore
tuple([slice(1, -1)] * len(self._shape))
]

def variances(self) -> np.ndarray:
if self.hasWeights:
return _roottarray_asnumpy(self.thist.GetSumw2(), shape=self._shape)[ # type: ignore
tuple([slice(1, -1)] * len(self._shape))
]
else:
return self.values()

def counts(self) -> np.ndarray:
if not self.hasWeights:
return self.values()

sumw = self.values()
return np.divide( # type: ignore
sumw ** 2,
self.variances(),
out=np.zeros_like(sumw, dtype=np.float64),
where=sumw != 0,
)


class ROOTPlottableProfile(ROOTPlottableHistBase):
def __init__(self, thist: Any) -> None:
super().__init__(thist)

@property
def kind(self) -> str:
return Kind.MEAN

def values(self) -> np.ndarray:
return np.array( # type: ignore
[self.thist.GetBinContent(i) for i in range(self.thist.GetNcells())]
).reshape(self._shape, order="F")[tuple([slice(1, -1)] * len(self._shape))]

def variances(self) -> np.ndarray:
return ( # type: ignore
np.array([self.thist.GetBinError(i) for i in range(self.thist.GetNcells())])
** 2
).reshape(self._shape, order="F")[tuple([slice(1, -1)] * len(self._shape))]

def counts(self) -> np.ndarray:
sumw = _roottarray_asnumpy(self.thist, shape=self._shape)[
tuple([slice(1, -1)] * len(self._shape))
]
if not (self.thist.GetSumw2() and self.thist.GetSumw2N()):
return sumw # type: ignore

sumw2 = _roottarray_asnumpy(self.thist.GetSumw2(), shape=self._shape)[
tuple([slice(1, -1)] * len(self._shape))
]
return np.divide( # type: ignore
sumw ** 2,
sumw2,
out=np.zeros_like(sumw, dtype=np.float64),
where=sumw != 0,
)


if TYPE_CHECKING:
# Verify that the above class is a valid PlottableHistogram
_axis = cast(ContinuousROOTAxis, None)
_axis2: PlottableAxisGeneric[str] = cast(DiscreteROOTAxis, None)
_ = cast(ROOTPlottableHistogram, None)
_ = cast(ROOTPlottableProfile, None)


def ensure_plottable_histogram(hist: Any) -> PlottableHistogram:
"""
Ensure a histogram follows the PlottableHistogram Protocol.
Expand Down Expand Up @@ -206,5 +372,14 @@ def ensure_plottable_histogram(hist: Any) -> PlottableHistogram:
# Standard tuple
return NumPyPlottableHistogram(*(np.asarray(h) for h in hist))

elif hasattr(hist, "InheritsFrom") and hist.InheritsFrom("TH1"):
henryiii marked this conversation as resolved.
Show resolved Hide resolved
if any(
hist.InheritsFrom(profCls)
for profCls in ("TProfile", "TProfile2D", "TProfile3D")
):
return ROOTPlottableProfile(hist)

return ROOTPlottableHistogram(hist)

else:
raise TypeError(f"Can't be used on this type of object: {hist!r}")
39 changes: 39 additions & 0 deletions tests/test_root.py
@@ -0,0 +1,39 @@
import numpy as np
import pytest
from pytest import approx

from uhi.numpy_plottable import ensure_plottable_histogram

ROOT = pytest.importorskip("ROOT")


def test_root_imported() -> None:
assert ROOT.TString("hi") == "hi"


def test_root_th1f_convert() -> None:
th = ROOT.TH1F("h1", "h1", 50, -2.5, 2.5)
th.FillRandom("gaus", 10000)
h = ensure_plottable_histogram(th)
assert all(th.GetBinContent(i + 1) == approx(iv) for i, iv in enumerate(h.values()))
assert all(
th.GetBinError(i + 1) == approx(ie)
for i, ie in enumerate(np.sqrt(h.variances())) # type: ignore
)


def test_root_th2f_convert() -> None:
th = ROOT.TH2F("h2", "h2", 50, -2.5, 2.5, 50, -2.5, 2.5)
_ = ROOT.TF2("xyg", "xygaus", -2.5, 2.5, -2.5, 2.5)
th.FillRandom("xyg", 10000)
h = ensure_plottable_histogram(th)
assert all(
th.GetBinContent(i + 1, j + 1) == approx(iv)
for i, row in enumerate(h.values())
for j, iv in enumerate(row)
)
assert all(
th.GetBinError(i + 1, j + 1) == approx(ie)
for i, row in enumerate(np.sqrt(h.variances())) # type: ignore
for j, ie in enumerate(row)
)