This page uses Wasm code blocks so you can run the examples directly in the browser.
SA Engine supports multiple linear regression through the
linear_regression system model. It can both estimate given weights and train a model given a dataset. The training is done using Gradient descent.
linear_regression system model is loaded with the
To train the regression model we use the
set :s = (select vector of v
from Vector v
where v in csv:file_stream(system_models:folder("linear_regression")+
set :r = linear_regression(:s, // training data
0.0000001, // learning rate
100, // max iterations
, // indices to use in prediction
2); // index of value to predict
set :w = :r;
In the preceding example we train the regression model on a set of 2D vectors loaded from the file
linear_reg2.csv. We use
0.0000001 as learning rate for the gradient descent algorithm, we set the maximum number of iterations to
100, and we say that we want to use the first value in the vector to predict the second value.
The vector returned contains the weight vector and the Sum of Squares Error (SSE) for the regression model.
In the case where we have 2D data points
[x,y] and a linear regression model
y = kx + m to predict the value of
y based on the
x value (like in the code example above), the weight vector returned by
linear_regression() has the format
To apply the regression model on some value we use the
lr_estimate() function. It takes the weight vector and the input vector and returns the predicted value according to the regression model.
For example, in our trained regression model we got the weight vector
[0.00000100622172569312,0.00180844720146722]. This means that for an x-value of
1500 we should get a y-value of approximately
1500 * 0.00180844720146722 + 0.00000100622172569312 = 2.713. Let's see if that is correct.
The prediction was as expected.
This means that we can now use the trained regression model on the training data and we should get predictions along a straight line. The following code plots the training data in one color and the prediction for each point in the training set in another color.
//plot: Multi plot
"sa_plot": "Scatter plot",
select vector of y
from Vector of Number v, Vector of Number y
where v in :s
and y = [v, lr_estimate(:w, permute(v, )), 2]
or v in :s
and y = concat(v,);
Let's go through the code in detail.
Line 1-6: This is the visualization description that tells SA Engine how to visualize the output. More information about scatter plots can be found in the Multi plot section in the SA Studio manual.
Line 7: Returns a vector of vectors (i.e., a Matrix).
Line 8: Specifies the data types for the variables in the query.
Line 9-10: Picks each data point from the training set and creates a new vector
y = [x, p(x), 2], where
x is the x-value of the data point,
p(x) is the y-value predicted by the regression model, and
2 is simply a color id for the visualization (note that we specified
"color_axis": 3 in the visualization, which means that the third index in the vector determines the color). The function
permute() simply extracts the first element (the x-value) from the data point
v as a single-valued vector.
Line 11-12: Picks each data point from the training set and simply appends a
1 as third element (used as color ID for the scatter plot).