基于pytorch的通用多特征分类库

nymph的Python项目详细描述


nymph

基于Pytorch的多特征分类框架

概述

基于Pytorch的多特征序列标注和普通分类框架,包装的还算可以。可以直接照搬demo,拿csv文件去训练预测。

功能

  • 多特征分类(特征包括字符型、数值型,其中字符型最好是单个词而非词组或句子)
  • 输出详细分类详情

原理

  • 预处理:对各列非数值类数据分别构建词表并使用Embedding获得低维稠密向量,对数值类数据进行标准化,然后拼接获得各行对应向量
  • 模型:
    • 普通分类:全连接神经网络,NormClassifier(具体效果看特征)
    • 序列标注:Bi-LSTM-CRF,SeqClassifier(效果较好)
  • 预测:使用sklearn获取f1分数,并且获得各类别分类详情

安装

使用如下命令进行安装

pip install -U nymph

使用示例

训练数据

数据可见test.csv

如图:

test_data

普通分类

训练模型

源码如下,具体可参见train_demo_by_norm.py

# -*- coding: utf-8 -*-importosimportpandasaspdfromnymph.dataimportNormDataset,split_datasetfromnymph.modulesimportNormClassifierproject_path=os.path.abspath(os.path.join(__file__,'../../'))data_path=os.path.join(project_path,r'data\test.csv')save_path='demo_saves'if__name__=='__main__':# 读取数据data=pd.read_csv(data_path)# 构建分类器classifier=NormClassifier()classifier.init_data_processor(data,target_name='label')# 构建数据集norm_ds=NormDataset(data)train_ratio=0.7dev_ratio=0.2test_ratio=0.1train_ds,dev_ds,test_ds=split_dataset(norm_ds,(train_ratio,dev_ratio,test_ratio))# 训练模型# classifier.train(train_set=train_ds, dev_set=dev_ds, save_path=save_path)classifier.train(train_set=norm_ds,dev_set=norm_ds,save_path=save_path)# 测试模型test_score=classifier.score(norm_ds)print('test_score',test_score)# 预测模型pred=classifier.predict(norm_ds)print(pred)
训练结果

终端输出

train_demo_by_norm_result

预测模型

源码如下,具体可参见predict_demo_by_norm.py

# -*- coding: utf-8 -*-importosimportpandasaspdfromnymph.dataimportNormDataset,split_datasetfromnymph.modulesimportNormClassifierproject_path=os.path.abspath(os.path.join(__file__,'../../'))data_path=os.path.join(project_path,r'data\test.csv')save_path='demo_saves'if__name__=='__main__':# 读取数据data=pd.read_csv(data_path)# 构建分类器classifier=NormClassifier()# 加载分类器classifier.load(save_path)# 构建数据集norm_ds=NormDataset(data)# 预测模型pred=classifier.predict(norm_ds)print(pred)# 获取各类别分类结果,并保存信息至文件中classifier.report(norm_ds,'report.csv')# 对数据进行预测,并将数据和预测结果写入到新的文件中classifier.summary(norm_ds,'summary.csv')
预测结果

如图:predict_demo_by_norm_result

report.csv内容

report

summary.csv内容

summary

序列标注

训练模型

源码如下,具体可参见train_demo_by_seq.py

# -*- coding: utf-8 -*-importosimportpandasaspdfromnymph.dataimportSeqDataset,split_datasetfromnymph.modulesimportSeqClassifierproject_path=os.path.abspath(os.path.join(__file__,'../../'))data_path=os.path.join(project_path,r'data\test.csv')save_path='demo_saves_seq'defsplit_fn(dataset:list):returnlist(range(len(dataset)+1))if__name__=='__main__':# 读取数据data=pd.read_csv(data_path)# 构建分类器classifier=SeqClassifier()classifier.init_data_processor(data,target_name='label')# 构建数据集norm_ds=SeqDataset(data,split_fn=split_fn,min_len=4)train_ratio=0.7dev_ratio=0.2test_ratio=0.1train_ds,dev_ds,test_ds=split_dataset(norm_ds,(train_ratio,dev_ratio,test_ratio))# 训练模型# classifier.train(train_set=train_ds, dev_set=dev_ds, save_path=save_path)classifier.train(train_set=norm_ds,dev_set=norm_ds,save_path=save_path)# 测试模型test_score=classifier.score(norm_ds)print('test_score',test_score)# 预测模型pred=classifier.predict(norm_ds)print(pred)
训练结果

终端输出

train_demo_by_seq_result

预测模型

源码如下,具体可参见predict_demo_by_seq.py

# -*- coding: utf-8 -*-importosimportpandasaspdfromnymph.dataimportSeqDataset,split_datasetfromnymph.modulesimportSeqClassifierproject_path=os.path.abspath(os.path.join(__file__,'../../'))data_path=os.path.join(project_path,r'data\test.csv')save_path='demo_saves_seq'defsplit_fn(dataset:list):returnlist(range(len(dataset)+1))if__name__=='__main__':# 读取数据data=pd.read_csv(data_path)# 构建分类器classifier=SeqClassifier()# 加载分类器classifier.load(save_path)# 构建数据集seq_ds=SeqDataset(data,split_fn=split_fn,min_len=4)# 预测模型pred=classifier.predict(seq_ds)print(pred)# 获取各类别分类结果,并保存信息至文件中classifier.report(seq_ds,'seq_demo_report.csv')# 对数据进行预测,并将数据和预测结果写入到新的文件中classifier.summary(seq_ds,'seq_demo_summary.csv')

如图:predict_demo_by_seq_result

seq_demo_report.csv内容

seq_demo_report

seq_demo_summary.csv内容

seq_demo_summary

更新历史

  • 0.1.0: 初始化项目,增加全连接模型
  • 0.2.0: 增加序列标注模型,大幅重构项目结构与内部实现代码
  • 0.2.1: 更新代码,使GPU和CPU下同时可用
  • 0.2.2: 增加将训练过程的loss和f1值写入到TensorBoard中
  • 0.2.3: 增加Norm Classifier的TensorBoard功能

参考

  1. python - Sorting list based on values from another list? - Stack Overflow

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
java为什么即使我已经给出了代码中的所有权限,该代码也没有在emulator中运行?   java Android Studio正在抛出“线程中的异常”main“javax.net.ssl.SSLException:收到致命警报:协议\版本”   java中的for循环嵌套foreach语句   java读取/src/main/resources和/webinf/classes下的文件   java无法以此格式构造JSON响应   身份验证尝试从CAS secure rest api获取响应,但从java客户端获取登录页面作为响应   如何在java中使用excel从第1列和第3列获取单元格值,并将其作为键值对放入map中   在Java程序中打开Windows虚拟键盘   java有没有递归调用findMatch的方法?   java Pig脚本/命令,用于根据多个字符串筛选文件   java最小数量应匹配,应与POST匹配   java打开/关闭声音按钮不工作   Java嵌入式数据库持久性   java在方法调用时引发异常   java文本文件被覆盖而不是保存的问题   java Hibernate sql注释