首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow Plot (tfplot)

TensorFlow Plot (tfplot)

作者头像
狼啸风云
修改2022-09-04 21:43:42
1.4K0
修改2022-09-04 21:43:42
举报

原文链接:https://tensorflow-plot.readthedocs.io/en/latest/api/index.html


1. Showcases of tfplot

This guide shows a quick tour of the tfplot library. Please skip the setup section of this document.

import tfplot
tfplot.__version__
'0.3.0.dev0'

1. Setup: Utilities and Data

In order to see the images generated from the plot ops, we introduce a simple utility function which takes a Tensor as an input and displays the resulting image after executing it in a TensorFlow session.

You may want to skip this section to have the showcase started.

import tensorflow as tf
sess = tf.InteractiveSession()
def execute_op_as_image(op):
    """
    Evaluate the given `op` and return the content PNG image as `PIL.Image`.

    - If op is a plot op (e.g. RGBA Tensor) the image or
      a list of images will be returned
    - If op is summary proto (i.e. `op` was a summary op),
      the image content will be extracted from the proto object.
    """
    print ("Executing: " + str(op))
    ret = sess.run(op)
    plt.close()

    if isinstance(ret, np.ndarray):
        if len(ret.shape) == 3:
            # single image
            return Image.fromarray(ret)
        elif len(ret.shape) == 4:
            return [Image.fromarray(r) for r in ret]
        else:
            raise ValueError("Invalid rank : %d" % len(ret.shape))

    elif isinstance(ret, (str, bytes)):
        from io import BytesIO
        s = tf.Summary()
        s.ParseFromString(ret)
        ims = []
        for i in range(len(s.value)):
            png_string = s.value[i].image.encoded_image_string
            im = Image.open(BytesIO(png_string))
            ims.append(im)
        plt.close()
        if len(ims) == 1: return ims[0]
        else: return ims

    else:
        raise TypeError("Unknown type: " + str(ret))
and some data:
def fake_attention():
    import scipy.ndimage
    attention = np.zeros([16, 16], dtype=np.float32)
    attention[(11, 8)] = 1.0
    attention[(9, 9)] = 1.0
    attention = scipy.ndimage.filters.gaussian_filter(attention, sigma=1.5)
    return attention

sample_image = scipy.misc.face()
attention_map = fake_attention()

# display the data
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].imshow(sample_image); axs[0].set_title('image')
axs[1].imshow(attention_map, cmap='jet'); axs[1].set_title('attention')
plt.show()

And we finally wrap these numpy values into TensorFlow ops:

# the input to plot_op
image_tensor = tf.constant(sample_image, name='image')
attention_tensor = tf.constant(attention_map, name='attention')
print(image_tensor)
print(attention_tensor)

Tensor("image:0", shape=(768, 1024, 3), dtype=uint8)
Tensor("attention:0", shape=(16, 16), dtype=float32)

2. tfplot.autowrap: The Main End-User API

Use tfplot.autowrap to design a custom plot function of your own.

1. Decorator to define a TF op that draws plot

With tfplot.autowrap, you can wrap a python function that returns matplotlib.Figure (or AxesSubPlot) into TensorFlow ops, similar as in tf.py_func.

@tfplot.autowrap
def plot_scatter(x, y):
    # NEVER use plt.XXX, or matplotlib.pyplot.
    # Use tfplot.subplots() instead of plt.subplots() to avoid thread-safety issues.
    fig, ax = tfplot.subplots(figsize=(3, 3))
    ax.scatter(x, y, color='green')
    return fig

x = tf.constant(np.arange(10), dtype=tf.float32)
y = tf.constant(np.arange(10) ** 2, dtype=tf.float32)
execute_op_as_image(plot_scatter(x, y))

Executing: Tensor("plot_scatter:0", shape=(?, ?, 4), dtype=uint8)

We can create subplots as well. Also, note that additional arguments (i.e. kwargs) other than Tensor arguments (i.e. positional arguments) can be passed.

@tfplot.autowrap
def plot_image_and_attention(im, att, cmap=None):
    fig, axes = tfplot.subplots(1, 2, figsize=(7, 4))
    fig.suptitle('Image and Heatmap')
    axes[0].imshow(im)
    axes[1].imshow(att, cmap=cmap)
    return fig

op = plot_image_and_attention(sample_image, attention_map, cmap='jet')
execute_op_as_image(op)

Executing: Tensor("plot_image_and_attention:0", shape=(?, ?, 4), dtype=uint8)

Sometimes, it can be cumbersome to create instances of fig and ax. If you want to have them automatically created and injected, use a keyword argument named fig and/or ax:

@tfplot.autowrap(figsize=(2, 2))
def plot_scatter(x, y, *, ax, color='red'):
    ax.set_title('x^2')
    ax.scatter(x, y, color=color)

x = tf.constant(np.arange(10), dtype=tf.float32)
y = tf.constant(np.arange(10) ** 2, dtype=tf.float32)
execute_op_as_image(plot_scatter(x, y))

Executing: Tensor("plot_scatter_1:0", shape=(?, ?, 4), dtype=uint8)
../_images/guide_showcases_23_1.png
../_images/guide_showcases_23_1.png

2. Wrapping Matplotlib’s AxesPlot or Seaborn Plot

You can use tfplot.autowrap (or raw APIs such as tfplot.plot, etc.) to plot anything by writing a customized plotting function on your own, but sometimes we may want to convert already existing plot functions from common libraries such as matplotlib and seaborn.

To do this, you can still use tfplot.autowrap.

1. Matplotlib

Matplotlib provides a variety of plot methods defined in the class AxesPlot (usually, ax).

rs = np.random.RandomState(42)
x = rs.randn(100)
y = 2 * x + rs.randn(100)

