A XGBoost Classification Model on Beach Volleyball in R

An example of Classification using XGBoost and Tidymodels in R

Image by valentinaalemanno from Pixabay

NOTE: These days I am following Julia Silge for learning tidymodels framework better. This post is inspired from what I learned from her. You can find a screencast of her vidoes here

Beach volleyball is a team sport played by two teams of two players on a sand court divided by a net. As in indoor volleyball, the objective of the game is to send the ball over the net and to ground it on the opponent’s side of the court, and to prevent the same effort by the opponent. Beach Volleyball matches are quite popular around the world. The game holds high popularity in countries such as US and Brazil. Beach volleyball most likely originated in 1915 on Waikiki Beach in Hawaii, while the modern two-player game originated in Santa Monica, California. It has been an Olympic sport since the 1996 Summer Olympics.

This dataset contains a huge record of beach volleyball matches. There are approximately 76500 rows in this dataset. Each row contains statistics of one match. Some of the important features available in the data are gender, stats related to winners and losers, match outcome, date, player details etc;

The objective of this modelling exercise is to predict the outcome of the match using information available. It is a binary classification problem and we have several ways to handle such problems, however, in this exercise we will use xgboost algorithm for classification. Let’s get started.

Let’s load the dataset

vb_matches <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-05-19/vb_matches.csv', guess_max = 76000)

vb_matches %>% 
    head() %>% 
    knitr::kable()
circuittournamentcountryyeardategendermatch_numw_player1w_p1_birthdatew_p1_agew_p1_hgtw_p1_countryw_player2w_p2_birthdatew_p2_agew_p2_hgtw_p2_countryw_rankl_player1l_p1_birthdatel_p1_agel_p1_hgtl_p1_countryl_player2l_p2_birthdatel_p2_agel_p2_hgtl_p2_countryl_rankscoredurationbracketroundw_p1_tot_attacksw_p1_tot_killsw_p1_tot_errorsw_p1_tot_hitpctw_p1_tot_acesw_p1_tot_serve_errorsw_p1_tot_blocksw_p1_tot_digsw_p2_tot_attacksw_p2_tot_killsw_p2_tot_errorsw_p2_tot_hitpctw_p2_tot_acesw_p2_tot_serve_errorsw_p2_tot_blocksw_p2_tot_digsl_p1_tot_attacksl_p1_tot_killsl_p1_tot_errorsl_p1_tot_hitpctl_p1_tot_acesl_p1_tot_serve_errorsl_p1_tot_blocksl_p1_tot_digsl_p2_tot_attacksl_p2_tot_killsl_p2_tot_errorsl_p2_tot_hitpctl_p2_tot_acesl_p2_tot_serve_errorsl_p2_tot_blocksl_p2_tot_digs
AVPHuntington BeachUnited States20022002-05-24M1Kevin Wong1972-09-1229.6947379United StatesStein Metzger1972-11-1729.5140375United States1Chuck Moore1973-08-1828.7638676United StatesEd Ratledge1976-12-1625.4346380United States3221-18, 21-1200:33:00Winner’s BracketRound 1NANANANA1NA7NANANANANA2NA0NANANANANA1NA0NANANANANA0NA1NA
AVPHuntington BeachUnited States20022002-05-24M2Brad Torsone1975-01-1427.3566178United StatesCasey Jennings1975-07-1026.8720175United States16Mark Paaluhi1971-03-0831.2115075United StatesNick Hannemann1972-01-1230.3627778United States1721-16, 17-21, 15-1000:57:00Winner’s BracketRound 1NANANANA0NA4NANANANANA4NA0NANANANANA0NA2NANANANANA0NA0NA
AVPHuntington BeachUnited States20022002-05-24M3Eduardo Bacil1971-03-1131.2032974BrazilFred Souza1972-05-1330.0287579Brazil24Adam Jewell1975-06-2426.9158177United StatesCollin Smith1975-05-2626.9952176United States921-18, 21-1800:46:00Winner’s BracketRound 1NANANANA0NA2NANANANANA0NA4NANANANANA1NA1NANANANANA0NA0NA
AVPHuntington BeachUnited States20022002-05-24M4Brent Doble1970-01-0332.3860478United StatesKarch Kiraly1960-11-0341.5523674United States8David Swatik1973-02-1429.2703676United StatesMike Mattarocci1969-10-0532.6324480United States2521-16, 21-1500:44:00Winner’s BracketRound 1NANANANA0NA3NANANANANA0NA0NANANANANA0NA2NANANANANA2NA0NA
AVPHuntington BeachUnited States20022002-05-24M5Albert Hannemann1970-05-0432.0547675United StatesJeff Nygaard1972-08-0329.8042480United States5Adam Roberts1976-01-2526.3271773United StatesJim Walls1978-03-2624.1615375United States2820-22, 23-21, 15-1001:08:00Winner’s BracketRound 1NANANANA1NA0NANANANANA0NA6NANANANANA0NA0NANANANANA0NA1NA
AVPHuntington BeachUnited States20022002-05-24M6Jason Ring1974-07-2127.8412075United StatesPaul Baxter1972-02-0130.3080177United States12Eli Fairfield1979-02-1023.28268NAUnited StatesJuan Rodriguez Ibarra1969-05-3032.9828976Mexico2121-15, 16-21, 15-1100:55:00Winner’s BracketRound 1NANANANA0NA0NANANANANA0NA0NANANANANA0NA0NANANANANA0NA0NA

