from __future__ import print_function
import numpy as np
from sympy import Symbol,Function
import sympy
from pixell import fft as efft, enmap
import os,sys,warnings
from . import utils
"""
Routines to reduce and evaluate symbolic mode coupling integrals
"""
# Built-in special symbols
l1x = Symbol('l1x') # \vec{l}_1x
l1y = Symbol('l1y') # \vec{l}_1y
l2x = Symbol('l2x') # \vec{l}_2x
l2y = Symbol('l2y') # \vec{l}_2y
l1 = Symbol('l1') # |\vec{l}_1|
l2 = Symbol('l2') # |\vec{l}_2|
Lx = Symbol('Lx') # \vec{L}_x
Ly = Symbol('Ly') # \vec{L}_y
L = Symbol('L') # |\vec{L}|
Ldl1 = (Lx*l1x+Ly*l1y) # \vec{L}.\vec{l}_1
Ldl2 = (Lx*l2x+Ly*l2y) # \vec{L}.\vec{l}_2
Lxl1 = (Ly*l1x-Lx*l1y) # \vec{L} x \vec{l}_1 for curl
Lxl2 = (Ly*l2x-Lx*l2y) # \vec{L} x \vec{l}_2 for curl
# More built-in special symbols
# cos(2\theta_{12}), sin(2\theta_{12}) for polarization
cos2t12,sin2t12 = utils.substitute_trig(l1x,l1y,l2x,l2y,l1,l2)
# Custom symbol wrapper
def e(symbol):
# TODO: add exceptions if symbol doesn't correspond to key structure
return Symbol(symbol)
ifft = lambda x: enmap.ifft(x,normalize='phys')
fft = lambda x: enmap.fft(x,normalize='phys')
evaluate = utils.evaluate
[docs]def factorize_2d_convolution_integral(expr,l1funcs=None,l2funcs=None,groups=None,validate=True):
"""Reduce a sympy expression of variables l1x,l1y,l2x,l2y,l1,l2 into a sum of
products of factors that depend only on vec(l1) and vec(l2) and neither, each. If the expression
appeared as the integrand in an integral over vec(l1), where
vec(l2) = vec(L) - vec(l1) then this reduction allows one to evaluate the
integral as a function of vec(L) using FFTs instead of as a convolution.
Parameters
----------
expr: :obj:`sympy.core.symbol.Symbol`
The full Sympy expression to reduce to sum of products of functions of l1 and l2.
l1funcs: list
List of symbols that are functions of l1
l2funcs: list
List of symbols that are functions of l2
groups: list,optional
Group all terms such that they have common factors of the provided list of
expressions to reduce the number of FFTs.
validate: boolean,optional
Whether to check that the final expression and the original agree. Defaults to True.
Returns
-------
terms
unique_l1s
unique_l2s
ogroups
ogroup_weights
ogroup_symbols
"""
# Generic message if validation fails
val_fail_message = "Validation failed. This expression is likely not reducible to FFT form."
# Get the 2D convolution cartesian variables
# l1x,l1y,l2x,l2y,l1,l2 = get_ells()
if l1funcs is None: l1funcs = []
if l2funcs is None: l2funcs = []
if l1x not in l1funcs: l1funcs.append(l1x)
if l1y not in l1funcs: l1funcs.append(l1y)
if l1 not in l1funcs: l1funcs.append(l1)
if l2x not in l2funcs: l2funcs.append(l2x)
if l2y not in l2funcs: l2funcs.append(l2y)
if l2 not in l2funcs: l2funcs.append(l2)
Lx = Symbol('Lx')
Ly = Symbol('Ly')
L = Symbol('L')
ofuncs1 = set(l1funcs) - set([l1x,l1y,l1])
ofuncs2 = set(l2funcs) - set([l2x,l2y,l2])
# List to collect terms in
terms = []
if validate: prodterms = []
# We must expand the expression so that the top-level operation is Add, i.e. it looks like
# A + B + C + ...
expr = sympy.expand( expr )
# What is the top-level operation?
op = expr.func
if op is sympy.Add:
arguments = expr.args # If Add, then we have multiple terms
else:
arguments = [expr] # If not Add, then we have a single term
# Let's factorize each term
unique_l1s = []
unique_l2s = []
def homogenize(inexp):
outexp = inexp.subs([[l1x,Lx],[l2x,Lx],[l1y,Ly],[l2y,Ly],[l1,L],[l2,L]])
ofuncs = ofuncs1.union(ofuncs2)
for ofunc in ofuncs:
nfunc = Symbol(str(ofunc)[:-3])
outexp = outexp.subs(ofunc,nfunc)
return outexp
def get_group(inexp):
if groups is None: return 0
found = False
d = Symbol('dummy')
for i,group in enumerate(groups):
s = inexp.subs(group,d)
if not((s/d).has(d)):
if found:
print(s,group)
raise ValueError("Groups don't seem to be mutually exclusive.")
index = i
found = True
if not(found):
raise ValueError("Couldn't associate a group")
return index
ogroups = [] if not(groups is None) else None
ogroup_weights = [] if not(groups is None) else None
ogroup_symbols = sympy.ones(len(groups),1) if not(groups is None) else None
for k,arg in enumerate(arguments):
temp, ll1terms = arg.as_independent(*l1funcs, as_Mul=True)
loterms, ll2terms = temp.as_independent(*l2funcs, as_Mul=True)
if any([x==0 for x in [ll1terms,ll2terms,loterms]]): continue
# Group ffts
if groups is not None:
gindex = get_group(loterms)
ogroups.append(gindex)
fsyms = loterms.free_symbols
ocoeff = loterms.evalf(subs=dict(zip(fsyms,[1]*len(fsyms))))
ogroup_weights.append( float(ocoeff) )
if ogroup_symbols[gindex]==1:
ogroup_symbols[gindex] = loterms/ocoeff
else:
assert ogroup_symbols[gindex]==loterms/ocoeff, "Error validating group membership"
vdict = {}
vdict['l1'] = ll1terms
vdict['l2'] = ll2terms
tdict = {}
tdict['l1'] = homogenize(vdict['l1'])
tdict['l2'] = homogenize(vdict['l2'])
if not(tdict['l1'] in unique_l1s):
unique_l1s.append(tdict['l1'])
tdict['l1index'] = unique_l1s.index(tdict['l1'])
if not(tdict['l2'] in unique_l2s):
unique_l2s.append(tdict['l2'])
tdict['l2index'] = unique_l2s.index(tdict['l2'])
vdict['other'] = loterms
tdict['other'] = loterms
terms.append(tdict)
# Validate!
if validate:
# Check that all the factors of this term do give back the original term
products = sympy.Mul(vdict['l1'])*sympy.Mul(vdict['l2'])*sympy.Mul(vdict['other'])
assert sympy.simplify(products-arg)==0, val_fail_message
prodterms.append(products)
# Check that the factors don't include symbols they shouldn't
assert all([not(vdict['l1'].has(x)) for x in l2funcs]), val_fail_message
assert all([not(vdict['l2'].has(x)) for x in l1funcs]), val_fail_message
assert all([not(vdict['other'].has(x)) for x in l1funcs]), val_fail_message
assert all([not(vdict['other'].has(x)) for x in l2funcs]), val_fail_message
# Check that the sum of products of final form matches original expression
if validate:
fexpr = sympy.Add(*prodterms)
assert sympy.simplify(expr-fexpr)==0, val_fail_message
return terms,unique_l1s,unique_l2s,ogroups,ogroup_weights,ogroup_symbols
[docs]def integrate(shape,wcs,feed_dict,expr,xmask=None,ymask=None,cache=True,validate=True,groups=None,physical_units=True):
"""
Integrate an arbitrary expression after factorizing it.
Parameters
----------
shape : tuple
The shape of the array for the geometry of the footprint. Typically
(...,Ny,Nx) for Ny pixels in the y-direction and Nx in the x-direction.
wcs : :obj:`astropy.wcs.wcs.WCS`
The wcs object completing the specification of the geometry of the footprint.
feed_dict: dict
Mapping from names of custom symbols to numpy arrays.
expr: :obj:`sympy.core.symbol.Symbol`
A sympy expression containing recognized symbols (see docs)
xmask: (Ny,Nx) ndarray,optional
Fourier space 2D mask for the l1 part of the integral. Defaults to ones.
ymask: (Ny,Nx) ndarray, optional
Fourier space 2D mask for the l2 part of the integral. Defaults to ones.
cache: boolean, optional
Whether to store in memory and reuse repeated terms. Defaults to true.
validate: boolean,optional
Whether to check that the final expression and the original agree. Defaults to True.
groups: list,optional
Group all terms such that they have common factors of the provided list of
expressions to reduce the number of FFTs.
physical_units: boolean,optional
Whether the input is in pixel units or not.
Returns
-------
result : (Ny,Nx) ndarray
The numerical result of the integration of the expression after factorization.
"""
# Geometry
modlmap = enmap.modlmap(shape,wcs)
lymap,lxmap = enmap.lmap(shape,wcs)
pixarea = np.prod(enmap.pixshape(shape,wcs))
feed_dict['L'] = modlmap
feed_dict['Ly'] = lymap
feed_dict['Lx'] = lxmap
shape = shape[-2:]
ones = enmap.ones(shape,wcs,dtype=np.float32)
val = 0.
if xmask is None:
warnings.warn("No xmask specified; assuming all ones. This is probably not going to end well.")
xmask = ones
if ymask is None:
warnings.warn("No xmask specified; assuming all ones. This is probably not going to end well.")
ymask = ones
# Expression
syms = expr.free_symbols
l1funcs = []
l2funcs = []
for sym in syms:
strsym = str(sym)
if strsym[-3:]=="_l1": l1funcs.append(sym)
elif strsym[-3:]=="_l2": l2funcs.append(sym)
integrands,ul1s,ul2s, \
ogroups,ogroup_weights, \
ogroup_symbols = factorize_2d_convolution_integral(expr,l1funcs=l1funcs,l2funcs=l2funcs,
validate=validate,groups=groups)
def _fft(x):
return fft(enmap.enmap(x+0j,wcs))
def _ifft(x):
return ifft(enmap.enmap(x+0j,wcs))
if cache:
cached_u1s = []
cached_u2s = []
for u1 in ul1s:
l12d = evaluate(u1,feed_dict)*ones
cached_u1s.append(_ifft(l12d*xmask))
for u2 in ul2s:
l22d = evaluate(u2,feed_dict)*ones
cached_u2s.append(_ifft(l22d*ymask))
# For each term, the index of which group it belongs to
def get_l1l2(term):
if cache:
ifft1 = cached_u1s[term['l1index']]
ifft2 = cached_u2s[term['l2index']]
else:
l12d = evaluate(term['l1'],feed_dict)*ones
ifft1 = _ifft(l12d*xmask)
l22d = evaluate(term['l2'],feed_dict)*ones
ifft2 = _ifft(l22d*ymask)
return ifft1,ifft2
if ogroups is None:
for i,term in enumerate(integrands):
ifft1,ifft2 = get_l1l2(term)
ot2d = evaluate(term['other'],feed_dict)*ones
ffft = _fft(ifft1*ifft2)
val += ot2d*ffft
else:
vals = np.zeros((len(ogroup_symbols),)+shape,dtype=np.float32)+0j
for i,term in enumerate(integrands):
ifft1,ifft2 = get_l1l2(term)
gindex = ogroups[i]
vals[gindex,...] += ifft1*ifft2 *ogroup_weights[i]
for i,group in enumerate(ogroup_symbols):
ot2d = evaluate(ogroup_symbols[i],feed_dict)*ones
ffft = _fft(vals[i,...])
val += ot2d*ffft
mul = 1 if physical_units else 1./pixarea
return val * mul