如何使用python在Tensorflow、CNN中创建学习模型的多个实例?

2024-06-25 23:12:56 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个类可以实例化一个对象 et=EyeTracker() 其中,EyeTracker()是一个具有init构造函数的类,可以找到here。实际上,我想创建多个模型,并将一个模型的某些部分分配给新创建的模型,然后丢弃旧模型。下面有个错误。任何帮助都将不胜感激。我真的被困住了。我也问了一个类似的问题here。所以这两个答案都是受欢迎的

def train(args):
    train_data, val_data = load_data(args.input)
    train_data = prepare_data(train_data)
    val_data = prepare_data(val_data)
    with tf.variable_scope("", reuse=True) as scope:
        et = EyeTracker()
        train_loss_history, train_err_history, val_loss_history, val_err_history = et.train(train_data, val_data, \
                                            lr=args.learning_rate, \
                                            batch_size=args.batch_size, \
                                            max_epoch=args.max_epoch, \
                                            min_delta=1e-4, \
                                            patience=args.patience, \
                                            print_per_epoch=args.print_per_epoch,
                                            out_model=args.save_model)
        save some parts of the (et)
        scope.reuse_variables()
        et = EyeTracker()
        Assign some parts of previous (et) to the new one and continue training
        train_loss_history, train_err_history, val_loss_history, val_err_history = et.train(train_data, val_data, \
                                            lr=args.learning_rate, \
                                            batch_size=args.batch_size, \
                                            max_epoch=args.max_epoch, \
                                            min_delta=1e-4, \
                                            patience=args.patience, \
                                            print_per_epoch=args.print_per_epoch,
                                            out_model=args.save_model)

错误是

Variable conv1_eye_w does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope? i am really sorry, if my question is annoying.


Tags: 模型datasizebatchargstrainvalhistory
1条回答
网友
1楼 · 发布于 2024-06-25 23:12:56

部分解决。我将默认构造函数init更改为成员函数initialize(),并将值作为参数传递,如下所示

g = tf.Graph()
with g.as_default(): 
            et = EyeTracker()
            et.initialize(96,256,384,64,96,256,384,64)            
            result_temp = et.train(n_epoch, train_data, val_data, lr=args.learning_rate, batch_size=args.batch_size, max_epoch=args.max_epoch, min_delta=1e-4, patience=args.patience, print_per_epoch=args.print_per_epoch, out_model=args.save_model)

g = tf.Graph()
with g.as_default(): 
            et = EyeTracker()
            et.initialize(80,256,384,64,96,256,384,64)           
            result_temp = et.train(n_epoch, train_data, val_data, lr=args.learning_rate, batch_size=args.batch_size, max_epoch=args.max_epoch, min_delta=1e-4, patience=args.patience, print_per_epoch=args.print_per_epoch, out_model=args.save_model)

相关问题 更多 >