Keras Sequential model#

This is a companion notebook for the excellent book Deep Learning with Python, Second Edition (code provided by François Chollet).

The Sequential model, the most approachable API—it’s basically a Python list. As such, it’s limited to simple (sequential) stacks of layers.

Setup#

from tensorflow import keras
from tensorflow.keras import layers
2023-01-12 14:40:11.934998: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.

Sequential class#

model = keras.Sequential([
    layers.Dense(64, activation="relu"),
    layers.Dense(10, activation="softmax")
])
2023-01-12 14:40:19.448789: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.

Incrementally building#

model = keras.Sequential()
model.add(layers.Dense(64, activation="relu"))
model.add(layers.Dense(10, activation="softmax"))

Build a model#

As input, we use input_shape = (None, 3):

  • This means the number of samples per batch is variable (indicated by the None batch size).

  • The model will process batches where each sample has shape (3,1), i.e. a simple array with 3 values.

model.build(input_shape=(None, 3))
model.weights
[<tf.Variable 'dense_2/kernel:0' shape=(3, 64) dtype=float32, numpy=
 array([[ 0.2826782 ,  0.2417497 , -0.05836663,  0.24952745, -0.11575036,
         -0.14105794,  0.00627172, -0.27456254, -0.1287042 , -0.04035717,
          0.184448  ,  0.15180314, -0.04409876,  0.25864238,  0.07377666,
          0.16428366,  0.13834658, -0.10165681, -0.14240925,  0.01597226,
          0.21214521,  0.08710402, -0.10955833, -0.01319629, -0.16369334,
         -0.04609835,  0.16951093, -0.11422695, -0.14093272,  0.00276107,
         -0.22663775,  0.07431605,  0.14151314, -0.18165556, -0.25529322,
          0.06155869, -0.02872646,  0.21174711,  0.15349936, -0.20779619,
         -0.17252736,  0.0460569 ,  0.09419367, -0.0829    , -0.09096426,
         -0.24262328,  0.11593872, -0.14637753, -0.04844654,  0.17973652,
          0.23427951,  0.1318292 ,  0.16158694,  0.03198919, -0.09110838,
          0.26938975,  0.23663473, -0.01077375, -0.02268949, -0.25346667,
         -0.01892135, -0.05571799, -0.07512164, -0.08814697],
        [-0.24195947,  0.25286376,  0.29157704,  0.17440143, -0.17675611,
          0.29094243, -0.08157101, -0.03145474,  0.11797833,  0.15611872,
          0.01262742, -0.22699606,  0.04916424, -0.11812451,  0.18900064,
         -0.1652983 ,  0.14014447, -0.07694185,  0.13507903, -0.21459259,
         -0.21704552, -0.02510366,  0.23418874,  0.00336367, -0.23711869,
         -0.06082419, -0.2981086 , -0.28445542,  0.13680229,  0.23173332,
         -0.15219434,  0.02420676, -0.14985585,  0.21909189, -0.02908722,
          0.27220637, -0.14114776,  0.10208118,  0.04449955, -0.20028144,
          0.12821311, -0.12000096,  0.22986537, -0.2041235 ,  0.03854424,
          0.04071441,  0.15225706, -0.11581393,  0.01045698,  0.22192371,
          0.09460613,  0.22948337, -0.02796492, -0.051245  , -0.26213908,
          0.15558985,  0.17616057,  0.04300615, -0.06333047,  0.24559873,
         -0.23299494, -0.01003173, -0.00982976, -0.22057365],
        [-0.11273037,  0.26247364, -0.2745475 , -0.29180413,  0.01016629,
         -0.17487554, -0.01739892,  0.01346025, -0.09383999,  0.16431198,
         -0.110906  ,  0.0034062 ,  0.1695095 ,  0.06775609,  0.22547638,
          0.18317783,  0.00054038,  0.21720213,  0.2686953 ,  0.06856132,
          0.23969126,  0.04505947,  0.24772435, -0.1340534 ,  0.2661003 ,
          0.00364372,  0.2925592 , -0.2026523 , -0.2258859 ,  0.15597475,
          0.192054  , -0.00059846, -0.10291038,  0.27247667,  0.22205025,
         -0.2595243 ,  0.2589751 ,  0.03667736,  0.14113435, -0.131831  ,
         -0.04710087, -0.17132822,  0.16399184, -0.13850483,  0.12999117,
         -0.21982506,  0.1447914 , -0.29201826, -0.03777638,  0.09851852,
          0.10695145,  0.14410555,  0.13701653,  0.29396623,  0.24253297,
         -0.22996339, -0.02364269, -0.01131505, -0.07063267,  0.0348458 ,
          0.10594893, -0.01522812, -0.01860899, -0.15044205]],
       dtype=float32)>,
 <tf.Variable 'dense_2/bias:0' shape=(64,) dtype=float32, numpy=
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>,
 <tf.Variable 'dense_3/kernel:0' shape=(64, 10) dtype=float32, numpy=
 array([[ 0.0227116 , -0.11105742,  0.21916005, -0.08332005,  0.24296471,
          0.2227405 ,  0.25771508, -0.17642915,  0.01605007, -0.17918953],
        [-0.11198601, -0.25513417, -0.00127828,  0.03253537, -0.11811225,
         -0.20744815, -0.25488585,  0.07495144,  0.01301107,  0.12305295],
        [-0.08068126,  0.01837113,  0.22749636, -0.14017056, -0.25852892,
          0.08594859,  0.02672622,  0.16838437,  0.17546982,  0.14865255],
        [ 0.05641568,  0.04687923, -0.12208281,  0.12915713, -0.10566363,
          0.26764193, -0.01995844,  0.16542298, -0.10871701,  0.15046057],
        [-0.04409443, -0.06392211, -0.27356145, -0.26723915,  0.14522198,
         -0.0360457 , -0.04710938,  0.21549538, -0.08259961, -0.0357943 ],
        [-0.14304037,  0.18999264,  0.2665176 ,  0.11225769,  0.01715231,
          0.13444492, -0.10369457, -0.11847356,  0.00268942, -0.12984817],
        [-0.05507535,  0.13729471,  0.27961668, -0.1330486 , -0.1118194 ,
         -0.2062298 ,  0.19906229, -0.11703044,  0.12512952, -0.23288646],
        [-0.12215878, -0.05028877,  0.14478579, -0.01053175,  0.27401718,
          0.07245386, -0.01024175,  0.14503568, -0.26744384,  0.19058079],
        [-0.1139501 , -0.11130202, -0.26256302, -0.13598533,  0.02726343,
          0.05046174, -0.26882333, -0.026676  ,  0.08728185,  0.21774504],
        [-0.11839731,  0.04882893,  0.12299815, -0.2599662 ,  0.16727012,
          0.1803537 , -0.25208202, -0.17002226,  0.26642606,  0.08154455],
        [ 0.00598964, -0.25482467, -0.06871052,  0.08533493,  0.03285757,
         -0.12505262,  0.22925994,  0.10890403,  0.11673069,  0.0337801 ],
        [-0.03310305, -0.14233832,  0.04183352, -0.07953618,  0.21416712,
          0.14172086,  0.19805121,  0.27707782, -0.02544436,  0.21469042],
        [-0.19563001,  0.04014832,  0.22591802,  0.13549042,  0.00592965,
         -0.2279095 , -0.24906816,  0.01830214,  0.15673727,  0.16839063],
        [-0.17512563, -0.20069964,  0.16562027, -0.05383922, -0.09928906,
         -0.24388468,  0.20519531, -0.21608615, -0.24536487,  0.06215116],
        [ 0.14227608, -0.03319633, -0.01149756,  0.11757824, -0.22534323,
          0.24636236, -0.14564534,  0.07447845,  0.11842787,  0.24155226],
        [-0.11731577, -0.24292657, -0.09989673,  0.28227618, -0.24485096,
          0.2002593 , -0.22064932, -0.27931702, -0.0679395 ,  0.19920072],
        [ 0.05468574,  0.09213287, -0.22578709, -0.0244385 , -0.1426517 ,
         -0.10320021, -0.09566167, -0.15192202, -0.16512883,  0.27052644],
        [ 0.14523488, -0.04513611, -0.0197213 , -0.02283931,  0.00127882,
         -0.06654392, -0.02208057,  0.22867993, -0.12515588, -0.28147826],
        [-0.08756414, -0.26039654,  0.24498418, -0.27918416, -0.03895712,
          0.01204684,  0.07729626, -0.25768465,  0.00730079,  0.11953881],
        [-0.01533568, -0.28437448, -0.09470837,  0.19159552,  0.04241586,
          0.23649171, -0.02788156, -0.2463947 , -0.15466684,  0.28197512],
        [ 0.01217487, -0.23272991, -0.03559268, -0.10867207, -0.24214673,
          0.10347122, -0.21003893, -0.00662482, -0.26657578,  0.16093999],
        [-0.18764117,  0.21695593, -0.17500621, -0.17758694, -0.12731421,
         -0.0184052 ,  0.27466002, -0.2731235 ,  0.15947849,  0.16490778],
        [ 0.11493817,  0.13910729, -0.24848594,  0.19265798,  0.05819044,
          0.00161412, -0.1606789 , -0.12757558,  0.05787176, -0.06857325],
        [-0.16540676, -0.06942044,  0.11635676, -0.22922024, -0.02355534,
         -0.03926913,  0.04614705,  0.00384259,  0.118379  ,  0.05644155],
        [ 0.03266335,  0.25526795,  0.23688522, -0.17452264,  0.27752265,
          0.13963193,  0.13582444, -0.24107021,  0.18948069, -0.22057104],
        [ 0.10681504,  0.08417356,  0.22577521, -0.24380106, -0.2714516 ,
          0.09115446,  0.22181728,  0.2786174 ,  0.16901967,  0.27663216],
        [ 0.28214952,  0.26111713,  0.15805769,  0.04456577, -0.20399912,
          0.2525579 , -0.21956275, -0.2518011 ,  0.2589009 ,  0.15441933],
        [ 0.03210691, -0.17892143, -0.2790965 , -0.23397797, -0.12601474,
         -0.1843094 , -0.00033233, -0.11967742,  0.16478705,  0.05147281],
        [-0.2100276 , -0.21319014,  0.19939956, -0.12155043, -0.00118488,
          0.2788054 ,  0.14748845, -0.25435475,  0.26713142, -0.09573269],
        [-0.12282069,  0.04055178, -0.07216601, -0.08296894, -0.03722358,
          0.2557349 ,  0.05643979,  0.02678543, -0.23654853, -0.14930965],
        [-0.21942908,  0.26818374,  0.2711011 , -0.08819449,  0.25800756,
         -0.0775543 , -0.0685423 ,  0.07053077, -0.19228438,  0.15546453],
        [-0.02288148,  0.22397122,  0.01737478,  0.25386217, -0.22039846,
         -0.12739839, -0.14631534, -0.24097076, -0.1979242 ,  0.16015527],
        [-0.20723145, -0.24874209, -0.20234114,  0.06359249, -0.06654024,
         -0.0367415 , -0.25965247,  0.25881943, -0.09978382,  0.01308081],
        [-0.02068844, -0.14234124,  0.25291792,  0.28060636,  0.01828939,
          0.20907512,  0.1955671 ,  0.17785856, -0.01837817,  0.19368383],
        [ 0.01163149, -0.04084526, -0.22893077,  0.17049173,  0.10836229,
          0.11900771, -0.23876789,  0.22959718, -0.0790877 , -0.18509153],
        [-0.01294157, -0.07983795, -0.02466226,  0.28205565, -0.25292608,
         -0.02668598, -0.19692093,  0.00818145, -0.23145251, -0.18205221],
        [ 0.23367211,  0.24347326,  0.1234127 , -0.12516367,  0.18918008,
          0.223901  ,  0.1703668 , -0.1597994 ,  0.04592568,  0.14886642],
        [ 0.07242563,  0.04856294,  0.09821436,  0.02847558, -0.27761132,
         -0.0975105 , -0.16396323, -0.0160751 ,  0.20444745,  0.1743817 ],
        [ 0.22725096,  0.07651511,  0.20171663, -0.0382005 , -0.2535014 ,
          0.2428414 ,  0.19316924, -0.13883036,  0.1390709 ,  0.19727951],
        [-0.03731509, -0.08706963,  0.04758394, -0.19286829, -0.20768902,
          0.17123109,  0.18550268, -0.10874288, -0.2796373 , -0.22547175],
        [-0.14806247,  0.06824917,  0.06949037, -0.22299978,  0.2756212 ,
         -0.17971757,  0.16690215,  0.0509927 , -0.24552815, -0.2130152 ],
        [-0.25983727,  0.28111818,  0.18532759,  0.07777786,  0.19683269,
         -0.11684313, -0.0190202 ,  0.25294152,  0.2688261 , -0.21490869],
        [-0.28319404,  0.10766447,  0.20378241,  0.12080377, -0.09629595,
          0.2008099 ,  0.1234237 , -0.22897604, -0.18349743,  0.27788773],
        [-0.2841503 , -0.13964917,  0.20021886, -0.09898022, -0.1812427 ,
          0.07425204,  0.05729967,  0.22853097, -0.19274032, -0.13685228],
        [ 0.24986735, -0.14380215,  0.05483645,  0.04961327, -0.24411348,
         -0.08334428,  0.19719255, -0.08729461,  0.22737321,  0.14047116],
        [-0.20240664,  0.20164093,  0.05406007,  0.05829111,  0.11699671,
         -0.03660707,  0.03010759, -0.22588892, -0.16960958, -0.08234476],
        [-0.08976611, -0.25159043, -0.21576694, -0.04624535,  0.16310138,
          0.11668876,  0.09123489,  0.11777532,  0.18514103,  0.19199109],
        [-0.03330836,  0.03585234, -0.0454091 , -0.27826872,  0.07515293,
          0.17013273,  0.20453253,  0.14011088, -0.05214702,  0.17547506],
        [ 0.22536227,  0.15563351,  0.11436877,  0.02922487,  0.18187684,
         -0.10304916, -0.05022298, -0.03793003, -0.00061372, -0.17427775],
        [-0.021559  ,  0.04070622, -0.03319564, -0.1861834 ,  0.10490647,
         -0.24110253, -0.18209708, -0.19290258, -0.10217156, -0.11148512],
        [ 0.26249275,  0.13463345, -0.28446525,  0.0084078 , -0.19129707,
         -0.12181464,  0.11688161,  0.23600379,  0.10549825,  0.21703747],
        [ 0.15863305,  0.0899078 , -0.23113152, -0.0747181 ,  0.20299783,
         -0.13408852, -0.10583518,  0.07816035,  0.02337489, -0.0484326 ],
        [ 0.14874142, -0.0968443 ,  0.13962099, -0.25640187,  0.16081408,
         -0.1709216 , -0.06350961, -0.22494417, -0.20407508,  0.15411967],
        [-0.02977607, -0.00175914,  0.17356071, -0.27876046, -0.12089054,
         -0.2729321 , -0.23154856, -0.14235313,  0.21474555, -0.2843    ],
        [ 0.01441514,  0.24067512, -0.18067977,  0.19038567, -0.10801885,
         -0.01788306,  0.19336918,  0.27651796,  0.26251146, -0.13158728],
        [ 0.24935874,  0.06310111, -0.09378305, -0.2692763 ,  0.01248866,
          0.18675834,  0.13483578,  0.22451726,  0.24030146,  0.00683257],
        [-0.24297464,  0.25801042,  0.18131006, -0.10502113,  0.05550426,
          0.11546999, -0.01112029, -0.25688937,  0.0695487 , -0.24913889],
        [-0.08311157, -0.13582627,  0.27336392,  0.25429365, -0.09435149,
         -0.15393601,  0.17118576, -0.2065309 , -0.20882371, -0.16374686],
        [-0.05297574, -0.05977911, -0.18794043,  0.04125264,  0.18313748,
         -0.05674006, -0.18896624,  0.24253413,  0.08855343,  0.07062358],
        [-0.21827197,  0.21259293, -0.02283069, -0.0180327 ,  0.1536724 ,
         -0.20623158, -0.13351431,  0.13682213,  0.11213455, -0.05159855],
        [ 0.05925208,  0.26357362, -0.26900378, -0.14568037,  0.20320046,
         -0.27207616,  0.01259071, -0.05142766,  0.01171827, -0.18537328],
        [-0.04391758, -0.06941324,  0.03756186, -0.10019761,  0.21796522,
          0.013455  , -0.22225592, -0.19769338,  0.09341758, -0.2639161 ],
        [-0.14900164, -0.11312152, -0.11325139,  0.14102861,  0.07807955,
          0.09603262,  0.12314886, -0.25037706, -0.18202682, -0.1254163 ],
        [-0.16640025,  0.23894766,  0.28395763,  0.16329989, -0.11318886,
         -0.07199201,  0.1381489 , -0.01620194,  0.21754631, -0.2388964 ]],
       dtype=float32)>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>]

