Pytorch向后()的任何Tensorflow等价物?尝试将渐变发送回TF模型以进行backprop

2024-10-01 19:26:07 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试实现一个分割学习模型,在这个模型中,我在客户机上的TF模型接收数据并生成一个中间输出。这个中间输出将被发送到运行Pytorch模型的服务器,该服务器将把它作为输入,并将损失降至最低。然后,我的服务器将客户机梯度发送回TF模型,以便TF模型更新其权重

如何让我的TF模型使用从服务器返回的梯度更新其权重

# pytorch client
client_output.backward(client_grad)
optimizer.step()

使用PyTorch,我可以只做client_pred.backward(client_grad)client_optimizer.step()

如何使用Tensorflow客户端实现同样的功能?我用tape.gradient(client_grad, model.trainable_weights)试过GradientTape,但它一个也没给我。我认为这是因为在磁带上下文中没有计算,client_grad只是一个持有梯度的张量,没有连接到模型的层

有什么方法可以用tf的apply_gradients()或compute_gradients()来实现这一点吗

我只有客户端最后一层的渐变(由服务器发送)。我正在尝试计算客户端的所有渐变并更新权重

多谢各位


class TensorflowModel(tf.keras.Model):
        def __init__(self, D_in, H, D_out):
            super(TensorflowModel, self).__init__()
            self.d1 = Dense(H, activation='relu', input_shape=(D_in,))
            self.d2 = Dense(D_out)

        def call(self, x):
            x = self.d1(x)
            return self.d2(x)

tensorflowModel = TensorflowModel(D_in, H, D_out)
tensorflowOptimizer = tf.optimizers.Adam(lr=1e-4)

serverModel = torch.nn.Sequential(
        torch.nn.Linear(10, 50),
        torch.nn.ReLU(),
        torch.nn.Linear(50, 10)
    )
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(serverModel.parameters(), lr=1e-4)

for t in range(N):
    // let x be minibatch
    // let y be labels of minibatch

    client_pred = tensorflowModel(x)

    client_output = torch.from_numpy(client_pred.numpy())
    client_output.requires_grad = True

    y_pred = serverModel(client_output)
    loss = loss_fn(y_pred, y)  
    optimizer.zero_grad()
    loss.backward()
    optimizer.step() // update server weights

    // now retrieve client grad for last layer
    client_grad = client_output.grad.detach().clone().numpy()
    client_grad = tf.convert_to_tensor(client_grad) // change to tf tensor

    // now compute all client gradients and update client weights
    // HOW DO I DO THIS? 

我应该如何更新客户端权重?如果客户机是pytorch模型,我可以只做client_pred.backward(client_grad)和client_optimizer.step()。我不知道如何使用梯度带来计算梯度,因为client_grad是在服务器上计算的,是一个pytorch张量,转换成tf张量


Tags: 模型self服务器clientoutputtfnntorch

热门问题