Skip to content

Commit

Permalink
Merge pull request astropy#12614 from mhvk/refactor-mixin-tests
Browse files Browse the repository at this point in the history
Refactor mixin tests, to avoid duplication between ECSV, FITS, HDF5
  • Loading branch information
taldcroft committed Dec 31, 2021
2 parents 82edac5 + 1b6d9f3 commit dc588da
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 353 deletions.
105 changes: 4 additions & 101 deletions astropy/io/ascii/tests/test_ecsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@
import numpy as np
import yaml

from astropy.table import Table, Column, QTable, NdarrayMixin
from astropy.table import Table, Column, QTable
from astropy.table.table_helpers import simple_table
from astropy.coordinates import (SkyCoord, Latitude, Longitude, Angle, EarthLocation,
SphericalRepresentation, CartesianRepresentation,
SphericalCosLatDifferential)
from astropy.time import Time, TimeDelta
from astropy.units import allclose as quantity_allclose
from astropy.units import QuantityInfo

Expand All @@ -30,6 +26,7 @@
from astropy.io import ascii
from astropy import units as u

from astropy.io.tests.mixin_columns import mixin_cols, compare_attrs
from .common import TEST_DIR

DTYPES = ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32',
Expand Down Expand Up @@ -271,101 +268,6 @@ def assert_objects_equal(obj1, obj2, attrs, compare_class=True):
assert np.all(obj1 == obj2)


# TODO: unify with the very similar tests in fits/tests/test_connect.py
# and misc/tests/test_hd5f.py.
el = EarthLocation(x=[1, 2] * u.km, y=[3, 4] * u.km, z=[5, 6] * u.km)
sr = SphericalRepresentation(
[0, 1]*u.deg, [2, 3]*u.deg, 1*u.kpc)
cr = CartesianRepresentation(
[0, 1]*u.pc, [4, 5]*u.pc, [8, 6]*u.pc)
sd = SphericalCosLatDifferential(
[0, 1]*u.mas/u.yr, [0, 1]*u.mas/u.yr, 10*u.km/u.s)
srd = SphericalRepresentation(sr, differentials=sd)
sc = SkyCoord([1, 2], [3, 4], unit='deg,deg', frame='fk4',
obstime='J1990.5')
scd = SkyCoord([1, 2], [3, 4], [5, 6], unit='deg,deg,m', frame='fk4',
obstime=['J1990.5'] * 2)
scdc = scd.copy()
scdc.representation_type = 'cartesian'
scpm = SkyCoord([1, 2], [3, 4], [5, 6], unit='deg,deg,pc',
pm_ra_cosdec=[7, 8]*u.mas/u.yr, pm_dec=[9, 10]*u.mas/u.yr)
scpmrv = SkyCoord([1, 2], [3, 4], [5, 6], unit='deg,deg,pc',
pm_ra_cosdec=[7, 8]*u.mas/u.yr, pm_dec=[9, 10]*u.mas/u.yr,
radial_velocity=[11, 12]*u.km/u.s)
scrv = SkyCoord([1, 2], [3, 4], [5, 6], unit='deg,deg,pc',
radial_velocity=[11, 12]*u.km/u.s)
tm = Time([51000.5, 51001.5], format='mjd', scale='tai', precision=5, location=el[0])
tm2 = Time(tm, format='iso')
tm3 = Time(tm, location=el)
tm3.info.serialize_method['ecsv'] = 'jd1_jd2'
obj = Column([{'a': 1}, {'b': [2]}], dtype='object')

# NOTE: in the test below the name of the column "x" for the Quantity is
# important since it tests the fix for #10215 (namespace clash, where "x"
# clashes with "el.x").
mixin_cols = {
'tm': tm,
'tm2': tm2,
'tm3': tm3,
'dt': TimeDelta([1, 2] * u.day),
'sc': sc,
'scd': scd,
'scdc': scdc,
'scpm': scpm,
'scpmrv': scpmrv,
'scrv': scrv,
'x': [1, 2] * u.m,
'qdb': [10, 20] * u.dB(u.mW),
'qdex': [4.5, 5.5] * u.dex(u.cm / u.s**2),
'qmag': [21, 22] * u.ABmag,
'lat': Latitude([1, 2] * u.deg),
'lon': Longitude([1, 2] * u.deg, wrap_angle=180. * u.deg),
'ang': Angle([1, 2] * u.deg),
'el': el,
'sr': sr,
'cr': cr,
'sd': sd,
'srd': srd,
'nd': NdarrayMixin([1, 2]),
'obj': obj
}