You will see guess_max argument in read_csv here. A default property of read_csv is to guess the type of column by looking at the data in these columns. Generally it tries to detect the type by looking at first few rows in each column. guess_max controls how many rows it should see before it assigns a type to the column.

Let’s explore the dataset

We will use skim function to understand the data. As you can see we have 76756 rows with 65 variables. These are moderately large number of variables. We have character, numeric and date variables. We can see lot of missing data, hence it makes sense to explore missing data separately.

skimr::skim(vb_matches)
Namevb_matches
Number of rows76756
Number of columns65
_______________________
Column type frequency:
character17
Date5
difftime1
numeric42
________________________
Group variablesNone

Data summary

Variable type: character

skim_variablen_missingcomplete_rateminmaxemptyn_uniquewhitespace
circuit01.0034020
tournament01.0032201770
country01.004220510
gender01.0011020
w_player101.00629033880
w_p1_country121.004200850
w_player201.00530034310
w_p2_country51.004200870
w_rank1481.001708120
l_player101.00529057130
l_p1_country181.0042001090
l_player201.00530056890
l_p2_country101.0042001110
l_rank12400.981708370
score221.00425066240
bracket01.006210360
round49390.94780100

Variable type: Date

skim_variablen_missingcomplete_rateminmaxmediann_unique
date01.002000-09-162019-08-292009-08-25658
w_p1_birthdate3831.001953-06-132004-07-151981-10-302805
w_p2_birthdate4080.991952-10-112004-06-081981-10-152847
l_p1_birthdate10590.991953-06-132004-12-011982-03-284236
l_p2_birthdate9590.991949-12-042004-08-121982-03-204282

Variable type: difftime

skim_variablen_missingcomplete_rateminmaxmediann_unique
duration22490.97120 secs8040 secs00:42:00108

Variable type: numeric

