Logistic Regression
In this chapter, you'll learn about:
- Binary Classification with Probabilistic Models: Modeling binary outcomes using probabilities.
- Bernoulli Distribution: Understanding the distribution for binary random variables.
- Logistic Function (Sigmoid Function): Introducing the squashing function to map linear combinations to probabilities.
- Logistic Regression Model: Formulating the logistic regression for binary classification.
- Maximum Likelihood Estimation (MLE): Deriving the loss function for logistic regression.
- Cross-Entropy Loss: Connecting the logistic regression loss to cross-entropy and KL divergence.
- Gradient Computation: Calculating gradients for optimization.
- Convexity and Optimization: Discussing the convex nature of logistic regression and optimization methods.
In previous chapters, we introduced classification problems and explored linear classifiers. We discussed the limitations of using linear regression for classification and the need for models specifically designed for categorical outcomes.
In this chapter, we delve into logistic regression, a fundamental algorithm for binary classification tasks. Logistic regression models the probability that a given input belongs to a particular category, allowing for probabilistic interpretation of predictions. It is widely used due to its simplicity, interpretability, and effectiveness.
Binary Classification and the Bernoulli Distribution
Binary Classification Recap
- Objective: Assign an input to one of two classes, labeled as 0 or 1.
- Examples: Spam detection (spam or not spam), disease diagnosis (disease or healthy).
Bernoulli Distribution
- Definition: A discrete probability distribution for a random variable that has two possible outcomes, 1 (success) and 0 (failure).
- Parameter: , where .
- Probability Mass Function:
- Use in Classification: Models the probability that the target variable equals 1.
Modeling the Probability with Inputs
- Goal: Model as a function of the input features .
- Linear Combination: Compute a linear combination .
- Issue: The linear combination can take any real value, but must be between 0 and 1.
The Logistic Function (Sigmoid Function)
Need for a Squashing Function
- Purpose: Map the linear combination to a value between 0 and 1.
- Requirements:
- Monotonic increasing function.
- Outputs values strictly between 0 and 1.
Logistic (Sigmoid) Function
-
Definition:
-
Properties:
- Range: for all real .
- S-shape Curve: As , ; as , .
- Symmetry: .
-
Visualization:
Alternative Functions
- Probit Function: Based on the cumulative distribution function (CDF) of the normal distribution.
- Definition:
- Used in: Probit regression.
- Why Logistic Function?
- Mathematical Convenience: The logistic function leads to a convex loss function and simplifies computation.
- Interpretability: Provides a probabilistic interpretation of the output.
Logistic Regression Model
Model Formulation
- Probability of Class 1:
- Probability of Class 0:
- Interpretation: The logistic function maps the linear combination of inputs to a probability.
Decision Rule
- Predicted Class:
- Equivalently:
Terminology
- Goodness Score (Logit): .
- Soft Output: , the predicted probability.
- Hard Output: , the predicted class label.
Naming Convention
- Despite being called "regression," logistic regression is used for classification tasks. The name originates from its historical development in statistics.
Training Logistic Regression via Maximum Likelihood Estimation
Training Data
- Dataset: , where .
Likelihood Function
- Assumption: Observations are independent.
- Likelihood:
- Using the Model: where .
Log-Likelihood Function
- Log-Likelihood:
- Objective: Maximize .
Loss Function
- Negative Log-Likelihood (Cross-Entropy Loss):
- Purpose: Convert maximization problem into minimization.
Interpretation as Cross-Entropy and KL Divergence
- Cross-Entropy Loss: Measures the difference between two probability distributions.
- KL Divergence:
- Relation: The logistic regression loss is proportional to the KL divergence between the true distribution and the predicted distribution .
Gradient Computation for Optimization
Need for Gradient
- Purpose: Use gradient-based optimization methods (e.g., gradient descent) to minimize the loss function.
- Challenge: The loss function is convex but does not have a closed-form solution for and .
Computing the Gradient
- Gradient w.r.t Weights : where .
- Gradient w.r.t Bias :
- Derivation Highlights:
- Chain Rule: Used to compute derivatives of composite functions.
- Sigmoid Derivative:
- Simplification: The gradients simplify to expressions involving .
Matrix Notation
- Gradient Compact Form:
where:
- : Design matrix (stacked input vectors).
- : Vector of predicted probabilities.
- : Vector of true labels.
Similarity to Linear Regression
- Linear Regression Gradient:
- Observation: Logistic regression gradient resembles that of linear regression but with instead of .
Optimization Methods
No Closed-Form Solution
- Unlike linear regression, logistic regression does not have a closed-form solution for and .
- Reason: The sigmoid function introduces nonlinearity.
Gradient-Based Optimization
- Methods:
- Batch Gradient Descent: Updates parameters using the entire dataset.
- Stochastic Gradient Descent (SGD): Updates parameters using one sample at a time.
- Mini-Batch Gradient Descent: Updates parameters using subsets of the data.
- Algorithm:
- Initialize and .
- Compute the gradients and .
- Update parameters: where is the learning rate.
- Repeat until convergence.
Convexity of the Loss Function
- Property: The logistic regression loss function is convex.
- Implication: Any local minimum is the global minimum.
- Benefit: Guarantees that gradient-based methods will converge to the optimal solution (given appropriate learning rate and convergence criteria).
Regularization in Logistic Regression
Need for Regularization
- Purpose: Prevent overfitting by penalizing large weights.
- Approach: Add a regularization term to the loss function.
L2 Regularization (Ridge)
- Regularized Loss Function:
- Interpretation: Encourages smaller weights.
L1 Regularization (Lasso)
- Regularized Loss Function:
- Interpretation: Encourages sparsity in the weights.
Impact on Gradient
- Modified Gradient:
- For L2: .
- For L1: Gradient is less straightforward due to the absolute value.
Logistic Regression as a Generalized Linear Model (GLM)
Connection to GLMs
- GLMs: Extend linear models to allow the dependent variable to have a non-normal distribution.
- Components:
- Random Component: Specifies the distribution of the response variable (e.g., Bernoulli).
- Systematic Component: Linear predictor .
- Link Function: Connects the mean of the distribution to the linear predictor (e.g., logistic function).
Canonical Link Function
- Definition: The link function that leads to desirable mathematical properties.
- For Logistic Regression: The logistic function is the canonical link function for the Bernoulli distribution.
Practical Considerations
Feature Scaling
- Importance: Helps in faster convergence of gradient-based methods.
- Methods:
- Standardization (zero mean, unit variance).
- Normalization (scaling features to a specific range).
Choice of Learning Rate
- Trade-off:
- Too Large: May cause the algorithm to diverge.
- Too Small: Slow convergence.
- Adaptive Methods: Algorithms like Adam, RMSProp adjust the learning rate during training.
Handling Imbalanced Data
- Issue: Class imbalance can bias the model toward the majority class.
- Solutions:
- Resampling techniques (oversampling minority class, undersampling majority class).
- Using evaluation metrics suitable for imbalanced data (precision, recall, F1-score).
Evaluation Metrics
- Accuracy: May be misleading with imbalanced data.
- Confusion Matrix: Provides detailed insights.
- ROC Curve and AUC: Evaluate the trade-off between true positive rate and false positive rate.
- Precision-Recall Curve: More informative with imbalanced datasets.
Conclusion
Logistic regression is a powerful and widely used algorithm for binary classification tasks. By modeling the probability of class membership using the logistic function, it provides both a probabilistic framework and a linear decision boundary.
Understanding logistic regression lays the foundation for more advanced classification algorithms and deep learning models. Its principles are fundamental in machine learning and are essential knowledge for any practitioner.
Recap
- Bernoulli Distribution: Used for modeling binary outcomes.
- Logistic Function: Maps linear combinations to probabilities between 0 and 1.
- Logistic Regression Model: Predicts the probability of class membership.
- Maximum Likelihood Estimation: Used to derive the loss function.
- Cross-Entropy Loss: The negative log-likelihood function for logistic regression.
- Gradient Computation: Necessary for optimizing the loss function.
- Convexity: Ensures that gradient-based methods converge to the global minimum.
- Regularization: Prevents overfitting by penalizing large weights.
- GLMs: Logistic regression is a special case of generalized linear models.