# -*- coding: utf-8 -*-
# Time-stamp: <2021-03-18 15:07:42 ycopin>
"""
.. _spectrum:
spectrum - Spectrum manipulations
=================================
Spectrum I/O, multi-roll merge, adaptive continuum fit, emission-line
detection and fit.
.. Warning:: All input wavelengths in Å.
.. autosummary::
Table
Spectrum
SimSpectrum
"""
__author__ = "Yannick Copin <y.copin@ipnl.in2p3.fr>"
import os
import warnings
import numpy as N
import astropy.io.fits as F
from . import statistics as S
warnings.formatwarning = formatwarning # Override default warning format
# warnings.simplefilter("always") # Always print warnings
[docs]def get_extension(name, default=0):
"""
Return (name, EXT) from 'name[ext]', using default *ext* if unspecified.
"""
import re
# Decipher name and extension from name[EXT]
search = re.search(r'(.*)\[(.*)\]', name)
if search:
bname, ext = search.groups()
else:
bname, ext = name, default
try:
ext = int(ext) # ext is an integer
except ValueError:
ext = ext.upper() # ext is a string
return bname, ext
# Table class ##############################
[docs]class Table:
"""
Simple FITS table handler.
"""
def __init__(self, name, colX, colY, ext=1, keepHdr=False):
"""
colY can be a single column name (self.ny=1) or a tuple of column
names (self.ny=len(colY)). If self.ny>1, self.y is a
(self.ny, self.npts) 2D-array. ext specifies the extension
number or name, which can also be specified in name[ext].
"""
if name.endswith(']'): # Decipher name and extension from name[EXT]
name, ext = get_extension(name, default=ext)
self.name = name
self.basename = os.path.basename(self.name)
self.colX = colX
self.colY = colY
self.ext = ext
if isinstance(colY, (tuple, list)):
self.ny = len(colY)
else:
self.ny = 1
ffile = F.open(self.name, ignore_missing_end=True)
try: # Check existence of requested extension
ffile.index_of(ext)
except KeyError:
raise KeyError("Cannot find extension %s among %s" %
(str(ext), [ "%d:%s" % (i, ffile[i].name)
for i in range(len(ffile)) ]))
if keepHdr: # Make a copy of requested header
self.hdr = ffile[ext].header.copy()
try: # Make a copy of requested columns
self.x = ffile[ext].data.field(self.colX).copy()
if self.ny == 1:
self.y = ffile[ext].data.field(self.colY).copy()
else:
self.y = N.array([ ffile[ext].data.field(col).copy()
for col in self.colY ])
except NameError:
raise NameError("Cannot find requested columns %s, %s among %s" %
(self.colX, self.colY, ffile[ext].data.names))
ffile.close()
self.npts = len(self.x)
def __str__(self):
return "Table '%s[%s]': col. %s, %s, %d rows" % \
(self.basename, str(self.ext), self.colX, self.colY, self.npts)
[docs] @staticmethod
def isSorted(x, increasing=None, strictly=False):
"""
Check if x in [strictly] sorted (increasing=None), increasing
(True) or decreasing (False).
"""
xx = N.ravel(x)
if increasing is None or increasing:
if strictly:
inc = (xx[1:] > xx[:-1])
else:
inc = (xx[1:] >= xx[:-1])
if increasing is None or not increasing:
if strictly:
dec = (xx[1:] < xx[:-1])
else:
dec = (xx[1:] <= xx[:-1])
if increasing is None:
return inc.all() or dec.all()
elif increasing:
return inc.all()
else:
return dec.all()
[docs] def sort(self):
if not self.isSorted(self.x, increasing=True, strictly=True):
isort = self.x.argsort()
self.x = self.x[isort]
self.y = self.y[isort]
[docs] def interpolate(self, x, cols=None, monotonic=False):
"""
Interpolate table over input array x. For self.ny>1, cols specifies
the column indices to be interpolated (all by default).
"""
if not self.isSorted(self.x, increasing=True, strictly=True):
isort = self.x.argsort()
xsort = self.x[isort]
ysort = self.y[isort]
else:
isort = None
xsort = self.x
ysort = self.y
if monotonic:
from scipy.interpolate import pchip as interpolator
else:
from scipy.interpolate import UnivariateSpline
interpolator = lambda x, y: UnivariateSpline(x, y, s=0)
if self.ny == 1:
return interpolator(xsort, ysort)(x)
else:
if cols is None: # Interpolate all columns
cols = range(self.ny)
xx = N.asarray(x)
res = N.array([ interpolator(xsort, ysort[i])(xx) for i in cols ])
return res.squeeze() # Keep it simple
# Spectrum class ##############################
[docs]class Spectrum:
"""
FITS 1D-spectrum handler.
Class to read and manage a spectrum typically from a FITS file
(`NAXIS=1`), including the associated [co]variance from an extension
or an external file.
.. Warning:: this is not directly compatible w/ NISP simulated spectrum
from `Simulations_140425`. See :class:`SimSpectrum`.
"""
def __init__(self, name, varname=None, keepFits=True):
"""
Spectrum initialization.
"""
self.name = name # Generic name
if name is None: # Blank instance
return
self._readFits(name, # Read signal [and variance if any]
mode='update' if keepFits else 'readonly')
if not keepFits:
self.close()
if varname: # Override variance extension if any
if self.varname: # Set by _readFits from var. extension
warnings.warn("%s: VARIANCE extension overriden by %s" %
(name, varname), RuntimeWarning)
self.varname = varname
v = Spectrum(varname, varname=None, keepFits=keepFits)
assert ((v.npts, v.start, v.step) ==
(self.npts, self.start, self.step)), \
f"Incompatible variance spectrum '{varname}' wrt. to spectrum '{name}'"
self.v = v.y.copy()
# All other attributes and header keywords should be
# essentially the same as for signal spectrum, no need to
# keep them
if self.hasCov and self.hasVar: # Test variance vs. cov. coherence
assert N.allclose(self.v, self.cov.diagonal()), \
"%s: VARIANCE and COVARiance diagonal are incompatible"
@property
def hasVar(self):
return hasattr(self, 'v') and self.v is not None
@property
def hasCov(self):
return hasattr(self, 'cov') and self.cov is not None
[docs] def close(self):
"""Close FITS file (if any) and forget about it."""
if self._fits is not None:
self._fits.close()
self._fits = None
def __str__(self):
s = "Spectrum %s: %d px [%.2f-%.2f A] @%.2f A/px" % \
(self.name, self.npts, self.start, self.end, self.step)
if self.hasCov:
s += " with covariance"
elif self.hasVar:
s += " with variance"
else:
s += " (no [co]variance)"
if self._fits is None:
s += " (closed)"
return s
def _readFits(self, name, mode='readonly'):
"""
Initialize a Spectrum from FITS spectrum name. 'name' can be
'name[ext]', in which case only extension 'ext' is
considered.
"""
# Decipher name and extension from name[EXT]
self.filename, self.ext = get_extension(name)
self._fits = F.open(self.filename,
mode=mode, ignore_missing_end=True)
extnames = [ h.name for h in self._fits ] # "PRIMARY", etc.
try:
spec = self._fits[self.ext] # Spectrum extension
except (IndexError, KeyError,):
raise IOError("Cannot read extension %s in %s:%s" %
(self.ext, self.filename, extnames))
self._hdr = spec.header.copy() # Spectrum header
self._hdr['CRPIX1'] = self._hdr.get('CRPIX1', 1) # Make it mandatory
self.npts = self._hdr['NAXIS1']
self.step = self._hdr['CDELT1']
self.start = self._hdr['CRVAL1'] - \
(self._hdr['CRPIX1'] - 1) * self.step
self.end = self.start + (self.npts - 1) * self.step
self.x = N.linspace(self.start, self.end, self.npts) # Wavelength
self.y = spec.data.copy() # Signal
if 'VARIANCE' in extnames: # Read VARIANCE extension
vhdr = self._fits['VARIANCE'].header
vhdr['CRPIX1'] = vhdr.get('CRPIX1', 1) # Make it mandatory
try:
assert vhdr['NAXIS1'] == self.npts
assert vhdr['CDELT1'] == self.step
assert vhdr['CRVAL1'] == self._hdr['CRVAL1']
assert vhdr['CRPIX1'] == self._hdr['CRPIX1']
except AssertionError:
warnings.warn(
"%s[VARIANCE]: header incompatible with primary header" %
self.filename, RuntimeWarning)
self.varname = f"{self.filename}[VARIANCE]"
self.v = self._fits['VARIANCE'].data.copy() # Variance
else:
self.varname = None
self.v = None
if 'COVAR' in extnames: # Read COVAR extension
vhdr = self._fits['COVAR'].header
vhdr['CRPIX1'] = vhdr.get('CRPIX1', 1) # Make it mandatory
vhdr['CRPIX2'] = vhdr.get('CRPIX2', 1)
try:
assert vhdr['NAXIS1'] == vhdr['NAXIS2'] == self.npts
assert vhdr['CDELT1'] == vhdr['CDELT2'] == self.step
assert vhdr['CRVAL1'] == vhdr['CRVAL2'] == self._hdr['CRVAL1']
assert vhdr['CRPIX1'] == vhdr['CRPIX2'] == self._hdr['CRPIX1']
except AssertionError:
warnings.warn(
"%s[VARIANCE]: header incompatible with primary header" %
self.filename, RuntimeWarning)
self.covname = f"{self.filename}[COVAR]"
self.cov = self._fits['COVAR'].data.copy() # Lower-tri. covariance
self.cov += N.triu(self.cov.T, 1) # Reconstruct full cov.
else:
self.covname = None
self.cov = None
[docs] def readKey(self, keyword, default=None):
"""Read a single keyword, defaulting to *default* if any."""
if default is None:
return self._hdr[keyword]
else:
return self._hdr.get(keyword, default)
[docs] def setKey(self, keywords=(), **kwargs):
"""
Set keywords from `keywords=((key, val[, comment]),)` or kwargs
`key=val` or `key=(val, comment)`.
"""
for key in keywords:
name, val = key[0], key[1:] # name, (value, [comment])
self._hdr[name.upper()] = val
for key in kwargs:
self._hdr[key.upper()] = kwargs[key]
[docs] def writeto(self, outName, force=False, hdrOnly=False, keywords=(), **kwargs):
"""Save Spectrum to new FITS-file."""
if self._fits is None: # FITS file has been closed
raise IOError("Cannot write to disk to closed FITS file")
else:
spec = self._fits[self.ext]
self._hdr['CRPIX1'] = self._hdr.get('CRPIX1', 1) # Make it mandatory
if not hdrOnly: # Update FITS-data
spec.data = N.array(self.y)
# Update FITS-header
self._hdr['NAXIS1'] = self.npts
self._hdr['CDELT1'] = self.step
self._hdr['CRVAL1'] = self.start + \
(self._hdr['CRPIX1'] - 1) * self.step
# Remove any prior VARIANCE/COVARiance extensions if any:
# they will then be re-added as needed
extnames = [ ext.name for ext in self._fits ]
for extname in ("VARIANCE", "COVAR"):
if extname in extnames:
i = self._fits.index_of(extname)
self._fits.remove(self._fits[i])
if self.hasVar and kwargs.pop('varext', True):
# Add variance spectrum as extension VARIANCE
assert len(self.v) == self.npts, \
"Variance extension (%d px) " \
"is not coherent with signal (%d px)" % \
(len(self.v), self.npts)
var = F.ImageHDU(self.v, name='VARIANCE')
var.header['CRVAL1'] = (
self.start + (self._hdr['CRPIX1'] - 1) * self.step)
var.header['CDELT1'] = self.step
var.header['CRPIX1'] = self._hdr['CRPIX1']
self._fits.append(var)
if self.hasCov and kwargs.pop('covext', True):
# Add covariance array as extension COVAR
assert self.cov.shape == (self.npts, self.npts), \
"Covariance extension %s " \
"is not coherent with signal (%d px)" % \
(self.cov.shape, self.npts)
# Add lower-tri COVARiance matrix as an image extension
# cov = F.CompImageHDU(N.tril(self.cov), name='COVAR')
cov = F.ImageHDU(N.tril(self.cov), name='COVAR')
cov.header['CRVAL1'] = cov.header['CRVAL2'] = (
self.start + (self._hdr['CRPIX1'] - 1) * self.step)
cov.header['CDELT1'] = cov.header['CDELT2'] = \
self.step
cov.header['CRPIX1'] = cov.header['CRPIX2'] = \
self._hdr['CRPIX1']
self._fits.append(cov)
# Update required keywords
if keywords or kwargs:
self.setKey(keywords=keywords, **kwargs)
# Test output file presence
if force:
clobber = True # Overwrite existing file
else:
clobber = False # DO NOT overwrite existing file...
if os.path.exists(outName):
ans = input(f"Overwrite output file '{outName}'? [N/y] ")
if ans and ans[0].lower() == 'y':
clobber = True # ...except if confirmed
else:
warnings.warn(f"Output file {outName} not overwritten")
return
# Reset header from local copy self._hdr
spec.header = self._hdr
# Fix missing keywords (but should be OK)
self._fits.writeto(outName, clobber=clobber, output_verify='silentfix')
self.name = outName
self.filename = outName
[docs] @classmethod
def read_spectrum(cls, arg, keepFits=True):
"""
Return an initiated Spectrum from arg=name[, var_name], including
proper deciphering of arg.
"""
innames = arg.split(',') # Check for spectrum, var_spectrum
specname = innames[0]
if len(innames) == 2: # Explicit specName, var_specName
varname = innames[1]
else: # Get variance name and test existence
varname = cls.get_varname(specname, exists=True)
# Set variance if any
return cls(specname, varname=varname, keepFits=keepFits)
[docs] @staticmethod
def get_varname(specname, exists=False):
"""
Return variance spectrum name associated to spectrum 'specname'. Assumes
variance spectrum is located in same directory as spectrum. If exists,
test if variance file exists or return None.
"""
path, bname = os.path.split(specname)
varname = 'var_' + bname # Prefix bname with 'var_'
outname = os.path.join(path, varname) # Add path to varname
if exists and not os.path.isfile(outname):
outname = None
return outname
[docs] @classmethod
def from_vecs(cls, x, y, name='', v=None):
"""
Initialize spectrum from arbitrary vectors.
:param x: wavelength array [Å]
:param y: signal array [erg/s/cm²/Å]
:param name: name of spectrum
:param v: variance array
:raises AssertionError: incompatible vectors or
non-uniformly sampled wavelengths
"""
# Initialize spectrum from tables
spec = cls(None)
spec.name = name
spec.x = N.copy(x)
if spec.x[0] > 1000: # Supposedly in Å
pass
else:
warnings.warn("Converting input wavelengths from µm to Å")
spec.x *= 1e4 # In µm, convert to Å
spec.y = N.copy(y)
rms = N.sqrt((spec.y ** 2).mean())
if rms < 1e-14: # Supposedly in erg/s/cm²/Å
spec.y /= 1e-17 # Convert to 1e-17 erg/s/cm²/Å
else:
pass
assert len(spec.y) == len(spec.x), "Incompatible vectors x and y"
if v is not None:
spec.v = N.copy(v)
rms = N.sqrt((spec.v ** 2).mean())
if rms < 1e-28: # In erg/s/cm²/Å
spec.v /= 1e-34 # Convert to 1e-17 erg/s/cm²/Å
else:
# raise ValueError(
# "input variance not in erg/s/cm2/AA: RMS={}".format(rms))
pass
spec.v[spec.v == 0] = N.inf # Discard null-variance
assert len(spec.v) == len(spec.x), "Incompatible vectors x and v"
else:
spec.v = None
spec.npts = len(spec.x)
spec.start, spec.end = spec.x[0], spec.x[-1]
steps = N.diff(spec.x)
spec.step = steps.mean()
# Check constant wavelength sampling
usteps = N.unique(N.round(steps, 4)) # Unique steps [AA]
assert N.allclose(steps, spec.step, rtol=1e-3), \
f"Wavelength vector x is not uniformly sampled: steps={usteps} A"
if len(usteps) > 1:
warnings.warn("Wavelength vector x is not strictly "
"uniformly sampled: steps={} A".format(usteps))
spec._fits = None
return spec
[docs] @classmethod
def from_simtables(cls, name, path='', minnpts=3):
r"""
Read spectrum stored in simulation table "`name`.fits" and
associated noise table "`name`\_noise.fits".
Return `None` if spectrum is too short.
.. Note:: valid for `Simulations_140425`.
"""
fullname = os.path.join(path, name)
signal = Table(fullname + '.fits', 'Waves', 'Fluxes')
if len(signal.x) < minnpts: # Discard very short spectra
warnings.warn("Spectrum '%s' is only %d px long, discarded" %
(name, len(signal.x)))
return None
signal.sort()
noise = Table(fullname + '_noise.fits', 'Waves', 'Noise')
noise.sort()
assert N.allclose(signal.x, noise.x), \
"Incompatible signal and noise wavelengths"
# Initialize spectrum from table content
spec = cls.from_vecs(x=signal.x[1:-1], # 1st & last px can be shorter
y=signal.y[1:-1],
name=name,
# Convert stderr to variance
v=noise.y[1:-1] ** 2)
return spec
[docs] def extend(self, x):
"""
Extend all spectra to match wavelength sampling `x`.
"""
start, end = x[0], x[-1]
# Extend self on both side to match x (positive if too short)
istart = int(round((self.start - start) / self.step))
iend = int(round((end - self.end) / self.step))
if not (istart >= 0 and iend >= 0):
raise ValueError(
"incompatible wavelength range: {s.roll}: {s.start}, {s.end} "
"vs. target: {}, {}".format(start, end, s=self))
# print("start={}, end={}, npx={}".format(start, end, len(x)))
# print("istart={}, iend={}".format(istart, iend))
# print("BEFORE: self={s.start}, {s.end}, {s.npts}".format(s=self))
if istart > 0: # Extend self on the left
self.y = N.concatenate(([0] * istart, self.y))
if self.v is not None:
self.v = N.concatenate(([N.inf] * istart, self.v))
self.x = N.concatenate(
(self.start + N.arange(-istart, 0) * self.step, self.x))
self.start = self.x[0]
self.npts += istart
if iend > 0: # Extend self on the right
self.y = N.concatenate((self.y, [0] * iend))
if self.v is not None:
self.v = N.concatenate((self.v, [N.inf] * iend))
self.x = N.concatenate((
self.x, self.end + N.arange(1, iend + 1) * self.step))
self.end = self.x[-1]
self.npts += iend
# print("AFTER: self={s.start}, {s.end}, {s.npts}".format(s=self))
# Tests
assert self.npts == len(self.x) == len(x), \
f"ARGH! {self.npts}, {len(self.x)}, {len(x)}"
[docs] def plot(self, ax=None, errorband=False, sgfilter=None, **kwargs):
"""
Plot spectrum.
:param ax: plot axis
:param errorband: add errorbands
:param sgfilter: Savitzky-Golay filter parameters `(hsize, order)`
:param kwargs: propagated to `plot` command
:return: plot axis
"""
if ax is None:
import matplotlib.pyplot as P
fig = P.figure()
ax = fig.add_subplot(1, 1, 1,
xlabel="Wavelength [Å]",
ylabel="Flux [1e-17 erg/s/cm²/Å]")
mask = (self.y == 0)
if self.hasVar:
mask |= ~N.isfinite(self.v) | (self.v <= 0)
y = self.y.copy()
if sgfilter:
hsize, order = sgfilter
y[mask] = N.nan
y = S.savitzky_golay(y, hsize, order, derivative=0)
x = N.ma.masked_where(mask, self.x)
y = N.ma.masked_where(mask, y)
l, = ax.plot(x, y, **kwargs)
if errorband and self.hasVar:
v = N.ma.masked_where(mask, self.v)
dy = v ** 0.5
ax.fill_between(x, y - dy, y + dy, color=l.get_color(), alpha=0.5)
return ax
[docs]class SimSpectrum(Spectrum):
"""
Multi-roll NISP simulated spectrum (`Simulations_140425`).
.. Warning:: Wavelengths in Å, fluxes in 1e-17 erg/s/cm²/Å.
"""
[docs] @classmethod
def from_idx(cls, idx, path=''):
"""
Read spectra corresponding to index `idx`, return `{'roll': spectrum}`.
"""
from glob import glob
specname = 'spectrum_%d_' % idx # 'spectrum_####_'
specnames = glob(os.path.join(path, f'{specname}roll???_??.fits'))
pathnames = [ os.path.split(name) # [ (path, name) ]
for name in specnames ]
# Check all spectra are coming from the same directory
paths = { path for path, name in pathnames } # Set
assert len(paths) == 1
path = paths.pop()
names = [ os.path.splitext(name)[0] # Remove ext
for _, name in pathnames ]
rolls = [ '_'.join(name.split('_')[-2:]) # 'roll###_##'
for name in names ]
sdict = {}
for roll in rolls:
spec = cls.from_simtables(specname + roll, path=path)
if spec is not None:
spec.idx = idx
spec.path = path
spec.roll = roll
sdict[roll] = spec
if not sdict:
raise IOError("Cannot find spectrum #%d in %s" % (idx, path))
else:
print("Spectrum #%d ('%s'): %d rolls" % (idx, path, len(rolls)))
return sdict # { 'roll###_##': spec }
def __str__(self):
return "{s.roll}: RMS={rms:.6g}, " \
"{s.npts:3d} px [{s.start:.2f}-{s.end:.2f}] " \
"step={step:.2f} AA/px, mstart={mstart:.2f} AA".format(
s=self, rms=N.sqrt((self.y ** 2).mean()),
step=self.step, mstart=(self.start % self.step))
[docs] @staticmethod
def select_refspec(sdict):
"""
Select reference spectrum from `sdict = { 'roll': spectrum }`.
Select reference with most current start % step = 4.8 A
"""
rolls = sorted( key for key in list(sdict.keys()) if not key == 'merged' )
refroll = rolls[0]
for roll in rolls[1:]:
spec = sdict[roll]
if round(spec.start % spec.step, 1) == 4.8 and \
spec.npts > sdict[refroll].npts:
refroll = roll
return refroll
[docs] @classmethod
def merge_specs(cls, sdict, rtol=2e-2):
"""
Merge spectra from `sdict = { 'roll': spectrum }` to produce a
:class:`Spectrum`.
"""
# Choose the longuest spectrum as a temporary reference for coherence
# checks
rolls = sorted(sdict.keys())
refroll = rolls[N.argmax([ sdict[roll].npts for roll in rolls ])]
refspec = sdict[refroll] # Reference spectrum
print("Ref.", refspec)
selrolls = [refroll] # Selected rolls
# Coherence checks and extended wavelength domain computation
x = refspec.x.copy() # Extended wavelength domain
for roll in rolls:
if roll == refroll:
continue
spec = sdict[roll]
if not N.isclose(spec.step, refspec.step):
warnings.warn(
"incompatible wavelength steps: {}:{} vs. {}:{}".format(
roll, spec.step, refroll, refspec.step))
elif rtol and not N.isclose(spec.start % spec.step,
refspec.start % refspec.step,
rtol=rtol):
warnings.warn(
"incompatible wavelength starts: {}:{} vs. {}:{}".format(
roll, spec.start % spec.step,
refroll, refspec.start % refspec.step))
else:
selrolls.append(roll) # Current roll selected for merging
print(" ", spec)
# Look for largest common wavelength domain
istart = int(round((x[0] - spec.start) / refspec.step))
iend = int(round((spec.end - x[-1]) / refspec.step))
if istart > 0: # Extend x to the left
x = N.concatenate(
(x[0] + N.arange(-istart, 0) * refspec.step, x))
if iend > 0: # Extend x to the right
x = N.concatenate(
(x, x[-1] + N.arange(1, iend + 1) * refspec.step))
print(f"Common wavelength domain: {len(x)} px [{x[0]:.2f}, {x[-1]:.2f}]")
# Extend spectra to common wavelength domain
for roll in selrolls:
sdict[roll].extend(x)
# Merge selected spectra
print("Merging %d/%d spectra" % (len(selrolls), len(rolls)))
specs = [ sdict[roll] for roll in selrolls ] # Selected spectra
y, v = S.sample_mean(N.vstack([ spec.y for spec in specs ]),
N.vstack([ spec.v for spec in specs ]),
axis=0)
spec = cls.from_vecs(refspec.x, y, v=v, name='merged')
spec.rolls = selrolls # Selected rolls
spec.roll = 'merged'
# Update sdict with a new 'merged' key
sdict['merged'] = spec
return spec # Merged spectrum
[docs] def writeto(self, name, path='', clobber=True, keywords={}):
"""
Minimal FITS writer. No propagation of header.
"""
shdu = F.PrimaryHDU(self.y) # Signal in primary header
hdr = shdu.header
hdr['NAXIS1'] = self.npts
hdr['CDELT1'] = self.step
hdr['CRPIX1'] = 1
hdr['CRVAL1'] = self.start
hdr['CUNIT1'] = 'AA'
hdr['FXUNIT'] = '1e-17 erg/s/cm2/AA'
# Update header with additional keywords
if hasattr(self, 'rolls'): # Rolls used in the merge
for i, roll in enumerate(self.rolls, start=1):
hdr['ROLLS%d' % i] = roll
for key, val in keywords.items():
hdr[key.upper()] = val
hdulist = F.HDUList([shdu])
if self.v is not None:
vhdu = F.ImageHDU(self.v)
hdr = vhdu.header
hdr['NAXIS1'] = self.npts
hdr['CDELT1'] = self.step
hdr['CRPIX1'] = 1
hdr['CRVAL1'] = self.start
hdr['CUNIT1'] = 'AA'
hdr['FXUNIT'] = '(1e-17 erg/s/cm2/AA)**2'
hdr['EXTNAME'] = 'Variance'
hdulist.append(vhdu)
if not path and hasattr(self, 'path'):
path = self.path
specname = os.path.join(path, name)
print(f"Saving spectrum '{self.name}' in '{specname}'")
hdulist.writeto(specname, clobber=clobber)
# Utility functions ==============================
[docs]def fit_adaptive(spec, pmax=0.05, dmax=3, nmax=3, verbose=False):
"""Trivial wrapper to :func:`S.fit_adaptive_spectrum`."""
cont, lines = S.fit_adaptive_spectrum(spec.x, spec.y, spec.v,
pmax=pmax, dmax=dmax, nmax=nmax,
verbose=verbose)
return cont, lines
[docs]def plot_specs(sdict, sgfilter=None, model=None, metadata=None, emilines={},
fig=None, add_pull=True):
"""
Plot spectra from `sdict = { 'roll': spectrum }`.
"""
import scipy.stats as SS
import matplotlib.pyplot as P
from . import mpl
from itertools import cycle
linecycler = cycle(["-", "--", "-.", ":"])
if fig is None:
fig = P.figure(figsize=(14, 5))
ax = fig.add_subplot(1, 1, 1,
xlabel="Wavelength [Å]",
ylabel="Flux [1e-17 erg/s/cm²/Å]")
# Merged spectrum (if any)
if 'merged' in sdict:
merged = sdict['merged'] # Merged spectrum
merged.plot(ax, c=mpl.blue, lw=2, zorder=3,
label='Merged (%d)' % len(merged.rolls),
errorband=True)
if sgfilter:
merged.plot(ax, c=mpl.blue, lw=2, zorder=3, alpha=0.5,
label=f'SG-filtered {sgfilter}',
sgfilter=sgfilter)
rolls = merged.rolls # Selected rolls
else:
rolls = [] # No selected rolls
# Raw spectra
for roll in sorted(sdict):
if roll == 'merged':
continue
spec = sdict[roll]
color = '0.6' if roll in rolls else '0.3' # Highlight unselected rolls
spec.plot(ax, c=color, lw=1, errorband=False,
ls=next(linecycler),
label=f"{roll} ({spec.start % spec.step:.2f})")
# Spectrum metadata
if metadata:
ax.set_title("Spectrum #%d, z=%.3f" % (metadata['idx'], metadata['z']))
lmin, lmax = ax.get_xlim()
trans = P.matplotlib.transforms.blended_transform_factory(
ax.transData, ax.transAxes)
for name, lbda in emilines.items(): # Redshifted emission lines
zlbda = lbda * (1 + metadata['z']) # [Å]
if lmin < zlbda < lmax:
ax.axvline(zlbda, c=mpl.green)
if not name.endswith('_'):
ax.annotate(name, (zlbda, 0.9), xycoords=trans,
rotation='vertical')
# Model
if model and 'merged' in sdict:
mask = ~N.isfinite(merged.v)
x = N.ma.masked_where(mask, merged.x)
y = N.ma.masked_where(mask, merged.y)
v = N.ma.masked_where(mask, merged.v)
cont, lines = model
ycont = cont(x) # Continuum
ymodel = ycont + lines(x)
label = "Fit: d°={}/{}, {}/{} lines".format(
cont.degree, cont.dmax, len(lines.lines), lines.nmax)
if lines.pnext:
label += f" (next: {lines.pnext:.1%})"
ax.plot(x, ymodel, c=mpl.red, lw=1, zorder=4, label=label)
for i, line in enumerate(lines.lines): # Individual emission lines
if len(lines.lines) > 1:
mu = getattr(lines, f'mean_{i}').value
ax.plot(x, ycont + line(x),
c=mpl.red, lw=1, ls='--', zorder=3)
else:
mu = getattr(lines, 'mean').value
ax.axvline(mu, c=mpl.red, ls='--') # Adjusted position
ax.set_title(ax.get_title() +
f" (pmax={cont.pmax:.1%}/{lines.pmax:.1%})")
if add_pull: # Pull histogram
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import NullFormatter
mpulls = (y - ymodel) / v ** 0.5 # Pulls from merged spectrum
mpulls = mpulls[mpulls != 0]
# Pulls from raw spectra
def _pull(roll):
ok = N.isfinite(sdict[roll].v)
return (sdict[roll].y[ok] - cont(sdict[roll].x[ok]) -
lines(sdict[roll].x[ok])) / sdict[roll].v[ok] ** 0.5
spulls = N.concatenate([ _pull(roll) for roll in sorted(sdict)
if not roll == 'merged' ])
spulls = spulls[spulls != 0]
bins = S.hist_bins(spulls)
axh = make_axes_locatable(ax).append_axes("right", 2, pad=0.2)
axh.hist(mpulls, bins=bins, density=True, # log=True,
orientation='horizontal', histtype='stepfilled',
color=mpl.blue, label="µ={:+.2f}, σ={:.2f}".format(
mpulls.mean(), mpulls.std(ddof=1)))
axh.hist(spulls, bins=bins, density=True, # log=True,
orientation='horizontal', histtype='step', lw=1, ec='0.6',
label=f"µ={spulls.mean():+.2f}, σ={spulls.std(ddof=1):.2f}")
# Add normal profile
axh.plot(SS.norm.pdf(bins), bins, c=mpl.green, ls='--',
label="µ≡0, σ≡1")
axh.xaxis.set_major_formatter(NullFormatter())
axh.set_title("Pull distribution")
axh.legend(loc='best', fontsize='small')
ax.legend(loc='upper left', fontsize='small', frameon=True, framealpha=0.5)
return fig