本项目使用google-bert/bert-base-chinese
模型进行中文文本分类任务,使用中文数据集进行训练,训练完成后,可以导出模型,进行预测。
数据集下载地址
相关问题分析
- 创建新的虚拟环境
conda create -n bert_env python==3.8
- 激活环境
conda activate bert_env
- 安装依赖包,临时使用镜像源
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
环境搭建完成,激活依赖下载完成后可以按需求分别执行predict_single.py
和predict_batch.py
文件。
在THUCNews/data/test.csv数据集上进行测试的结果如下:
混淆矩阵的详细参考这里
分类指标如下:
指标 | 值 |
---|---|
Accuracy | 0.9434 |
Precision | 0.9438 |
Recall | 0.9434 |
项目采用Apache License 2.0许可。
bert-base-chinese
需要自行下载,下载方式参考classification-base-bert/bert-base-chinese/README.md
classification-base-bert/model_config.py
的model_path
是预训练后的模型文件,推理之前需要先执行model_train.py
文件进行训练。