More on groupby
You already saw simple grouping and aggregation in an earlier chapter. In this lesson, we will apply groupby to solve more complex problems.
Imputation of missing values
Previously, you learned how to address missing "age" values by substituting them with the mean or median age.
While using the overall mean or median is a simple approach, a more refined method involves grouping passengers by their passenger class (pclass) and then calculating the mean age for each class to impute missing ages within those respective groups.
Let's first justify why a different mean is appropriate by computing the mean for each group:
import pandas as pd
df = pd.read_csv(
"https://raw.githubusercontent.com/jravi123/datasets/refs/heads/main/titanic.csv"
)
df.groupby("pclass")["age"].mean()
agepclass
1 38.233441
2 29.877630
3 25.140620
As you can see from the result, first-class passengers had a much higher mean age than the others.
Now, let's appropriately impute the missing age values with the mean of that specific pclass. But first, let's see a few NaN values for pclass '1'. The first few values are displayed below:
df[(df.pclass == 1) & df.age.isna()]["age"]
age31 NaN
55 NaN
64 NaN
166 NaN
Let's also check a few rows for pclass '2' and '3' as well:
df[(df.pclass == 2) & df.age.isna()]["age"].head(2)
age17 NaN
181 NaN
df[(df.pclass == 3) & df.age.isna()]["age"].head(3)
age5 NaN
19 NaN
26 NaN
Now, we will replace the NaNs with the mean age for that pclass group using the transform function:
df["age"] = df.groupby("pclass")["age"].transform(lambda x: x.fillna(x.mean()))
Let's verify the values for the ones that are displayed above:
df.iloc[[31, 55, 64, 17, 5]][["age", "pclass"]]
age pclass31 38.233441 1
55 38.233441 1
64 38.233441 1
17 29.877630 2
5 25.140620 3
As you can see from the output above, the mean values for the respective pclass are used for imputing.
Applying multiple aggregate functions
Let's say we want to find the mean, the median, and also the sum of the fare in each pclass. You could do all three using the aggregate function:
df.groupby("pclass")["fare"].aggregate(["sum", "mean", "median"])
sum mean medianpclass
1 18177.4125 84.154687 60.2875
2 3801.8417 20.662183 14.2500
3 6714.6951 13.675550 8.0500
Filtering on groups
You can also apply filters to the group to get only those groups that match the filter.
Suppose you want to find the decks where the fare is more than 20 for all passengers in that deck category. You could run a query like the one below:
df.groupby("deck").filter(lambda x: all(x.fare > 20))["deck"].unique()
array(['C', 'B'], dtype=object)
It turns out only decks 'C' and 'B' are where the fare is more than 20 for all passengers.