Skip to content

Commit 102113e

Browse files
jstacclaude
andauthored
Remove num_layers from Config and improve visualization (#265)
- Remove redundant num_layers field from Config class - Update Keras model builder to use len(config.layer_sizes) - Format MSE values to 6 decimal places in summary table - Change plots to show validation data instead of training data - Use red color with alpha=0.5 for scatter plots 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude <noreply@anthropic.com>
1 parent 9724c6c commit 102113e

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

lectures/ifp_dl.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,18 @@ def initialize_network(
302302
```
303303

304304

305-
Here's a function to train the network by gradient ascent, given a generic loss
305+
Next we write a function to train the network by gradient *descent*, given a generic loss
306306
function.
307307

308+
```{note}
309+
We use gradient descent rather than ascent because we'll employ optax, which
310+
expects to be minimizing a loss function.
311+
312+
To make this work, we'll set the loss to $- \hat M(\theta)$.
313+
```
314+
315+
Here's the function.
316+
308317
```{code-cell} ipython3
309318
@partial(jax.jit, static_argnames=('config', 'loss_fn'))
310319
def train_network(
@@ -318,7 +327,6 @@ def train_network(
318327
models by providing an appropriate loss function.
319328
320329
"""
321-
322330
# Initialize network parameters
323331
key = random.PRNGKey(config.seed)
324332
params = initialize_network(key, config.layer_sizes)

lectures/jax_nn.md

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ Our default value of $k$ will be 10.
103103
```{code-cell} ipython3
104104
class Config(NamedTuple):
105105
epochs: int = 4000 # Number of passes through the data set
106-
num_layers: int = 4 # Depth of the network
107106
output_dim: int = 10 # Output dimension of input and hidden layers
108107
learning_rate: float = 0.001 # Learning rate for gradient descent
109108
layer_sizes: tuple = (1, 10, 10, 10, 1) # Sizes of each layer in the network
@@ -167,15 +166,15 @@ def build_keras_model(
167166
):
168167
model = Sequential()
169168
# Add layers to the network sequentially, from inputs towards outputs
170-
for i in range(config.num_layers-1):
169+
for i in range(len(config.layer_sizes) - 1):
171170
model.add(
172171
Dense(units=config.output_dim, activation=activation_function)
173172
)
174173
# Add a final layer that maps to a scalar value, for regression.
175174
model.add(Dense(units=1))
176175
# Embed training configurations
177176
model.compile(
178-
optimizer=keras.optimizers.SGD(),
177+
optimizer=keras.optimizers.SGD(),
179178
loss='mean_squared_error'
180179
)
181180
return model
@@ -214,10 +213,10 @@ The next function extracts and visualizes a prediction from the trained model.
214213

215214
```{code-cell} ipython3
216215
def plot_keras_output(model, x, y, x_validate, y_validate):
217-
y_predict = model.predict(x, verbose=2)
216+
y_predict = model.predict(x_validate, verbose=2)
218217
fig, ax = plt.subplots()
219-
ax.scatter(x, y)
220-
ax.plot(x, y_predict, label="fitted model", color='black')
218+
ax.scatter(x_validate, y_validate, color='red', alpha=0.5)
219+
ax.plot(x_validate, y_predict, label="fitted model", color='black')
221220
ax.set_xlabel('x')
222221
ax.set_ylabel('y')
223222
plt.show()
@@ -495,8 +494,8 @@ Here's a visualization of the quality of our fit.
495494

496495
```{code-cell} ipython3
497496
fig, ax = plt.subplots()
498-
ax.scatter(x_train, y_train)
499-
ax.plot(x_train.flatten(), f(θ, x_train).flatten(),
497+
ax.scatter(x_validate, y_validate, color='red', alpha=0.5)
498+
ax.plot(x_validate.flatten(), f(θ, x_validate).flatten(),
500499
label="fitted model", color='black')
501500
ax.set_xlabel('x')
502501
ax.set_ylabel('y')
@@ -566,8 +565,8 @@ print(f"Final MSE on validation data = {optax_sgd_mse:.6f}")
566565

567566
```{code-cell} ipython3
568567
fig, ax = plt.subplots()
569-
ax.scatter(x_train, y_train)
570-
ax.plot(x_train.flatten(), f(θ, x_train).flatten(),
568+
ax.scatter(x_validate, y_validate, color='red', alpha=0.5)
569+
ax.plot(x_validate.flatten(), f(θ, x_validate).flatten(),
571570
label="fitted model", color='black')
572571
ax.set_xlabel('x')
573572
ax.set_ylabel('y')
@@ -633,8 +632,8 @@ Here's a visualization of the result.
633632

634633
```{code-cell} ipython3
635634
fig, ax = plt.subplots()
636-
ax.scatter(x_train, y_train)
637-
ax.plot(x_train.flatten(), f(θ, x_train).flatten(),
635+
ax.scatter(x_validate, y_validate, color='red', alpha=0.5)
636+
ax.plot(x_validate.flatten(), f(θ, x_validate).flatten(),
638637
label="fitted model", color='black')
639638
ax.set_xlabel('x')
640639
ax.set_ylabel('y')
@@ -688,6 +687,9 @@ results = {
688687
}
689688
690689
df = pd.DataFrame(results)
690+
# Format MSE columns to 6 decimal places
691+
df['Training MSE'] = df['Training MSE'].apply(lambda x: f"{x:.6f}")
692+
df['Validation MSE'] = df['Validation MSE'].apply(lambda x: f"{x:.6f}")
691693
print("\nSummary of Training Methods:")
692694
print(df.to_string(index=False))
693695
```

0 commit comments

Comments
 (0)