前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >JAX 中文文档(十四)

JAX 中文文档(十四)

作者头像
ApacheCN_飞龙
发布2024-06-22 08:45:44
1520
发布2024-06-22 08:45:44
举报
文章被收录于专栏:信数据得永生

原文:jax.readthedocs.io/en/latest/

jax.scipy 模块

原文:jax.readthedocs.io/en/latest/jax.scipy.html

jax.scipy.cluster

| vq(obs, code_book[, check_finite]) | 将观测值分配给代码簿中的代码。 | ## jax.scipy.fft

dct(x[, type, n, axis, norm])

计算输入的离散余弦变换

dctn(x[, type, s, axes, norm])

计算输入的多维离散余弦变换

idct(x[, type, n, axis, norm])

计算输入的离散余弦变换的逆变换

| idctn(x[, type, s, axes, norm]) | 计算输入的多维离散余弦变换的逆变换 | ## jax.scipy.integrate

| trapezoid(y[, x, dx, axis]) | 使用复合梯形法则沿指定轴积分。 | ## jax.scipy.interpolate

| RegularGridInterpolator(points, values[, …]) | 对正规矩形网格上的点进行插值。 | ## jax.scipy.linalg

block_diag(*arrs)

从输入数组创建块对角矩阵。

cho_factor(a[, lower, overwrite_a, check_finite])

基于 Cholesky 的线性求解因式分解

cho_solve(c_and_lower, b[, overwrite_b, …])

使用 Cholesky 分解解线性系统

cholesky(a[, lower, overwrite_a, check_finite])

计算矩阵的 Cholesky 分解。

det(a[, overwrite_a, check_finite])

计算矩阵的行列式

eigh()

计算 Hermitian 矩阵的特征值和特征向量

eigh_tridiagonal(d, e, *[, eigvals_only, …])

解对称实三对角矩阵的特征值问题

expm(A, *[, upper_triangular, max_squarings])

计算矩阵指数

expm_frechet()

计算矩阵指数的 Frechet 导数

funm(A, func[, disp])

评估矩阵值函数

hessenberg()

计算矩阵的 Hessenberg 形式

hilbert(n)

创建阶数为 n 的 Hilbert 矩阵。

inv(a[, overwrite_a, check_finite])

返回方阵的逆矩阵

lu()

计算 LU 分解

lu_factor(a[, overwrite_a, check_finite])

基于 LU 的线性求解因式分解

lu_solve(lu_and_piv, b[, trans, …])

使用 LU 分解解线性系统

polar(a[, side, method, eps, max_iterations])

计算极分解

qr()

计算数组的 QR 分解

rsf2csf(T, Z[, check_finite])

将实数舒尔形式转换为复数舒尔形式。

schur(a[, output])

计算舒尔分解

solve(a, b[, lower, overwrite_a, …])

解线性方程组

solve_triangular(a, b[, trans, lower, …])

解上(或下)三角线性方程组

sqrtm(A[, blocksize])

计算矩阵的平方根

svd()

计算奇异值分解

| toeplitz(c[, r]) | 构造 Toeplitz 矩阵 | ## jax.scipy.ndimage

| map_coordinates(input, coordinates, order[, …]) | 使用插值将输入数组映射到新坐标。 | ## jax.scipy.optimize

minimize(fun, x0[, args, tol, options])

最小化一个或多个变量的标量函数。

| OptimizeResults(x, success, status, fun, …) | 优化结果对象。 | ## jax.scipy.signal

fftconvolve(in1, in2[, mode, axes])

使用快速傅里叶变换(FFT)卷积两个 N 维数组。

convolve(in1, in2[, mode, method, precision])

两个 N 维数组的卷积。

convolve2d(in1, in2[, mode, boundary, …])

两个二维数组的卷积。

correlate(in1, in2[, mode, method, precision])

两个 N 维数组的互相关。

correlate2d(in1, in2[, mode, boundary, …])

两个二维数组的互相关。

csd(x, y[, fs, window, nperseg, noverlap, …])

使用 Welch 方法估计交叉功率谱密度(CSD)。

detrend(data[, axis, type, bp, overwrite_data])

从数据中移除线性或分段线性趋势。

istft(Zxx[, fs, window, nperseg, noverlap, …])

执行逆短时傅里叶变换(ISTFT)。

stft(x[, fs, window, nperseg, noverlap, …])

计算短时傅里叶变换(STFT)。

| welch(x[, fs, window, nperseg, noverlap, …]) | 使用 Welch 方法估计功率谱密度(PSD)。 | ## jax.scipy.spatial.transform

Rotation(quat)

三维旋转。

| Slerp(times, timedelta, rotations, rotvecs) | 球面线性插值旋转。 | ## jax.scipy.sparse.linalg

bicgstab(A, b[, x0, tol, atol, maxiter, M])

使用双共轭梯度稳定迭代解决 Ax = b。

cg(A, b[, x0, tol, atol, maxiter, M])

使用共轭梯度法解决 Ax = b。

| gmres(A, b[, x0, tol, atol, restart, …]) | GMRES 解决线性系统 A x = b,给定 A 和 b。 | ## jax.scipy.special

bernoulli(n)

生成前 N 个伯努利数。

beta()

贝塔函数

betainc(a, b, x)

正则化的不完全贝塔函数。

betaln(a, b)

贝塔函数绝对值的自然对数

digamma(x)

Digamma 函数

entr(x)

熵函数

erf(x)

误差函数

erfc(x)

误差函数的补函数

erfinv(x)

误差函数的反函数

exp1(x)

指数积分函数。

expi

指数积分函数。

expit(x)

逻辑 sigmoid(expit)函数

expn

广义指数积分函数。

factorial(n[, exact])

阶乘函数

gamma(x)

伽马函数。

gammainc(a, x)

正则化的下不完全伽马函数。

gammaincc(a, x)

正则化的上不完全伽马函数。

gammaln(x)

伽马函数绝对值的自然对数。

gammasgn(x)

伽马函数的符号。

hyp1f1

1F1 超几何函数。

i0(x)

修改贝塞尔函数零阶。

i0e(x)

指数缩放的修改贝塞尔函数零阶。

i1(x)

修改贝塞尔函数一阶。

i1e(x)

指数缩放的修改贝塞尔函数一阶。

log_ndtr

对数正态分布函数。

logit

对数几率函数。

logsumexp()

对数-总和-指数归约。

lpmn(m, n, z)

第一类相关勒让德函数(ALFs)。

lpmn_values(m, n, z, is_normalized)

第一类相关勒让德函数(ALFs)。

multigammaln(a, d)

多变量伽马函数的自然对数。

ndtr(x)

正态分布函数。

ndtri§

正态分布函数的反函数。

poch

Pochhammer 符号。

polygamma(n, x)

多次伽马函数。

spence(x)

斯宾斯函数,也称实数域下的二元对数函数。

sph_harm(m, n, theta, phi[, n_max])

计算球谐函数。

xlog1py

计算 x*log(1 + y),当 x=0 时返回 0。

xlogy

计算 x*log(y),当 x=0 时返回 0。

zeta

赫维茨 ζ 函数。

kl_div(p, q)

库尔巴克-莱布勒散度。

| rel_entr(p, q) | 相对熵函数。 | ## jax.scipy.stats

mode(a[, axis, nan_policy, keepdims])

计算数组沿轴的众数(最常见的值)。

rankdata(a[, method, axis, nan_policy])

计算数组沿轴的排名。

sem(a[, axis, ddof, nan_policy, keepdims])

计算均值的标准误差。

jax.scipy.stats.bernoulli

logpmf(k, p[, loc])

伯努利对数概率质量函数。

pmf(k, p[, loc])

伯努利概率质量函数。

cdf(k, p)

伯努利累积分布函数。

| ppf(q, p) | 伯努利百分位点函数。 | ### jax.scipy.stats.beta

logpdf(x, a, b[, loc, scale])

Beta 对数概率分布函数。

pdf(x, a, b[, loc, scale])

Beta 概率分布函数。

cdf(x, a, b[, loc, scale])

Beta 累积分布函数。

logcdf(x, a, b[, loc, scale])

Beta 对数累积分布函数。

sf(x, a, b[, loc, scale])

Beta 分布生存函数。

| logsf(x, a, b[, loc, scale]) | Beta 分布对数生存函数。 | ### jax.scipy.stats.betabinom

logpmf(k, n, a, b[, loc])

Beta-二项式对数概率质量函数。

| pmf(k, n, a, b[, loc]) | Beta-二项式概率质量函数。 | ### jax.scipy.stats.binom

logpmf(k, n, p[, loc])

二项式对数概率质量函数。

| pmf(k, n, p[, loc]) | 二项式概率质量函数。 | ### jax.scipy.stats.cauchy

logpdf(x[, loc, scale])

柯西对数概率分布函数。

pdf(x[, loc, scale])

柯西概率分布函数。

cdf(x[, loc, scale])

柯西累积分布函数。

logcdf(x[, loc, scale])

柯西对数累积分布函数。

sf(x[, loc, scale])

柯西分布对数生存函数。

logsf(x[, loc, scale])

柯西对数生存函数。

isf(q[, loc, scale])

柯西分布逆生存函数。

| ppf(q[, loc, scale]) | 柯西分布分位点函数。 | ### jax.scipy.stats.chi2

