This notebook contains:
torch implementations of a few linear algebra techniques:
initial implementations of secure linear regression and Jonathan Bloom's DASH that leverage PySyft for secure computation.
These implementations linear regression and DASH are not currently strictly secure, in that a few final steps are performed on the local worker for now. That's because our implementations of LDLt decomposition, QR decomposition, etc. don't quite work for the PySyft AdditiveSharingTensor
just yet. They definitely do in principle (because they're compositions of operations the SPDZ supports), but there are still a few details to hammer out.
import numpy as np
import torch as th
import syft as sy
from scipy import stats
WARNING: Logging before flag parsing goes to stderr. W0710 23:13:43.013911 4542494144 secure_random.py:26] Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/Users/andrew/.virtualenvs/pysyft/lib/python3.7/site-packages/tf_encrypted-0.5.6-py3.7-macosx-10.14-x86_64.egg/tf_encrypted/operations/secure_random/secure_random_module_tf_1.14.0.so' W0710 23:13:43.023926 4542494144 deprecation_wrapper.py:119] From /Users/andrew/.virtualenvs/pysyft/lib/python3.7/site-packages/tf_encrypted-0.5.6-py3.7-macosx-10.14-x86_64.egg/tf_encrypted/session.py:26: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.
sy.create_sandbox(globals())
Setting up Sandbox... - Hooking PyTorch - Creating Virtual Workers: - bob - theo - jason - alice - andy - jon Storing hook and workers as global variables... Loading datasets from SciKit Learn... - Boston Housing Dataset - Diabetes Dataset - Breast Cancer Dataset - Digits Dataset - Iris Dataset - Wine Dataset - Linnerud Dataset Distributing Datasets Amongst Workers... Collecting workers into a VirtualGrid... Done!
These are torch implementations of basic linear algebra routines we'll use to perform regression (and also in parts of the next section).
def _eye(n):
"""th.eye doesn't seem to work after hooking torch, so just adding
a workaround for now.
"""
return th.FloatTensor(np.eye(n))
def ldlt_decomposition(x):
"""Decompose the square, symmetric, full-rank matrix X as X = LDL^t, where
- L is upper triangular
- D is diagonal.
"""
n, _ = x.shape
l, diag = _eye(n), th.zeros(n).float()
for j in range(n):
diag[j] = x[j, j] - (th.sum((l[j, :j] ** 2) * diag[:j]))
for i in range(j + 1, n):
l[i, j] = (x[i, j] - th.sum(diag[:j] * l[i, :j] * l[j, :j])) / diag[j]
return l, th.diag(diag), l.transpose(0, 1)
def back_solve(u, y):
"""Solve Ux = y for U a square, upper triangular matrix of full rank"""
n = u.shape[0]
x = th.zeros(n)
for i in range(n - 1, -1, -1):
x[i] = (y[i] - th.sum(u[i, i+1:] * x[i+1:])) / u[i, i]
return x.reshape(-1, 1)
def forward_solve(l, y):
"""Solve Lx = y for L a square, lower triangular matrix of full rank."""
n = l.shape[0]
x = th.zeros(n)
for i in range(0, n):
x[i] = (y[i] - th.sum(l[i, :i] * x[:i])) / l[i, i]
return x.reshape(-1, 1)
def invert_triangular(t, upper=True):
"""
Invert by repeated forward/back-solving.
TODO: -Could be made more efficient with vectorized implementation of forward/backsolve
-detection and validation around triangularity/squareness
"""
solve = back_solve if upper else forward_solve
t_inv = th.zeros_like(t)
n = t.shape[0]
for i in range(n):
e = th.zeros(n, 1)
e[i] = 1.
t_inv[:, [i]] = solve(t, e)
return t_inv
def solve_symmetric(a, y):
"""Solve the linear system Ax = y where A is a symmetric matrix of full rank."""
l, d, lt = ldlt_decomposition(a)
# TODO: more efficient to just extract diagonal of d as 1D vector and scale?
x_ = forward_solve(l.mm(d), y)
return back_solve(lt, x_)
"""
Basic tests for LDLt decomposition.
"""
def _assert_small(x, failure_msg=None, threshold=1E-5):
norm = x.norm()
assert norm < threshold, failure_msg
def test_ldlt_case(a):
l, d, lt = ldlt_decomposition(a)
_assert_small(l - lt.transpose(0, 1))
_assert_small(l.mm(d).mm(lt) - a, 'Decomposition is inaccurate.')
_assert_small(l - th.tril(l), 'L is not lower triangular.')
_assert_small(th.triu(th.tril(d)) - d, 'D is not diagonal.')
print(f'PASSED for {a}')
def test_solve_symmetric_case(a, x):
y = a.mm(x)
_assert_small(solve_symmetric(a, y) - x)
print(f'PASSED for {a}, {x}')
a = th.tensor([[1, 2, 3],
[2, 1, 2],
[3, 2, 1]]).float()
x = th.tensor([1, 2, 3]).float().reshape(-1, 1)
test_ldlt_case(a)
test_solve_symmetric_case(a, x)
PASSED for tensor([[1., 2., 3.], [2., 1., 2.], [3., 2., 1.]]) PASSED for tensor([[1., 2., 3.], [2., 1., 2.], [3., 2., 1.]]), tensor([[1.], [2.], [3.]])
We're solving $$ \min_\beta \|X \beta - y\|_2 $$ in the situation where the data $(X, y)$ is horizontally partitioned (each worker $w$ owns chunks $X_w, y_w$ of the rows of $X$ and $y$).
We want to do this
solve_symmetric
defined above).The correct $\beta$ is $[1, 2, -1]$
X = th.tensor(10 * np.random.randn(30000, 3))
y = (X[:, 0] + 2 * X[:, 1] - X[:, 2]).reshape(-1, 1)
Split the data into chunks and send a chunk to each worker, storing pointers to chunks in two MultiPointerTensor
s.
workers = [alice, bob, theo]
crypto_provider = jon
chunk_size = int(X.shape[0] / len(workers))
def _get_chunk_pointers(data, chunk_size, workers):
return [
data[(i * chunk_size):((i+1)*chunk_size), :].send(worker)
for i, worker in enumerate(workers)
]
X_ptrs = sy.MultiPointerTensor(
children=_get_chunk_pointers(X, chunk_size, workers))
y_ptrs = sy.MultiPointerTensor(
children=_get_chunk_pointers(y, chunk_size, workers))
This is the only step that depends on the number of rows of $X, y$, and it's performed locally on each worker in plain text. The result is two MultiPointerTensor
s with pointers to each workers' summand of $X^tX$ (or $X^ty$).
Xt_ptrs = X_ptrs.transpose(0, 1)
XtX_summand_ptrs = Xt_ptrs.mm(X_ptrs)
Xty_summand_ptrs = Xt_ptrs.mm(y_ptrs)
We add those summands up in two steps:
def _generate_shared_summand_pointers(
summand_ptrs,
workers,
crypto_provider):
for worker_id, summand_pointer in summand_ptrs.child.items():
shared_summand_pointer = summand_pointer.fix_precision().share(
*workers, crypto_provider=crypto_provider)
yield shared_summand_pointer.get()
XtX_shared = sum(
_generate_shared_summand_pointers(
XtX_summand_ptrs, workers, crypto_provider))
Xty_shared = sum(_generate_shared_summand_pointers(
Xty_summand_ptrs, workers, crypto_provider))
The coefficient $\beta$ is the solution to $$X^t X \beta = X^t y$$
We solve for $\beta$ using solve_symmetric
. Critically, this is a composition of linear operations that should be supported by AdditiveSharingTensor
. Unlike the classic Cholesky decomposition, the $LDL^t$ decomposition in step 1 does not involve taking square roots, which would be challenging.
TODO: there's still some additional work required to get solve_symmetric
working for AdditiveSharingTensor
, so we're performing the final linear solve publicly for now.
beta = solve_symmetric(XtX_shared.get().float_precision(), Xty_shared.get().float_precision())
beta
tensor([[ 1.0000], [ 2.0000], [-1.0000]])
A $m \times n$ real matrix $A$ with $m \geq n$ can be written as $$A = QR$$ for $Q$ orthogonal and $R$ upper triangular. This is helpful in solving systems of equations, among other things. It is also central to the compression idea of DASH.
"""
Full QR decomposition via Householder transforms,
following Numerical Linear Algebra (Trefethen and Bau).
"""
def _apply_householder_transform(a, v):
return a - 2 * v.mm(v.transpose(0, 1).mm(a))
def _build_householder_matrix(v):
n = v.shape[0]
u = v / v.norm()
return _eye(n) - 2 * u.mm(u.transpose(0, 1))
def _householder_qr_step(a):
x = a[:, 0].reshape(-1, 1)
alpha = x.norm()
u = x.copy()
# note: can get better stability by multiplying by sign(u[0, 0])
# (where sign(0) = 1); is this supported in the secure context?
u[0, 0] += u.norm()
# is there a simple way of getting around computing the norm twice?
u /= u.norm()
a = _apply_householder_transform(a, u)
return a, u
def _recover_q(householder_vectors):
"""
Build the matrix Q from the Householder transforms.
"""
n = len(householder_vectors)
def _apply_transforms(x):
"""Trefethen and Bau, Algorithm 10.3"""
for k in range(n-1, -1, -1):
x[k:, :] = _apply_householder_transform(
x[k:, :],
householder_vectors[k])
return x
m = householder_vectors[0].shape[0]
n = len(householder_vectors)
q = th.zeros(m, m)
# Determine q by evaluating it on a basis
for i in range(m):
e = th.zeros(m, 1)
e[i] = 1.
q[:, [i]] = _apply_transforms(e)
return q
def qr(a, return_q=True):
"""
Args:
a: shape (m, n), m >= n
return_q: bool, whether to reconstruct q
Returns:
orthogonal q of shape (m, m) (None if return_q is False)
upper-triangular of shape (m, n)
"""
m, n = a.shape
assert m >= n, \
f"Passed a of shape {a.shape}, must have a.shape[0] >= a.shape[1]"
r = a.copy()
householder_unit_normal_vectors = []
for k in range(n):
r[k:, k:], u = _householder_qr_step(r[k:, k:])
householder_unit_normal_vectors.append(u)
if return_q:
q = _recover_q(householder_unit_normal_vectors)
else:
q = None
return q, r
"""
Basic tests for QR decomposition
"""
def _test_qr_case(a):
q, r = qr(a)
# actually have QR = A
_assert_small(q.mm(r) - a, "QR = A failed")
# Q is orthogonal
m, _ = a.shape
_assert_small(
q.mm(q.transpose(0, 1)) - _eye(m),
"QQ^t = I failed"
)
# R is upper triangular
lower_triangular_entries = th.tensor([
r[i, j].item() for i in range(r.shape[0])
for j in range(i)])
_assert_small(
lower_triangular_entries,
"R is not upper triangular"
)
print(f"PASSED for \n{a}\n")
def test_qr():
_test_qr_case(
th.tensor([[1, 0, 1],
[1, 1, 0],
[0, 1, 1]]).float()
)
_test_qr_case(
th.tensor([[1, 0, 1],
[1, 1, 0],
[0, 1, 1],
[1, 1, 1],]).float()
)
test_qr()
PASSED for tensor([[1., 0., 1.], [1., 1., 0.], [0., 1., 1.]]) PASSED for tensor([[1., 0., 1.], [1., 1., 0.], [0., 1., 1.], [1., 1., 1.]])
We follow https://github.com/jbloom22/DASH/.
The overall structure is roughly analogous to the linear regression example above.
AdditiveSharingTensor
.def _generate_worker_data_pointers(
n, m, k, worker,
beta_correct, gamma_correct, epsilon=0.01
):
"""
Return pointers to worker-level data.
Args:
n: number of rows
m: number of transient
k: number of covariates
beta_correct: coefficients for transient features (tensor of shape (m, 1))
gamma_correct: coefficients for covariates (tensor of shape (k, 1))
epsilon: scale of noise added to response
Return:
y, X, C: pointers to response, transients, and covariates
"""
X = th.randn(n, m).send(worker)
C = th.randn(n, k).send(worker)
y = (X.mm(beta_correct.copy().send(worker)).reshape(-1, 1) +
C.mm(gamma_correct.copy().send(worker)).reshape(-1, 1))
y += (epsilon * th.randn(n, 1)).send(worker)
return y, X, C
def _dot(x):
return (x * x).sum(dim=0).reshape(-1, 1)
def _secure_sum(worker_level_pointers, workers, crypto_provider):
"""
Securely add up an interable of pointers to (same-sized) tensors.
Args:
worker_level_pointers: iterable of pointer tensors
workers: list of workers
crypto_provider: worker
Returns:
AdditiveSharingTensor shared among workers
"""
return sum([
p.fix_precision(precision_fractional=10).share(*workers, crypto_provider=crypto_provider).get()
for p in worker_level_pointers
])
def dash_example_secure(
workers, crypto_provider,
n_samples_by_worker, m, k,
beta_correct, gamma_correct,
epsilon=0.01
):
"""
Args:
workers: list of workers
crypto_provider: worker
n_samples_by_worker: dict mapping worker ids to ints (number of rows of data)
m: number of transients
k: number of covariates
beta_correct: coefficient for transient features
gamma_correct: coefficient for covariates
epsilon: scale of noise added to response
Returns:
beta, sigma, tstat, pval: coefficient of transients and accompanying statistics
"""
# Generate each worker's data
worker_data_pointers = {
p: _generate_worker_data_pointers(
n, m, k, workers[p],
beta_correct, gamma_correct,
epsilon=epsilon)
for p, n in n_samples_by_worker.items()
}
# to be populated with pointers to results of local, worker-level computations
Ctys, CtXs, yys, Xys, XXs, Rs = {}, {}, {}, {}, {}, {}
def _sum(pointers):
return _secure_sum(pointers, list(players.values()), crypto_provider)
# worker-level compression step
for p, (y, X, C) in worker_data_pointers.items():
# perform worker-level compression step
yys[p] = y.norm()
Xys[p] = X.transpose(0, 1).mm(y)
XXs[p] = _dot(X)
Ctys[p] = C.transpose(0, 1).mm(y)
CtXs[p] = C.transpose(0, 1).mm(X)
_, R_full = qr(C, return_q=False)
Rs[p] = R_full[:k, :]
# Perform secure sum
# - We're returning result to the local worker and computing there for the rest
# of the way, but should be possible to compute via SMPC (on a pointers to AdditiveSharingTensors)
# - still afew minor-looking issues with implementing invert_triangular/qr for
# AdditiveSharingTensor
yy = _sum(yys.values()).get().float_precision()
Xy = _sum(Xys.values()).get().float_precision()
XX = _sum(XXs.values()).get().float_precision()
Cty = _sum(Ctys.values()).get().float_precision()
CtX = _sum(CtXs.values()).get().float_precision()
# Rest is done publicly on the local worker for now
_, R_public = qr(
th.cat([R.get() for R in Rs.values()], dim=0),
return_q=False)
invR_public = invert_triangular(R_public[:k, :])
Qty = invR_public.transpose(0, 1).mm(Cty)
QtX = invR_public.transpose(0, 1).mm(CtX)
QtXQty = QtX.transpose(0, 1).mm(Qty)
QtyQty = _dot(Qty)
QtXQtX = _dot(QtX)
yyq = yy - QtyQty
Xyq = Xy - QtXQty
XXq = XX - QtXQtX
d = sum(n_samples_by_worker.values()) - k - 1
beta = Xyq / XXq
sigma = ((yyq / XXq - (beta ** 2)) / d).abs() ** 0.5
tstat = beta / sigma
pval = 2 * stats.t.cdf(-abs(tstat), d)
return beta, sigma, tstat, pval
players = {
worker.id: worker
for worker in [alice, bob, theo]
}
# de
n_samples_by_player = {
alice.id: 100000,
bob.id: 200000,
theo.id: 100000
}
crypto_provider = jon
m = 100
k = 3
d = sum(n_samples_by_player.values()) - k - 1
beta_correct = th.ones(m, 1)
gamma_correct = th.ones(k, 1)
dash_example_secure(
players, crypto_provider,
n_samples_by_player, m, k,
beta_correct, gamma_correct)
(tensor([[1.0198], [0.9758], [0.9643], [1.0004], [1.0150], [0.9944], [0.9961], [1.0318], [1.0002], [0.9830], [0.9790], [0.9926], [1.0008], [0.9697], [0.9954], [0.9995], [0.9960], [0.9767], [1.0059], [0.9838], [0.9911], [1.0179], [1.0080], [0.9829], [0.9937], [0.9819], [1.0188], [0.9811], [0.9971], [0.9866], [1.0117], [0.9953], [0.9966], [0.9952], [0.9957], [0.9860], [1.0206], [0.9928], [0.9925], [1.0149], [0.9587], [0.9851], [1.0102], [1.0127], [1.0143], [1.0050], [0.9926], [0.9646], [0.9966], [0.9906], [1.0212], [0.9948], [1.0253], [0.9936], [0.9834], [0.9770], [0.9885], [0.9890], [0.9954], [0.9900], [0.9795], [0.9657], [0.9836], [1.0042], [0.9957], [0.9929], [1.0127], [0.9869], [0.9969], [1.0172], [1.0030], [0.9844], [1.0121], [1.0071], [0.9954], [0.9936], [0.9954], [1.0070], [0.9928], [0.9900], [0.9970], [0.9992], [0.9851], [0.9942], [0.9710], [0.9799], [0.9675], [1.0246], [1.0085], [0.9906], [0.9984], [1.0182], [0.9805], [0.9905], [1.0034], [0.9965], [0.9983], [0.9973], [0.9872], [0.9937]]), tensor([[0.0032], [0.0032], [0.0031], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0031], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0031], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0031], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0031], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032], [0.0032]]), tensor([[319.5643], [309.3729], [306.5470], [315.1175], [318.3896], [313.9088], [314.0636], [322.6956], [315.1400], [311.2341], [310.3611], [313.5351], [315.4006], [307.5811], [313.6877], [314.7977], [314.0446], [309.6685], [316.4051], [311.1368], [313.2263], [319.4759], [316.2847], [310.5904], [313.5548], [311.4909], [319.7083], [310.5916], [314.3991], [311.7694], [317.7445], [314.2973], [313.8512], [313.7931], [313.8863], [311.7593], [320.1660], [313.1064], [312.8649], [318.5549], [305.3120], [311.6055], [317.7577], [318.0811], [318.6992], [316.0238], [313.5295], [306.4926], [314.0880], [312.5821], [319.9753], [313.6790], [321.0503], [313.4674], [310.9974], [310.1643], [312.6641], [312.5327], [313.7359], [312.9331], [310.3143], [307.3026], [311.3515], [316.4857], [314.6939], [313.4534], [318.1405], [312.0404], [314.7274], [318.8090], [315.5815], [311.3936], [317.9680], [317.0735], [314.1143], [313.6470], [313.5154], [316.9404], [313.7480], [312.7427], [314.7064], [314.5428], [311.3454], [313.8516], [308.4349], [310.3224], [307.0361], [320.6432], [317.1022], [312.7644], [314.5857], [319.2417], [310.1894], [312.5483], [315.9466], [314.3875], [314.4272], [314.6446], [312.2668], [313.8906]]), array([[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]))