首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何安装torchcrf并修复导入错误?

如何安装torchcrf并修复导入错误?
EN

Stack Overflow用户
提问于 2022-09-27 14:12:00
回答 1查看 54关注 0票数 0

我使用Windows11pro x64,PyCharm 2022.2.2 (专业版)- Build #PY-222.4167.33,建于2022年9月15日。Python版本:

代码语言:javascript
运行
复制
Microsoft Windows [Version 10.0.22621.521]
(c) Microsoft Corporation. All rights reserved.

C:\Users\donhu>python
Python 3.10.7 (tags/v3.10.7:6cc6b13, Sep  5 2022, 14:08:36) [MSC v.1933 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>>

我的代码

代码语言:javascript
运行
复制
import argparse
import yaml
import pandas as pd
import torch
import torchcrf
import transformers
from data import Dataset
from engines import train_fn
import warnings

warnings.filterwarnings("ignore")
parser = argparse.ArgumentParser()
parser.add_argument("--data_file", type=str)
parser.add_argument("--hyps_file", type=str)
args = parser.parse_args()
data_file = yaml.load(open(args.data_file), Loader=yaml.FullLoader)
hyps_file = yaml.load(open(args.hyps_file), Loader=yaml.FullLoader)
train_loader = torch.utils.data.DataLoader(
    Dataset(
        df=pd.read_csv(data_file["train_df_path"]),
        tag_names=data_file["tag_names"],
        tokenizer=transformers.AutoTokenizer.from_pretrained(hyps_file["encoder"], use_fast=False),
    ),
    num_workers=hyps_file["num_workers"],
    batch_size=hyps_file["batch_size"],
    shuffle=True,
)
val_loader = torch.utils.data.DataLoader(
    Dataset(
        df=pd.read_csv(data_file["val_df_path"]),
        tag_names=data_file["tag_names"],
        tokenizer=transformers.AutoTokenizer.from_pretrained(hyps_file["encoder"], use_fast=False),
    ),
    num_workers=hyps_file["num_workers"],
    batch_size=hyps_file["batch_size"] * 2,
)
loaders = {
    "train": train_loader,
    "val": val_loader,
}
model = transformers.RobertaForTokenClassification.from_pretrained(hyps_file["encoder"],
                                                                   num_labels=data_file["num_tags"])
if hyps_file["use_crf"]:
    criterion = torchcrf.CRF(num_tags=data_file["num_tags"], batch_first=True)
else:
    criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=float(hyps_file["lr"]))
train_fn(
    loaders, model, torch.device(hyps_file["device"]), hyps_file["device_ids"],
    criterion,
    optimizer,
    epochs=hyps_file["epochs"],
    ckp_path="../ckps/{}.pt".format(hyps_file["encoder"].split("/")[-1]),
)

我用PyCharm安装

我还按命令安装

代码语言:javascript
运行
复制
pip install torchcrf

我也试着

代码语言:javascript
运行
复制
pip install pytorch-crf

但我不是成功。

如何安装torchcrf并修复导入错误?

EN

回答 1

Stack Overflow用户

发布于 2022-09-27 14:25:03

安装

代码语言:javascript
运行
复制
from TorchCRF import CRF

重写代码

代码语言:javascript
运行
复制
import argparse
import yaml
import pandas as pd
import torch
from TorchCRF import CRF
import transformers
from data import Dataset
from engines import train_fn
import warnings

warnings.filterwarnings("ignore")
parser = argparse.ArgumentParser()
parser.add_argument("--data_file", type=str)
parser.add_argument("--hyps_file", type=str)
args = parser.parse_args()
data_file = yaml.load(open(args.data_file), Loader=yaml.FullLoader)
hyps_file = yaml.load(open(args.hyps_file), Loader=yaml.FullLoader)
train_loader = torch.utils.data.DataLoader(
    Dataset(
        df=pd.read_csv(data_file["train_df_path"]),
        tag_names=data_file["tag_names"],
        tokenizer=transformers.AutoTokenizer.from_pretrained(hyps_file["encoder"], use_fast=False),
    ),
    num_workers=hyps_file["num_workers"],
    batch_size=hyps_file["batch_size"],
    shuffle=True,
)
val_loader = torch.utils.data.DataLoader(
    Dataset(
        df=pd.read_csv(data_file["val_df_path"]),
        tag_names=data_file["tag_names"],
        tokenizer=transformers.AutoTokenizer.from_pretrained(hyps_file["encoder"], use_fast=False),
    ),
    num_workers=hyps_file["num_workers"],
    batch_size=hyps_file["batch_size"] * 2,
)
loaders = {
    "train": train_loader,
    "val": val_loader,
}
model = transformers.RobertaForTokenClassification.from_pretrained(hyps_file["encoder"],
                                                                   num_labels=data_file["num_tags"])
if hyps_file["use_crf"]:
    criterion = CRF(num_tags=data_file["num_tags"], batch_first=True)
else:
    criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=float(hyps_file["lr"]))
train_fn(
    loaders, model, torch.device(hyps_file["device"]), hyps_file["device_ids"],
    criterion,
    optimizer,
    epochs=hyps_file["epochs"],
    ckp_path="../ckps/{}.pt".format(hyps_file["encoder"].split("/")[-1]),
)

参考https://pypi.org/project/TorchCRF/

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73869068

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档