skim_variablen_missingcomplete_ratemeansdp0p25p50p75p100hist
year01.002010.295.482000.002006.002009.002015.002019.00▃▇▆▅▇
match_num01.0031.8423.551.0013.0027.0047.00137.00▇▅▂▁▁
w_p1_age3831.0028.735.0513.5525.0128.3432.0559.86▂▇▃▁▁
w_p1_hgt39660.9573.713.6463.0071.0074.0076.0085.00▁▅▇▃▁
w_p2_age4080.9928.834.8513.4125.3028.6031.9552.40▁▇▆▁▁
w_p2_hgt40160.9573.733.6961.0071.0074.0076.0085.00▁▃▇▆▁
l_p1_age10590.9928.345.2613.2324.5427.9431.6860.86▂▇▂▁▁
l_p1_hgt69880.9173.403.6261.0071.0073.0076.0085.00▁▃▇▅▁
l_p2_age9590.9928.375.1213.4124.7128.0131.5867.76▂▇▁▁▁
l_p2_hgt69830.9173.523.6561.0071.0074.0076.0085.00▁▃▇▅▁
w_p1_tot_attacks621780.1925.8910.000.0019.0024.0032.00142.00▇▅▁▁▁
w_p1_tot_kills621780.1914.735.340.0011.0014.0018.0040.00▂▇▅▁▁
w_p1_tot_errors624130.192.902.270.001.002.004.0032.00▇▁▁▁▁
w_p1_tot_hitpct621850.190.480.23-0.700.380.480.5820.00▇▁▁▁▁
w_p1_tot_aces605600.211.321.450.000.001.002.0014.00▇▂▁▁▁
w_p1_tot_serve_errors624170.192.031.650.001.002.003.0013.00▇▃▁▁▁
w_p1_tot_blocks605600.211.702.150.000.001.003.0014.00▇▂▁▁▁
w_p1_tot_digs621780.198.355.480.004.008.0012.0046.00▇▅▁▁▁
w_p2_tot_attacks621740.1926.1110.090.0019.0025.0032.00124.00▇▇▁▁▁
w_p2_tot_kills621740.1914.805.330.0011.0014.0018.0041.00▂▇▅▁▁
w_p2_tot_errors624130.192.922.290.001.002.004.0034.00▇▁▁▁▁
w_p2_tot_hitpct621810.190.480.16-0.680.370.470.583.50▁▇▁▁▁
w_p2_tot_aces605560.211.191.360.000.001.002.0010.00▇▁▁▁▁
w_p2_tot_serve_errors624130.191.931.620.001.002.003.0013.00▇▃▁▁▁
w_p2_tot_blocks605560.211.692.190.000.001.003.0016.00▇▂▁▁▁
w_p2_tot_digs621740.198.545.560.004.008.0012.0052.00▇▃▁▁▁
l_p1_tot_attacks621790.1927.1311.110.0019.0026.0034.00330.00▇▁▁▁▁
l_p1_tot_kills621790.1912.775.760.009.0012.0016.0041.00▃▇▃▁▁
l_p1_tot_errors624130.194.382.760.002.004.006.0030.00▇▂▁▁▁
l_p1_tot_hitpct621890.190.310.18-0.800.210.320.424.25▃▇▁▁▁
l_p1_tot_aces605610.210.781.040.000.000.001.009.00▇▂▁▁▁
l_p1_tot_serve_errors624180.192.101.660.001.002.003.0012.00▇▃▁▁▁
l_p1_tot_blocks605610.211.001.530.000.000.002.0014.00▇▁▁▁▁
l_p1_tot_digs621790.197.195.170.003.006.0010.0051.00▇▂▁▁▁
l_p2_tot_attacks621780.1926.6810.81-6.0019.0026.0033.00128.00▃▇▁▁▁
l_p2_tot_kills621780.1912.575.660.008.0012.0016.0042.00▃▇▃▁▁
l_p2_tot_errors624130.194.322.710.002.004.006.0028.00▇▃▁▁▁
l_p2_tot_hitpct621890.190.310.18-0.670.210.320.423.50▂▇▁▁▁
l_p2_tot_aces605600.210.781.060.000.000.001.0011.00▇▁▁▁▁
l_p2_tot_serve_errors624170.192.051.660.001.002.003.0015.00▇▂▁▁▁
l_p2_tot_blocks605600.211.061.560.000.000.002.0013.00▇▁▁▁▁
l_p2_tot_digs621780.197.145.180.003.006.0010.0041.00▇▃▁▁▁

We will explore the missing data now. We can use dataexplorer package to explore missing data as it provides some very useful visual summaries of the dataset.

DataExplorer::plot_str(vb_matches)
DataExplorer::plot_intro(vb_matches)

Missing 1_1

Missing 1

DataExplorer::plot_missing(vb_matches)

Missing 2

Prepare data

We will do some transformations to make the data modelling ready.

  • We should merge player 1 and player 2 stats for both winners and losers and treat it as one column.
  • We should use gender, circuit and year variables and drop rest of the columns for now
vb_parsed <- vb_matches %>%
    transmute(
        circuit,
        gender,
        year,
        w_attacks = w_p1_tot_attacks + w_p2_tot_attacks,
        w_kills = w_p1_tot_kills + w_p2_tot_kills,
        w_errors = w_p1_tot_errors + w_p2_tot_errors,
        w_aces = w_p1_tot_aces + w_p2_tot_aces,
        w_serve_errors = w_p1_tot_serve_errors + w_p2_tot_serve_errors,
        w_blocks = w_p1_tot_blocks + w_p2_tot_blocks,
        w_digs = w_p1_tot_digs + w_p2_tot_digs,
        l_attacks = l_p1_tot_attacks + l_p2_tot_attacks,
        l_kills = l_p1_tot_kills + l_p2_tot_kills,
        l_errors = l_p1_tot_errors + l_p2_tot_errors,
        l_aces = l_p1_tot_aces + l_p2_tot_aces,
        l_serve_errors = l_p1_tot_serve_errors + l_p2_tot_serve_errors,
        l_blocks = l_p1_tot_blocks + l_p2_tot_blocks,
        l_digs = l_p1_tot_digs + l_p2_tot_digs
    ) %>%
    na.omit()

