文本分类¶
此教程主要针对初级用户,内容主要包括利用MindText的classification进行网络训练。
目录结构说明¶
mindtext/classification/
├── config # 模型、数据集等配置文件目录
│ └── fasttext
│ └── fasttext.yaml # 模型配置文件
├── dataset # 数据集的读取和处理
│ ├── dataset.py # 生成训练、验证和测试数据
│ ├── FastTextDataPreProcess.py # 数据预处理
│ ├── load_data.py # 加载数据
│ └── mindrecord.py # 中间数据文件处理
├── docs
│ └── getting_started.py # 用户使用指南
├── model
│ ├── backbones
│ | └── fasttext.py # 模型的骨干架构
| ├── classifiers
│ | └── base.py # 模型分类器
| ├── build_model.py # 创建模型
│ ├── loss.py # 创建loss
│ └── optimizer.py # 创建优化器
├── test # 测试模型
├── tools
│ ├── eval.py # 评估模型
│ ├── export.py # 导出模型的checkpoint文件
│ ├── infer.py # 模型推理
│ └── train.py # 训练模型
└── utils
├── config.py # 处理yaml文件
└── lr_schedule.py # 学习率设置(可选)
环境安装与配置¶
下载MindText并进入文件夹
git clone https://gitee.com/mindspore/mindtext.git
cd mindtext
安装
python setup.py install
所需环境:
pip install pandas==1.2.4
pip install numpy==1.20.3
pip install mindspore==1.2.0
pip install tqdm==4.61.1
pip install PyYAML==5.4.1
pip install scikit_learn==0.24.1
pip install spacy==2.3.1
python -m spacy download en_core_web_lg==2.3.1
数据准备¶
下载并解压数据集.
你可以从数据集下载页面下载,并按下方目录结构放置:
/root/fasttext/ag_news_csv
├── ag_news_csv
│ ├── train.csv
│ └── text.csv
├── dbpedia_csv
│ ├── train.csv
│ └── text.csv
├── yelp_review_polarity_csv
│ ├── train.csv
│ └── text.csv
自定义配置文件¶
进入./config/fasttext目录,打开fasttext.yaml文件
fasttext.yaml文件中有多个参数配置, 案例如下:
# Builtin Configurations
model_name: "fasttext" # 模型名称
device_target: "GPU" # 设备,可选GPU, ASCEND
PREPROCESS: # 数据预处理参数
max_len: 467 # 数据集最长文本长度
mid_dir_path: "./ag_temp_data" # mindrecord生成路径
vocab_file_path: "your_path/vocab.txt" # 生成/读取词表路径
MODEL_PARAMETERS: # 模型参数
vocab_size: 1383812 # 词表大小
embedding_dims: 16 # 词嵌入大小
num_class: 4 # 类别数
OPTIMIZER: # 优化器参数
function: "Adam" # 优化器类型,以Adam优化器为例
lr: 0.20 # 学习率
min_lr: 0.000001 # 最小学习率
decay_steps: 236 # 学习率衰减补偿
warmup_steps: 400000 # warm_up步长
poly_lr_scheduler_power: 0.001 # 学习率策略
TRAIN: # 训练参数
data_path: "your_path/train.csv" # 训练集路径
batch_size: 512 # batch_size
buckets: [64, 128, 467] # 训练集数据加载块大小
epoch: 5 # 训练epoch数
epoch_count: 1
loss_function: "SoftmaxCrossEntropyWithLogits" # 损失函数类型
pretrain_ckpt_dir: "" # 断点训练检查点
save_ckpt_steps: 116 # 检查点保存步长
save_ckpt_dir: "your_path" # 检查点保存路径
keep_ckpt_max: 10 # 最大检查点数
run_distribute: False # 分布式训练,默认False
distribute_batch_size_gpu: 64 # 分布式训练单卡batch_size
VALID: # 测试参数
data_path: "your_path/test.csv" # 测试集路径
batch_size: 512 # batch_size
model_ckpt: "your_path/fasttext-*_***.ckpt" # 模型检查点
test_buckets: [467] # 测试集数据加载块大小
INFER: # 推断参数
data_path: "your_path/test.csv" # 推断数据路径
batch_size: 2048 # batch_size
model_ckpt: "your_path/fasttext-*_***.ckpt" # 模型检查点
buckets: [467] # 推断数据加载块大小
EXPORT: # 模型导出参数
device_id: 0 # 设备id
ckpt_file: "your_path/fasttext-*_***.ckpt" # 检查点路径
file_name: "fasttexts" # 文件名称
file_format: "AIR" # 文件类型,可选AIR, ONNX, MINDIR
模型训练¶
进入mindtext/classification/tools目录。
cd mindtext/classification/tools
执行下面的命令开始模型训练:
python train.py -c ../configs/fasttext/fasttext.yaml
模型评估¶
进入mindtext/classification/tools目录。
执行下面的命令开始模型评估:
python eval.py -c ../configs/fasttext/fasttext.yaml
模型导出¶
进入mindtext/classification/tools目录。
执行下面的命令开始模型导出:
python export.py -c ../configs/fasttext/fasttext.yaml
模型预测¶
进入mindtext/classification/tools目录。
执行下面的命令开始模型预测:
python infer.py -c ../configs/fasttext/fasttext.yaml
-c参数是指定训练的配置文件路径,训练的具体超参数可查看yaml文件