Keras Sequential model
Contents
Keras Sequential model#
This is a companion notebook for the excellent book Deep Learning with Python, Second Edition (code provided by François Chollet).
The Sequential model, the most approachable API—it’s basically a Python list. As such, it’s limited to simple (sequential) stacks of layers.
Setup#
from tensorflow import keras
from tensorflow.keras import layers
2023-01-12 14:40:11.934998: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Sequential class#
model = keras.Sequential([
layers.Dense(64, activation="relu"),
layers.Dense(10, activation="softmax")
])
2023-01-12 14:40:19.448789: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Incrementally building#
model = keras.Sequential()
model.add(layers.Dense(64, activation="relu"))
model.add(layers.Dense(10, activation="softmax"))
Build a model#
As input, we use input_shape = (None, 3)
:
This means the number of samples per batch is variable (indicated by the
None
batch size).The model will process batches where each sample has shape
(3,1)
, i.e. a simple array with 3 values.
model.build(input_shape=(None, 3))
model.weights
[<tf.Variable 'dense_2/kernel:0' shape=(3, 64) dtype=float32, numpy=
array([[ 0.2826782 , 0.2417497 , -0.05836663, 0.24952745, -0.11575036,
-0.14105794, 0.00627172, -0.27456254, -0.1287042 , -0.04035717,
0.184448 , 0.15180314, -0.04409876, 0.25864238, 0.07377666,
0.16428366, 0.13834658, -0.10165681, -0.14240925, 0.01597226,
0.21214521, 0.08710402, -0.10955833, -0.01319629, -0.16369334,
-0.04609835, 0.16951093, -0.11422695, -0.14093272, 0.00276107,
-0.22663775, 0.07431605, 0.14151314, -0.18165556, -0.25529322,
0.06155869, -0.02872646, 0.21174711, 0.15349936, -0.20779619,
-0.17252736, 0.0460569 , 0.09419367, -0.0829 , -0.09096426,
-0.24262328, 0.11593872, -0.14637753, -0.04844654, 0.17973652,
0.23427951, 0.1318292 , 0.16158694, 0.03198919, -0.09110838,
0.26938975, 0.23663473, -0.01077375, -0.02268949, -0.25346667,
-0.01892135, -0.05571799, -0.07512164, -0.08814697],
[-0.24195947, 0.25286376, 0.29157704, 0.17440143, -0.17675611,
0.29094243, -0.08157101, -0.03145474, 0.11797833, 0.15611872,
0.01262742, -0.22699606, 0.04916424, -0.11812451, 0.18900064,
-0.1652983 , 0.14014447, -0.07694185, 0.13507903, -0.21459259,
-0.21704552, -0.02510366, 0.23418874, 0.00336367, -0.23711869,
-0.06082419, -0.2981086 , -0.28445542, 0.13680229, 0.23173332,
-0.15219434, 0.02420676, -0.14985585, 0.21909189, -0.02908722,
0.27220637, -0.14114776, 0.10208118, 0.04449955, -0.20028144,
0.12821311, -0.12000096, 0.22986537, -0.2041235 , 0.03854424,
0.04071441, 0.15225706, -0.11581393, 0.01045698, 0.22192371,
0.09460613, 0.22948337, -0.02796492, -0.051245 , -0.26213908,
0.15558985, 0.17616057, 0.04300615, -0.06333047, 0.24559873,
-0.23299494, -0.01003173, -0.00982976, -0.22057365],
[-0.11273037, 0.26247364, -0.2745475 , -0.29180413, 0.01016629,
-0.17487554, -0.01739892, 0.01346025, -0.09383999, 0.16431198,
-0.110906 , 0.0034062 , 0.1695095 , 0.06775609, 0.22547638,
0.18317783, 0.00054038, 0.21720213, 0.2686953 , 0.06856132,
0.23969126, 0.04505947, 0.24772435, -0.1340534 , 0.2661003 ,
0.00364372, 0.2925592 , -0.2026523 , -0.2258859 , 0.15597475,
0.192054 , -0.00059846, -0.10291038, 0.27247667, 0.22205025,
-0.2595243 , 0.2589751 , 0.03667736, 0.14113435, -0.131831 ,
-0.04710087, -0.17132822, 0.16399184, -0.13850483, 0.12999117,
-0.21982506, 0.1447914 , -0.29201826, -0.03777638, 0.09851852,
0.10695145, 0.14410555, 0.13701653, 0.29396623, 0.24253297,
-0.22996339, -0.02364269, -0.01131505, -0.07063267, 0.0348458 ,
0.10594893, -0.01522812, -0.01860899, -0.15044205]],
dtype=float32)>,
<tf.Variable 'dense_2/bias:0' shape=(64,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>,
<tf.Variable 'dense_3/kernel:0' shape=(64, 10) dtype=float32, numpy=
array([[ 0.0227116 , -0.11105742, 0.21916005, -0.08332005, 0.24296471,
0.2227405 , 0.25771508, -0.17642915, 0.01605007, -0.17918953],
[-0.11198601, -0.25513417, -0.00127828, 0.03253537, -0.11811225,
-0.20744815, -0.25488585, 0.07495144, 0.01301107, 0.12305295],
[-0.08068126, 0.01837113, 0.22749636, -0.14017056, -0.25852892,
0.08594859, 0.02672622, 0.16838437, 0.17546982, 0.14865255],
[ 0.05641568, 0.04687923, -0.12208281, 0.12915713, -0.10566363,
0.26764193, -0.01995844, 0.16542298, -0.10871701, 0.15046057],
[-0.04409443, -0.06392211, -0.27356145, -0.26723915, 0.14522198,
-0.0360457 , -0.04710938, 0.21549538, -0.08259961, -0.0357943 ],
[-0.14304037, 0.18999264, 0.2665176 , 0.11225769, 0.01715231,
0.13444492, -0.10369457, -0.11847356, 0.00268942, -0.12984817],
[-0.05507535, 0.13729471, 0.27961668, -0.1330486 , -0.1118194 ,
-0.2062298 , 0.19906229, -0.11703044, 0.12512952, -0.23288646],
[-0.12215878, -0.05028877, 0.14478579, -0.01053175, 0.27401718,
0.07245386, -0.01024175, 0.14503568, -0.26744384, 0.19058079],
[-0.1139501 , -0.11130202, -0.26256302, -0.13598533, 0.02726343,
0.05046174, -0.26882333, -0.026676 , 0.08728185, 0.21774504],
[-0.11839731, 0.04882893, 0.12299815, -0.2599662 , 0.16727012,
0.1803537 , -0.25208202, -0.17002226, 0.26642606, 0.08154455],
[ 0.00598964, -0.25482467, -0.06871052, 0.08533493, 0.03285757,
-0.12505262, 0.22925994, 0.10890403, 0.11673069, 0.0337801 ],
[-0.03310305, -0.14233832, 0.04183352, -0.07953618, 0.21416712,
0.14172086, 0.19805121, 0.27707782, -0.02544436, 0.21469042],
[-0.19563001, 0.04014832, 0.22591802, 0.13549042, 0.00592965,
-0.2279095 , -0.24906816, 0.01830214, 0.15673727, 0.16839063],
[-0.17512563, -0.20069964, 0.16562027, -0.05383922, -0.09928906,
-0.24388468, 0.20519531, -0.21608615, -0.24536487, 0.06215116],
[ 0.14227608, -0.03319633, -0.01149756, 0.11757824, -0.22534323,
0.24636236, -0.14564534, 0.07447845, 0.11842787, 0.24155226],
[-0.11731577, -0.24292657, -0.09989673, 0.28227618, -0.24485096,
0.2002593 , -0.22064932, -0.27931702, -0.0679395 , 0.19920072],
[ 0.05468574, 0.09213287, -0.22578709, -0.0244385 , -0.1426517 ,
-0.10320021, -0.09566167, -0.15192202, -0.16512883, 0.27052644],
[ 0.14523488, -0.04513611, -0.0197213 , -0.02283931, 0.00127882,
-0.06654392, -0.02208057, 0.22867993, -0.12515588, -0.28147826],
[-0.08756414, -0.26039654, 0.24498418, -0.27918416, -0.03895712,
0.01204684, 0.07729626, -0.25768465, 0.00730079, 0.11953881],
[-0.01533568, -0.28437448, -0.09470837, 0.19159552, 0.04241586,
0.23649171, -0.02788156, -0.2463947 , -0.15466684, 0.28197512],
[ 0.01217487, -0.23272991, -0.03559268, -0.10867207, -0.24214673,
0.10347122, -0.21003893, -0.00662482, -0.26657578, 0.16093999],
[-0.18764117, 0.21695593, -0.17500621, -0.17758694, -0.12731421,
-0.0184052 , 0.27466002, -0.2731235 , 0.15947849, 0.16490778],
[ 0.11493817, 0.13910729, -0.24848594, 0.19265798, 0.05819044,
0.00161412, -0.1606789 , -0.12757558, 0.05787176, -0.06857325],
[-0.16540676, -0.06942044, 0.11635676, -0.22922024, -0.02355534,
-0.03926913, 0.04614705, 0.00384259, 0.118379 , 0.05644155],
[ 0.03266335, 0.25526795, 0.23688522, -0.17452264, 0.27752265,
0.13963193, 0.13582444, -0.24107021, 0.18948069, -0.22057104],
[ 0.10681504, 0.08417356, 0.22577521, -0.24380106, -0.2714516 ,
0.09115446, 0.22181728, 0.2786174 , 0.16901967, 0.27663216],
[ 0.28214952, 0.26111713, 0.15805769, 0.04456577, -0.20399912,
0.2525579 , -0.21956275, -0.2518011 , 0.2589009 , 0.15441933],
[ 0.03210691, -0.17892143, -0.2790965 , -0.23397797, -0.12601474,
-0.1843094 , -0.00033233, -0.11967742, 0.16478705, 0.05147281],
[-0.2100276 , -0.21319014, 0.19939956, -0.12155043, -0.00118488,
0.2788054 , 0.14748845, -0.25435475, 0.26713142, -0.09573269],
[-0.12282069, 0.04055178, -0.07216601, -0.08296894, -0.03722358,
0.2557349 , 0.05643979, 0.02678543, -0.23654853, -0.14930965],
[-0.21942908, 0.26818374, 0.2711011 , -0.08819449, 0.25800756,
-0.0775543 , -0.0685423 , 0.07053077, -0.19228438, 0.15546453],
[-0.02288148, 0.22397122, 0.01737478, 0.25386217, -0.22039846,
-0.12739839, -0.14631534, -0.24097076, -0.1979242 , 0.16015527],
[-0.20723145, -0.24874209, -0.20234114, 0.06359249, -0.06654024,
-0.0367415 , -0.25965247, 0.25881943, -0.09978382, 0.01308081],
[-0.02068844, -0.14234124, 0.25291792, 0.28060636, 0.01828939,
0.20907512, 0.1955671 , 0.17785856, -0.01837817, 0.19368383],
[ 0.01163149, -0.04084526, -0.22893077, 0.17049173, 0.10836229,
0.11900771, -0.23876789, 0.22959718, -0.0790877 , -0.18509153],
[-0.01294157, -0.07983795, -0.02466226, 0.28205565, -0.25292608,
-0.02668598, -0.19692093, 0.00818145, -0.23145251, -0.18205221],
[ 0.23367211, 0.24347326, 0.1234127 , -0.12516367, 0.18918008,
0.223901 , 0.1703668 , -0.1597994 , 0.04592568, 0.14886642],
[ 0.07242563, 0.04856294, 0.09821436, 0.02847558, -0.27761132,
-0.0975105 , -0.16396323, -0.0160751 , 0.20444745, 0.1743817 ],
[ 0.22725096, 0.07651511, 0.20171663, -0.0382005 , -0.2535014 ,
0.2428414 , 0.19316924, -0.13883036, 0.1390709 , 0.19727951],
[-0.03731509, -0.08706963, 0.04758394, -0.19286829, -0.20768902,
0.17123109, 0.18550268, -0.10874288, -0.2796373 , -0.22547175],
[-0.14806247, 0.06824917, 0.06949037, -0.22299978, 0.2756212 ,
-0.17971757, 0.16690215, 0.0509927 , -0.24552815, -0.2130152 ],
[-0.25983727, 0.28111818, 0.18532759, 0.07777786, 0.19683269,
-0.11684313, -0.0190202 , 0.25294152, 0.2688261 , -0.21490869],
[-0.28319404, 0.10766447, 0.20378241, 0.12080377, -0.09629595,
0.2008099 , 0.1234237 , -0.22897604, -0.18349743, 0.27788773],
[-0.2841503 , -0.13964917, 0.20021886, -0.09898022, -0.1812427 ,
0.07425204, 0.05729967, 0.22853097, -0.19274032, -0.13685228],
[ 0.24986735, -0.14380215, 0.05483645, 0.04961327, -0.24411348,
-0.08334428, 0.19719255, -0.08729461, 0.22737321, 0.14047116],
[-0.20240664, 0.20164093, 0.05406007, 0.05829111, 0.11699671,
-0.03660707, 0.03010759, -0.22588892, -0.16960958, -0.08234476],
[-0.08976611, -0.25159043, -0.21576694, -0.04624535, 0.16310138,
0.11668876, 0.09123489, 0.11777532, 0.18514103, 0.19199109],
[-0.03330836, 0.03585234, -0.0454091 , -0.27826872, 0.07515293,
0.17013273, 0.20453253, 0.14011088, -0.05214702, 0.17547506],
[ 0.22536227, 0.15563351, 0.11436877, 0.02922487, 0.18187684,
-0.10304916, -0.05022298, -0.03793003, -0.00061372, -0.17427775],
[-0.021559 , 0.04070622, -0.03319564, -0.1861834 , 0.10490647,
-0.24110253, -0.18209708, -0.19290258, -0.10217156, -0.11148512],
[ 0.26249275, 0.13463345, -0.28446525, 0.0084078 , -0.19129707,
-0.12181464, 0.11688161, 0.23600379, 0.10549825, 0.21703747],
[ 0.15863305, 0.0899078 , -0.23113152, -0.0747181 , 0.20299783,
-0.13408852, -0.10583518, 0.07816035, 0.02337489, -0.0484326 ],
[ 0.14874142, -0.0968443 , 0.13962099, -0.25640187, 0.16081408,
-0.1709216 , -0.06350961, -0.22494417, -0.20407508, 0.15411967],
[-0.02977607, -0.00175914, 0.17356071, -0.27876046, -0.12089054,
-0.2729321 , -0.23154856, -0.14235313, 0.21474555, -0.2843 ],
[ 0.01441514, 0.24067512, -0.18067977, 0.19038567, -0.10801885,
-0.01788306, 0.19336918, 0.27651796, 0.26251146, -0.13158728],
[ 0.24935874, 0.06310111, -0.09378305, -0.2692763 , 0.01248866,
0.18675834, 0.13483578, 0.22451726, 0.24030146, 0.00683257],
[-0.24297464, 0.25801042, 0.18131006, -0.10502113, 0.05550426,
0.11546999, -0.01112029, -0.25688937, 0.0695487 , -0.24913889],
[-0.08311157, -0.13582627, 0.27336392, 0.25429365, -0.09435149,
-0.15393601, 0.17118576, -0.2065309 , -0.20882371, -0.16374686],
[-0.05297574, -0.05977911, -0.18794043, 0.04125264, 0.18313748,
-0.05674006, -0.18896624, 0.24253413, 0.08855343, 0.07062358],
[-0.21827197, 0.21259293, -0.02283069, -0.0180327 , 0.1536724 ,
-0.20623158, -0.13351431, 0.13682213, 0.11213455, -0.05159855],
[ 0.05925208, 0.26357362, -0.26900378, -0.14568037, 0.20320046,
-0.27207616, 0.01259071, -0.05142766, 0.01171827, -0.18537328],
[-0.04391758, -0.06941324, 0.03756186, -0.10019761, 0.21796522,
0.013455 , -0.22225592, -0.19769338, 0.09341758, -0.2639161 ],
[-0.14900164, -0.11312152, -0.11325139, 0.14102861, 0.07807955,
0.09603262, 0.12314886, -0.25037706, -0.18202682, -0.1254163 ],
[-0.16640025, 0.23894766, 0.28395763, 0.16329989, -0.11318886,
-0.07199201, 0.1381489 , -0.01620194, 0.21754631, -0.2388964 ]],
dtype=float32)>,
<tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>]
Model summary#
model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_2 (Dense) (None, 64) 256
dense_3 (Dense) (None, 10) 650
=================================================================
Total params: 906
Trainable params: 906
Non-trainable params: 0
_________________________________________________________________
Naming models and layers#
model = keras.Sequential(name="my_example_model")
model.add(layers.Dense(64, activation="relu", name="my_first_layer"))
model.add(layers.Dense(10, activation="softmax", name="my_last_layer"))
model.build((None, 3))
model.summary()
Model: "my_example_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
my_first_layer (Dense) (None, 64) 256
my_last_layer (Dense) (None, 10) 650
=================================================================
Total params: 906
Trainable params: 906
Non-trainable params: 0
_________________________________________________________________
Specifying input shape#
Use Input to declare the shape of the inputs. Note that the shape argument must be the shape of each sample, not the shape of one batch.
model = keras.Sequential()
model.add(keras.Input(shape=(3,)))
model.add(layers.Dense(64, activation="relu"))
model.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_4 (Dense) (None, 64) 256
=================================================================
Total params: 256
Trainable params: 256
Non-trainable params: 0
_________________________________________________________________
model.add(layers.Dense(10, activation="softmax"))
model.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_4 (Dense) (None, 64) 256
dense_5 (Dense) (None, 10) 650
=================================================================
Total params: 906
Trainable params: 906
Non-trainable params: 0
_________________________________________________________________