fig, ax = plt.subplots()
ax.scatter(x, y)
ax.set_title("Created from matplotlib API")
plt.show()
../_images/guide_showcases_27_0.png
../_images/guide_showcases_27_0.png

We can wrap the Axes.scatter() method as TensorFlow op as follows:

from matplotlib.axes import Axes
tf_scatter = tfplot.autowrap(Axes.scatter, figsize=(4, 4))

plot_op = tf_scatter(x, y)
execute_op_as_image(plot_op)

Executing: Tensor("scatter:0", shape=(?, ?, 4), dtype=uint8)
../_images/guide_showcases_29_1.png
../_images/guide_showcases_29_1.png

2. Seaborn

Seaborn provides many useful axis plot functions that can be used out-of-box. Most of functions for drawing an AxesPlot will have the ax=... parameter.

See seaborn’s example gallery for interesting features seaborn provides.

import seaborn as sns

assert sns.__version__ >= '0.8', \
    'Use seaborn >= v0.8.0, otherwise `import seaborn as sns` will affect the default matplotlib style.'

barplot: (Discrete) Probability Distribution

# https://seaborn.pydata.org/generated/seaborn.barplot.html

y = np.random.RandomState(42).normal(size=[18])
y = np.exp(y) / np.exp(y).sum() # softmax
y = tf.constant(y, dtype=tf.float32)

ATARI_ACTIONS = [
    '⠀', '●', '↑', '→', '←', '↓', '↗', '↖', '↘', '↙',
    '⇑', '⇒', '⇐', '⇓', '⇗', '⇖', '⇘', '⇙' ]
x = tf.constant(ATARI_ACTIONS)

op = tfplot.autowrap(sns.barplot, palette='Blues_d')(x, y)
execute_op_as_image(op)

Executing: Tensor("barplot:0", shape=(?, ?, 4), dtype=uint8)
../_images/guide_showcases_33_1.png
../_images/guide_showcases_33_1.png
y = np.random.RandomState(42).normal(size=[3, 18])

y = np.exp(y) / np.exp(y).sum(axis=1).reshape([-1, 1]) # softmax example-wise
y = tf.constant(y, dtype=tf.float32)

ATARI_ACTIONS = [
    '⠀', '●', '↑', '→', '←', '↓', '↗', '↖', '↘', '↙',
    '⇑', '⇒', '⇐', '⇓', '⇗', '⇖', '⇘', '⇙' ]
x = tf.broadcast_to(tf.constant(ATARI_ACTIONS), y.shape)

op = tfplot.autowrap(sns.barplot, palette='Blues_d', batch=True)(x, y)
for im in execute_op_as_image(op):
    display(im)

Executing: Tensor("barplot_1/PlotImages:0", shape=(3, ?, ?, 4), dtype=uint8)
../_images/guide_showcases_34_1.png
../_images/guide_showcases_34_1.png
../_images/guide_showcases_34_2.png
../_images/guide_showcases_34_2.png
../_images/guide_showcases_34_3.png
../_images/guide_showcases_34_3.png

HeatmapLet’s wrap seaborn’s heatmap function, as TensorFlow operation, with some additional default kwargs. This is very useful for visualization.

# @seealso https://seaborn.pydata.org/examples/heatmap_annotation.html
tf_heatmap = tfplot.autowrap(sns.heatmap, figsize=(9, 6))

op = tf_heatmap(attention_map, cbar=True, annot=True, fmt=".2f")
execute_op_as_image(op)

Executing: Tensor("heatmap:0", shape=(?, ?, 4), dtype=uint8)
../_images/guide_showcases_37_1.png
../_images/guide_showcases_37_1.png

What if we don’t want axes and colorbars, but only the map itself? Compare to plain tf.summary.image, which just gives a grayscale image.

# print only heatmap figures other than axis, colorbar, etc.
tf_heatmap = tfplot.autowrap(sns.heatmap, figsize=(4, 4), tight_layout=True,
                             cmap='jet', cbar=False, xticklabels=False, yticklabels=False)

op = tf_heatmap(attention_map, name='HeatmapImage')
execute_op_as_image(op)

Executing: Tensor("HeatmapImage:0", shape=(?, ?, 4), dtype=uint8)
../_images/guide_showcases_39_1.png
../_images/guide_showcases_39_1.png

3. And Many More!

This document has covered a basic usage of tfplot, but there are a few more:

  • tfplot.contrib: contains some off-the-shelf functions for creating plot operations that can be useful in practice, in few lines (without a hassle of writing function body). See [contrib.ipynb] for more tour of available APIs.
  • tfplot.plot(), tfplot.plot_many(), etc.: Low-level APIs.
  • tfplot.summary: One-liner APIs for creating TF summary operations.
import tfplot.contrib

For example, probmap and probmap_simple create an image Tensor that visualizes a probability map:

op = tfplot.contrib.probmap(attention_map, figsize=(4, 3))
execute_op_as_image(op)
Executing: Tensor("probmap:0", shape=(?, ?, 4), dtype=uint8)
../_images/guide_showcases_45_1.png
../_images/guide_showcases_45_1.png
op = tfplot.contrib.probmap_simple(attention_map, figsize=(3, 3), vmin=0, vmax=1)
execute_op_as_image(op)

Executing: Tensor("probmap_1:0", shape=(?, ?, 4), dtype=uint8)
../_images/guide_showcases_46_1.png
../_images/guide_showcases_46_1.png

That’s all! Please take a look at API documentations and more examples if you are interested.

2. tfplot.contrib: Some pre-defined plot ops

The tfplot.contrib package contains some off-the-shelf functions for defining plotting operations. This package provides some off-the-shelf functions that could be useful widely across many typical use cases.Unfortunately, it may not provide super flexible and fine-grained customization points beyond the current parameters. If it does not fit what you want to get, then consider designing your own plotting functions using tfplot.autowrap.

