我是PyTorch新手,我一直收到错误消息mat1 dim1 must match mat1 dim0
这是我的网络代码
def train(model,dataloader,batch_size,epochs,optimizer, name='Model_avec_clusters', save_model=2):
model.train()
l=[]
n_epoch = -1
for epoch in tqdm(range(epochs)):
n_epoch += 1
running_loss = 0.0
count = 0
for traj, client, stand, taxi, week, day, quarter, targets in dataloader:
traj = traj.to(device)
targets = targets.to(device)
client = client.to(device)
stand = stand.to(device)
taxi = taxi.to(device)
week = week.to(device)
day = day.to(device)
quarter = quarter.to(device)
outputs = model(traj, client, stand, taxi, week, day, quarter)
loss = criterion(outputs.double(),targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# statistics
running_loss += loss.data.item()
count += 1
epoch_loss = running_loss / count
l.append(epoch_loss)
if epoch%1==0:
print('Loss: {:.4f}'.format(
epoch_loss))
if epoch%save_model==0:
torch.save(model.state_dict(),f"model_cluster/{name}_{epoch}.pth")
with open(f"loss_cluster/{name}_scores.txt", 'w') as file:
file.writelines(["%s\n" % item for item in l])
plt.plot(l)
return l
l = train(network, dataloader, batch_size = batch_size, epochs=5, optimizer=optimizer)
报这个错
相似问题