logpdf(x, df[, loc, scale])

卡方分布对数概率分布函数。

pdf(x, df[, loc, scale])

卡方概率分布函数。

cdf(x, df[, loc, scale])

卡方累积分布函数。

logcdf(x, df[, loc, scale])

卡方对数累积分布函数。

sf(x, df[, loc, scale])

卡方生存函数。

| logsf(x, df[, loc, scale]) | 卡方对数生存函数。 | ### jax.scipy.stats.dirichlet

logpdf(x, alpha)

狄利克雷对数概率分布函数。

| pdf(x, alpha) | 狄利克雷概率分布函数。 | ### jax.scipy.stats.expon

logpdf(x[, loc, scale])

指数对数概率分布函数。

| pdf(x[, loc, scale]) | 指数概率分布函数。 | ### jax.scipy.stats.gamma

logpdf(x, a[, loc, scale])

伽玛对数概率分布函数。

pdf(x, a[, loc, scale])

伽玛概率分布函数。

cdf(x, a[, loc, scale])

伽玛累积分布函数。

logcdf(x, a[, loc, scale])

伽玛对数累积分布函数。

sf(x, a[, loc, scale])

伽玛生存函数。

| logsf(x, a[, loc, scale]) | 伽玛对数生存函数。 | ### jax.scipy.stats.gennorm

cdf(x, beta)

广义正态累积分布函数。

logpdf(x, beta)

广义正态对数概率分布函数。

| pdf(x, beta) | 广义正态概率分布函数。 | ### jax.scipy.stats.geom

logpmf(k, p[, loc])

几何对数概率质量函数。

| pmf(k, p[, loc]) | 几何概率质量函数。 | ### jax.scipy.stats.laplace

cdf(x[, loc, scale])

拉普拉斯累积分布函数。

logpdf(x[, loc, scale])

拉普拉斯对数概率分布函数。

| pdf(x[, loc, scale]) | 拉普拉斯概率分布函数。 | ### jax.scipy.stats.logistic

cdf(x[, loc, scale])

Logistic 累积分布函数。

isf(x[, loc, scale])

Logistic 分布逆生存函数。

logpdf(x[, loc, scale])

Logistic 对数概率分布函数。

pdf(x[, loc, scale])

Logistic 概率分布函数。

ppf(x[, loc, scale])

Logistic 分位点函数。

| sf(x[, loc, scale]) | Logistic 分布生存函数。 | ### jax.scipy.stats.multinomial

logpmf(x, n, p)

多项式对数概率质量函数。

| pmf(x, n, p) | 多项分布概率质量函数。 | ### jax.scipy.stats.multivariate_normal

logpdf(x, mean, cov[, allow_singular])

多元正态分布对数概率分布函数。

| pdf(x, mean, cov) | 多元正态分布概率分布函数。 | ### jax.scipy.stats.nbinom

logpmf(k, n, p[, loc])

负二项分布对数概率质量函数。

| pmf(k, n, p[, loc]) | 负二项分布概率质量函数。 | ### jax.scipy.stats.norm

logpdf(x[, loc, scale])

正态分布对数概率分布函数。

pdf(x[, loc, scale])

正态分布概率分布函数。

cdf(x[, loc, scale])

正态分布累积分布函数。

logcdf(x[, loc, scale])

正态分布对数累积分布函数。

ppf(q[, loc, scale])

正态分布百分点函数。

sf(x[, loc, scale])

正态分布生存函数。

logsf(x[, loc, scale])

正态分布对数生存函数。

| isf(q[, loc, scale]) | 正态分布逆生存函数。 | ### jax.scipy.stats.pareto

logpdf(x, b[, loc, scale])

帕累托对数概率分布函数。

| pdf(x, b[, loc, scale]) | 帕累托分布概率分布函数。 | ### jax.scipy.stats.poisson

logpmf(k, mu[, loc])

泊松分布对数概率质量函数。

pmf(k, mu[, loc])

泊松分布概率质量函数。

| cdf(k, mu[, loc]) | 泊松分布累积分布函数。 | ### jax.scipy.stats.t

logpdf(x, df[, loc, scale])

学生 t 分布对数概率分布函数。

| pdf(x, df[, loc, scale]) | 学生 t 分布概率分布函数。 | ### jax.scipy.stats.truncnorm

cdf(x, a, b[, loc, scale])

截断正态分布累积分布函数。

logcdf(x, a, b[, loc, scale])

截断正态分布对数累积分布函数。

logpdf(x, a, b[, loc, scale])

截断正态分布对数概率分布函数。

logsf(x, a, b[, loc, scale])

截断正态分布对数生存函数。

pdf(x, a, b[, loc, scale])

截断正态分布概率分布函数。

| sf(x, a, b[, loc, scale]) | 截断正态分布对数生存函数。 | ### jax.scipy.stats.uniform

logpdf(x[, loc, scale])

均匀分布对数概率分布函数。

pdf(x[, loc, scale])

均匀分布概率分布函数。

cdf(x[, loc, scale])

均匀分布累积分布函数。

ppf(q[, loc, scale])

均匀分布百分点函数。

jax.scipy.stats.gaussian_kde

gaussian_kde(dataset[, bw_method, weights])

高斯核密度估计器

gaussian_kde.evaluate(points)

对给定点评估高斯核密度估计器。

gaussian_kde.integrate_gaussian(mean, cov)

加权高斯积分分布。

gaussian_kde.integrate_box_1d(low, high)

在给定限制下积分分布。

gaussian_kde.integrate_kde(other)

集成两个高斯核密度估计分布的乘积。

gaussian_kde.resample(key[, shape])

从估计的概率密度函数中随机采样数据集

gaussian_kde.pdf(x)

概率密度函数

gaussian_kde.logpdf(x)

对数概率密度函数

jax.scipy.stats.vonmises

logpdf(x, kappa)

von Mises 对数概率分布函数。

| pdf(x, kappa) | von Mises 概率分布函数。 | ### jax.scipy.stats.wrapcauchy

logpdf(x, c)

Wrapped Cauchy 对数概率分布函数。

pdf(x, c)

Wrapped Cauchy 概率分布函数。

jax.scipy.stats.bernoulli.logpmf

原文:jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.logpmf.html

代码语言:javascript
复制
jax.scipy.stats.bernoulli.logpmf(k, p, loc=0)

伯努利对数概率质量函数。

scipy.stats.bernoulli 的 JAX 实现 logpmf

伯努利概率质量函数定义如下

[\begin{split}f(k) = \begin{cases} 1 - p, & k = 0 \ p, & k = 1 \ 0, & \mathrm{otherwise} \end{cases}\end{split}]

参数:

返回值:

logpmf 值的数组

返回类型:

Array

另请参阅

  • jax.scipy.stats.bernoulli.cdf()
  • jax.scipy.stats.bernoulli.pmf()
  • jax.scipy.stats.bernoulli.ppf()

jax.scipy.stats.bernoulli.pmf

原文:jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.pmf.html

代码语言:javascript
复制
jax.scipy.stats.bernoulli.pmf(k, p, loc=0)

伯努利概率质量函数。

scipy.stats.bernoulli pmf 的 JAX 实现

伯努利概率质量函数定义为

[\begin{split}f(k) = \begin{cases} 1 - p, & k = 0 \ p, & k = 1 \ 0, & \mathrm{otherwise} \end{cases}\end{split}]

参数:

返回:

pmf 值数组

返回类型:

数组

参见

  • jax.scipy.stats.bernoulli.cdf()
  • jax.scipy.stats.bernoulli.logpmf()
  • jax.scipy.stats.bernoulli.ppf()

jax.scipy.stats.bernoulli.cdf

原文:jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.cdf.html

代码语言:javascript
复制
jax.scipy.stats.bernoulli.cdf(k, p)

伯努利累积分布函数。

scipy.stats.bernoulli 的 JAX 实现 cdf

伯努利累积分布函数被定义为:

[f_{cdf}(k, p) = \sum_{i=0}^k f_{pmf}(k, p)]

其中 (f_{pmf}(k, p)) 是伯努利概率质量函数 jax.scipy.stats.bernoulli.pmf()

参数:

返回:

cdf 值的数组

返回类型:

Array

另请参见

  • jax.scipy.stats.bernoulli.logpmf()
  • jax.scipy.stats.bernoulli.pmf()
  • jax.scipy.stats.bernoulli.ppf()

jax.scipy.stats.bernoulli.ppf

原文:jax.readthedocs.io/en/latest/_autosummary/jax.scipy.stats.bernoulli.ppf.html

代码语言:javascript
复制
jax.scipy.stats.bernoulli.ppf(q, p)

伯努利百分点函数。

JAX 实现的 scipy.stats.bernoulli ppf

百分点函数是累积分布函数的反函数,jax.scipy.stats.bernoulli.cdf()

参数:

返回:

ppf 值数组

返回类型:

Array

另见

  • jax.scipy.stats.bernoulli.cdf()
  • jax.scipy.stats.bernoulli.logpmf()
  • jax.scipy.stats.bernoulli.pmf()

jax.lax 模块

原文:jax.readthedocs.io/en/latest/jax.lax.html

jax.lax 是支持诸如 jax.numpy 等库的基本操作的库。通常会定义转换规则,例如 JVP 和批处理规则,作为对 jax.lax 基元的转换。

