diff --git a/README.md b/README.md index 2ea9a420..f481118b 100644 --- a/README.md +++ b/README.md @@ -295,6 +295,20 @@ gives: +-----------+------+------------+-----------------+ ``` +##### Other alignment + +You can also change the alignment of the headers of the columns using the `header_align` +attribute in the same way as the `align` attribute. + +```python +x.header_align["City name"] = "l" +print(x) +``` + +The `valign` attribute works in the same way as the other align attributes for +controlling vertical alignment. It accepts the values `"t"`, `"m"` and `"b"` for top, +middle, and bottom respectively. + ##### Sorting your table by a field You can make sure that your ASCII tables are produced with the data sorted by one diff --git a/src/prettytable/prettytable.py b/src/prettytable/prettytable.py index 2f24599b..89bff650 100644 --- a/src/prettytable/prettytable.py +++ b/src/prettytable/prettytable.py @@ -135,6 +135,7 @@ def __init__(self, field_names=None, **kwargs): self._rows = [] self.align = {} self.valign = {} + self.header_align = {} self.max_width = {} self.min_width = {} self.int_format = {} @@ -189,6 +190,7 @@ def __init__(self, field_names=None, **kwargs): "bottom_left_junction_char", "align", "valign", + "header_align", "max_width", "min_width", "none_format", @@ -231,6 +233,7 @@ def __init__(self, field_names=None, **kwargs): # Column specific arguments, use property.setters self.align = kwargs["align"] or {} self.valign = kwargs["valign"] or {} + self.header_align = kwargs["header_align"] or {} self.max_width = kwargs["max_width"] or {} self.min_width = kwargs["min_width"] or {} self.int_format = kwargs["int_format"] or {} @@ -615,6 +618,15 @@ def field_names(self, val): self._align[field_name] = self._align[BASE_ALIGN_VALUE] else: self.align = "c" + if self._header_align and old_names: + for old_name, new_name in zip(old_names, val): + self._header_align[new_name] = self._header_align[old_name] + for old_name in old_names: + if old_name in self._header_align and old_name not in val: + self._header_align.pop(old_name) + elif self._header_align: + for field_name in self._field_names: + self._header_align[field_name] = self._header_align[BASE_ALIGN_VALUE] if self._valign and old_names: for old_name, new_name in zip(old_names, val): self._valign[new_name] = self._valign[old_name] @@ -648,6 +660,30 @@ def align(self, val): for field in self._field_names: self._align[field] = val + @property + def header_align(self): + """Controls alignment of header fields + Arguments: + + header_align - header alignment, one of "l", "c", or "r" """ + return self._header_align + + @header_align.setter + def header_align(self, val): + if val is None or (isinstance(val, dict) and len(val) == 0): + if not self._field_names: + self._header_align = {BASE_ALIGN_VALUE: "c"} + else: + for field in self._field_names: + self._header_align[field] = "c" + else: + self._validate_align(val) + if not self._field_names: + self._header_align = {BASE_ALIGN_VALUE: val} + else: + for field in self._field_names: + self._header_align[field] = val + @property def valign(self): """Controls vertical alignment of fields @@ -1428,7 +1464,7 @@ def del_row(self, row_index): ) del self._rows[row_index] - def add_column(self, fieldname, column, align="c", valign="t"): + def add_column(self, fieldname, column, align="c", valign="t", header_align="c"): """Add a column to the table. @@ -1444,10 +1480,12 @@ def add_column(self, fieldname, column, align="c", valign="t"): if len(self._rows) in (0, len(column)): self._validate_align(align) + self._validate_align(header_align) self._validate_valign(valign) self._field_names.append(fieldname) self._align[fieldname] = align self._valign[fieldname] = valign + self._header_align[fieldname] = header_align for i in range(0, len(column)): if len(self._rows) < i + 1: self._rows.append([]) @@ -1465,6 +1503,7 @@ def add_autoindex(self, fieldname="Index"): self._field_names.insert(0, fieldname) self._align[fieldname] = self.align self._valign[fieldname] = self.valign + self._header_align[fieldname] = self.header_align for i, row in enumerate(self._rows): row.insert(0, i + 1) @@ -1850,7 +1889,7 @@ def _stringify_header(self, options): fieldname = fieldname[:width] bits.append( " " * lpad - + self._justify(fieldname, width, self._align[field]) + + self._justify(fieldname, width, self._header_align[field]) + " " * rpad ) if options["border"] or options["preserve_internal_border"]: @@ -1858,7 +1897,6 @@ def _stringify_header(self, options): bits.append(options["vertical_char"]) else: bits.append(" ") - # If only preserve_internal_border is true, then we just appended # a vertical character at the end when we wanted a space if not options["border"] and options["preserve_internal_border"]: @@ -2178,8 +2216,15 @@ def _get_formatted_html_string(self, options): if options["fields"] and field not in options["fields"]: continue lines.append( - ' %s' # noqa: E501 - % (lpad, rpad, escape(field).replace("\n", linebreak)) + ' %s' # noqa: E501 + % ( + lpad, + rpad, + {"l": "left", "r": "right", "c": "center"}[ + self._header_align[field] + ], + escape(field).replace("\n", linebreak), + ) # noqa: E501 ) lines.append(" ") lines.append(" ") diff --git a/tests/test_prettytable.py b/tests/test_prettytable.py index 7ddce3e4..f14c57ab 100644 --- a/tests/test_prettytable.py +++ b/tests/test_prettytable.py @@ -331,6 +331,7 @@ def test_add_field_names_later(self, field_name_less_table: prettytable): def aligned_before_table(): x = PrettyTable() x.align = "r" + x.header_align = "r" x.field_names = ["City name", "Area", "Population", "Annual Rainfall"] x.add_row(["Adelaide", 1295, 1158259, 600.5]) x.add_row(["Brisbane", 5905, 1857594, 1146.4]) @@ -354,6 +355,7 @@ def aligned_after_table(): x.add_row(["Melbourne", 1566, 3806092, 646.9]) x.add_row(["Perth", 5386, 1554769, 869.4]) x.align = "r" + x.header_align = "r" return x @@ -1435,6 +1437,41 @@ def test_style_align(self, style, expected): result = t.get_string() assert result.strip() == expected.strip() + @pytest.mark.parametrize( + "style, expected", + [ + pytest.param( + DEFAULT, + """ ++---------+--------+--------+ +| L | C | R | ++---------+--------+--------+ +| value 1 | value2 | value3 | +| value 4 | value5 | value6 | +| value 7 | value8 | value9 | ++---------+--------+--------+ +""", + id="MARKDOWN", + ), + ], + ) + def test_style_header_align(self, style, expected): + # Arrange + t = helper_table() + t.field_names = ["L", "C", "R"] + + assert t.header_align["L"] == "c" + + # Act + t.set_style(style) + t.header_align["L"] = "l" + t.header_align["C"] = "c" + t.header_align["R"] = "r" + + # Assert + result = t.get_string() + assert result.strip() == expected.strip() + class TestCsvOutput: def test_csv_output(self):