time_attrs = ['value', 'shape', 'format', 'scale', 'precision',
'in_subfmt', 'out_subfmt', 'location']
compare_attrs = {
'c1': ['data'],
'c2': ['data'],
'tm': time_attrs,
'tm2': time_attrs,
'tm3': time_attrs,
'dt': ['shape', 'value', 'format', 'scale'],
'sc': ['ra', 'dec', 'representation_type', 'frame.name'],
'scd': ['ra', 'dec', 'distance', 'representation_type', 'frame.name'],
'scdc': ['x', 'y', 'z', 'representation_type', 'frame.name'],
'scpm': ['ra', 'dec', 'distance', 'pm_ra_cosdec', 'pm_dec',
'representation_type', 'frame.name'],
'scpmrv': ['ra', 'dec', 'distance', 'pm_ra_cosdec', 'pm_dec',
'radial_velocity', 'representation_type', 'frame.name'],
'scrv': ['ra', 'dec', 'distance', 'radial_velocity', 'representation_type',
'frame.name'],
'x': ['value', 'unit'],
'qdb': ['value', 'unit'],
'qdex': ['value', 'unit'],
'qmag': ['value', 'unit'],
'lon': ['value', 'unit', 'wrap_angle'],
'lat': ['value', 'unit'],
'ang': ['value', 'unit'],
'el': ['x', 'y', 'z', 'ellipsoid'],
'nd': ['data'],
'sr': ['lon', 'lat', 'distance'],
'cr': ['x', 'y', 'z'],
'sd': ['d_lon_coslat', 'd_lat', 'd_distance'],
'srd': ['lon', 'lat', 'distance', 'differentials.s.d_lon_coslat',
'differentials.s.d_lat', 'differentials.s.d_distance'],
'obj': []
}