许多基元都是等价于 XLA 操作的薄包装,详细描述请参阅XLA 操作语义文档。

在可能的情况下,优先使用诸如 jax.numpy 等库,而不是直接使用 jax.laxjax.numpy API 遵循 NumPy,因此比 jax.lax API 更稳定,更不易更改。

Operators

abs(x)

按元素绝对值:(|x|)。

acos(x)

按元素求反余弦:(\mathrm{acos}(x))。

acosh(x)

按元素求反双曲余弦:(\mathrm{acosh}(x))。

add(x, y)

按元素加法:(x + y)。

after_all(*operands)

合并一个或多个 XLA 令牌值。

approx_max_k(operand, k[, …])

以近似方式返回 operand 的最大 k 值及其索引。

approx_min_k(operand, k[, …])

以近似方式返回 operand 的最小 k 值及其索引。

argmax(operand, axis, index_dtype)

计算沿着 axis 的最大元素的索引。

argmin(operand, axis, index_dtype)

计算沿着 axis 的最小元素的索引。

asin(x)

按元素求反正弦:(\mathrm{asin}(x))。

asinh(x)

按元素求反双曲正弦:(\mathrm{asinh}(x))。

atan(x)

按元素求反正切:(\mathrm{atan}(x))。

atan2(x, y)

两个变量的按元素反正切:(\mathrm{atan}({x \over y}))。

atanh(x)

按元素求反双曲正切:(\mathrm{atanh}(x))。

batch_matmul(lhs, rhs[, precision])

批量矩阵乘法。

bessel_i0e(x)

指数缩放修正贝塞尔函数 (0) 阶:(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x))

bessel_i1e(x)

指数缩放修正贝塞尔函数 (1) 阶:(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x))

betainc(a, b, x)

按元素的正则化不完全贝塔积分。

bitcast_convert_type(operand, new_dtype)

按元素位转换。

bitwise_and(x, y)

按位与运算:(x \wedge y)。

bitwise_not(x)

按位取反:(\neg x)。

bitwise_or(x, y)

按位或运算:(x \vee y)。

bitwise_xor(x, y)

按位异或运算:(x \oplus y)。

population_count(x)

按元素计算 popcount,即每个元素中设置的位数。

broadcast(operand, sizes)

广播数组,添加新的前导维度。

broadcast_in_dim(operand, shape, …)

包装 XLA 的 BroadcastInDim 操作符。

broadcast_shapes()

返回经过 NumPy 广播后的形状。

broadcast_to_rank(x, rank)

添加 1 的前导维度,使 x 的等级为 rank。

broadcasted_iota(dtype, shape, dimension)

iota的便捷封装器。

cbrt(x)

元素级立方根:(\sqrt[3]{x})。

ceil(x)

元素级向上取整:(\left\lceil x \right\rceil)。

clamp(min, x, max)

元素级 clamp 函数。

clz(x)

元素级计算前导零的个数。

collapse(operand, start_dimension[, …])

将数组的维度折叠为单个维度。

complex(x, y)

元素级构造复数:(x + jy)。

concatenate(operands, dimension)

沿指定维度连接一系列数组。

conj(x)

元素级复数的共轭函数:(\overline{x})。

conv(lhs, rhs, window_strides, padding[, …])

conv_general_dilated的便捷封装器。

convert_element_type(operand, new_dtype)

元素级类型转换。

conv_dimension_numbers(lhs_shape, rhs_shape, …)

将卷积维度编号转换为 ConvDimensionNumbers。

conv_general_dilated(lhs, rhs, …[, …])

带有可选扩展的通用 n 维卷积运算符。

conv_general_dilated_local(lhs, rhs, …[, …])

带有可选扩展的通用 n 维非共享卷积运算符。

conv_general_dilated_patches(lhs, …[, …])

提取符合 conv_general_dilated 接受域的补丁。

conv_transpose(lhs, rhs, strides, padding[, …])

计算 N 维卷积的“转置”的便捷封装器。

conv_with_general_padding(lhs, rhs, …[, …])

conv_general_dilated的便捷封装器。

cos(x)

元素级余弦函数:(\mathrm{cos}(x))。

cosh(x)

元素级双曲余弦函数:(\mathrm{cosh}(x))。

cumlogsumexp(operand[, axis, reverse])

沿轴计算累积 logsumexp。

cummax(operand[, axis, reverse])

沿轴计算累积最大值。

cummin(operand[, axis, reverse])

沿轴计算累积最小值。

cumprod(operand[, axis, reverse])

沿轴计算累积乘积。

cumsum(operand[, axis, reverse])

沿轴计算累积和。

digamma(x)

元素级 digamma 函数:(\psi(x))。

div(x, y)

元素级除法:(x \over y)。

dot(lhs, rhs[, precision, …])

向量/向量,矩阵/向量和矩阵/矩阵乘法。

dot_general(lhs, rhs, dimension_numbers[, …])

通用的点积/收缩运算符。

dynamic_index_in_dim(operand, index[, axis, …])

对 dynamic_slice 的便捷封装,用于执行整数索引。

dynamic_slice(operand, start_indices, …)

封装了 XLA 的 DynamicSlice 操作符。

dynamic_slice_in_dim(operand, start_index, …)

方便地封装了应用于单个维度的 lax.dynamic_slice()。

dynamic_update_index_in_dim(operand, update, …)

方便地封装了 dynamic_update_slice(),用于在单个 axis 中更新大小为 1 的切片。

dynamic_update_slice(operand, update, …)

封装了 XLA 的 DynamicUpdateSlice 操作符。

dynamic_update_slice_in_dim(operand, update, …)

方便地封装了 dynamic_update_slice(),用于在单个 axis 中更新一个切片。

eq(x, y)

元素级相等:(x = y)。

erf(x)

元素级误差函数:(\mathrm{erf}(x))。

erfc(x)

元素级补充误差函数:(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x))。

erf_inv(x)

元素级反误差函数:(\mathrm{erf}^{-1}(x))。

exp(x)

元素级指数函数:(e^x)。

expand_dims(array, dimensions)

将任意数量的大小为 1 的维度插入到数组中。

expm1(x)

元素级运算 (e^{x} - 1)。

fft(x, fft_type, fft_lengths)

floor(x)

元素级向下取整:(\left\lfloor x \right\rfloor)。

full(shape, fill_value[, dtype, sharding])

返回填充值为 fill_value 的形状数组。

full_like(x, fill_value[, dtype, shape, …])

基于示例数组 x 创建类似于 np.full 的完整数组。

gather(operand, start_indices, …[, …])

Gather 操作符。

ge(x, y)

元素级大于或等于:(x \geq y)。

gt(x, y)

元素级大于:(x > y)。

igamma(a, x)

元素级正则化不完全 gamma 函数。

igammac(a, x)

元素级补充正则化不完全 gamma 函数。

imag(x)

提取复数的虚部:(\mathrm{Im}(x))。

index_in_dim(operand, index[, axis, keepdims])

方便地封装了 lax.slice(),用于执行整数索引。

index_take(src, idxs, axes)

integer_pow(x, y)

元素级幂运算:(x^y),其中 (y) 是固定整数。

iota(dtype, size)

封装了 XLA 的 Iota 操作符。

is_finite(x)

元素级 (\mathrm{isfinite})。

le(x, y)

元素级小于或等于:(x \leq y)。

lgamma(x)

元素级对数 gamma 函数:(\mathrm{log}(\Gamma(x)))。

log(x)

元素级自然对数:(\mathrm{log}(x))。

log1p(x)

元素级 (\mathrm{log}(1 + x))。

logistic(x)

元素级 logistic(sigmoid)函数:(\frac{1}{1 + e^{-x}})。

lt(x, y)

元素级小于:(x < y)。

max(x, y)

元素级最大值:(\mathrm{max}(x, y))

min(x, y)

元素级最小值:(\mathrm{min}(x, y))

mul(x, y)

元素级乘法:(x \times y)。

ne(x, y)

按位不等于:(x \neq y)。

neg(x)

按位取负:(-x)。

nextafter(x1, x2)

返回 x1 在 x2 方向上的下一个可表示的值。

pad(operand, padding_value, padding_config)

对数组应用低、高和/或内部填充。

polygamma(m, x)

按位多次 gamma 函数:(\psi^{(m)}(x))。

population_count(x)

按位人口统计,统计每个元素中设置的位数。

pow(x, y)

按位幂运算:(x^y)。

random_gamma_grad(a, x)

Gamma 分布导数的按位计算。

real(x)

按位提取实部:(\mathrm{Re}(x))。

reciprocal(x)

按位倒数:(1 \over x)。

reduce(operands, init_values, computation, …)

封装了 XLA 的 Reduce 运算符。

reduce_precision(operand, exponent_bits, …)

封装了 XLA 的 ReducePrecision 运算符。

reduce_window(operand, init_value, …[, …])

rem(x, y)

按位取余:(x \bmod y)。

reshape(operand, new_sizes[, dimensions])

封装了 XLA 的 Reshape 运算符。

rev(operand, dimensions)

封装了 XLA 的 Rev 运算符。

rng_bit_generator(key, shape[, dtype, algorithm])

无状态的伪随机数位生成器。

rng_uniform(a, b, shape)

有状态的伪随机数生成器。