winners <- vb_parsed %>% 
    select(circuit,
           gender,
           year,
           w_attacks:w_digs) %>% 
    rename_with(~str_remove_all(., "w_")) %>% 
    mutate(win = "win")

losers <- vb_parsed %>% 
    select(circuit,
           gender,
           year,
           l_attacks:l_digs) %>% 
    rename_with(~str_remove_all(., "l_")) %>% 
    mutate(win = "lose")

vb_df <- bind_rows(winners, losers) %>% 
    mutate_if(is.character, factor)

vb_df %>% 
    head() %>% 
    knitr::kable()
circuitgenderyearattackskillserrorsacesserve_errorsblocksdigswin
AVPM20044524702511win
AVPM200471311638721win
AVPM20044326524710win
AVPM2004423252377win
AVPM20044431105615win
AVPM20045531604822win

Finally, we have created two separate data frames for winner and losers and then combined them by binding rows. Our final dataframe is ready. We can use it for modelling now. It contains 28664 rows and 11 columns. Rows are lesser than the raw data as we have straight away removed missing values. Sample size is still large enough to perform modelling.

Lets perform some EDA

We can plot some visualizations to check if these variables are good enough to differentiate the two classes i.e. win and loss.

vb_df %>%
  pivot_longer(attacks:digs, names_to = "stat", values_to = "value") %>%
  ggplot(aes(
    x = gender,
    y = value,
    fill = win,
    color = win
  )) +
  geom_boxplot(alpha = 0.5) +
  facet_wrap( ~ stat, scales = "free_y", nrow = 2) +
  labs(x = NULL,
       fill = NULL,
       color = NULL)

EDA 1

vb_df %>%
  pivot_longer(attacks:digs, names_to = "stat", values_to = "value") %>%
  ggplot(aes(
    x = circuit,
    y = value,
    fill = win,
    color = win
  )) +
  geom_boxplot(alpha = 0.5) +
  facet_wrap( ~ stat, scales = "free_y", nrow = 2) +
  labs(x = NULL,
       fill = NULL,
       color = NULL)

EDA 2

These boxplots confirm that some of the features such as kills, errors, attacks are good predictors of outcome class. It is evident visually, So we will confirm our hypothesis while modelling.

Build a model

As we are going to build a xgboost model so there is no separate need to pre-process the data. The algorithm takes care of nominal variables by dummifying them and normalizing the numerical variables. Hence we will do following things now in order to build the model

  • We will create a split
  • We will create model specs
  • We will create grid for model tuning
  • We will create a workflow for systematic process flow
  • We will also create resamples using k-fold cross validation for effective tuning.
  • Finally we will train the model

Data Split

set.seed(123)
vb_split <- vb_df %>% 
  initial_split(strata = win)

vb_train <- training(vb_split)
vb_test <- testing(vb_split)
vb_split
## <Analysis/Assess/Total>
## <21498/7166/28664>

Model Spec

In arguments we have defined parameters but have not specified any value as we are going to find the best hyper parameters by tuning the model.

xgb_spec <- boost_tree(
  trees = 1000,
  tree_depth = tune(),
  min_n = tune(),
  loss_reduction = tune(),
  sample_size = tune(),
  mtry = tune(),
  learn_rate = tune()
) %>%
  set_engine(engine = "xgboost") %>%
  set_mode(mode = "classification")

xgb_spec
## Boosted Tree Model Specification (classification)
## 
## Main Arguments:
##   mtry = tune()
##   trees = 1000
##   min_n = tune()
##   tree_depth = tune()
##   learn_rate = tune()
##   loss_reduction = tune()
##   sample_size = tune()
## 
## Computational engine: xgboost

Tuning Grid using Hypercube

xgb_grid <- grid_latin_hypercube(tree_depth(),
                     min_n(),
                     loss_reduction(),
                     sample_size = sample_prop(),
                     finalize(mtry(),vb_train),
                     learn_rate(),
                     size = 20)