def test_ecsv_mixins_ascii_read_class():
"""Ensure that ascii.read(ecsv_file) returns the correct class
(QTable if any Quantity subclasses, Table otherwise).
Expand Down Expand Up @@ -508,7 +410,8 @@ def test_ecsv_mixins_per_column(table_cls, name_col, ndim):

for colname in t.colnames:
assert len(t2[colname].shape) == ndim
assert_objects_equal(t[colname], t2[colname], compare_attrs[colname])
compare = ['data'] if colname in ('c1', 'c2') else compare_attrs[colname]
assert_objects_equal(t[colname], t2[colname], compare)

# Special case to make sure Column type doesn't leak into Time class data
if name.startswith('tm'):
Expand Down
2 changes: 1 addition & 1 deletion astropy/io/fits/fitstime.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def time_to_fits(table):
hdr.extend([Card(keyword=f'OBSGEO-{dim.upper()}',
value=getattr(location, dim).to_value(u.m))
for dim in ('x', 'y', 'z')])
elif location != col.location:
elif np.any(location != col.location):
raise ValueError('Multiple Time Columns with different geocentric '
'observatory locations ({}, {}) encountered.'
'This is not supported by the FITS standard.'
Expand Down
156 changes: 39 additions & 117 deletions astropy/io/fits/tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,25 @@
AstropyDeprecationWarning)
from astropy.utils.misc import _NOT_OVERWRITING_MSG_MATCH

from astropy.coordinates import (SkyCoord, Latitude, Longitude, Angle, EarthLocation,
SphericalRepresentation, CartesianRepresentation,
SphericalCosLatDifferential)
from astropy.time import Time, TimeDelta
from astropy.time import Time
from astropy.units.quantity import QuantityInfo

from astropy.io.tests.mixin_columns import mixin_cols, compare_attrs, serialized_names


# FITS does not preserve precision, in_subfmt, and out_subfmt.
time_attrs = ['value', 'shape', 'format', 'scale', 'location']
compare_attrs = {name: (time_attrs if isinstance(col, Time) else compare_attrs[name])
for name, col in mixin_cols.items()}
# FITS does not support multi-element location, array with object dtype,
# or logarithmic quantities.
unsupported_cols = {name: col for name, col in mixin_cols.items()
if (isinstance(col, Time) and col.location.shape != ()
or isinstance(col, np.ndarray) and col.dtype.kind == 'O'
or isinstance(col, u.LogQuantity))}
mixin_cols = {name: col for name, col in mixin_cols.items()
if name not in unsupported_cols}


def equal_data(a, b):
for name in a.dtype.names:
Expand Down Expand Up @@ -713,85 +726,6 @@ def assert_objects_equal(obj1, obj2, attrs, compare_class=True):
else:
assert np.all(a1 == a2)

# Testing FITS table read/write with mixins. This is mostly
# copied from ECSV mixin testing. Analogous tests also exist for HDF5.


el = EarthLocation(x=1 * u.km, y=3 * u.km, z=5 * u.km)
el2 = EarthLocation(x=[1, 2] * u.km, y=[3, 4] * u.km, z=[5, 6] * u.km)
sr = SphericalRepresentation(
[0, 1]*u.deg, [2, 3]*u.deg, 1*u.kpc)
cr = CartesianRepresentation(
[0, 1]*u.pc, [4, 5]*u.pc, [8, 6]*u.pc)
sd = SphericalCosLatDifferential(
[0, 1]*u.mas/u.yr, [0, 1]*u.mas/u.yr, 10*u.km/u.s)
srd = SphericalRepresentation(sr, differentials=sd)
sc = SkyCoord([1, 2], [3, 4], unit='deg,deg', frame='fk4',
obstime='J1990.5')
scd = SkyCoord([1, 2], [3, 4], [5, 6], unit='deg,deg,m', frame='fk4',
obstime=['J1990.5', 'J1991.5'])
scdc = scd.copy()
scdc.representation_type = 'cartesian'
scpm = SkyCoord([1, 2], [3, 4], [5, 6], unit='deg,deg,pc',
pm_ra_cosdec=[7, 8]*u.mas/u.yr, pm_dec=[9, 10]*u.mas/u.yr)
scpmrv = SkyCoord([1, 2], [3, 4], [5, 6], unit='deg,deg,pc',
pm_ra_cosdec=[7, 8]*u.mas/u.yr, pm_dec=[9, 10]*u.mas/u.yr,
radial_velocity=[11, 12]*u.km/u.s)
scrv = SkyCoord([1, 2], [3, 4], [5, 6], unit='deg,deg,pc',
radial_velocity=[11, 12]*u.km/u.s)
tm = Time([2450814.5, 2450815.5], format='jd', scale='tai', location=el)

# NOTE: in the test below the name of the column "x" for the Quantity is
# important since it tests the fix for #10215 (namespace clash, where "x"
# clashes with "el2.x").
mixin_cols = {
'tm': tm,
'dt': TimeDelta([1, 2] * u.day),
'sc': sc,
'scd': scd,
'scdc': scdc,
'scpm': scpm,
'scpmrv': scpmrv,
'scrv': scrv,
'x': [1, 2] * u.m,
'lat': Latitude([1, 2] * u.deg),
'lon': Longitude([1, 2] * u.deg, wrap_angle=180. * u.deg),
'ang': Angle([1, 2] * u.deg),
'el2': el2,
'sr': sr,
'cr': cr,
'sd': sd,
'srd': srd,
}

time_attrs = ['value', 'shape', 'format', 'scale', 'location']
compare_attrs = {
'c1': ['data'],
'c2': ['data'],
'tm': time_attrs,
'dt': ['shape', 'value', 'format', 'scale'],
'sc': ['ra', 'dec', 'representation_type', 'frame.name'],
'scd': ['ra', 'dec', 'distance', 'representation_type', 'frame.name'],
'scdc': ['x', 'y', 'z', 'representation_type', 'frame.name'],
'scpm': ['ra', 'dec', 'distance', 'pm_ra_cosdec', 'pm_dec',
'representation_type', 'frame.name'],
'scpmrv': ['ra', 'dec', 'distance', 'pm_ra_cosdec', 'pm_dec',
'radial_velocity', 'representation_type', 'frame.name'],
'scrv': ['ra', 'dec', 'distance', 'radial_velocity', 'representation_type',
'frame.name'],
'x': ['value', 'unit'],
'lon': ['value', 'unit', 'wrap_angle'],
'lat': ['value', 'unit'],
'ang': ['value', 'unit'],
'el2': ['x', 'y', 'z', 'ellipsoid'],
'nd': ['x', 'y', 'z'],
'sr': ['lon', 'lat', 'distance'],
'cr': ['x', 'y', 'z'],
'sd': ['d_lon_coslat', 'd_lat', 'd_distance'],
'srd': ['lon', 'lat', 'distance', 'differentials.s.d_lon_coslat',
'differentials.s.d_lat', 'differentials.s.d_distance'],
}


def test_fits_mixins_qtable_to_table(tmpdir):
"""Test writing as QTable and reading as Table. Ensure correct classes
Expand Down Expand Up @@ -835,35 +769,12 @@ def test_fits_mixins_as_one(table_cls, tmpdir):
"""Test write/read all cols at once and validate intermediate column names"""
filename = str(tmpdir.join('test_simple.fits'))
names = sorted(mixin_cols)