round(x[, rounding_method])

按位四舍五入。

rsqrt(x)

按位倒数平方根:(1 \over \sqrt{x})。

scatter(operand, scatter_indices, updates, …)

Scatter-update 运算符。

scatter_add(operand, scatter_indices, …[, …])

Scatter-add 运算符。

scatter_apply(operand, scatter_indices, …)

Scatter-apply 运算符。

scatter_max(operand, scatter_indices, …[, …])

Scatter-max 运算符。

scatter_min(operand, scatter_indices, …[, …])

Scatter-min 运算符。

scatter_mul(operand, scatter_indices, …[, …])

Scatter-multiply 运算符。

shift_left(x, y)

按位左移:(x \ll y)。

shift_right_arithmetic(x, y)

按位算术右移:(x \gg y)。

shift_right_logical(x, y)

按位逻辑右移:(x \gg y)。

sign(x)

按位符号函数。

sin(x)

按位正弦函数:(\mathrm{sin}(x))。

sinh(x)

按位双曲正弦函数:(\mathrm{sinh}(x))。

slice(operand, start_indices, limit_indices)

封装了 XLA 的 Slice 运算符。

slice_in_dim(operand, start_index, limit_index)

lax.slice() 的单维度应用封装。

sort()

封装了 XLA 的 Sort 运算符。

sort_key_val(keys, values[, dimension, …])

沿着dimension排序keys并对values应用相同的置换。

sqrt(x)

逐元素平方根:(\sqrt{x})。

square(x)

逐元素平方:(x²)。

squeeze(array, dimensions)

从数组中挤出任意数量的大小为 1 的维度。

sub(x, y)

逐元素减法:(x - y)。

tan(x)

逐元素正切:(\mathrm{tan}(x))。

tanh(x)

逐元素双曲正切:(\mathrm{tanh}(x))。

top_k(operand, k)

返回operand最后一轴上的前k个值及其索引。

transpose(operand, permutation)

包装 XLA 的Transpose运算符。

zeros_like_array(x)

zeta(x, q)

逐元素 Hurwitz zeta 函数:(\zeta(x, q))

控制流操作符

associative_scan(fn, elems[, reverse, axis])

使用关联二元操作并行执行扫描。

cond(pred, true_fun, false_fun, *operands[, …])

根据条件应用true_fun或false_fun。

fori_loop(lower, upper, body_fun, init_val, *)

通过归约到jax.lax.while_loop()从lower到upper循环。

map(f, xs)

在主要数组轴上映射函数。

scan(f, init[, xs, length, reverse, unroll, …])

在主要数组轴上扫描函数并携带状态。

select(pred, on_true, on_false)

根据布尔谓词在两个分支之间选择。

select_n(which, *cases)

从多个情况中选择数组值。

switch(index, branches, *operands[, operand])

根据index应用恰好一个branches。

while_loop(cond_fun, body_fun, init_val)

在cond_fun为 True 时重复调用body_fun。

自定义梯度操作符

stop_gradient(x)

停止梯度计算。

custom_linear_solve(matvec, b, solve[, …])

使用隐式定义的梯度执行无矩阵线性求解。

custom_root(f, initial_guess, solve, …[, …])

可微分求解函数的根。

并行操作符

all_gather(x, axis_name, *[, …])

在所有副本中收集x的值。

all_to_all(x, axis_name, split_axis, …[, …])

映射轴的实例化和映射不同轴。

pdot(x, y, axis_name[, pos_contract, …])

psum(x, axis_name, *[, axis_index_groups])

在映射的轴axis_name上进行全归约求和。

psum_scatter(x, axis_name, *[, …])

像psum(x, axis_name),但每个设备仅保留部分结果。

pmax(x, axis_name, *[, axis_index_groups])

在映射的轴axis_name上计算全归约最大值。

pmin(x, axis_name, *[, axis_index_groups])

在映射的轴axis_name上计算全归约最小值。

pmean(x, axis_name, *[, axis_index_groups])

在映射的轴axis_name上计算全归约均值。

ppermute(x, axis_name, perm)

根据置换 perm 执行集体置换。

pshuffle(x, axis_name, perm)

使用替代置换编码的 jax.lax.ppermute 的便捷包装器

pswapaxes(x, axis_name, axis, *[, …])

将 pmapped 轴 axis_name 与非映射轴 axis 交换。

axis_index(axis_name)

返回沿映射轴 axis_name 的索引。

与分片相关的操作符

with_sharding_constraint(x, shardings)

在 jitted 计算中约束数组的分片机制

线性代数操作符 (jax.lax.linalg)

cholesky(x, *[, symmetrize_input])

Cholesky 分解。

eig(x, *[, compute_left_eigenvectors, …])

一般矩阵的特征分解。

eigh(x, *[, lower, symmetrize_input, …])

Hermite 矩阵的特征分解。

hessenberg(a)

将方阵约化为上 Hessenberg 形式。

lu(x)

带有部分主元列主元分解。

householder_product(a, taus)

单元 Householder 反射的乘积。

qdwh(x, *[, is_hermitian, max_iterations, …])

基于 QR 的动态加权 Halley 迭代进行极分解。

qr(x, *[, full_matrices])

QR 分解。

schur(x, *[, compute_schur_vectors, …])

svd()

奇异值分解。

triangular_solve(a, b, *[, left_side, …])

三角解法。

tridiagonal(a, *[, lower])

将对称/Hermitian 矩阵约化为三对角形式。

tridiagonal_solve(dl, d, du, b)

计算三对角线性系统的解。

参数类

代码语言:javascript
复制
class jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)

描述卷积的批量、空间和特征维度。

参数:

  • lhs_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(批量维度,特征维度,空间维度…)。
  • rhs_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(输出特征维度,输入特征维度,空间维度…)。
  • out_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(批量维度,特征维度,空间维度…)。
代码语言:javascript
复制
jax.lax.ConvGeneralDilatedDimensionNumbers

alias of tuple[str, str, str] | ConvDimensionNumbers | None

代码语言:javascript
复制
class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map)

描述了传递给 XLA 的 Gather 运算符 的维度号参数。有关维度号含义的详细信息,请参阅 XLA 文档。

Parameters:

  • offset_dims (tuple[int, …**]) – gather 输出中偏移到从操作数切片的数组中的维度的集合。必须是升序整数元组,每个代表输出的一个维度编号。
  • collapsed_slice_dims (tuple[int, …**]) – operand 中具有 slice_sizes[i] == 1 的维度 i 的集合,这些维度不应在 gather 输出中具有对应维度。必须是一个升序整数元组。
  • start_index_map (tuple[int, …**]) – 对于 start_indices 中的每个维度,给出应该被切片的操作数中对应的维度。必须是一个大小等于 start_indices.shape[-1] 的整数元组。

与 XLA 的 GatherDimensionNumbers 结构不同,index_vector_dim 是隐含的;总是存在一个索引向量维度,且它必须始终是最后一个维度。要收集标量索引,请添加大小为 1 的尾随维度。

代码语言:javascript
复制
class jax.lax.GatherScatterMode(value)

描述了如何处理 gather 或 scatter 中的越界索引。

可能的值包括:

CLIP:

索引将被夹在最近的范围内值上,即整个要收集的窗口都在范围内。

FILL_OR_DROP:

如果收集窗口的任何部分越界,则返回整个窗口,即使其他部分原本在界内的元素也将用常量填充。如果分散窗口的任何部分越界,则整个窗口将被丢弃。

PROMISE_IN_BOUNDS:

用户承诺索引在范围内。不会执行额外检查。实际上,根据当前的 XLA 实现,这意味着越界的 gather 将被夹在范围内,但越界的 scatter 将被丢弃。如果索引越界,则梯度将不正确。

代码语言:javascript
复制
class jax.lax.Precision(value)

lax 函数的精度枚举

JAX 函数的精度参数通常控制加速器后端(即 TPU 和 GPU)上的数组计算速度和精度之间的权衡。成员包括:

默认:

最快模式,但最不准确。在 bfloat16 中执行计算。别名:'default''fastest''bfloat16'

高:

较慢但更准确。以 3 个 bfloat16 传递执行 float32 计算,或在可用时使用 tensorfloat32。别名:'high''bfloat16_3x''tensorfloat32'

最高:

最慢但最准确。根据适用情况在 float32 或 float64 中执行计算。别名:'highest''float32'

代码语言:javascript
复制
jax.lax.PrecisionLike

别名为 str | Precision | tuple[str, str] | tuple[Precision, Precision] | None

代码语言:javascript
复制
class jax.lax.RoundingMethod(value)

一个枚举。

代码语言:javascript
复制
class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)

描述了对 XLA 的 Scatter 操作符 的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。

参数:

  • update_window_dims (Sequence[int]) – 更新中作为窗口维度的维度集合。必须是整数元组,按升序排列,每个表示一个维度编号。
  • inserted_window_dims (Sequence[int]) – 必须插入更新形状的大小为 1 的窗口维度集合。必须是整数元组,按升序排列,每个表示输出的维度编号的镜像图。这些是 gather 情况下 collapsed_slice_dims 的镜像图。
  • scatter_dims_to_operand_dims (Sequence[int]) – 对于 scatter_indices 中的每个维度,给出 operand 中对应的维度。必须是整数序列,大小等于 scatter_indices.shape[-1]。

