Skip to content

Commit

Permalink
Proor of concept on how to fix issues PyWavelets#531, PyWavelets#535
Browse files Browse the repository at this point in the history
  • Loading branch information
amanita-citrina committed Oct 12, 2020
1 parent 196b5d3 commit 6a5a6bc
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 21 deletions.
45 changes: 26 additions & 19 deletions pywt/_cwt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from math import floor, ceil
from scipy import interpolate

from ._extensions._pywt import (DiscreteContinuousWavelet, ContinuousWavelet,
Wavelet, _check_dtype)
from ._functions import integrate_wavelet, scale2frequency
from ._functions import evaluate_wavelet, scale2frequency


__all__ = ["cwt"]
Expand Down Expand Up @@ -123,13 +124,16 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
dt_out = dt_cplx if wavelet.complex_cwt else dt
out = np.empty((np.size(scales),) + data.shape, dtype=dt_out)
precision = 10
int_psi, x = integrate_wavelet(wavelet, precision=precision)
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
psi, x = evaluate_wavelet(wavelet, precision=precision)
psi = np.conj(psi) if wavelet.complex_cwt else psi

# convert int_psi, x to the same precision as the data
dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt
int_psi = np.asarray(int_psi, dtype=dt_psi)
# convert psi, x to the same precision as the data
dt_psi = dt_cplx if psi.dtype.kind == 'c' else dt
psi = np.asarray(psi, dtype=dt_psi)
x = np.asarray(x, dtype=data.real.dtype)
# FIXME: The original wavelet function could be used here, but
# interpolation is computationally more efficient.
wavefun = interpolate.interp1d(x, psi, kind='cubic', assume_sorted=True)

if method == 'fft':
size_scale0 = -1
Expand All @@ -146,41 +150,44 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
data = data.reshape((-1, data.shape[-1]))

for i, scale in enumerate(scales):
step = x[1] - x[0]
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
j = j.astype(int) # floor
if j[-1] >= int_psi.size:
j = np.extract(j < int_psi.size, j)
int_psi_scale = int_psi[j][::-1]
# FIXME: Boundary points might be discarded erroneously
if np.sign(x[0])*np.sign(x[-1])<0:
# Wavelet is sampled at 0.0 if the range includes it
xsl = np.arange(0.0, x[0], -1.0/scale)
xsr = np.arange(0.0, x[-1], 1.0/scale)
xs = np.concatenate((xsl[:0:-1], xsr))
else:
xs = np.arange(x[0], x[-1], 1.0/scale)
psi_scale = wavefun(xs)[::-1]

if method == 'conv':
if data.ndim == 1:
conv = np.convolve(data, int_psi_scale)
conv = np.convolve(data, psi_scale)
else:
# batch convolution via loop
conv_shape = list(data.shape)
conv_shape[-1] += int_psi_scale.size - 1
conv_shape[-1] += psi_scale.size - 1
conv_shape = tuple(conv_shape)
conv = np.empty(conv_shape, dtype=dt_out)
for n in range(data.shape[0]):
conv[n, :] = np.convolve(data[n], int_psi_scale)
conv[n, :] = np.convolve(data[n], psi_scale)
else:
# The padding is selected for:
# - optimal FFT complexity
# - to be larger than the two signals length to avoid circular
# convolution
size_scale = next_fast_len(
data.shape[-1] + int_psi_scale.size - 1
data.shape[-1] + psi_scale.size - 1
)
if size_scale != size_scale0:
# Must recompute fft_data when the padding size changes.
fft_data = fftmodule.fft(data, size_scale, axis=-1)
size_scale0 = size_scale
fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1)
fft_wav = fftmodule.fft(psi_scale, size_scale, axis=-1)
conv = fftmodule.ifft(fft_wav * fft_data, axis=-1)
conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1]
conv = conv[..., :data.shape[-1] + psi_scale.size - 1]

coef = - np.sqrt(scale) * np.diff(conv, axis=-1)
coef = conv / np.sqrt(scale)
if out.dtype.kind != 'c':
coef = coef.real
# transform axis is always -1 due to the data reshape above
Expand Down
55 changes: 53 additions & 2 deletions pywt/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from ._extensions._pywt import DiscreteContinuousWavelet, Wavelet, ContinuousWavelet


__all__ = ["integrate_wavelet", "central_frequency", "scale2frequency", "qmf",
"orthogonal_filter_bank",
__all__ = ["integrate_wavelet", "evaluate_wavelet", "central_frequency",
"scale2frequency", "qmf", "orthogonal_filter_bank",
"intwave", "centrfrq", "scal2frq", "orthfilt"]


Expand Down Expand Up @@ -119,6 +119,57 @@ def integrate_wavelet(wavelet, precision=8):
return _integrate(psi_d, step), _integrate(psi_r, step), x


def evaluate_wavelet(wavelet, precision=8):
"""
Evaluate `psi` wavelet function between lower and upper bound.
Parameters
----------
wavelet : Wavelet instance or str
Wavelet to evaluate. If a string, should be the name of a wavelet.
precision : int, optional
Number of wavelet function points computed with Wavelet's
wavefun(level=precision) method (default: 8).
Returns
-------
[psi, x] :
for orthogonal wavelets
[psi_d, psi_r, x] :
for other wavelets
Examples
--------
>>> from pywt import Wavelet, evaluate_wavelet
>>> wavelet1 = Wavelet('db2')
>>> [psi, x] = evaluate_wavelet(wavelet1, precision=5)
>>> wavelet2 = Wavelet('bior1.3')
>>> [psi_d, psi_r, x] = evaluate_wavelet(wavelet2, precision=5)
"""

if type(wavelet) in (tuple, list):
psi, x = np.asarray(wavelet[0]), np.asarray(wavelet[1])
return psi, x
elif not isinstance(wavelet, (Wavelet, ContinuousWavelet)):
wavelet = DiscreteContinuousWavelet(wavelet)

functions_approximations = wavelet.wavefun(precision)

if len(functions_approximations) == 2: # continuous wavelet
psi, x = functions_approximations
return psi, x

elif len(functions_approximations) == 3: # orthogonal wavelet
phi, psi, x = functions_approximations
return psi, x

else: # biorthogonal wavelet
phi_d, psi_d, phi_r, psi_r, x = functions_approximations
return psi_d, psi_r, x


def central_frequency(wavelet, precision=8):
"""
Computes the central frequency of the `psi` wavelet function.
Expand Down

0 comments on commit 6a5a6bc

Please sign in to comment.