首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何将两个线性回归预测模型(每个数据帧的子集)合并为数据帧的一列

如何将两个线性回归预测模型(每个数据帧的子集)合并为数据帧的一列
EN

Stack Overflow用户
提问于 2014-07-22 08:05:18
回答 1查看 2.3K关注 0票数 1

我想要构建两个线性回归模型,这些模型基于数据集的两个子集,然后有一个列,其中包含每个子集的预测值。下面是我的数据框架示例:

代码语言:javascript
运行
复制
dat <- read.table(text = " cats birds    wolfs     snakes
 0        3        8         7
 1        3        8         7
 1        1        2         3
 0        1        2         3
 0        1        2         3
 1        6        1         1
 0        6        1         1
 1        6        1         1   ",header = TRUE) 

首先,我建立了两个模型:

代码语言:javascript
运行
复制
# one is for wolfs ~ snakes where cats=0
f0<-lm(wolfs~snakes,data=dat,subset=dat$cats==0)

#the second model is for wolfs ~ snakes where cats=1
f1<-lm(wolfs~snakes,data=dat,subset=dat$cats==1)

然后,我对每个模型进行了预测:

代码语言:javascript
运行
复制
f0_predict<-predict(f0,data=dat,subset=dat$cats==1,type='response')
f1_predict<-predict(f1,data=dat,subset=dat$cats==0,type='response')

这很好,但我无法找到将其插入到原始数据帧的方法,如果是cats==0,我将获得模型对cats==0行的预测值,如果是cat==1,则将获得模型的预测值,而在同一列中,cats==1的名称是: full_prediction。例如,输出应该是(带有伪预测值):

代码语言:javascript
运行
复制
  cats   birds    wolfs     snakes full_prediction
     0        3        8         7        0.6
     1        3        8         7        0.5
     1        1        2         3        0.4
     0        1        2         3        0.3
     0        1        2         3        0.3
     1        6        1         1        0.7
     0        6        1         1        0.1
     1        6        1         1        0.7

如果您查看第6-8行,您可以看到full_prediction的值为cats==1为0.7,cats==0为0.1,您知道如何做这种事情吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2014-07-22 08:10:26

使用splitunsplit

代码语言:javascript
运行
复制
dat.l <- split(dat, dat$cats)

dat.l <- lapply(dat.l, function(x){
  mod <- lm(wolfs~snakes,data=x)
  x$full_prediction <- predict(mod,data=x,type='response')
  return(x)
})

unsplit(dat.l, dat$cats)

输出:

代码语言:javascript
运行
复制
cats birds wolfs snakes full_prediction
1    0     3     8      7       7.5789474
2    1     3     8      7       7.6666667
3    1     1     2      3       3.0000000
4    0     1     2      3       2.6315789
5    0     1     2      3       2.6315789
6    1     6     1      1       0.6666667
7    0     6     1      1       0.1578947
8    1     6     1      1       0.6666667

dplyr解决方案将是:

代码语言:javascript
运行
复制
require(dplyr)
dat %>% 
  group_by(cats) %>%
  do({
    mod <- lm(wolfs~snakes, data = .)
    pred <- predict(mod)
    data.frame(., pred)
  })
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/24881923

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档