1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
| optimizer = optimizers.Nadam() loss_func = losses.SparseCategoricalCrossentropy()
train_loss = metrics.Mean(name='train_loss') train_metric = metrics.SparseCategoricalAccuracy(name='train_accuracy')
valid_loss = metrics.Mean(name='valid_loss') valid_metric = metrics.SparseCategoricalAccuracy(name='valid_accuracy')
@tf.function def train_step(model, features, labels): with tf.GradientTape() as tape: predictions = model(features,training = True) loss = loss_func(labels, predictions) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss.update_state(loss) train_metric.update_state(labels, predictions)
@tf.function def valid_step(model, features, labels): predictions = model(features) batch_loss = loss_func(labels, predictions) valid_loss.update_state(batch_loss) valid_metric.update_state(labels, predictions)
def train_model(model,ds_train,ds_valid,epochs): for epoch in tf.range(1,epochs+1):
for features, labels in ds_train: train_step(model,features,labels)
for features, labels in ds_valid: valid_step(model,features,labels)
logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}'
if epoch%1 ==0: printbar() tf.print(tf.strings.format(logs, (epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result()))) tf.print("")
train_loss.reset_states() valid_loss.reset_states() train_metric.reset_states() valid_metric.reset_states()
train_model(model,ds_train,ds_test,10)
|