Model summary#

model.summary()
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_2 (Dense)             (None, 64)                256       
                                                                 
 dense_3 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 906
Trainable params: 906
Non-trainable params: 0
_________________________________________________________________

Naming models and layers#

model = keras.Sequential(name="my_example_model")
model.add(layers.Dense(64, activation="relu", name="my_first_layer"))
model.add(layers.Dense(10, activation="softmax", name="my_last_layer"))
model.build((None, 3))
model.summary()
Model: "my_example_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 my_first_layer (Dense)      (None, 64)                256       
                                                                 
 my_last_layer (Dense)       (None, 10)                650       
                                                                 
=================================================================
Total params: 906
Trainable params: 906
Non-trainable params: 0
_________________________________________________________________

Specifying input shape#

Use Input to declare the shape of the inputs. Note that the shape argument must be the shape of each sample, not the shape of one batch.

model = keras.Sequential()
model.add(keras.Input(shape=(3,)))
model.add(layers.Dense(64, activation="relu"))
model.summary()
Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_4 (Dense)             (None, 64)                256       
                                                                 
=================================================================
Total params: 256
Trainable params: 256
Non-trainable params: 0
_________________________________________________________________
model.add(layers.Dense(10, activation="softmax"))
model.summary()
Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_4 (Dense)             (None, 64)                256       
                                                                 
 dense_5 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 906
Trainable params: 906
Non-trainable params: 0
_________________________________________________________________