程序猿python学习AIphaZero,TensorFlow强化学习AI游戏,100行代码运行看看!

打败世界冠军?AIphaGo Zero原理?

没错,本篇文章利用100行代码展示如何利用TensorFlow框架编写一个很简单的深度强化游戏AI核心部分,希望在本篇文章里,同学们能学到DQN网络原理。再也不用担心麻麻说我学机器学习搬砖啦!

Deep Q Network是DeepMind在2013年提出来的网络,是第一个成功地将深度学习和强化学习结合起来的模型,也是打败世界围棋冠军柯洁AIphaGO Zero核心原理,启发了后续一系列的工作。这些后续工作中比较有名的有Double DQN, Prioritized Replay 和 Dueling Network。.

游戏操作:

按住鼠标左键左移小棒子,按住鼠标右键右移小棒子。每次用棒子接住小方块得一分。通过深度强化学习算法,让计算机自动完成游戏操作。

安装Python依赖库

lPipinstall pygame

lPip install numpy

核心代码展示

定义CNN卷积网络:

1.defconvolutional_neural_network(input_image):

2.weights = {'w_conv1':tf.Variable(tf.zeros([8, 8, 4, 32])),

3.'w_conv2':tf.Variable(tf.zeros([4, 4, 32, 64])),

4.'w_conv3':tf.Variable(tf.zeros([3, 3, 64, 64])),

5.'w_fc4':tf.Variable(tf.zeros([3456, 784])),

6.'w_out':tf.Variable(tf.zeros([784, output]))}

7.

8.biases = {'b_conv1':tf.Variable(tf.zeros([32])),

9.'b_conv2':tf.Variable(tf.zeros([64])),

10.'b_conv3':tf.Variable(tf.zeros([64])),

11.'b_fc4':tf.Variable(tf.zeros([784])),

12.'b_out':tf.Variable(tf.zeros([output]))}

13.

17.conv3_flat = tf.reshape(conv3, [-1, 3456])

19.

20.output_layer = tf.matmul(fc4, weights['w_out']) + biases['b_out']

21.returnoutput_layer

其中包含三个卷积层,一个全连接层,通过relu激活函数输出给下一层。

训练神经网络方法:

1.deftrain_neural_network(input_image):

2.predict_action = convolutional_neural_network(input_image)

3.

4.argmax = tf.placeholder("float", [None, output])

5.gt = tf.placeholder("float", [None])

6.

7.action = tf.reduce_sum(tf.mul(predict_action, argmax), reduction_indices = 1)

8.cost = tf.reduce_mean(tf.square(action - gt))

9.optimizer = tf.train.AdamOptimizer(1e-6).minimize(cost)

10.

11.game = Game()

12.D = deque()

13.

14._, image = game.step(MOVE_STAY)

15.#转换为灰度值

16.image = cv2.cvtColor(cv2.resize(image, (100, 80)), cv2.COLOR_BGR2GRAY)

17.#转换为二值

18.ret, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY)

19.input_image_data = np.stack((image, image, image, image), axis = 2)

20.

21.with tf.Session() as sess:

22.sess.run(tf.initialize_all_variables())

23.

24.saver = tf.train.Saver()

25.

26.n = 0

27.epsilon = INITIAL_EPSILON

28.whileTrue:

29.action_t = predict_action.eval(feed_dict = )[0]

30.

31.argmax_t = np.zeros([output], dtype=np.int)

32.if(random.random()

33.maxIndex = random.randrange(output)

34.else:

35.maxIndex = np.argmax(action_t)

36.argmax_t[maxIndex] = 1

37.ifepsilon > FINAL_EPSILON:

38.epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE

39.

41.# if event.type == QUIT:

42.# pygame.quit()

43.# sys.exit()

44.reward, image = game.step(list(argmax_t))

45.

46.image = cv2.cvtColor(cv2.resize(image, (100, 80)), cv2.COLOR_BGR2GRAY)

47.ret, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY)

48.image = np.reshape(image, (80, 100, 1))

49.input_image_data1 = np.append(image, input_image_data[:, :, 0:3], axis = 2)

50.

51.D.append((input_image_data, argmax_t, reward, input_image_data1))

52.

53.iflen(D) > REPLAY_MEMORY:

