Python:使用多边形剪裁光栅图像的函数的代码检查

2024-09-27 09:35:31 发布

您现在位置:Python中文网/ 问答频道 /正文

我要求对以下代码进行复查。我有一个空间参考图像和一个多边形。我写了一个代码(见下文)来剪辑这个图像,以便保存一个新的图像(剪辑区域)。此函数基于要素类的几何图形剪裁光栅。基于几何图形的剪裁意味着您将使用要素类中所有要素的边界来剪裁光栅,而不是这些要素的最小边界矩形

输入:多边形图层和一个或多个光栅图层 输出:新光栅图层,剪裁为多边形边界

import osgeo.gdal
import shapefile
import struct, numpy, pylab
import numpy as np
import ogr
import osr,gdal
from shapely.geometry import Polygon
import osgeo.gdal as gdal
import sys
from osgeo import gdal, gdalnumeric, ogr, osr
import Image,ImageDraw

def world2Pixel(geoMatrix, x, y):
    """
    Uses a gdal geomatrix (gdal.GetGeoTransform()) to calculate
    the pixel location of a geospatial coordinate
    (http://geospatialpython.com/2011/02/clip-raster-using-shapefile.html)
    geoMatrix
    [0] = top left x (x Origin)
    [1] = w-e pixel resolution (pixel Width)
    [2] = rotation, 0 if image is "north up"
    [3] = top left y (y Origin)
    [4] = rotation, 0 if image is "north up"
    [5] = n-s pixel resolution (pixel Height)

    """
    ulX = geoMatrix[0]
    ulY = geoMatrix[3]
    xDist = geoMatrix[1]
    yDist = geoMatrix[5]
    rtnX = geoMatrix[2]
    rtnY = geoMatrix[4]
    pixel = np.round((x - ulX) / xDist).astype(np.int)
    line = np.round((ulY - y) / xDist).astype(np.int)
    return (pixel, line)

def Pixel2world(geoMatrix, x, y):
    ulX = geoMatrix[0]
    ulY = geoMatrix[3]
    xDist = geoMatrix[1]
    yDist = geoMatrix[5]
    coorX = (ulX + (x * xDist))
    coorY = (ulY + (y * yDist))
    return (coorX, coorY)

def RASTERClipByPolygon(inFile,poly,outFile):
    # Open the image as a read only image
    ds = osgeo.gdal.Open(inFile,gdal.GA_ReadOnly)
    # Check the ds (=dataset) has been successfully open
    # otherwise exit the script with an error message.
    if ds is None:
        raise SystemExit("The raster could not openned")
    # Get image georeferencing information.
    geoMatrix = ds.GetGeoTransform()
    ulX = geoMatrix[0]
    ulY = geoMatrix[3]
    xDist = geoMatrix[1]
    yDist = geoMatrix[5]
    rtnX = geoMatrix[2]
    rtnY = geoMatrix[4]
    # get the WKT (= Well-known text)
    dsWKT = ds.GetProjectionRef()
    # get driver information
    DriverName = ds.GetDriver().ShortName
    # open shapefile (= border of are of interest)
    shp = osgeo.ogr.Open(poly)
    if len(shp.GetLayer()) != 1:
         raise SystemExit('The shapefile must have exactly one layer')
    # Create an OGR layer from a boundary shapefile
    layer = shp.GetLayer(0)
    feature = layer.GetNextFeature()
    geometry = feature.GetGeometryRef()
    # Make sure that it is a polygon
    if geometry.GetGeometryType() != osgeo.ogr.wkbPolygon:
            raise SystemExit('This module can only load polygon')
    # get Extent of the clip area
    X_min, X_max, Y_min, Y_max = layer.GetExtent()
    # Convert the layer extent to image pixel coordinates
    uldX, uldY = world2Pixel(geoMatrix, X_min, Y_max)
    lrdX, lrdY = world2Pixel(geoMatrix, X_max, Y_min)
    # Calculate the pixel size of the new image
    pxWidth = int(lrdX - uldX)
    pxHeight = int(lrdY - uldY)
    # get the Coodinate of left-up vertex of the pixel
    X_minPixel, Y_maxPixel = Pixel2world(geoMatrix, uldX, uldY)
    # get polygon's vertices
    pts = geometry.GetGeometryRef(0)
    points = []
    for p in range(pts.GetPointCount()):
        points.append((pts.GetX(p), pts.GetY(p)))
    pnts = np.array(points).transpose()
    # work band by band
    nBands = ds.RasterCount
    # panchromatic
    if nBands == 1:
        band = ds.GetRasterBand(1)
        # get nodata value
        nodata = band.GetNoDataValue()
        # convert band in Array
        bandArray = band.ReadAsArray()
        del band
        # clip arrey
        bandArray_Area = bandArray[uldY:lrdY, uldX:lrdX]
        del bandArray
        # Create 2D Polygon Mask. Mode 'L', not '1', because
        # Numpy-1.5.0 / PIL-1.1.7 does not support the numpy.array(img)
        # conversion nicely for bivalue images.
        img = Image.new('L', (pxWidth, pxHeight), 0)
        target_ds = gdal.GetDriverByName(DriverName).Create(outFile, pxWidth, pxHeight, nBands, ds.GetRasterBand(1).DataType)
        target_ds.SetGeoTransform((X_minPixel, xDist, rtnX,Y_maxPixel, rtnY, yDist))
        pixels, line = world2Pixel(target_ds.GetGeoTransform(),pnts[0],pnts[1])
        listdata = [(pixels[i],line[i]) for i in xrange(len(pixels))]
        ImageDraw.Draw(img).polygon(listdata, outline=1, fill=1)
        mask = numpy.array(img)
        bandArray_Masked = bandArray_Area*mask
        del bandArray_Area, mask
        target_ds.GetRasterBand(nBands).WriteArray(bandArray_Masked)
        target_ds.GetRasterBand(nBands).SetNoDataValue(nodata)
    else:
        img = Image.new('L', (pxWidth, pxHeight), 0)
        target_ds = gdal.GetDriverByName(DriverName).Create(outFile, pxWidth, pxHeight, nBands, ds.GetRasterBand(1).DataType)
        target_ds.SetGeoTransform((X_min, xDist, rtnX,Y_max, rtnY, yDist))
        pixels, line = world2Pixel(target_ds.GetGeoTransform(),pnts[0],pnts[1])
        listdata = [(pixels[i],line[i]) for i in xrange(len(pixels))]
        ImageDraw.Draw(img).polygon(listdata, outline=1, fill=1)
        mask = numpy.array(img)
        for bandno in range(1, nBands+1):
            band = ds.GetRasterBand(bandno)
            nodata = band.GetNoDataValue()
            # convert band in Array
            bandArray = band.ReadAsArray()
            del band
            # clip arrey
            bandArray_Area = bandArray[ulY:lrY, ulX:lrX]
            del bandArray
            bandArray_Masked = bandArray_Area*mask
            target_ds.GetRasterBand(bandno).WriteArray(bandArray_Masked)
            del bandArray_Area
            target_ds.GetRasterBand(bandno).SetNoDataValue(nodata)
    # set the reference info
    if len(dsWKT) is 0:
        # Source has no projection (needs GDAL >= 1.7.0 to work)
        target_ds.SetProjection('LOCAL_CS["arbitrary"]')
    else:
    # Make the target raster have the same projection as the source
        target_ds.SetProjection(dsWKT)
    target_ds = None

Tags: oftheimageimporttargetbandifds

热门问题