import tfplot.contrib

for fn in sorted(tfplot.contrib.__all__):
    print("%-20s" % fn, tfplot.contrib.__dict__[fn].__doc__.split('\n')[1].strip())

batch                Make an autowrapped plot function (... -> RGBA tf.Tensor) work in a batch
probmap              Display a heatmap in color. The resulting op will be a RGBA image Tensor.
probmap_simple       Display a heatmap in color, but only displays the image content.

1. probmap

For example, probmap and probmap_simple create an image Tensor that visualizes a probability map:

attention_op = tf.constant(attention_map, name="attention_op")
print(attention_op)

op = tfplot.contrib.probmap(attention_map, figsize=(4, 3))
execute_op_as_image(op)

Tensor("attention_op:0", shape=(16, 16), dtype=float32)
Executing: Tensor("probmap:0", shape=(?, ?, 4), dtype=uint8)
../_images/guide_contrib_11_1.png
../_images/guide_contrib_11_1.png
op = tfplot.contrib.probmap_simple(attention_map, figsize=(3, 3),
                                   vmin=0, vmax=1)
execute_op_as_image(op)

Executing: Tensor("probmap_1:0", shape=(?, ?, 4), dtype=uint8)
../_images/guide_contrib_12_1.png
../_images/guide_contrib_12_1.png

2. Auto-batch mode (tfplot.contrib.batch)

In many cases, we may want to make plotting operations behave in a batch manner. You can use tfplot.contrib.batch to make those functions work in a batch mode:

# batch version
N = 5
p = np.zeros([N, N, N])
for i in range(N):
    p[i, i, i] = 1.0

p = tf.constant(p, name="batch_tensor"); print(p)                      # (batch_size, 5, 5)
op = tfplot.contrib.batch(tfplot.contrib.probmap)(p, figsize=(3, 2))   # (batch_size, H, W, 4)

results = execute_op_as_image(op)      # list of N images
Image.fromarray(np.hstack([np.asarray(im) for im in results]))

Tensor("batch_tensor:0", shape=(5, 5, 5), dtype=float64)
Executing: Tensor("probmap_2/PlotImages:0", shape=(5, ?, ?, 4), dtype=uint8)
../_images/guide_contrib_15_1.png
../_images/guide_contrib_15_1.png

3. More APIs

1. Low-level APIs: tfplot.plot()

The following examples show the usage of the most general form of the API, tfplot.plot(). It has a very similar usage as tf.py_func().

Conceptually, we can draw any matplotlib plot as a TensorFlow op. One thing to remember is that the plot_func function (passed to tfplot.plot()) should be implemented using object-oriented APIs of matplotlib, not pyplot.XXX APIs (or matplotlib.pyplot.XXX) in order to avoid thread-safety issues.

1. A basic example

def test_figure():
    fig, ax = tfplot.subplots(figsize=(3, 3))
    ax.text(0.5, 0.5, "Hello World!",
            ha='center', va='center', size=24)
    return fig

plot_op = tfplot.plot(test_figure, [])
execute_op_as_image(plot_op)

Executing: Tensor("Plot:0", shape=(?, ?, 4), dtype=uint8)
../_images/guide_more-api_4_1.png
../_images/guide_more-api_4_1.png

2. with Arguments

def figure_attention(attention):
    fig, ax = tfplot.subplots(figsize=(4, 3))
    im = ax.imshow(attention, cmap='jet')
    fig.colorbar(im)
    return fig

plot_op = tfplot.plot(figure_attention, [attention_tensor])
execute_op_as_image(plot_op)

Executing: Tensor("Plot_1:0", shape=(?, ?, 4), dtype=uint8)
../_images/guide_more-api_6_1.png
../_images/guide_more-api_6_1.png

3. Examples of using kwargs

# the plot function can have additional kwargs for providing configuration points
def overlay_attention(attention, image,
                      alpha=0.5, cmap='jet'):
    fig = tfplot.Figure(figsize=(4, 4))
    ax = fig.add_subplot(1, 1, 1)
    ax.axis('off')
    fig.subplots_adjust(0, 0, 1, 1)  # get rid of margins

    H, W = attention.shape
    ax.imshow(image, extent=[0, H, 0, W])
    ax.imshow(attention, cmap=cmap,
              alpha=alpha, extent=[0, H, 0, W])
    return fig
plot_op = tfplot.plot(overlay_attention, [attention_tensor, image_tensor])
execute_op_as_image(plot_op)

Executing: Tensor("Plot_2:0", shape=(?, ?, 4), dtype=uint8)
# the kwargs to `tfplot.plot()` are passed to the plot function (i.e. `overlay_attention`)
# during the execution of the plot operation.
plot_op = tfplot.plot(overlay_attention, [attention_tensor, image_tensor],
                      cmap='gray', alpha=0.8)
execute_op_as_image(plot_op)

Executing: Tensor("Plot_3:0", shape=(?, ?, 4), dtype=uint8)

4. plot_many() – the batch version

# make a fake batch
batch_size = 3
attention_batch = tf.random_gamma([batch_size, 7, 7], alpha=0.3, seed=42)
image_batch = tf.tile(tf.expand_dims(image_tensor, 0),
                      [batch_size, 1, 1, 1], name='image_batch')
print (attention_batch)
print (image_batch)

# plot_many()
plot_op = tfplot.plot_many(overlay_attention, [attention_batch, image_batch])
images = execute_op_as_image(plot_op)

