Show the code
::p_load(
pacman
tidyverse,
tidymodels,
tidytext,
ollamar,
ellmer,
mall,
highcharter,
echarts4r,
shapviz,
patchwork,
finetune,
vip,
DALEX,
httr2,
learntidymodels,
htmltools )
This project demonstrates how to incorporate LLM capabilities into a data science workflow, focusing on the 2025-02-04 tidytuesday Simpsons dataset
. First the data is explored, picking out some patterns in the Simpsons episodes, with the help of LLMs for sentiment analysis, episode summarisation and understanding character dynamics. This was done using a local version of Ollama3.2
.
Once we had a deeper understanding of the data we turned to feature engineering to prepare the data for modeling. Here we took advantage of the 4o-mini
model using the OpenAI API, this included creating embeddings of the episode summaries from the exploratory data analysis. Next, we modeled the data in order to predict episode ratings, and explored the variables driving the prediction.
Finally, we generated some new data using Ollama3.2
to predict ratings for these hypothetical episodes. To conclude, we suggested future analysis directions, such as exploring different modeling approaches & LLM providers.
Before we begin, let’s load all our dependencies:
::p_load(
pacman
tidyverse,
tidymodels,
tidytext,
ollamar,
ellmer,
mall,
highcharter,
echarts4r,
shapviz,
patchwork,
finetune,
vip,
DALEX,
httr2,
learntidymodels,
htmltools )
For this project we will use the Simpsons episode data, which was part of tidytuesday 2024 week 5, it has episodes from 2010-2016. We read in the data as follows:
# Import the datasets:
<- tidytuesdayR::tt_load(2025, week = 5)
tuesdata
<- tuesdata$simpsons_characters
simpsons_characters <- tuesdata$simpsons_episodes
simpsons_episodes <- tuesdata$simpsons_locations
simpsons_locations <- tuesdata$simpsons_script_lines simpsons_script_lines
This leaves us with 4 data frames, simpsons_characters
which has the characters included in the episodes. simpsons_episodes
which has information about the episodes such as air date rating etc. simpsons_locations
which has information about the location in the episode when the script line was said. Finally, the simpsons_script_lines
data frame which has the actual script lines which the characters said. Let’s take a look at how the IMDb ratings and US viewers changed over time, highlighting the min and max for each of the metrics:
# Create the plot data:
%>%
simpsons_episodes select(id, imdb_rating, original_air_date, us_viewers_in_millions) %>%
arrange(original_air_date) %>%
group_by(original_air_date) %>%
summarise(
imdb_rating = mean(imdb_rating),
us_viewers_in_millions = mean(us_viewers_in_millions)
%>%
) ungroup() -> simpsons_plot_data
# Plot time series:
%>%
simpsons_plot_data e_charts_("original_air_date") %>%
e_line_("imdb_rating",
name = "IMDB Rating",
smooth = TRUE,
x_index = 1,
y_index = 1
%>%
) e_mark_point(
"IMDB Rating",
data = list(type = "min"),
x_index = 1,
y_index = 1
%>%
) e_mark_point(
"IMDB Rating",
data = list(type = "max"),
x_index = 1,
y_index = 1
%>%
) e_area(
us_viewers_in_millions,smooth = TRUE,
symbol = "none",
name = "US Viewers (millions)"
%>%
) e_mark_point("US Viewers (millions)", data = list(type = "min")) %>%
e_mark_point("US Viewers (millions)", data = list(type = "max")) %>%
e_grid(width = "40%", left = "5%") %>%
e_grid(width = "45%", left = "50%", right = "2%") %>%
e_x_axis(name = "") %>%
e_y_axis(name = "") %>%
e_x_axis(gridIndex = 1, name = "") %>%
e_y_axis(gridIndex = 1, name = "") %>%
e_title("Simpson episode IMDB rating & US viewers") %>%
e_theme("purple-passion") %>%
e_legend(right = 0) %>%
e_tooltip(
trigger = "axis",
axisPointer = list(
type = "cross"
)%>%
)e_datazoom(x_index = 0) %>%
e_datazoom(x_index = 1) %>%
e_datazoom(type = "inside", x_index = 0) %>%
e_datazoom(type = "inside", x_index = 1) %>%
e_toolbox(
show = TRUE,
right = 10,
top = 15,
itemSize = 18,
iconStyle = list(
borderColor = "#ffffff",
borderWidth = 1
),emphasis = list(
iconStyle = list(
borderColor = "#ffcc00",
borderWidth = 2
)
),feature = list(
dataZoom = list(
show = TRUE,
title = list(
zoom = "Zoom",
back = "Reset Zoom"
),iconStyle = list(
borderColor = "#ffffff"
)
),restore = list(
show = TRUE,
title = "Reset",
iconStyle = list(
borderColor = "#ffffff"
)
),saveAsImage = list(
show = TRUE,
title = "Save as Image",
iconStyle = list(
borderColor = "#ffffff"
)
)
) )
Figure 1. Simpson episode KPIs over time.
It seems like the episode with the lowest rating aired on the 20th of May 2012, with a rating of 4.5 on IMDb. In contrast, the episode with the highest rating, aired on the 13th of December 2015. As for the number of viewers, the episode that aired on 10th of January 2010 saw a whopping 14.62 million US viewers. Let’s hone in on the script lines dataframe and leverage some LLM capabilities. We would like to understand the sentiment of the episode based on the script lines. To do this we first have to wrangle our dataframe, such that one row has the episode and all the script lines associated with it:
# Collapse all script lines from an episode to a single row:
%>%
simpsons_script_lines filter(spoken_words != 'NA') %>%
select(episode_id, raw_character_text, spoken_words) %>%
group_by(episode_id) %>%
summarise(spoken_words = paste(spoken_words, collapse = " ")) %>%
ungroup() -> episodes_w_id
Now we have all our episode lines in one row per episode. From this we can generate the episode sentiment using the LLM. However instead of running the senitment analysis on the whole data frame, let’s first summarise the episode using the LLM and then run the sentiment analysis on that text instead. This will speed up our inference time. We can do this using ther mall
package, with a local instance of the llama 3.2
model:
# Get the llama 3.2 model:
::pull("llama3.2")
ollamar
# Run the summarize operation:
<- llm_summarize(
simpson_summary
episodes_w_id,
spoken_words, pred_name = "episode_summary",
additional_prompt = "Summarise the transcripts from these episodes of The Simpsons. Please do not mention episode in the summary, simply summarise the episode.",max_words = 30)
We pass the data frame to the llm_summarize
function, specify the column which the summary should be based on, the name of the new summary column and finally limit the response to 30 words per row. Now that we have these summaries, we can in turn pass that back to the LLM, to run a sentiment analysis:
# Run the sentiment operation:
<- llm_sentiment(
simpsons_episodes_w_sent
simpson_summary,
episode_summary, pred_name = "episode_sentiment")
This leave us with the following dataframe:
simpsons_episodes_w_sent
# A tibble: 118 × 4
episode_id spoken_words episode_summary episode_sentiment
<dbl> <chr> <chr> <chr>
1 450 "Oh, I love going to aquatic pa… "grampa meets … neutral
2 451 "Time to try out my new deep fr… "krusty's rati… negative
3 452 "Holy moly! YOU LIKE POTATO... … "homer wins th… negative
4 453 "I'm not a dude, I'm a hottie! … "marge wants h… negative
5 454 "Hey, I like art, okay? Ach. Th… "mabel discove… positive
6 455 "No. I don't care! It's stuck! … "marge gets fr… negative
7 456 "Thanks to our new G.P.S., I'll… "bart simpson … negative
8 457 "What about Sodom and Gomorrah?… "Homer Simpson… negative
9 458 "Burns, you're coming with us. … "burns gets ar… negative
10 459 "I can't believe you're making … "the chief get… neutral
# ℹ 108 more rows
Let’s take a look at whether there is a difference between the average values for the IMDb rating and the US viewers based on the sentiment of the episode:
# Get the summary statistics by sentiment:
%>%
simpsons_episodes select(id, imdb_rating, original_air_date, us_viewers_in_millions) %>%
left_join(simpsons_episodes_w_sent %>%
select(episode_id,episode_sentiment), by = c("id" = "episode_id")) %>%
group_by(episode_sentiment) %>%
summarise(
avg_rating = mean(imdb_rating),
stddev = sd(imdb_rating),
avg_viewers = mean(us_viewers_in_millions),
stddev_viewers = sd(us_viewers_in_millions)
%>%
) na.omit()-> summaries
# Pivot the summaries dataframe to long format:
<- summaries %>%
chart_data select(episode_sentiment, avg_rating, stddev, avg_viewers, stddev_viewers) %>%
arrange(desc(avg_rating)) %>%
mutate(episode_sentiment = fct_inorder(episode_sentiment)) %>%
pivot_longer(
cols = c(avg_rating, avg_viewers),
names_to = "metric",
values_to = "value"
%>%
) mutate(
metric = case_when(
== "avg_rating" ~ "Average Rating",
metric == "avg_viewers" ~ "Average Viewers (millions)"
metric
)
)
# Create the plot:
%>%
chart_data hchart(
"column",
hcaes(x = episode_sentiment, y = value, group = metric)
%>%
) hc_xAxis(title = list(text = "Episode Sentiment")) %>%
hc_yAxis(title = list(text = "Value")) %>%
hc_title(text = "Average Rating and Viewers by Episode Sentiment") %>%
hc_tooltip(
formatter = JS(
"
function() {
return '<b>' + this.series.name + '</b><br/>' + this.y.toFixed(2);
}
"
)%>%
) hc_legend(
align = "right",
verticalAlign = "top"
%>%
) hc_add_theme(hc_theme_smpl())
Figure 2. Simpson episode sentiment KPIs.
It looks like there are only marginal differences between the average rating and the viewers for the different sentiments, with neutral having the highest average for both KPIs, followed by negative and finally positive. There is not enough sample size to determine whether the difference is statistically significant, however this alludes to some slight variations (albeit marginal), perhaps with more episodes analysed, some more fruitful findings can be established.
Let’s further explore some of the temporal elements of the script data, to answer the following question: how has the Simpson family dialogue evolved throughout the time frames of the episodes airing. More specifically, we want to know how the script time has changed for the characters, focusing on Homer, Marge, Lisa and Bart. To do so we filter for the dialogue lines only and then aggregate the word_count
column to give us a notion of the total dialogue over time:
# Select relevant characters & dialogue lines only:
%>%
simpsons_script_lines select(
episode_id,
number,
raw_text,
speaking_line,
raw_character_text,
spoken_words,
normalized_text,
location_id,
word_count%>%
) filter(
== TRUE &
speaking_line %in%
raw_character_text c('Homer Simpson', 'Marge Simpson', 'Lisa Simpson', 'Bart Simpson')
%>%
) arrange(episode_id, number) -> main_char_lines
# Connect to episodes data frame, to get a date:
%>%
main_char_lines left_join(
%>%
simpsons_episodes select(id, original_air_date),
by = c("episode_id" = "id")
%>%
) select(raw_character_text, word_count, original_air_date) %>%
group_by(raw_character_text, original_air_date) %>%
summarise(word_count = sum(word_count)) %>%
ungroup() -> to_plot
to_plot
# A tibble: 462 × 3
raw_character_text original_air_date word_count
<chr> <date> <dbl>
1 Bart Simpson 2010-01-03 179
2 Bart Simpson 2010-01-10 199
3 Bart Simpson 2010-01-31 295
4 Bart Simpson 2010-02-14 188
5 Bart Simpson 2010-02-21 66
6 Bart Simpson 2010-03-14 359
7 Bart Simpson 2010-03-21 362
8 Bart Simpson 2010-03-28 143
9 Bart Simpson 2010-04-11 207
10 Bart Simpson 2010-04-18 104
# ℹ 452 more rows
With this in hand we can visualise the total for each character over time:
# Format the data to wide format:
%>%
to_plot pivot_wider(names_from = raw_character_text, values_from = word_count) %>%
replace_na(list(
`Homer Simpson` = 0,
`Marge Simpson` = 0,
`Lisa Simpson` = 0,
`Bart Simpson` = 0
-> to_plot
))
# Plot river chart:
%>%
to_plot e_charts(original_air_date) %>%
e_river(`Homer Simpson`) %>%
e_river(`Marge Simpson`) %>%
e_river(`Lisa Simpson`) %>%
e_river(`Bart Simpson`) %>%
e_tooltip(trigger = "axis") %>%
e_title("Simpsons Family Character Word Count Over Time",) %>%
e_legend(top = 30)
Figure 3. Simpson family dialogue evolution.
Let’s turn out gaze to one final plot, focusing on the Simpson episode script lines. We create a new data frame with the speaking lines only:
# Filter for speakling lines only:
%>%
simpsons_script_lines select(episode_id,number,raw_text,speaking_line,raw_character_text, spoken_words,normalized_text,location_id,word_count) %>%
filter(speaking_line == TRUE) %>%
arrange(episode_id,number) -> simpsons_script_lines
simpsons_script_lines
# A tibble: 26,169 × 9
episode_id number raw_text speaking_line raw_character_text spoken_words
<dbl> <dbl> <chr> <lgl> <chr> <chr>
1 450 1 Homer Simpso… TRUE Homer Simpson Oh, I love …
2 450 2 Homer Simpso… TRUE Homer Simpson Now to seal…
3 450 5 Marge Simpso… TRUE Marge Simpson Oh, there's…
4 450 6 Marge Simpso… TRUE Marge Simpson A bird eati…
5 450 7 Lisa Simpson… TRUE Lisa Simpson I wanna do …
6 450 9 Bart Simpson… TRUE Bart Simpson I'm gonna t…
7 450 10 Grampa Simps… TRUE Grampa Simpson My feet hur…
8 450 11 Announcer: A… TRUE Announcer Attention, …
9 450 12 Marge Simpso… TRUE Marge Simpson That's us!
10 450 13 Announcer: I… TRUE Announcer In one minu…
# ℹ 26,159 more rows
# ℹ 3 more variables: normalized_text <chr>, location_id <dbl>,
# word_count <dbl>
I want to understand the character being addressed by the character who is speaking, let’s see if the LLM can help us here, we do the following:
# Supply the prompt:
<- paste(
my_prompt "Answer a question.",
"Return only the answer, no explanation. Return only the name no punctuation and all lower case letters",
"Acceptable answers are character names, if it is not clear say 'not clear', do not try to guess",
"Answer this about the following text, who is the character addressing?:"
)
# Create a new data frame with character beign addressed:
<- llm_custom(
character_being_addressed %>%
simpsons_script_lines slice_sample(prop = 0.1), # Only used a sample cause of performance
normalized_text,
my_prompt, pred_name = "character"
)
The LLM call to the local llama3.2 model took a long time, therefore I only used 10% of the data to demonstrate the concept.
Let’s filter out some of the examples where the LLM couldn’t quite come up with the character being addressed:
# Filter for explicit characters only
%>%
character_being_addressed filter(!grepl("clear", character, ignore.case = TRUE)) %>% select(raw_character_text,character,word_count)->chars_final
chars_final
# A tibble: 922 × 3
raw_character_text character word_count
<chr> <chr> <dbl>
1 Raymondo The person (implied by the first-person narrat… 6
2 Homer Simpson The Reader 6
3 Lise Maman 14
4 Homer Simpson The Boys. 5
5 Marge Simpson Tomato 7
6 Kang The Chosen One 15
7 Selma Bouvier Ryan Seacrest 18
8 Homer Simpson Marge 8
9 Seymour Skinner Mr Testacleese 7
10 Bart Simpson Maggie 16
# ℹ 912 more rows
The raw_character_text
column is the character saying the script line and the character
column is the character which the LLM identified as the character being addressed.
I created a data frame to remove some stop words,which will be used to filter out some rows:
# Define stop words to remove:
<- c("the","a","an","and","or","but","in","on","at","to","for","of",
stop_words "with","by","is","are","was","were","being","been","have","has","had","do", "does","did","will","would","could","should","may","might","must","shall", "can")
We need to clean up our columns abit and filter for significant interactions:
# Let's create a helper function to aid in the cleaning of the names:
<- function(name) {
clean_character_name %>%
name # Force to lowercase:
str_to_lower() %>%
# Remove punctuation:
str_remove_all("[[:punct:]]") %>%
# Split into words:
str_split("\\s+") %>%
map(
~ {
# Remove stop words, using object from before:
<- .x[!.x %in% stop_words]
words # Remove empty strings, in case they exist:
<- words[words != ""]
words # Keep only first Name for Simpson family member:
if (length(words) >= 2 && "simpson" %in% words) {
<- words[words != "simpson"]
words
}# Join back together:
paste(words, collapse = " ")
}%>%
) unlist() %>%
str_trim() %>%
# Handle unknowns:
ifelse(. == "", "unknown", .)
}
# We apply the cleaning to our dataframe:
<- chars_final %>%
character_data_cleaned mutate(
# Clean both character columns:
raw_character_text_clean = map_chr(
raw_character_text,
clean_character_name
),character_clean = map_chr(character, clean_character_name)
%>%
) # Remove rows where cleaned names are the same (self-references):
filter(raw_character_text_clean != character_clean) %>%
# Remove unknown characters:
filter(raw_character_text_clean != "unknown" & character_clean != "unknown")
# Next we filter for only those characters which have significant
# interactions (at least twice):
<- character_data_cleaned %>%
character_data_filtered group_by(raw_character_text_clean) %>%
mutate(char_frequency = n()) %>%
ungroup() %>%
filter(char_frequency >= 2) %>%
group_by(character_clean) %>%
mutate(target_frequency = n()) %>%
ungroup() %>%
filter(target_frequency >= 2)
character_data_filtered
# A tibble: 368 × 7
raw_character_text character word_count raw_character_text_clean
<chr> <chr> <dbl> <chr>
1 Homer Simpson The Reader 6 homer
2 Homer Simpson The Boys. 5 homer
3 Homer Simpson Marge 8 homer
4 Bart Simpson Maggie 16 bart
5 Bart Simpson She 13 bart
6 Adult Bart Dad 9 adult bart
7 Bart Simpson The Parent(s) 7 bart
8 Marge Simpson You. 18 marge
9 Homer Simpson You. 10 homer
10 Snake Jailbird You 7 snake jailbird
# ℹ 358 more rows
# ℹ 3 more variables: character_clean <chr>, char_frequency <int>,
# target_frequency <int>
Next, we create the nodes and the edges that connect the characters, using the count of the unique character combinations as the weight of the connections between the nodes, where the thickness of the connections shows us how many times those characters were addressed by the another:
# Create edges with interaction counts as weights:
<- character_data_filtered %>%
edges group_by(source = raw_character_text_clean, target = character_clean) %>%
summarise(value = n(), .groups = "drop") %>%
mutate(lineWidth = pmax(value * 2, 1),
lineWidth = pmin(lineWidth, 15))
# Create nodes:
<- unique(c(edges$source, edges$target))
all_characters
<- data.frame(
nodes name = all_characters,
stringsAsFactors = FALSE
%>%
) left_join(
%>%
edges group_by(source) %>%
summarise(out_interactions = sum(value), .groups = "drop") %>%
rename(name = source),
by = "name"
%>%
) left_join(
%>%
edges group_by(target) %>%
summarise(in_interactions = sum(value), .groups = "drop") %>%
rename(name = target),
by = "name"
%>%
) mutate(
total_interactions = coalesce(out_interactions, 0) +
coalesce(in_interactions, 0),
value = total_interactions,
size = pmax(sqrt(total_interactions * 10), 10),
grp = case_when(
%in% c("homer", "marge", "bart", "lisa", "maggie") ~
name "Simpson Family",
TRUE ~ "Other"
)%>%
) select(name, value, size, grp)
# Create network graph:
e_charts() %>%
e_graph(
layout = "force",
draggable = TRUE,
focusNodeAdjacency = TRUE,
roam = TRUE,
zoom = 0.3,
center = c(400, 300),
force = list(
repulsion = 1500,
edgeLength = 120,
gravity = 0.1,
friction = 0.6
)%>%
) e_graph_nodes(nodes, name, value, size, grp) %>%
e_graph_edges(edges, source, target, value, lineWidth) %>%
e_tooltip() %>%
e_title("Character Interaction Network") %>%
e_legend(orient = "vertical", right = 0, top = "middle")
Figure 4. Network graph of character interactions.
It looks like Homer has the most interactions from the Simpsons family. In contrast, from the Other
category, the character you
has the most. It seems like the LLM is identifying when thew viewer is being referred to by the characters in the episode, is this right? 🤔 Let’s take this with a grain of salt since it was the output from an LLM, in any case you get the point!
We will also use the LLM capabilities to help in the feature engineering process. For this we will use the text-embedding-ada-002
model from OpenAI. We briefly turn to the dark side and create the embeddings, with the following Python code:
# Import dependencies:
from openai import OpenAI
import pandas as pd
import numpy as np
import yaml
import os
import ast
from sklearn.preprocessing import LabelEncoder
# Specify the models:
= "gpt-4o-mini"
LLM_MODEL = "text-embedding-ada-002"
EMBEDDING_MODEL
= OpenAI(api_key=os.environ['OPENAI_API_KEY'])
client
# Create a function to help us create the embedding:
def get_embeddings(text):
= client.embeddings.create(
response input=[text],
=EMBEDDING_MODEL
model
)
= response.data[0].embedding
embedding return embedding
# Assign the embedding to a new column:
'summary_embedding'] = simpsons_episodes_w_sent['episode_summary'].apply(get_embeddings)
simpsons_episodes_w_sent[
= pd.DataFrame(simpsons_episodes_w_sent['summary_embedding'].tolist(), index=df.index)
embeddings_df
= pd.concat([simpsons_episodes_w_sent[['episode_id','spoken_words','episode_summary','sentiment']], embeddings_df], axis=1)
simpsons_episodes_w_sent_embed
# The final dataframe:
simpsons_episodes_w_sent_embed
# A tibble: 118 × 1,541
episode_id spoken_words episode_summary `0` `1` `2` `3`
<dbl> <chr> <chr> <dbl> <dbl> <dbl> <dbl>
1 450 "Oh, I love g… "grampa meets … 0.0157 -0.0121 0.00869 -0.0278
2 451 "Time to try … "krusty's rati… -0.0190 -0.0282 0.0240 -0.00696
3 452 "Holy moly! Y… "homer wins th… 0.00208 -0.0228 -0.00576 -0.0177
4 453 "I'm not a du… "marge wants h… -0.0112 -0.0257 0.0229 0.00146
5 454 "Hey, I like … "mabel discove… -0.0105 -0.0175 0.00581 -0.0201
6 455 "No. I don't … "marge gets fr… 0.00502 0.00254 -0.00325 -0.0294
7 456 "Thanks to ou… "bart simpson … -0.0102 0.00155 0.00697 -0.0115
8 457 "What about S… "Homer Simpson… 0.00900 -0.0408 -0.0123 -0.00303
9 458 "Burns, you'r… "burns gets ar… 0.00506 -0.0389 0.0129 -0.0444
10 459 "I can't beli… "the chief get… 0.0197 -0.00870 0.00621 -0.0142
# ℹ 108 more rows
# ℹ 1,534 more variables: `4` <dbl>, `5` <dbl>, `6` <dbl>, `7` <dbl>,
# `8` <dbl>, `9` <dbl>, `10` <dbl>, `11` <dbl>, `12` <dbl>, `13` <dbl>,
# `14` <dbl>, `15` <dbl>, `16` <dbl>, `17` <dbl>, `18` <dbl>, `19` <dbl>,
# `20` <dbl>, `21` <dbl>, `22` <dbl>, `23` <dbl>, `24` <dbl>, `25` <dbl>,
# `26` <dbl>, `27` <dbl>, `28` <dbl>, `29` <dbl>, `30` <dbl>, `31` <dbl>,
# `32` <dbl>, `33` <dbl>, `34` <dbl>, `35` <dbl>, `36` <dbl>, `37` <dbl>, …
Now with our embeddings to hand we rearrange our dataframe and create our train and test splits, as well as our cross validation folds:
# Create base dataset for modeling:
%>%
simpsons_episodes_w_sent_embed inner_join(simpsons_episodes %>%
select(id, imdb_rating),
by = c("episode_id" = "id")) %>%
select(episode_id, imdb_rating, everything()) -> base_simpson
# Create splits & folds:
set.seed(123)
<- initial_split(base_simpson, prop = 0.8)
simpson_split <- training(simpson_split)
simpson_train <- testing(simpson_split)
simpson_test
<- vfold_cv(simpson_train, v = 10) simpson_folds
Next, we create our recipe, specifying imdb_rating
as the variable we want to predict:
# Create the recipe:
<- recipe(imdb_rating ~ ., data = simpson_train) %>%
simpson_recipe update_role(episode_id, new_role = "ID") %>%
step_dummy(all_nominal())
We need to create a workflow for our recipe, we do so below:
# Specify the workflow:
<-
xgb_spec boost_tree(
trees = tune(),
mtry = tune(),
min_n = tune(),
learn_rate = 0.01
%>%
) set_engine("xgboost") %>%
set_mode("regression")
# Put the workflow and the recipe together:
<- workflow(simpson_recipe, xgb_spec)
xgb_wf
xgb_wf
For the hyperparameter tuning we will use tune_race_anova()
from the finetune
package, which iteratively eliminates tuning parameters that are unlikely to lead to the best results using a repeated measure ANOVA model:
%>%
xgb_simpson_rs plot_race() +
theme_minimal() +
ggtitle(
label = "Model candidates - Racing Methods",
subtitle = "The model represented by the color red makes it to the final stage"
+
) theme(plot.title.position = "plot")
We can also inspect the best model and it’s parameters:
# Pick the best model according to rmse:
%>%
xgb_simpson_rs show_best(metric = "rmse")
# A tibble: 5 × 9
mtry trees min_n .metric .estimator mean n std_err .config
<int> <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
1 777 1246 13 rmse standard 0.436 10 0.0507 Preprocessor1_Model11
2 3 1571 40 rmse standard 0.447 10 0.0511 Preprocessor1_Model01
3 420 628 15 rmse standard 0.449 10 0.0462 Preprocessor1_Model06
4 1468 1420 4 rmse standard 0.451 10 0.0467 Preprocessor1_Model20
5 289 1322 17 rmse standard 0.459 10 0.0500 Preprocessor1_Model04
The best model achieves an RMSE of 0.436, which is ok, although it could be better. Let’s perform the last fit:
# Conduct the last fit:
<-
xgb_last %>%
xgb_wf finalize_workflow(select_best(xgb_simpson_rs, metric = "rmse")) %>%
last_fit(simpson_split)
Let’s take a look at the variables driving the outcome:
<- extract_fit_parsnip(xgb_last)
fitted_xgb
vip(fitted_xgb, geom = "col") +
theme_minimal() +
ggtitle(label = "Variable importance for XGBoost model") +
theme(plot.title.position = "plot")
It looks like the embeddings with the most importance are 639
and 1111
. What entries have the highest value of this embedding?
# Feature importance dataframe:
%>%
simpson_train select(-imdb_rating, -episode_id) %>%
colnames() %>%
as_tibble() %>%
inner_join(vi(fitted_xgb), by = c("value" = "Variable")) %>%
arrange(desc(Importance)) -> importance_df
<- importance_df$value[1]
top_feature
# Data frame for top feature:
<- base_simpson %>%
top_feature_analysis arrange(desc(!!sym(top_feature))) %>%
head(10)
%>%
top_feature_analysis inner_join(simpsons_episodes_w_sent,
by = join_by(episode_id)) %>%
select(episode_summary,639) %>%
as_tibble()
# A tibble: 10 × 2
episode_summary `636`
<chr> <dbl>
1 a man has a great idea for beer cozies but his daughter is unintere… 0.0182
2 marge opens a sandwich store, faces competition from another franch… -0.0104
3 bart learns about eating insects for protein from a book. -0.00569
4 mabel discovers eliza s Simpson's family was involved in the underg… -0.0191
5 Radioactive Man's arch-nemesis, Petroleus Rex, threatens to destroy… -0.0223
6 a 2nd-grade debate between isabel gutierrez and lisa simpson on a b… 0.0130
7 burns gets arrested and imprisoned for stealing a valuable painting… 0.0127
8 homer gets sent to prison for attempting to bribe a public official… -0.0142
9 homers prank war gets out of hand with vandal art causing a stir in… -0.00410
10 homers dog gets lost at a village set for a medieval movie, mr burn… -0.00306
Finally, let’s use our LLM to come up with some new data we can pass some predictions on. Again we use our local llama3.2
model, with custom prompts:
# Create the system prompt for episode generation:
<- create_message(
episode_system_prompt "You are a creative writer for The Simpsons TV show. Generate brief, creative episode summaries (30 words only) that capture the show's humor and style. Focus on the plot and tone of the episode, do not return the episode title just the summary itself.",
"system"
)
# Provide some high level themes for episode generation:
<- c(
episode_prompts "Create a Simpson episode about Homer getting a new job",
"Create a Simpson episode about Bart's latest school prank",
"Create a Simpson episode about Marge starting a new hobby",
"Create a Simpson episode about Lisa's activism",
"Create a Simpson episode about Springfield facing a crisis"
)
# Create requests for episode generation:
<- lapply(episode_prompts, function(prompt) {
episode_reqs <- append_message(prompt, "user", episode_system_prompt)
messages chat("llama3.2", messages, output = "req")
})
# Activate parallelism:
<- req_perform_parallel(episode_reqs)
episode_resps
# Retrieve the generated summaries:
<- bind_rows(lapply(episode_resps, resp_process, "df"))
episodes_df <- episodes_df$content
episode_summaries episode_summaries
# A tibble: 5 × 1
content
<chr>
1 "Homer lands a dream job as a professional couch tester, but his lack of disc…
2 "When Bart replaces Mrs. Krabappel's grade book with a whoopee cushion, chaos…
3 "When Marge discovers a passion for competitive knitting, Homer's loud outbur…
4 "When a gentrification project threatens to ruin the town's eclectic music sc…
5 "As a sudden, toxic gas leak threatens to poison the town, Homer's antics ina…
Next, we want to do the sentiment analysis of these proposed Simpson episodes. For this let’s use a custom system prompt method, rather than using the llm_sentiment()
function. We spicy the prompt below:
# Create the propt for sentiment analysis:
<- create_message(
sentiment_system_prompt "Your only task is to evaluate the sentiment/tone of Simpson episode summaries. Classify them as 'positive' (upbeat, funny, heartwarming), 'negative' (dark, sad, cynical), or 'neutral' (mixed or balanced tone). Respond with only one word.",
"system"
)
# Create the request:
<- lapply(episode_summaries, function(episode) {
sentiment_reqs <- append_message(
messages paste("Classify the sentiment of this Simpson episode summary:", episode),
"user",
sentiment_system_prompt
)chat("llama3.2", messages, output = "req")
})
# Classify using parallelism:
<- req_perform_parallel(sentiment_reqs)
sentiment_resps
# Retrieve the responses:
<- bind_rows(lapply(sentiment_resps, resp_process, "df"))
sentiments_df
# Combine the df:
<- data.frame(
results episode_summary = episode_summaries,
episode_sentiment = sentiments_df$content,
stringsAsFactors = FALSE
)
results
# A tibble: 5 × 2
content episode_sentiment
<chr> <chr>
1 "Homer lands a dream job as a professional couch tester, bu… Neutral
2 "When Bart replaces Mrs. Krabappel's grade book with a whoo… Negative
3 "When Marge discovers a passion for competitive knitting, H… Negative
4 "When a gentrification project threatens to ruin the town's… Positive
5 "As a sudden, toxic gas leak threatens to poison the town, … Positive
Finally, we create the embeddings based on the newly generated episode summaries, suing the method outlined in the feature engineering section with the text-embedding-ada-002
model:
ollama_episode_embeddings_to_pred
# A tibble: 5 × 1,538
episode_id episode_sentiment `0` `1` `2` `3` `4`
<int> <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
1 1 neutral 0.0152 -0.0226 0.0103 -0.0199 0.00137
2 2 neutral 0.0178 -0.00724 0.00651 -0.0198 -0.00993
3 3 neutral -0.000269 -0.00554 0.0167 -0.0149 -0.00416
4 4 neutral 0.0153 -0.0236 0.00790 -0.00765 -0.00120
5 5 neutral 0.0441 -0.00497 0.00628 -0.0224 -0.0250
# ℹ 1,531 more variables: `5` <dbl>, `6` <dbl>, `7` <dbl>, `8` <dbl>,
# `9` <dbl>, `10` <dbl>, `11` <dbl>, `12` <dbl>, `13` <dbl>, `14` <dbl>,
# `15` <dbl>, `16` <dbl>, `17` <dbl>, `18` <dbl>, `19` <dbl>, `20` <dbl>,
# `21` <dbl>, `22` <dbl>, `23` <dbl>, `24` <dbl>, `25` <dbl>, `26` <dbl>,
# `27` <dbl>, `28` <dbl>, `29` <dbl>, `30` <dbl>, `31` <dbl>, `32` <dbl>,
# `33` <dbl>, `34` <dbl>, `35` <dbl>, `36` <dbl>, `37` <dbl>, `38` <dbl>,
# `39` <dbl>, `40` <dbl>, `41` <dbl>, `42` <dbl>, `43` <dbl>, `44` <dbl>, …
With all our data ready with the features we can then pass it to the model and see how each embedding/feature are driving the predictions:
# Create the predictions, using XGB workflow:
<- predict(xgb_wf, ollama_episode_embeddings_to_pred) %>%
preds cbind(ollama_episode_embeddings %>%
mutate(episode_id = row_number()) %>% select(episode_id,episode_summary,sentiment)) %>%
as_tibble() %>%
select(episode_id, episode_summary, sentiment, .pred = .pred)
# Prep & bake the data:
<- bake(
ollama_episode_embeddings_prep prep(simpson_recipe),
has_role("predictor"),
new_data = NULL,
composition = "matrix"
)
# Create SHAP matrix:
<- shapviz(extract_fit_engine(xgb_last),
shap X_pred = ollama_episode_embeddings_prep,
x = ollama_episode_embeddings_to_pred)
# Create the plots:
sv_waterfall(shap,row_id = 1,
fill_colors = c("steelblue", "orange")) -> p1
sv_waterfall(shap,row_id = 2,
fill_colors = c("steelblue", "orange")) -> p2
sv_waterfall(shap,row_id = 3,
fill_colors = c("#FF6B6B", "#4ECDC4")) -> p3
sv_waterfall(shap,row_id = 4,
fill_colors = c("#FF6B6B", "#4ECDC4")) -> p4
# Compose the plots using patchwork:
+p2)/(p3+p4) +
(p1plot_annotation(
title = "Waterfall force plots for the first four hypothetical episodes",
subtitle = "The first episode is the one with the highest IMDB rating"
+
) theme(plot.title.position = "plot")
We can take a closer look at the Simpsons episode summary with the highest predicted rating:
When a gentrification project threatens to ruin the town’s eclectic music scene, Lisa rallies her friends to protest, but her idealistic efforts clash with Homer’s lovable apathy, and Marge must intervene to save the community.
In this project we explored the 2025-02-04 tidytuesday Simpsons dataset
, with the help of LLMs to uncover key trends in the episodes. These included summarising episode scripts, sentiment analysis and identifying character interactions. This was done using a llama3.2
model running locally. Not only were we able to get a deeper understanding of episodes, but we were also able to use the LLM to generate some features to be used when modeling the episode rating. Here we leverage an embedding model using the Open AI api.
Next, we turned our gaze towards modeling. We took the features generated with the LLMs and created an XGBoost model. We also hyperparameter tuned it, to optimise the model parameters, using racing methods. The results uncovered the variables that were driving the predictions of the episode ratings.
Finally, we leveraged LLM capabilities to create new data, upon which our workflow can make predictions and once again explored the variables driving the outcome.
One thing that might stand out is the extent to which the embeddings create alot of features. A potential improvement to the workflow in the project could be to use dimensionality reduction techniques. This can be done as follows:
# Prep the recipe:
<- recipe(imdb_rating ~ ., data = simpson_train) %>%
simpson_recipe update_role(episode_id, new_role = "ID") %>%
step_dummy(all_nominal()) %>%
step_zv(all_numeric_predictors()) %>%
step_normalize(all_numeric_predictors()) %>%
step_pca(all_numeric_predictors(), id="pca", num_comp=4) %>%
prep()
# Create a data frame:
tidy(simpson_recipe, id="pca", type="coef")%>%
filter(component %in% c("PC1","PC2","PC3","PC4")) %>%
group_by(component) %>%
slice_head(n=10) -> to_plot
Once we have our data with the components we can plot them, looking at the first 4 components and the first 10 embeddings:
# Create the first plot:
%>%
to_plot filter(component == "PC1") %>%
e_charts(terms) %>%
e_bar(value) %>%
e_color("#45B7D1") %>%
e_grid(left = "20%") %>%
e_tooltip(confine=TRUE) %>%
e_x_axis(
name = "embedding",
nameLocation = "middle",
nameGap = 30,
nameTextStyle = list(fontSize = 12),
axisLabel = list(interval = 0)
%>%
) e_y_axis(name = "value") -> p1
# Crete the second plot:
%>%
to_plot filter(component == "PC2") %>%
e_charts(terms) %>%
e_bar(value) %>%
e_color("#4ECDC4") %>%
e_grid(left = "20%") %>%
e_tooltip() %>%
e_x_axis(
name = "embedding",
nameLocation = "middle",
nameGap = 30,
nameTextStyle = list(fontSize = 12),
axisLabel = list(interval = 0)
%>%
) e_y_axis(name = "value") -> p2
# Create the third plot:
%>%
to_plot filter(component == "PC3") %>%
e_charts(terms) %>%
e_bar(value) %>%
e_color("#96CEB4") %>%
e_grid(left = "20%") %>%
e_tooltip(confine=TRUE) %>%
e_x_axis(
name = "embedding",
nameLocation = "middle",
nameGap = 30,
nameTextStyle = list(fontSize = 12),
axisLabel = list(interval = 0)
%>%
) e_y_axis(name = "value") -> p3
# Create the fourth plot:
%>%
to_plot filter(component == "PC4") %>%
e_charts(terms) %>%
e_bar(value) %>%
e_color("#FF6B6B") %>%
e_grid(left = "20%") %>%
e_tooltip() %>%
e_x_axis(
name = "embedding",
nameLocation = "middle",
nameGap = 30,
nameTextStyle = list(fontSize = 12),
axisLabel = list(interval = 0)
%>%
) e_y_axis(name = "value") -> p4
# Crompose the plots:
div(style = "display: grid; grid-template-columns: 1fr 1fr; grid-gap: 10px;",
div(p1),
div(p2),
div(p3),
div(p4)
)
Figure 8. Principal components first 10 embeddings.
Using components instead of the full embeddings might increase performance, but will come at the expense of explainability. Similarly, we might wish to experiment with other models, not only for predicting the episode rating, but also the LLM used. Different LLMs will behave differently and therefore could provide different results. For example, another model might be able to pick out different character interactions like we did in the EDA section, picking out different dynamics, or provide different sentiment for the episodes. Finally, it would be great to do this same analysis but on a much larger dataset of Simpsons episodes (the tidytuesday dataset was limited in scope due to GitHub repo size constraints), to see whether the same trends are unveiled in other seasons.
sessionInfo()
R version 4.4.1 (2024-06-14)
Platform: aarch64-apple-darwin20
Running under: macOS 15.5
Matrix products: default
BLAS: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.12.0
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
time zone: Europe/London
tzcode source: internal
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] htmltools_0.5.8.1 httr2_1.1.2 DALEX_2.4.3 vip_0.4.1
[5] finetune_1.2.1 patchwork_1.3.0 shapviz_0.9.7 echarts4r_0.4.5
[9] highcharter_0.9.4 mall_0.1.0 ellmer_0.1.1 ollamar_1.2.2
[13] tidytext_0.4.2 yardstick_1.3.1 workflowsets_1.1.0 workflows_1.1.4
[17] tune_1.2.1 rsample_1.2.1 recipes_1.1.0 parsnip_1.2.1
[21] modeldata_1.4.0 infer_1.0.7 dials_1.3.0 scales_1.3.0
[25] broom_1.0.6 tidymodels_1.2.0 lubridate_1.9.3 forcats_1.0.0
[29] stringr_1.5.1 dplyr_1.1.4 purrr_1.0.4 readr_2.1.5
[33] tidyr_1.3.1 tibble_3.2.1 ggplot2_3.5.2 tidyverse_2.0.0
loaded via a namespace (and not attached):
[1] rstudioapi_0.16.0 jsonlite_1.8.9 magrittr_2.0.3
[4] farver_2.1.2 rmarkdown_2.29 fs_1.6.5
[7] vctrs_0.6.5 shades_1.4.0 curl_6.2.2
[10] janeaustenr_1.0.0 xgboost_1.7.8.1 TTR_0.24.4
[13] parallelly_1.38.0 htmlwidgets_1.6.4 tokenizers_0.3.0
[16] zoo_1.8-12 ggfittext_0.10.2 igraph_2.0.3
[19] mime_0.12 lifecycle_1.0.4 iterators_1.0.14
[22] pkgconfig_2.0.3 Matrix_1.7-0 R6_2.5.1
[25] fastmap_1.2.0 future_1.34.0 shiny_1.8.1.1
[28] digest_0.6.37 colorspace_2.1-0 furrr_0.3.1
[31] SnowballC_0.7.1 labeling_0.4.3 timechange_0.3.0
[34] compiler_4.4.1 bit64_4.5.2 withr_3.0.2
[37] S7_0.2.0 backports_1.5.0 MASS_7.3-60.2
[40] lava_1.8.0 rappdirs_0.3.3 tools_4.4.1
[43] quantmod_0.4.26 httpuv_1.6.15 future.apply_1.11.2
[46] nnet_7.3-19 glue_1.8.0 promises_1.3.2
[49] grid_4.4.1 gggenes_0.5.1 generics_0.1.3
[52] gtable_0.3.5 tzdb_0.4.0 class_7.3-22
[55] data.table_1.15.4 hms_1.1.3 utf8_1.2.4
[58] foreach_1.5.2 pillar_1.10.0 vroom_1.6.5
[61] later_1.4.2 splines_4.4.1 lhs_1.2.0
[64] lattice_0.22-6 survival_3.7-0 bit_4.5.0.1
[67] tidyselect_1.2.1 coro_1.1.0 knitr_1.49
[70] xfun_0.50 hardhat_1.4.0 timeDate_4032.109
[73] stringi_1.8.4 DiceDesign_1.10 yaml_2.3.10
[76] pacman_0.5.1 evaluate_1.0.3 codetools_0.2-20
[79] archive_1.1.12 cli_3.6.3 rpart_4.1.23
[82] reticulate_1.40.0 xtable_1.8-4 munsell_0.5.1
[85] Rcpp_1.0.13-1 globals_0.16.3 png_0.1-8
[88] parallel_4.4.1 gower_1.0.1 assertthat_0.2.1
[91] GPfit_1.0-8 listenv_0.9.1 ipred_0.9-15
[94] rlist_0.4.6.2 xts_0.14.0 prodlim_2024.06.25
[97] crayon_1.5.3 rlang_1.1.6