首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Python: tqdm没有显示进度条

Python: tqdm没有显示进度条
EN

Stack Overflow用户
提问于 2021-05-14 13:20:31
回答 3查看 8.1K关注 0票数 8

我已经为我的网络的PyTorch函数编写了fit代码。但是当我在循环中使用tqdm时,它不会从0%增加到我无法理解的原因。

以下是代码:

代码语言:javascript
运行
复制
from tqdm.notebook import tqdm

def fit(model, train_dataset, val_dataset, epochs=1, batch_size=32, warmup_prop=0, lr=5e-5):

    device = torch.device('cuda:1')
    model.to(device)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    optimizer = AdamW(model.parameters(), lr=lr)
    
    num_warmup_steps = int(warmup_prop * epochs * len(train_loader))
    num_training_steps = epochs * len(train_loader)
    
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

    loss_fct = nn.BCEWithLogitsLoss(reduction='mean').to(device)
    
    for epoch in range(epochs):
        model.train()
        start_time = time.time()
        
        optimizer.zero_grad()
        avg_loss = 0
        
        for step, (x, y_batch) in tqdm(enumerate(train_loader), total=len(train_loader)): 
            y_pred = model(x.to(device))
            
            loss = loss_fct(y_pred.view(-1).float(), y_batch.float().to(device))
            loss.backward()
            avg_loss += loss.item() / len(train_loader)


            optimizer.step()
            scheduler.step()
            model.zero_grad()
            optimizer.zero_grad()
                
        model.eval()
        preds = []
        truths = []
        avg_val_loss = 0.

        with torch.no_grad():
            for x, y_batch in val_loader:                
                y_pred = model(x.to(device))
                loss = loss_fct(y_pred.detach().view(-1).float(), y_batch.float().to(device))
                avg_val_loss += loss.item() / len(val_loader)
                
                probs = torch.sigmoid(y_pred).detach().cpu().numpy()
                preds += list(probs.flatten())
                truths += list(y_batch.numpy().flatten())
            score = roc_auc_score(truths, preds)
            
        
        dt = time.time() - start_time
        lr = scheduler.get_last_lr()[0]
        print(f'Epoch {epoch + 1}/{epochs} \t lr={lr:.1e} \t t={dt:.0f}s \t loss={avg_loss:.4f} \t val_loss={avg_val_loss:.4f} \t val_auc={score:.4f}')

输出

使用所需参数执行fit函数后的输出如下所示:

0%| | 0/6986 [00:00<?, ?it/s]

怎么解决这个问题?

EN

回答 3

Stack Overflow用户

回答已采纳

发布于 2021-05-14 14:55:34

当你从tqdm.notebook进口的时候,这意味着你使用的是木星笔记本,对吗?如果不是,你必须做from tqdm import tqdm

我简化了示例代码,使其变得非常小,如下所示:

代码语言:javascript
运行
复制
import time
from tqdm.notebook import tqdm

l = [None] * 10000

for i, e in tqdm(enumerate(l), total = len(l)): 
    time.sleep(0.01)

并在Google Colab jupyter笔记本上执行。它给我展示了这样一个很棒的进度条:

因此,这意味着tqdm在笔记本模式下正确工作。因此,您的可迭代或循环代码有一些问题,而不是tqdm。可能的原因可能是您的内部循环需要很长时间,所以即使是1次迭代(在您的情况下总共6986次迭代中)也要花费很长时间,而且不会在进度栏中显示。

另一个原因是您的可迭代性要花费很长时间才能生成第二个元素,而且您还必须检查它是否有效。

我还看到您向我们展示了ASCII进度条,它不是笔记本中通常显示的进度条(笔记本通常显示图形条)。所以也许你根本不在笔记本里?然后,您必须执行from tqdm import tqdm而不是from tqdm.notebook import tqdm

另外,首先尝试简化您的代码(只是暂时的),以确定在您的情况下,原因是否真的与tqdm模块有关,而不是使用您的可迭代或循环代码。试着从我上面提供的代码开始。

另外,与tqdm相比,只在循环中打印类似print(step)的内容,它是否至少在屏幕上打印两行?

如果在我的代码中我执行from tqdm import tqdm,然后在控制台Python中执行它,那么我得到:

代码语言:javascript
运行
复制
10%|███████████▉              | 950/10000 [00:14<02:20, 64.37it/s]

这意味着控制台版本也能工作。

票数 7
EN

Stack Overflow用户

发布于 2021-09-06 12:52:12

这可能发生在木星,如果笔记本不可信-如果是这样,点击右上角的“不可信”框。

票数 3
EN

Stack Overflow用户

发布于 2021-11-21 13:03:13

这是因为在终端环境中使用from tqdm.notebook import tqdm而不是from tqdm import tqdm

提供一个示例来说明这个问题:

代码语言:javascript
运行
复制
from tqdm.notebook import tqdm

if __name__ == '__main__':
    data = range(10000)
    for i, item in enumerate(tqdm(range(len(data)))):
        i = i + 1

它将显示:

但是,如果使用此代码示例:

代码语言:javascript
运行
复制
from tqdm import tqdm

if __name__ == '__main__':
    data = range(10000)
    for i, item in enumerate(tqdm(range(len(data)))):
        i = i + 1

它将展示:

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67535060

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档