我试着在tensorflow网站https://www.tensorflow.org/tutorials/estimator/linear上阅读教程
这是一个使用线性回归的代码,但我不能这样做,因为我无法理解指定的错误
但我得到了一个不受支持的可调用错误
这是密码
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import clear_output
from six.moves import urllib
import tensorflow.compat.v2.feature_column as fc
import tensorflow as tf
dftrain = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')
dfeval = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')
y_train = dftrain.pop('survived')
y_eval = dfeval.pop('survived')
CATEGORICAL_COLUMNS = ['sex', 'n_siblings_spouses', 'parch', 'class', 'deck', 'embark_town', 'alone']
NUMERIC_COLUMNS = ['age', 'fare']
feature_columns = []
for feature_name in CATEGORICAL_COLUMNS:
vocabulary = dftrain[feature_name].unique()
feature_columns.append(tf.feature_column.categorical_column_with_vocabulary_list(feature_name, vocabulary))
for feature_name in NUMERIC_COLUMNS:
feature_columns.append(tf.feature_column.numeric_column(feature_name, dtype=tf.float32))
def make_input_fn(data_df, label_df, num_epochs=20, shuffle=True, batch_size=32):
def input_function():
ds = tf.data.Dataset.from_tensor_slices((dict(data_df), label_df))
if shuffle:
ds = ds.shuffle(1000)
ds = ds.batch(batch_size).repeat(num_epochs)
return ds
return input_function()
train_input_fn = make_input_fn(dftrain, y_train)
eval_input_fn = make_input_fn(dfeval, y_eval, num_epochs=1, shuffle=False)
linear_est = tf.estimator.LinearClassifier(feature_columns=feature_columns)
linear_est.train(train_input_fn)
result = linear_est.evaluate(eval_input_fn)
clear_output()
print(result)
这就是我得到的错误
Traceback (most recent call last):
File "C:/Users/gotru/PycharmProjects/tensor/Linear_Regression.py", line 46, in <module>
linear_est.train(train_input_fn)
File "C:\Users\gotru\anaconda3\envs\tensorflow\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 374, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "C:\Users\gotru\anaconda3\envs\tensorflow\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1164, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "C:\Users\gotru\anaconda3\envs\tensorflow\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1191, in _train_model_default
input_fn, ModeKeys.TRAIN))
File "C:\Users\gotru\anaconda3\envs\tensorflow\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1028, in _get_features_and_labels_from_input_fn
self._call_input_fn(input_fn, mode))
File "C:\Users\gotru\anaconda3\envs\tensorflow\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1106, in _call_input_fn
input_fn_args = function_utils.fn_args(input_fn)
File "C:\Users\gotru\anaconda3\envs\tensorflow\lib\site-packages\tensorflow_core\python\util\function_utils.py", line 57, in fn_args
args = tf_inspect.getfullargspec(fn).args
File "C:\Users\gotru\anaconda3\envs\tensorflow\lib\site-packages\tensorflow_core\python\util\tf_inspect.py", line 257, in getfullargspec
return _getfullargspec(target)
File "C:\Users\gotru\anaconda3\envs\tensorflow\lib\inspect.py", line 1132, in getfullargspec
raise TypeError('unsupported callable') from ex
TypeError: unsupported callable
提前谢谢
下面的代码工作正常,没有任何错误
输出
相关问题 更多 >
编程相关推荐