如何在tensorflow/keras中使用预定的内核列表初始化Conv2D层?

2024-09-30 20:28:38 发布

您现在位置:Python中文网/ 问答频道 /正文

我想使用Conv2D层来跨越输入图像并运行三个2x2内核

这不是tensorflow的目的,但我真的希望使用tensorflow作为后端引擎来高效运行内核,并在不同的设备GPU和/或CPU之间分配工作负载

我尝试了如下代码。但它似乎不太管用

import tensorflow as tf

class InitConvKernels(tf.keras.initializers.Initializer):

  def __init__(self, num_kernels, kernel_tensor):
    self.kernel_list= kernel_tensor
    self.index = -1
    self.num_kernels = num_kernels

  def __call__(self, shape, dtype=None):
    index += 1 
    assert(self.index <= self.num_kernels) # doesn't affect anything
    tf.print(shape) # doesn't work
    return self.kernel_list[index]

  def get_config(self):
    return {'kernel_list': self.kernel_list, 'num_kernels': self.num_kernels}

我正在调用自定义初始值设定项,但返回的层为空:

kernel_list = tf.constant([[[-1, -1],  [-1, -1]], [[1, 1],   [1, 1]],  [[-1, 1],  [1, -1]],])
layer = layers.Conv2D(
    filters=3,
    kernel_size=2,
    kernel_initializer=InitConvKernels(3,kernel_list),
    bias_initializer=initializers.Zeros()
)

layer.variables为空([]layer.layer.get_weights()也是空的([]

我的目标是评估输入图像上kernel_list中三个内核的卷积,并汇总所有结果


Tags: 图像selflayerindextftensorflowdefkernel
1条回答
网友
1楼 · 发布于 2024-09-30 20:28:38
from PIL import Image
import requests
from io import BytesIO
import numpy as np
from keras.models import Sequential
from keras.layers import Conv2D


response = requests.get('https://upload.wikimedia.org/wikipedia/commons/thumb/0/02/Stack_Overflow_logo.svg/1280px-Stack_Overflow_logo.svg.png')
image = Image.open(BytesIO(response.content))

正在加载image from urlenter image description here

构建一个运行kernel的模型(运行更多内核,使kernel_init成为生成器,并在初始化Conv2D时随时调整过滤器的数量)

def kernel_init(shape, dtype=None, partition_info=None):
    kernel = np.zeros(shape)
    kernel[:,:,0,0] = np.array([[1,0,1],[-1,0,-1],[1,0,1]])
    return kernel

#Build Keras model
model = Sequential()
model.add(Conv2D(1, [3,3], kernel_initializer=kernel_init, 
                 input_shape=(251,1280,4), padding="valid"))
model.build()

# To apply existing filter, we use predict with no training
out = model.predict(image)

以及可视化输出:

import matplotlib.pyplot as plt
plt.matshow(out[0,:,:,0])

enter image description here

编辑: 值得一提的是OpenAI's Triton,它可以帮助使用更高级的语言和框架(如pytorch)来运行高效的GPU代码:

Python-like programming language which enables researchers with no CUDA experience to write highly efficient GPU code—most of the time on par with what an expert would be able to produce.

相关问题 更多 >