Skip to main content

Federated Learning

This page uses Wasm code blocks so you can run the examples directly in the browser.

This guide will demonstrate a simple federated learning task in which two local linear regression models, each fitted on local data, are combined in a global model whose weights are then transferred to the local models.

Simulating multiple edges

Federated learning in an edge computing context usually involves contributions from multiple edge instances. Ordinarily, virtual edges can be started using the start_edges() function, but this is not supported in Wasm. We'll therefore use an approach that relies on distinct model names instead.

First, load the built-in OSQL model for linear regression.

system_models:load("linear_regression");

Next, create some utility functions for setting and getting model weights.

// Create ML model
create function add_federated_model(charstring model_name) -> object
as new_lr_model(model_name,[1],2);

// Get ML model
create function get_federated_model(charstring model_name) -> object
as merge(lr_model(model_name), {"r2": 0});

// Set weights of ML model
create function set_weights(charstring model_name, vector of number weights) -> record
as {
declare Record rec;
set rec = merge(lr_model(model_name), {"weights": weights});
set lr_model(model_name) = rec;
return rec;
};

// Train ML model
create function train_local(charstring model_name, matrix data, number learning_rate, integer num_rounds) -> record
as {
declare Record rec;
set rec = linear_regression(model_name, data, learning_rate, num_rounds);
set lr_model(model_name) = rec;
return rec;
};

And some functions for synthetic data:

// Generate a synthetic data point
create function syn_data_point(number slope, number intercept) -> vector of number
as select cast([x, x * slope + intercept + noise] as vector of number) from number x, number noise where x = frand(10) and noise = frand(-0.5, 0.5);

// Generate n synthetic data points
create function syn_data(number slope, number intercept, integer n) -> matrix
as select vector of syn_data_point(slope, intercept)
from integer i where i in range(n);

Running on different "edges"​

// Linear regression models
set :server_model_name = "lr_server";
set :edge_models = [{"name": "lr_edge_1", "slope": frand(1.8, 2.2), "intercept": frand(9, 11)},
{"name": "lr_edge_2", "slope": frand(1.8, 2.2), "intercept": frand(9, 11)}];

// Initialize server model
add_federated_model(:server_model_name);

// Initialize edge models
select add_federated_model(cast(edge["name"] as charstring)) from record edge where edge in :edge_models;

create function run_fedml_lr(charstring server_model_name, vector edge_models, integer num_epochs, number learning_rate, integer rounds) -> stream of record
as for each integer e where e in range(num_epochs)
{
declare vector edge_weights, vector new_weights;
set edge_weights = [0, 0];
for each record m where m in edge_models
{
declare charstring name, number slope, number intercept, matrix data;
set name = m["name"];
set slope = m["slope"];
set intercept = m["intercept"];
set data = syn_data(slope, intercept, 10);
train_local(name, data, learning_rate, rounds);
set edge_weights = edge_weights + cast(lr_model(name)["weights"] as vector);
};

set new_weights = edge_weights / dim(edge_models);
set_weights(server_model_name, new_weights);

for each record m where m in edge_models
{
declare charstring name;
set name = m["name"];
set_weights(name, new_weights);
};

return new_weights;
};

In the code above, the session variable :edge_models (line 3) contains two "edges" whose ground truths are initialized to random values near a slope of 2 and a y-intercept of 10.

The function run_fedml_lr contains the main federated learning loop. This loop is implemented using a procedural function. For each edge, simulated data is generated using that edge's ground truth settings (line 23). The edge is then trained on the simulated data (line 24), and its weights are added to the variable edge_weights.

In the following lines (28-29), the edge weights are averaged and used to update the weights of the server model.

Finally, in lines 31-26, the updated server model is deployed to the edges.

The following code calls the run_fedml_lr function, running the federated learning loop for 100 epochs. As expected, the slope will be close to 2 and the y-intercept close to 10.

//plot: Line plot
select {
"Intercept": x[1],
"Slope": x[2]
}
from JSON x
where x in
run_fedml_lr(:server_model_name, :edge_models, 100, .01, 10);