博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
自定义训练的演示,使用tf-data,Eager Execution和keras
阅读量:5942 次
发布时间:2019-06-19

本文共 5042 字,大约阅读时间需要 16 分钟。

1,机器学习的基本步骤

  • Import and parse the data sets.
  • Select the type of model.
  • Train the model.
  • Evaluate the model's effectiveness.
  • Use the trained model to make predictions

2,eager mode的使用限制

Once eager execution is enabled, it cannot be disabled within the same program

3,tf_data_dataset

TensorFlow's handles many common cases for loading data into a model

The default behavior is to shuffle the data (shuffle=True, shuffle_buffer_size=10000), and repeat the dataset forever (num_epochs=None)

batch_size = 32

# tf.data.experimental.make_csv_dataset

返回dataset的标准格式:The make_csv_dataset function returns a tf.data.Dataset of (features, label) pairs, where features is a dictionary: {'feature_name': value}

train_dataset = tf.contrib.data.make_csv_dataset(
    train_dataset_fp,
    batch_size,
    column_names=column_names,
    label_name=label_name,
    num_epochs=1)

4,遍历一下

next(iterator[, default])Return the next item from the iterator. If default is given and the iterator features, labels = next(iter(train_dataset)) features.get("sepal_length")

5,如何堆叠column即分列的feature的值为整行

def pack_features_vector(features, labels):

  """Pack the features into a single array."""
  features = tf.stack(list(features.values()), axis=1)
  return features, labels

6,

train_dataset = train_dataset.map(pack_features_vector)

for a,b in train_dataset:

    print(a,b)

7,如何理解tf.data.Dataset.map()

7.1,定义

  • map(
  • map_func,
  • num_parallel_calls=
    None
  • )

 

7.2,例子

# NOTE: The following examples use `{ ... }` to represent the

# contents of a dataset.
a = { 1, 2, 3, 4, 5 }
 
a.map(lambda x: x + 1) = { 2, 3, 4, 5, 6 }
7.3,解释

This transformation applies map_func to each element of this dataset, and returns a new dataset containing the transformed elements, in the same order as they appeared in the input.

  • import tensorflow as tf
  • def fun(x):
  • return x +1
  •  
  •  
  • ds = tf.data.Dataset.range(
    5)
  • ds = ds.map(fun)

8,选择model

model = tf.keras.Sequential([

  tf.keras.layers.Dense(10, activation=tf.nn.relu, input_shape=(4,)),  # input shape required
  tf.keras.layers.Dense(10, activation=tf.nn.relu),
  tf.keras.layers.Dense(3)
])

predictions = model(features)

predictions[:5]

tf.nn.softmax(predictions[:5])

print("Prediction: {}".format(tf.argmax(predictions, axis=1)))

print("    Labels: {}".format(labels))

9,训练

9.1,确定损失函数

def loss(model, x, y):

  y_ = model(x)
  return tf.losses.sparse_softmax_cross_entropy(labels=y, logits=y_)
l = loss(model, features, labels)
print("Loss test: {}".format(l))

9.2,计算梯度

def grad(model, inputs, targets):

  with tf.GradientTape() as tape:
    loss_value = loss(model, inputs, targets)
  return loss_value, tape.gradient(loss_value, model.trainable_variables)

9.3,优化器

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

global_step = tf.Variable(0)

 

loss_value, grads = grad(model, features, labels)

print("Step: {}, Initial Loss: {}".format(global_step.numpy(),
                                          loss_value.numpy()))
optimizer.apply_gradients(zip(grads, model.trainable_variables), global_step)
print("Step: {},         Loss: {}".format(global_step.numpy(),
                                          loss(model, features, labels).numpy()))

10,迭代训练

## Note: Rerunning this cell uses the same model variables

from tensorflow import contrib
tfe = contrib.eager
# keep results for plotting
train_loss_results = []
train_accuracy_results = []
num_epochs = 201
for epoch in range(num_epochs):
  epoch_loss_avg = tfe.metrics.Mean()
  epoch_accuracy = tfe.metrics.Accuracy()
  # Training loop - using batches of 32
  for x, y in train_dataset:
    # Optimize the model
    loss_value, grads = grad(model, x, y)
    optimizer.apply_gradients(zip(grads, model.trainable_variables),
                              global_step)
    # Track progress
    epoch_loss_avg(loss_value)  # add current batch loss
    # compare predicted label to actual label
    epoch_accuracy(tf.argmax(model(x), axis=1, output_type=tf.int32), y)
  # end epoch
  train_loss_results.append(epoch_loss_avg.result())
  train_accuracy_results.append(epoch_accuracy.result())
 
  if epoch % 50 == 0:
    print("Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}".format(epoch,
                                                                epoch_loss_avg.result(),
                                                                epoch_accuracy.result()))

11,评估

test_url = "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv"

test_fp = tf.keras.utils.get_file(fname=os.path.basename(test_url),
                                  origin=test_url)

test_dataset = tf.contrib.data.make_csv_dataset(

    test_fp,
    batch_size,
    column_names=column_names,
    label_name='species',
    num_epochs=1,
    shuffle=False)
test_dataset = test_dataset.map(pack_features_vector)

test_accuracy = tfe.metrics.Accuracy()

for (x, y) in test_dataset:
  logits = model(x)
  prediction = tf.argmax(logits, axis=1, output_type=tf.int32)
  test_accuracy(prediction, y)
print("Test set accuracy: {:.3%}".format(test_accuracy.result()))

tf.stack([y,prediction],axis=1)

12,预测

predict_dataset = tf.convert_to_tensor([

    [5.1, 3.3, 1.7, 0.5,],
    [5.9, 3.0, 4.2, 1.5,],
    [6.9, 3.1, 5.4, 2.1]
])
predictions = model(predict_dataset)
for i, logits in enumerate(predictions):
  class_idx = tf.argmax(logits).numpy()
  p = tf.nn.softmax(logits)[class_idx]
  name = class_names[class_idx]
  print("Example {} prediction: {} ({:4.1f}%)".format(i, name, 100*p))

 

 

 

 

 

转载于:https://www.cnblogs.com/augustone/p/10511400.html

你可能感兴趣的文章
javascript中使用Map
查看>>
backbonejs中的模型篇(二)
查看>>
Spring MVC 3 深入总结
查看>>
内存泄漏以及常见的解决方法
查看>>
HDP2.0.6+hadoop2.2.0+eclipse(windows和linux下)调试环境搭建
查看>>
【转】R语言笔记--颜色的使用
查看>>
.woff HTTP GET 404 (Not Found)
查看>>
.NET基础之自定义泛型
查看>>
HTML5 Canvas 实现的9个 Loading 效果
查看>>
java.lang.NoClassDefFoundError: org/apache/avro/ipc/Responder
查看>>
利用JasperReport+iReport进行Web报表开发
查看>>
JSON and Microsoft Technologies(翻译)
查看>>
ylbtech-LanguageSamples-ConditionalMethods(条件方法)
查看>>
js 判断各种数据类型
查看>>
【leetcode】Find Peak Element ☆
查看>>
linux:sed高级命令之n、N(转)
查看>>
触发器更新多条数据
查看>>
微信公众平台原创声明功能公测 自媒体原创保护的福音
查看>>
ADF_Advanced ADF系列2_Fusion应用的客制和个性化(Part2)
查看>>
php_linux_centos6.4_安装mysql_apache_php
查看>>