与 XLA 的 ScatterDimensionNumbers 结构不同,index_vector_dim 是隐式的;总是有一个索引向量维度,并且它必须始终是最后一个维度。要分散标量索引,添加一个尺寸为 1 的尾随维度。

jax.random 模块

原文:jax.readthedocs.io/en/latest/jax.random.html

伪随机数生成的实用程序。

jax.random 包提供了多种例程,用于确定性生成伪随机数序列。

基本用法

代码语言:javascript
复制
>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
...   key, subkey = jax.random.split(key)
...   params = compiled_update(subkey, params, next(batches)) 

PRNG keys

与 NumPy 和 SciPy 用户习惯的 有状态 伪随机数生成器(PRNGs)不同,JAX 随机函数都要求作为第一个参数传递一个显式的 PRNG 状态。随机状态由我们称之为 key 的特殊数组元素类型描述,通常由 jax.random.key() 函数生成:

代码语言:javascript
复制
>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0] 

然后,可以在 JAX 的任何随机数生成例程中使用该 key:

代码语言:javascript
复制
>>> random.uniform(key)
Array(0.41845703, dtype=float32) 

请注意,使用 key 不会修改它,因此重复使用相同的 key 将导致相同的结果:

代码语言:javascript
复制
>>> random.uniform(key)
Array(0.41845703, dtype=float32) 

如果需要新的随机数,可以使用 jax.random.split() 生成新的子 key:

代码语言:javascript
复制
>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.10536897, dtype=float32) 

注意

类型化的 key 数组,例如上述 key<fry>,在 JAX v0.4.16 中引入。在此之前,key 通常以 uint32 数组表示,其最终维度表示 key 的位级表示。

两种形式的 key 数组仍然可以通过 jax.random 模块创建和使用。新式的类型化 key 数组使用 jax.random.key() 创建。传统的 uint32 key 数组使用 jax.random.PRNGKey() 创建。

要在两者之间进行转换,使用 jax.random.key_data()jax.random.wrap_key_data()。当与 JAX 外部系统(例如将数组导出为可序列化格式)交互或将 key 传递给基于 JAX 的库时,可能需要传统的 key 格式。

否则,建议使用类型化的 key。传统 key 相对于类型化 key 的注意事项包括:

  • 它们有一个额外的尾维度。
  • 它们具有数字数据类型 (uint32),允许进行通常不用于 key 的操作,例如整数算术。
  • 它们不包含有关 RNG 实现的信息。当传统 key 传递给 jax.random 函数时,全局配置设置确定 RNG 实现(参见下文的“高级 RNG 配置”)。

要了解更多关于此升级以及 key 类型设计的信息,请参阅 JEP 9263

高级

设计和背景

TLDR:JAX PRNG = Threefry counter PRNG + 一个功能数组导向的 分裂模型

更多详细信息,请参阅 docs/jep/263-prng.md

总结一下,JAX PRNG 还包括但不限于以下要求:

  1. 确保可重现性,
  2. 良好的并行化,无论是向量化(生成数组值)还是多副本、多核计算。特别是它不应在随机函数调用之间使用顺序约束。
高级 RNG 配置

JAX 提供了几种 PRNG 实现。可以通过可选的 impl 关键字参数选择特定的实现。如果在密钥构造函数中没有传递 impl 选项,则实现由全局 jax_default_prng_impl 配置标志确定。

  • 默认,“threefry2x32”: 基于 Threefry 哈希函数构建的基于计数器的 PRNG
  • 实验性 一种仅包装了 XLA 随机位生成器(RBG)算法的 PRNG。请参阅 TF 文档
    • “rbg” 使用 ThreeFry 进行分割,并使用 XLA RBG 进行数据生成。
    • “unsafe_rbg” 仅用于演示目的,使用 RBG 进行分割(使用未经测试的虚构算法)和生成。

    这些实验性实现生成的随机流尚未经过任何经验随机性测试(例如 Big Crush)。生成的随机比特可能会在 JAX 的不同版本之间变化。

不使用默认 RNG 的可能原因是:

  1. 可能编译速度较慢(特别是对于 Google Cloud TPU)
  2. 在 TPU 上执行速度较慢
  3. 不支持高效的自动分片/分区

这里是一个简短的总结:

属性

Threefry

Threefry*

rbg

unsafe_rbg

rbg**

unsafe_rbg**

在 TPU 上最快

可以高效分片(使用 pjit)

在分片中相同

在 CPU/GPU/TPU 上相同

在 JAX/XLA 版本间相同

(*): 设置了jax_threefry_partitionable=1

(**): 设置了XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1

“rbg” 和 “unsafe_rbg” 之间的区别在于,“rbg” 用于生成随机值时使用了较不稳定/研究较少的哈希函数(但不用于 jax.random.split 或 jax.random.fold_in),而 “unsafe_rbg” 还额外在 jax.random.split 和 jax.random.fold_in 中使用了更不稳定的哈希函数。因此,在不同密钥生成的随机流质量方面不那么安全。

要了解有关 jax_threefry_partitionable 的更多信息,请参阅jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers

API 参考

密钥创建与操作

PRNGKey(seed, *[, impl])

给定整数种子创建伪随机数生成器(PRNG)密钥。

key(seed, *[, impl])

给定整数种子创建伪随机数生成器(PRNG)密钥。

key_data(密钥)

恢复 PRNG 密钥数组下的密钥数据位。

wrap_key_data(key_bits_array, *[, impl])

将密钥数据位数组包装成 PRNG 密钥数组。

fold_in(key, data)

将数据折叠到 PRNG 密钥中,形成新的 PRNG 密钥。

split(key[, num])

将 PRNG 密钥按添加一个前导轴拆分为 num 个新密钥。

clone(key)

克隆一个密钥以便重复使用。

随机抽样器

ball(key, d[, p, shape, dtype])

从单位 Lp 球中均匀采样。

bernoulli(key[, p, shape])

采样给定形状和均值的伯努利分布随机值。

beta(key, a, b[, shape, dtype])

采样给定形状和浮点数数据类型的贝塔分布随机值。

binomial(key, n, p[, shape, dtype])

采样给定形状和浮点数数据类型的二项分布随机值。

bits(key[, shape, dtype])

以无符号整数的形式采样均匀比特。

categorical(key, logits[, axis, shape])

从分类分布中采样随机值。

cauchy(key[, shape, dtype])

采样给定形状和浮点数数据类型的柯西分布随机值。

chisquare(key, df[, shape, dtype])

采样给定形状和浮点数数据类型的卡方分布随机值。

choice(key, a[, shape, replace, p, axis])

从给定数组中生成随机样本。

dirichlet(key, alpha[, shape, dtype])

采样给定形状和浮点数数据类型的狄利克雷分布随机值。

double_sided_maxwell(key, loc, scale[, …])

从双边 Maxwell 分布中采样。

exponential(key[, shape, dtype])

采样给定形状和浮点数数据类型的指数分布随机值。

f(key, dfnum, dfden[, shape, dtype])

采样给定形状和浮点数数据类型的 F 分布随机值。

gamma(key, a[, shape, dtype])

采样给定形状和浮点数数据类型的伽马分布随机值。

generalized_normal(key, p[, shape, dtype])

从广义正态分布中采样。

geometric(key, p[, shape, dtype])

采样给定形状和浮点数数据类型的几何分布随机值。

gumbel(key[, shape, dtype])

采样给定形状和浮点数数据类型的 Gumbel 分布随机值。

laplace(key[, shape, dtype])

采样给定形状和浮点数数据类型的拉普拉斯分布随机值。

loggamma(key, a[, shape, dtype])

采样给定形状和浮点数数据类型的对数伽马分布随机值。

logistic(key[, shape, dtype])

采样给定形状和浮点数数据类型的 logistic 随机值。

lognormal(key[, sigma, shape, dtype])

采样给定形状和浮点数数据类型的对数正态分布随机值。

maxwell(key[, shape, dtype])

从单边 Maxwell 分布中采样。

multivariate_normal(key, mean, cov[, shape, …])

采样给定均值和协方差的多变量正态分布随机值。

normal(key[, shape, dtype])

采样给定形状和浮点数数据类型的标准正态分布随机值。

orthogonal(key, n[, shape, dtype])

从正交群 O(n) 中均匀采样。

pareto(key, b[, shape, dtype])

采样给定形状和浮点数数据类型的帕累托分布随机值。

permutation(key, x[, axis, independent])

返回随机排列的数组或范围。

poisson(key, lam[, shape, dtype])

采样给定形状和整数数据类型的泊松分布随机值。

rademacher(key[, shape, dtype])

从 Rademacher 分布中采样。

randint(key, shape, minval, maxval[, dtype])

用给定的形状和数据类型在[minval, maxval)范围内示例均匀随机整数值。

