Comprehensive Guide on Softmax function
Start your free 7-days trial now!
What is the Softmax function?
The Softmax function is defined as follows:
Where:
$x_i$ is the $i$-th value in the vector $\boldsymbol{x}$
$N$ is the dimension of the vector $\boldsymbol{x}$
The Softmax function is a practical function turns numbers into probabilities that sum up to one.
Example of using the Softmax function
Consider the following input vector:
Using the formula for Softmax \eqref{eq:yHWjjyQou5VFcDGhpZV} gives us:
Therefore, we have that:
Note the following:
the output of the entries sum to $1$, which means you can interpret them as probabilities.
the output of the Softmax function $\boldsymbol{y}$ is sometimes referred to as the logit.
Application to neural network
When modelling with neural networks, we often run into the Softmax function. Suppose we wanted to build a neural network that aims to classify whether the image is a cat, a dog or a bird. In such a case, we often use the Softmax function as the activation function for the final layer. The output probabilities are saying that the model is:
70% sure the image is a cat
20% sure the image is a dog
10% sure the image is a bird
If you are performing predictions only without the need of probabilities, then the Softmax function is not necessary.
Comparison with Sigmoid function
Both the Softmax and sigmoid functions map inputs to a range of 0 to 1. However, the difference is that the inputs of the sigmoid do not sum to one as probabilities should.
Implementing Softmax function using Python's NumPy
Basic implementation
We can easily implement the Softmax function as described by equation \eqref{eq:yHWjjyQou5VFcDGhpZV} using NumPy like so:
Let's use this function to compute the Softmax of vector \eqref{eq:ei1XptgAwaMk77sju1d}:
softmax([2, 1, 0.1])
array([0.65900114, 0.24243297, 0.09856589])
Notice how the output is identical to what we calculated by hand.
Optimised implementation
Our basic implementation of the Softmax function is based directly on the definition of the Softmax function as described by \eqref{eq:yHWjjyQou5VFcDGhpZV}:
The problem with this implementation is that exponential functions $e^x$ quickly become large as the value of $x$ increase. For instance, consider $\exp(100)$:
2.6881171418161356e+43
Notice how even a small input of $x=100$ would result in extremely large numbers. In fact, if we try $\exp(800)$, the value is so large that it cannot be computed:
inf
This happens because computers represent numerical values using a fixed number of bytes (e.g. 8 bytes). The caveat is that extremely small or large numbers cannot be defined simply because there aren't enough bytes. If the number is so large that it cannot be represented using a fixed-number of bytes, then NumPy will return inf
.
This limitation of our basic implementation means that large inputs will fail:
softmax([800, 500, 600])
array([1.00000000e+000, 5.14820022e-131, 1.38389653e-087])
Here, nan
stands for not-a-number, that is, the number is too large that it cannot be computed. For this reason, the basic implementation is never used in practise.
The way to overcome this limitation is to reformulate the Softmax function like so:
Note that all we have done is multiplied the numerator and denominator by some scalar constant $C$, and hence \eqref{eq:u0SjbfEiloxYNtGxR2o} is equivalent to the original equation of the Softmax function \eqref{eq:yHWjjyQou5VFcDGhpZV}.
Let's now understand why \eqref{eq:u0SjbfEiloxYNtGxR2o} is better for numerical computation. $C'$ can be any constant value, so we can choose $C'$ such that the exponent ($C'+x_i$) is small. This is how we can avoid large uncomputable numbers.
Now, what is a good value of $C'$? If our goal is to minimize the exponent $C'+x_i$, we could set C' to be the negative maximum of our input vector x.
For instance, consider the following input vector:
The negative of the maximum of $\boldsymbol{x}$ is:
From \eqref{eq:u0SjbfEiloxYNtGxR2o} we know that:
Notice how we now avoid $\exp(800)$, and our exponents are much smaller!
The implementation of \eqref{eq:u0SjbfEiloxYNtGxR2o} in NumPy is as follows:
def softmax(x): """ x: 1D NumPy array of inputs """ c = -np.max(x) x += c return np.exp(x) / np.sum(np.exp(x))
Now, we can use the function like so:
softmax([800, 500, 600])
array([1.00000000e+000, 5.14820022e-131, 1.38389653e-087])
Notice how we do not have any nan
this time.