图形处理器
FlyAI-GPU的Python项目详细描述
Flyai
整体运行流程
下载读取csv数据-->处理csv中的数据,转成机器可识别的矩阵-->分批返回数据-->编写算法训练模型-->验证模型、保存模型、使用模型
app.yaml
是项目的配置文件,项目目录下必须存在这个文件
processor.py
处理数据的输入输出文件,把通过csv文件返回的数据,处理成能让程序识别、训练的矩阵。
可以自己定义输入输出的方法名,在
app.yaml
中声明即可。默认为:
definput_x(self,**datas):''' 参数为csv中作为输入x的一条数据,该方法会被Dataset多次调用 :param datas: 输入的数据列表 :return: 返回矩阵 '''passdefinput_y(self,**datas):''' 参数为csv中作为输入y的一条数据,该方法会被Dataset多次调用 :param datas: 数据标签列表 :return: 返回矩阵 '''passdefoutput_y(self,data):''' 验证时使用,把模型输出的y转为对应的结果 :param data: 预测返回的数据 :return: 返回预测的标签 '''pass
dataset.py
该文件在
flyai.dataset
包中,通过next_batch()
方法获得x_train
y_train
x_test
y_test
数据
main.py
程序入口,编写算法,训练模型的文件。在该文件中实现自己的算法,然后通过
dataset.py
中的next_batch
方法获取训练和测试数据。
model.py
训练好模型之后可以继承
flyai.model.base
包中的base
重写下面三个方法实现模型的保存、验证和使用。defpredict(self,path,name,**data):''' 使用模型 :param path: 模型所在的路径 :param name: 模型的名字 :param data: 模型的输入参数 :return: '''passdefevaluate(self,path,name):''' 验证模型 :param path: 模型的路径 :param name: 模型的名字 :return: 返回验证的准确率 '''passdefsave_model(self,session,path,name,overwrite=False):''' 保存模型 :param session: 训练模型的sessopm :param path: 要保存模型的路径 :param name: 要保存模型的名字 :param overwrite: 是否覆盖当前模型 :return: '''self.check(path,overwrite)
path.py
可以设置数据文件、模型文件的存放路径。
tensorflow_accuracy.py
该文件为验证文件,成功训练模型之后。会调用模型,并给该模型打分。根据不同类别的模型,需要实现不同的验证。