类似于 论文实践学习 - Stacked Hourglass Networks for Human Pose Estimation ,基于Docker-Torch,估计人体关节点.
这里只简单进行测试估计结果,由于显存有限,未能加入所有的 scale_search.
# 输入参数由两个, 第二个参数默认为 'mean'
th demo.lua imglist.txt 'max'
# or
th demo.lua imglist.txt
require 'paths'
paths.dofile('util.lua')
paths.dofile('img.lua')
--------------------------------------------------------------------------------
-- Initialization
--------------------------------------------------------------------------------
a = loadImageNames(arg[1]) -- 批量读取文件名列表
m = torch.load( '../checkpoints/mpii/crf_parts/model.t7') -- Load pre-trained model
m:cuda()
m:evaluate()
-- Parameters
local isflip = true
local minusmean = tru
local scale_search = {1.0, 1.1} -- 根据显存情况来选择
-- local scale_search = {0.7,0.8,0.9,1.0,1.1,1.2} -- used in paper with NVIDIA Titan X (12 GB memory).
-- Displays a convenient progress bar
idxs = torch.range(1, a.nsamples)
nsamples = idxs:nElement()
xlua.progress(0,nsamples)
preds = torch.Tensor(nsamples,16,3)
imgs = torch.Tensor(nsamples,3,256,256)
local imgpath = '../data/image/'
--------------------------------------------------------------------------------
-- Main loop
--------------------------------------------------------------------------------
for idx = 1,nsamples do
-- Set up input image
local imgname = paths.concat(imgpath, a['images'][idxs[idx]])
print(imgname)
local im = image.load(imgname)
local original_scale = 256/200 -- 假设预先已经将图像中人体进行裁剪,并resize到256
local center = {128.0, 128.0}
local fuseInp = torch.zeros(#scale_search, 3, 256, 256)
local hmpyra = torch.zeros(#scale_search, 16, im:size(2), im:size(3))
local batch = torch.Tensor(#scale_search, 3, 256, 256)
local flipbatch = torch.Tensor(#scale_search, 3, 256, 256)
for is, factor in ipairs(scale_search) do
local scale = original_scale*factor
local inp = crop(im, center, scale, 0, 256)
batch[{is, {}, {}, {}}]:copy(inp)
imgs[idx]:copy(inp)
end
-- minus mean
if minusmean then
batch:add(-0.5)
end
-- Get network output
local out = m:forward(batch:cuda())
-- Get flipped output
if isflip then
out = applyFn(function (x) return x:clone() end, out)
local flippedOut = m:forward(flip(batch):cuda())
flippedOut = applyFn(function (x) return flip(shuffleLR(x)) end, flippedOut)
out = applyFn(function (x,y) return x:add(y):div(2) end, out, flippedOut)
end
cutorch.synchronize()
local hm = out[#out]:float()
hm[hm:lt(0)] = 0
-- Get heatmaps (original image size)
for is, scale in pairs(scale_search) do
local hm_img = getHeatmaps(im, center, original_scale*scale, 0, 256, hm[is])
hmpyra[{is, {}, {}, {}}]:copy(hm_img:sub(1, 16))
end
-- fuse heatmap
if arg[2] == 'max' then
fuseHm = hmpyra:max(1)
else
fuseHm = hmpyra:mean(1)
end
fuseHm = fuseHm[1]
fuseHm[fuseHm:lt(0)] = 0
-- get predictions
for p = 1,16 do
local maxy, iy = fuseHm[p]:max(2)
local maxv, ix = maxy:max(1)
ix = torch.squeeze(ix)
preds[idx][p][2] = ix
preds[idx][p][1] = iy[ix]
preds[idx][p][3] = maxy[ix]
end
xlua.progress(idx, nsamples)
collectgarbage()
end
-- Save predictions
local predFile = hdf5.open('../preds/preds.h5', 'w')
predFile:write('preds', preds)
predFile:write('imgs', imgs)
predFile:close()
#!/usr/bin/env python
import h5py
import scipy.misc as scm
import matplotlib.pyplot as plt
JointsIndex = {'r_ankle': 0, 'r_knee': 1, 'r_hip': 2,
'l_hip': 3, 'l_knee': 4, 'l_ankle': 5,
'pelvis': 6, 'thorax': 7, 'neck': 8, 'head': 9,
'r_wrist': 10, 'r_elbow': 11, 'r_shoulder': 12,
'l_shoulder': 13, 'l_elbow': 14, 'l_wrist': 15}
JointPairs = [['head', 'neck'], ['neck', 'thorax'],
['thorax', 'r_shoulder'], ['thorax', 'l_shoulder'], \
['r_shoulder', 'r_elbow'], ['r_elbow', 'r_wrist'],
['l_shoulder', 'l_elbow'], ['l_elbow', 'l_wrist'], \
['pelvis', 'r_hip'], ['pelvis', 'l_hip'], ['r_hip', 'r_knee'],
['r_knee', 'r_ankle'], \
['l_hip', 'l_knee'], ['l_knee', 'l_ankle'],
['thorax', 'pelvis']]
StickType = ['r-', 'r-', 'g-', 'b-', 'g-', 'g-', 'b-', 'b-', 'c-', 'm-',
'c-', 'c-', 'm-', 'm-', 'r-']
imgs = open('../test/imglist.txt','r').readlines()
images_path = '../data/image/'
f = h5py.File('preds.h5','r')
f_keys = f.keys()
#imgs = f['imgs'][:]
preds = f['preds'][:]
f.close()
assert len(imgs) == len(preds)
for i in range(len(imgs)):
filename = images_path + imgs[i][:-1]
img = scm.imread(filename)
pose = preds[i]
# img = imgs[i].transpose(1,2,0)
plt.axis('off')
plt.imshow(img)
# for i in range(16):
# if pose[i][0] > 0 and pose[i][1] > 0:
# plt.scatter(pose[i][0], pose[i][1], marker='o', color='r', s=15)
# plt.show()
for i in range(len(JointPairs)):
idx1 = JointsIndex[JointPairs[i][0]]
idx2 = JointsIndex[JointPairs[i][1]]
if pose[idx1][0] > 0 and pose[idx1][1] > 0 and \
pose[idx2][0] > 0 and pose[idx2][1] > 0:
joints_x = [pose[idx1][0], pose[idx2][0]]
joints_y = [pose[idx1][1], pose[idx2][1]]
plt.plot(joints_x, joints_y, StickType[i], linewidth=3)
plt.show()
print 'Done.'