将Pytorch中实现的自定义优化器转换为Tensorflow

2024-10-17 06:31:48 发布

您现在位置:Python中文网/ 问答频道 /正文

我不擅长PyTorch,因此我希望您能帮助我将此代码转换为TensorFlow。我试着看了一些文章,但内容有点密集,所以一点解释也是值得的,这样整个社区都可以从中受益。 """ 进口火炬 导入副本

PESG类(torch.optim.Optimizer):

def __init__(self, model, a=None, b=None, alpha=None, imratio=0.1, margin=1.0, lr=0.1, gamma=500, clip_value=1.0, weight_decay=1e-5, **kwargs):
   
    assert a is not None, 'Found no variable a!'
    assert b is not None, 'Found no variable b!'
    assert alpha is not None, 'Found no variable alpha!'
    
    self.p = imratio
    self.margin = margin
    self.model = model
    
    self.lr = lr
    self.gamma = gamma
    self.clip_value = clip_value
    self.weight_decay = weight_decay
    
    self.a = a 
    self.b = b 
    self.alpha = alpha 

    # TODO! 
    self.model_ref = []
    for var in list(self.model.parameters())+[self.a, self.b]: 
        self.model_ref.append(torch.empty(var.shape).normal_(mean=0, std=0.01).cuda())
    self.model_acc = []
    for var in list(self.model.parameters())+[self.a, self.b]: 
        self.model_acc.append(torch.zeros(var.shape, dtype=torch.float32, device="cuda", requires_grad=False).cuda()) 
    self.T = 0
    self.step_counts = 0

    def get_parameters(params):
        for p in params:
            yield p
    self.params = get_parameters(list(model.parameters())+[a,b])
    self.defaults = dict(lr=self.lr, 
                         margin=margin, 
                         gamma=gamma, 
                         p=imratio, 
                         a=self.a, 
                         b=self.b,
                         alpha=self.alpha,
                         clip_value=clip_value,
                         weight_decay=weight_decay,
                         model_ref = self.model_ref,
                         model_acc = self.model_acc
                         )
    
    super(PESG, self).__init__(self.params, self.defaults)
 
@property    
def optim_steps(self):
    return self.step_counts

def update_lr(self, lr):
    self.param_groups[0]['lr']=lr

@torch.no_grad()
def step(self):
    """Performs a single optimization step.
    """
    for group in self.param_groups:
        weight_decay = group['weight_decay']
        clip_value = group['clip_value']
        self.lr =  group['lr']
        
        p = group['p']
        gamma = group['gamma']
        m = group['margin']
       
        model_ref = group['model_ref']
        model_acc = group['model_acc']

        a = group['a']
        b = group['b']
        alpha = group['alpha']
        
        # updates
        for i, p in enumerate(group['params']):
            if p.grad is None:
                continue  
            p.data = p.data - group['lr']*( torch.clamp(p.grad.data , -clip_value, clip_value) + 1/gamma*(p.data - model_ref[i].data) ) - group['lr']*weight_decay*p.data
            model_acc[i].data = model_acc[i].data + p.data

        alpha.data = alpha.data + group['lr']*(2*(m + b.data - a.data)-2*alpha.data)
        alpha.data  = torch.clamp(alpha.data,  0, 999)

    self.T += 1  
    self.step_counts += 1

def zero_grad(self):
    self.model.zero_grad()
    self.a.grad = None
    self.b.grad = None
    self.alpha.grad =None
    
def update_regularizer(self, decay_factor=None):
    if decay_factor != None:
        self.param_groups[0]['lr'] = self.param_groups[0]['lr']/decay_factor
        print ('Reducing learning rate to %.5f @ T=%s!'%(self.param_groups[0]['lr'], self.T))

    print ('Updating regularizer @ T=%s!'%(self.T))
    for i, param in enumerate(self.model_ref):
        self.model_ref[i].data = self.model_acc[i].data/self.T
    for i, param in enumerate(self.model_acc):
        self.model_acc[i].data = torch.zeros(param.shape, dtype=torch.float32, device="cuda", requires_grad=False).cuda()
    self.T = 0
    

“”“


Tags: selfalphanonerefdataclipmodelvalue