import tensorflow as tf a = tf.constant([1.0, 2.0]) b = tf.constant([3.0, 4.0]) reuslt = a + b print result
# Tensor("add:0", shape=(2,), dtype=float32)
计算图 (graph)
搭建神经网络的计算过程,只搭建,不运算
会话(Session)
执行计算图中的节点运算
python2
1 2 3 4 5 6 7 8 9 10 11
import tensorflow as tf a = tf.constant([1.0, 2.0]) w = tf.constant([3.0, 4.0]) y = tf.matmul(x,w) print y # Tensor("matmul:0", shape(1,1), dtype=float32)
# 定义前向传播 a = tf.matmul(x, w1) y = tf.matmul(a, w2)
# 用会话计算结果 with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) ptint "y is: ",sess.run(y, feed_dict = {x: [[0.7, 0.5]]})
# 3 生成会话,训练STEPS轮 with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) # 输出未训练的参数值 print"w1:\n", sess.run(w1) print"w2:\n", sess.run(w2) print"\n"
# 训练模型 STEPS = 3000 for i inrange(STEPS): start = (i * BATCH_SIZE) % 32 end = start + BATCH_SIZE sess.run(train_step, feed_dict = {x: X[start: end], y_: Y[start, end]}) if i % 500 == 0: total_loss = sess.run(loss, feed_dict = {x: X, y_: Y}) print("After %d training stap(s), loss on all data is %g" % (i, total_loss))
# w1: # [[-0.81131822 1.48459876 0.06532937] # [-2.4427042 0.0992484 0.59122431]] # w2: # [[-0.81131822] # [ 1.48459876] # [ 0.06532937]] # # After 0 training step(s), loss on all data is 5.13118 # After 500 training step(s), loss on all data is 0.429111 # After 1000 training step(s), loss on all data is 0.409789 # After 1500 training step(s), loss on all data is 0.399923 # After 2000 training step(s), loss on all data is 0.394146 # After 2500 training step(s), loss on all data is 0.390597 # # w1: # [[-0.70006633 0.9136318 0.08953571] # [-2.3402493 -0.14641267 0.58823055]] # w2: # [[-0.06024267] # [ 0.91956186] # [-0.0682071 ]]
搭建神经网络步骤
1 准备
import 常量定义 生成数据集
2 前向传播:定义输入、参数和输出
x = y_ =
w1 = w2 =
a = y =
3 反向传播:定义损失函数、反向传播方法
loss = train_step =
4 生成会话,训练STEPS轮
1 2 3 4 5 6 7 8
with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) STEPS = for i inrange(STEPS): start = end = sess.run(train_step, feed_dict)