利用tensorflow的估计器api快速创建机器学习模型的框架。

estimator的Python项目详细描述


利用tensorflow的估计器api快速创建机器学习模型的框架。

安装

Install TensorFlow

pip install tensorflow

然后运行:

pip install estimator

建议使用virtual environment

开始

fromestimatorimportModelimporttensorflowastf# Define the network architecture - layers, number of units, activations etc.defnetwork(inputs):hidden=tf.layers.Dense(units=64,activation=tf.nn.relu)(inputs)outputs=tf.layers.Dense(units=10)(hidden)returnoutputs# Configure the learning process - loss, optimizer, evaluation metrics etc.model=Model(network,loss='sparse_softmax_cross_entropy',optimizer=('GradientDescent',0.001),metrics=['accuracy'])# Train the model using training datamodel.train(x_train,y_train,epochs=30,batch_size=128)# Evaluate the model performance on test or validation dataloss_and_metrics=model.evaluate(x_test,y_test)# Use the model to make predictions for new datapredictions=model.predict(x)# or call the model directlypredictions=model(x)

提供更多配置选项:

model=Model(network,loss='sparse_softmax_cross_entropy',optimizer=optimizer('GradientDescent',0.001),metrics=['accuracy'],model_dir='/tmp/my_model')

您还可以使用自定义的损失和度量函数:

defcustom_loss(labels,outputs):passdefcustom_metric(labels,outputs):passmodel=Model(network,loss=custom_loss,optimizer=('GradientDescent',0.001),metrics=['accuracy',custom_metric])

示例:cnn mnist分类器

本例基于tensorflow的MNIST example

fromestimatorimportModel,GradientDescent,TRAINimporttensorflowastfdefnetwork(x,mode):x=tf.reshape(x,[-1,28,28,1])x=tf.layers.Conv2D(filters=32,kernel_size=[5,5],padding='same',activation=tf.nn.relu)(x)x=tf.layers.MaxPooling2D(pool_size=[2,2],strides=2)(x)x=tf.layers.Conv2D(filters=64,kernel_size=[5,5],padding='same',activation=tf.nn.relu)(x)x=tf.layers.MaxPooling2D(pool_size=[2,2],strides=2)(x)x=tf.layers.Flatten()(x)x=tf.layers.Dense(units=1024,activation=tf.nn.relu)(x)x=tf.layers.Dropout(rate=0.4)(x,training=mode==TRAIN)x=tf.layers.Dense(units=10)(x)returnx# Configure the learning processmodel=Model(network,loss='sparse_softmax_cross_entropy',optimizer=('GradientDescent',0.001))

mode参数指定模型是用于训练、评估还是预测。

模型函数

为了获得更多的控制,可以使用Estimator类在函数内部配置模型:

fromestimatorimportEstimator,PREDICTimporttensorflowastfdefmodel(features,labels,mode):# Define the network architecturehidden=tf.layers.Dense(units=64,activation=tf.nn.relu)(features)outputs=tf.layers.Dense(units=10)(hidden)predictions=tf.argmax(outputs,axis=1)# In prediction mode, simply return predictions without configuring learning processifmode==PREDICT:returnpredictions# Configure the learning process for training and evaluation modesloss=tf.losses.sparse_softmax_cross_entropy(labels,outputs)optimizer=tf.train.GradientDescentOptimizer(0.001)accuracy=tf.metrics.accuracy(labels,predictions)returndict(loss=loss,optimizer=optimizer,metrics={'accuracy':accuracy})# Create the model using model functionmodel=Estimator(model)# Train the modelmodel.train(x_train,y_train,epochs=30,batch_size=128)

许可证

MIT

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
标头中的java cachecontrol未反映在jetty服务器上   java根据XSLT版本选择XSLT处理器   使用Lombok项目的java不明确方法调用   java powershell为每个文件构建一个要执行的字符串   java如何在Vaadin组合框中添加搜索图标?   从输入流读取有限长度的java最佳实践   动态操作后GridLayoutManager中的java项高度   java理解ThreadPoolExecutor中的池大小   在java中保持地址空间不变   java无法理解为什么我有空对象引用   java优化项目Euler#22   java会因为多线程而覆盖代码中的DataSnapshot吗   java文件夹层次结构遍历   java在循环中动态创建方法   音频Java无法组合2个以上。wav文件   java在具有UTF8样式名称的文件夹/目录中运行可运行的JAR文件   java如何在具有动态根键时反序列化JSON