xgb_grid
## # A tibble: 20 x 6
##    tree_depth min_n loss_reduction sample_size  mtry learn_rate
##         <int> <int>          <dbl>       <dbl> <int>      <dbl>
##  1          6     7       1.85e- 5       0.125    10   2.34e- 4
##  2          7    25       5.43e- 7       0.922     3   1.45e- 5
##  3         13    19       9.87e- 2       0.537     4   4.11e- 7
##  4         11    39       2.57e- 1       0.303     7   6.58e- 3
##  5         13    11       9.46e- 3       0.561     5   6.87e- 6
##  6          9    30       2.34e+ 1       0.354     4   4.37e- 3
##  7         15     6       7.46e+ 0       0.617     7   1.35e- 4
##  8          4    36       4.93e- 4       0.663     5   5.36e- 5
##  9          9    27       3.31e- 6       0.170     9   1.23e- 3
## 10          2    15       9.23e- 9       0.979     1   1.56e- 6
## 11         10    33       1.12e- 7       0.859     8   1.19e- 8
## 12          5    19       2.91e- 9       0.449     2   2.69e- 8
## 13         12    37       1.14e+ 0       0.897    11   1.81e-10
## 14          2     3       2.33e- 2       0.747     3   1.22e- 7
## 15         10    21       4.50e-10       0.817     2   1.15e- 9
## 16          1     9       3.99e- 8       0.190     8   3.02e- 7
## 17          4    13       1.94e- 3       0.468     6   2.89e- 9
## 18          5    31       1.67e- 4       0.379     9   4.96e-10
## 19          8    15       1.02e-10       0.713    10   4.86e- 2
## 20         14    28       4.07e- 6       0.274     6   1.62e- 2

Model Workflow

xgb_wf <- workflow() %>%
  add_formula(win ~ .) %>%
  add_model(xgb_spec)
xgb_wf
## == Workflow ========================================================================================================
## Preprocessor: Formula
## Model: boost_tree()
## 
## -- Preprocessor ----------------------------------------------------------------------------------------------------
## win ~ .
## 
## -- Model -----------------------------------------------------------------------------------------------------------
## Boosted Tree Model Specification (classification)
## 
## Main Arguments:
##   mtry = tune()
##   trees = 1000
##   min_n = tune()
##   tree_depth = tune()
##   learn_rate = tune()
##   loss_reduction = tune()
##   sample_size = tune()
## 
## Computational engine: xgboost

k-fold Cross Validation for resampling

set.seed(123)
vb_folds <- vfold_cv(vb_train, strata = win)
vb_folds
## #  10-fold cross-validation using stratification 
## # A tibble: 10 x 2
##    splits               id    
##    <list>               <chr> 
##  1 <split [19.3K/2.1K]> Fold01
##  2 <split [19.3K/2.1K]> Fold02
##  3 <split [19.3K/2.1K]> Fold03
##  4 <split [19.3K/2.1K]> Fold04
##  5 <split [19.3K/2.1K]> Fold05
##  6 <split [19.3K/2.1K]> Fold06
##  7 <split [19.3K/2.1K]> Fold07
##  8 <split [19.3K/2.1K]> Fold08
##  9 <split [19.3K/2.1K]> Fold09
## 10 <split [19.4K/2.1K]> Fold10

Setting Parallel environment for faster processing and fitting the model

XGBoost is a computationally intensive algorithm, hence it is always better to use parallel processing to save on some time. Argument verbose = T is optional , It simply shows the progress of model fitting in R console so you can see the progress of model convergence.

doParallel::registerDoParallel()
set.seed(234)
xgb_res <- tune_grid(
  xgb_wf,
  resamples = vb_folds,
  grid = xgb_grid,
  control = control_grid(save_pred = T, verbose = T)
)

Explore the results

Lets see what are the hyperparameters that yield best AUC value.

best_auc <- xgb_res %>% 
  show_best(metric = "roc_auc")
best_auc
## # A tibble: 5 x 11
##    mtry min_n tree_depth learn_rate loss_reduction sample_size .metric
##   <int> <int>      <int>      <dbl>          <dbl>       <dbl> <chr>  
## 1     6    28         14 0.0162     0.00000407           0.274 roc_auc
## 2     7    39         11 0.00658    0.257                0.303 roc_auc
## 3    10    15          8 0.0486     0.000000000102       0.713 roc_auc
## 4     7     6         15 0.000135   7.46                 0.617 roc_auc
## 5     5    11         13 0.00000687 0.00946              0.561 roc_auc
## # ... with 4 more variables: .estimator <chr>, mean <dbl>, n <int>,
## #   std_err <dbl>

Similarly, Lets see what are the hyperparameters that yield best Accuracy.

best_acc <- xgb_res %>% 
  show_best(metric = "accuracy")
