Source code for inspec.spectrum

# -*- coding: utf-8 -*-
# Time-stamp: <2021-03-18 15:07:42 ycopin>

"""
.. _spectrum:

spectrum - Spectrum manipulations
=================================

Spectrum I/O, multi-roll merge, adaptive continuum fit, emission-line
detection and fit.

.. Warning:: All input wavelengths in Å.

.. autosummary::

   Table
   Spectrum
   SimSpectrum
"""

__author__ = "Yannick Copin <y.copin@ipnl.in2p3.fr>"

import os
import warnings

import numpy as N
import astropy.io.fits as F

from . import statistics as S


[docs]def formatwarning(message, category, filename, lineno, file=None, line=None): # return '%s:%s: %s: %s\n' % (filename, lineno, category.__name__, message) return f'WARNING: {message}\n'
warnings.formatwarning = formatwarning # Override default warning format # warnings.simplefilter("always") # Always print warnings
[docs]def get_extension(name, default=0): """ Return (name, EXT) from 'name[ext]', using default *ext* if unspecified. """ import re # Decipher name and extension from name[EXT] search = re.search(r'(.*)\[(.*)\]', name) if search: bname, ext = search.groups() else: bname, ext = name, default try: ext = int(ext) # ext is an integer except ValueError: ext = ext.upper() # ext is a string return bname, ext
# Table class ##############################
[docs]class Table: """ Simple FITS table handler. """ def __init__(self, name, colX, colY, ext=1, keepHdr=False): """ colY can be a single column name (self.ny=1) or a tuple of column names (self.ny=len(colY)). If self.ny>1, self.y is a (self.ny, self.npts) 2D-array. ext specifies the extension number or name, which can also be specified in name[ext]. """ if name.endswith(']'): # Decipher name and extension from name[EXT] name, ext = get_extension(name, default=ext) self.name = name self.basename = os.path.basename(self.name) self.colX = colX self.colY = colY self.ext = ext if isinstance(colY, (tuple, list)): self.ny = len(colY) else: self.ny = 1 ffile = F.open(self.name, ignore_missing_end=True) try: # Check existence of requested extension ffile.index_of(ext) except KeyError: raise KeyError("Cannot find extension %s among %s" % (str(ext), [ "%d:%s" % (i, ffile[i].name) for i in range(len(ffile)) ])) if keepHdr: # Make a copy of requested header self.hdr = ffile[ext].header.copy() try: # Make a copy of requested columns self.x = ffile[ext].data.field(self.colX).copy() if self.ny == 1: self.y = ffile[ext].data.field(self.colY).copy() else: self.y = N.array([ ffile[ext].data.field(col).copy() for col in self.colY ]) except NameError: raise NameError("Cannot find requested columns %s, %s among %s" % (self.colX, self.colY, ffile[ext].data.names)) ffile.close() self.npts = len(self.x) def __str__(self): return "Table '%s[%s]': col. %s, %s, %d rows" % \ (self.basename, str(self.ext), self.colX, self.colY, self.npts)
[docs] @staticmethod def isSorted(x, increasing=None, strictly=False): """ Check if x in [strictly] sorted (increasing=None), increasing (True) or decreasing (False). """ xx = N.ravel(x) if increasing is None or increasing: if strictly: inc = (xx[1:] > xx[:-1]) else: inc = (xx[1:] >= xx[:-1]) if increasing is None or not increasing: if strictly: dec = (xx[1:] < xx[:-1]) else: dec = (xx[1:] <= xx[:-1]) if increasing is None: return inc.all() or dec.all() elif increasing: return inc.all() else: return dec.all()
[docs] def sort(self): if not self.isSorted(self.x, increasing=True, strictly=True): isort = self.x.argsort() self.x = self.x[isort] self.y = self.y[isort]
[docs] def interpolate(self, x, cols=None, monotonic=False): """ Interpolate table over input array x. For self.ny>1, cols specifies the column indices to be interpolated (all by default). """ if not self.isSorted(self.x, increasing=True, strictly=True): isort = self.x.argsort() xsort = self.x[isort] ysort = self.y[isort] else: isort = None xsort = self.x ysort = self.y if monotonic: from scipy.interpolate import pchip as interpolator else: from scipy.interpolate import UnivariateSpline interpolator = lambda x, y: UnivariateSpline(x, y, s=0) if self.ny == 1: return interpolator(xsort, ysort)(x) else: if cols is None: # Interpolate all columns cols = range(self.ny) xx = N.asarray(x) res = N.array([ interpolator(xsort, ysort[i])(xx) for i in cols ]) return res.squeeze() # Keep it simple
# Spectrum class ##############################
[docs]class Spectrum: """ FITS 1D-spectrum handler. Class to read and manage a spectrum typically from a FITS file (`NAXIS=1`), including the associated [co]variance from an extension or an external file. .. Warning:: this is not directly compatible w/ NISP simulated spectrum from `Simulations_140425`. See :class:`SimSpectrum`. """ def __init__(self, name, varname=None, keepFits=True): """ Spectrum initialization. """ self.name = name # Generic name if name is None: # Blank instance return self._readFits(name, # Read signal [and variance if any] mode='update' if keepFits else 'readonly') if not keepFits: self.close() if varname: # Override variance extension if any if self.varname: # Set by _readFits from var. extension warnings.warn("%s: VARIANCE extension overriden by %s" % (name, varname), RuntimeWarning) self.varname = varname v = Spectrum(varname, varname=None, keepFits=keepFits) assert ((v.npts, v.start, v.step) == (self.npts, self.start, self.step)), \ f"Incompatible variance spectrum '{varname}' wrt. to spectrum '{name}'" self.v = v.y.copy() # All other attributes and header keywords should be # essentially the same as for signal spectrum, no need to # keep them if self.hasCov and self.hasVar: # Test variance vs. cov. coherence assert N.allclose(self.v, self.cov.diagonal()), \ "%s: VARIANCE and COVARiance diagonal are incompatible" @property def hasVar(self): return hasattr(self, 'v') and self.v is not None @property def hasCov(self): return hasattr(self, 'cov') and self.cov is not None
[docs] def close(self): """Close FITS file (if any) and forget about it.""" if self._fits is not None: self._fits.close() self._fits = None
def __str__(self): s = "Spectrum %s: %d px [%.2f-%.2f A] @%.2f A/px" % \ (self.name, self.npts, self.start, self.end, self.step) if self.hasCov: s += " with covariance" elif self.hasVar: s += " with variance" else: s += " (no [co]variance)" if self._fits is None: s += " (closed)" return s def _readFits(self, name, mode='readonly'): """ Initialize a Spectrum from FITS spectrum name. 'name' can be 'name[ext]', in which case only extension 'ext' is considered. """ # Decipher name and extension from name[EXT] self.filename, self.ext = get_extension(name) self._fits = F.open(self.filename, mode=mode, ignore_missing_end=True) extnames = [ h.name for h in self._fits ] # "PRIMARY", etc. try: spec = self._fits[self.ext] # Spectrum extension except (IndexError, KeyError,): raise IOError("Cannot read extension %s in %s:%s" % (self.ext, self.filename, extnames)) self._hdr = spec.header.copy() # Spectrum header self._hdr['CRPIX1'] = self._hdr.get('CRPIX1', 1) # Make it mandatory self.npts = self._hdr['NAXIS1'] self.step = self._hdr['CDELT1'] self.start = self._hdr['CRVAL1'] - \ (self._hdr['CRPIX1'] - 1) * self.step self.end = self.start + (self.npts - 1) * self.step self.x = N.linspace(self.start, self.end, self.npts) # Wavelength self.y = spec.data.copy() # Signal if 'VARIANCE' in extnames: # Read VARIANCE extension vhdr = self._fits['VARIANCE'].header vhdr['CRPIX1'] = vhdr.get('CRPIX1', 1) # Make it mandatory try: assert vhdr['NAXIS1'] == self.npts assert vhdr['CDELT1'] == self.step assert vhdr['CRVAL1'] == self._hdr['CRVAL1'] assert vhdr['CRPIX1'] == self._hdr['CRPIX1'] except AssertionError: warnings.warn( "%s[VARIANCE]: header incompatible with primary header" % self.filename, RuntimeWarning) self.varname = f"{self.filename}[VARIANCE]" self.v = self._fits['VARIANCE'].data.copy() # Variance else: self.varname = None self.v = None if 'COVAR' in extnames: # Read COVAR extension vhdr = self._fits['COVAR'].header vhdr['CRPIX1'] = vhdr.get('CRPIX1', 1) # Make it mandatory vhdr['CRPIX2'] = vhdr.get('CRPIX2', 1) try: assert vhdr['NAXIS1'] == vhdr['NAXIS2'] == self.npts assert vhdr['CDELT1'] == vhdr['CDELT2'] == self.step assert vhdr['CRVAL1'] == vhdr['CRVAL2'] == self._hdr['CRVAL1'] assert vhdr['CRPIX1'] == vhdr['CRPIX2'] == self._hdr['CRPIX1'] except AssertionError: warnings.warn( "%s[VARIANCE]: header incompatible with primary header" % self.filename, RuntimeWarning) self.covname = f"{self.filename}[COVAR]" self.cov = self._fits['COVAR'].data.copy() # Lower-tri. covariance self.cov += N.triu(self.cov.T, 1) # Reconstruct full cov. else: self.covname = None self.cov = None
[docs] def readKey(self, keyword, default=None): """Read a single keyword, defaulting to *default* if any.""" if default is None: return self._hdr[keyword] else: return self._hdr.get(keyword, default)
[docs] def setKey(self, keywords=(), **kwargs): """ Set keywords from `keywords=((key, val[, comment]),)` or kwargs `key=val` or `key=(val, comment)`. """ for key in keywords: name, val = key[0], key[1:] # name, (value, [comment]) self._hdr[name.upper()] = val for key in kwargs: self._hdr[key.upper()] = kwargs[key]
[docs] def resetHeader(self): """Delete all non-standard keywords.""" # Delete all reference keywords for k in list(self._hdr.items()): del self._hdr[k[0]] # Add mandatory keywords self._hdr['SIMPLE'] = True self._hdr['BITPIX'] = -64 self._hdr['NAXIS'] = 1 self._hdr['NAXIS1'] = self.npts self._hdr['CDELT1'] = self.step self._hdr['CRPIX1'] = 1 self._hdr['CRVAL1'] = self.start
[docs] def writeto(self, outName, force=False, hdrOnly=False, keywords=(), **kwargs): """Save Spectrum to new FITS-file.""" if self._fits is None: # FITS file has been closed raise IOError("Cannot write to disk to closed FITS file") else: spec = self._fits[self.ext] self._hdr['CRPIX1'] = self._hdr.get('CRPIX1', 1) # Make it mandatory if not hdrOnly: # Update FITS-data spec.data = N.array(self.y) # Update FITS-header self._hdr['NAXIS1'] = self.npts self._hdr['CDELT1'] = self.step self._hdr['CRVAL1'] = self.start + \ (self._hdr['CRPIX1'] - 1) * self.step # Remove any prior VARIANCE/COVARiance extensions if any: # they will then be re-added as needed extnames = [ ext.name for ext in self._fits ] for extname in ("VARIANCE", "COVAR"): if extname in extnames: i = self._fits.index_of(extname) self._fits.remove(self._fits[i]) if self.hasVar and kwargs.pop('varext', True): # Add variance spectrum as extension VARIANCE assert len(self.v) == self.npts, \ "Variance extension (%d px) " \ "is not coherent with signal (%d px)" % \ (len(self.v), self.npts) var = F.ImageHDU(self.v, name='VARIANCE') var.header['CRVAL1'] = ( self.start + (self._hdr['CRPIX1'] - 1) * self.step) var.header['CDELT1'] = self.step var.header['CRPIX1'] = self._hdr['CRPIX1'] self._fits.append(var) if self.hasCov and kwargs.pop('covext', True): # Add covariance array as extension COVAR assert self.cov.shape == (self.npts, self.npts), \ "Covariance extension %s " \ "is not coherent with signal (%d px)" % \ (self.cov.shape, self.npts) # Add lower-tri COVARiance matrix as an image extension # cov = F.CompImageHDU(N.tril(self.cov), name='COVAR') cov = F.ImageHDU(N.tril(self.cov), name='COVAR') cov.header['CRVAL1'] = cov.header['CRVAL2'] = ( self.start + (self._hdr['CRPIX1'] - 1) * self.step) cov.header['CDELT1'] = cov.header['CDELT2'] = \ self.step cov.header['CRPIX1'] = cov.header['CRPIX2'] = \ self._hdr['CRPIX1'] self._fits.append(cov) # Update required keywords if keywords or kwargs: self.setKey(keywords=keywords, **kwargs) # Test output file presence if force: clobber = True # Overwrite existing file else: clobber = False # DO NOT overwrite existing file... if os.path.exists(outName): ans = input(f"Overwrite output file '{outName}'? [N/y] ") if ans and ans[0].lower() == 'y': clobber = True # ...except if confirmed else: warnings.warn(f"Output file {outName} not overwritten") return # Reset header from local copy self._hdr spec.header = self._hdr # Fix missing keywords (but should be OK) self._fits.writeto(outName, clobber=clobber, output_verify='silentfix') self.name = outName self.filename = outName
[docs] @classmethod def read_spectrum(cls, arg, keepFits=True): """ Return an initiated Spectrum from arg=name[, var_name], including proper deciphering of arg. """ innames = arg.split(',') # Check for spectrum, var_spectrum specname = innames[0] if len(innames) == 2: # Explicit specName, var_specName varname = innames[1] else: # Get variance name and test existence varname = cls.get_varname(specname, exists=True) # Set variance if any return cls(specname, varname=varname, keepFits=keepFits)
[docs] @staticmethod def get_varname(specname, exists=False): """ Return variance spectrum name associated to spectrum 'specname'. Assumes variance spectrum is located in same directory as spectrum. If exists, test if variance file exists or return None. """ path, bname = os.path.split(specname) varname = 'var_' + bname # Prefix bname with 'var_' outname = os.path.join(path, varname) # Add path to varname if exists and not os.path.isfile(outname): outname = None return outname
[docs] @classmethod def from_vecs(cls, x, y, name='', v=None): """ Initialize spectrum from arbitrary vectors. :param x: wavelength array [Å] :param y: signal array [erg/s/cm²/Å] :param name: name of spectrum :param v: variance array :raises AssertionError: incompatible vectors or non-uniformly sampled wavelengths """ # Initialize spectrum from tables spec = cls(None) spec.name = name spec.x = N.copy(x) if spec.x[0] > 1000: # Supposedly in Å pass else: warnings.warn("Converting input wavelengths from µm to Å") spec.x *= 1e4 # In µm, convert to Å spec.y = N.copy(y) rms = N.sqrt((spec.y ** 2).mean()) if rms < 1e-14: # Supposedly in erg/s/cm²/Å spec.y /= 1e-17 # Convert to 1e-17 erg/s/cm²/Å else: pass assert len(spec.y) == len(spec.x), "Incompatible vectors x and y" if v is not None: spec.v = N.copy(v) rms = N.sqrt((spec.v ** 2).mean()) if rms < 1e-28: # In erg/s/cm²/Å spec.v /= 1e-34 # Convert to 1e-17 erg/s/cm²/Å else: # raise ValueError( # "input variance not in erg/s/cm2/AA: RMS={}".format(rms)) pass spec.v[spec.v == 0] = N.inf # Discard null-variance assert len(spec.v) == len(spec.x), "Incompatible vectors x and v" else: spec.v = None spec.npts = len(spec.x) spec.start, spec.end = spec.x[0], spec.x[-1] steps = N.diff(spec.x) spec.step = steps.mean() # Check constant wavelength sampling usteps = N.unique(N.round(steps, 4)) # Unique steps [AA] assert N.allclose(steps, spec.step, rtol=1e-3), \ f"Wavelength vector x is not uniformly sampled: steps={usteps} A" if len(usteps) > 1: warnings.warn("Wavelength vector x is not strictly " "uniformly sampled: steps={} A".format(usteps)) spec._fits = None return spec
[docs] @classmethod def from_simtables(cls, name, path='', minnpts=3): r""" Read spectrum stored in simulation table "`name`.fits" and associated noise table "`name`\_noise.fits". Return `None` if spectrum is too short. .. Note:: valid for `Simulations_140425`. """ fullname = os.path.join(path, name) signal = Table(fullname + '.fits', 'Waves', 'Fluxes') if len(signal.x) < minnpts: # Discard very short spectra warnings.warn("Spectrum '%s' is only %d px long, discarded" % (name, len(signal.x))) return None signal.sort() noise = Table(fullname + '_noise.fits', 'Waves', 'Noise') noise.sort() assert N.allclose(signal.x, noise.x), \ "Incompatible signal and noise wavelengths" # Initialize spectrum from table content spec = cls.from_vecs(x=signal.x[1:-1], # 1st & last px can be shorter y=signal.y[1:-1], name=name, # Convert stderr to variance v=noise.y[1:-1] ** 2) return spec
[docs] def extend(self, x): """ Extend all spectra to match wavelength sampling `x`. """ start, end = x[0], x[-1] # Extend self on both side to match x (positive if too short) istart = int(round((self.start - start) / self.step)) iend = int(round((end - self.end) / self.step)) if not (istart >= 0 and iend >= 0): raise ValueError( "incompatible wavelength range: {s.roll}: {s.start}, {s.end} " "vs. target: {}, {}".format(start, end, s=self)) # print("start={}, end={}, npx={}".format(start, end, len(x))) # print("istart={}, iend={}".format(istart, iend)) # print("BEFORE: self={s.start}, {s.end}, {s.npts}".format(s=self)) if istart > 0: # Extend self on the left self.y = N.concatenate(([0] * istart, self.y)) if self.v is not None: self.v = N.concatenate(([N.inf] * istart, self.v)) self.x = N.concatenate( (self.start + N.arange(-istart, 0) * self.step, self.x)) self.start = self.x[0] self.npts += istart if iend > 0: # Extend self on the right self.y = N.concatenate((self.y, [0] * iend)) if self.v is not None: self.v = N.concatenate((self.v, [N.inf] * iend)) self.x = N.concatenate(( self.x, self.end + N.arange(1, iend + 1) * self.step)) self.end = self.x[-1] self.npts += iend # print("AFTER: self={s.start}, {s.end}, {s.npts}".format(s=self)) # Tests assert self.npts == len(self.x) == len(x), \ f"ARGH! {self.npts}, {len(self.x)}, {len(x)}"
[docs] def plot(self, ax=None, errorband=False, sgfilter=None, **kwargs): """ Plot spectrum. :param ax: plot axis :param errorband: add errorbands :param sgfilter: Savitzky-Golay filter parameters `(hsize, order)` :param kwargs: propagated to `plot` command :return: plot axis """ if ax is None: import matplotlib.pyplot as P fig = P.figure() ax = fig.add_subplot(1, 1, 1, xlabel="Wavelength [Å]", ylabel="Flux [1e-17 erg/s/cm²/Å]") mask = (self.y == 0) if self.hasVar: mask |= ~N.isfinite(self.v) | (self.v <= 0) y = self.y.copy() if sgfilter: hsize, order = sgfilter y[mask] = N.nan y = S.savitzky_golay(y, hsize, order, derivative=0) x = N.ma.masked_where(mask, self.x) y = N.ma.masked_where(mask, y) l, = ax.plot(x, y, **kwargs) if errorband and self.hasVar: v = N.ma.masked_where(mask, self.v) dy = v ** 0.5 ax.fill_between(x, y - dy, y + dy, color=l.get_color(), alpha=0.5) return ax
[docs]class SimSpectrum(Spectrum): """ Multi-roll NISP simulated spectrum (`Simulations_140425`). .. Warning:: Wavelengths in Å, fluxes in 1e-17 erg/s/cm²/Å. """
[docs] @classmethod def from_idx(cls, idx, path=''): """ Read spectra corresponding to index `idx`, return `{'roll': spectrum}`. """ from glob import glob specname = 'spectrum_%d_' % idx # 'spectrum_####_' specnames = glob(os.path.join(path, f'{specname}roll???_??.fits')) pathnames = [ os.path.split(name) # [ (path, name) ] for name in specnames ] # Check all spectra are coming from the same directory paths = { path for path, name in pathnames } # Set assert len(paths) == 1 path = paths.pop() names = [ os.path.splitext(name)[0] # Remove ext for _, name in pathnames ] rolls = [ '_'.join(name.split('_')[-2:]) # 'roll###_##' for name in names ] sdict = {} for roll in rolls: spec = cls.from_simtables(specname + roll, path=path) if spec is not None: spec.idx = idx spec.path = path spec.roll = roll sdict[roll] = spec if not sdict: raise IOError("Cannot find spectrum #%d in %s" % (idx, path)) else: print("Spectrum #%d ('%s'): %d rolls" % (idx, path, len(rolls))) return sdict # { 'roll###_##': spec }
def __str__(self): return "{s.roll}: RMS={rms:.6g}, " \ "{s.npts:3d} px [{s.start:.2f}-{s.end:.2f}] " \ "step={step:.2f} AA/px, mstart={mstart:.2f} AA".format( s=self, rms=N.sqrt((self.y ** 2).mean()), step=self.step, mstart=(self.start % self.step))
[docs] @staticmethod def select_refspec(sdict): """ Select reference spectrum from `sdict = { 'roll': spectrum }`. Select reference with most current start % step = 4.8 A """ rolls = sorted( key for key in list(sdict.keys()) if not key == 'merged' ) refroll = rolls[0] for roll in rolls[1:]: spec = sdict[roll] if round(spec.start % spec.step, 1) == 4.8 and \ spec.npts > sdict[refroll].npts: refroll = roll return refroll
[docs] @classmethod def merge_specs(cls, sdict, rtol=2e-2): """ Merge spectra from `sdict = { 'roll': spectrum }` to produce a :class:`Spectrum`. """ # Choose the longuest spectrum as a temporary reference for coherence # checks rolls = sorted(sdict.keys()) refroll = rolls[N.argmax([ sdict[roll].npts for roll in rolls ])] refspec = sdict[refroll] # Reference spectrum print("Ref.", refspec) selrolls = [refroll] # Selected rolls # Coherence checks and extended wavelength domain computation x = refspec.x.copy() # Extended wavelength domain for roll in rolls: if roll == refroll: continue spec = sdict[roll] if not N.isclose(spec.step, refspec.step): warnings.warn( "incompatible wavelength steps: {}:{} vs. {}:{}".format( roll, spec.step, refroll, refspec.step)) elif rtol and not N.isclose(spec.start % spec.step, refspec.start % refspec.step, rtol=rtol): warnings.warn( "incompatible wavelength starts: {}:{} vs. {}:{}".format( roll, spec.start % spec.step, refroll, refspec.start % refspec.step)) else: selrolls.append(roll) # Current roll selected for merging print(" ", spec) # Look for largest common wavelength domain istart = int(round((x[0] - spec.start) / refspec.step)) iend = int(round((spec.end - x[-1]) / refspec.step)) if istart > 0: # Extend x to the left x = N.concatenate( (x[0] + N.arange(-istart, 0) * refspec.step, x)) if iend > 0: # Extend x to the right x = N.concatenate( (x, x[-1] + N.arange(1, iend + 1) * refspec.step)) print(f"Common wavelength domain: {len(x)} px [{x[0]:.2f}, {x[-1]:.2f}]") # Extend spectra to common wavelength domain for roll in selrolls: sdict[roll].extend(x) # Merge selected spectra print("Merging %d/%d spectra" % (len(selrolls), len(rolls))) specs = [ sdict[roll] for roll in selrolls ] # Selected spectra y, v = S.sample_mean(N.vstack([ spec.y for spec in specs ]), N.vstack([ spec.v for spec in specs ]), axis=0) spec = cls.from_vecs(refspec.x, y, v=v, name='merged') spec.rolls = selrolls # Selected rolls spec.roll = 'merged' # Update sdict with a new 'merged' key sdict['merged'] = spec return spec # Merged spectrum
[docs] def writeto(self, name, path='', clobber=True, keywords={}): """ Minimal FITS writer. No propagation of header. """ shdu = F.PrimaryHDU(self.y) # Signal in primary header hdr = shdu.header hdr['NAXIS1'] = self.npts hdr['CDELT1'] = self.step hdr['CRPIX1'] = 1 hdr['CRVAL1'] = self.start hdr['CUNIT1'] = 'AA' hdr['FXUNIT'] = '1e-17 erg/s/cm2/AA' # Update header with additional keywords if hasattr(self, 'rolls'): # Rolls used in the merge for i, roll in enumerate(self.rolls, start=1): hdr['ROLLS%d' % i] = roll for key, val in keywords.items(): hdr[key.upper()] = val hdulist = F.HDUList([shdu]) if self.v is not None: vhdu = F.ImageHDU(self.v) hdr = vhdu.header hdr['NAXIS1'] = self.npts hdr['CDELT1'] = self.step hdr['CRPIX1'] = 1 hdr['CRVAL1'] = self.start hdr['CUNIT1'] = 'AA' hdr['FXUNIT'] = '(1e-17 erg/s/cm2/AA)**2' hdr['EXTNAME'] = 'Variance' hdulist.append(vhdu) if not path and hasattr(self, 'path'): path = self.path specname = os.path.join(path, name) print(f"Saving spectrum '{self.name}' in '{specname}'") hdulist.writeto(specname, clobber=clobber)
# Utility functions ==============================
[docs]def fit_adaptive(spec, pmax=0.05, dmax=3, nmax=3, verbose=False): """Trivial wrapper to :func:`S.fit_adaptive_spectrum`.""" cont, lines = S.fit_adaptive_spectrum(spec.x, spec.y, spec.v, pmax=pmax, dmax=dmax, nmax=nmax, verbose=verbose) return cont, lines
[docs]def plot_specs(sdict, sgfilter=None, model=None, metadata=None, emilines={}, fig=None, add_pull=True): """ Plot spectra from `sdict = { 'roll': spectrum }`. """ import scipy.stats as SS import matplotlib.pyplot as P from . import mpl from itertools import cycle linecycler = cycle(["-", "--", "-.", ":"]) if fig is None: fig = P.figure(figsize=(14, 5)) ax = fig.add_subplot(1, 1, 1, xlabel="Wavelength [Å]", ylabel="Flux [1e-17 erg/s/cm²/Å]") # Merged spectrum (if any) if 'merged' in sdict: merged = sdict['merged'] # Merged spectrum merged.plot(ax, c=mpl.blue, lw=2, zorder=3, label='Merged (%d)' % len(merged.rolls), errorband=True) if sgfilter: merged.plot(ax, c=mpl.blue, lw=2, zorder=3, alpha=0.5, label=f'SG-filtered {sgfilter}', sgfilter=sgfilter) rolls = merged.rolls # Selected rolls else: rolls = [] # No selected rolls # Raw spectra for roll in sorted(sdict): if roll == 'merged': continue spec = sdict[roll] color = '0.6' if roll in rolls else '0.3' # Highlight unselected rolls spec.plot(ax, c=color, lw=1, errorband=False, ls=next(linecycler), label=f"{roll} ({spec.start % spec.step:.2f})") # Spectrum metadata if metadata: ax.set_title("Spectrum #%d, z=%.3f" % (metadata['idx'], metadata['z'])) lmin, lmax = ax.get_xlim() trans = P.matplotlib.transforms.blended_transform_factory( ax.transData, ax.transAxes) for name, lbda in emilines.items(): # Redshifted emission lines zlbda = lbda * (1 + metadata['z']) # [Å] if lmin < zlbda < lmax: ax.axvline(zlbda, c=mpl.green) if not name.endswith('_'): ax.annotate(name, (zlbda, 0.9), xycoords=trans, rotation='vertical') # Model if model and 'merged' in sdict: mask = ~N.isfinite(merged.v) x = N.ma.masked_where(mask, merged.x) y = N.ma.masked_where(mask, merged.y) v = N.ma.masked_where(mask, merged.v) cont, lines = model ycont = cont(x) # Continuum ymodel = ycont + lines(x) label = "Fit: d°={}/{}, {}/{} lines".format( cont.degree, cont.dmax, len(lines.lines), lines.nmax) if lines.pnext: label += f" (next: {lines.pnext:.1%})" ax.plot(x, ymodel, c=mpl.red, lw=1, zorder=4, label=label) for i, line in enumerate(lines.lines): # Individual emission lines if len(lines.lines) > 1: mu = getattr(lines, f'mean_{i}').value ax.plot(x, ycont + line(x), c=mpl.red, lw=1, ls='--', zorder=3) else: mu = getattr(lines, 'mean').value ax.axvline(mu, c=mpl.red, ls='--') # Adjusted position ax.set_title(ax.get_title() + f" (pmax={cont.pmax:.1%}/{lines.pmax:.1%})") if add_pull: # Pull histogram from mpl_toolkits.axes_grid1 import make_axes_locatable from matplotlib.ticker import NullFormatter mpulls = (y - ymodel) / v ** 0.5 # Pulls from merged spectrum mpulls = mpulls[mpulls != 0] # Pulls from raw spectra def _pull(roll): ok = N.isfinite(sdict[roll].v) return (sdict[roll].y[ok] - cont(sdict[roll].x[ok]) - lines(sdict[roll].x[ok])) / sdict[roll].v[ok] ** 0.5 spulls = N.concatenate([ _pull(roll) for roll in sorted(sdict) if not roll == 'merged' ]) spulls = spulls[spulls != 0] bins = S.hist_bins(spulls) axh = make_axes_locatable(ax).append_axes("right", 2, pad=0.2) axh.hist(mpulls, bins=bins, density=True, # log=True, orientation='horizontal', histtype='stepfilled', color=mpl.blue, label="µ={:+.2f}, σ={:.2f}".format( mpulls.mean(), mpulls.std(ddof=1))) axh.hist(spulls, bins=bins, density=True, # log=True, orientation='horizontal', histtype='step', lw=1, ec='0.6', label=f"µ={spulls.mean():+.2f}, σ={spulls.std(ddof=1):.2f}") # Add normal profile axh.plot(SS.norm.pdf(bins), bins, c=mpl.green, ls='--', label="µ≡0, σ≡1") axh.xaxis.set_major_formatter(NullFormatter()) axh.set_title("Pull distribution") axh.legend(loc='best', fontsize='small') ax.legend(loc='upper left', fontsize='small', frameon=True, framealpha=0.5) return fig