NumPy ndarray.argmax()

The numpy.ndarray.argmax() method returns the indices of the maximum values along a specified axis in a NumPy array. If no axis is provided, it returns the index of the maximum value in the flattened array.

Syntax

</>
Copy
ndarray.argmax(axis=None, out=None, *, keepdims=False)

Parameters

ParameterTypeDescription
axisNone, int, or tuple of ints, optionalAxis or axes along which to find the index of the maximum value. If None, it considers the flattened array.
outndarray, optionalAlternative output array for storing the result. Must have the same shape as the expected output.
keepdimsbool, optionalIf True, the reduced dimensions are kept as size one, allowing proper broadcasting.

Return Value

Returns an integer index (or an array of indices if an axis is specified) representing the position of the maximum value in the array.


Examples

1. Finding the Index of the Maximum Value in a Flattened Array

In this example, we create a 1D NumPy array and use argmax() to find the index of the maximum element.

</>
Copy
import numpy as np

arr = np.array([10, 25, 14, 36, 9])

index = arr.argmax()
print(index)

Output:

3

The maximum value in the array is 36, which is at index 3.

2. Using the axis Parameter in ndarray.argmax()

Here, we apply argmax() along different axes in a 2D array.

</>
Copy
import numpy as np

arr = np.array([[10, 35, 22],
                [47, 5, 30]])

index_axis0 = arr.argmax(axis=0)
print("Max indices along axis 0:", index_axis0)

index_axis1 = arr.argmax(axis=1)
print("Max indices along axis 1:", index_axis1)

Output:

Max indices along axis 0: [1 0 1]
Max indices along axis 1: [1 0]

For axis=0 (columns), it returns the row indices of the maximum values in each column.For axis=1 (rows), it returns the column indices of the maximum values in each row.

3. Keeping Dimensions with keepdims=True in ndarray.argmax()

Setting keepdims=True preserves the original array shape with size-one dimensions.

</>
Copy
import numpy as np

arr = np.array([[2, 8, 1],
                [5, 3, 7]])

result = arr.argmax(axis=1, keepdims=True)
print(result)

Output:

[[1]
 [2]]

The maximum indices for each row are kept as a column vector due to keepdims=True, making it useful for broadcasting.