Próbuję napisać własną funkcję straty dla UNET. W tej funkcji chcę przypisać wszystkie różnice większe niż 10 między y_true
a y_pred
do 10 oraz wszystkie różnice mniejsze niż 1 do 1. Czy istnieje wygodny sposób porównywania i przypisywania do tensorów?
def weighted_cross_entropyy(i):
def loss(y_true, y_pred):
diff = K.abs(y_true - y_pred)
diff[K.less(diff, 1)] == 1
diff[K.greater(diff, 10)] == 10
return K.mean(K.square((y_pred - y_true)* diff), axis= -1)
return loss
1 odpowiedź
Próbowałem rozwiązać problem i dotarłem tutaj, ale nadal otrzymuję poniższy błąd
Def weighted_cross_entropyy (i):
def loss(y_true, y_pred):
def f1():
return K.mean(K.square(y_pred - y_true), axis= - 1)
def f2():
return K.mean(K.square(y_pred - y_true), axis= - 1) * 10
def f3(w):
return K.mean(K.square(y_pred - y_true), axis= - 1) * w
w = K.sqrt(K.sum(K.square(y_true - y_pred), axis=-1))
print(w)
r = tf.case([(tf.less(w, 0), f1), (tf.greater(w, 10), f2)], default = f3(w), exclusive=True)
return r
return loss
Błąd: Kształt musi mieć rangę 0, ale ma rangę 3 dla 'loss_6 / conv2d_168_loss / case / cond / Switch' (op: 'Switch') z kształtami wejściowymi: [?, 128,800], [?, 128,800].
Podobne pytania
Nowe pytania
python
Python to wielozadaniowy, wielozadaniowy język programowania dynamicznie typowany. Został zaprojektowany tak, aby był szybki do nauczenia się, zrozumienia i użycia oraz wymuszania czystej i jednolitej składni. Należy pamiętać, że Python 2 oficjalnie nie jest obsługiwany od 01-01-2020. Mimo to, w przypadku pytań Pythona specyficznych dla wersji, dodaj znacznik [python-2.7] lub [python-3.x]. Korzystając z wariantu Pythona (np. Jython, PyPy) lub biblioteki (np. Pandas i NumPy), należy umieścić go w tagach.