diff --git a/test/utils.jl b/test/utils.jl index a5f3fb2aac..92cf9e862b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,8 +1,9 @@ using Flux using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, orthogonal, truncated_normal, - sparse_init, identity_init, stack, unstack, batch, unbatch, + sparse_init, identity_init, unstack, batch, unbatch, unsqueeze, params, loadparams!, loadmodel! +using MLUtils using StatsBase: var, std using Statistics, LinearAlgebra using Random @@ -326,14 +327,14 @@ end @testset "Stacking" begin x = randn(3,3) - stacked = stack([x, x], dims=2) + stacked = MLUtils.stack([x, x], dims=2) @test size(stacked) == (3,2,3) stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ] unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]] @test unstack(stacked_array, dims=2) == unstacked_array - @test stack(unstacked_array, dims=2) == stacked_array - @test stack(unstack(stacked_array, dims=1), dims=1) == stacked_array + @test MLUtils.stack(unstacked_array, dims=2) == stacked_array + @test MLUtils.stack(unstack(stacked_array, dims=1), dims=1) == stacked_array end @testset "Batching" begin