当前位置: 移动技术网 > IT编程>开发语言>Java > 通过反射注解批量插入数据到DB的实现方法

通过反射注解批量插入数据到DB的实现方法

2019年07月19日  | 移动技术网IT编程  | 我要评论

批量导入思路

最近遇到一个需要批量导入数据问题。后来考虑运用反射做成一个工具类,思路是首先定义注解接口,在bean类上加注解,运行时通过反射获取传入bean的注解,自动生成需要插入db的sql,根据设置的参数值批量提交。不需要写具体的sql,也没有dao的实现,这样一来批量导入的实现就和具体的数据库表彻底解耦。实际批量执行的sql如下:

insert into company_candidate(company_id,user_id,card_id,facebook_id,type,create_time,weight,score) values (?,?,?,?,?,?,?,?) on duplicate key update type=?,weight=?,score=?

第一步,定义注解接口

注解接口table中定义了数据库名和表名。retentionpolicy.runtime表示该注解保存到运行时,因为我们需要在运行时,去读取注解参数来生成具体的sql。

@documented
@retention(retentionpolicy.runtime)
@target(elementtype.type)
public @interface table {
  /**
   * 表名
   * @return
   */
  string tablename() default "";
  /**
   * 数据库名称
   * @return
   */
  string dbname();
}

注解接口tablefield中定义了数据库表名的各个具体字段名称,以及该字段是否忽略(忽略的话就会以数据库表定义默认值填充,db非null字段的注解不允许出现把ignore注解设置为true)。update注解是在主键在db重复时,需要更新的字段。

@documented
@retention(retentionpolicy.runtime)
@target(elementtype.field)
public @interface tablefield {
  /**
   * 对应数据库字段名称
   * @return
   */
  string fieldname() default "";
  /**
   * 是否是主键
   * @return
   */
  boolean pk() default false;
  /**
   * 是否忽略该字段
   * @return
   */
  boolean ignore() default false;
  /**
   * 当数据存在时,是否更新该字段
   * @return
   */
  boolean update() default false;
}

第二步,给bean添加注解

给bean添加注解(为了简洁省略了import和set/get方法以及其他属性),@tablefield(fieldname = "company_id")表示companyid字段对应db表的字段名为"company_id",其中updatetime属性的注解含有ignore=true,表示该属性值会被忽略。另外serialversionuid属性由于没有@tablefield注解,在更新db时也会被忽略。

代码如下:

@table(dbname = "company", tablename = "company_candidate")
public class companycandidatemodel implements serializable{
 private static final long serialversionuid = -1234554321773322135l;
 @tablefield(fieldname = "company_id")
 private int companyid;
 @tablefield(fieldname = "user_id")
 private int userid;
 //名片id
 @tablefield(fieldname = "card_id")
 private int cardid;
 //facebookid
 @tablefield(fieldname = "facebook_id")
 private long facebookid;
  @tablefield(fieldname="type", update = true)
 private int type;
 @tablefield(fieldname = "create_time")
 private date createtime;
 @tablefield(fieldname = "update_time", ignore=true)
 private date updatetime;
 // 权重
  @tablefield(fieldname="weight", update = true)
 private int weight;
 // 分值
  @tablefield(fieldname="score", update = true)
 private double score;

第三步,读取注解的反射工具类

读取第二步bean类的注解的反射工具类。利用反射getannotation(tablefield.class)读取注解信息,为批量sql的拼接最好准备。

gettablebeanfieldmap()方法里生成一个linkedhashmap对象,是为了保证生成插入sql的field顺序,之后也能按同样的顺序给参数赋值,避免错位。getsqlparamfields()方法也类似,是为了给preparedstatement设置参数用。

代码如下:

public class reflectutil {
  /**
   * <class,<表定义field名,bean定义field>>的map缓存
   */
  private static final map<class<?>, map<string field="">> classtablebeanfieldmap = new hashmap<class<?>, map<string field="">>();
  // 用来按顺序填充sql参数,其中存储的field和classtablebeanfieldmap保存同样的顺序,但数量多出on duplicate key update部分field
  private static final map<class<?>, list<field>> sqlparamfieldsmap = new hashmap<class<?>, list<field>>(); 
  private reflectutil(){};
  /**
   * 获取该类上所有@tablefield注解,且没有忽略的字段的map。
   * <br />返回一个有序的linkedhashmap类型
   * <br />其中key为db表中的字段,value为bean类里的属性field对象
   * @param clazz
   * @return
   */
  public static map<string field=""> gettablebeanfieldmap(class<?> clazz) {
   // 从缓存获取
   map<string field=""> fieldsmap = classtablebeanfieldmap.get(clazz);
   if (fieldsmap == null) {
   fieldsmap = new linkedhashmap<string field="">();
      for (field field : clazz.getdeclaredfields()) {// 获得所有声明属性数组的一个拷贝
       tablefield annotation = field.getannotation(tablefield.class);
        if (annotation != null && !annotation.ignore() && !"".equals(annotation.fieldname())) {
          field.setaccessible(true);// 方便后续获取私有域的值
         fieldsmap.put(annotation.fieldname(), field);
        }
  }
      // 放入缓存
      classtablebeanfieldmap.put(clazz, fieldsmap);
   }
   return fieldsmap;
  }
  /**
   * 获取该类上所有@tablefield注解,且没有忽略的字段的map。on duplicate key update后需要更新的字段追加在list最后,为了填充参数值准备
   * <br />返回一个有序的arraylist类型
   * <br />其中key为db表中的字段,value为bean类里的属性field对象
   * @param clazz
   * @return
   */
  public static list<field> getsqlparamfields(class<?> clazz) {
   // 从缓存获取
   list<field> sqlparamfields = sqlparamfieldsmap.get(clazz);
   if (sqlparamfields == null) {
   // 获取所有参数字段
     map<string field=""> fieldsmap = gettablebeanfieldmap(clazz);
   sqlparamfields = new arraylist<field>(fieldsmap.size() * 2);
     // sql后段on duplicate key update需要更新的字段
     list<field> updateparamfields = new arraylist<field>();
   iterator<entry<string field="">> iter = fieldsmap.entryset().iterator();
   while (iter.hasnext()) {
    entry<string field=""> entry = (entry<string field="">) iter.next();
    field field = entry.getvalue();
    // insert语句对应sql参数字段
    sqlparamfields.add(field);
        // on duplicate key update后面语句对应sql参数字段
        tablefield annotation = field.getannotation(tablefield.class);
    if (annotation != null && !annotation.ignore() && annotation.update()) {
    updateparamfields.add(field);
    }
   }
   sqlparamfields.addall(updateparamfields);
      // 放入缓存
   sqlparamfieldsmap.put(clazz, sqlparamfields);
   }
   return sqlparamfields;
  }
  /**
   * 获取表名,对象中使用@table的tablename来标记对应数据库的表名,若未标记则自动将类名转成小写
   * 
   * @param clazz
   * @return
   */
  public static string gettablename(class<?> clazz) {
    table table = clazz.getannotation(table.class);
    if (table != null && table.tablename() != null && !"".equals(table.tablename())) {
      return table.tablename();
    }
    // 当未配置@table的tablename,自动将类名转成小写
    return clazz.getsimplename().tolowercase();
  }
  /**
   * 获取数据库名,对象中使用@table的dbname来标记对应数据库名
   * @param clazz
   * @return
   */
  public static string getdbname(class<?> clazz) {
    table table = clazz.getannotation(table.class);
    if (table != null && table.dbname() != null) {
      // 注解@table的dbname
      return table.dbname();
    }
    return "";
  }

第四步,生成sql语句

根据上一步的方法,生成真正执行的sql语句。

insert into company_candidate(company_id,user_id,card_id,facebook_id,type,create_time,weight,score) values (?,?,?,?,?,?,?,?) on duplicate key update type=?,weight=?,score=?

代码如下:

public class sqlutil {
  private static final char comma = ',';
  private static final char brackets_begin = '(';
  private static final char brackets_end = ')';
  private static final char question_mark = '?';
  private static final char equal_sign = '=';
  private static final string insert_begin = "insert into ";
  private static final string insert_valurs = " values ";
  private static final string duplicate_update = " on duplicate key update ";
  // 数据库表名和对应insertupdatesql的缓存
  private static final map<string string=""> tableinsertsqlmap = new hashmap<string string="">();
  /**
   * 获取插入的sql语句,对象中使用@tablefield的fieldname来标记对应数据库的列名,若未标记则忽略
   * 必须标记@tablefield(fieldname = "company_id")注解
   * @param tablename
   * @param fieldsmap
   * @return
   * @throws exception
   */
  public static string getinsertsql(string tablename, map<string field=""> fieldsmap) throws exception {
   string sql = tableinsertsqlmap.get(tablename);
   if (sql == null) {
   stringbuilder sbsql = new stringbuilder(300).append(insert_begin);
   stringbuilder sbvalue = new stringbuilder(insert_valurs);
   stringbuilder sbupdate = new stringbuilder(100).append(duplicate_update);
   sbsql.append(tablename);
   sbsql.append(brackets_begin);
   sbvalue.append(brackets_begin);
   iterator<entry<string field="">> iter = fieldsmap.entryset().iterator();
   while (iter.hasnext()) {
    entry<string field=""> entry = (entry<string field="">) iter.next();
    string tablefieldname = entry.getkey();
    field field = entry.getvalue();
    sbsql.append(tablefieldname);
    sbsql.append(comma);
    sbvalue.append(question_mark);
    sbvalue.append(comma);
    tablefield tablefield = field.getannotation(tablefield.class);
    if (tablefield != null && tablefield.update()) {
    sbupdate.append(tablefieldname);
    sbupdate.append(equal_sign);
    sbupdate.append(question_mark);
    sbupdate.append(comma);
    }
   }
   // 去掉最后的逗号
   sbsql.deletecharat(sbsql.length() - 1);
   sbvalue.deletecharat(sbvalue.length() - 1);
   sbsql.append(brackets_end);
   sbvalue.append(brackets_end);
   sbsql.append(sbvalue);
   if (!sbupdate.tostring().equals(duplicate_update)) {
    sbupdate.deletecharat(sbupdate.length() - 1);
    sbsql.append(sbupdate);
   }
   sql = sbsql.tostring();
   tableinsertsqlmap.put(tablename, sql);
   }
    return sql;
  }

第五步,批量sql插入实现

从连接池获取connection,sqlutil.getinsertsql()获取执行的sql语句,根据sqlparamfields来为preparedstatement填充参数值。当循环的值集合到达batchnum时就提交一次。

代码如下:

