Softmax Regression
In this chapter, you'll learn about:
- Limitations of Linear Regression for Classification: Understanding why mean squared error is not ideal for classification tasks.
- Extension from Binary to Multiclass Classification: Generalizing logistic regression to handle multiple classes.
- Softmax Function: Introducing the softmax function to model multiclass probabilities.
- Cross-Entropy Loss for Multiclass Classification: Formulating the loss function suitable for multiclass settings.
- Deriving the Gradient for Optimization: Computing gradients for parameter updates in softmax regression.
- Connection Between Softmax and Logistic Regression: Showing that logistic regression is a special case of softmax regression.
- Inference in Multiclass Classification: Making predictions using the trained softmax regression model.
- Decision Principles and Expected Risk: Justifying the use of maximum a posteriori (MAP) inference in classification.
In previous chapters, we discussed logistic regression for binary classification, focusing on modeling the probability of binary outcomes and optimizing the cross-entropy loss function. However, many real-world classification problems involve more than two classes. In this chapter, we extend the principles of logistic regression to multiclass classification using softmax regression.
Softmax regression, also known as multinomial logistic regression, generalizes logistic regression by modeling the probability distribution over multiple classes. It employs the softmax function to ensure that the predicted probabilities are positive and sum to one.
Limitations of Linear Regression for Classification
Mean Squared Error (MSE) Issues
Using linear regression with mean squared error for classification tasks is problematic for several reasons:
- Inappropriate Loss Penalization: MSE does not penalize misclassifications adequately, especially when the predictions are confidently wrong.
- Prediction Range: Linear regression can produce predictions outside the [0, 1] interval, which are not valid probabilities.
- Symmetric Penalty: MSE applies the same penalty regardless of whether the prediction is overconfident or underconfident.
Comparison with Cross-Entropy Loss
- Cross-Entropy Loss:
- Penalizes confident wrong predictions more heavily.
- Ensures that the loss goes to infinity as the predicted probability of the incorrect class approaches one.
- Gradient Behavior:
- The gradient of the cross-entropy loss remains bounded, preventing numerical instability during optimization.
- The chain rule in differentiation helps in balancing the drastic increase in loss with the gradient magnitude.
Multiclass Classification
Problem Setup
- Objective: Assign an input to one of classes, labeled as .
- Targets:
- Represented using one-hot encoding:
- Alternatively, use an index representation where .
Issues with Linear Regression
- Ordinal Encoding Problems: Assigning numerical values to classes (e.g., 1, 2, 3) introduces artificial ordering and distance relationships that do not exist between categories.
- Mean Squared Error Limitations: Similar issues as in binary classification, but exacerbated due to multiple classes.
Softmax Function
Definition
The softmax function converts raw scores (logits) from a linear model into probabilities that sum to one.
- Logits (Scores):
- Softmax Function:
- Properties:
- for all .
- .
Interpretation
- Probabilities: Each represents the model's estimated probability that the input belongs to class .
- Exponential Transformation: Ensures all outputs are positive.
- Normalization: Dividing by the sum ensures the outputs sum to one.
Connection to Logistic Function
- Binary Classification: The logistic (sigmoid) function is a special case of the softmax function when .
- Derivation:
- For , the softmax probabilities reduce to: which is the sigmoid function applied to .
Softmax Regression Model
Model Formulation
- Scores:
- Predicted Probabilities:
- Parameterization:
- Weights: for each class .
- Biases: for each class .
Decision Rule
- Predicted Class:
- Inference: Choose the class with the highest predicted probability.
Training Softmax Regression via Maximum Likelihood Estimation
Training Data
- Dataset: , where is a one-hot encoded vector.
Likelihood Function
- Assumption: Observations are independent.
- Likelihood: where represents all model parameters.
Log-Likelihood Function
- Log-Likelihood:
- Objective: Maximize .
Cross-Entropy Loss
- Loss Function:
- Per-Sample Loss: where is the index of the true class.
Interpretation
- The cross-entropy loss measures the difference between the true distribution and the predicted distribution .
- Penalizes incorrect predictions more heavily when the model is confident but wrong.
Gradient Computation for Optimization
Need for Gradient
- Purpose: Use gradient-based optimization methods (e.g., gradient descent) to minimize the loss function.
- Challenge: No closed-form solution for the optimal parameters.
Computing the Gradient
- Gradient w.r.t Weights :
- Gradient w.r.t Biases :
- Derivation Highlights:
- Softmax Derivative: where is the Kronecker delta.
- Chain Rule: Applied to compute derivatives of composite functions.
Matrix Notation
- Gradient Compact Form:
where:
- : Matrix of weights with columns .
- : Design matrix (stacked input vectors).
- : Matrix of predicted probabilities.
- : Matrix of true labels (one-hot encoded).
Optimization Methods
- Gradient Descent Variants:
- Batch Gradient Descent.
- Stochastic Gradient Descent (SGD).
- Mini-Batch Gradient Descent.
- Regularization: Add penalty terms to prevent overfitting (e.g., L2 regularization).
Connection Between Softmax and Logistic Regression
Logistic Regression as a Special Case
- Binary Classification: When , softmax regression reduces to logistic regression.
- Derivation:
- Softmax Probabilities:
- Simplify : which is the sigmoid function applied to .
- Parameter Equivalence:
- Let .
- Let .
Implications
- The softmax function generalizes the logistic (sigmoid) function to multiple classes.
- Understanding this connection helps in grasping the underlying principles of classification models.
Inference in Multiclass Classification
Making Predictions
- Compute Scores:
- Compute Probabilities:
- Predict Class: