首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何在Kaggle笔记本上安装"Tree Ensemble Layer“

如何在Kaggle笔记本上安装"Tree Ensemble Layer“
EN

Stack Overflow用户
提问于 2021-01-29 23:02:38
回答 1查看 134关注 0票数 0

我想在Kaggle Notebook上尝试以下代码,但我找不到安装tf_trees的方法。

代码语言:javascript
运行
复制
from tensorflow import keras
from tf_trees import TEL

tree_layer = TEL(output_logits_dim=2, trees_num=10, depth=3)

model = keras.Sequential()
model.add(keras.layers.BatchNormalization())
model.add(tree_layer)

似乎无法使用!pip安装来安装tf_trees

如果有人能提出一个解决方案,我将不胜感激。谢谢。

资源:https://github.com/google-research/google-research/tree/master/tf_trees

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-02-10 06:18:06

首先打开互联网支持,然后从github克隆google-research存储库:

代码语言:javascript
运行
复制
!git clone https://github.com/google-research/google-research.git

然后我们需要g++的编译和链接选项,因此运行以下代码片段:

代码语言:javascript
运行
复制
import tensorflow as tf; 
print(" ".join(tf.sysconfig.get_compile_flags()))

代码语言:javascript
运行
复制
import tensorflow as tf; 
print(" ".join(tf.sysconfig.get_link_flags()))

对于我的笔记本,我得到了以下标志:

代码语言:javascript
运行
复制
-I/opt/conda/lib/python3.7/site-packages/tensorflow/include -D_GLIBCXX_USE_CXX11_ABI=0
-L/opt/conda/lib/python3.7/site-packages/tensorflow -l:libtensorflow_framework.so.2

在此之后,只需用上面的输出替换变量${TF_CFLAGS[@]}${TF_LFLAGS[@]}

代码语言:javascript
运行
复制
!g++ -std=c++11 -shared google-research/tf_trees/neural_trees_ops.cc google-research/tf_trees/neural_trees_kernels.cc google-research/tf_trees/neural_trees_helpers.cc -o google-research/tf_trees/neural_trees_ops.so -fPIC -I/opt/conda/lib/python3.7/site-packages/tensorflow/include -D_GLIBCXX_USE_CXX11_ABI=0 -L/opt/conda/lib/python3.7/site-packages/tensorflow -l:libtensorflow_framework.so.2 -O2

最后,我们需要添加系统路径

代码语言:javascript
运行
复制
import sys
sys.path.insert(1, '/kaggle/working/google-research')

并运行您的代码片段

代码语言:javascript
运行
复制
from tensorflow import keras
from tf_trees import TEL

tree_layer = TEL(output_logits_dim=2, trees_num=10, depth=3)

model = keras.Sequential()
model.add(keras.layers.BatchNormalization())
model.add(tree_layer)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65956893

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档