  /**
   * 批量插入,如果主键一致则更新。结果返回更新记录条数<br />
   * @param datalist
   *      要插入的对象list
   * @param batchnum
   *      每次批量插入条数
   * @return 更新记录条数
   */
  public int batchinsertsql(list<? extends object> datalist, int batchnum) throws exception {
   if (datalist == null || datalist.isempty()) {
   return 0;
   }
    class<?> clazz = datalist.get(0).getclass();
    string tablename = reflectutil.gettablename(clazz);
    string dbname = reflectutil.getdbname(clazz);
    connection connnection = null;
    preparedstatement preparedstatement = null;
    // 获取所有需要更新到db的属性域
    map<string field=""> fieldsmap = reflectutil.gettablebeanfieldmap(datalist.get(0).getclass());
    // 根据需要插入更新的字段生成sql语句
    string sql = sqlutil.getinsertsql(tablename, fieldsmap);
    log.debug("prepare to start batch operation , sql = " + sql + " , dbname = " + dbname);
    // 获取和sql语句同样顺序的填充参数fields
    list<field> sqlparamfields = reflectutil.getsqlparamfields(datalist.get(0).getclass());
    // 最终更新结果条数
    int result = 0;
    int parameterindex = 1;// sql填充参数开始位置为1
    // 执行错误的对象
    list<object> errorsrecords = new arraylist</object><object>(batchnum);//指定数组大小
    // 计数器,batchnum提交后内循环累计次数
    int innercount = 0;
    try {
      connnection = this.getconnection(dbname);
      // 设置非自动提交
      connnection.setautocommit(false);
      preparedstatement = connnection.preparestatement(sql);
      // 当前操作的对象
      object object = null;
      int totalrecordcount = datalist.size();
      for (int current = 0; current < totalrecordcount; current++) {
        innercount++;
        object = datalist.get(current);
       parameterindex = 1;// 开始参数位置为1
       for(field field : sqlparamfields) {
       // 放入insert语句对应sql参数
          preparedstatement.setobject(parameterindex++, field.get(object));
       }
       errorsrecords.add(object);
        preparedstatement.addbatch();
        // 达到批量次数就提交一次
        if (innercount >= batchnum || current >= totalrecordcount - 1) {
          // 执行batch操作
          preparedstatement.executebatch();
          preparedstatement.clearbatch();
          // 提交
          connnection.commit();
          // 记录提交成功条数
          result += innercount;
          innercount = 0;
          errorsrecords.clear();
        }
        // 尽早让gc回收
        datalist.set(current, null);
      }
      return result;
    } catch (exception e) {
      // 失败后处理方法
      callbackimpl.getinstance().exectuer(sql, errorsrecords, e);
      batchdbexception be = new batchdbexception("batch run error , dbname = " + dbname + " sql = " + sql, e);
      be.initcause(e);
      throw be;
    } finally {
      // 关闭
      if (preparedstatement != null) {
       preparedstatement.clearbatch();
        preparedstatement.close();
      }
      if (connnection != null)
        connnection.close();
    }
  }

最后,批量工具类使用例子

在mysql下的开发环境下测试,5万条数据大概13秒。

list<companycandidatemodel> updatedatalist = new arraylist<companycandidatemodel>(50000);
// ...为updatedatalist填充数据
int result = batchjdbctemplate.batchinsertsql(updatedatalist, 50);

总结

以上就是这篇文章的全部内容了,希望本文的内容对大家的学习或者工作具有一定的参考学习价值,谢谢大家对移动技术网的支持。如果你想了解更多相关内容请查看下面相关链接

如对本文有疑问, 点击进行留言回复!!

相关文章:

验证码:
移动技术网