[rayleigh(key, scale[, shape, dtype])

用给定的形状和浮点数数据类型示例瑞利随机值。

t(key, df[, shape, dtype])

用给定的形状和浮点数数据类型示例学生 t 分布随机值。

triangular(key, left, mode, right[, shape, …])

用给定的形状和浮点数数据类型示例三角形随机值。

truncated_normal(key, lower, upper[, shape, …])

用给定的形状和数据类型示例截断标准正态随机值。

uniform(key[, shape, dtype, minval, maxval])

用给定的形状和数据类型在[minval, maxval)范围内示例均匀随机值。

[wald(key, mean[, shape, dtype])

用给定的形状和浮点数数据类型示例瓦尔德随机值。

weibull_min(key, scale, concentration[, …])

从威布尔分布中采样。

jax.sharding 模块

原文:jax.readthedocs.io/en/latest/jax.sharding.html

代码语言:javascript
复制
class jax.sharding.Sharding

描述了jax.Array如何跨设备布局。

代码语言:javascript
复制
property addressable_devices: set[Device]

Sharding中由当前进程可寻址的设备集合。

代码语言:javascript
复制
addressable_devices_indices_map(global_shape)

从可寻址设备到它们包含的数组数据切片的映射。

addressable_devices_indices_map 包含适用于可寻址设备的device_indices_map部分。

参数:

global_shape (tuple[int, …**])

返回类型:

Mapping[Device, tuple[slice, …] | None]

代码语言:javascript
复制
property device_set: set[Device]

这个Sharding跨越的设备集合。

在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。

代码语言:javascript
复制
devices_indices_map(global_shape)

返回从设备到它们包含的数组切片的映射。

映射包括所有全局设备,即包括来自其他进程的不可寻址设备。

参数:

global_shape (tuple[int, …**])

返回类型:

Mapping[Device, tuple[slice, …]]

代码语言:javascript
复制
is_equivalent_to(other, ndim)

如果两个分片等效,则返回True

如果它们在相同设备上放置了相同的逻辑数组分片,则两个分片是等效的。

例如,如果NamedShardingPositionalSharding都将数组的相同分片放置在相同的设备上,则它们可能是等效的。

参数:

  • self (Sharding)
  • other (Sharding)
  • ndim (int)

返回类型:

bool

代码语言:javascript
复制
property is_fully_addressable: bool

此分片是否是完全可寻址的?

如果当前进程能够寻址Sharding中列出的所有设备,则分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable 等效于 “is_local”。

代码语言:javascript
复制
property is_fully_replicated: bool

此分片是否完全复制?

如果每个设备都有整个数据的完整副本,则分片是完全复制的。

代码语言:javascript
复制
property memory_kind: str | None

返回分片的内存类型。

代码语言:javascript
复制
shard_shape(global_shape)

返回每个设备上数据的形状。

此函数返回的分片形状是从global_shape和分片属性计算得出的。

参数:

global_shape (tuple[int, …**])

返回类型:

tuple[int, …]

代码语言:javascript
复制
with_memory_kind(kind)

返回具有指定内存类型的新分片实例。

参数:

kind (str)

返回类型:

分片

代码语言:javascript
复制
class jax.sharding.SingleDeviceSharding

基类:分片

一个将其数据放置在单个设备上的分片

参数:

device – 单个设备

示例

代码语言:javascript
复制
>>> single_device_sharding = jax.sharding.SingleDeviceSharding(
...     jax.devices()[0]) 
代码语言:javascript
复制
property device_set: set[Device]

分片跨越的设备集。

在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的非可寻址设备。

代码语言:javascript
复制
devices_indices_map(global_shape)

返回从设备到每个包含的数组片段的映射。

映射包括所有全局设备,即包括来自其他进程的非可寻址设备。

参数:

global_shape (tuple[int, …**])

返回类型:

映射[设备, tuple[slice, …]]

代码语言:javascript
复制
property is_fully_addressable: bool

此分片是否完全可寻址?

如果当前进程可以寻址分片中命名的所有设备,则称分片完全可寻址。is_fully_addressable在多进程 JAX 中等同于“is_local”。

代码语言:javascript
复制
property is_fully_replicated: bool

此分片是否完全复制?

如果每个设备都有整个数据的完整副本,则分片完全复制。

代码语言:javascript
复制
property memory_kind: str | None

返回分片的内存类型。

代码语言:javascript
复制
with_memory_kind(kind)

返回具有指定内存类型的新分片实例。

参数:

kind (str)

返回类型:

单设备分片

代码语言:javascript
复制
class jax.sharding.NamedSharding

基类:分片

一个NamedSharding使用命名轴来表示分片。

一个NamedSharding是设备Mesh和描述如何跨该网格对数组进行分片的PartitionSpec的组合。

一个Mesh是 JAX 设备的多维 NumPy 数组,其中网格的每个轴都有一个名称,例如 'x''y'

一个PartitionSpec是一个元组,其元素可以是None、一个网格轴或一组网格轴的元组。每个元素描述如何在零个或多个网格维度上对输入维度进行分区。例如,PartitionSpec('x', 'y')表示数据的第一维在网格的 x 轴上进行分片,第二维在网格的 y 轴上进行分片。

分布式数组和自动并行化(jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names)教程详细讲解了如何使用MeshPartitionSpec,包括更多细节和图示。

参数:

  • mesh – 一个jax.sharding.Mesh对象。
  • spec – 一个 jax.sharding.PartitionSpec 对象。

示例

代码语言:javascript
复制
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec) 
代码语言:javascript
复制
property addressable_devices: set[Device]

当前进程可以访问的Sharding中的设备集。

代码语言:javascript
复制
property device_set: set[Device]

Sharding跨越的设备集。

在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。

代码语言:javascript
复制
property is_fully_addressable: bool

此分片是否完全可寻址?

一个分片如果当前进程可以访问Sharding中列出的所有设备,则被视为完全可寻址。在多进程 JAX 中,is_fully_addressable等同于“is_local”。

代码语言:javascript
复制
property is_fully_replicated: bool

此分片是否完全复制?

如果每个设备都有整个数据的完整副本,则称分片为完全复制。

代码语言:javascript
复制
property memory_kind: str | None

返回分片的内存类型。

代码语言:javascript
复制
property mesh

(self) -> object

代码语言:javascript
复制
property spec

(self) -> object

代码语言:javascript
复制
with_memory_kind(kind)

返回具有指定内存类型的新Sharding实例。

参数:

kind (str)

返回类型:

NamedSharding

代码语言:javascript
复制
class jax.sharding.PositionalSharding(devices, *, memory_kind=None)

基类:Sharding

参数:

  • devices (Sequence*[xc.Device]* | np.ndarray)
  • memory_kind (str | None)
代码语言:javascript
复制
property device_set: set[Device]

Sharding跨越的设备集。

在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。

代码语言:javascript
复制
property is_fully_addressable: bool

此分片是否完全可寻址?

一个分片如果当前进程可以访问Sharding中列出的所有设备,则被视为完全可寻址。在多进程 JAX 中,is_fully_addressable等同于“is_local”。

代码语言:javascript
复制
property is_fully_replicated: bool

此分片是否完全复制?

如果每个设备都有整个数据的完整副本,则称分片为完全复制。

代码语言:javascript
复制
property memory_kind: str | None

返回分片的内存类型。

代码语言:javascript
复制
with_memory_kind(kind)

返回具有指定内存类型的新Sharding实例。

参数:

kind (str)

返回类型:

PositionalSharding

代码语言:javascript
复制
class jax.sharding.PmapSharding

基类:Sharding

描述了jax.pmap()使用的分片。

代码语言:javascript
复制
classmethod default(shape, sharded_dim=0, devices=None)

创建一个PmapSharding,与jax.pmap()使用的默认放置方式匹配。

参数:

  • shape (tuple[int, …**]) – 输入数组的形状。
  • sharded_dim (int") – 输入数组进行分片的维度。默认为 0。
  • devicesSequence[Device] | None) – 可选的设备序列。如果省略,隐含的
  • usedpmap 使用的设备顺序是) – jax.local_devices()
  • of这是顺序) – jax.local_devices()

返回类型:

PmapSharding

代码语言:javascript
复制
property device_set: set[Device]

这个Sharding跨越的设备集合。

在多控制器 JAX 中,设备集合是全局的,即包括其他进程的非可寻址设备。

代码语言:javascript
复制
property devices

(self)-> ndarray

代码语言:javascript
复制
devices_indices_map(global_shape)

返回设备到每个包含的数组切片的映射。

映射包括所有全局设备,即包括其他进程的非可寻址设备。

参数:

global_shape元组[int,…**]

返回类型:

Mapping[Device元组[切片,…]]

代码语言:javascript
复制
is_equivalent_to(other, ndim)

如果两个分片等效,则返回True

如果它们将相同的逻辑数组分片放置在相同的设备上,则两个分片是等效的。

例如,如果NamedShardingPositionalSharding将数组的相同分片放置在相同的设备上,则它们可能是等效的。

参数:

  • selfPmapSharding
  • otherPmapSharding
  • ndimint

返回类型:

布尔(“in Python v3.12”)

代码语言:javascript
复制
property is_fully_addressable: bool

这个分片是否完全可寻址?

如果当前进程能够处理Sharding中命名的所有设备,则分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable相当于“is_local”。

代码语言:javascript
复制
property is_fully_replicated: bool

这个分片是否完全复制?

如果每个设备都有完整数据的副本,则分片是完全复制的。

代码语言:javascript
复制
property memory_kind: str | None

返回分片的内存类型。

代码语言:javascript
复制
shard_shape(global_shape)

返回每个设备上数据的形状。

此函数返回的分片形状是从global_shape和分片属性计算而来的。

参数:

