首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >mat1 dim 1必须匹配mat2 dim 0 ?

mat1 dim 1必须匹配mat2 dim 0 ?

提问于 2022-04-25 16:34:09
回答 0关注 0查看 200

我是PyTorch新手,我一直收到错误消息mat1 dim1 must match mat1 dim0

这是我的网络代码

def train(model,dataloader,batch_size,epochs,optimizer, name='Model_avec_clusters', save_model=2):

model.train()

l=[]

n_epoch = -1

for epoch in tqdm(range(epochs)):

n_epoch += 1

running_loss = 0.0

count = 0

for traj, client, stand, taxi, week, day, quarter, targets in dataloader:

traj = traj.to(device)

targets = targets.to(device)

client = client.to(device)

stand = stand.to(device)

taxi = taxi.to(device)

week = week.to(device)

day = day.to(device)

quarter = quarter.to(device)

outputs = model(traj, client, stand, taxi, week, day, quarter)

loss = criterion(outputs.double(),targets)

optimizer.zero_grad()

loss.backward()

optimizer.step()

# statistics

running_loss += loss.data.item()

count += 1

epoch_loss = running_loss / count

l.append(epoch_loss)

if epoch%1==0:

print('Loss: {:.4f}'.format(

epoch_loss))

if epoch%save_model==0:

torch.save(model.state_dict(),f"model_cluster/{name}_{epoch}.pth")

with open(f"loss_cluster/{name}_scores.txt", 'w') as file:

file.writelines(["%s\n" % item for item in l])

plt.plot(l)

return l

l = train(network, dataloader, batch_size = batch_size, epochs=5, optimizer=optimizer)

报这个错

回答

和开发者交流更多问题细节吧,去 写回答
相关文章

相似问题

相关问答用户
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档