TensorFlow LinearClassifier断言失败:[标签必须<=n\u类1]

2024-06-28 18:52:37 发布

您现在位置:Python中文网/ 问答频道 /正文

我收到一个错误,指出我拥有的标签数量大于我正在使用的tf.estimator.LinearClassifier的n\u classes-1

我假设这与用于训练和测试的输入有关,它决定了特性和标签。我已经为此测试了不同的配置,但找不到正确的答案。我使用的数据是一个包含4个int值的CSV,最后一个是标签。我在python3.6的Windows上运行

def my_input_fn(data_file, num_epochs, batch_size):
    dataset = tf.data.experimental.make_csv_dataset(
        data_file,
        batch_size=batch_size,
        column_names=_CSV_COLUMNS, # ['int1', 'int2', 'int3', 'int4'] 
        label_name='int4',
        na_value="?",
        num_epochs=num_epochs,
        ignore_errors=True)
    return dataset

train_inpf = functools.partial(my_input_fn, train_file, num_epochs=2, shuffle=True, batch_size=32)
test_inpf = functools.partial(my_input_fn, test_file, num_epochs=1, shuffle=False, batch_size=1)

如果有用的话,下面是我如何设置分类器的:作为特性使用的3个int列被指定为分类数据

col1 = tf.feature_column.categorical_column_with_vocabulary_list(
    'int1', column_uniques_lists['int1'], dtype=tf.int64)

col2 = tf.feature_column.categorical_column_with_vocabulary_list(
    'int2', column_uniques_lists['int2'], dtype=tf.int64)


col3 = tf.feature_column.categorical_column_with_vocabulary_list(
    'int3', column_uniques_lists['int3'], dtype=tf.int64)

my_categorical_columns = [col1,col2,col3]

classifier = tf.estimator.LinearClassifier(feature_columns=my_categorical_columns,                                           
n_classes=len(column_uniques_lists['int4']), model_dir='.\\SaveLC\\model_dir')

列\u uniques\u列出了一个字典,其中包含每列中包含的所有唯一值

int4列中有7个唯一的值,每个值对应一个类,因此我希望让模型运行时根据[int1,int2,int3]的输入响应的int4进行预测


Tags: sizemytfbatchcolumnnumfeaturefile