Jak mogę stwierdzić, że Tensor wygląda następująco (na przykład):

[ True, True ]
[ True, True, True, False, False, False, False ]
[ True, True, True, False, False ]
[ True, False, False, False, False ]
[ False, False ]

Ale odrzuć takie dane wejściowe:

[ True, False, True, False, False, True, False ]
[ False, False, False, False, True ]

Mówiąc bardziej ogólnie: chcę sprawdzić, czy tensor składa się tylko z sekwencji od 0 do N wartości True, po której następuje od 0 do N wartości False. Jak mogę to zrobić z Tensorflow 2?

1
miho 12 marzec 2020, 14:51

2 odpowiedzi

Najlepsza odpowiedź

Inne podejście, badanie wskaźników pierwiastków:

import tensorflow as tf

def is_valid(t):
  where_false = tf.where(~t)
  return len(where_false) == 0 or all( idx_true < min(where_false) for idx_true in tf.where(t))

assert is_valid(tf.constant([ True, True ]))
assert is_valid(tf.constant([ True, True, True, False, False, False, False ]))
assert is_valid(tf.constant([ True, True, True, False, False ]))
assert is_valid(tf.constant([ True, False, False, False, False ]))
assert is_valid(tf.constant([ False, False ]))
assert not is_valid(tf.constant([ True, False, True, False, False, True, False ]))
assert not is_valid(tf.constant([ False, False, False, False, True ]))

Chodzi o to, że wszystkie wartości True powinny pojawić się przed pierwszym False, jeśli takie istnieją.

1
GPhilo 12 marzec 2020, 12:13

Oto jeden ze sposobów, w jaki możesz to zrobić:

import tensorflow as tf

def is_valid(a):
    # a is assumed to be a 1D boolean array
    a = tf.convert_to_tensor(a)
    # Convert to integer
    a_int = tf.dtypes.cast(a, tf.int32)
    # Take pairwise differences
    diff = a_int[1:] - a_int[:-1]
    # Check all differences are zero or negative (no transitions from False to True)
    return tf.reduce_all(diff <= 0)

# Valid examples
tf.print(is_valid([ True, True ]))
# 1
tf.print(is_valid([ True, True, True, False, False, False, False ]))
# 1
tf.print(is_valid([ True, True, True, False, False ]))
# 1
tf.print(is_valid([ True, False, False, False, False ]))
# 1
tf.print(is_valid([ False, False ]))
# 1

# Invalid examples
tf.print(is_valid([ True, False, True, False, False, True, False ]))
# 0
tf.print(is_valid([ False, False, False, False, True ]))
# 0

Uwaga: is_valid zwraca skalarny tensor boolowski, mimo że tf.print drukuje go jako liczbę całkowitą.

1
jdehesa 12 marzec 2020, 12:00