首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何使用tensorflow.js v3.13.0进行模型推理后保存在GPU上的数据?

如何使用tensorflow.js v3.13.0进行模型推理后保存在GPU上的数据?
EN

Stack Overflow用户
提问于 2022-01-13 22:09:58
回答 1查看 315关注 0票数 0

我刚刚意识到tfjsv3.13.0中有一个新特性(请参阅https://github.com/tensorflow/tfjs/pull/5953)。我试图使用新的dataToGPU()张量方法将模型输出保留在GPU上,因为发送数据回CPU的data()方法在我的用例中花费了太多的时间。但是,当我调用新方法并尝试将它创建的WebGLTexture绑定到我的WebGLRenderingContext时,我会得到以下错误。

代码语言:javascript
运行
复制
WebGL: INVALID_OPERATION: bindTexture: object does not belong to this context

我猜这是因为纹理是在一个与我想要绑定纹理的画布不一样的上下文上创建的。因此,为了解决这个问题,似乎还有另外一个特性,它为tfjs的HTMLCanvasElement后端声明提供了一个OffscreenCanvasOffscreenCanvas(参见https://github.com/tensorflow/tfjs/pull/5983)。但是,我并没有在代码中声明任何后端,所以我不确定如何使用这些特性。

有人能告诉我如何在运行模型时实例化和使用WebGL后端吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-01-15 02:01:21

有关如何注册基于后端的自定义webgl的示例,请参阅GitHub https://github.com/vladmandic/human/blob/main/src/tfjs/humangl.ts上的以下内容

这里添加了此代码的副本,以防上面的链接失败:

代码语言:javascript
运行
复制
/** TFJS custom backend registration */

import type { Human } from '../human';
import { log } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import * as image from '../image/image';
import * as models from '../models';
import type { AnyCanvas } from '../exports';
// import { env } from '../env';

export const config = {
  name: 'humangl',
  priority: 999,
  canvas: <null | AnyCanvas>null,
  gl: <null | WebGL2RenderingContext>null,
  extensions: <string[]> [],
  webGLattr: { // https://www.khronos.org/registry/webgl/specs/latest/1.0/#5.2
    alpha: false,
    antialias: false,
    premultipliedAlpha: false,
    preserveDrawingBuffer: false,
    depth: false,
    stencil: false,
    failIfMajorPerformanceCaveat: false,
    desynchronized: true,
  },
};

function extensions(): void {
  /*
  https://www.khronos.org/registry/webgl/extensions/
  https://webglreport.com/?v=2
  */
  const gl = config.gl;
  if (!gl) return;
  config.extensions = gl.getSupportedExtensions() as string[];
  // gl.getExtension('KHR_parallel_shader_compile');
}

/**
 * Registers custom WebGL2 backend to be used by Human library
 *
 * @returns void
 */
export async function register(instance: Human): Promise<void> {
  // force backend reload if gl context is not valid
  if (instance.config.backend !== 'humangl') return;
  if ((config.name in tf.engine().registry) && (!config.gl || !config.gl.getParameter(config.gl.VERSION))) {
    log('error: humangl backend invalid context');
    models.reset(instance);
    /*
    log('resetting humangl backend');
    await tf.removeBackend(config.name);
    await register(instance); // re-register
    */
  }
  if (!tf.findBackend(config.name)) {
    try {
      config.canvas = await image.canvas(100, 100);
    } catch (err) {
      log('error: cannot create canvas:', err);
      return;
    }
    try {
      config.gl = config.canvas?.getContext('webgl2', config.webGLattr) as WebGL2RenderingContext;
      const glv2 = config.gl.getParameter(config.gl.VERSION).includes('2.0');
      if (!glv2) {
        log('override: using fallback webgl backend as webgl 2.0 is not detected');
        instance.config.backend = 'webgl';
        return;
      }
      if (config.canvas) {
        config.canvas.addEventListener('webglcontextlost', async (e) => {
          log('error: humangl:', e.type);
          log('possible browser memory leak using webgl or conflict with multiple backend registrations');
          instance.emit('error');
          throw new Error('backend error: webgl context lost');
          // log('resetting humangl backend');
          // env.initial = true;
          // models.reset(instance);
          // await tf.removeBackend(config.name);
          // await register(instance); // re-register
        });
        config.canvas.addEventListener('webglcontextrestored', (e) => {
          log('error: humangl context restored:', e);
        });
        config.canvas.addEventListener('webglcontextcreationerror', (e) => {
          log('error: humangl context create:', e);
        });
      }
    } catch (err) {
      log('error: cannot get WebGL context:', err);
      return;
    }
    try {
      tf.setWebGLContext(2, config.gl);
    } catch (err) {
      log('error: cannot set WebGL context:', err);
      return;
    }
    try {
      const ctx = new tf.GPGPUContext(config.gl);
      tf.registerBackend(config.name, () => new tf.MathBackendWebGL(ctx), config.priority);
    } catch (err) {
      log('error: cannot register WebGL backend:', err);
      return;
    }
    try {
      const kernels = tf.getKernelsForBackend('webgl');
      kernels.forEach((kernelConfig) => {
        const newKernelConfig = { ...kernelConfig, backendName: config.name };
        tf.registerKernel(newKernelConfig);
      });
    } catch (err) {
      log('error: cannot update WebGL backend registration:', err);
      return;
    }
    const current = tf.backend().getGPGPUContext ? tf.backend().getGPGPUContext().gl : null;
    if (current) {
      log(`humangl webgl version:${current.getParameter(current.VERSION)} renderer:${current.getParameter(current.RENDERER)}`);
    } else {
      log('error: no current gl context:', current, config.gl);
      return;
    }
    try {
      tf.ENV.set('WEBGL_VERSION', 2);
    } catch (err) {
      log('error: cannot set WebGL backend flags:', err);
      return;
    }
    extensions();
    log('backend registered:', config.name);
  }
}
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70703710

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档