from ._base import BaseLinearModel, RegressorMixin
from jasmine.losses import mse_loss
[docs]
class LinearRegression(RegressorMixin, BaseLinearModel):
"""Linear regression trained with gradient descent."""
center_targets = True
log_every_epoch = True
[docs]
def __init__(
self,
use_bias=True,
learning_rate=0.01,
n_epochs=1000,
loss_function=mse_loss,
l1_penalty=0.0,
l2_penalty=0.0,
optimizer=None,
):
super().__init__(
use_bias=use_bias,
learning_rate=learning_rate,
n_epochs=n_epochs,
loss_function=loss_function,
l1_penalty=l1_penalty,
l2_penalty=l2_penalty,
optimizer=optimizer,
)
[docs]
def loss_fn(self, params, X, y):
predictions = self.forward(params, X)
loss = self.loss_function(y, predictions)
return loss + self._regularization_penalty(params)