Skip to content

Commit

Permalink
feat: conversion of ROOT histograms to PlottableHistogram (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
pieterdavid committed Jun 13, 2021
1 parent 74562b6 commit c2e481e
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 5 deletions.
31 changes: 28 additions & 3 deletions .github/workflows/ci.yml
Expand Up @@ -26,6 +26,7 @@ jobs:
checks:
name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }}
runs-on: ${{ matrix.runs-on }}
needs: [pre-commit]
strategy:
fail-fast: false
matrix:
Expand All @@ -36,7 +37,6 @@ jobs:
- python-version: pypy-3.7
runs-on: ubuntu-latest


steps:
- uses: actions/checkout@v2

Expand All @@ -47,16 +47,41 @@ jobs:
- name: Install package
run: python -m pip install .[test]

- name: Test package
- name: Test
run: python -m pytest -ra

root:
name: ROOT test
runs-on: ubuntu-latest
needs: [pre-commit]
defaults:
run:
shell: bash -l {0}

steps:
- uses: actions/checkout@v2

- uses: conda-incubator/setup-miniconda@v2
with:
miniforge-variant: Mambaforge
use-mamba: true
environment-file: environment.yml

- name: Install package
run: pip install .

- name: Test root
run: pytest -ra tests/test_root.py



dist:
name: Distribution build
runs-on: ubuntu-latest
needs: [pre-commit]

steps:
- uses: actions/checkout@v1
- uses: actions/checkout@v2

- name: Build sdist and wheel
run: pipx run build
Expand Down
9 changes: 9 additions & 0 deletions environment.yml
@@ -0,0 +1,9 @@
name: uhi
channels:
- conda-forge
dependencies:
- pip
- pytest
- root
- pytest
- boost-histogram
15 changes: 13 additions & 2 deletions noxfile.py
Expand Up @@ -17,7 +17,7 @@ def lint(session):
Run the linter.
"""
session.install("pre-commit")
session.run("pre-commit", "run", "--all-files")
session.run("pre-commit", "run", "--all-files", *session.posargs)


@nox.session(python=ALL_PYTHONS)
Expand All @@ -26,7 +26,7 @@ def tests(session):
Run the unit and regular tests.
"""
session.install(".[test]")
session.run("pytest")
session.run("pytest", *session.posargs)


@nox.session
Expand Down Expand Up @@ -63,3 +63,14 @@ def build(session):

session.install("build")
session.run("python", "-m", "build")


@nox.session(venv_backend="conda")
def root_tests(session):
"""
Test against ROOT.
"""

session.conda_install("--channel=conda-forge", "ROOT", "pytest", "boost-histogram")
session.install(".")
session.run("pytest", "tests/test_root.py")
175 changes: 175 additions & 0 deletions src/uhi/numpy_plottable.py
Expand Up @@ -160,6 +160,172 @@ def variances(self) -> Optional[np.ndarray]:
_: PlottableHistogram = cast(NumPyPlottableHistogram, None)


def _roottarray_asnumpy(
tarr: Any, shape: Optional[Tuple[int, ...]] = None
) -> np.ndarray:
llv = tarr.GetArray()
arr: np.ndarray = np.frombuffer(llv, dtype=llv.typecode, count=tarr.GetSize())
if shape is not None:
return np.reshape(arr, shape, order="F")
else:
return arr


class ROOTAxis:
def __init__(self, tAxis: Any) -> None:
self.tAx = tAxis

def __len__(self) -> int:
return self.tAx.GetNbins() # type: ignore

def __getitem__(self, index: int) -> Any:
pass

def __eq__(self, other: Any) -> bool:
if not isinstance(other, ROOTAxis):
return NotImplemented
return len(self) == len(other) and all(
aEdges == bEdges for aEdges, bEdges in zip(self, other)
)

def __iter__(self) -> Union[Iterator[Tuple[float, float]], Iterator[str]]:
pass

@staticmethod
def create(tAx: Any) -> Union["DiscreteROOTAxis", "ContinuousROOTAxis"]:
if all(tAx.GetBinLabel(i + 1) for i in range(tAx.GetNbins())):
return DiscreteROOTAxis(tAx)
else:
return ContinuousROOTAxis(tAx)


class ContinuousROOTAxis(ROOTAxis):
@property
def traits(self) -> PlottableTraits:
return Traits(circular=False, discrete=False)

def __getitem__(self, index: int) -> Tuple[float, float]:
return (self.tAx.GetBinLowEdge(index + 1), self.tAx.GetBinUpEdge(index + 1))

def __iter__(self) -> Iterator[Tuple[float, float]]:
for i in range(len(self)):
yield self[i]


class DiscreteROOTAxis(ROOTAxis):
@property
def traits(self) -> PlottableTraits:
return Traits(circular=False, discrete=True)

def __getitem__(self, index: int) -> str:
return self.tAx.GetBinLabel(index + 1) # type: ignore

def __iter__(self) -> Iterator[str]:
for i in range(len(self)):
yield self[i]


class ROOTPlottableHistBase:
"""Common base for ROOT histograms and TProfile"""

def __init__(self, thist: Any) -> None:
self.thist: Any = thist
nDim = thist.GetDimension()
self._shape: Tuple[int, ...] = tuple(
getattr(thist, f"GetNbins{ax}")() + 2 for ax in "XYZ"[:nDim]
)
self.axes: Tuple[Union[ContinuousROOTAxis, DiscreteROOTAxis], ...] = tuple(
ROOTAxis.create(getattr(thist, f"Get{ax}axis")()) for ax in "XYZ"[:nDim]
)