Tensor("random_gamma/Maximum:0", shape=(3, 7, 7), dtype=float32)
Tensor("image_batch:0", shape=(3, 768, 1024, 3), dtype=uint8)
Executing: Tensor("PlotMany/PlotImages:0", shape=(3, ?, ?, 4), dtype=uint8)
# just see the three images
_, axes = plt.subplots(1, 3, figsize=(10, 3))
for i in range(3):
    axes[i].set_title("%d : [%dx%d]" % (i, images[i].height, images[i].width))
    axes[i].imshow(images[i])
plt.show()

5. Wrap once, use it as a factory – tfplot.autowrap() or tfplot.wrap()

Let’s wrap the function overlay_attention, which

  • takes a heatmap attention and a RGB image overlay_image
  • and plots the heatmap on top of the image

as Tensors:

plot_op = tfplot.autowrap(overlay_attention)(attention_tensor, image_tensor)
execute_op_as_image(plot_op)

Executing: Tensor("overlay_attention:0", shape=(?, ?, 4), dtype=uint8)

More clean style in a functional way!

6. Batch example

tf_plot_attention = tfplot.wrap(overlay_attention, name='PlotAttention', batch=True)
print (tf_plot_attention)

<function wrap[__main__.overlay_attention] at 0x127f26f28>

Then we can call the resulting tf_plot_attention function to build new TensorFlow ops:

plot_op = tf_plot_attention(attention_batch, image_batch)
images = execute_op_as_image(plot_op)
images

Executing: Tensor("PlotAttention/PlotImages:0", shape=(3, ?, ?, 4), dtype=uint8)
[<PIL.Image.Image image mode=RGBA size=288x288 at 0x12A896470>,
 <PIL.Image.Image image mode=RGBA size=288x288 at 0x12A896390>,
 <PIL.Image.Image image mode=RGBA size=288x288 at 0x12A8962E8>]
# just see the three images
_, axes = plt.subplots(1, 3, figsize=(10, 3))
for i in range(3):
    axes[i].set_title("%d : [%dx%d]" % (i, images[i].height, images[i].width))
    axes[i].imshow(images[i])
plt.show()

2. tfplot.summary (deprecated)

Finally, we can directly create a TensorFlow summary op from input tensors. This will give a similar API usage as tf.summary.image(), which is a shortcut to creating plot ops and then creating image summaries.

import tfplot.summary

1. tfplot.summary.plot()

# Just directly add a single plot result into a summary
summary_op = tfplot.summary.plot("plot_summary", test_figure, [])
print(summary_op)
execute_op_as_image(summary_op)

Tensor("plot_summary/ImageSummary:0", shape=(), dtype=string)
Executing: Tensor("plot_summary/ImageSummary:0", shape=(), dtype=string)
../_images/guide_more-api_28_1.png
../_images/guide_more-api_28_1.png

2. tfplot.summary.plot_many() – the batch version

# batch of attention maps --> image summary
batch_size, H, W = 4, 4, 4
batch_attentions = np.zeros((batch_size, H, W), dtype=np.float32)
for b in range(batch_size):
    batch_attentions[b, b, b] = 1.0

# Note that tfplot.summary.plot_many() takes an input in a batch form
def figure_attention_demo2(attention):
    fig, ax = tfplot.subplots(figsize=(4, 3))
    im = ax.imshow(attention, cmap='jet')
    fig.colorbar(im)
    return fig
summary_op = tfplot.summary.plot_many("batch_attentions_summary", figure_attention_demo2,
                                      [batch_attentions], max_outputs=4)
print(summary_op)
images = execute_op_as_image(summary_op)

