当前位置: 移动技术网 > IT编程>开发语言>Java > java实现随机森林RandomForest的示例代码

java实现随机森林RandomForest的示例代码

2019年07月19日  | 移动技术网IT编程  | 我要评论

随机森林是由多棵树组成的分类或回归方法。主要思想来源于bagging算法,bagging技术思想主要是给定一弱分类器及训练集,让该学习算法训练多轮,每轮的训练集由原始训练集中有放回的随机抽取,大小一般跟原始训练集相当,这样依次训练多个弱分类器,最终的分类由这些弱分类器组合,对于分类问题一般采用多数投票法,对于回归问题一般采用简单平均法。随机森林在bagging的基础上,每个弱分类器都是决策树,决策树的生成过程中中,在属性的选择上增加了依一定概率选择属性,在这些属性中选择最佳属性及分割点,传统做法一般是全部属性中去选择最佳属性,这样随机森林有了样本选择的随机性,属性选择的随机性,这样一来增加了每个分类器的差异性、不稳定性及一定程度上避免每个分类器的过拟合(一般决策树有过拟合现象),由此组合分类器增加了最终的泛化能力。下面是代码的简单实现

/**
 * 随机森林 回归问题
 * @author ysh  1208706282
 *
 */
public class randomforest {
  list<sample> msamples;
  list<cart> mcarts;
  double mfeaturerate;
  int mmaxdepth;
  int mminleaf;
  random mrandom;
  /**
   * 加载数据  回归树
   * @param path
   * @param regex
   * @throws exception
   */
  public void loaddata(string path,string regex) throws exception{
    msamples = new arraylist<sample>();
    bufferedreader reader = new bufferedreader(new filereader(path));
    string line = null;
    string splits[] = null;
    sample sample = null;
    while(null != (line=reader.readline())){
      splits = line.split(regex);
      sample = new sample();
      sample.label = double.valueof(splits[0]);
      sample.feature = new arraylist<double>(splits.length-1);
      for(int i=0;i<splits.length-1;i++){
        sample.feature.add(new double(splits[i+1]));
      }
      msamples.add(sample);
    }
    reader.close();
  }
  public void train(int iters){
    mcarts = new arraylist<cart>(iters);
    cart cart = null;
    for(int iter=0;iter<iters;iter++){
      cart = new cart();
      cart.mfeaturerate = mfeaturerate;
      cart.mmaxdepth = mmaxdepth;
      cart.mminleaf = mminleaf;
      cart.mrandom = mrandom;
      list<sample> s = new arraylist<sample>(msamples.size());
      for(int i=0;i<msamples.size();i++){
        s.add(msamples.get(cart.mrandom.nextint(msamples.size())));
      }
      cart.setdata(s);
      cart.train();
      mcarts.add(cart);
      system.out.println("iter: "+iter);
      s = null;
    }
  }
  /**
   * 回归问题简单平均法 分类问题多数投票法
   * @param sample
   * @return
   */
  public double classify(sample sample){
    double val = 0;
    for(cart cart:mcarts){
      val += cart.classify(sample);
    }
    return val/mcarts.size();
  }
  /**
   * @param args
   * @throws exception 
   */
  public static void main(string[] args) throws exception {
    // todo auto-generated method stub
    randomforest forest = new randomforest();
    forest.loaddata("f:/2016-contest/20161001/train_data_1.csv", ",");
    forest.mfeaturerate = 0.8;
    forest.mmaxdepth = 3;
    forest.mminleaf = 1;
    forest.mrandom = new random();
    forest.mrandom.setseed(100);
    forest.train(100);
    
    list<sample> samples = cart.loadtestdata("f:/2016-contest/20161001/valid_data_1.csv", true, ",");
    double sum = 0;
    for(sample s:samples){
      double val = forest.classify(s);
      sum += (val-s.label)*(val-s.label);
      system.out.println(val+" "+s.label);
    }
    system.out.println(sum/samples.size()+" "+sum);
    system.out.println(system.currenttimemillis());
  }

}

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持移动技术网。

如对本文有疑问, 点击进行留言回复!!

相关文章:

验证码:
移动技术网