-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparams_GPT2.py
More file actions
84 lines (68 loc) · 2.79 KB
/
params_GPT2.py
File metadata and controls
84 lines (68 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# params_GPT2.py – Software AnatomyNLM
# Copyright 2023 b<>com. All rights reserved.
# This software is licensed under the Apache License, Version 2.0.
# You may not use this file except in compliance with the license.
# You may obtain a copy of the license at:
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Created on Wed Sep 27 07:46:20 2023
@author: Majd SALEH
"""
import keras_nlp
from keras.utils.layer_utils import count_params
import Transformer_params_calculator as params_calc
def main(L,d_k,d_v,d_e,M,d_f,v,n):
""" This function compares the number of trainable parameters computed
using the formula in num_params_GPT2 and the one reported by
keras.utils.layer_utils.count_params"""
# Inputs:
# L : number of transformer blocks
# d_k : dimension of key vector
# d_v : dimension of value vector
# d_e : dimension of embedding vector
# M : number of attention heads
# d_f : feedforward dimension
# v : vocabulary size
# n : sequence lenth (max)
# Output:
# the function prints a report of the considered coparison
# -------------------------------------------------------------------------
# Load pretrained gpt2 language model
model = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
# model = keras_nlp.models.GPT2Backbone.from_preset("gpt2_base_en")
# -------------------------------------------------------------------------
# Print model summary
model.summary(expand_nested=True)
# -------------------------------------------------------------------------
# count trainable parameters from the model and print their shapes
print("Trainable parameters' shapes:")
separator="-"*60
for tw in model.trainable_weights:
print(tw.shape)
print(separator)
print("Trainable parameters' count obtained from the model:")
total_tp=count_params(model.trainable_weights)
print(total_tp)
print(separator)
# -------------------------------------------------------------------------
# count trainable parameters from the derived formula
num_params=params_calc.num_params_GPT2(L,d_k,d_v,d_e,M,d_f,v,n)
print("Trainable parameters' count from the derived formula:")
print(num_params)
print(separator)
if __name__ == '__main__':
# Hyper parameters of GPT2 (gpt2_base_en)
v=50257
n=1024
d_e=768 #(12*64)
M=12
L=12
d_k=64
d_v=64
d_f=3072
main(L,d_k,d_v,d_e,M,d_f,v,n)