serialized_names = ['ang',
'cr.x', 'cr.y', 'cr.z',
'dt.jd1', 'dt.jd2',
'el2.x', 'el2.y', 'el2.z',
'lat',
'lon',
'sc.ra', 'sc.dec',
'scd.ra', 'scd.dec', 'scd.distance',
'scd.obstime.jd1', 'scd.obstime.jd2',
'scdc.x', 'scdc.y', 'scdc.z',
'scdc.obstime.jd1', 'scdc.obstime.jd2',
'scpm.ra', 'scpm.dec', 'scpm.distance',
'scpm.pm_ra_cosdec', 'scpm.pm_dec',
'scpmrv.ra', 'scpmrv.dec', 'scpmrv.distance',
'scpmrv.pm_ra_cosdec', 'scpmrv.pm_dec',
'scpmrv.radial_velocity',
'scrv.ra', 'scrv.dec', 'scrv.distance',
'scrv.radial_velocity',
'sd.d_lon_coslat', 'sd.d_lat', 'sd.d_distance',
'sr.lon', 'sr.lat', 'sr.distance',
'srd.lon', 'srd.lat', 'srd.distance',
'srd.differentials.s.d_lon_coslat',
'srd.differentials.s.d_lat',
'srd.differentials.s.d_distance',
'tm', # serialize_method is formatted_value
'x',
]

# FITS stores times directly, so we just get the column back.
all_serialized_names = []
for name in sorted(mixin_cols):
all_serialized_names.extend(
[name] if isinstance(mixin_cols[name], Time)
else serialized_names[name])
t = table_cls([mixin_cols[name] for name in names], names=names)
t.meta['C'] = 'spam'
t.meta['comments'] = ['this', 'is', 'a', 'comment']
Expand All @@ -880,7 +791,7 @@ def test_fits_mixins_as_one(table_cls, tmpdir):

# Read directly via fits and confirm column names
with fits.open(filename) as hdus:
assert hdus[1].columns.names == serialized_names
assert hdus[1].columns.names == all_serialized_names


@pytest.mark.parametrize('name_col', list(mixin_cols.items()))
Expand All @@ -898,23 +809,34 @@ def test_fits_mixins_per_column(table_cls, name_col, tmpdir):
if not t.has_mixin_columns:
pytest.skip('column is not a mixin (e.g. Quantity subclass in Table)')

if isinstance(t[name], NdarrayMixin):
pytest.xfail('NdarrayMixin not supported')

t.write(filename, format="fits")
t2 = table_cls.read(filename, format='fits', astropy_native=True)
if isinstance(col, Time):
# FITS Time does not preserve format
t2[name].format = col.format

assert t.colnames == t2.colnames

for colname in t.colnames:
assert_objects_equal(t[colname], t2[colname], compare_attrs[colname])
compare = ['data'] if colname in ('c1', 'c2') else compare_attrs[colname]
assert_objects_equal(t[colname], t2[colname], compare)

# Special case to make sure Column type doesn't leak into Time class data
if name.startswith('tm'):
assert t2[name]._time.jd1.__class__ is np.ndarray
assert t2[name]._time.jd2.__class__ is np.ndarray


@pytest.mark.parametrize('name_col', unsupported_cols.items())
@pytest.mark.xfail(reason='column type unsupported')
def test_fits_unsupported_mixin(self, name_col, tmpdir):
# Check that we actually fail in writing unsupported columns defined
# on top.
filename = str(tmpdir.join('test_simple.fits'))
name, col = name_col
Table([col], names=[name]).write(filename, format='fits')


def test_info_attributes_with_no_mixins(tmpdir):
"""Even if there are no mixin columns, if there is metadata that would be lost it still
gets serialized
Expand Down

0 comments on commit dc588da

Please sign in to comment.