@property
def name(self) -> str:
return self.thist.GetName() # type: ignore


class ROOTPlottableHistogram(ROOTPlottableHistBase):
def __init__(self, thist: Any) -> None:
super().__init__(thist)

@property
def hasWeights(self) -> bool:
return bool(self.thist.GetSumw2() and self.thist.GetSumw2N())

@property
def kind(self) -> str:
return Kind.COUNT

def values(self) -> np.ndarray:
return _roottarray_asnumpy(self.thist, shape=self._shape)[ # type: ignore
tuple([slice(1, -1)] * len(self._shape))
]

def variances(self) -> np.ndarray:
if self.hasWeights:
return _roottarray_asnumpy(self.thist.GetSumw2(), shape=self._shape)[ # type: ignore
tuple([slice(1, -1)] * len(self._shape))
]
else:
return self.values()

def counts(self) -> np.ndarray:
if not self.hasWeights:
return self.values()

sumw = self.values()
return np.divide( # type: ignore
sumw ** 2,
self.variances(),
out=np.zeros_like(sumw, dtype=np.float64),
where=sumw != 0,
)


class ROOTPlottableProfile(ROOTPlottableHistBase):
def __init__(self, thist: Any) -> None:
super().__init__(thist)

@property
def kind(self) -> str:
return Kind.MEAN

def values(self) -> np.ndarray:
return np.array( # type: ignore
[self.thist.GetBinContent(i) for i in range(self.thist.GetNcells())]
).reshape(self._shape, order="F")[tuple([slice(1, -1)] * len(self._shape))]

def variances(self) -> np.ndarray:
return ( # type: ignore
np.array([self.thist.GetBinError(i) for i in range(self.thist.GetNcells())])
** 2
).reshape(self._shape, order="F")[tuple([slice(1, -1)] * len(self._shape))]

def counts(self) -> np.ndarray:
sumw = _roottarray_asnumpy(self.thist, shape=self._shape)[
tuple([slice(1, -1)] * len(self._shape))
]
if not (self.thist.GetSumw2() and self.thist.GetSumw2N()):
return sumw # type: ignore

sumw2 = _roottarray_asnumpy(self.thist.GetSumw2(), shape=self._shape)[
tuple([slice(1, -1)] * len(self._shape))
]
return np.divide( # type: ignore
sumw ** 2,
sumw2,
out=np.zeros_like(sumw, dtype=np.float64),
where=sumw != 0,
)


if TYPE_CHECKING:
# Verify that the above class is a valid PlottableHistogram
_axis = cast(ContinuousROOTAxis, None)
_axis2: PlottableAxisGeneric[str] = cast(DiscreteROOTAxis, None)
_ = cast(ROOTPlottableHistogram, None)
_ = cast(ROOTPlottableProfile, None)


def ensure_plottable_histogram(hist: Any) -> PlottableHistogram:
"""
Ensure a histogram follows the PlottableHistogram Protocol.
Expand Down Expand Up @@ -206,5 +372,14 @@ def ensure_plottable_histogram(hist: Any) -> PlottableHistogram:
# Standard tuple
return NumPyPlottableHistogram(*(np.asarray(h) for h in hist))

elif hasattr(hist, "InheritsFrom") and hist.InheritsFrom("TH1"):
if any(
hist.InheritsFrom(profCls)
for profCls in ("TProfile", "TProfile2D", "TProfile3D")
):
return ROOTPlottableProfile(hist)

return ROOTPlottableHistogram(hist)

else:
raise TypeError(f"Can't be used on this type of object: {hist!r}")
39 changes: 39 additions & 0 deletions tests/test_root.py
@@ -0,0 +1,39 @@
import numpy as np
import pytest
from pytest import approx

from uhi.numpy_plottable import ensure_plottable_histogram

ROOT = pytest.importorskip("ROOT")


def test_root_imported() -> None:
assert ROOT.TString("hi") == "hi"


def test_root_th1f_convert() -> None:
th = ROOT.TH1F("h1", "h1", 50, -2.5, 2.5)
th.FillRandom("gaus", 10000)
h = ensure_plottable_histogram(th)
assert all(th.GetBinContent(i + 1) == approx(iv) for i, iv in enumerate(h.values()))
assert all(
th.GetBinError(i + 1) == approx(ie)
for i, ie in enumerate(np.sqrt(h.variances())) # type: ignore
)


def test_root_th2f_convert() -> None:
th = ROOT.TH2F("h2", "h2", 50, -2.5, 2.5, 50, -2.5, 2.5)
_ = ROOT.TF2("xyg", "xygaus", -2.5, 2.5, -2.5, 2.5)
th.FillRandom("xyg", 10000)
h = ensure_plottable_histogram(th)
assert all(
th.GetBinContent(i + 1, j + 1) == approx(iv)
for i, row in enumerate(h.values())
for j, iv in enumerate(row)
)
assert all(
th.GetBinError(i + 1, j + 1) == approx(ie)
for i, row in enumerate(np.sqrt(h.variances())) # type: ignore
for j, ie in enumerate(row)
)

0 comments on commit c2e481e

Please sign in to comment.