Source code for jasmine.linear_model._linear

from ._base import BaseLinearModel, RegressorMixin
from jasmine.losses import mse_loss


[docs] class LinearRegression(RegressorMixin, BaseLinearModel): """Linear regression trained with gradient descent.""" center_targets = True log_every_epoch = True
[docs] def __init__( self, use_bias=True, learning_rate=0.01, n_epochs=1000, loss_function=mse_loss, l1_penalty=0.0, l2_penalty=0.0, optimizer=None, ): super().__init__( use_bias=use_bias, learning_rate=learning_rate, n_epochs=n_epochs, loss_function=loss_function, l1_penalty=l1_penalty, l2_penalty=l2_penalty, optimizer=optimizer, )
[docs] def loss_fn(self, params, X, y): predictions = self.forward(params, X) loss = self.loss_function(y, predictions) return loss + self._regularization_penalty(params)