我想用自定义生成器类训练模型,但model.fit()给了我以下错误:
Traceback (most recent call last):
File "C:/Users/benja/PycharmProjects/mri/modelTrainer.py", line 100, in <module>
the_generator = DataGenerator()
TypeError: 'module' object is not callable
下面是我编写的DataGenerator类:
import numpy as np
import math
from tensorflow.keras.utils import Sequence
import os
import nibabel as nib
import pandas as pd
niftiFilesDirPath = './train/nifti/'
class DataGenerator(Sequence):
def __init__(self):
csvFileName = "combined.csv"
niftiFileNames = [s for s in os.listdir(niftiFilesDirPath) if s.endswith(".nii.gz")]
print("Files fount: ", len(niftiFileNames))
dataframe = pd.read_csv(niftiFilesDirPath + csvFileName)
niftiFileLables = []
for niftiFileName in niftiFileNames:
label = dataframe.loc[dataframe["Image ID"] == int(niftiFileName.split(".")[0])]
labelValue = label['Has Parkinson'].values[0]
if labelValue == 0:
niftiFileLables.append([0,1])
else:
niftiFileLables.append([1,0])
self.x, self.y = niftiFileNames, niftiFileLables
self.batch_size = 8
def __len__(self):
return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) *
self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) *
self.batch_size]
niftiImagesList = []
for niftiFileName in batch_x:
niftiFile = os.path.join(niftiFilesDirPath, niftiFileName)
theImage = nib.load(niftiFile)
imageNpArray = theImage.get_fdata()
niftiImagesList.append(imageNpArray)
print(imageNpArray.shape)
print(imageNpArray.dtype)
return np.array(niftiImagesList), np.array(batch_y)
下面是我想在DataGenerator类上培训的模型:
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv3D, MaxPool3D
from tensorflow.keras import optimizers, losses
import DataGenerator
model = Sequential()
model.add(Conv3D(8, (3, 3, 3), activation='relu', input_shape=(256,256,128,1)))
model.add(MaxPool3D((3, 3, 3)))
model.add(Dense(256, activation='tanh'))
model.add(Dense(2, activation='linear'))
# setup model
model.compile(optimizer=optimizers.Adam(1e-3),
loss=losses.mean_squared_error,
metrics=['mae'])
# Generators
the_generator = DataGenerator()
# Train model on dataset
model.fit(x=the_generator, epochs=10)
代码似乎是正确的,但我得到了错误,尽管多次尝试。如何将tf.keras.utils.Sequence与Tensorflow 2中的model.fit()一起使用
因为这条线:
它将是一个模块,您需要在模块内部导入定义,而不是模块本身。此错误与Python语法有关,与TensorFlow或Keras无关
相关问题 更多 >
编程相关推荐