首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >三维张量与四维张量的比较

三维张量与四维张量的比较
EN

Stack Overflow用户
提问于 2016-09-07 18:21:54
回答 1查看 756关注 0票数 2

我有以下的U-Net,我使用它来分割灰度PNG图像.

代码语言:javascript
复制
import cv2
import os
from sklearn.utils import shuffle
import tensorflow as tf
import numpy as np


OVERALLSIZE = int(float(input('Choose the number of images you want (<5635) : ')))
PATH = input('give absolute path to image')
TESTSIZE = int(float(input('Choose the number of the data you want to use as test (<5635) : ')))

######################################################################################################################

images = [img for img in os.listdir(PATH + '/Xtrain') if img.endswith('png')]
# put random_state to 1
images = shuffle(images,random_state = 0)
masks = [name[:-4]+'_mask.png' for name in images]

images, masks = images[:OVERALLSIZE], masks[:OVERALLSIZE]
images_, masks_  = [cv2.imread(PATH + '/Xtrain/' + img, cv2.IMREAD_GRAYSCALE).astype(np.int) for img in images], \
                   [cv2.imread(PATH + '/ytrain/' + msk, cv2.IMREAD_GRAYSCALE).astype(np.int) for msk in masks]

######################################################################################################################


X_train, y_train, X_test, y_test = np.asarray(images_[TESTSIZE:])/255., \
                                   np.asarray(masks_[TESTSIZE:]), \
                                   np.asarray(images_[:TESTSIZE])/255., \
                                   np.asarray(masks_[:TESTSIZE])



x = tf.placeholder(tf.float32, shape=[None, 420, 580], name='x')
y_ = tf.placeholder(tf.float32, shape=[None, 420, 580], name='y_')

sess = tf.InteractiveSession()

def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev = 0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

def convoer(inputs, shape, flag):
    W = weight_variable(shape)
    b = bias_variable([shape[3]])

    temp = shape
    temp[2] = shape[3]

    Wa = weight_variable(temp)
    ba = bias_variable([shape[3]])

    conv = tf.nn.relu(conv2d(inputs, W) + b)
    conv = tf.nn.relu(conv2d(conv, Wa) + ba)
    pool = max_pool_2x2(conv)

    if flag: return pool
    elif not flag: return conv

def upconvoer(inputs, shape, height, width):
    W = weight_variable(shape)
    b = bias_variable([shape[3]])

    temp = shape
    temp[2] = shape[3]

    Wa = weight_variable(temp)
    ba = bias_variable([shape[3]])

    up = tf.image.resize_images(inputs, height, width)
    conv = tf.nn.relu(conv2d(up, W) + b)
    conv = tf.nn.relu(conv2d(conv, Wa) + ba)

    return conv

def conv2d(x, W):
    return tf.nn.conv2d(x,W,strides = [1,1,1,1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize = [1,2,2,1], strides = [1,2,2,1], padding='SAME')

def U():
    inputs = tf.reshape(x, [-1,420,580,1])

    pool1 = convoer(inputs, [3,3,1,32], True)

    pool2 = convoer(pool1, [3,3,32,64], True)
    pool3 = convoer(pool2, [3,3,64,128], True)
    pool4 = convoer(pool3, [3,3,128,256], True)
    conv5 = convoer(pool4, [3,3,256,512], False)

    conv6 = upconvoer(conv5, [3,3,512,256], 73, 53)
    conv7 = upconvoer(conv6, [3,3,256,128], 145, 105)
    conv8 = upconvoer(conv7, [3,3,128,64], 290, 210)
    conv9 = upconvoer(conv8, [3,3,64,32], 420, 580)

    W10 = weight_variable([1,1,32,1])
    b10 = bias_variable([1])

    conv10 = tf.nn.sigmoid(conv2d(conv9, W10) + b10)

    y = conv10

    return y

y = U()

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

sess.run(train_step, feed_dict={x: X_train, y_: y_train})

交叉熵步骤,我试图乘一个三维张量y_批,grayscale_in_height,grayscale_in_width and 4d张量y批,高度,宽度,通道。

我得到以下错误:

代码语言:javascript
复制
ValueError: Incompatible shapes for broadcasting: (?, 420, 580) and (?, 420, 580, 1)

我尝试在3个不同的位置重塑y:在U函数中,在定义它的地方,在cross_entropy中,但是没有一个能工作。

EN

回答 1

Stack Overflow用户

发布于 2022-03-01 08:24:06

在张量中加入一个tf.newaxis。

代码语言:javascript
复制
x = x[:,:, :, tf.newaxis]

或者使用tf.squeeze去掉y中的附加轴。

代码语言:javascript
复制
y = tf.squeeze(y, axis=-1)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/39376770

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档