Spark SQL and DataFrames
While RDDs provide a powerful, low-level API, most modern Spark applications are built using Spark SQL and its primary abstraction, the DataFrame. Spark SQL is a module for structured data processing that brings the power of SQL and relational databases to Spark.
What is a DataFrame?
A DataFrame is a distributed collection of data organized into named columns. Conceptually, it is equivalent to a table in a relational database or a data frame in R/Python (like pandas), but with two key differences:
- Distributed: The data is partitioned and distributed across the worker nodes of a cluster.
- Optimized: Operations on DataFrames are automatically optimized by Spark's Catalyst Optimizer, which creates highly efficient physical execution plans. This often makes DataFrame code significantly faster than equivalent RDD code.
DataFrames can be constructed from a wide range of sources, including structured data files (JSON, CSV, Parquet), tables in Hive, external databases, or existing RDDs.

Working with DataFrames
You can interact with DataFrames in two primary ways: using the DataFrame API (a fluent, programmatic style) or by registering the DataFrame as a temporary view and using standard SQL queries.

A Simple Example
Here’s a quick look at how you might read a JSON file and perform some basic analysis using the DataFrame API in PySpark.
from pyspark.sql import SparkSession
# 1. Create a SparkSession
# This is the entry point for any Spark SQL functionality
spark = (
SparkSession.builder.appName("My Spark SQL Example")
.master("local[*]")
.getOrCreate()
)
# 2. Read data into a DataFrame
# Spark can infer the schema from the JSON file
df = spark.read.json("path/to/your/customers.json")
# 3. Perform operations (Transformations)
# Select specific columns
names_and_ages = df.select("name", "age")
# Filter the data
adults = df.filter(df.age >= 21)
# Group by a column and aggregate
avg_age_by_city = df.groupBy("city").avg("age")
# 4. Trigger a computation (Action)
# The .show() action displays the first 20 rows of the resulting DataFrame
print("Adult Customers:")
adults.show()
print("Average Age by City:")
avg_age_by_city.show()
# Stop the SparkSession
spark.stop()
DataFrame API Cheat Sheet
Here are some of the most common functions in the PySpark DataFrame API. Let df be a DataFrame object.
Creating and Reading Data
| Description | Example |
|---|---|
| Create a SparkSession | spark = SparkSession.builder.master("local").getOrCreate() |
| Read a local JSON file | df = spark.read.json("examples/customers.json") |
| Read an HDFS CSV file | df = spark.read.csv("hdfs:///data/customers.csv", header=True, inferSchema=True) |
Inspecting Data
| Description | Example |
|---|---|
| Print the schema | df.printSchema() |
| Show the first 20 rows | df.show() or df.show(20) |
| Get summary statistics | df.describe().show() |
Get the first n rows |
df.head(5) |
Get the last n rows |
df.tail(5) |
| Count the total number of rows | df.count() |
| Get column data types | df.dtypes |
Selecting and Filtering
| Description | Example |
|---|---|
| Select specific columns | df.select("name", "city") |
| Select using column objects | df.select(df['name'], df['city']) |
| Filter rows based on a condition | df.filter(df.age > 30) |
| Alternative filter syntax | df.where("age > 30") |
Manipulating Columns and Rows
| Description | Example |
|---|---|
| Add or replace a column | df.withColumn("age_plus_ten", df.age + 10) |
| Rename a column | df.withColumnRenamed("age", "user_age") |
| Change all column labels | df.toDF("col1", "col2", ...) |
| Drop a column | df.drop("city") |
| Drop duplicate rows | df.drop_duplicates() |
Handling Missing Data
| Description | Example |
|---|---|
| Fill NA in all columns | df.fillna(0) |
| Fill NA in specific columns | df.fillna({"age": 0, "city": "Unknown"}) |
Alternative na syntax |
df.na.fill(0) |
Grouping and Aggregating
Grouping data is a two-step process. First, you create a GroupedData object using groupBy(). Then, you apply an aggregation function to that object.
grouped_data = df.groupBy("city")
| Description | Example |
|---|---|
| Count records per group | grouped_data.count().show() |
| Sum numerical columns per group | grouped_data.sum("age", "salary").show() |
| Get max value per group | grouped_data.max("salary").show() |
| Get min value per group | grouped_data.min("salary").show() |
| Get average value per group | grouped_data.avg("age").show() |
| Compute multiple aggregates | grouped_data.agg({"age": "avg", "salary": "max"}).show() |
Analyzing with Window Functions
While groupBy aggregates data into a single row per group, Window Functions compute a result for each row based on a "window" of related input rows. This is useful for tasks like calculating running totals, moving averages, or ranking items within a group.
The process involves three main steps:
- Define a Window Specification: This defines how the rows are partitioned and ordered. You use
Window.partitionBy()to group data andWindow.orderBy()to sort it within each group. - Select a Window Function: Choose a function to apply over the window, such as
rank(),row_number(),lag(), or aggregate functions likesum(). - Apply the Function: Use the
.over(window_spec)clause on the selected function.
Example: Find the Top 2 Highest-Paid Employees in Each Department
Imagine you have a DataFrame of employees and want to find the top earners in each department without losing the individual employee records.
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import col, rank
# Sample data
data = [
("James", "Sales", 3000),
("Michael", "Sales", 4600),
("Robert", "Sales", 4100),
("Maria", "Finance", 3000),
("James", "Finance", 3000),
("Scott", "Finance", 3300),
("Jen", "Finance", 3900),
("Jeff", "Marketing", 3000),
("Kumar", "Marketing", 2000),
("Saif", "Sales", 4100),
]
columns = ["employee_name", "department", "salary"]
df = spark.createDataFrame(data, schema=columns)
# 1. Define the window spec
# Partition by department and order by salary in descending order
window_spec = Window.partitionBy("department").orderBy(col("salary").desc())
# 2. Apply the rank() function over the window
ranked_df = df.withColumn("rank", rank().over(window_spec))
# 3. Filter to find the top 2 employees in each department
top_earners_df = ranked_df.filter(col("rank") <= 2)
top_earners_df.show()
# +-------------+----------+------+----+
# |employee_name|department|salary|rank|
# +-------------+----------+------+----+
# | Maria| Finance| 3900| 1|
# | Scott| Finance| 3300| 2|
# | Michael| Sales| 4600| 1|
# | Robert| Sales| 4100| 2|
# | Saif| Sales| 4100| 2|
# | Jeff| Marketing| 3000| 1|
# | Kumar| Marketing| 2000| 2|
# +-------------+----------+------+----+
In this example, rank() is used, which assigns the same rank to rows with identical values in the orderBy column (e.g., Robert and Saif). If a unique ranking is needed, row_number() would be a better choice.
Why Not Use groupBy?
The key difference is that groupBy aggregates data, which means it collapses multiple rows into a single summary row. If we tried to solve the problem above with groupBy, we could find the maximum salary in each department (df.groupBy("department").max("salary")), but we would lose the information about which employee earns that salary.
A Window function, on the other hand, preserves the original rows. It calculates a value for each row based on the other rows in its "window" (or partition). This allows you to perform group-level calculations (like ranking) while keeping the detail of each individual record.
- Use
groupBywhen you need a single, aggregated result per group. - Use a Window function when you need to compute a value for each row based on its group.
Other Utilities
| Description | Example |
|---|---|
| Get the underlying SparkContext | sc = spark.sparkContext |
| Get a column object | col_obj = df['age'] or col_obj = df.age |