From f5386f1a58ea285822a988f0f8e906dd5064ca59 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 6 Apr 2022 01:58:54 +0200 Subject: [PATCH 1/6] clean losses --- docs/make.jl | 5 ++ docs/src/models/advanced.md | 3 +- docs/src/models/losses.md | 44 ++++++------- src/Flux.jl | 3 +- src/losses/functions.jl | 122 ++++++++++++++++++------------------ test/losses.jl | 112 ++++++++++++++++----------------- 6 files changed, 144 insertions(+), 145 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 7332766786..05f9335c5e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,11 @@ using Documenter, Flux, NNlib, Functors, MLUtils DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true) +DocMeta.setdocmeta!(Flux.Losses, :DocTestSetup, :(using Flux.Losses); recursive = true) + +# In the Losses module, doctests which differ in the printed Float32 values won't fail +DocMeta.setdocmeta!(Flux.Losses, :DocTestFilters, :(r"[0-9\.]+f0"); recursive = true) + makedocs(modules = [Flux, NNlib, Functors, MLUtils], doctest = false, sitename = "Flux", diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index 8757828594..1453da95cc 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -213,6 +213,7 @@ model(gpu(rand(10))) A custom loss function for the multiple outputs may look like this: ```julia using Statistics +using Flux.Losses: mse # assuming model returns the output of a Split # x is a single input @@ -220,6 +221,6 @@ using Statistics function loss(x, ys, model) # rms over all the mse ŷs = model(x) - return sqrt(mean(Flux.mse(y, ŷ) for (y, ŷ) in zip(ys, ŷs))) + return sqrt(mean(mse(y, ŷ) for (y, ŷ) in zip(ys, ŷs))) end ``` diff --git a/docs/src/models/losses.md b/docs/src/models/losses.md index 39703ec365..177f24d5fb 100644 --- a/docs/src/models/losses.md +++ b/docs/src/models/losses.md @@ -3,8 +3,16 @@ Flux provides a large number of common loss functions used for training machine learning models. They are grouped together in the `Flux.Losses` module. -Loss functions for supervised learning typically expect as inputs a target `y`, and a prediction `ŷ`. -In Flux's convention, the order of the arguments is the following +As an example, the crossentropy function for multi-class classification that takes logit predictions (i.e. not [`softmax`](@ref)ed) +can be imported with + +```julia +using Flux.Losses: logitcrossentropy +``` + +Loss functions for supervised learning typically expect as inputs a true target `y` and a prediction `ŷ`, +typically passed as arrays of size `num_target_features x num_examples_in_batch`. +In Flux's convention, the order of the arguments is the following: ```julia loss(ŷ, y) @@ -14,32 +22,16 @@ Most loss functions in Flux have an optional argument `agg`, denoting the type o batch: ```julia -loss(ŷ, y) # defaults to `mean` -loss(ŷ, y, agg=sum) # use `sum` for reduction -loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction -loss(ŷ, y, agg=x->mean(w .* x)) # weighted mean -loss(ŷ, y, agg=identity) # no aggregation. +loss(ŷ, y) # defaults to `mean` +loss(ŷ, y, agg = sum) # use `sum` for reduction +loss(ŷ, y, agg = x -> sum(x, dims=2)) # partial reduction +loss(ŷ, y, agg = x -> mean(w .* x)) # weighted mean +loss(ŷ, y, agg = identity) # no aggregation. ``` ## Losses Reference -```@docs -Flux.Losses.mae -Flux.Losses.mse -Flux.Losses.msle -Flux.Losses.huber_loss -Flux.Losses.label_smoothing -Flux.Losses.crossentropy -Flux.Losses.logitcrossentropy -Flux.Losses.binarycrossentropy -Flux.Losses.logitbinarycrossentropy -Flux.Losses.kldivergence -Flux.Losses.poisson_loss -Flux.Losses.hinge_loss -Flux.Losses.squared_hinge_loss -Flux.Losses.dice_coeff_loss -Flux.Losses.tversky_loss -Flux.Losses.binary_focal_loss -Flux.Losses.focal_loss -Flux.Losses.siamese_contrastive_loss +```@autodocs +Modules = [Flux.Losses] +Pages = ["functions.jl"] ``` diff --git a/src/Flux.jl b/src/Flux.jl index aa3f021595..963b0dee05 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -51,9 +51,8 @@ include("outputsize.jl") include("data/Data.jl") using .Data - include("losses/Losses.jl") -using .Losses # TODO: stop importing Losses in Flux's namespace in v0.12 +using .Losses # TODO: stop importing Losses in Flux's namespace in v0.14? include("deprecations.jl") diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 8b036e0fb4..93bbe170e4 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -1,8 +1,3 @@ -# In this file, doctests which differ in the printed Float32 values won't fail -```@meta -DocTestFilters = r"[0-9\.]+f0" -``` - """ mae(ŷ, y; agg = mean) @@ -10,11 +5,12 @@ Return the loss corresponding to mean absolute error: agg(abs.(ŷ .- y)) -# Example +# Examples + ```jldoctest julia> y_model = [1.1, 1.9, 3.1]; -julia> Flux.mae(y_model, 1:3) +julia> mae(y_model, 1:3) 0.10000000000000009 ``` """ @@ -32,13 +28,14 @@ Return the loss corresponding to mean square error: See also: [`mae`](@ref), [`msle`](@ref), [`crossentropy`](@ref). -# Example +# Examples + ```jldoctest julia> y_model = [1.1, 1.9, 3.1]; julia> y_true = 1:3; -julia> Flux.mse(y_model, y_true) +julia> mse(y_model, y_true) 0.010000000000000018 ``` """ @@ -57,12 +54,13 @@ The loss corresponding to mean squared logarithmic errors, calculated as The `ϵ` term provides numerical stability. Penalizes an under-estimation more than an over-estimatation. -# Example +# Examples + ```jldoctest -julia> Flux.msle(Float32[1.1, 2.2, 3.3], 1:3) +julia> msle(Float32[1.1, 2.2, 3.3], 1:3) 0.009084041f0 -julia> Flux.msle(Float32[0.9, 1.8, 2.7], 1:3) +julia> msle(Float32[0.9, 1.8, 2.7], 1:3) 0.011100831f0 ``` """ @@ -112,14 +110,15 @@ value of α larger the smoothing of `y`. `dims` denotes the one-hot dimension, unless `dims=0` which denotes the application of label smoothing to binary distributions encoded in a single number. -# Example +# Examples + ```jldoctest julia> y = Flux.onehotbatch([1, 1, 1, 0, 1, 0], 0:1) 2×6 OneHotMatrix(::Vector{UInt32}) with eltype Bool: ⋅ ⋅ ⋅ 1 ⋅ 1 1 1 1 ⋅ 1 ⋅ -julia> y_smoothed = Flux.label_smoothing(y, 0.2f0) +julia> y_smoothed = label_smoothing(y, 0.2f0) 2×6 Matrix{Float32}: 0.1 0.1 0.1 0.9 0.1 0.9 0.9 0.9 0.9 0.1 0.9 0.1 @@ -134,10 +133,10 @@ julia> y_dis = vcat(y_sim[2,:]', y_sim[1,:]') 0.666667 0.666667 0.666667 0.333333 0.666667 0.333333 0.333333 0.333333 0.333333 0.666667 0.333333 0.666667 -julia> Flux.crossentropy(y_sim, y) < Flux.crossentropy(y_sim, y_smoothed) +julia> crossentropy(y_sim, y) < crossentropy(y_sim, y_smoothed) true -julia> Flux.crossentropy(y_dis, y) > Flux.crossentropy(y_dis, y_smoothed) +julia> crossentropy(y_dis, y) > crossentropy(y_dis, y_smoothed) true ``` """ @@ -177,7 +176,8 @@ computing the loss. See also: [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref), [`logitbinarycrossentropy`](@ref). -# Example +# Examples + ```jldoctest julia> y_label = Flux.onehotbatch([0, 1, 2, 1, 0], 0:2) 3×5 OneHotMatrix(::Vector{UInt32}) with eltype Bool: @@ -195,19 +195,19 @@ julia> sum(y_model; dims=1) 1×5 Matrix{Float32}: 1.0 1.0 1.0 1.0 1.0 -julia> Flux.crossentropy(y_model, y_label) +julia> crossentropy(y_model, y_label) 1.6076053f0 -julia> 5 * ans ≈ Flux.crossentropy(y_model, y_label; agg=sum) +julia> 5 * ans ≈ crossentropy(y_model, y_label; agg=sum) true -julia> y_smooth = Flux.label_smoothing(y_label, 0.15f0) +julia> y_smooth = label_smoothing(y_label, 0.15f0) 3×5 Matrix{Float32}: 0.9 0.05 0.05 0.05 0.9 0.05 0.9 0.05 0.9 0.05 0.05 0.05 0.9 0.05 0.05 -julia> Flux.crossentropy(y_model, y_smooth) +julia> crossentropy(y_model, y_smooth) 1.5776052f0 ``` """ @@ -229,9 +229,10 @@ and [`softmax`](@ref) separately. See also: [`binarycrossentropy`](@ref), [`logitbinarycrossentropy`](@ref), [`label_smoothing`](@ref). -# Example +# Examples + ```jldoctest -julia> y_label = Flux.onehotbatch(collect("abcabaa"), 'a':'c') +julia> y_label = onehotbatch(collect("abcabaa"), 'a':'c') 3×7 OneHotMatrix(::Vector{UInt32}) with eltype Bool: 1 ⋅ ⋅ 1 ⋅ 1 1 ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ @@ -243,10 +244,10 @@ julia> y_model = reshape(vcat(-9:0, 0:9, 7.5f0), 3, 7) -8.0 -5.0 -2.0 0.0 3.0 6.0 9.0 -7.0 -4.0 -1.0 1.0 4.0 7.0 7.5 -julia> Flux.logitcrossentropy(y_model, y_label) +julia> logitcrossentropy(y_model, y_label) 1.5791205f0 -julia> Flux.crossentropy(softmax(y_model), y_label) +julia> crossentropy(softmax(y_model), y_label) 1.5791197f0 ``` """ @@ -284,18 +285,18 @@ julia> y_prob = softmax(reshape(vcat(1:3, 3:5), 2, 3) .* 1f0) 0.268941 0.5 0.268941 0.731059 0.5 0.731059 -julia> Flux.binarycrossentropy(y_prob[2,:], y_bin) +julia> binarycrossentropy(y_prob[2,:], y_bin) 0.43989f0 julia> all(p -> 0 < p < 1, y_prob[2,:]) # else DomainError true -julia> y_hot = Flux.onehotbatch(y_bin, 0:1) +julia> y_hot = onehotbatch(y_bin, 0:1) 2×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool: ⋅ 1 ⋅ 1 ⋅ 1 -julia> Flux.crossentropy(y_prob, y_hot) +julia> crossentropy(y_prob, y_hot) 0.43989f0 ``` """ @@ -322,10 +323,10 @@ julia> y_model = Float32[2, -1, pi] -1.0 3.1415927 -julia> Flux.logitbinarycrossentropy(y_model, y_bin) +julia> logitbinarycrossentropy(y_model, y_bin) 0.160832f0 -julia> Flux.binarycrossentropy(sigmoid.(y_model), y_bin) +julia> binarycrossentropy(sigmoid.(y_model), y_bin) 0.16083185f0 ``` """ @@ -344,7 +345,8 @@ between the given probability distributions. The KL divergence is a measure of how much one probability distribution is different from the other. It is always non-negative, and zero only when both the distributions are equal. -# Example +# Examples + ```jldoctest julia> p1 = [1 0; 0 1] 2×2 Matrix{Int64}: @@ -356,16 +358,16 @@ julia> p2 = fill(0.5, 2, 2) 0.5 0.5 0.5 0.5 -julia> Flux.kldivergence(p2, p1) ≈ log(2) +julia> kldivergence(p2, p1) ≈ log(2) true -julia> Flux.kldivergence(p2, p1; agg = sum) ≈ 2log(2) +julia> kldivergence(p2, p1; agg = sum) ≈ 2log(2) true -julia> Flux.kldivergence(p2, p2; ϵ = 0) # about -2e-16 with the regulator +julia> kldivergence(p2, p2; ϵ = 0) # about -2e-16 with the regulator 0.0 -julia> Flux.kldivergence(p1, p2; ϵ = 0) # about 17.3 with the regulator +julia> kldivergence(p1, p2; ϵ = 0) # about 17.3 with the regulator Inf ``` """ @@ -377,12 +379,14 @@ function kldivergence(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ)) end """ - poisson_loss(ŷ, y) + poisson_loss(ŷ, y; agg = mean) -# Return how much the predicted distribution `ŷ` diverges from the expected Poisson -# distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`. +Return how much the predicted distribution `ŷ` diverges from the expected Poisson +distribution `y`. Calculated as -[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson). + sum(ŷ .- y .* log.(ŷ)) / size(y, 2) + +[More information](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson). """ function poisson_loss(ŷ, y; agg = mean) _check_sizes(ŷ, y) @@ -393,10 +397,11 @@ end hinge_loss(ŷ, y; agg = mean) Return the [hinge_loss loss](https://en.wikipedia.org/wiki/Hinge_loss) given the -prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as -`sum(max.(0, 1 .- ŷ .* y)) / size(y, 2)`. +prediction `ŷ` and true labels `y` (containing 1 or -1). Calculated as + + sum(max.(0, 1 .- ŷ .* y)) / size(y, 2) -See also: [`squared_hinge_loss`](@ref) +See also: [`squared_hinge_loss`](@ref). """ function hinge_loss(ŷ, y; agg = mean) _check_sizes(ŷ, y) @@ -409,7 +414,7 @@ end Return the squared hinge_loss loss given the prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`. -See also: [`hinge_loss`](@ref) +See also [`hinge_loss`](@ref). """ function squared_hinge_loss(ŷ, y; agg = mean) _check_sizes(ŷ, y) @@ -438,6 +443,7 @@ Return the [Tversky loss](https://arxiv.org/abs/1706.05721). Used with imbalanced data to give more weight to false negatives. Larger β weigh recall more than precision (by placing more emphasis on false negatives) Calculated as: + 1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1) """ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7)) @@ -454,9 +460,10 @@ end Return the [binary_focal_loss](https://arxiv.org/pdf/1708.02002.pdf) The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output). -For `γ == 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref). +For `γ == 0`, the loss is mathematically equivalent to [`binarycrossentropy`](@ref). + +# Examples -# Example ```jldoctest julia> y = [0 1 0 1 0 1] @@ -470,12 +477,11 @@ julia> ŷ = [0.268941 0.5 0.268941 0.268941 0.5 0.268941 0.731059 0.5 0.731059 -julia> Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 +julia> binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 true ``` -See also: [`Losses.focal_loss`](@ref) for multi-class setting - +See also: [`focal_loss`](@ref) for multi-class setting. """ function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ)) _check_sizes(ŷ, y) @@ -496,9 +502,10 @@ It down-weights well-classified examples and focuses on hard examples. The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output). The modulating factor, `γ`, controls the down-weighting strength. -For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`](@ref). +For `γ == 0`, the loss is mathematically equivalent to [`crossentropy`](@ref). + +# Examples -# Example ```jldoctest julia> y = [1 0 0 0 1 0 1 0 1 0 @@ -514,11 +521,11 @@ julia> ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0) 0.244728 0.244728 0.244728 0.244728 0.244728 0.665241 0.665241 0.665241 0.665241 0.665241 -julia> Flux.focal_loss(ŷ, y) ≈ 1.1277571935622628 +julia> focal_loss(ŷ, y) ≈ 1.1277571935622628 true ``` -See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels +See also: [`binary_focal_loss`](@ref) for binary (not one-hot) labels """ function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ)) @@ -535,15 +542,10 @@ which can be useful for training Siamese Networks. It is given by agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2) -Specify `margin` to set the baseline for distance at which pairs are dissimilar. - +Specify `margin` to set the baseline for distance at which pairs are dissimilar. """ -function siamese_contrastive_loss(ŷ, y; agg = mean, margin::Real = 1) - _check_sizes(ŷ, y) +function siamese_contrastive_loss(ŷ, y; agg = mean, margin::Real = 1) + _check_sizes(ŷ, y) margin < 0 && throw(DomainError(margin, "Margin must be non-negative")) return agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2) end - -```@meta -DocTestFilters = nothing -``` diff --git a/test/losses.jl b/test/losses.jl index 2ca697a657..78fa2f5404 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -1,20 +1,20 @@ using Test -using Flux: onehotbatch, σ +using Flux: onehotbatch, sigmoid -using Flux.Losses: mse, label_smoothing, crossentropy, logitcrossentropy, binarycrossentropy, logitbinarycrossentropy +using Flux.Losses using Flux.Losses: xlogx, xlogy # group here all losses, used in tests -const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle, - Flux.Losses.crossentropy, Flux.Losses.logitcrossentropy, - Flux.Losses.binarycrossentropy, Flux.Losses.logitbinarycrossentropy, - Flux.Losses.kldivergence, - Flux.Losses.huber_loss, - Flux.Losses.tversky_loss, - Flux.Losses.dice_coeff_loss, - Flux.Losses.poisson_loss, - Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss, - Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss, Flux.Losses.siamese_contrastive_loss] +const ALL_LOSSES = [mse, mae, msle, + crossentropy, logitcrossentropy, + binarycrossentropy, logitbinarycrossentropy, + kldivergence, + huber_loss, + tversky_loss, + dice_coeff_loss, + poisson_loss, + hinge_loss, squared_hinge_loss, + binary_focal_loss, focal_loss, siamese_contrastive_loss] @testset "xlogx & xlogy" begin @@ -45,17 +45,17 @@ y = [1, 1, 0, 0] end @testset "mae" begin - @test Flux.mae(ŷ, y) ≈ 1/2 + @test mae(ŷ, y) ≈ 1/2 end @testset "huber_loss" begin - @test Flux.huber_loss(ŷ, y) ≈ 0.20500000000000002 + @test huber_loss(ŷ, y) ≈ 0.20500000000000002 end y = [123.0,456.0,789.0] ŷ = [345.0,332.0,789.0] @testset "msle" begin - @test Flux.msle(ŷ, y) ≈ 0.38813985859136585 + @test msle(ŷ, y) ≈ 0.38813985859136585 end # Now onehot y's @@ -104,15 +104,15 @@ logŷ, y = randn(3), rand(3) yls = y.*(1-2sf).+sf @testset "binarycrossentropy" begin - @test binarycrossentropy.(σ.(logŷ), label_smoothing(y, 2sf; dims=0); ϵ=0) ≈ -yls.*log.(σ.(logŷ)) - (1 .- yls).*log.(1 .- σ.(logŷ)) - @test binarycrossentropy(σ.(logŷ), y; ϵ=0) ≈ mean(-y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))) - @test binarycrossentropy(σ.(logŷ), y) ≈ mean(-y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))) + @test binarycrossentropy.(sigmoid.(logŷ), label_smoothing(y, 2sf; dims=0); ϵ=0) ≈ -yls.*log.(sigmoid.(logŷ)) - (1 .- yls).*log.(1 .- sigmoid.(logŷ)) + @test binarycrossentropy(sigmoid.(logŷ), y; ϵ=0) ≈ mean(-y.*log.(sigmoid.(logŷ)) - (1 .- y).*log.(1 .- sigmoid.(logŷ))) + @test binarycrossentropy(sigmoid.(logŷ), y) ≈ mean(-y.*log.(sigmoid.(logŷ) .+ eps.(sigmoid.(logŷ))) - (1 .- y).*log.(1 .- sigmoid.(logŷ) .+ eps.(sigmoid.(logŷ)))) @test binarycrossentropy([0.1,0.2,0.9], 1) ≈ -mean(log, [0.1,0.2,0.9]) # constant label end @testset "logitbinarycrossentropy" begin - @test logitbinarycrossentropy.(logŷ, label_smoothing(y, 0.2)) ≈ binarycrossentropy.(σ.(logŷ), label_smoothing(y, 0.2); ϵ=0) - @test logitbinarycrossentropy(logŷ, y) ≈ binarycrossentropy(σ.(logŷ), y; ϵ=0) + @test logitbinarycrossentropy.(logŷ, label_smoothing(y, 0.2)) ≈ binarycrossentropy.(sigmoid.(logŷ), label_smoothing(y, 0.2); ϵ=0) + @test logitbinarycrossentropy(logŷ, y) ≈ binarycrossentropy(sigmoid.(logŷ), y; ϵ=0) end y = onehotbatch([1], 0:1) @@ -128,44 +128,44 @@ y = [1 2 3] ŷ = [4.0 5.0 6.0] @testset "kldivergence" begin - @test Flux.kldivergence([0.1,0.0,0.9], [0.1,0.0,0.9]) ≈ Flux.kldivergence([0.1,0.9], [0.1,0.9]) - @test Flux.kldivergence(ŷ, y) ≈ -1.7661057888493457 - @test Flux.kldivergence(y, y) ≈ 0 + @test kldivergence([0.1,0.0,0.9], [0.1,0.0,0.9]) ≈ kldivergence([0.1,0.9], [0.1,0.9]) + @test kldivergence(ŷ, y) ≈ -1.7661057888493457 + @test kldivergence(y, y) ≈ 0 end y = [1 2 3 4] ŷ = [5.0 6.0 7.0 8.0] @testset "hinge_loss" begin - @test Flux.hinge_loss(ŷ, y) ≈ 0 - @test Flux.hinge_loss(y, 0.5 .* y) ≈ 0.125 + @test hinge_loss(ŷ, y) ≈ 0 + @test hinge_loss(y, 0.5 .* y) ≈ 0.125 end @testset "squared_hinge_loss" begin - @test Flux.squared_hinge_loss(ŷ, y) ≈ 0 - @test Flux.squared_hinge_loss(y, 0.5 .* y) ≈ 0.0625 + @test squared_hinge_loss(ŷ, y) ≈ 0 + @test squared_hinge_loss(y, 0.5 .* y) ≈ 0.0625 end y = [0.1 0.2 0.3] ŷ = [0.4 0.5 0.6] @testset "poisson_loss" begin - @test Flux.poisson_loss(ŷ, y) ≈ 0.6278353988097339 - @test Flux.poisson_loss(y, y) ≈ 0.5044459776946685 + @test poisson_loss(ŷ, y) ≈ 0.6278353988097339 + @test poisson_loss(y, y) ≈ 0.5044459776946685 end y = [1.0 0.5 0.3 2.4] ŷ = [0 1.4 0.5 1.2] @testset "dice_coeff_loss" begin - @test Flux.dice_coeff_loss(ŷ, y) ≈ 0.2799999999999999 - @test Flux.dice_coeff_loss(y, y) ≈ 0.0 + @test dice_coeff_loss(ŷ, y) ≈ 0.2799999999999999 + @test dice_coeff_loss(y, y) ≈ 0.0 end @testset "tversky_loss" begin - @test Flux.tversky_loss(ŷ, y) ≈ -0.06772009029345383 - @test Flux.tversky_loss(ŷ, y, β=0.8) ≈ -0.09490740740740744 - @test Flux.tversky_loss(y, y) ≈ -0.5576923076923075 + @test tversky_loss(ŷ, y) ≈ -0.06772009029345383 + @test tversky_loss(ŷ, y, β=0.8) ≈ -0.09490740740740744 + @test tversky_loss(y, y) ≈ -0.5576923076923075 end @testset "no spurious promotions" begin @@ -173,7 +173,7 @@ end y = rand(T, 2) ŷ = rand(T, 2) for f in ALL_LOSSES - fwd, back = Flux.pullback(f, ŷ, y) + fwd, back = pullback(f, ŷ, y) @test fwd isa T @test eltype(back(one(T))[1]) == T end @@ -190,9 +190,9 @@ end 0 1] ŷ1 = [0.6 0.3 0.4 0.7] - @test Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 - @test Flux.binary_focal_loss(ŷ1, y1) ≈ 0.05691642237852222 - @test Flux.binary_focal_loss(ŷ, y; γ=0.0) ≈ Flux.binarycrossentropy(ŷ, y) + @test binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 + @test binary_focal_loss(ŷ1, y1) ≈ 0.05691642237852222 + @test binary_focal_loss(ŷ, y; γ=0.0) ≈ binarycrossentropy(ŷ, y) end @testset "focal_loss" begin @@ -206,9 +206,9 @@ end ŷ1 = [0.4 0.2 0.5 0.5 0.1 0.3] - @test Flux.focal_loss(ŷ, y) ≈ 1.1277571935622628 - @test Flux.focal_loss(ŷ1, y1) ≈ 0.45990566879720157 - @test Flux.focal_loss(ŷ, y; γ=0.0) ≈ Flux.crossentropy(ŷ, y) + @test focal_loss(ŷ, y) ≈ 1.1277571935622628 + @test focal_loss(ŷ1, y1) ≈ 0.45990566879720157 + @test focal_loss(ŷ, y; γ=0.0) ≈ crossentropy(ŷ, y) end @testset "siamese_contrastive_loss" begin @@ -232,19 +232,19 @@ end 0.1 0.2 0.7] - @test Flux.siamese_contrastive_loss(ŷ, y) ≈ 0.2333333333333333 - @test Flux.siamese_contrastive_loss(ŷ, y, margin = 0.5f0) ≈ 0.10000000000000002 - @test Flux.siamese_contrastive_loss(ŷ, y, margin = 1.5f0) ≈ 0.5333333333333333 - @test Flux.siamese_contrastive_loss(ŷ1, y1) ≈ 0.32554644f0 - @test Flux.siamese_contrastive_loss(ŷ1, y1, margin = 0.5f0) ≈ 0.16271012f0 - @test Flux.siamese_contrastive_loss(ŷ1, y1, margin = 1.5f0) ≈ 0.6532292f0 - @test Flux.siamese_contrastive_loss(ŷ, y, margin = 1) ≈ Flux.siamese_contrastive_loss(ŷ, y) - @test Flux.siamese_contrastive_loss(y, y) ≈ 0.0 - @test Flux.siamese_contrastive_loss(y1, y1) ≈ 0.0 - @test Flux.siamese_contrastive_loss(ŷ, y, margin = 0) ≈ 0.09166666666666667 - @test Flux.siamese_contrastive_loss(ŷ1, y1, margin = 0) ≈ 0.13161165f0 - @test Flux.siamese_contrastive_loss(ŷ2, y2) ≈ 0.21200000000000005 - @test Flux.siamese_contrastive_loss(ŷ2, ŷ2) ≈ 0.18800000000000003 - @test_throws DomainError(-0.5, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ1, y1, margin = -0.5) - @test_throws DomainError(-1, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ, y, margin = -1) + @test siamese_contrastive_loss(ŷ, y) ≈ 0.2333333333333333 + @test siamese_contrastive_loss(ŷ, y, margin = 0.5f0) ≈ 0.10000000000000002 + @test siamese_contrastive_loss(ŷ, y, margin = 1.5f0) ≈ 0.5333333333333333 + @test siamese_contrastive_loss(ŷ1, y1) ≈ 0.32554644f0 + @test siamese_contrastive_loss(ŷ1, y1, margin = 0.5f0) ≈ 0.16271012f0 + @test siamese_contrastive_loss(ŷ1, y1, margin = 1.5f0) ≈ 0.6532292f0 + @test siamese_contrastive_loss(ŷ, y, margin = 1) ≈ siamese_contrastive_loss(ŷ, y) + @test siamese_contrastive_loss(y, y) ≈ 0.0 + @test siamese_contrastive_loss(y1, y1) ≈ 0.0 + @test siamese_contrastive_loss(ŷ, y, margin = 0) ≈ 0.09166666666666667 + @test siamese_contrastive_loss(ŷ1, y1, margin = 0) ≈ 0.13161165f0 + @test siamese_contrastive_loss(ŷ2, y2) ≈ 0.21200000000000005 + @test siamese_contrastive_loss(ŷ2, ŷ2) ≈ 0.18800000000000003 + @test_throws DomainError(-0.5, "Margin must be non-negative") siamese_contrastive_loss(ŷ1, y1, margin = -0.5) + @test_throws DomainError(-1, "Margin must be non-negative") siamese_contrastive_loss(ŷ, y, margin = -1) end From 44d2912a02f28c90d1e407e5c84f2da20bb4ebe9 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 6 Apr 2022 02:02:54 +0200 Subject: [PATCH 2/6] cleanup --- docs/src/models/losses.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/src/models/losses.md b/docs/src/models/losses.md index 177f24d5fb..440e4e7e2e 100644 --- a/docs/src/models/losses.md +++ b/docs/src/models/losses.md @@ -10,14 +10,15 @@ can be imported with using Flux.Losses: logitcrossentropy ``` -Loss functions for supervised learning typically expect as inputs a true target `y` and a prediction `ŷ`, -typically passed as arrays of size `num_target_features x num_examples_in_batch`. +Loss functions for supervised learning typically expect as inputs a true target `y` and a prediction `ŷ`. In Flux's convention, the order of the arguments is the following: ```julia loss(ŷ, y) ``` +They are commonly passed as arrays of size `num_target_features x num_examples_in_batch`. + Most loss functions in Flux have an optional argument `agg`, denoting the type of aggregation performed over the batch: From 590a209efa44c9c0a1767385362816d53fd29554 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 6 Apr 2022 02:06:09 +0200 Subject: [PATCH 3/6] cleanup --- src/losses/functions.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 93bbe170e4..079fee2baf 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -384,7 +384,7 @@ end Return how much the predicted distribution `ŷ` diverges from the expected Poisson distribution `y`. Calculated as - sum(ŷ .- y .* log.(ŷ)) / size(y, 2) + agg(ŷ .- y .* log.(ŷ)) [More information](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson). """ @@ -399,7 +399,7 @@ end Return the [hinge_loss loss](https://en.wikipedia.org/wiki/Hinge_loss) given the prediction `ŷ` and true labels `y` (containing 1 or -1). Calculated as - sum(max.(0, 1 .- ŷ .* y)) / size(y, 2) + agg(max.(0, 1 .- ŷ .* y)) See also: [`squared_hinge_loss`](@ref). """ @@ -412,7 +412,9 @@ end squared_hinge_loss(ŷ, y) Return the squared hinge_loss loss given the prediction `ŷ` and true labels `y` -(containing 1 or -1); calculated as `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`. +(containing 1 or -1). Calculated as + + agg((max.(0, 1 .- ŷ .* y)).^2) See also [`hinge_loss`](@ref). """ @@ -458,7 +460,7 @@ end binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=eps(ŷ)) Return the [binary_focal_loss](https://arxiv.org/pdf/1708.02002.pdf) -The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output). +The input `ŷ` is expected to be normalized (i.e. [`softmax`](@ref) output). For `γ == 0`, the loss is mathematically equivalent to [`binarycrossentropy`](@ref). From 4295a7cafb3d49abed7253525793adf76b923f8d Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 7 Apr 2022 08:40:25 +0200 Subject: [PATCH 4/6] improve docs --- docs/make.jl | 4 ---- docs/src/models/losses.md | 4 ++-- src/losses/functions.jl | 24 ++++++++++++++++++++++++ 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 05f9335c5e..653108241c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,10 +1,6 @@ using Documenter, Flux, NNlib, Functors, MLUtils DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true) -DocMeta.setdocmeta!(Flux.Losses, :DocTestSetup, :(using Flux.Losses); recursive = true) - -# In the Losses module, doctests which differ in the printed Float32 values won't fail -DocMeta.setdocmeta!(Flux.Losses, :DocTestFilters, :(r"[0-9\.]+f0"); recursive = true) makedocs(modules = [Flux, NNlib, Functors, MLUtils], doctest = false, diff --git a/docs/src/models/losses.md b/docs/src/models/losses.md index 440e4e7e2e..9fe9e6c15b 100644 --- a/docs/src/models/losses.md +++ b/docs/src/models/losses.md @@ -19,8 +19,8 @@ loss(ŷ, y) They are commonly passed as arrays of size `num_target_features x num_examples_in_batch`. -Most loss functions in Flux have an optional argument `agg`, denoting the type of aggregation performed over the -batch: +Most losses in Flux have an optional argument `agg` accepting a function to be used as +as a final aggregation: ```julia loss(ŷ, y) # defaults to `mean` diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 079fee2baf..134fc4592a 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -8,6 +8,8 @@ Return the loss corresponding to mean absolute error: # Examples ```jldoctest +julia> using Flux.Losses: mae + julia> y_model = [1.1, 1.9, 3.1]; julia> mae(y_model, 1:3) @@ -31,6 +33,8 @@ See also: [`mae`](@ref), [`msle`](@ref), [`crossentropy`](@ref). # Examples ```jldoctest +julia> using Flux.Losses: mse + julia> y_model = [1.1, 1.9, 3.1]; julia> y_true = 1:3; @@ -57,6 +61,8 @@ Penalizes an under-estimation more than an over-estimatation. # Examples ```jldoctest +julia> using Flux.Losses: msle + julia> msle(Float32[1.1, 2.2, 3.3], 1:3) 0.009084041f0 @@ -113,6 +119,8 @@ of label smoothing to binary distributions encoded in a single number. # Examples ```jldoctest +julia> using Flux.Losses: label_smoothing, crossentropy + julia> y = Flux.onehotbatch([1, 1, 1, 0, 1, 0], 0:1) 2×6 OneHotMatrix(::Vector{UInt32}) with eltype Bool: ⋅ ⋅ ⋅ 1 ⋅ 1 @@ -179,6 +187,8 @@ See also: [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref), [`logitbina # Examples ```jldoctest +julia> using Flux.Losses: label_smoothing, crossentropy + julia> y_label = Flux.onehotbatch([0, 1, 2, 1, 0], 0:2) 3×5 OneHotMatrix(::Vector{UInt32}) with eltype Bool: 1 ⋅ ⋅ ⋅ 1 @@ -232,6 +242,8 @@ See also: [`binarycrossentropy`](@ref), [`logitbinarycrossentropy`](@ref), [`lab # Examples ```jldoctest +julia> using Flux.Losses: crossentropy, logitcrossentropy + julia> y_label = onehotbatch(collect("abcabaa"), 'a':'c') 3×7 OneHotMatrix(::Vector{UInt32}) with eltype Bool: 1 ⋅ ⋅ 1 ⋅ 1 1 @@ -273,7 +285,10 @@ computing the loss. See also: [`crossentropy`](@ref), [`logitcrossentropy`](@ref). # Examples + ```jldoctest +julia> using Flux.Losses: binarycrossentropy, crossentropy + julia> y_bin = Bool[1,0,1] 3-element Vector{Bool}: 1 @@ -314,7 +329,10 @@ Mathematically equivalent to See also: [`crossentropy`](@ref), [`logitcrossentropy`](@ref). # Examples + ```jldoctest +julia> using Flux.Losses: binarycrossentropy, logitbinarycrossentropy + julia> y_bin = Bool[1,0,1]; julia> y_model = Float32[2, -1, pi] @@ -348,6 +366,8 @@ from the other. It is always non-negative, and zero only when both the distribut # Examples ```jldoctest +julia> using Flux.Losses: kldivergence + julia> p1 = [1 0; 0 1] 2×2 Matrix{Int64}: 1 0 @@ -467,6 +487,8 @@ For `γ == 0`, the loss is mathematically equivalent to [`binarycrossentropy`](@ # Examples ```jldoctest +julia> using Flux.Losses: binary_focal_loss + julia> y = [0 1 0 1 0 1] 2×3 Matrix{Int64}: @@ -509,6 +531,8 @@ For `γ == 0`, the loss is mathematically equivalent to [`crossentropy`](@ref). # Examples ```jldoctest +julia> using Flux.Losses: focal_loss + julia> y = [1 0 0 0 1 0 1 0 1 0 0 0 1 0 0] From 18cdb63cd450c2198ecc96f9c289d1e1f2ecec68 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 7 Apr 2022 10:34:55 +0200 Subject: [PATCH 5/6] reinstate docfilter --- src/losses/functions.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 134fc4592a..d425340c04 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -1,3 +1,8 @@ +# In this file, doctests which differ in the printed Float32 values won't fail +```@meta +DocTestFilters = r"[0-9\.]+f0" +``` + """ mae(ŷ, y; agg = mean) @@ -575,3 +580,8 @@ function siamese_contrastive_loss(ŷ, y; agg = mean, margin::Real = 1) margin < 0 && throw(DomainError(margin, "Margin must be non-negative")) return agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2) end + + +```@meta +DocTestFilters = nothing +``` \ No newline at end of file From 567dbc1631721daa86b0e90e0b1fba3f91418653 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 7 Apr 2022 10:58:38 +0200 Subject: [PATCH 6/6] qualify onehotbatch --- src/losses/functions.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index d425340c04..fe858cb231 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -249,7 +249,7 @@ See also: [`binarycrossentropy`](@ref), [`logitbinarycrossentropy`](@ref), [`lab ```jldoctest julia> using Flux.Losses: crossentropy, logitcrossentropy -julia> y_label = onehotbatch(collect("abcabaa"), 'a':'c') +julia> y_label = Flux.onehotbatch(collect("abcabaa"), 'a':'c') 3×7 OneHotMatrix(::Vector{UInt32}) with eltype Bool: 1 ⋅ ⋅ 1 ⋅ 1 1 ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ @@ -311,7 +311,7 @@ julia> binarycrossentropy(y_prob[2,:], y_bin) julia> all(p -> 0 < p < 1, y_prob[2,:]) # else DomainError true -julia> y_hot = onehotbatch(y_bin, 0:1) +julia> y_hot = Flux.onehotbatch(y_bin, 0:1) 2×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool: ⋅ 1 ⋅ 1 ⋅ 1