Monday, April 30, 2012

Selecting rows from numpy ndarray

I want to select only certain rows from a numpy array based on the value in the second column. For example, this test array has integers from 1 to 10 in second column.



>>> test = numpy.array([numpy.arange(100), numpy.random.randint(1, 11, 100)]).transpose()
>>> test[:10, :]
array([[ 0, 6],
[ 1, 7],
[ 2, 10],
[ 3, 4],
[ 4, 1],
[ 5, 10],
[ 6, 6],
[ 7, 4],
[ 8, 6],
[ 9, 7]])


If I wanted only rows where the second value is 4, it is easy.



>>> test[test[:, 1] == 4]
array([[ 3, 4],
[ 7, 4],
[16, 4],
...
[81, 4],
[83, 4],
[88, 4]])


But how do I achieve the same result when there is more than one wanted value. The wanted list can be of arbitrary length. For example, I may want all rows where the second column is either 2, 4 or 6.



>>> wanted = [2, 4, 6]


The only way I have come up with is to use list comprehension and then convert this back into an array and seems too convoluted, although it works.



>>> test[numpy.array([test[x, 1] in wanted for x in range(len(test))])]
array([[ 0, 6],
[ 3, 4],
[ 6, 6],
...
[90, 2],
[91, 6],
[92, 2]])


Is there a better way to do this in numpy itself that I am missing?





No comments:

Post a Comment