-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclient.py
More file actions
276 lines (241 loc) · 12.2 KB
/
client.py
File metadata and controls
276 lines (241 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
# client.py - Handles client-side logic for U-shaped split learning.
# Runs training sequentially (train cmd) or testing (test cmd).
# ADDED TIMING
import jax
import jax.numpy as jnp
from jax import jit
import flax.linen as nn
from flax.training import train_state
from flax import serialization
import optax
import numpy as np
import os
import time # Import time for timing
import requests
import sys
# --- Config ---
CONFIG = {
"epochs": 3,
"lr": 0.001,
"batch_size": 64,
"num_clients": 3,
"data_dir": "./data_splits_jax",
"model_dir": "./trained_model"
}
SERVER_URL = "http://127.0.0.1:5000"
CLIENT_PART1_WEIGHTS = "client_part1.msgpack"
CLIENT_PART2_WEIGHTS = "client_part2.msgpack"
BATCH_SIZE = CONFIG['batch_size']
# --- 1. Client Model Parts ---
class ClientModelPart1(nn.Module):
# First part of U
@nn.compact
def __call__(self, x): # Input: (batch, 28, 28, 1), unnormalized
x = nn.Conv(features=16, kernel_size=(5, 5), padding='SAME')(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
return x # Output: (batch, 14, 14, 16)
class ServerModel(nn.Module):
# Need server definition for loading weights during test
@nn.compact
def __call__(self, x): # Input: (batch, 14, 14, 16)
x = nn.Conv(features=32, kernel_size=(5, 5), padding='SAME')(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
return x # Output: (batch, 7, 7, 32)
class ClientModelPart2(nn.Module):
# Last part of U
@nn.compact
def __call__(self, x): # Input: (batch, 7, 7, 32)
x = x.reshape((x.shape[0], -1)) # Flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x # Output: (batch, 10) logits
# --- 2. Data Loading ---
def load_client_data(client_id):
file_path = os.path.join(CONFIG['data_dir'], f'client_{client_id}_data.npz')
with np.load(file_path) as data: return data['data'], data['labels']
def load_test_data():
file_path = os.path.join(CONFIG['data_dir'], 'test_data.npz')
with np.load(file_path) as data: return data['data'], data['labels']
def data_iterator(key, data, labels, batch_size):
# Simple loader, drops incomplete last batch
num_samples = data.shape[0]
num_batches = num_samples // batch_size
if num_samples % batch_size != 0: pass # Skip warning
indices = jax.random.permutation(key, num_samples)[:num_batches * batch_size]
indices = indices.reshape((num_batches, batch_size))
for batch_indices in indices: yield data[batch_indices], labels[batch_indices]
# --- 3. Client Weight Management ---
def load_or_init_client_states():
# Loads weights or init if files don't exist
key = jax.random.PRNGKey(0); key, c1_key, c2_key = jax.random.split(key, 3)
client_model1, client_model2 = ClientModelPart1(), ClientModelPart2()
# Base states needed for structure
dummy_input_c1 = jnp.ones([1, 28, 28, 1]); params_c1 = client_model1.init(c1_key, dummy_input_c1)
optimizer_c1 = optax.adam(CONFIG['lr'])
c1_state = train_state.TrainState.create(apply_fn=client_model1.apply, params=params_c1, tx=optimizer_c1)
dummy_input_c2 = jnp.ones([1, 7, 7, 32]); params_c2 = client_model2.init(c2_key, dummy_input_c2)
optimizer_c2 = optax.adam(CONFIG['lr'])
c2_state = train_state.TrainState.create(apply_fn=client_model2.apply, params=params_c2, tx=optimizer_c2)
# Load Part 1
if os.path.exists(CLIENT_PART1_WEIGHTS):
print("Loading Client Part 1 weights...")
with open(CLIENT_PART1_WEIGHTS, 'rb') as f: c1_bytes = f.read()
c1_state = serialization.from_bytes(c1_state, c1_bytes)
else: print("Initializing Client Part 1 weights...")
# Load Part 2
if os.path.exists(CLIENT_PART2_WEIGHTS):
print("Loading Client Part 2 weights...")
with open(CLIENT_PART2_WEIGHTS, 'rb') as f: c2_bytes = f.read()
c2_state = serialization.from_bytes(c2_state, c2_bytes)
else: print("Initializing Client Part 2 weights...")
return c1_state, c2_state
def save_client_states(c1_state, c2_state):
# Saves current client weights
with open(CLIENT_PART1_WEIGHTS, 'wb') as f: f.write(serialization.to_bytes(c1_state))
with open(CLIENT_PART2_WEIGHTS, 'wb') as f: f.write(serialization.to_bytes(c2_state))
print("Client weights saved.")
# --- 4. JITted Client Computations ---
@jit
def client1_forward_jitted(params, batch_data):
# Forward pass for first part, returns output and pullback
return jax.vjp(ClientModelPart1().apply, params, batch_data)
@jit
def client2_backward_jitted(params, server_activations, labels):
# Forward pass for second part + loss + gradients
def loss_fn(client2_params, server_activations_input):
logits = ClientModelPart2().apply(client2_params, server_activations_input)
return optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
# Grads for my weights (arg 0) and grads to send back (arg 1)
grad_fn = jax.grad(loss_fn, argnums=(0, 1))
grads = grad_fn(params, server_activations)
loss_val = loss_fn(params, server_activations)
return grads, loss_val
@jit
def client1_backward_jitted(client1_vjp_pullback, smashed_data_grads):
# Backward pass for first part using pullback
return client1_vjp_pullback(smashed_data_grads)[0]
# --- 5. Train Function (one client's turn) ---
def train_one_client(client_id, c1_state, c2_state):
print(f"\n--- Training Client {client_id} ---")
data, labels = load_client_data(client_id)
key = jax.random.PRNGKey(int(time.time()) + client_id) # Seed differently per client
client_start_time = time.time() # Start timer for this client
for epoch in range(CONFIG['epochs']):
key, data_key = jax.random.split(key)
data_gen = data_iterator(data_key, data, labels, CONFIG['batch_size'])
for batch_idx, (batch_data, batch_labels) in enumerate(data_gen):
try:
# --- U-SHAPE FLOW ---
# 1. My Forward 1
smashed_data, client1_vjp = client1_forward_jitted(c1_state.params, batch_data)
# 2. Network Call -> Server Forward
res = requests.post(f"{SERVER_URL}/forward", data=serialization.to_bytes(smashed_data))
if res.status_code != 200: raise Exception(f"Server forward error: {res.text}")
step_id = res.headers.get('Step-ID')
dummy_server_output = jnp.ones([BATCH_SIZE, 7, 7, 32], dtype=jnp.float32)
server_activations = serialization.from_bytes(dummy_server_output, res.content)
# 3. My Backward 2 (gets loss & grads)
(client2_grads, server_activations_grads), loss = client2_backward_jitted(
c2_state.params, server_activations, batch_labels
)
# 4. Network Call -> Server Backward
headers = {'Step-ID': step_id}
res = requests.post(f"{SERVER_URL}/backward", data=serialization.to_bytes(server_activations_grads), headers=headers)
if res.status_code != 200: raise Exception(f"Server backward error: {res.text}")
dummy_smashed_grads = jnp.ones([BATCH_SIZE, 14, 14, 16], dtype=jnp.float32)
smashed_data_grads = serialization.from_bytes(dummy_smashed_grads, res.content)
# 5. My Backward 1
client1_grads = client1_backward_jitted(client1_vjp, smashed_data_grads)
# 6. Update my weights
c1_state = c1_state.apply_gradients(grads=client1_grads)
c2_state = c2_state.apply_gradients(grads=client2_grads)
# --- END ---
if batch_idx % 10 == 0: # Log often
print(f" C{client_id} E{epoch+1} B{batch_idx} Loss: {loss:.4f}")
except requests.exceptions.ConnectionError:
print(f"\n--- FATAL ERROR: Cannot connect to server at {SERVER_URL}. ---")
return c1_state, c2_state, False # Indicate failure
except Exception as e:
print(f"\n--- ERROR C{client_id} B{batch_idx}: {e} ---")
continue # Skip batch
client_end_time = time.time() # Stop timer
print(f"--- Client {client_id} Epochs Complete (took {client_end_time - client_start_time:.2f}s) ---")
return c1_state, c2_state, True # Indicate success
# --- 6. Test Function ---
def test():
print("--- Final Test & Publish ---")
test_start_time = time.time() # Start test timer
c1_state, c2_state = load_or_init_client_states()
# Get server weights
print("Downloading server weights...")
try:
res = requests.get(f"{SERVER_URL}/get_weights")
if res.status_code != 200: raise Exception(f"Server weight error: {res.text}")
# Need base state structure to load into
s_model = ServerModel(); key = jax.random.PRNGKey(1); dummy_input = jnp.ones([1, 14, 14, 16])
params = s_model.init(key, dummy_input); optimizer = optax.adam(0.001)
s_state_base = train_state.TrainState.create(apply_fn=s_model.apply, params=params, tx=optimizer)
s_state = serialization.from_bytes(s_state_base, res.content)
print("Server weights loaded.")
except Exception as e:
print(f"Could not get server weights: {e}")
return
# JITted eval function
@jit
def eval_step(c1_params, s_params, c2_params, batch_data, labels):
smashed = ClientModelPart1().apply(c1_params, batch_data)
server_act = ServerModel().apply(s_params, smashed)
logits = ClientModelPart2().apply(c2_params, server_act)
return jnp.mean(jnp.argmax(logits, axis=-1) == labels)
# Run evaluation
test_data, test_labels = load_test_data()
key = jax.random.PRNGKey(0)
test_iter = data_iterator(key, test_data, test_labels, CONFIG['batch_size'])
accuracies = [eval_step(c1_state.params, s_state.params, c2_state.params, d, l) for d, l in test_iter]
final_accuracy = jnp.mean(jnp.array(accuracies))
test_end_time = time.time() # Stop test timer
print(f"\n--- Final Accuracy: {final_accuracy * 100:.2f}% (Testing took {test_end_time - test_start_time:.2f}s) ---")
# Save and upload combined model parameters to IPFS
print("\n--- Saving and Uploading Final Model ---")
if not os.path.exists(CONFIG['model_dir']): os.makedirs(CONFIG['model_dir'])
# Bundle just the params
final_model_params = {'client1': c1_state.params, 'server': s_state.params, 'client2': c2_state.params}
model_bytes = serialization.to_bytes(final_model_params)
model_path = os.path.join(CONFIG['model_dir'], "final_u_model.msgpack")
with open(model_path, "wb") as f: f.write(model_bytes)
print(f"Final model params saved: {model_path}")
# Re-use IPFS upload function
def add_to_ipfs(file_path):
try:
with open(file_path, 'rb') as f:
response = requests.post('http://127.0.0.1:5001/api/v0/add', files={'file': f})
return response.json()['Hash'] if response.status_code == 200 else None
except: return None
model_cid = add_to_ipfs(model_path)
if model_cid: print(f"\n--- Model Uploaded --- CID: {model_cid}")
else: print("\n--- Failed model upload ---")
# --- Entry Point ---
if __name__ == "__main__":
if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test']:
print("Usage: python3 client.py [train|test]")
sys.exit(1)
mode = sys.argv[1]
if mode == 'test':
test()
elif mode == 'train':
overall_start_time = time.time() # Start overall timer
c1_state, c2_state = load_or_init_client_states()
all_success = True
# Run clients 1, 2, 3 sequentially
for client_num in range(1, CONFIG['num_clients'] + 1):
c1_state, c2_state, success = train_one_client(client_num, c1_state, c2_state)
if not success: all_success = False; break # Stop on error
overall_end_time = time.time() # Stop overall timer
# Save final state if all succeeded
if all_success:
save_client_states(c1_state, c2_state)
print(f"\n--- All client training finished successfully (Total time: {overall_end_time - overall_start_time:.2f}s) ---")
else: print("\n--- Training stopped due to error. ---")