(1)学习FastText的原理和使用,通过10折交叉验证划分数据集。
(2)注意fasttext.train_supervised
这里predict
后的返回值结果,因为要概率值最大的那个label,所以包括在后面的栗子我们会发现有一坨model.predict(x)[0][0].split('__')[-1]
,千万不要慌,就是去第一个label然后因为加上了下划线嘛,所以去掉下划线分割出的几坨东西,我们取最后一坨就是想要的label,ex:上面的__label__baking
就处理得到baking
啦。
One-hot、Bag of Words、N-gram、TF-IDF等方法都存在一定问题:转换得到的向量维度很高,需要较长的训练实践;没有考虑单词与单词之间的关系,只是进行了统计。
而深度学习用于文本表示,可以将其映射到一个低维空间,比如FastText
、Word2Vec
和Bert
。
FastText是一个三层神经网络:输入层、隐含层、输出层。通过embedding层将单词映射到稠密空间,然后将句子中所有的单词在embedding空间中进行平均,进而完成分类。
具体的论文:Bag of Tricks for Efficient Text Classification, https://arxiv.org/abs/1607.01759
FastText在文本分类任务上,是优于TF-IDF的:
首先是fasttext
包的下载,如果在anaconda的prompt用命令pip install fasttext
下载不了,可以直接在该网址找到对应自己python解释器版本的whl
文件下载,然后很多博客都说接着在cmd
命令pip install
下载,但是我试了是会报错的,所以试了下回到prompt下载发现就可以了(安装成功显示如下)。
还有一个细节要注意,因为代码开头会import fasttext
,就是说文件名我们是不能命名为fasttext
,否则会冲突,即报错如下:
AttributeError: partially initialized module 'fasttext' has no attribute 'train_supervised' (most likely due to a circular import)
1)多线程训练:fastText在训练的时候是采用的多线程进行训练的。每个训练线程在更新参数时并没有加锁,这会给参数更新带来一些噪音,但是不会影响最终的结果。无论是 google 的 word2vec 实现,还是 fastText 库,都没有加锁。线程的默认是12个,可以手动的进行设置。
2)分层softmax:fastText在计算softmax的时候采用分层softmax,这样可以大大提高运行的效率。
(3)使用了Hierarchical softmax其实就是所谓的霍夫曼树结构:该树每个叶节点都是一个词语,softmax的结果无非就是一个概率,那么我们要找某一个词时,就是计算到该单词的路径中概率积。
首先看后面会用到的fasttext.train_supervised
的参数:
input_file 训练文件路径(必须) output 输出文件路径(必须) label_prefix 标签前缀 default __label__ lr 学习率 default 0.1 lr_update_rate 学习率更新速率 default 100 dim 词向量维度 default 100 ws 上下文窗口大小 default 5 epoch epochs 数量 default 5 min_count 最低词频 default 5 word_ngrams n-gram 设置 default 1 loss 损失函数 {ns,hs,softmax} default softmax minn 最小字符长度 default 0 maxn 最大字符长度 default 0 thread 线程数量 default 12 t 采样阈值 default 0.0001 silent 禁用 c++ 扩展日志输出 default 1 encoding 指定 input_file 编码 default utf-8 pretrained_vectors 指定使用已有的词向量 .vec 文件 default None
在调参之前,对于训练数据的样本有这样的规定。
每条数据+"\t"+“label_prefix”+标签
也就是标签在每条数据之后,并用label_prefix
作为前缀(下面使用label_ft
)。
此处学习了Facebook的fastText官方文档,中文文档比英文文档少了点东西,墙裂建议看这里的英文文档。
中文文档:http://fasttext.apachecn.org/#/doc/zh/supervised-tutorial
(1)n-gram
示例: who am I? n-gram设置为2
n-gram特征有,who, who am, am, am I, I
(2)n-char
示例: where, n=3, 设置起止符<, > n-char特征有,<wh, whe, her, ere, er>
所以对于中文而言,输出的词可能没必要再细分,故n-char可为0,但是在词与词语义中会有一定联系,根据情况来设定n-gram。
官方有个栗子可以学习:
首先导入包和数据进行模型训练:
这里如果处理多分类的思想是给每个label设置一个二分类分类器,通过在fasttext.train_supervised
的-loss one-vs-all
或者-loss ova
参数设定。
>>> import fasttext >>> model = fasttext.train_supervised(input="cooking.train", lr=0.5, epoch=25, wordNgrams=2, bucket=200000, dim=50, loss='ova') Read 0M words Number of words: 14543 Number of labels: 735 Progress: 100.0% words/sec/thread: 72104 lr: 0.000000 loss: 4.340807 ETA: 0h 0m
想要尽可能多的预测则将k
参数设置为-1,然后为了只要概率大于0.5概率的label,可以设置threshold
为0.5:
>>> model.predict("Which baking dish is best to bake a banana bread ?", k=-1, threshold=0.5) ((u''__label__baking, u'__label__bananas', u'__label__bread'), array([1.00000, 0.939923, 0.592677])) >>> model.test("cooking.valid", k=-1) (3000L, 0.702, 0.2)
注意这里predict
后的返回值结果,因为要概率值最大的那个label,所以包括在后面的栗子我们会发现有一坨model.predict(x)[0][0].split('__')[-1]
,千万不要慌,就是去第一个label然后因为加上了下划线嘛,所以去掉下划线分割出的几坨东西,我们取最后一坨就是想要的label,ex:上面的__label__baking
就处理得到baking
啦。
详细看注释。
# -*- coding: utf-8 -*- """ Created on Fri Nov 5 09:04:43 2021 @author: 86493 """ import pandas as pd from sklearn.metrics import f1_score import fasttext # 转换为FastText需要的格式 train_df = pd.read_csv('train_set.csv', sep = '\t', nrows =15000) # astype是强制类型转换 train_df['label_ft'] = '__label__' + train_df['label'].astype(str) train_df[['text', 'label_ft']].iloc[: -5000].to_csv('train.csv', index = None, header = None, sep = '\t') # word_ngrams=2和上个task的ngram_range不是一个意思 model = fasttext.train_supervised('train.csv', lr = 1.0, # 学习率 wordNgrams = 2, # 字母组合长度 verbose = 2, # minCount = 1, # 过滤小于minCount的单词 epoch = 25, # 迭代次数 loss = 'hs') # 默认的loss是负采样 # 最后5k条测试样本作为验证集,对预测结果做解析,列表生成式 val_pred = [model.predict(x)[0][0].split('__')[-1] for x in train_df.iloc[-5000:]['text']] print(f1_score(train_df['label'].values[-5000:].astype(str), # [10000:]也行 val_pred, average = 'macro'))
运行的过程,最后的F1值为0.8229238895393863:
Read 9M words Number of words: 5341 Number of labels: 14 Read 9M words Number of words: 5341 Number of labels: 14 Progress: 0.1% words/sec/thread: 29827 lr: 0.999370 avg.loss: 2.461391 ETA: 0h18m 9s Read 9M words Number of words: 5341 Number of labels: 14 Progress: 0.2% words/sec/thread: 34835 lr: 0.998136 avg.loss: 2.420802 ETA: 0h15m34s Read 9M words Number of words: 5341 Number of labels: 14 Progress: 0.3% words/sec/thread: 32120 lr: 0.997353 avg.loss: 2.450788 ETA: 0h16m49s Read 9M words Number of words: 5341 Number of labels: 14 Progress: 0.6% words/sec/thread: 56698 lr: 0.993602 avg.loss: 2.382696 ETA: 0h 9m30s Read 9M words Number of words: 5341 Number of labels: 14 Progress: 1.1% words/sec/thread: 78496 lr: 0.988754 avg.loss: 2.180967 ETA: 0h 6m49s Read 9M words Number of words: 5341 Number of labels: 14 Progress: 1.4% words/sec/thread: 79590 lr: 0.986145 avg.loss: 2.087085 ETA: 0h 6m43s Read 9M words Number of words: 5341 Number of labels: 14 Progress: 2.7% words/sec/thread: 130232 lr: 0.973349 avg.loss: 1.691365 ETA: 0h 4m 3s Read 9M words Number of words: 5341 Number of labels: 14 。。。。。。。 Progress: 97.4% words/sec/thread: 601597 lr: 0.025779 avg.loss: 0.150194 ETA: 0h 0m 1s Read 9M words Number of words: 5341 Number of labels: 14 Progress: 99.8% words/sec/thread: 603704 lr: 0.004354 avg.loss: 0.147263 ETA: 0h 0m 0s Read 9M words Number of words: 5341 Number of labels: 14 Progress: 100.0% words/sec/thread: 603532 lr: 0.000000 avg.loss: 0.146719 ETA: 0h 0m 0s 0.8229238895393863
调参:
(1)通过阅读文档,要弄清楚这些参数的大致含义,那些参数会增加模型的复杂度。
(2)通过在验证集上进行验证模型精度,找到模型在是否过拟合还是欠拟合。
10折交叉验证,每折使用9/10的数据进行训练,剩余1/10作为验证集检验模型的效果。这里需要注意每折的划分必须保证标签的分布与整个数据集的分布一致。
label2id = {} for i in range(total): label = str(all_labels[i]) if label not in label2id: label2id[label] = [i] else: label2id[label].append(i)
通过10折划分,我们选择最后一份完成剩余的实验,即索引为9的一份做为验证集,索引为1-8的作为训练集,然后基于验证集的结果调整超参数,使得模型性能更优。
(1)阿里天池平台
(2)fastText训练和使用
(3)fasttext中文文档:http://fasttext.apachecn.org/#/
(4)https://github.com/apachecn/fasttext-doc-zh/
(5)fastText官网:https://fasttext.cc/