Skip to content

Commit

Permalink
Merge pull request #54 from JeffFessler/master
Browse files Browse the repository at this point in the history
Support differing x,y eltypes for dwt
  • Loading branch information
gummif authored Jul 8, 2020
2 parents d044058 + 72cbb3d commit 7f2da5b
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 33 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ julia:
- 1.0
- 1.1
- 1.2
- 1.4
- nightly
matrix:
allow_failures:
Expand Down
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
name = "Wavelets"
uuid = "29a6e085-ba6d-5f35-a997-948ac2efa89a"
author = ["Gudmundur Adalsteinsson "]
version = "0.9.0"
version = "0.9.1"

[compat]
DSP = "0.5.1"
FFTW = "0.2.4"
Reexport = "0.2.0"
SpecialFunctions = "0.7.1"
julia = " 1"

[deps]
DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
Expand Down
11 changes: 8 additions & 3 deletions src/mod/Transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ const DWTArray = AbstractArray
const WPTArray = AbstractVector
const ValueType = Union{AbstractFloat, Complex}

const FVector = StridedVector # e.g., work space vectors

# DWT

"""
Expand Down Expand Up @@ -114,8 +116,8 @@ for (Xwt, Xwt!, _Xwt!, fw) in ((:dwt, :dwt!, :_dwt!, true),
y = Array{T}(undef, size(x))
return ($_Xwt!)(y, x, filter, L, $fw)
end
function ($Xwt!)(y::DWTArray{T}, x::DWTArray{T}, filter::OrthoFilter,
L::Integer=maxtransformlevels(x)) where T<:ValueType
function ($Xwt!)(y::DWTArray{<:ValueType}, x::DWTArray{<:ValueType}, filter::OrthoFilter,
L::Integer=maxtransformlevels(x))
return ($_Xwt!)(y, x, filter, L, $fw)
end
# lifting
Expand Down Expand Up @@ -485,7 +487,10 @@ end # for

# Array with shared memory
function unsafe_vectorslice(A::Array{T}, i::Int, n::Int) where T
return unsafe_wrap(Array, pointer(A, i), n)::Vector{T}
return unsafe_wrap(Array, pointer(A, i), n)::Vector{T}
end
function unsafe_vectorslice(A::StridedArray{T}, i::Int, n::Int) where T
return @view A[i:(i-1+n)]
end

# linear indices of start of rows/cols/planes
Expand Down
4 changes: 2 additions & 2 deletions src/mod/Util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,15 +363,15 @@ function merge!(b::AbstractArray{T}, ib::Integer, incb::Integer, a::AbstractVect
end


function stridedcopy!(b::AbstractVector{T}, a::AbstractArray{T}, ia::Integer, inca::Integer, n::Integer) where T<:Number
function stridedcopy!(b::AbstractVector{<:Number}, a::AbstractArray{<:Number}, ia::Integer, inca::Integer, n::Integer)
@assert ia+(n-1)*inca <= length(a) && n <= length(b)

@inbounds for i = 1:n
b[i] = a[ia + (i-1)*inca]
end
return b
end
function stridedcopy!(b::AbstractArray{T}, ib::Integer, incb::Integer, a::AbstractVector{T}, n::Integer) where T<:Number
function stridedcopy!(b::AbstractArray{<:Number}, ib::Integer, incb::Integer, a::AbstractVector{<:Number}, n::Integer)
@assert ib+(n-1)*incb <= length(b) && n <= length(a)

@inbounds for i = 1:n
Expand Down
55 changes: 29 additions & 26 deletions src/mod/transforms_filter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
# DWT
# 1-D
# writes to y
function _dwt!(y::AbstractVector{T}, x::AbstractVector{T},
filter::OrthoFilter, L::Integer, fw::Bool) where T<:Number
si = Vector{T}(undef, length(filter)-1) # tmp filter vector
function _dwt!(y::AbstractVector{Ty}, x::AbstractVector{Tx},
filter::OrthoFilter, L::Integer, fw::Bool) where {Tx<:Number, Ty<:Number}
T = promote_type(Tx, Ty)
si = Vector{T}(undef, length(filter)-1) # tmp filter vector
scfilter, dcfilter = WT.makereverseqmfpair(filter, fw, T)
return _dwt!(y, x, filter, L, fw, dcfilter, scfilter, si)
end
function _dwt!(y::AbstractVector{T}, x::AbstractVector{T},
function _dwt!(y::AbstractVector{<:Number}, x::AbstractVector{<:Number},
filter::OrthoFilter, L::Integer, fw::Bool,
dcfilter::Vector{T}, scfilter::Vector{T},
si::Vector{T}, snew::Vector{T} = Vector{T}(undef, ifelse(L>1, length(x)>>1, 0))) where T<:Number
Expand All @@ -35,7 +36,7 @@ function _dwt!(y::AbstractVector{T}, x::AbstractVector{T},
if L == 0
return copyto!(y,x)
end
s = x # s is currect scaling coefs location
s = x # s is current scaling coefs location
filtlen = length(filter)

lrange = 1:L
Expand All @@ -59,10 +60,10 @@ function _dwt!(y::AbstractVector{T}, x::AbstractVector{T},
end
return y
end
function unsafe_dwt1level!(y::AbstractVector{T}, x::AbstractVector{T},
function unsafe_dwt1level!(y::AbstractVector{<:Number}, x::AbstractVector{<:Number},
filter::OrthoFilter, fw::Bool,
dcfilter::Vector{T}, scfilter::Vector{T},
si::Vector{T}) where T<:Number
dcfilter::FVector{T}, scfilter::FVector{T},
si::FVector{T}) where T<:Number
n = length(x)
l = 1
filtlen = length(filter)
Expand All @@ -81,11 +82,11 @@ function unsafe_dwt1level!(y::AbstractVector{T}, x::AbstractVector{T},
return y
end

function dwt_transform_strided!(y::Array{T}, x::AbstractArray{T},
function dwt_transform_strided!(y::AbstractArray{<:Number}, x::AbstractArray{<:Number},
msub::Int, nsub::Int, stride::Int, idx_func::Function,
tmpvec::Vector{T}, tmpvec2::Vector{T},
tmpvec::FVector{T}, tmpvec2::FVector{T},
filter::OrthoFilter, fw::Bool,
dcfilter::Vector{T}, scfilter::Vector{T}, si::Vector{T}) where T<:Number
dcfilter::FVector{T}, scfilter::FVector{T}, si::FVector{T}) where T<:Number
for i=1:msub
xi = idx_func(i)
stridedcopy!(tmpvec, x, xi, stride, nsub)
Expand All @@ -94,11 +95,11 @@ function dwt_transform_strided!(y::Array{T}, x::AbstractArray{T},
end
end

function dwt_transform_cols!(y::Array{T}, x::AbstractArray{T},
function dwt_transform_cols!(y::AbstractArray{<:Number}, x::AbstractArray{<:Number},
msub::Int, nsub::Int, idx_func::Function,
tmpvec::Vector{T},
tmpvec::FVector{T},
filter::OrthoFilter, fw::Bool,
dcfilter::Vector{T}, scfilter::Vector{T}, si::Vector{T}) where T<:Number
dcfilter::FVector{T}, scfilter::FVector{T}, si::FVector{T}) where T<:Number
for i=1:nsub
xi = idx_func(i)
copyto!(tmpvec, 1, x, xi, msub)
Expand All @@ -109,16 +110,17 @@ end

# 2-D
# writes to y
function _dwt!(y::Matrix{T}, x::AbstractMatrix{T},
filter::OrthoFilter, L::Integer, fw::Bool) where T<:Number
function _dwt!(y::AbstractMatrix{Ty}, x::AbstractMatrix{Tx},
filter::OrthoFilter, L::Integer, fw::Bool) where {Tx<:Number, Ty<:Number}
m, n = size(x)
T = promote_type(Tx, Ty)
si = Vector{T}(undef, length(filter)-1) # tmp filter vector
tmpbuffer = Vector{T}(undef, max(n<<1, m)) # tmp storage vector
scfilter, dcfilter = WT.makereverseqmfpair(filter, fw, T)

return _dwt!(y, x, filter, L, fw, dcfilter, scfilter, si, tmpbuffer)
end
function _dwt!(y::Matrix{T}, x::AbstractMatrix{T},
function _dwt!(y::AbstractMatrix{<:Number}, x::AbstractMatrix{<:Number},
filter::OrthoFilter, L::Integer, fw::Bool,
dcfilter::Vector{T}, scfilter::Vector{T},
si::Vector{T}, tmpbuffer::Vector{T}) where T<:Number
Expand Down Expand Up @@ -187,16 +189,17 @@ end

# 3-D
# writes to y
function _dwt!(y::Array{T, 3}, x::AbstractArray{T, 3},
filter::OrthoFilter, L::Integer, fw::Bool) where T<:Number
function _dwt!(y::AbstractArray{Ty, 3}, x::AbstractArray{Tx, 3},
filter::OrthoFilter, L::Integer, fw::Bool) where {Tx<:Number, Ty<:Number}
m, n, d = size(x)
T = promote_type(Tx, Ty)
si = Vector{T}(undef, length(filter)-1) # tmp filter vector
tmpbuffer = Vector{T}(undef, max(m, n<<1, d<<1)) # tmp storage vector
scfilter, dcfilter = WT.makereverseqmfpair(filter, fw, T)

return _dwt!(y, x, filter, L, fw, dcfilter, scfilter, si, tmpbuffer)
end
function _dwt!(y::Array{T, 3}, x::AbstractArray{T, 3},
function _dwt!(y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3},
filter::OrthoFilter, L::Integer, fw::Bool,
dcfilter::Vector{T}, scfilter::Vector{T},
si::Vector{T}, tmpbuffer::Vector{T}) where T<:Number
Expand Down Expand Up @@ -329,7 +332,7 @@ function _wpt!(y::AbstractVector{T}, x::AbstractVector{T}, filter::OrthoFilter,
Lfw = (fw ? Lmax-L : L-1)
nj = detailn(n, Lfw)
treeind = 2^(Lfw)-1
dx = unsafe_vectorslice(snew, 1, nj)
dx = first ? x : unsafe_vectorslice(snew, 1, nj) # dx will be overwritten if first

while ix <= n
if tree[treeind+k]
Expand Down Expand Up @@ -381,9 +384,9 @@ end
# x : filter convolved with x[ix:ix+nx-1], where nx=nout*2 (shifted by shift)
# ss : shift downsampling
# based on Base.filt
function filtdown!(f::Vector{T}, si::Vector{T},
out::AbstractVector{T}, iout::Integer, nout::Integer,
x::AbstractVector{T}, ix::Integer,
function filtdown!(f::AbstractVector{T}, si::AbstractVector{T},
out::AbstractVector{<:Number}, iout::Integer, nout::Integer,
x::AbstractVector{<:Number}, ix::Integer,
shift::Integer=0, ss::Bool=false) where T<:Number
nx = nout<<1
silen = length(si)
Expand Down Expand Up @@ -462,8 +465,8 @@ end
# ss : shift upsampling
# based on Base.filt
function filtup!(add2out::Bool, f::Vector{T}, si::Vector{T},
out::AbstractVector{T}, iout::Integer, nout::Integer,
x::AbstractVector{T}, ix::Integer,
out::AbstractVector{<:Number}, iout::Integer, nout::Integer,
x::AbstractVector{<:Number}, ix::Integer,
shift::Integer=0, ss::Bool=false) where T<:Number
nx = nout>>1
silen = length(si)
Expand Down
4 changes: 3 additions & 1 deletion src/mod/transforms_lifting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ end
# tmp: size at least n>>2
# oopc: use oop computation, if false iy and incy are assumed to be 1
# oopv: the out of place location
function unsafe_dwt1level!(y::AbstractArray{T}, iy::Integer, incy::Integer, oopc::Bool, oopv::Vector{T}, scheme::GLS, fw::Bool, stepseq::Vector, norm1::T, norm2::T, tmp::Vector{T}) where T<:Number
function unsafe_dwt1level!(y::AbstractArray{T}, iy::Integer, incy::Integer, oopc::Bool,
oopv::FVector{T}, scheme::GLS, fw::Bool, stepseq::FVector, norm1::T, norm2::T,
tmp::FVector{T}) where T<:Number
if !oopc
oopv = y
end
Expand Down

0 comments on commit 7f2da5b

Please sign in to comment.