NumPy ndarray.argmax()
The numpy.ndarray.argmax()
method returns the indices of the maximum values along a specified axis in a NumPy array. If no axis is provided, it returns the index of the maximum value in the flattened array.
Syntax
ndarray.argmax(axis=None, out=None, *, keepdims=False)
Parameters
Parameter | Type | Description |
---|---|---|
axis | None, int, or tuple of ints, optional | Axis or axes along which to find the index of the maximum value. If None , it considers the flattened array. |
out | ndarray, optional | Alternative output array for storing the result. Must have the same shape as the expected output. |
keepdims | bool, optional | If True , the reduced dimensions are kept as size one, allowing proper broadcasting. |
Return Value
Returns an integer index (or an array of indices if an axis is specified) representing the position of the maximum value in the array.
Examples
1. Finding the Index of the Maximum Value in a Flattened Array
In this example, we create a 1D NumPy array and use argmax()
to find the index of the maximum element.
import numpy as np
arr = np.array([10, 25, 14, 36, 9])
index = arr.argmax()
print(index)
Output:
3
The maximum value in the array is 36
, which is at index 3
.
2. Using the axis
Parameter in ndarray.argmax()
Here, we apply argmax()
along different axes in a 2D array.
import numpy as np
arr = np.array([[10, 35, 22],
[47, 5, 30]])
index_axis0 = arr.argmax(axis=0)
print("Max indices along axis 0:", index_axis0)
index_axis1 = arr.argmax(axis=1)
print("Max indices along axis 1:", index_axis1)
Output:
Max indices along axis 0: [1 0 1]
Max indices along axis 1: [1 0]
For axis=0
(columns), it returns the row indices of the maximum values in each column.For axis=1
(rows), it returns the column indices of the maximum values in each row.
3. Keeping Dimensions with keepdims=True
in ndarray.argmax()
Setting keepdims=True
preserves the original array shape with size-one dimensions.
import numpy as np
arr = np.array([[2, 8, 1],
[5, 3, 7]])
result = arr.argmax(axis=1, keepdims=True)
print(result)
Output:
[[1]
[2]]
The maximum indices for each row are kept as a column vector due to keepdims=True
, making it useful for broadcasting.