global_shape元组[int,…**]

返回类型:

元组[int,…]

代码语言:javascript
复制
property sharding_spec

(self)-> jax::ShardingSpec

代码语言:javascript
复制
with_memory_kind(kind)

返回具有指定内存类型的新 Sharding 实例。

参数:

kindstr

代码语言:javascript
复制
class jax.sharding.GSPMDSharding

基类:Sharding

代码语言:javascript
复制
property device_set: set[Device]

这个Sharding跨越的设备集合。

在多控制器 JAX 中,设备集是全局的,即包括来自其他进程的不可寻址设备。

代码语言:javascript
复制
property is_fully_addressable: bool

此分片是否完全可寻址?

如果当前进程可以访问Sharding中命名的所有设备,则分片是完全可寻址的。is_fully_addressable相当于多进程 JAX 中的“is_local”。

代码语言:javascript
复制
property is_fully_replicated: bool

此分片是否完全复制?

一个分片是完全复制的,如果每个设备都有整个数据的完整副本。

代码语言:javascript
复制
property memory_kind: str | None

返回分片的内存类型。

代码语言:javascript
复制
with_memory_kind(kind)

返回具有指定内存类型的新 Sharding 实例。

参数:

kindstr

返回类型:

GSPMDSharding

代码语言:javascript
复制
class jax.sharding.PartitionSpec(*partitions)

元组描述如何在设备网格上对数组进行分区。

每个元素都可以是None、字符串或字符串元组。有关更多详细信息,请参阅jax.sharding.NamedSharding的文档。

此类存在,以便 JAX 的 pytree 实用程序可以区分分区规范和应视为 pytrees 的元组。

代码语言:javascript
复制
class jax.sharding.Mesh(devices, axis_names)

声明在此管理器范围内可用的硬件资源。

特别是,所有axis_names在管理块内都变成有效的资源名称,并且可以在jax.experimental.pjit.pjit()in_axis_resources参数中使用,还请参阅 JAX 的多进程编程模型(jax.readthedocs.io/en/latest/multi_process.html)和分布式数组与自动并行化教程(jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html

如果您在多线程中编译,请确保with Mesh上下文管理器位于线程将执行的函数内部。

参数:

  • devicesndarray) - 包含 JAX 设备对象(例如从jax.devices()获得的对象)的 NumPy ndarray 对象。
  • axis_namestuple[Any, …**]) - 资源轴名称序列,用于分配给devices参数的维度。其长度应与devices的秩匹配。

示例

代码语言:javascript
复制
>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> inp = np.arange(16).reshape((8, 2))
>>> devices = np.array(jax.devices()).reshape(4, 2)
...
>>> # Declare a 2D mesh with axes `x` and `y`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> # Use the mesh object directly as a context manager.
>>> with global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) 
代码语言:javascript
复制
>>> # Initialize the Mesh and use the mesh as the context manager.
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) 
代码语言:javascript
复制
>>> # Also you can use it as `with ... as ...`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> with global_mesh as m:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) 
代码语言:javascript
复制
>>> # You can also use it as `with Mesh(...)`.
>>> with Mesh(devices, ('x', 'y')):
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) 

jax.debug 模块

原文:jax.readthedocs.io/en/latest/jax.debug.html

运行时值调试实用工具

jax.debug.print 和 jax.debug.breakpoint 描述了如何利用 JAX 的运行时值调试功能。

callback(callback, *args[, ordered])

调用可分阶段的 Python 回调函数。

print(fmt, *args[, ordered])

打印值,并在 JAX 函数中工作。

breakpoint(*[, backend, filter_frames, …])

在程序中某一点设置断点。

调试分片实用工具

能够在分段函数内(和外部)检查和可视化数组分片的函数。

inspect_array_sharding(value, *, callback)

在 JIT 编译函数内部启用检查数组分片。

visualize_array_sharding(arr, **kwargs)

可视化数组的分片。

visualize_sharding(shape, sharding, *[, …])

使用 rich 可视化 Sharding。

jax.dlpack 模块

原文:jax.readthedocs.io/en/latest/jax.dlpack.html

from_dlpack(external_array[, device, copy])

返回一个 DLPack 张量的 Array 表示形式。

to_dlpack(x[, stream, src_device, …])

返回一个封装了 Array x 的 DLPack 张量。

jax.distributed 模块

原文:jax.readthedocs.io/en/latest/jax.distributed.html

initialize([coordinator_address, …])

初始化 JAX 分布式系统。

shutdown()

关闭分布式系统。

jax.dtypes 模块

原文:jax.readthedocs.io/en/latest/jax.dtypes.html

bfloat16

bfloat16 浮点数值

canonicalize_dtype(dtype[, allow_extended_dtype])

根据config.x64_enabled配置将 dtype 转换为规范的 dtype。

float0

对应于相同名称的标量类型和 dtype 的 DType 类。

issubdtype(a, b)

如果第一个参数是类型代码在类型层次结构中较低/相等,则返回 True。

prng_key()

PRNG Key dtypes 的标量类。

result_type(*args[, return_weak_type_flag])

方便函数,用于应用 JAX 参数 dtype 提升。

scalar_type_of(x)

返回与 JAX 值关联的标量类型。

jax.flatten_util 模块

原文:jax.readthedocs.io/en/latest/jax.flatten_util.html

函数列表

-

ravel_pytree(pytree)

将一个数组的 pytree 展平(压缩)为一个 1D 数组。

jax.image 模块

原文:jax.readthedocs.io/en/latest/jax.image.html

图像操作函数。

更多的图像操作函数可以在建立在 JAX 之上的库中找到,例如 PIX

图像操作函数

resize(image, shape, method[, antialias, …])

图像调整大小。

scale_and_translate(image, shape, …[, …])

对图像应用缩放和平移。

参数类

代码语言:javascript
复制
class jax.image.ResizeMethod(value)

图像调整大小方法。

可能的取值包括:

NEAREST:

最近邻插值。

LINEAR:

线性插值

LANCZOS3:

Lanczos 重采样,使用半径为 3 的核。

LANCZOS5:

Lanczos 重采样,使用半径为 5 的核。

CUBIC:

三次插值,使用 Keys 三次核。

jax.nn 模块

原文:jax.readthedocs.io/en/latest/jax.nn.html

  • jax.nn.initializers 模块

神经网络库常见函数。

激活函数

relu

线性整流单元激活函数。

relu6

线性整流单元 6 激活函数。

sigmoid(x)

Sigmoid 激活函数。

softplus(x)

Softplus 激活函数。

sparse_plus(x)

稀疏加法函数。

sparse_sigmoid(x)

稀疏 Sigmoid 激活函数。

soft_sign(x)

Soft-sign 激活函数。

silu(x)

SiLU(又称 swish)激活函数。

swish(x)

SiLU(又称 swish)激活函数。

log_sigmoid(x)

对数 Sigmoid 激活函数。

leaky_relu(x[, negative_slope])

泄漏整流线性单元激活函数。

hard_sigmoid(x)

硬 Sigmoid 激活函数。

hard_silu(x)

硬 SiLU(swish)激活函数。

hard_swish(x)

硬 SiLU(swish)激活函数。

hard_tanh(x)

硬\tanh 激活函数。

elu(x[, alpha])

指数线性单元激活函数。

celu(x[, alpha])

连续可微的指数线性单元激活函数。

selu(x)

缩放的指数线性单元激活函数。

gelu(x[, approximate])

高斯误差线性单元激活函数。

glu(x[, axis])

门控线性单元激活函数。

squareplus(x[, b])

Squareplus 激活函数。

mish(x)

Mish 激活函数。

其他函数

softmax(x[, axis, where, initial])

Softmax 函数。

log_softmax(x[, axis, where, initial])

对数 Softmax 函数。

logsumexp()

对数-总和-指数归约。

standardize(x[, axis, mean, variance, …])

通过减去mean并除以(\sqrt{\mathrm{variance}})来标准化数组。

one_hot(x, num_classes, *[, dtype, axis])

对给定索引进行 One-hot 编码。

jax.nn.initializers 模块

原文:jax.readthedocs.io/en/latest/jax.nn.initializers.html

与 Keras 和 Sonnet 中定义一致的常见神经网络层初始化器。

初始化器

该模块提供了与 Keras 和 Sonnet 中定义一致的常见神经网络层初始化器。

初始化器是一个函数,接受三个参数:(key, shape, dtype),并返回一个具有形状shape和数据类型dtype的数组。参数key是一个 PRNG 密钥(例如来自jax.random.key()),用于生成初始化数组的随机数。

constant(value[, dtype])

构建一个返回常数值数组的初始化器。

delta_orthogonal([scale, column_axis, dtype])

构建一个用于增量正交核的初始化器。

glorot_normal([in_axis, out_axis, …])

构建一个 Glorot 正态初始化器(又称 Xavier 正态初始化器)。

glorot_uniform([in_axis, out_axis, …])

构建一个 Glorot 均匀初始化器(又称 Xavier 均匀初始化器)。

he_normal([in_axis, out_axis, batch_axis, dtype])

构建一个 He 正态初始化器(又称 Kaiming 正态初始化器)。

he_uniform([in_axis, out_axis, batch_axis, …])

构建一个 He 均匀初始化器(又称 Kaiming 均匀初始化器)。