54.D.popleft()

55.

56.ifn > OBSERVE:

57.minibatch = random.sample(D, BATCH)

58.input_image_data_batch = [d[0]fordinminibatch]

59.argmax_batch = [d[1]fordinminibatch]

60.reward_batch = [d[2]fordinminibatch]

61.input_image_data1_batch = [d[3]fordinminibatch]

62.

63.gt_batch = []

64.

65.out_batch = predict_action.eval(feed_dict = )

66.

67.foriinrange(0, len(minibatch)):

68.gt_batch.append(reward_batch[i] + LEARNING_RATE * np.max(out_batch[i]))

69.

70.optimizer.run(feed_dict = )

71.

72.input_image_data = input_image_data1

73.n = n+1

74.

75.ifn % 10000 == 0:

76.saver.save(sess,'game.cpk', global_step = n)#保存模型

77.

78.print(n,"epsilon:", epsilon," ","action:", maxIndex," ","reward:", reward)

79.

80.

81.train_neural_network(input_image)

训练效果:

AI傻乎乎的自动尝试玩这款游戏,不断试错,玩的不亦乐乎。

项目总结:

本次项目展示100行python代码,实现了利用TensorFlow框架展示深度强化学习的效果。

本文来自企鹅号 - angtk昂钛客媒体

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏锦小年的博客

Nilearn学习笔记3-提取时间序列建立功能连接体

在nilearn库中,提供了两种从fmri数据中提取时间序列的方法,一种基于脑分区(Time-series from a brain parcellation ...

3505
来自专栏利炳根的专栏

学习笔记CB001:NLTK库、语料库、词概率、双连词、词典

聊天机器人知识主要是自然语言处理。包括语言分析和理解、语言生成、机器学习、人机对话、信息检索、信息传输与信息存储、文本分类、自动文摘、数学方法、语言资源、系统评...

29010
来自专栏从流域到海域

A Gentle Introduction to Autocorrelation and Partial Autocorrelation (译文)

A Gentle Introduction to Autocorrelation and Partial Autocorrelation 自相关和偏自相关的简单...

3026
来自专栏人工智能

使用10几行Python代码,快速建立视觉模型识别图像

视觉 进化的作用,让人类对图像的处理非常高效。 这里,我给你展示一张照片。 ? 如果我这样问你: 你能否分辨出图片中哪个是猫,哪个是狗? 你可能立即会觉得自己遭...

3669
来自专栏AI研习社

Github 项目推荐 | YOLOv3 的最小化 PyTorch 实现

https://github.com/eriklindernoren/PyTorch-YOLOv3

3962
来自专栏AI科技大本营的专栏

无人驾驶汽车系统入门:基于深度学习的实时激光雷达点云目标检测及ROS实现

近年来,随着深度学习在图像视觉领域的发展,一类基于单纯的深度学习模型的点云目标检测方法被提出和应用,本文将详细介绍其中一种模型——SqueezeSeg,并且使用...

2341
来自专栏企鹅号快讯

不用@微信官方了,Python20行自动戴帽!

这两天被朋友圈里@微信官方要求戴帽的消息刷屏了,会玩的都悄咪咪地用美图秀秀一类的app给自己头像p一顶然后可高兴地表示“哎呀好神奇hhhh”,呆萌的当然就一直等...

2307
来自专栏数据小魔方

ggplot2双坐标轴的解决方案

本来没有打算写这一篇的,因为在一幅图表中使用双坐标轴确实不是一个很好地习惯,无论是信息传递的效率还是数据表达的准确性而言。 但是最近有好几个小伙伴儿跟我咨询关于...

4419
来自专栏AI研习社

手把手教你如何用 OpenCV + Python 实现人脸识别

下午的时候,配好了 OpenCV 的 Python 环境,OpenCV 的 Python 环境搭建。于是迫不及待的想体验一下 opencv 的人脸识别,如下文。...

8037
来自专栏小詹同学

人脸检测(一)——基于单文档的应用台程序

Opencv自带训练好的人脸模型(人脸的人眼、口等器官类似),此文基于vs2013建立应用台单文档程序,具体建立过程不予详细叙述,主要记录利用的Opencv自带...

4055

扫码关注云+社区

领取腾讯云代金券