diff --git a/numpy/polynomial/_polybase.py b/numpy/polynomial/_polybase.py index b084f37c8d0..d0be8538db1 100644 --- a/numpy/polynomial/_polybase.py +++ b/numpy/polynomial/_polybase.py @@ -438,18 +438,18 @@ def _repr_latex_(self): # get the scaled argument string to the basis functions off, scale = self.mapparms() if off == 0 and scale == 1: - term = 'x' + term = self.symbol needs_parens = False elif scale == 1: - term = f"{self._repr_latex_scalar(off)} + x" + term = f"{self._repr_latex_scalar(off)} + {self.symbol}" needs_parens = True elif off == 0: - term = f"{self._repr_latex_scalar(scale)}x" + term = f"{self._repr_latex_scalar(scale)}{self.symbol}" needs_parens = True else: term = ( f"{self._repr_latex_scalar(off)} + " - f"{self._repr_latex_scalar(scale)}x" + f"{self._repr_latex_scalar(scale)}{self.symbol}" ) needs_parens = True @@ -485,7 +485,7 @@ def _repr_latex_(self): # in case somehow there are no coefficients at all body = '0' - return rf"$x \mapsto {body}$" + return rf"${self.symbol} \mapsto {body}$" diff --git a/numpy/polynomial/tests/test_printing.py b/numpy/polynomial/tests/test_printing.py index 8da4e75c190..60e04f51844 100644 --- a/numpy/polynomial/tests/test_printing.py +++ b/numpy/polynomial/tests/test_printing.py @@ -419,3 +419,24 @@ def test_multichar_basis_func(self): p = poly.HermiteE([1, 2, 3]) assert_equal(self.as_latex(p), r'$x \mapsto 1.0\,{He}_{0}(x) + 2.0\,{He}_{1}(x) + 3.0\,{He}_{2}(x)$') + + def test_symbol_basic(self): + # default input + p = poly.Polynomial([1, 2, 3], symbol='z') + assert_equal(self.as_latex(p), + r'$z \mapsto 1.0 + 2.0\,z + 3.0\,z^{2}$') + + # translated input + p = poly.Polynomial([1, 2, 3], domain=[-2, 0], symbol='z') + assert_equal(self.as_latex(p), + r'$z \mapsto 1.0 + 2.0\,\left(1.0 + z\right) + 3.0\,\left(1.0 + z\right)^{2}$') + + # scaled input + p = poly.Polynomial([1, 2, 3], domain=[-0.5, 0.5], symbol='z') + assert_equal(self.as_latex(p), + r'$z \mapsto 1.0 + 2.0\,\left(2.0z\right) + 3.0\,\left(2.0z\right)^{2}$') + + # affine input + p = poly.Polynomial([1, 2, 3], domain=[-1, 0], symbol='z') + assert_equal(self.as_latex(p), + r'$z \mapsto 1.0 + 2.0\,\left(1.0 + 2.0z\right) + 3.0\,\left(1.0 + 2.0z\right)^{2}$')