Napisałem małą funkcję, aby interpolować sekwencje czasowe nieregularnie próbkowanych obrazów rastrowych, aby były równomiernie rozmieszczone w czasie (poniżej). Działa dobrze, ale po prostu wiem, że szukam tego, że brakuje mi skrótów. Szukam Numpy Ninja, aby dać mi więc porady dotyczące uderzenia swojej składni, a może też zdobędzie trochę wydajności.

Twoje zdrowie!

import numpy as np

def interp_rasters(rasters, chrons, sampleFactor = 1):
    nFrames = round(len(chrons) * sampleFactor)
    interpChrons = np.linspace(np.min(chrons), np.max(chrons), nFrames)
    frames, rows, cols, channels = rasters.shape
    interpRasters = np.zeros((nFrames, rows, cols, channels), dtype = 'uint8')
    outs = []
    for row in range(rows):
        for col in range(cols):
            for channel in range(channels):
                pixelSeries = rasters[:, row, col, channel]
                interpRasters[:, row, col, channel] = np.interp(
                    interpChrons,
                    chrons,
                    pixelSeries
                    )
    return interpRasters
0
Rohan S Byrne 19 marzec 2020, 03:58

1 odpowiedź

Najlepsza odpowiedź

Gdy wartości Y ma być pod uwagę, muszą być 1d, nie widzę sposobu, aby nie zapętlić przez NP.ARrays. Jeśli tablice RASTERS i International są przekształcane, jak w funkcji można stosować jedną pętlę, bez wyraźnego indeksowania. Dało to około 10% poprawy prędkości dla moich wymyślonych danych testowych.

import numpy as np

frames = 10
rows = 5
cols = 10
channels = 3

np.random.seed(1234)

rasters = np.random.randint(0,256, size=(frames, rows, cols, channels))
chrons = np.random.randint(0, 256, size  = 10 )


# The original function.
def interp_rasters(rasters, chrons, sampleFactor = 1):
    nFrames = round(len(chrons) * sampleFactor)
    interpChrons = np.linspace(np.min(chrons), np.max(chrons), nFrames)
    frames, rows, cols, channels = rasters.shape
    interpRasters = np.zeros((nFrames, rows, cols, channels), dtype = 'uint8')
    outs = []
    for row in range(rows):
        for col in range(cols):
            for channel in range(channels):
                pixelSeries = rasters[:, row, col, channel]
                interpRasters[:, row, col, channel] = np.interp(
                    interpChrons,
                    chrons,
                    pixelSeries
                    )
    return interpRasters

def interp_rasters2(rasters, chrons, sampleFactor = 1):
    nFrames = round(len(chrons) * sampleFactor)
    interpChrons = np.linspace(np.min(chrons), np.max(chrons), nFrames)
    frames, rows, cols, channels = rasters.shape
    interpRasters = np.zeros((nFrames, rows, cols, channels), dtype = 'uint8')

    # Create reshaped arrays pointing to the same data 
    dat_in = rasters.reshape(frames, rows*cols*channels).T  
    # shape (r*c*c, frames)

    dat_out = interpRasters.reshape(nFrames, rows*cols*channels).T  
    # shape (r*c*c, frames)

    for pixelseries, row_out in zip(dat_in, dat_out):
        # Loop through all data in one loop.
        row_out[:] = np.interp( interpChrons, chrons, pixelseries )
    return interpRasters  
    # As dat_out and interpRasters share the same data return interpRasters

print(np.isclose(interp_rasters(rasters, chrons), interp_rasters2(rasters, chrons)).all())
# True  # The results are the same from the two functions.

%timeit interp_rasters(rasters, chrons)
# 568 µs ± 2.55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit interp_rasters2(rasters, chrons)
# 520 µs ± 239 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Edytuj Jest również np.apply_along_axis. To usuwa wyraźne pętle, zmniejsza ilość kodu, ale jest wolniejsze niż poprzednie rozwiązania.

def interp_rasters3(rasters, chrons, sampleFactor = 1):
    nFrames = round(len(chrons) * sampleFactor)
    interpChrons = np.linspace(np.min(chrons), np.max(chrons), nFrames)

    def func( arr ):  # Define the function to apply along the axis
        return np.interp( interpChrons, chrons, arr )

    return np.apply_along_axis( func, 0, rasters ).astype( np.uint8 )

print(np.isclose(interp_rasters(rasters, chrons), interp_rasters3(rasters, chrons)).all())
# True

Myślę, że rozumiem wersję 3 lepszą niż wersja 1 lub 2 w ciągu 6 miesięcy, jeśli prędkość nie jest krytyczna.

HTH

1
Tls Chris 20 marzec 2020, 20:43