-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathBasicSA.py
More file actions
143 lines (122 loc) · 4.28 KB
/
BasicSA.py
File metadata and controls
143 lines (122 loc) · 4.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Lab-SA Basic SA for Federated Learning
import random, math
import SecureProtocol as sp
from ast import literal_eval
from learning.utils import sum_weights, add_to_weights
# Get common values for server set-up: n, t, ...
def getCommonValues():
# R: domain
# g: generator, p: prime
R = 100 #temp
commonValues = {"g": sp.g, "p": sp.p, "R": R}
return commonValues
# Generate two key pairs
def generateKeyPairs():
c_pk, c_sk = sp.generateKeyPair(sp.g, sp.p)
s_pk, s_sk = sp.generateKeyPair(sp.g, sp.p)
return (c_pk, c_sk), (s_pk, s_sk)
# Generate secret-shares of s_sk and bu and encrypt those data
# c_pk_dic [dictionary]: all users' c_pk of the current round
def generateSharesOfMask(t, u, s_sk, c_sk, c_pk_dic, R):
n = len(c_pk_dic)
bu = random.randrange(1, R) # 1~R
s_sk_shares_list = sp.make_shares(s_sk, t, n)
bu_shares_list = sp.make_shares(bu, t, n)
euv_list = []
for i, c_pk in c_pk_dic.items():
v = int(i)
if u == v: continue
s_uv = sp.agree(c_sk, c_pk, sp.p)
plainText = str([u, v, s_sk_shares_list[v], bu_shares_list[v]])
euv = sp.encrypt(s_uv, plainText)
euv_list.append((u, v, euv))
return euv_list, bu
def generateMaskedInput(u, bu, xu, s_sk, euv_list, s_pk_dic, R):
# compute p_uv
p_uv_list = []
for u, v, euv in euv_list: # euv_list = [ (u, v, euv), (u, v, euv) ... ]
if u == v:
continue
s_uv = sp.agree(s_sk, s_pk_dic[v], sp.p)
random.seed(s_uv)
p_uv = random.randrange(1, R) # 1~R
if u < v:
p_uv = -p_uv
p_uv_list.append(p_uv)
#compute bu
random.seed(bu)
pu = random.randrange(1, R) # 1~R
# make masked xu(: dic of weights)
mask = pu + sum(p_uv_list)
yu = add_to_weights(xu, mask)
return yu
# users_previous [list]: users who were alive in the previous round
# users_last [list]: users who were alive in the recent round
def unmasking(u, c_sk, euv_dic, c_pk_dic, users_previous, users_last):
s_sk_shares_dic = {}
bu_shares_dic = {}
for v in users_previous:
if v == u:
continue
try:
# decrypt
s_uv = sp.agree(c_sk, c_pk_dic[v], sp.p)
plainText = sp.decrypt(s_uv, euv_dic[v])
_v, _u, s_sk_shares, bu_shares = literal_eval(plainText) # list
if not(u == int(_u) and v == int(_v)):
raise Exception('Something went wrong during reconstruction.')
# s_sk_shares for drop-out users / bu_shars for surviving users
try:
users_last.remove(v) # v is in U3
bu_shares_dic[v] = bu_shares
except ValueError: # v is in U2\U3
s_sk_shares_dic[v] = s_sk_shares
except:
raise Exception('Decryption failed.')
return s_sk_shares_dic, bu_shares_dic
def reconstruct(shares_list):
return sp.combine_shares(shares_list)
def reconstructPvu(v, u, s_sk_v, s_pk_u, R):
s_uv = sp.agree(s_sk_v, s_pk_u, sp.p)
random.seed(s_uv)
p_vu = random.randrange(1, R)
if v < u:
p_vu = -p_vu
return p_vu
def reconstructPu(bu_shares_list, R): # list of user u
bu = sp.combine_shares(bu_shares_list)
random.seed(bu)
pu = random.randrange(1, R)
return pu
def generatingOutput(yu_list, mask):
sum_yu = sum_weights(yu_list)
sum_xu = add_to_weights(sum_yu, mask)
return sum_xu
def stochasticQuantization(weights, q, p):
# weights = local model of user
# q = quantization level
# p = large prime
quantized = []
for x in weights:
floor_qx = math.floor(q * x)
selected = int(random.choices(
population = [floor_qx / q, (floor_qx + 1) / q],
weights = [1 - (q * x - floor_qx), q * x - floor_qx],
k = 1 # select one
)[0] * q)
if selected < 0:
selected = selected + p
quantized.append(selected)
return quantized
def convertToRealDomain(weights, q, p):
# weights = local model of user
# q = quantization level
# p = large prime
real_numbers = []
m = (p - 1) / 2
for x in weights:
if 0 <= x and x < m:
real_numbers.append(x / q)
else: # (p-1)/2 <= x < p
real_numbers.append((x - p) / q)
return real_numbers