lecun_normal([in_axis, out_axis, …])

构建一个 Lecun 正态初始化器。

lecun_uniform([in_axis, out_axis, …])

构建一个 Lecun 均匀初始化器。

normal([stddev, dtype])

构建一个返回实数正态分布随机数组的初始化器。

ones(key, shape[, dtype])

返回一个填充为一的常数数组的初始化器。

orthogonal([scale, column_axis, dtype])

构建一个返回均匀分布正交矩阵的初始化器。

truncated_normal([stddev, dtype, lower, upper])

构建一个返回截断正态分布随机数组的初始化器。

uniform([scale, dtype])

构建一个返回实数均匀分布随机数组的初始化器。

variance_scaling(scale, mode, distribution)

初始化器,根据权重张量的形状调整其尺度。

zeros(key, shape[, dtype])

返回一个填充零的常数数组的初始化器。

jax.ops 模块

原文:jax.readthedocs.io/en/latest/jax.ops.html

段落约简运算符

| segment_max(data, segment_ids[, …]) | 计算数组段内的最大值。 |

函数 jax.ops.index_update、jax.ops.index_add 等已在 JAX 0.2.22 中弃用,并已移除。请改用 JAX 数组上的 jax.numpy.ndarray.at 属性。

segment_min(data, segment_ids[, …])

segment_prod(data, segment_ids[, …])

segment_sum(data, segment_ids[, …])

jax.profiler 模块

原文:jax.readthedocs.io/en/latest/jax.profiler.html

跟踪和时间分析

描述了如何利用 JAX 的跟踪和时间分析功能进行程序性能分析。

start_server(port)

在指定端口启动分析器服务器。

start_trace(log_dir[, create_perfetto_link, …])

启动性能分析跟踪。

stop_trace()

停止当前正在运行的性能分析跟踪。

trace(log_dir[, create_perfetto_link, …])

上下文管理器,用于进行性能分析跟踪。

annotate_function(func[, name])

生成函数执行的跟踪事件的装饰器。

TraceAnnotation

在分析器中生成跟踪事件的上下文管理器。

StepTraceAnnotation(name, **kwargs)

在分析器中生成步骤跟踪事件的上下文管理器。

设备内存分析

请参阅设备内存分析,了解 JAX 的设备内存分析功能简介。

device_memory_profile([backend])

捕获 JAX 设备内存使用情况,格式为 pprof 协议缓冲区。

save_device_memory_profile(filename[, backend])

收集设备内存使用情况,并将其写入文件。

jax.stages 模块

原文:jax.readthedocs.io/en/latest/jax.stages.html

接口到编译执行过程的各个阶段。

JAX 转换,例如jax.jitjax.pmap,也支持一种通用的显式降阶和预编译执行 ahead of time 的方式。 该模块定义了代表这一过程各个阶段的类型。

有关更多信息,请参阅AOT walkthrough

代码语言:javascript
复制
class jax.stages.Wrapped(*args, **kwargs)

一个准备好进行追踪、降阶和编译的函数。

此协议反映了诸如jax.jit之类的函数的输出。 调用它会导致 JIT(即时)降阶、编译和执行。 它也可以在编译之前明确降阶,并在执行之前编译结果。

代码语言:javascript
复制
__call__(*args, **kwargs)

执行包装的函数,根据需要进行降阶和编译。

代码语言:javascript
复制
lower(*args, **kwargs)

明确为给定的参数降阶此函数。

一个降阶函数被从 Python 阶段化,并翻译为编译器的输入语言,可能以依赖于后端的方式。 它已准备好进行编译,但尚未编译。

返回:

一个Lowered实例,表示降阶。

返回类型:

降阶

代码语言:javascript
复制
trace(*args, **kwargs)

明确为给定的参数追踪此函数。

一个追踪函数被从 Python 阶段化,并翻译为一个 jaxpr。 它已准备好进行降阶,但尚未降阶。

返回:

一个Traced实例,表示追踪。

返回类型:

追踪

代码语言:javascript
复制
class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)

降阶一个根据参数类型和值特化的函数。

降阶是一种准备好进行编译的计算。 此类将降阶与稍后编译和执行所需的剩余信息一起携带。 它还提供了一个通用的 API,用于查询 JAX 各种降阶路径(jit()pmap()等)中降阶计算的属性。

参数:

  • 降阶XlaLowering
  • args_infoAny
  • out_treePyTreeDef
  • no_kwargsbool
代码语言:javascript
复制
as_text(dialect=None)

此降阶的人类可读文本表示。

旨在可视化和调试目的。 这不必是有效的也不一定可靠的序列化。 它直接传递给外部调用者。

参数:

方言str | ) – 可选字符串,指定一个降阶方言(例如,“stablehlo”)

返回类型:

str

代码语言:javascript
复制
compile(compiler_options=None)

编译,并返回相应的Compiled实例。

参数:

compiler_options (dict[str, str | bool] | None)

返回类型:

Compiled

代码语言:javascript
复制
compiler_ir(dialect=None)

这种降低的任意对象表示。

旨在调试目的。这不是有效的也不是可靠的序列化。输出在不同调用之间没有一致性的保证。

如果不可用,则返回None,例如基于后端、编译器或运行时。

参数:

dialect (str | None) – 可选字符串,指定一个降低方言(例如“stablehlo”)

返回类型:

Any | None

代码语言:javascript
复制
cost_analysis()

执行成本估算的摘要。

旨在可视化和调试。此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶的嵌套字典、列表和元组)。然而,它的结构可以是任意的:在 JAX 和 jaxlib 的不同版本甚至调用之间可能不一致。

如果不可用,则返回None,例如基于后端、编译器或运行时。

返回类型:

Any | None

代码语言:javascript
复制
property in_tree: PyTreeDef

一对(位置参数、关键字参数)的树结构。

代码语言:javascript
复制
class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)

编译后的函数专门针对类型/值进行了优化表示。

编译计算与可执行文件相关联,并提供执行所需的剩余信息。它还为查询 JAX 的各种编译路径和后端中编译计算属性提供了一个共同的 API。

参数:

  • args_info (Any)
  • out_tree (PyTreeDef)
代码语言:javascript
复制
__call__(*args, **kwargs)

将自身作为函数调用。

代码语言:javascript
复制
as_text()

这是可执行文件的人类可读文本表示。

旨在可视化和调试。这不是有效的也不是可靠的序列化。

如果不可用,则返回None,例如基于后端、编译器或运行时。

返回类型:

str | None

代码语言:javascript
复制
cost_analysis()

执行成本估算的摘要。

旨在可视化和调试。此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶的嵌套字典、列表和元组)。然而,它的结构可以是任意的:在 JAX 和 jaxlib 的不同版本甚至调用之间可能不一致。

如果不可用,则返回None,例如基于后端、编译器或运行时。

返回类型:

Any | None

代码语言:javascript
复制
property in_tree: PyTreeDef

(位置参数,关键字参数) 的树结构。

代码语言:javascript
复制
memory_analysis()

估计内存需求的摘要。

用于可视化和调试目的。由此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如嵌套的字典、列表和具有数字叶子的元组)。然而,其结构可以是任意的:在 JAX 和 jaxlib 的不同版本之间,甚至在不同调用之间可能是不一致的。

返回 None 如果不可用,例如基于后端、编译器或运行时。

返回类型:

任意 | None

代码语言:javascript
复制
runtime_executable()

此可执行对象的任意对象表示。

用于调试目的。这不是有效也不是可靠的序列化。输出不能保证在不同调用之间的一致性。

返回 None 如果不可用,例如基于后端、编译器或运行时。

返回类型:

任意 | None

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-06-22,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • jax.scipy 模块
    • jax.scipy.cluster
      • jax.scipy.stats.bernoulli
      • jax.scipy.stats.gaussian_kde
      • jax.scipy.stats.vonmises
  • jax.scipy.stats.bernoulli.logpmf
  • jax.scipy.stats.bernoulli.pmf
  • jax.scipy.stats.bernoulli.cdf
  • jax.scipy.stats.bernoulli.ppf
  • jax.lax 模块
    • Operators
      • 控制流操作符
        • 自定义梯度操作符
          • 并行操作符
            • 与分片相关的操作符
              • 线性代数操作符 (jax.lax.linalg)
                • 参数类
                • jax.random 模块
                  • 基本用法
                    • PRNG keys
                      • 高级
                        • 设计和背景
                        • 高级 RNG 配置
                      • API 参考
                        • 密钥创建与操作
                        • 随机抽样器
                    • jax.sharding 模块
                      • jax.debug 模块
                        • 运行时值调试实用工具
                          • 调试分片实用工具
                          • jax.dlpack 模块
                          • jax.distributed 模块
                          • jax.dtypes 模块
                          • jax.flatten_util 模块
                            • 函数列表
                            • jax.image 模块
                              • 图像操作函数
                                • 参数类
                                • jax.nn 模块
                                  • 激活函数
                                    • 其他函数
                                    • jax.nn.initializers 模块
                                      • 初始化器
                                      • jax.ops 模块
                                        • | segment_max(data, segment_ids[, …]) | 计算数组段内的最大值。 |
                                        • jax.profiler 模块
                                          • 跟踪和时间分析
                                            • 设备内存分析
                                            • jax.stages 模块
                                              领券
                                              问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档