1. <ul id="0c1fb"></ul>

      <noscript id="0c1fb"><video id="0c1fb"></video></noscript>
      <noscript id="0c1fb"><listing id="0c1fb"><thead id="0c1fb"></thead></listing></noscript>

      99热在线精品一区二区三区_国产伦精品一区二区三区女破破_亚洲一区二区三区无码_精品国产欧美日韩另类一区

      RELATEED CONSULTING
      相關(guān)咨詢
      選擇下列產(chǎn)品馬上在線溝通
      服務(wù)時間:8:30-17:00
      你可能遇到了下面的問題
      關(guān)閉右側(cè)工具欄

      新聞中心

      這里有您想知道的互聯(lián)網(wǎng)營銷解決方案
      怎么使用pytorch框架-創(chuàng)新互聯(lián)

      這篇文章主要講解了“怎么使用pytorch框架”,文中的講解內(nèi)容簡單清晰,易于學(xué)習(xí)與理解,下面請大家跟著小編的思路慢慢深入,一起來研究和學(xué)習(xí)“怎么使用pytorch框架”吧!

      我們一直強(qiáng)調(diào)成都網(wǎng)站設(shè)計(jì)、成都網(wǎng)站建設(shè)對于企業(yè)的重要性,如果您也覺得重要,那么就需要我們慎重對待,選擇一個安全靠譜的網(wǎng)站建設(shè)公司,企業(yè)網(wǎng)站我們建議是要么不做,要么就做好,讓網(wǎng)站能真正成為企業(yè)發(fā)展過程中的有力推手。專業(yè)網(wǎng)站制作公司不一定是大公司,創(chuàng)新互聯(lián)作為專業(yè)的網(wǎng)絡(luò)公司選擇我們就是放心。

        中文新聞情感分類 Bert-Pytorch-transformers

        使用pytorch框架以及transformers包,以及Bert的中文預(yù)訓(xùn)練模型

        文件目錄

        data

        Train_DataSet.csv

        Train_DataSet_Label.csv

        main.py

        NewsData.py

        #main.py

        from transformers import BertTokenizer

        from transformers import BertForSequenceClassification

        from transformers import BertConfig

        from transformers import BertPreTrainedModel

        import torch

        import torch.nn as nn

        from transformers import BertModel

        import time

        import argparse

        from NewsData import NewsData

        import os

        def get_train_args():

        parser=argparse.ArgumentParser()

        parser.add_argument('--batch_size',type=int,default=10,help = '每批數(shù)據(jù)的數(shù)量')

        parser.add_argument('--nepoch',type=int,default=3,help = '訓(xùn)練的輪次')

        parser.add_argument('--lr',type=float,default=0.001,help = '學(xué)習(xí)率')

        parser.add_argument('--gpu',type=bool,default=True,help = '是否使用gpu')

        parser.add_argument('--num_workers',type=int,default=2,help='dataloader使用的線程數(shù)量')

        parser.add_argument('--num_labels',type=int,default=3,help='分類類數(shù)')

        parser.add_argument('--data_path',type=str,default='./data',help='數(shù)據(jù)路徑')

        opt=parser.parse_args()

        print(opt)

        return opt

        def get_model(opt):

        #類方法.from_pretrained()獲取預(yù)訓(xùn)練模型,num_labels是分類的類數(shù)

        model = BertForSequenceClassification.from_pretrained('bert-base-chinese',num_labels=opt.num_labels)

        return model

        def get_data(opt):

        #NewsData繼承于pytorch的Dataset類

        trainset = NewsData(opt.data_path,is_train = 1)

        trainloader=torch.utils.data.DataLoader(trainset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.num_workers)

        testset = NewsData(opt.data_path,is_train = 0)

        testloader=torch.utils.data.DataLoader(testset,batch_size=opt.batch_size,shuffle=False,num_workers=opt.num_workers)

        return trainloader,testloader

        def train(epoch,model,trainloader,testloader,optimizer,opt):

        print('\ntrain-Epoch: %d' % (epoch+1))

        model.train()

        start_time = time.time()

        print_step = int(len(trainloader)/10)

        for batch_idx,(sue,label,posi) in enumerate(trainloader):

        if opt.gpu:

        sue = sue.cuda()

        posi = posi.cuda()

        label = label.unsqueeze(1).cuda()

        optimizer.zero_grad()

        #輸入?yún)?shù)為詞列表、位置列表、標(biāo)簽

        outputs = model(sue, position_ids=posi,labels = label)

        loss, logits = outputs[0],outputs[1]

        loss.backward()

        optimizer.step()

        if batch_idx % print_step == 0:

        print("Epoch:%d [%d|%d] loss:%f" %(epoch+1,batch_idx,len(trainloader),loss.mean()))

        print("time:%.3f" % (time.time() - start_time))

        def test(epoch,model,trainloader,testloader,opt):

        print('\ntest-Epoch: %d' % (epoch+1))

        model.eval()

        total=0

        correct=0

        with torch.no_grad():

        for batch_idx,(sue,label,posi) in enumerate(testloader):

        if opt.gpu:

        sue = sue.cuda()

        posi = posi.cuda()

        labels = label.unsqueeze(1).cuda()

        label = label.cuda()

        else:

        labels = label.unsqueeze(1)

        outputs = model(sue, labels=labels)

        loss, logits = outputs[:2]

        _,predicted=torch.max(logits.data,1)

        total+=sue.size(0)

        correct+=predicted.data.eq(label.data).cpu().sum()

        s = ("Acc:%.3f" %((1.0*correct.numpy())/total))

        print(s)

        if __name__=='__main__':

        opt = get_train_args()

        model = get_model(opt)

        trainloader,testloader = get_data(opt)

        if opt.gpu:

        model.cuda()

        optimizer=torch.optim.SGD(model.parameters(),lr=opt.lr,momentum=0.9)

        if not os.path.exists('./model.pth'):

        for epoch in range(opt.nepoch):

        train(epoch,model,trainloader,testloader,optimizer,opt)

        test(epoch,model,trainloader,testloader,opt)

        torch.save(model.state_dict(),'./model.pth')

        else:鄭州治療婦科哪個醫(yī)院好 http://www.120kdfk.com/

        model.load_state_dict(torch.load('model.pth'))

        print('模型存在,直接test')

        test(0,model,trainloader,testloader,opt)

        #NewsData.py

        from transformers import BertTokenizer

        from transformers import BertForSequenceClassification

        from transformers import BertConfig

        from transformers import BertPreTrainedModel

        import torch

        import torch.nn as nn

        from transformers import BertModel

        import time

        import argparse

        class NewsData(torch.utils.data.Dataset):

        def __init__(self,root,is_train = 1):

        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

        self.data_num = 7346

        self.x_list = []

        self.y_list = []

        self.posi = []

        with open(root + '/Train_DataSet.csv',encoding='UTF-8') as f:

        for i in range(self.data_num+1):

        line = f.readline()[:-1] + '這是一個中性的數(shù)據(jù)'

        data_one_str = line.split(',')[len(line.split(','))-2]

        data_two_str = line.split(',')[len(line.split(','))-1]

        if len(data_one_str) < 6:

        z = len(data_one_str)

        data_one_str = data_one_str + ',' + data_two_str[0:min(200,len(data_two_str))]

        else:

        data_one_str = data_one_str

        if i==0:

        continue

        word_l = self.tokenizer.encode(data_one_str, add_special_tokens=False)

        if len(word_l)<100:

        while(len(word_l)!=100):

        word_l.append(0)

        else:

        word_l = word_l[0:100]

        word_l.append(102)

        l = word_l

        word_l = [101]

        word_l.extend(l)

        self.x_list.append(torch.tensor(word_l))

        self.posi.append(torch.tensor([i for i in range(102)]))

        with open(root + '/Train_DataSet_Label.csv',encoding='UTF-8') as f:

        for i in range(self.data_num+1):

        #print(i)

        label_one = f.readline()[-2]

        if i==0:

        continue

        label_one = int(label_one)

        self.y_list.append(torch.tensor(label_one))

        #訓(xùn)練集或者是測試集

        if is_train == 1:

        self.x_list = self.x_list[0:6000]

        self.y_list = self.y_list[0:6000]

        self.posi = self.posi[0:6000]

        else:

        self.x_list = self.x_list[6000:]

        self.y_list = self.y_list[6000:]

        self.posi = self.posi[6000:]

        self.len = len(self.x_list)

        def __getitem__(self, index):

        return self.x_list[index], self.y_list[index],self.posi[index]

        def __len__(self):

        return self.len

      感謝各位的閱讀,以上就是“怎么使用pytorch框架”的內(nèi)容了,經(jīng)過本文的學(xué)習(xí)后,相信大家對怎么使用pytorch框架這一問題有了更深刻的體會,具體使用情況還需要大家實(shí)踐驗(yàn)證。這里是創(chuàng)新互聯(lián),小編將為大家推送更多相關(guān)知識點(diǎn)的文章,歡迎關(guān)注!


      文章標(biāo)題:怎么使用pytorch框架-創(chuàng)新互聯(lián)
      地址分享:http://www.ef60e0e.cn/article/djisdp.html
      99热在线精品一区二区三区_国产伦精品一区二区三区女破破_亚洲一区二区三区无码_精品国产欧美日韩另类一区
      1. <ul id="0c1fb"></ul>

        <noscript id="0c1fb"><video id="0c1fb"></video></noscript>
        <noscript id="0c1fb"><listing id="0c1fb"><thead id="0c1fb"></thead></listing></noscript>

        龙口市| 洛阳市| 湾仔区| 建瓯市| 恩施市| 商水县| 凌源市| 那曲县| 洞头县| 泸溪县| 肥城市| 宁明县| 南涧| 塔河县| 巩义市| 南康市| 西峡县| 漳平市| 南京市| 绍兴县| 墨玉县| 灯塔市| 八宿县| 托里县| 蕲春县| 嘉定区| 苗栗市| 黎川县| 集安市| 信丰县| 于都县| 清水县| 漳浦县| 梁平县| 乳山市| 岚皋县| 砀山县| 泸州市| 根河市| 淮安市| 宽甸|