best_acc
## # A tibble: 5 x 11
##    mtry min_n tree_depth learn_rate loss_reduction sample_size .metric
##   <int> <int>      <int>      <dbl>          <dbl>       <dbl> <chr>  
## 1     6    28         14    1.62e-2 0.00000407           0.274 accura~
## 2     7    39         11    6.58e-3 0.257                0.303 accura~
## 3    10    15          8    4.86e-2 0.000000000102       0.713 accura~
## 4     5    11         13    6.87e-6 0.00946              0.561 accura~
## 5     4    19         13    4.11e-7 0.0987               0.537 accura~
## # ... with 4 more variables: .estimator <chr>, mean <dbl>, n <int>,
## #   std_err <dbl>

Visualize the results

We can visually see how average performance of the model in terms of AUC changed for different values of model parameters during convergence.

xgb_res %>%
  collect_metrics() %>%
  filter(.metric == "roc_auc") %>%
  select(mean, mtry:sample_size) %>%
  pivot_longer(mtry:sample_size,
               names_to = "parameter",
               values_to = "value") %>%
  ggplot(aes(value, mean, color = parameter), alpha = 0.5) +
  geom_line(size = 1, show.legend = F) +
  facet_wrap( ~ parameter, scales = "free_x") 

Model Perf 1

xgb_res %>%
  collect_metrics() %>%
  transmute(trial = row_number(), .metric, mean) %>%
  ggplot(aes(trial, mean, color = .metric), alpha = 0.5) +
  geom_line(size = 1, show.legend = T) +
  geom_point(size = 2, show.legend = T) +
  labs(x = "Trials",
       y = "Mean Parametric Value",
       title = "Accuracy & AUC trends for different trials")  

Model Perf 2

Select best hyper parameter

We will select the best hyperparameter based on best roc_auc and finalize the workflow

best_auc <- xgb_res %>% 
  select_best(metric = "roc_auc")

final_xgb <- xgb_wf %>% 
  finalize_workflow(best_auc)

final_xgb
## == Workflow ========================================================================================================
## Preprocessor: Formula
## Model: boost_tree()
## 
## -- Preprocessor ----------------------------------------------------------------------------------------------------
## win ~ .
## 
## -- Model -----------------------------------------------------------------------------------------------------------
## Boosted Tree Model Specification (classification)
## 
## Main Arguments:
##   mtry = 6
##   trees = 1000
##   min_n = 28
##   tree_depth = 14
##   learn_rate = 0.016199890154566
##   loss_reduction = 4.07341934175922e-06
##   sample_size = 0.27370679514017
## 
## Computational engine: xgboost

Variable Importance Metrics

It will be interesting to see which variable according the model are of high importance in predicting the outcome.

library(vip)

final_xgb %>% 
  fit(data = vb_train) %>% 
  pull_workflow_fit() %>% 
  vip(geom = "point")

VIP Plot

VIP plot suggests that kills, errors, attacks, blocks, digs are the most important variable. If you remember this was our hypothesis before building the model. Hence it confirms our hypothesis.

Training and Testing Final Model

Finally, we will train our final model again on the whole training data and will test on test data. Final Model Performance is as follows:

final_res <- final_xgb %>% 
  last_fit(vb_split)

final_metric <- final_res %>%
  collect_metrics()

final_metric    
## # A tibble: 2 x 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.842
## 2 roc_auc  binary         0.927

Final model performs more or less similar to training which indicates there is no overfitting and we are good to go further.

Confusion Metric

final_conf_mat <- final_res %>% 
  collect_predictions() %>% 
  conf_mat(win, .pred_class)

final_conf_mat
##           Truth
## Prediction lose  win
##       lose 2967  513
##       win   616 3070

ROC AUC Curve Plot

final_res %>% 
  collect_predictions() %>% 
  roc_curve(win, .pred_win) %>% 
  autoplot()

Final AUC

Save our data and objects

Although this is not a necessary step in modelling but it is good practice to save the .Robjects as it saves time when we load the model next time. We simply can load all the objects back again rather than creating them again in next session.

save.image(file = "allobjects.RData")

We did it..!! We created a XGBoost Model to predict win in a beach volleyball match. We tuned hyperparameters using a standard tidymodels workflow and measured the performance of the model which is pretty good in this case. I hope this helps. Thank you so much for reading. See you again in the next post..!!

Gaurav Sharma
Gaurav Sharma
Data Enthusiast | Engineer | INTJ

My research interests include Manufacturing Analytics, Industry 4.0 and Factory Digitization.

Related