Tensor("batch_attentions_summary/ImageSummary:0", shape=(), dtype=string)
Executing: Tensor("batch_attentions_summary/ImageSummary:0", shape=(), dtype=string)
# just see the 4 images in the summary
_, axes = plt.subplots(2, 2, figsize=(8, 6))
for i in range(batch_size):
    axes[i//2, i%2].set_title("%d : [%dx%d]" % (i, images[i].height, images[i].width))
    axes[i//2, i%2].imshow(images[i])
plt.show()
../_images/guide_more-api_31_0.png
../_images/guide_more-api_31_0.png

3. API Reference

1. tfplot

1.Wrapper functions

tfplot.autowrap(*args, **kwargs)[source]

Wrap a function as a TensorFlow operation similar to tfplot.wrap() (as a decorator or with normal function call), but provides with additional features such as auto-creating matplotlib figures.

  • (fig, ax) matplotlib objects are automatically created and injected given that plot_func has a keyword argument named fig and/or `ax. In such cases, we do not need to manually call tfplot.subplots() to create matplotlib figure/axes objects. If a manual creation of fig, ax is forced, please consider using tfplot.wrap() instead.
  • It can automatically handle return values of the provided plot_func function. If it returns nothing (None) but fig was automatically injected then the resulting figure will be drawn, or returns Axes then the associated Figure will be used.

Example

>>> @tfplot.autowrap(figsize=(3, 3))
>>> def plot_imshow(img, *, fig, ax):
>>>    ax.imshow(img)
>>>
>>> plot_imshow(an_image_tensor)
Tensor("plot_imshow:0", shape=(?, ?, 4), dtype=uint8)

Parameters:

plot_func – A python function or callable to wrap. See the documentation of tfplot.plot() for details. Additionally, if this function has a parameter named fig and/or ax, new instances of Figure and/or AxesSubplot will be created and passed. batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False. name – A default name for the operation (optional). If not given, the name of plot_func will be used. figsize – The figure size for the figure to be created. tight_layout – If True, the resulting figure will have no margins for axis. Equivalent to calling fig.subplots_adjust(0, 0, 1, 1). kwargs_default – An optimal kwargs that will be passed by default to plot_func when executed inside a TensorFlow graph.

  • plot_func – A python function or callable to wrap. See the documentation of tfplot.plot() for details. Additionally, if this function has a parameter named fig and/or ax, new instances of Figure and/or AxesSubplot will be created and passed.
  • batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False.
  • name – A default name for the operation (optional). If not given, the name of plot_func will be used.
  • figsize – The figure size for the figure to be created.
  • tight_layout – If True, the resulting figure will have no margins for axis. Equivalent to calling fig.subplots_adjust(0, 0, 1, 1).
  • kwargs_default – An optimal kwargs that will be passed by default to plot_func when executed inside a TensorFlow graph.

tfplot.wrap(*args, **kwargs)[source]

Wrap a plot function as a TensorFlow operation. It will return a python function that creates a TensorFlow plot operation applying the arguments as input. It can be also used as a decorator.

For example:

>>> @tfplot.wrap
>>> def plot_imshow(img):
>>>    fig, ax = tfplot.subplots()
>>>    ax.imshow(img)
>>>    return fig
>>>
>>> plot_imshow(an_image_tensor)
Tensor("plot_imshow:0", shape=(?, ?, 4), dtype=uint8)

Or, if plot_func is a python function that takes numpy arrays as input and draw a plot by returning a matplotlib Figure, we can wrap this function as a Tensor factory, such as:

>>> tf_plot = tfplot.wrap(plot_func, name="MyPlot", batch=True)
>>> # x, y = get_batch_inputs(batch_size=4, ...)
>>> plot_x = tf_plot(x)
Tensor("MyPlot:0", shape=(4, ?, ?, 4), dtype=uint8)
>>> plot_y = tf_plot(y)
Tensor("MyPlot_1:0", shape=(4, ?, ?, 4), dtype=uint8)

Parameters:

plot_func – A python function or callable to wrap. See the documentation of tfplot.plot() for details. batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False. name – A default name for the operation (optional). If not given, the name of plot_func will be used. kwargs – An optional kwargs that will be passed by default to plot_func when executed inside a TensorFlow graph.

Returns:

A python function that will create a TensorFlow plot operation, passing the provided arguments.

  • plot_func – A python function or callable to wrap. See the documentation of tfplot.plot() for details.
  • batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False.
  • name – A default name for the operation (optional). If not given, the name of plot_func will be used.
  • kwargs – An optional kwargs that will be passed by default to plot_func when executed inside a TensorFlow graph.

Returns: A python function that will create a TensorFlow plot operation, passing the provided arguments.

tfplot.wrap_axesplot(axesplot_func, _sentinel=None, batch=False, name=None, figsize=None, tight_layout=False, **kwargs)[source]

DEPRECATED: Use tfplot.autowrap() instead. Will be removed in the next version.

Wrap an axesplot function as a TensorFlow operation. It will return a python function that creates a TensorFlow plot operation applying the arguments as input.

An axesplot function axesplot_func can be either:

  • an unbounded method of matplotlib Axes (or AxesSubplot) class, such as Axes.scatter() and Axes.text(), etc, or
  • a simple python function that takes the named argument ax, of type Axes or AxesSubplot, on which the plot will be drawn. Some good examples of this family includes seaborn.heatmap(ax=...).

The resulting function can be used as a Tensor factory. When the created tensorflow plot op is being executed, a new matplotlib figure which consists of a single AxesSubplot will be created, and the axes plot will be used as an argument for axesplot_func. For example,

>>> import seaborn.apionly as sns
>>> tf_heatmap = tfplot.wrap_axesplot(sns.heatmap, name="HeatmapPlot", figsize=(4, 4), cmap='jet')

>>> plot_op = tf_heatmap(attention_map, cmap)
Tensor(HeatmapPlot:0", shape=(?, ?, 4), dtype=uint8)

Parameters:

axesplot_func – An unbounded method of matplotlib Axes or AxesSubplot, or a python function or callable which has the ax parameter for specifying the axis to draw on. batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False. name – A default name for the operation (optional). If not given, the name of axesplot_func will be used. figsize – The figure size for the figure to be created. tight_layout – If True, the resulting figure will have no margins for axis. Equivalent to calling fig.subplots_adjust(0, 0, 1, 1). kwargs – An optional kwargs that will be passed by default to axesplot_func.

Returns:

A python function that will create a TensorFlow plot operation, passing the provied arguments and a new instance of AxesSubplot into axesplot_func.

  • axesplot_func – An unbounded method of matplotlib Axes or AxesSubplot, or a python function or callable which has the ax parameter for specifying the axis to draw on.
  • batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False.
  • name – A default name for the operation (optional). If not given, the name of axesplot_func will be used.
  • figsize – The figure size for the figure to be created.
  • tight_layout – If True, the resulting figure will have no margins for axis. Equivalent to calling fig.subplots_adjust(0, 0, 1, 1).
  • kwargs – An optional kwargs that will be passed by default to axesplot_func.

Returns: A python function that will create a TensorFlow plot operation, passing the provied arguments and a new instance of AxesSubplot into axesplot_func.

2. Raw Plot Ops

tfplot.plot(plot_func, in_tensors, name='Plot', **kwargs)[source]

Create a TensorFlow op which draws plot in an image. The resulting image is in a 3-D uint8 tensor.

Given a python function plot_func, which takes numpy arrays as its inputs (the evaluations of in_tensors) and returns a matplotlib Figure object as its outputs, wrap this function as a TensorFlow op. The returning figure will be rendered as a RGB-A image upon execution.

Parameters:

plot_func – a python function or callable The function which accepts numpy ndarray objects as an argument that match the corresponding tf.Tensor objects in in_tensors. It should return a new instance of matplotlib.figure.Figure, which contains the resulting plot image. in_tensors – A list of tf.Tensor objects. name – A name for the operation (optional). kwargs – Additional keyword arguments passed to plot_func (optional).

Returns:

A single uint8 Tensor of shape (?, ?, 4), containing the plot image that plot_func computes.

  • plot_func – a python function or callable The function which accepts numpy ndarray objects as an argument that match the corresponding tf.Tensor objects in in_tensors. It should return a new instance of matplotlib.figure.Figure, which contains the resulting plot image.
  • in_tensors – A list of tf.Tensor objects.
  • name – A name for the operation (optional).
  • kwargs – Additional keyword arguments passed to plot_func (optional).

Returns: A single uint8 Tensor of shape (?, ?, 4), containing the plot image that plot_func computes.

tfplot.plot_many(plot_func, in_tensors, name='PlotMany', max_outputs=None, **kwargs)[source]

A batch version of plot. Create a TensorFlow op which draws a plot for each image. The resulting images are given in a 4-D uint8 Tensor of shape [batch_size, height, width, 4].

Parameters:

plot_func – A python function or callable, which accepts numpy ndarray objects as an argument that match the corresponding tf.Tensor objects in in_tensors. It should return a new instance of matplotlib.figure.Figure, which contains the resulting plot image. The shape (height, width) of generated figure for each plot should be same. in_tensors – A list of tf.Tensor objects. name – A name for the operation (optional). max_outputs – Max number of batch elements to generate plots for (optional). kwargs – Additional keyword arguments passed to plot_func (optional).

Returns:

A single uint8 Tensor of shape (B, ?, ?, 4), containing the B plot images, each of which is computed by plot_func, where B equals batch_size, the number of batch elements in the each tensor from in_tensors, or max_outputs (whichever is smaller).

  • plot_func – A python function or callable, which accepts numpy ndarray objects as an argument that match the corresponding tf.Tensor objects in in_tensors. It should return a new instance of matplotlib.figure.Figure, which contains the resulting plot image. The shape (height, width) of generated figure for each plot should be same.
  • in_tensors – A list of tf.Tensor objects.
  • name – A name for the operation (optional).
  • max_outputs – Max number of batch elements to generate plots for (optional).
  • kwargs – Additional keyword arguments passed to plot_func (optional).

Returns: A single uint8 Tensor of shape (B, ?, ?, 4), containing the B plot images, each of which is computed by plot_func, where B equals batch_size, the number of batch elements in the each tensor from in_tensors, or max_outputs (whichever is smaller).


2. tfplot.figure

Figure utilities.

tfplot.figure.to_array(fig)[source]

Convert a matplotlib figure fig into a 3D numpy array.

Example

>>> fig, ax = tfplot.subplots(figsize=(4, 4))
>>> # draw whatever, e.g. ax.text(0.5, 0.5, "text")

>>> im = to_array(fig)   # ndarray [288, 288, 4]

Parameters:

fig – A matplotlib.figure.Figure object.

Returns:

A numpy ndarray of shape (?, ?, 4), containing an RGB-A image of the figure.

tfplot.figure.to_summary(fig, tag)[source]

Convert a matplotlib figure fig into a TensorFlow Summary object that can be directly fed into Summary.FileWriter.

Example

>>> fig, ax = ...    # (as above)
>>> summary = to_summary(fig, tag='MyFigure/image')

>>> type(summary)
tensorflow.core.framework.summary_pb2.Summary
>>> summary_writer.add_summary(summary, global_step=global_step)

Parameters:

fig – A matplotlib.figure.Figure object. tag (string) – The tag name of the created summary.

Returns:

A TensorFlow Summary protobuf object containing the plot image as a image summary.

  • fig – A matplotlib.figure.Figure object.
  • tag (string) – The tag name of the created summary.

Returns: A TensorFlow Summary protobuf object containing the plot image as a image summary.

tfplot.figure.subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, subplot_kw=None, gridspec_kw=None, **fig_kw)[source]

Create a figure and a set of subplots, as in pyplot.subplots().

It works almost similar to pyplot.subplots(), but differ from it in that it does not involve any side effect as pyplot does (e.g. modifying thread states such as current figure or current subplot).

(docstrings inherited from matplotlib.pyplot.subplots)

Parameters:

ncols (nrows,) – Number of rows/columns of the subplot grid. sharey (sharex,) – Controls sharing of properties among x (sharex) or y (sharey) axes: True or ‘all’: x- or y-axis will be shared among all subplots. False or ‘none’: each subplot x- or y-axis will be independent. ’row’: each subplot row will share an x- or y-axis. ’col’: each subplot column will share an x- or y-axis. When subplots have a shared x-axis along a column, only the x tick labels of the bottom subplot are created. Similarly, when subplots have a shared y-axis along a row, only the y tick labels of the first column subplot are created. To later turn other subplots’ ticklabels on, use tick_params(). squeeze (bool, optional, default: True) – If True, extra dimensions are squeezed out from the returned array of Axes: if only one subplot is constructed (nrows=ncols=1), the resulting single Axes object is returned as a scalar. for Nx1 or 1xM subplots, the returned object is a 1D numpy object array of Axes objects. for NxM, subplots with N>1 and M>1 are returned as a 2D array. If False, no squeezing at all is done: the returned Axes object is always a 2D array containing Axes instances, even if it ends up being 1x1. subplot_kw (dict, optional) – Dict with keywords passed to the add_subplot() call used to create each subplot. gridspec_kw (dict, optional) – Dict with keywords passed to the GridSpec constructor used to create the grid the subplots are placed on. **fig_kw – All additional keyword arguments are passed to the figure() call.

Returns:

fig (matplotlib.figure.Figure object) ax (Axes object or array of Axes objects.) – ax can be either a single matplotlib.axes.Axes object or an array of Axes objects if more than one subplot was created. The dimensions of the resulting array can be controlled with the squeeze keyword, see above.

  • ncols (nrows,) – Number of rows/columns of the subplot grid.
  • sharey (sharex,) – Controls sharing of properties among x (sharex) or y (sharey) axes:
    • True or ‘all’: x- or y-axis will be shared among all subplots.
    • False or ‘none’: each subplot x- or y-axis will be independent.
    • ’row’: each subplot row will share an x- or y-axis.
    • ’col’: each subplot column will share an x- or y-axis.

    When subplots have a shared x-axis along a column, only the x tick labels of the bottom subplot are created. Similarly, when subplots have a shared y-axis along a row, only the y tick labels of the first column subplot are created. To later turn other subplots’ ticklabels on, use tick_params().

  • squeeze (bool, optional, default: True) –
    • If True, extra dimensions are squeezed out from the returned array of Axes:
      • if only one subplot is constructed (nrows=ncols=1), the resulting single Axes object is returned as a scalar.
      • for Nx1 or 1xM subplots, the returned object is a 1D numpy object array of Axes objects.
      • for NxM, subplots with N>1 and M>1 are returned as a 2D array.
    • If False, no squeezing at all is done: the returned Axes object is always a 2D array containing Axes instances, even if it ends up being 1x1.
  • subplot_kw (dict, optional) – Dict with keywords passed to the add_subplot() call used to create each subplot.
  • gridspec_kw (dict, optional) – Dict with keywords passed to the GridSpec constructor used to create the grid the subplots are placed on.
  • **fig_kw – All additional keyword arguments are passed to the figure() call.

Returns:

  • fig (matplotlib.figure.Figure object)
  • ax (Axes object or array of Axes objects.) – ax can be either a single matplotlib.axes.Axes object or an array of Axes objects if more than one subplot was created. The dimensions of the resulting array can be controlled with the squeeze keyword, see above.

Examples

First create some toy data:

>>> x = np.linspace(0, 2*np.pi, 400)
>>> y = np.sin(x**2)

Creates just a figure and only one subplot

>>> fig, ax = tfplot.subplots()
>>> ax.plot(x, y)
>>> ax.set_title('Simple plot')

Creates two subplots and unpacks the output array immediately

>>> f, (ax1, ax2) = tfplot.subplots(1, 2, sharey=True)
>>> ax1.plot(x, y)
>>> ax1.set_title('Sharing Y axis')
>>> ax2.scatter(x, y)

Creates four polar axes, and accesses them through the returned array

>>> fig, axes = tfplot.subplots(2, 2, subplot_kw=dict(polar=True))
>>> axes[0, 0].plot(x, y)
>>> axes[1, 1].scatter(x, y)

Share a X axis with each column of subplots

>>> tfplot.subplots(2, 2, sharex='col')

Share a Y axis with each row of subplots

>>> tfplot.subplots(2, 2, sharey='row')

Share both X and Y axes with all subplots

>>> tfplot.subplots(2, 2, sharex='all', sharey='all')

Note that this is the same as

>>> tfplot.subplots(2, 2, sharex=True, sharey=True)

3. tfplot.contrib

Some predefined plot functions.

tfplot.contrib.probmap(*args, **kwargs_call)[source]

Display a heatmap in color. The resulting op will be a RGBA image Tensor.

Parameters:

x – A 2-D image-like tensor to draw. cmap – Matplotlib colormap. Defaults ‘jet’ axis – If True (default), x-axis and y-axis will appear. colorbar – If True (default), a colorbar will be placed on the right. vmin – A scalar. Minimum value of the range. See matplotlib.axes.Axes.imshow. vmax – A scalar. Maximum value of the range. See matplotlib.axes.Axes.imshow.

Returns:

A uint8 Tensor of shape (?, ?, 4) containing the resulting plot.

  • x – A 2-D image-like tensor to draw.
  • cmap – Matplotlib colormap. Defaults ‘jet’
  • axis – If True (default), x-axis and y-axis will appear.
  • colorbar – If True (default), a colorbar will be placed on the right.
  • vmin – A scalar. Minimum value of the range. See matplotlib.axes.Axes.imshow.
  • vmax – A scalar. Maximum value of the range. See matplotlib.axes.Axes.imshow.

Returns: A uint8 Tensor of shape (?, ?, 4) containing the resulting plot.

tfplot.contrib.probmap_simple(x, **kwargs)[source]

Display a heatmap in color, but only displays the image content. The resulting op will be a RGBA image Tensor.

It reduces to probmap having colorbar and axis off. See the documentation of probmap for available arguments.

tfplot.contrib.batch(func)[source]

Make an autowrapped plot function (… -> RGBA tf.Tensor) work in a batch manner.

Example

>>> p
Tensor("p:0", shape=(batch_size, 16, 16, 4), dtype=uint8)
>>> tfplot.contrib.batch(tfplot.contrib.probmap)(p)
Tensor("probmap/PlotImages:0", shape=(batch_size, ?, ?, 4), dtype=uint8)

4. tfplot.summary

Summary Op utilities.

tfplot.summary.wrap(plot_func, _sentinel=None, batch=False, name=None, **kwargs)[source]

Wrap a plot function as a TensorFlow summary builder. It will return a python function that creates a TensorFlow op which evaluates to Summary protocol buffer with image.

The resulting function (say summary_wrapped) will have the following signature:

summary_wrapped(name, tensor, # [more input tensors ...],
                max_outputs=3, collections=None)

Examples

Given a plot function which returns a matplotlib Figure,

>>> def figure_heatmap(data, cmap='jet'):
>>>     fig, ax = tfplot.subplots()
>>>     ax.imshow(data, cmap=cmap)
>>>     return fig

we can wrap it as a summary builder function:

>>> summary_heatmap = tfplot.summary.wrap(figure_heatmap, batch=True)

Now, when building your computation graph, call it to build summary ops like tf.summary.image:

>>> heatmap_tensor
<tf.Tensor 'heatmap_tensor:0' shape=(16, 128, 128) dtype=float32>
>>>
>>> summary_heatmap("heatmap/original", heatmap_tensor)
>>> summary_heatmap("heatmap/cmap_gray", heatmap_tensor, cmap=gray)
>>> summary_heatmap("heatmap/no_default_collections", heatmap_tensor, collections=[])

Parameters:

plot_func – A python function or callable to wrap. See the documentation of tfplot.plot() for details. batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False. name – A default name for the plot op (optional). If not given, the name of plot_func will be used. kwargs – Optional keyword arguments that will be passed by default to plot().

Returns:

A python function that will create a TensorFlow summary operation, passing the provided arguments into plot op.

  • plot_func – A python function or callable to wrap. See the documentation of tfplot.plot() for details.
  • batch – If True, all the tensors passed as argument will be assumed to be batched. Default value is False.
  • name – A default name for the plot op (optional). If not given, the name of plot_func will be used.
  • kwargs – Optional keyword arguments that will be passed by default to plot().

Returns: A python function that will create a TensorFlow summary operation, passing the provided arguments into plot op.

tfplot.summary.plot(name, plot_func, in_tensors, collections=None, **kwargs)[source]

Create a TensorFlow op that outpus a Summary protocol buffer, to which a single plot operation is executed (i.e. image summary).

Basically, it is a one-liner wrapper of tfplot.ops.plot() and tf.summary.image() calls.

The generated Summary object contains single image summary value of the image of the plot drawn.

Parameters:

name – The name of scope for the generated ops and the summary op. Will also serve as a series name prefix in TensorBoard. plot_func – A python function or callable, specifying the plot operation as in tfplot.plot(). See the documentation at tfplot.plot(). in_tensors – A list of Tensor objects, as in plot(). collections – Optional list of ops.GraphKeys. The collections to add the summary to. Defaults to [_ops.GraphKeys.SUMMARIES]. kwargs – Optional keyword arguments passed to plot().

Returns:

A scalar Tensor of type string. The serialized Summary protocol buffer (tensorflow operation).

  • name – The name of scope for the generated ops and the summary op. Will also serve as a series name prefix in TensorBoard.
  • plot_func – A python function or callable, specifying the plot operation as in tfplot.plot(). See the documentation at tfplot.plot().
  • in_tensors – A list of Tensor objects, as in plot().
  • collections – Optional list of ops.GraphKeys. The collections to add the summary to. Defaults to [_ops.GraphKeys.SUMMARIES].
  • kwargs – Optional keyword arguments passed to plot().

Returns: A scalar Tensor of type string. The serialized Summary protocol buffer (tensorflow operation).

tfplot.summary.plot_many(name, plot_func, in_tensors, max_outputs=3, collections=None, **kwargs)[source]

Create a TensorFlow op that outputs a Summary protocol buffer, where plots could be drawn in a batch manner. This is a batch version of tfplot.summary.plot().

Specifically, all the input tensors in_tensors to plot_func is assumed to have the same batch size. Tensors corresponding to a single batch element will be passed to plot_func as input.

The resulting Summary contains multiple (up to max_outputs) image summary values, each of which contains a plot rendered by plot_func.

Parameters:

name – The name of scope for the generated ops and the summary op. Will also serve as a series name prefix in TensorBoard. plot_func – A python function or callable, specifying the plot operation as in tfplot.plot(). See the documentation at tfplot.plot(). in_tensors – A list of Tensor objects, the input to plot_func but each in a batch. max_outputs – Max number of batch elements to generate plots for. collections – Optional list of ops.GraphKeys. The collections to add the sumamry to. Defaults to [_ops.GraphKeys.SUMMARIES]. kwargs – Optional keyword arguments passed to plot().

Returns:

A scalar Tensor of type string. The serialized Summary protocol buffer (tensorflow operation).

  • name – The name of scope for the generated ops and the summary op. Will also serve as a series name prefix in TensorBoard.
  • plot_func – A python function or callable, specifying the plot operation as in tfplot.plot(). See the documentation at tfplot.plot().
  • in_tensors – A list of Tensor objects, the input to plot_func but each in a batch.
  • max_outputs – Max number of batch elements to generate plots for.
  • collections – Optional list of ops.GraphKeys. The collections to add the sumamry to. Defaults to [_ops.GraphKeys.SUMMARIES].
  • kwargs – Optional keyword arguments passed to plot().

Returns: A scalar Tensor of type string. The serialized Summary protocol buffer (tensorflow operation).

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019年08月11日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. Showcases of tfplot
    • 1. Setup: Utilities and Data
      • 2. tfplot.autowrap: The Main End-User API
        • 1. Decorator to define a TF op that draws plot
      • 2. Wrapping Matplotlib’s AxesPlot or Seaborn Plot
        • 1. Matplotlib
        • 2. Seaborn
      • 3. And Many More!
      • 2. tfplot.contrib: Some pre-defined plot ops
        • 1. probmap
          • 2. Auto-batch mode (tfplot.contrib.batch)
          • 3. More APIs
            • 1. Low-level APIs: tfplot.plot()
              • 1. A basic example
              • 2. with Arguments
              • 3. Examples of using kwargs
              • 4. plot_many() – the batch version
              • 5. Wrap once, use it as a factory – tfplot.autowrap() or tfplot.wrap()
              • 6. Batch example
            • 2. tfplot.summary (deprecated)
              • 1. tfplot.summary.plot()
              • 2. tfplot.summary.plot_many() – the batch version
          • 3. API Reference
            • 1. tfplot
              • 1.Wrapper functions
              • 2. Raw Plot Ops
            • 2. tfplot.figure
              • 3. tfplot.contrib
                • 4. tfplot.summary
                领券
                问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档