用于X射线图像和骨骼检测的图像阈值算法

2024-09-27 07:32:01 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个关于OpenCV (python)的小项目,其中一个步骤是从人体获取x射线图像,并将其转换为二值图像,其中白色像素表示存在某些骨骼的位置,黑色表示没有骨骼

由于有时“骨骼部分”可能比另一个区域的“非骨骼部分”更暗,因此简单的阈值设置不起作用。我也试过adaptive threshold,我看不出有什么不同

我提出了一个简单的算法,为每一行应用一个简单的阈值。 代码如下:

def threshhold(image, val):
    image = image.copy()

    for row_idx in range(image.shape[0]):
        max_row = image[row_idx].max()
        min_row = image[row_idx].min()
        tresh = np.median(image[row_idx]) + (val * (max_row - min_row))
        # Or use np.mean instead of np.median

        _, tresh = cv2.threshold(image[row_idx], tresh, 255, cv2.THRESH_BINARY)
        image[row_idx] = tresh.ravel()
    return image

下面是执行相同工作的代码,但是逐列而不是逐行:

def threshhold2(image, val):
    image = image.copy()

    for row_idx in range(image.shape[1]):
        max_row = image[:, row_idx].max()
        min_row = image[:, row_idx].min()
        tresh = np.median(image[:, row_idx]) + (val * (max_row - min_row))
        # Or use np.mean instead of np.median

        _, tresh = cv2.threshold(image[:, row_idx], tresh, 255, cv2.THRESH_BINARY)
        image[:, row_idx] = tresh.ravel()
    return image

此方法适用于以下图像: demo_image_1 这一次不太好,但也没那么糟: demo_image_2 非常可怕: enter image description here 只有左半边看起来不错 enter image description here

如你所见;该算法仅适用于某些图像。 我会很高兴看到更多有经验的人的想法

顺便说一句,图像不适合我

完整源代码:

import os
import cv2
import numpy as np

files_to_see = os.listdir("data_set")
files_to_see.sort()
current_file = 0

print(files_to_see)

def slice(image, size):
    out = []
    x_count = image.shape[1] // size
    y_count = image.shape[0] // size
    for y_idx in range(y_count):
        for x_idx in range(x_count):
            out.append(
                (
                    (y_idx, x_idx),
                    image[y_idx * size: (y_idx + 1) * size,
                          x_idx * size: (x_idx + 1) * size]
                )
            )
    return y_count, x_count, out

def normalize(image):
    image = image.copy()
    min_pix = image.min()
    max_pix = image.max()
    for y in range(image.shape[0]):
        for x in range(image.shape[1]):
            val = image[y, x]
            val -= min_pix
            val *= 255 / (max_pix - min_pix)
            image[y, x] = round(val)
    # image -= min_pix
    # image *= round(255 / (max_pix - min_pix))
    return image


def threshhold(image, val, method):
    image = image.copy()

    for row_idx in range(image.shape[0]):
        max_row = image[row_idx].max()
        min_row = image[row_idx].min()
        # tresh = np.median(image[row_idx]) + (val * (max_row - min_row))
        tresh = method(image[row_idx]) + (val * (max_row - min_row))

        _, tresh = cv2.threshold(image[row_idx], tresh, 255, cv2.THRESH_BINARY)
        image[row_idx] = tresh.ravel()
    return image

def threshhold2(image, val, method):
    image = image.copy()

    for row_idx in range(image.shape[1]):
        max_row = image[:, row_idx].max()
        min_row = image[:, row_idx].min()
        tresh = method(image[:, row_idx]) + (val * (max_row - min_row))

        _, tresh = cv2.threshold(image[:, row_idx], tresh, 255, cv2.THRESH_BINARY)
        image[:, row_idx] = tresh.ravel()
    return image

def recalculate_threshhold(v):
    global original_current_image, thresh_current_image, y_c, x_c, slices

    method = np.mean
    if cv2.getTrackbarPos("method", "xb labeler") == 0:
        method = np.median
    thresh_current_image = threshhold2(original_current_image, cv2.getTrackbarPos("threshhold_value", "xb labeler") / 1000, method)
    y_c, x_c, slices = slice(thresh_current_image, 128)

def thresh_current_image_mouse_event(event, x, y, flags, param):
    if event == 1:
        print(x // 128, y // 128)
        cv2.imshow("slice", slices[(x // 128) + (y // 128) * x_c][1])

cv2.namedWindow("xb labeler")
cv2.createTrackbar("threshhold_value", "xb labeler", 0, 1000, recalculate_threshhold)
cv2.createTrackbar("method", "xb labeler", 0, 1, recalculate_threshhold)

cv2.namedWindow("thresh_current_image")
cv2.setMouseCallback("thresh_current_image", thresh_current_image_mouse_event)

def init():
    global original_current_image, thresh_current_image, x_c, y_c, slices, files_to_see, current_file
    original_current_image = cv2.imread("data_set/" + files_to_see[current_file], cv2.CV_8UC1)
    original_current_image = cv2.resize(original_current_image, (512, 512))
    original_current_image = normalize(original_current_image)
    original_current_image = cv2.GaussianBlur(original_current_image, (5, 5), 10)
    recalculate_threshhold(1)
    y_c, x_c, slices = slice(thresh_current_image, 128)

init()

while True:

    cv2.imshow("thresh_current_image", thresh_current_image)
    cv2.imshow("xb labeler", original_current_image)
    k = cv2.waitKey(1)
    if k == ord('p'):
        cv2.imwrite("ssq.png", thresh_current_image)
        current_file += 1
        init()

cv2.destroyAllWindows()

编辑:添加原始图像:

enter image description here

enter image description here

enter image description here

enter image description here


Tags: imagedefnpvalcurrentmincv2method

热门问题