Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup Flux.Losses documentation #1930

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Documenter, Flux, NNlib, Functors, MLUtils

DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)

makedocs(modules = [Flux, NNlib, Functors, MLUtils],
doctest = false,
sitename = "Flux",
Expand Down
3 changes: 2 additions & 1 deletion docs/src/models/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,14 @@ model(gpu(rand(10)))
A custom loss function for the multiple outputs may look like this:
```julia
using Statistics
using Flux.Losses: mse

# assuming model returns the output of a Split
# x is a single input
# ys is a tuple of outputs
function loss(x, ys, model)
# rms over all the mse
ŷs = model(x)
return sqrt(mean(Flux.mse(y, ŷ) for (y, ŷ) in zip(ys, ŷs)))
return sqrt(mean(mse(y, ŷ) for (y, ŷ) in zip(ys, ŷs)))
end
```
49 changes: 21 additions & 28 deletions docs/src/models/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,36 @@
Flux provides a large number of common loss functions used for training machine learning models.
They are grouped together in the `Flux.Losses` module.

Loss functions for supervised learning typically expect as inputs a target `y`, and a prediction `ŷ`.
In Flux's convention, the order of the arguments is the following
As an example, the crossentropy function for multi-class classification that takes logit predictions (i.e. not [`softmax`](@ref)ed)
can be imported with

```julia
using Flux.Losses: logitcrossentropy
```

Loss functions for supervised learning typically expect as inputs a true target `y` and a prediction `ŷ`.
In Flux's convention, the order of the arguments is the following:

```julia
loss(ŷ, y)
```

Most loss functions in Flux have an optional argument `agg`, denoting the type of aggregation performed over the
batch:
They are commonly passed as arrays of size `num_target_features x num_examples_in_batch`.

Most losses in Flux have an optional argument `agg` accepting a function to be used as
as a final aggregation:

```julia
loss(ŷ, y) # defaults to `mean`
loss(ŷ, y, agg=sum) # use `sum` for reduction
loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction
loss(ŷ, y, agg=x->mean(w .* x)) # weighted mean
loss(ŷ, y, agg=identity) # no aggregation.
loss(ŷ, y) # defaults to `mean`
loss(ŷ, y, agg = sum) # use `sum` for reduction
loss(ŷ, y, agg = x -> sum(x, dims=2)) # partial reduction
loss(ŷ, y, agg = x -> mean(w .* x)) # weighted mean
loss(ŷ, y, agg = identity) # no aggregation.
```

## Losses Reference

```@docs
Flux.Losses.mae
Flux.Losses.mse
Flux.Losses.msle
Flux.Losses.huber_loss
Flux.Losses.label_smoothing
Flux.Losses.crossentropy
Flux.Losses.logitcrossentropy
Flux.Losses.binarycrossentropy
Flux.Losses.logitbinarycrossentropy
Flux.Losses.kldivergence
Flux.Losses.poisson_loss
Flux.Losses.hinge_loss
Flux.Losses.squared_hinge_loss
Flux.Losses.dice_coeff_loss
Flux.Losses.tversky_loss
Flux.Losses.binary_focal_loss
Flux.Losses.focal_loss
Flux.Losses.siamese_contrastive_loss
```@autodocs
Modules = [Flux.Losses]
Pages = ["functions.jl"]
```
3 changes: 1 addition & 2 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ include("outputsize.jl")
include("data/Data.jl")
using .Data


include("losses/Losses.jl")
using .Losses # TODO: stop importing Losses in Flux's namespace in v0.12
using .Losses # TODO: stop importing Losses in Flux's namespace in v0.14?

include("deprecations.jl")

Expand Down
Loading