Chapter 28 Nonlinear regression
Nonlinear regression is a form of regression analysis in which observational data are modeled by a function which is a nonlinear combination of the model parameters and depends on one or more independent variables.
Some nonlinear data sets can be transformed to a linear model.
Sone can not be transformed. For such modeling methods of Numerical analysis should be applied such as Newton’s method, Gauss-Newton method and Levenberg–Marquardt method.
Математическое моделирование7
Практика
Нелинейные модели:
В практических примерах ниже показано как
оценивать полиномиальную регрессию;
аппроксимировать нелинейные модели ступенчатыми функциями;
строить сплайны;
работать с локальной регрессией;
строить обобщённые линейные модели (GAM).: полиномиальная регрессия, полиномиальная логистическая регрессия, ступенчатая модель, обобщённая линейная модель.
Модели: Wage {ISLR}
Данные
1], глава 7.
Подробные комментарии к коду лабораторных см. в [
library('ISLR') # набор данных Auto
library('splines') # сплайны
library('gam') # обобщённые аддитивные модели
## Warning: package 'gam' was built under R version 3.3.3
## Loading required package: foreach
## Warning: package 'foreach' was built under R version 3.3.3
## Loaded gam 1.14
library('akima') # график двумерной плоскости
## Warning: package 'akima' was built under R version 3.3.3
library('ggplot2') # красивые графики
## Warning: package 'ggplot2' was built under R version 3.3.3
<- 1
my.seed 3000 работников-мужчин среднеатлантического региона Wage. Присоединяем его к пространству имён функцией attach(), и дальше обращаемся напрямую к столбцам таблицы.
Работаем с набором данных по зарплатам
attach(Wage)
:
Работаем со столбцами* wage – заработная плата работника до уплаты налогов;
* age – возраст работника в годах.
Полиномиальная регрессия
Зависимость зарплаты от возраста250.
Судя по графику ниже, ззаимосвязь заработной платы и возраста нелинейна. Наблюдается также группа наблюдений с высоким значением wage, граница проходит примерно на уровне
<- ggplot(data = Wage, aes(x = age, y = wage))
gp <- gp + geom_point() + geom_abline(slope = 0, intercept = 250, col = 'red')
gp
gp
Подгоняем полином четвёртой степени для зависимости заработной платы от возраста.
<- lm(wage ~ poly(age, 4), data = Wage)
fit round(coef(summary(fit)), 2)
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 111.70 0.73 153.28 0.00
## poly(age, 4)1 447.07 39.91 11.20 0.00
## poly(age, 4)2 -478.32 39.91 -11.98 0.00
## poly(age, 4)3 125.52 39.91 3.14 0.00
## poly(age, 4)4 -77.91 39.91 -1.95 0.05
poly(age, 4) создаёт таблицу с базисом ортогональных полиномов: линейные комбинации значений переменной age в степенях от 1 до 4.
Функция
round(head(poly(age, 4)), 3)
## 1 2 3 4
## [1,] -0.039 0.056 -0.072 0.087
## [2,] -0.029 0.026 -0.015 -0.003
## [3,] 0.004 -0.015 0.000 0.014
## [4,] 0.001 -0.015 0.005 0.013
## [5,] 0.012 -0.010 -0.011 0.010
## [6,] 0.018 -0.002 -0.017 -0.001
# можно получить сами значения age в заданных степенях
round(head(poly(age, 4, raw = T)), 3)
## 1 2 3 4
## [1,] 18 324 5832 104976
## [2,] 24 576 13824 331776
## [3,] 45 2025 91125 4100625
## [4,] 43 1849 79507 3418801
## [5,] 50 2500 125000 6250000
## [6,] 54 2916 157464 8503056
# на прогноз не повлияет, но оценки параметров изменяются
.2 <- lm(wage ~ poly(age, 4, raw = T), data = Wage)
fitround(coef(summary(fit.2)), 2)
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -184.15 60.04 -3.07 0.00
## poly(age, 4, raw = T)1 21.25 5.89 3.61 0.00
## poly(age, 4, raw = T)2 -0.56 0.21 -2.74 0.01
## poly(age, 4, raw = T)3 0.01 0.00 2.22 0.03
## poly(age, 4, raw = T)4 0.00 0.00 -1.95 0.05
# границы изменения переменной age
<- range(age)
agelims
# значения age, для которых делаем прогноз (от min до max с шагом 1)
<- seq(from = agelims[1], to = agelims[2])
age.grid
# рассчитать прогнозы и их стандартные ошибки
<- predict(fit, newdata = list(age = age.grid), se = T)
preds
# границы доверительного интервала для заработной платы
<- cbind(lower.bound = preds$fit - 2*preds$se.fit,
se.bands upper.bound = preds$fit + 2*preds$se.fit)
# смотрим результат
round(head(se.bands), 2)
## lower.bound upper.bound
## 1 41.33 62.53
## 2 49.76 67.24
## 3 57.39 71.76
## 4 64.27 76.09
## 5 70.44 80.27
## 6 75.94 84.28
4 презентации (рис. 7.1 книги). Функция matlines() рисует грфик столбцов одной матрицы против столбцов другой.
Рисуем левую панель графика со слайда
# наблюдения
plot(age, wage, xlim = agelims, cex = 0.5, col = 'darkgrey')
# заголовок
title('Полином четвёртой степени')
# модель
lines(age.grid, preds$fit, lwd = 2, col = 'blue')
# доверительные интервалы прогноза
matlines(x = age.grid, y = se.bands, lwd = 1, col = 'blue', lty = 3)
poly() совпадают.
Убедимся, что прогнозы по моделям с различными вызовами
# прогнозы по второму вызову модели
<- predict(fit.2, newdata = list(age = age.grid), se = T)
preds2
# максимальное расхождение между прогнозами по двум вариантам вызова модели
max(abs(preds$fit - preds2$fit))
## [1] 7.389644e-13
1 до 5 с помощью дисперсионного анализа (ANOVA).
Теперь подбираем степень полинома, сравнивая модели со степенями от
.1 <- lm(wage ~ age, data = Wage)
fit.2 <- lm(wage ~ poly(age, 2), data = Wage)
fit.3 <- lm(wage ~ poly(age, 3), data = Wage)
fit.4 <- lm(wage ~ poly(age, 4), data = Wage)
fit.5 <- lm(wage ~ poly(age, 5), data = Wage)
fit
round(anova(fit.1, fit.2, fit.3, fit.4, fit.5), 2)
Res.Df<dbl>
RSS<dbl>
Df<dbl>
Sum of Sq<dbl>
F<dbl>
Pr(>F)
<dbl>
2998 5022216 NA NA NA NA
2997 4793430 1 228786.01 143.59 0.00
2996 4777674 1 15755.69 9.89 0.00
2995 4771604 1 6070.15 3.81 0.05
2994 4770322 1 1282.56 0.80 0.37
5 rows
-значения для проверки нулевой гипотезы: текущая модель не даёт статистически значимого сокращения RSS по сравнению с предыдущей моделью. Можно сделать вывод, что степени 3 достаточно, дальнейшее увеличение степени не даёт значимого улучшения качества модели.
Рассматриваются пять моделей, в которых степени полинома от age идут по возрастанию. В крайнем правом столбце таблице приводятся p
> 250 от возраста
Зависимость вероятности получать зарплату 250, от возраста.
Теперь вернёмся к группе наблюдений с высоким wage. Рассмотрим зависимость вероятности того, что величина зарплаты больше glm() и указываем тип модели binomial:
Подгоняем логистическую регрессию и делаем прогнозы, для этого используем функцию для оценки обобщённой линейной модели
<- glm(I(wage > 250) ~ poly(age, 4), data = Wage, family = 'binomial')
fit
# прогнозы
<- predict(fit, newdata = list(age = age.grid), se = T)
preds
# пересчитываем доверительные интервалы и прогнозы в исходные ЕИ
<- exp(preds$fit) / (1 + exp(preds$fit))
pfit <- cbind(lower.bound = preds$fit - 2*preds$se.fit,
se.bands.logit upper.bound = preds$fit + 2*preds$se.fit)
<- exp(se.bands.logit)/(1 + exp(se.bands.logit))
se.bands
# результат - доверительный интервал для вероятности события
# "Заработная плата выше 250".
round(head(se.bands), 3)
## lower.bound upper.bound
## 1 0 0.002
## 2 0 0.003
## 3 0 0.004
## 4 0 0.005
## 5 0 0.006
## 6 0 0.007
4 слайда презентации (рис. 7.1 книги). Рисуем правую панель.
Достраиваем график с
# сетка для графика (изображаем вероятности, поэтому интервал изменения y мал)
plot(age, I(wage > 250), xlim = agelims, type = 'n', ylim = c(0, 0.2),
ylab = 'P(Wage > 250 | Age)')
# фактические наблюдения показываем засечками
points(jitter(age), I((wage > 250) / 5), cex = 0.5, pch = '|', col = 'darkgrey')
# модель
lines(age.grid, pfit, lwd = 2, col = 'blue')
# доверительные интервалы
matlines(age.grid, se.bands, lwd = 1, col = 'blue', lty = 3)
# заголовок
title('Полином четвёртой степени')
Ступенчатые функции
Для начала определим несколько интервалов, на каждом из которых будем моделировать зависимость wage от age своим средним уровнем.
# нарезаем предиктор age на 4 равных интервала
table(cut(age, 4))
##
## (17.9,33.5] (33.5,49] (49,64.5] (64.5,80.1]
## 750 1399 779 72
# подгоняем линейную модель на интервалах
<- lm(wage ~ cut(age, 4), data = Wage)
fit round(coef(summary(fit)), 2)
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 94.16 1.48 63.79 0.00
## cut(age, 4)(33.5,49] 24.05 1.83 13.15 0.00
## cut(age, 4)(49,64.5] 23.66 2.07 11.44 0.00
## cut(age, 4)(64.5,80.1] 7.64 4.99 1.53 0.13
# прогноз -- это средние по `wage` на каждом интервале
<- predict(fit, newdata = list(age = age.grid), se = T)
preds.cut
# интервальный прогноз
<- cbind(lower.bound = preds.cut$fit - 2*preds.cut$se.fit,
se.bands.cut upper.bound = preds.cut$fit + 2*preds.cut$se.fit)
7 презентации (рис. 7.2 книги).
Воспроизведём график со слайда
# наблюдения
plot(age, wage, xlim = agelims, cex = 0.5, col = 'darkgrey')
# модель
lines(age.grid, preds.cut$fit, lwd = 2, col = 'darkgreen')
# доверительные интервалы прогноза
matlines(x = age.grid, y = se.bands.cut, lwd = 1, col = 'darkgreen', lty = 3)
# заголовок
title('Ступенчатая функция')
250.
Правая часть графика, для вероятности того, что зарплата выше
<- glm(I(wage > 250) ~ cut(age, 4), data = Wage, family = 'binomial')
fit
# прогнозы
<- predict(fit, newdata = list(age = age.grid), se = T)
preds
# пересчитываем доверительные интервалы и прогнозы в исходные ЕИ
<- exp(preds$fit) / (1 + exp(preds$fit))
pfit <- cbind(lower.bound = preds$fit - 2*preds$se.fit,
se.bands.logit upper.bound = preds$fit + 2*preds$se.fit)
<- exp(se.bands.logit)/(1 + exp(se.bands.logit))
se.bands
# результат - доверительный интервал для вероятности события
# "Заработная плата выше 250".
round(head(se.bands), 3)
## lower.bound upper.bound
## 1 0.003 0.016
## 2 0.003 0.016
## 3 0.003 0.016
## 4 0.003 0.016
## 5 0.003 0.016
## 6 0.003 0.016
# сетка для графика (изображаем вероятности, поэтому интервал изменения y мал)
plot(age, I(wage > 250), xlim = agelims, type = 'n', ylim = c(0, 0.2),
ylab = 'P(Wage > 250 | Age)')
# фактические наблюдения показываем засечками
points(jitter(age), I((wage > 250) / 5), cex = 0.5, pch = '|', col = 'darkgrey')
# модель
lines(age.grid, pfit, lwd = 2, col = 'darkgreen')
# доверительные интервалы
matlines(age.grid, se.bands, lwd = 1, col = 'darkgreen', lty = 3)
# заголовок
title('Ступенчатая функция')
Сплайны
Построим кубический сплайн с тремя узлами.
# кубический сплайн с тремя узлами
<- lm(wage ~ bs(age, knots = c(25, 40, 60)), data = Wage)
fit # прогноз
<- predict(fit, newdata = list(age = age.grid), se = T)
preds.spl 6 степеней свободы. Если функции bs(), которая создаёт матрицу с базисом для полиномиального сплайна, передать только степени свободы, она распределит узлы равномерно. В данном случае это квартили распределения age.
Теперь построим натуральный по трём узлам. Три узла это
# 3 узла -- 6 степеней свободы (столбцы матрицы)
dim(bs(age, knots = c(25, 40, 60)))
## [1] 3000 6
# если не указываем узлы явно...
dim(bs(age, df = 6))
## [1] 3000 6
# они привязываются к квартилям
attr(bs(age, df = 6), 'knots')
## 25% 50% 75%
## 33.75 42.00 51.00
# натуральный сплайн
<- lm(wage ~ ns(age, df = 4), data = Wage)
fit2 <- predict(fit2, newdata = list(age = age.grid), se = T)
preds.spl2
График сравнения кубического и натурального сплайнов.
par(mfrow = c(1, 1), mar = c(4.5, 4.5, 1, 8.5), oma = c(0, 0, 0, 0), xpd = T)
# наблюдения
plot(age, wage, col = 'grey')
# модель кубического сплайна
lines(age.grid, preds.spl$fit, lwd = 2)
# доверительный интервал
lines(age.grid, preds.spl$fit + 2*preds.spl$se, lty = 'dashed')
lines(age.grid, preds.spl$fit - 2*preds.spl$se, lty = 'dashed')
# натуральный сплайн
lines(age.grid, preds.spl2$fit, col = 'red', lwd = 2)
# легенда
legend("topright", inset = c(-0.7, 0),
c('Кубический \n с 3 узлами', 'Натуральный'),
lwd = rep(2, 2), col = c('black', 'red'))
# заголовок
title("Сплайны")
20 (рисунок 7.8 книги).
Построим график со слайда
par(mfrow = c(1, 1), mar = c(4.5, 4.5, 1, 1), oma = c(0, 0, 4, 0))
# наблюдения
plot(age, wage, xlim = agelims, cex = 0.5, col = 'darkgrey')
# заголовок
title('Сглаживающий сплайн')
# подгоняем модель с 16 степенями свободы
<- smooth.spline(age, wage, df = 16)
fit
# подгоняем модель с подбором лямбды с помощью перекрёстной проверки
<- smooth.spline(age, wage, cv = T)
fit2 ## Warning in smooth.spline(age, wage, cv = T): cross-validation with non-
## unique 'x' values seems doubtful
$df
fit2## [1] 6.794596
# рисуем модель
lines(fit, col = 'red', lwd = 2)
lines(fit2, col = 'blue', lwd = 2)
legend('topright',
c('16 df', '6.8 df'),
col = c('red', 'blue'), lty = 1, lwd = 2, cex = 0.8)
Локальная регрессия24 (рис. 7.10).
Строим график со слайда
plot(age, wage, xlim = agelims, cex = 0.5, col = 'darkgrey')
title('Локальная регрессия')
# подгоняем модель c окном 0.2
<- loess(wage ~ age, span = 0.2, data = Wage)
fit
# подгоняем модель c окном 0.5
<- loess(wage ~ age, span = 0.5, data = Wage)
fit2
# рисум модели
lines(age.grid, predict(fit, data.frame(age = age.grid)),
col = 'red', lwd = 2)
lines(age.grid, predict(fit2, data.frame(age = age.grid)),
col = 'blue', lwd = 2)
# легенда
legend('topright',
c('s = 0.2', 's = 0.5'),
col = c('red', 'blue'), lty = 1, lwd = 2, cex = 0.8)
Обобщённые аддитивные модели (GAM) с непрерывным откликом4 (year), 5 (age) с категориальным предиктором edication.
Построим GAM на натуральных сплайнах степеней
# GAM на натуральных сплайнах
<- gam(wage ~ ns(year, 4) + ns(age, 5) + education, data = Wage)
gam.ns
Также построим модель на сглаживающих сплайнах.
# GAM на сглаживающих сплайнах
<- gam(wage ~ s(year, 4) + s(age, 5) + education, data = Wage)
gam.m3 28 (рис. 7.12).
График со слайда
par(mfrow = c(1, 3))
plot(gam.m3, se = T, col = 'blue')
27 (рис. 7.11).
График со слайда
par(mfrow = c(1, 3))
plot(gam.ns, se = T, col = 'red')
График функции от year похож на прямую. Сделаем ANOVA, чтобы понять, какая степень для year лучше.
<- gam(wage ~ s(age, 5) + education, data = Wage) # без year
gam.m1 <- gam(wage ~ year + s(age, 5) + education, data = Wage) # year^1
gam.m2
anova(gam.m1, gam.m2, gam.m3, test = 'F')
Resid. Df<dbl>
Resid. Dev<dbl>
Df<dbl>
Deviance<dbl>
F<dbl>
Pr(>F)
<dbl>
2990 3711731 NA NA NA NA
2989 3693842 1.000000 17889.243 14.477130 0.0001447167
2986 3689770 2.999989 4071.134 1.098212 0.3485661430
3 rows
Третья модель статистически не лучше второй. Кроме того, один из параметров этой модели незначим.
# сводка по модели gam.m3
summary(gam.m3)
##
## Call: gam(formula = wage ~ s(year, 4) + s(age, 5) + education, data = Wage)
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -119.43 -19.70 -3.33 14.17 213.48
##
## (Dispersion Parameter for gaussian family taken to be 1235.69)
##
## Null Deviance: 5222086 on 2999 degrees of freedom
## Residual Deviance: 3689770 on 2986 degrees of freedom
## AIC: 29887.75
##
## Number of Local Scoring Iterations: 2
##
## Anova for Parametric Effects
## Df Sum Sq Mean Sq F value Pr(>F)
## s(year, 4) 1 27162 27162 21.981 2.877e-06 ***
## s(age, 5) 1 195338 195338 158.081 < 2.2e-16 ***
## education 4 1069726 267432 216.423 < 2.2e-16 ***
## Residuals 2986 3689770 1236
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Anova for Nonparametric Effects
## Npar Df Npar F Pr(F)
## (Intercept)
## s(year, 4) 3 1.086 0.3537
## s(age, 5) 4 32.380 <2e-16 ***
## education
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Работаем с моделью gam.m2.
# прогноз по обучающей выборке
<- predict(gam.m2, newdata = Wage)
preds
Также можно использовать в GAM локальные регрессии.
# GAM на локальных регрессиях
<- gam(wage ~ s(year, df = 4) + lo(age, span = 0.7) + education,
gam.lo data = Wage)
par(mfrow = c(1, 3))
plot.gam(gam.lo, se = T, col = 'green')
# модель со взаимодействием регрессоров year и age
<- gam(wage ~ lo(year, age, span = 0.5) + education, data = Wage)
gam.lo.i ## Warning in lo.wam(x, z, wz, fit$smooth, which, fit$smooth.frame,
## bf.maxit, : liv too small. (Discovered by lowesd)
## Warning in lo.wam(x, z, wz, fit$smooth, which, fit$smooth.frame,
## bf.maxit, : lv too small. (Discovered by lowesd)
## Warning in lo.wam(x, z, wz, fit$smooth, which, fit$smooth.frame,
## bf.maxit, : liv too small. (Discovered by lowesd)
## Warning in lo.wam(x, z, wz, fit$smooth, which, fit$smooth.frame,
## bf.maxit, : lv too small. (Discovered by lowesd)
plot(gam.lo.i)
Логистическая GAM250.
Построим логистическую GAM для всероятности того, что wage превышает
<- gam(I(wage > 250) ~ year + s(age, df = 5) + education,
gam.lr family = 'binomial', data = Wage)
par(mfrow = c(1, 3))
plot(gam.lr, se = T, col = 'green')
# уровни образования по группам разного достатка
table(education, I(wage > 250))
##
## education FALSE TRUE
## 1. < HS Grad 268 0
## 2. HS Grad 966 5
## 3. Some College 643 7
## 4. College Grad 663 22
## 5. Advanced Degree 381 45
> 250, поэтому убираем её.
В категории с самым низким уровнем образования нет wage
<- gam(I(wage > 250) ~ year + s(age, df = 5) + education,
gam.lr.s family = 'binomial', data = Wage,
subset = (education != "1. < HS Grad"))
# график
par(mfrow = c(1, 3))
plot(gam.lr.s, se = T, col = 'green')
detach(Wage)
# Nonlinear modeling
Математическое моделирование8
Практика
Нелинейные модели:
В практических примерах ниже показано как
строить регрессионные деревья;
строить деревья классификации;
делать обрезку дерева;
использовать бэггинг, бустинг, случайный лес для улучшения качества прогнозирования.: деревья решений.
Модели: Sales {ISLR}, Boston {ISLR}
Данные
1], глава 8.
Подробные комментарии к коду лабораторных см. в [
library('tree') # деревья
## Warning: package 'tree' was built under R version 3.4.4
library('ISLR') # наборы данных
library('MASS')
library('randomForest') # случайный лес
## Warning: package 'randomForest' was built under R version 3.4.4
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
library('gbm')
## Warning: package 'gbm' was built under R version 3.4.4
## Loading required package: survival
## Loading required package: lattice
## Loading required package: splines
## Loading required package: parallel
## Loaded gbm 2.1.3
Деревья решений:
Загрузим таблицу с данными по продажам детских кресел и добавим к ней переменную High – “высокие продажи” со значениями
8 (тыс. шт.);
Yes если продажи больше
No в противном случае.
?Carseats## starting httpd help server ... done
attach(Carseats)
# новая переменная
<- ifelse(Sales <= 8, "No", "Yes")
High
# присоединяем к таблице данных
<- data.frame(Carseats, High)
Carseats
Строим дерево для категориального отклика High, отбросив непрерывный отклик Sales.
# модель бинарного дерева
<- tree(High ~ . -Sales, Carseats)
tree.carseats summary(tree.carseats)
##
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Income" "CompPrice" "Population"
## [6] "Advertising" "Age" "US"
## Number of terminal nodes: 27
## Residual mean deviance: 0.4575 = 170.7 / 373
## Misclassification error rate: 0.09 = 36 / 400
# график результата
plot(tree.carseats) # ветви
text(tree.carseats, pretty=0) # подписи
# посмотреть всё дерево в консоли
tree.carseats ## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 400 541.500 No ( 0.59000 0.41000 )
## 2) ShelveLoc: Bad,Medium 315 390.600 No ( 0.68889 0.31111 )
## 4) Price < 92.5 46 56.530 Yes ( 0.30435 0.69565 )
## 8) Income < 57 10 12.220 No ( 0.70000 0.30000 )
## 16) CompPrice < 110.5 5 0.000 No ( 1.00000 0.00000 ) *
## 17) CompPrice > 110.5 5 6.730 Yes ( 0.40000 0.60000 ) *
## 9) Income > 57 36 35.470 Yes ( 0.19444 0.80556 )
## 18) Population < 207.5 16 21.170 Yes ( 0.37500 0.62500 ) *
## 19) Population > 207.5 20 7.941 Yes ( 0.05000 0.95000 ) *
## 5) Price > 92.5 269 299.800 No ( 0.75465 0.24535 )
## 10) Advertising < 13.5 224 213.200 No ( 0.81696 0.18304 )
## 20) CompPrice < 124.5 96 44.890 No ( 0.93750 0.06250 )
## 40) Price < 106.5 38 33.150 No ( 0.84211 0.15789 )
## 80) Population < 177 12 16.300 No ( 0.58333 0.41667 )
## 160) Income < 60.5 6 0.000 No ( 1.00000 0.00000 ) *
## 161) Income > 60.5 6 5.407 Yes ( 0.16667 0.83333 ) *
## 81) Population > 177 26 8.477 No ( 0.96154 0.03846 ) *
## 41) Price > 106.5 58 0.000 No ( 1.00000 0.00000 ) *
## 21) CompPrice > 124.5 128 150.200 No ( 0.72656 0.27344 )
## 42) Price < 122.5 51 70.680 Yes ( 0.49020 0.50980 )
## 84) ShelveLoc: Bad 11 6.702 No ( 0.90909 0.09091 ) *
## 85) ShelveLoc: Medium 40 52.930 Yes ( 0.37500 0.62500 )
## 170) Price < 109.5 16 7.481 Yes ( 0.06250 0.93750 ) *
## 171) Price > 109.5 24 32.600 No ( 0.58333 0.41667 )
## 342) Age < 49.5 13 16.050 Yes ( 0.30769 0.69231 ) *
## 343) Age > 49.5 11 6.702 No ( 0.90909 0.09091 ) *
## 43) Price > 122.5 77 55.540 No ( 0.88312 0.11688 )
## 86) CompPrice < 147.5 58 17.400 No ( 0.96552 0.03448 ) *
## 87) CompPrice > 147.5 19 25.010 No ( 0.63158 0.36842 )
## 174) Price < 147 12 16.300 Yes ( 0.41667 0.58333 )
## 348) CompPrice < 152.5 7 5.742 Yes ( 0.14286 0.85714 ) *
## 349) CompPrice > 152.5 5 5.004 No ( 0.80000 0.20000 ) *
## 175) Price > 147 7 0.000 No ( 1.00000 0.00000 ) *
## 11) Advertising > 13.5 45 61.830 Yes ( 0.44444 0.55556 )
## 22) Age < 54.5 25 25.020 Yes ( 0.20000 0.80000 )
## 44) CompPrice < 130.5 14 18.250 Yes ( 0.35714 0.64286 )
## 88) Income < 100 9 12.370 No ( 0.55556 0.44444 ) *
## 89) Income > 100 5 0.000 Yes ( 0.00000 1.00000 ) *
## 45) CompPrice > 130.5 11 0.000 Yes ( 0.00000 1.00000 ) *
## 23) Age > 54.5 20 22.490 No ( 0.75000 0.25000 )
## 46) CompPrice < 122.5 10 0.000 No ( 1.00000 0.00000 ) *
## 47) CompPrice > 122.5 10 13.860 No ( 0.50000 0.50000 )
## 94) Price < 125 5 0.000 Yes ( 0.00000 1.00000 ) *
## 95) Price > 125 5 0.000 No ( 1.00000 0.00000 ) *
## 3) ShelveLoc: Good 85 90.330 Yes ( 0.22353 0.77647 )
## 6) Price < 135 68 49.260 Yes ( 0.11765 0.88235 )
## 12) US: No 17 22.070 Yes ( 0.35294 0.64706 )
## 24) Price < 109 8 0.000 Yes ( 0.00000 1.00000 ) *
## 25) Price > 109 9 11.460 No ( 0.66667 0.33333 ) *
## 13) US: Yes 51 16.880 Yes ( 0.03922 0.96078 ) *
## 7) Price > 135 17 22.070 No ( 0.64706 0.35294 )
## 14) Income < 46 6 0.000 No ( 1.00000 0.00000 ) *
## 15) Income > 46 11 15.160 Yes ( 0.45455 0.54545 ) *
Теперь построим дерево на обучающей выборке и оценим ошибку на тестовой.
# ядро генератора случайных чисел
set.seed(2)
# обучающая выборка
<- sample(1:nrow(Carseats), 200)
train
# тестовая выборка
<- Carseats[-train,]
Carseats.test <- High[-train]
High.test
# строим дерево на обучающей выборке
<- tree(High ~ . -Sales, Carseats, subset = train)
tree.carseats
# делаем прогноз
<- predict(tree.carseats, Carseats.test, type = "class")
tree.pred
# матрица неточностей
<- table(tree.pred, High.test)
tbl
tbl## High.test
## tree.pred No Yes
## No 86 27
## Yes 30 57
# оценка точности
<- sum(diag(tbl))/sum(tbl)
acc.test
acc.test## [1] 0.715
: доля верных прогнозов: 0.72.
Обобщённая характеристика точности
cv.tree() проводит кросс-валидацию для выбора лучшего дерева, аргумент prune.misclass означает, что мы минимизируем ошибку классификации.
Теперь обрезаем дерево, используя в качестве критерия частоту ошибок классификации. Функция
set.seed(3)
<- cv.tree(tree.carseats, FUN = prune.misclass)
cv.carseats # имена элементов полученного объекта
names(cv.carseats)
## [1] "size" "dev" "k" "method"
# сам объект
cv.carseats## $size
## [1] 19 17 14 13 9 7 3 2 1
##
## $dev
## [1] 55 55 53 52 50 56 69 65 80
##
## $k
## [1] -Inf 0.0000000 0.6666667 1.0000000 1.7500000 2.0000000
## [7] 4.2500000 5.0000000 23.0000000
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
# графики изменения параметров метода по ходу обрезки дерева ###################
# 1. ошибка с кросс-валидацией в зависимости от числа узлов
par(mfrow = c(1, 2))
plot(cv.carseats$size, cv.carseats$dev, type = "b",
ylab = 'Частота ошибок с кросс-вал. (dev)',
xlab = 'Число узлов (size)')
# размер дерева с минимальной ошибкой
<- cv.carseats$size[cv.carseats$dev == min(cv.carseats$dev)]
opt.size abline(v = opt.size, col = 'red', 'lwd' = 2) # соотв. вертикальная прямая
mtext(opt.size, at = opt.size, side = 1, col = 'red', line = 1)
# 2. ошибка с кросс-валидацией в зависимости от штрафа на сложность
plot(cv.carseats$k, cv.carseats$dev, type = "b",
ylab = 'Частота ошибок с кросс-вал. (dev)',
xlab = 'Штраф за сложность (k)')
9. Оценим точность дерева с 9 узлами.
Как видно на графике слева, минимум частоты ошибок достигается при числе узлов
# дерево с 9 узлами
<- prune.misclass(tree.carseats, best = 9)
prune.carseats
# визуализация
plot(prune.carseats)
text(prune.carseats, pretty = 0)
# прогноз на тестовую выборку
<- predict(prune.carseats, Carseats.test, type = "class")
tree.pred
# матрица неточностей
<- table(tree.pred, High.test)
tbl
tbl## High.test
## tree.pred No Yes
## No 94 24
## Yes 22 60
# оценка точности
<- sum(diag(tbl))/sum(tbl)
acc.test
acc.test## [1] 0.77
0.77. Увеличив количество узлов, получим более глубокое дерево, но менее точное.
Точность этой модели чуть выше точности исходного дерева и составляет
# дерево с 13 узлами
<- prune.misclass(tree.carseats, best = 15)
prune.carseats
# визуализация
plot(prune.carseats)
text(prune.carseats, pretty = 0)
# прогноз на тестовую выборку
<- predict(prune.carseats, Carseats.test, type = "class")
tree.pred
# матрица неточностей
<- table(tree.pred, High.test)
tbl
tbl## High.test
## tree.pred No Yes
## No 86 22
## Yes 30 62
# оценка точности
<- sum(diag(tbl))/sum(tbl)
acc.test
acc.test## [1] 0.74
# сбрасываем графические параметры
par(mfrow = c(1, 1))
Регрессионные деревья
Воспользуемся набором данных Boston.
?Boston
# обучающая выборка
set.seed(1)
<- sample(1:nrow(Boston), nrow(Boston)/2) # обучающая выборка -- 50%
train : медианная стоимости домов, в которых живут собственники (тыс. долл.).
Построим дерево регрессии для зависимой переменной medv
# обучаем модель
<- tree(medv ~ ., Boston, subset = train)
tree.boston summary(tree.boston)
##
## Regression tree:
## tree(formula = medv ~ ., data = Boston, subset = train)
## Variables actually used in tree construction:
## [1] "lstat" "rm" "dis"
## Number of terminal nodes: 8
## Residual mean deviance: 12.65 = 3099 / 245
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -14.10000 -2.04200 -0.05357 0.00000 1.96000 12.60000
# визуализация
plot(tree.boston)
text(tree.boston, pretty = 0)
Снова сделаем обрезку дерева в целях улучшения качества прогноза.
<- cv.tree(tree.boston)
cv.boston
# размер дерева с минимальной ошибкой
plot(cv.boston$size, cv.boston$dev, type = 'b')
<- cv.boston$size[cv.boston$dev == min(cv.boston$dev)]
opt.size abline(v = opt.size, col = 'red', 'lwd' = 2) # соотв. вертикальная прямая
mtext(opt.size, at = opt.size, side = 1, col = 'red', line = 1)
8 узлами. Покажем, как при желании можно обрезать дерево до 7 узлов (ошибка ненамного выше, чем минимальная).
В данном случаем минимум ошибки соответствует самому сложному дереву, с
# дерево с 7 узлами
= prune.tree(tree.boston, best = 7)
prune.boston
# визуализация
plot(prune.boston)
text(prune.boston, pretty = 0)
Прогноз сделаем по необрезанному дереву, т.к. там ошибка, оцененная по методу перекрёстной проверки, минимальна.
# прогноз по лучшей модели (8 узлов)
<- predict(tree.boston, newdata = Boston[-train, ])
yhat <- Boston[-train, "medv"]
boston.test
# график "прогноз -- реализация"
plot(yhat, boston.test)
# линия идеального прогноза
abline(0, 1)
# MSE на тестовой выборке
<- mean((yhat - boston.test)^2)
mse.test 25.05 (тыс.долл.).
MSE на тестовой выборке равна
Бэггинг и метод случайного леса=p, поэтому и то, и другое можно построить функцией randomForest().
Рассмотрим более сложные методы улучшения качества дерева. Бэггинг – частный случай случайного леса с m
13 предикторов на каждом шаге (аргумент mtry).
Для начала используем бэггинг, причём возьмём все
# бэггинг с 13 предикторами
set.seed(1)
<- randomForest(medv ~ ., data = Boston, subset = train,
bag.boston mtry = 13, importance = TRUE)
bag.boston##
## Call:
## randomForest(formula = medv ~ ., data = Boston, mtry = 13, importance = TRUE, subset = train)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 13
##
## Mean of squared residuals: 11.15723
## % Var explained: 86.49
# прогноз
= predict(bag.boston, newdata = Boston[-train, ])
yhat.bag
# график "прогноз -- реализация"
plot(yhat.bag, boston.test)
# линия идеального прогноза
abline(0, 1)
# MSE на тестовой
<- mean((yhat.bag - boston.test)^2)
mse.test
mse.test## [1] 13.50808
13.51.
Ошибка на тестовой выборке равна
Можно изменить число деревьев с помощью аргумента ntree.
<- randomForest(medv ~ ., data = Boston, subset = train,
bag.boston mtry = 13, ntree = 25)
# прогноз
<- predict(bag.boston, newdata = Boston[-train, ])
yhat.bag
# MSE на тестовой
<- mean((yhat.bag - boston.test)^2)
mse.test
mse.test## [1] 13.94835
Но, как видно, это только ухудшает прогноз.6 предикторов на каждом шаге.
Теперь попробуем вырастить случайный лес. Берём
# обучаем модель
set.seed(1)
<- randomForest(medv ~ ., data = Boston, subset = train,
rf.boston mtry = 6, importance = TRUE)
# прогноз
<- predict(rf.boston, newdata = Boston[-train, ])
yhat.rf
# MSE на тестовой выборке
<- mean((yhat.rf - boston.test)^2)
mse.test
# важность предикторов
importance(rf.boston) # оценки
## %IncMSE IncNodePurity
## crim 12.132320 986.50338
## zn 1.955579 57.96945
## indus 9.069302 882.78261
## chas 2.210835 45.22941
## nox 11.104823 1044.33776
## rm 31.784033 6359.31971
## age 10.962684 516.82969
## dis 15.015236 1224.11605
## rad 4.118011 95.94586
## tax 8.587932 502.96719
## ptratio 12.503896 830.77523
## black 6.702609 341.30361
## lstat 30.695224 7505.73936
varImpPlot(rf.boston) # графики
11.66, что ниже, чем для бэггинга.
Ошибка по модели случайного леса равна
Бустинг5000 регрессионных деревьев с глубиной 4.
Построим
set.seed(1)
<- gbm(medv ~ ., data = Boston[train, ], distribution = "gaussian",
boost.boston n.trees = 5000, interaction.depth = 4)
# график и таблица относительной важности переменных
summary(boost.boston)
# графики частной зависимости для двух наиболее важных предикторов
par(mfrow = c(1, 2))
plot(boost.boston, i = "rm")
plot(boost.boston, i = "lstat")
# прогноз
<- predict(boost.boston, newdata = Boston[-train, ], n.trees = 5000)
yhat.boost
# MSE на тестовой
<- mean((yhat.boost - boston.test)^2)
mse.test
mse.test## [1] 11.84434
0.2.
Настройку бустинга можно делать с помощью гиперпараметра λ (аргумент shrinkage). Установим его равным
# меняем значение гиперпараметра (lambda) на 0.2 -- аргумент shrinkage
<- gbm(medv ~ ., data = Boston[train, ], distribution = "gaussian",
boost.boston n.trees = 5000, interaction.depth = 4,
shrinkage = 0.2, verbose = F)
# прогноз
<- predict(boost.boston, newdata = Boston[-train, ], n.trees = 5000)
yhat.boost
# MSE а тестовой
<- mean((yhat.boost - boston.test)^2)
mse.test
mse.test## [1] 11.51109
Таким образом, изменив гиперпараметр, мы ещё немного снизили ошибку прогноза.