Table of Contents
原理ORM 实现1. 通过注解来将 Java Bean 和数据库字段关联2. 反射工具类3. 简单的 model 示例4. 注解解析5. 数据库操作6. 结合反射实现查询操作使用动态代理实现 @Query @Select 类似功能1. 动态代理2. 注解3. 表设计4. model5. repository7. 大体流程8. 代理使用9. 将生成代理放入 Spring IOC 容器中10. invoke方法处理
在这篇文章中主要用到了注解、反射、动态代理、正则、Jexl3 表达式来实现一个 ORM, 在大多数框架中,都会用到这些东西,在这里只是简单的使用一下,而不是原理分析
在使用的 ORM 框架中,我可以像操作对象一样操作数据的存储,这是怎么实现的,我们知道数据库是认识 SQL 语句的,但并不认识java bean
呀!同时我们在使用ORM时,需要根据ORM框架的规定定义我们的bean,这是为什么?
这是因为 ORM
为我们提供了将对象操作转化为对应的 SQL
语句,例如 save(bean)
, 这时就需要转化成一个 insert
语句,update(bean)
这时就需要转成成对应的 update 语句
通常 insert 语句格式为
1insert into 表名 (列名) values( 值)
update 语句为
1update 表名 set 列名 = 值 where id = 值
上面的格式可以看出,如果我们能从对象中得出 表名
列名
值
,我们也可以写一个简单的ORM框架
上篇文章中提到了一下注解,以及自定义注解和解析注解的方法,通过使用注解,我们可以完成一个简单的ORM
要想实现对数据库的操作,我们必须知道数据表名以及表中的字段名称以及类型,正如hibernate 使用注解标识 model 与数据库的映射关系一样,这里我使用了Java Persistence API
注解说明
注解 | 作用 | 使用说明 |
---|---|---|
@Entity | 标记这是一个实体 | 标注在类上,标明该类可以被 ORM 处理 |
@Table | 标记实体对应的表 | 标注在类上,标明该类对应的数据库标明 |
@Id | 标记该字段是id | 标注在字段上,标明该字段为 id |
@Column | 标记该字段对应的列信息 | 标记在字段上,标明对应的列信息,主要对应列名 |
字段属性表 用来存储对象中字段与数据表列的对应关系
1package com.zyndev.tool.fastsql.core;
2
3import lombok.Data;
4
5import javax.persistence.GenerationType;
6
7/**
8 * The type Db column info.
9 *
10 * @author 张瑀楠 zyndev@gmail.com
11 * @version 0.0.1
12 */
13@Data
14class DBColumnInfo {
15
16 /**
17 * (Optional) The primary key generation strategy
18 * that the persistence provider must use to
19 * generate the annotated entity primary key.
20 */
21 private GenerationType strategy = GenerationType.AUTO;
22
23 private String fieldName;
24
25 /**
26 * The name of the column
27 */
28 private String columnName;
29
30 /**
31 * Whether the column is a unique key.
32 */
33 private boolean unique;
34
35 /**
36 * Whether the database column is nullable.
37 */
38 private boolean nullable = true;
39
40 /**
41 * Whether the column is included in SQL INSERT
42 */
43 private boolean insertAble = true;
44
45 /**
46 * Whether the column is included in SQL UPDATE
47 */
48 private boolean updatable = true;
49
50 /**
51 * The SQL fragment that is used when
52 * generating the DDL for the column.
53 */
54 private String columnDefinition;
55
56 /**
57 * The name of the table that contains the column.
58 * If absent the column is assumed to be in the primary table.
59 */
60 private String table;
61
62 /**
63 * (Optional) The column length. (Applies only if a
64 * string-valued column is used.)
65 */
66 private int length = 255;
67
68 private boolean id = false;
69
70}
提供一些常用的反射操作
通过反射我们可以动态的得到一个类所有的成员变量的信息,同时为这些变量取值或者赋值
1package com.zyndev.tool.fastsql.util;
2
3
4import java.lang.reflect.Field;
5import java.lang.reflect.Method;
6import java.util.ArrayList;
7import java.util.HashMap;
8import java.util.List;
9import java.util.Map;
10
11
12/**
13 * The type Bean reflection util.
14 *
15 * @author yunan.zhang zyndev@gmail.com
16 * @version 0.0.3
17 * @date 2017年12月26日 16点36分
18 */
19public class BeanReflectionUtil {
20
21 public static Object getPrivatePropertyValue(Object obj,String propertyName)throws Exception{
22 Class cls=obj.getClass();
23 Field field=cls.getDeclaredField(propertyName);
24 field.setAccessible(true);
25 Object retvalue=field.get(obj);
26 return retvalue;
27 }
28
29 /**
30 * Gets static field value.
31 */
32 public static Object getStaticFieldValue(String className, String fieldName) throws Exception {
33 Class cls = Class.forName(className);
34 Field field = cls.getField(fieldName);
35 return field.get(cls);
36 }
37
38 /**
39 * Gets field value.
40 */
41 public static Object getFieldValue(Object obj, String fieldName) throws Exception {
42 Class cls = obj.getClass();
43 Field field = cls.getDeclaredField(fieldName);
44 field.setAccessible(true);
45 return field.get(obj);
46 }
47
48 /**
49 * Invoke method object.
50 */
51 public Object invokeMethod(Object owner, String methodName, Object[] args) throws Exception {
52 Class cls = owner.getClass();
53 Class[] argclass = new Class[args.length];
54 for (int i = 0, j = argclass.length; i < j; i++) {
55 argclass[i] = args[i].getClass();
56 }
57 @SuppressWarnings("unchecked")
58 Method method = cls.getMethod(methodName, argclass);
59 return method.invoke(owner, args);
60 }
61
62 /**
63 * Invoke static method object.
64 */
65 public Object invokeStaticMethod(String className, String methodName, Object[] args) throws Exception {
66 Class cls = Class.forName(className);
67 Class[] argClass = new Class[args.length];
68 for (int i = 0, j = argClass.length; i < j; i++) {
69 argClass[i] = args[i].getClass();
70 }
71 @SuppressWarnings("unchecked")
72 Method method = cls.getMethod(methodName, argClass);
73 return method.invoke(null, args);
74 }
75
76 /**
77 * New instance object.
78 */
79 public static Object newInstance(String className) throws Exception {
80 Class clazz = Class.forName(className);
81 @SuppressWarnings("unchecked")
82 java.lang.reflect.Constructor cons = clazz.getConstructor();
83 return cons.newInstance();
84 }
85
86 /**
87 * New instance object.
88 */
89 public static Object newInstance(Class clazz) throws Exception {
90 @SuppressWarnings("unchecked")
91 java.lang.reflect.Constructor cons = clazz.getConstructor();
92 return cons.newInstance();
93 }
94
95 /**
96 * Get bean declared fields field [ ].
97 */
98 public static Field[] getBeanDeclaredFields(String className) throws ClassNotFoundException {
99 Class clazz = Class.forName(className);
100 return clazz.getDeclaredFields();
101 }
102
103 /**
104 * Get bean declared methods method [ ].
105 */
106 public static Method[] getBeanDeclaredMethods(String className) throws ClassNotFoundException {
107 Class clazz = Class.forName(className);
108 return clazz.getMethods();
109 }
110
111 /**
112 * Copy fields.
113 */
114 public static void copyFields(Object source, Object target) {
115 try {
116 List<Field> list = new ArrayList<>();
117 Field[] sourceFields = getBeanDeclaredFields(source.getClass().getName());
118 Field[] targetFields = getBeanDeclaredFields(target.getClass().getName());
119 Map<String, Field> map = new HashMap<>(targetFields.length);
120 for (Field field : targetFields) {
121 map.put(field.getName(), field);
122 }
123 for (Field field : sourceFields) {
124 if (map.get(field.getName()) != null) {
125 list.add(field);
126 }
127 }
128 for (Field field : list) {
129 Field tg = map.get(field.getName());
130 tg.setAccessible(true);
131 tg.set(target, getFieldValue(source, field.getName()));
132 }
133 } catch (Exception e) {
134 e.printStackTrace();
135 }
136 }
137}
1package com.zyndev.tool.fastsql;
2
3import javax.persistence.*;
4
5/**
6 * @author 张瑀楠 zyndev@gmail.com
7 * @date 2017/11/30 下午11:21
8 */
9@Data
10@Entity
11@Table(name = "tb_student")
12public class Student {
13
14 @Id
15 @Column
16 private Integer id;
17
18 @Column
19 private String name;
20
21 @Column(updatable = true, insertable = true, nullable = false)
22 private Integer age;
23
24}
将对象的上的注解进行解析,得到对应关系
1package com.zyndev.tool.fastsql.core;
2
3
4import com.zyndev.tool.fastsql.util.StringUtil;
5
6import javax.persistence.Column;
7import javax.persistence.Id;
8import javax.persistence.Table;
9import java.lang.reflect.Field;
10import java.util.ArrayList;
11import java.util.HashMap;
12import java.util.List;
13import java.util.Map;
14
15
16/**
17 * The type Annotation parser.
18 * <p>
19 * 注解解析工具
20 *
21 * @author yunan.zhang zyndev@gmail.com
22 * @version 0.0.3
23 * @since 2017-12-26 16:59:07
24 */
25public class AnnotationParser {
26
27 /**
28 * 存储类名和数据库表名的关系
29 * 使用三个cache 主要为了减少反射的次数,提高效率
30 */
31 private final static Map<String, String> tableNameCache = new HashMap<>(30);
32
33 /**
34 * 存储类名和数据库列的关联关系
35 */
36 private final static Map<String, String> tableAllColumnNameCache = new HashMap<>(30);
37
38 /**
39 * 存储类名和对应的数据库全列名的关系
40 */
41 private final static Map<String, List<DBColumnInfo>> tableAllDBColumnCache = new HashMap<>(30);
42
43 /**
44 * Gets table name.
45 * 获得表名
46 * 判断是否有@Table注解,如果有得到注解的name,如果name为空,则使用类名做为表名
47 * 如果没有@Table返回null
48 *
49 * @param <E> the type parameter
50 * @param entity the entity
51 * @return the table name
52 */
53 public static <E> String getTableName(E entity) {
54 String tableName = tableNameCache.get(entity.getClass().getName());
55 if (tableName == null) {
56 Table table = entity.getClass().getAnnotation(Table.class);
57 if (table != null && StringUtil.isNotBlank(table.name())) {
58 tableName = table.name();
59 } else {
60 tableName = entity.getClass().getSimpleName();
61 }
62 tableNameCache.put(entity.getClass().getName(), tableName);
63 }
64 return tableName;
65 }
66
67 /**
68 * Gets all db column info.
69 */
70 public static <E> List<DBColumnInfo> getAllDBColumnInfo(E entity) {
71 List<DBColumnInfo> dbColumnInfoList = tableAllDBColumnCache.get(entity.getClass().getName());
72 if (dbColumnInfoList == null) {
73 dbColumnInfoList = new ArrayList<>();
74 DBColumnInfo dbColumnInfo;
75 Field[] fields = entity.getClass().getDeclaredFields();
76 for (Field field : fields) {
77 Column column = field.getAnnotation(Column.class);
78 if (column != null) {
79 dbColumnInfo = new DBColumnInfo();
80 if (StringUtil.isBlank(column.name())) {
81 dbColumnInfo.setColumnName(field.getName());
82 } else {
83 dbColumnInfo.setColumnName(column.name());
84 }
85 if (null != field.getAnnotation(Id.class)) {
86 dbColumnInfo.setId(true);
87 }
88 dbColumnInfo.setFieldName(field.getName());
89 dbColumnInfoList.add(dbColumnInfo);
90 }
91 }
92 tableAllDBColumnCache.put(entity.getClass().getName(), dbColumnInfoList);
93 }
94 return dbColumnInfoList;
95 }
96
97 /**
98 * 返回表字段的所有字段 column1,column2,column3
99 *
100 * @param <E> the type parameter
101 * @param entity the entity
102 * @return string
103 */
104 public static <E> String getTableAllColumn(E entity) {
105
106 String allColumn = tableAllColumnNameCache.get(entity.getClass().getName());
107 if (allColumn == null) {
108 List<DBColumnInfo> dbColumnInfoList = getAllDBColumnInfo(entity);
109 StringBuilder allColumnsInfo = new StringBuilder();
110 int i = 1;
111 for (DBColumnInfo dbColumnInfo : dbColumnInfoList) {
112 allColumnsInfo.append(dbColumnInfo.getColumnName());
113 if (i != dbColumnInfoList.size()) {
114 allColumnsInfo.append(",");
115 }
116 i++;
117 }
118 allColumn = allColumnsInfo.toString();
119 tableAllColumnNameCache.put(entity.getClass().getName(), allColumn);
120 }
121 return allColumn;
122
123 }
124}
数据库交互使用spring 提供的 JdbcTemplate,这里没有在自己写一套 DBUtil
保存一个entity
保存操作相对简单,这里主要是将 entity 转换为 insert 语句
1/**
2 * Save int.
3 * @param entity the entity
4 * @return the int
5 */
6@Override
7public int save(Object entity) {
8 try {
9 String tableName = AnnotationParser.getTableName(entity);
10 StringBuilder property = new StringBuilder();
11 StringBuilder value = new StringBuilder();
12 List<Object> propertyValue = new ArrayList<>();
13 List<DBColumnInfo> dbColumnInfoList = AnnotationParser.getAllDBColumnInfo(entity);
14
15 for (DBColumnInfo dbColumnInfo : dbColumnInfoList) {
16 if (dbColumnInfo.isId() || !dbColumnInfo.isInsertAble()) {
17 continue;
18 }
19 // 不为null
20 Object o = BeanReflectionUtil.getFieldValue(entity, dbColumnInfo.getFieldName());
21 if (o != null) {
22 property.append(",").append(dbColumnInfo.getColumnName());
23 value.append(",").append("?");
24 propertyValue.add(o);
25 }
26 }
27 String sql = "insert into " + tableName + "(" + property.toString().substring(1) + ") values(" + value.toString().substring(1) + ")";
28 return this.getJdbcTemplate().update(sql, propertyValue.toArray());
29 } catch (Exception e) {
30 e.printStackTrace();
31 }
32 return 0;
33}
更新操作
更新操作相对于 保存来说,多了一步确定where
语句操作
1/**
2 * Update int.
3 *
4 * @param entity the entity
5 * @param ignoreNull the ignore null
6 * @param columns the columns
7 * @return the int
8 */
9@Override
10public int update(Object entity, boolean ignoreNull, String... columns) {
11 try {
12 String tableName = AnnotationParser.getTableName(entity);
13 StringBuilder property = new StringBuilder();
14 StringBuilder where = new StringBuilder();
15 List<Object> propertyValue = new ArrayList<>();
16 List<Object> wherePropertyValue = new ArrayList<>();
17 List<DBColumnInfo> dbColumnInfos = AnnotationParser.getAllDBColumnInfo(entity);
18 for (DBColumnInfo dbColumnInfo : dbColumnInfos) {
19
20 Object o = BeanReflectionUtil.getFieldValue(entity, dbColumnInfo.getFieldName());
21 if (dbColumnInfo.isId()) {
22 where.append(" and ").append(dbColumnInfo.getColumnName()).append(" = ? ");
23 wherePropertyValue.add(o);
24 } else if (ignoreNull || o != null) {
25 property.append(",").append(dbColumnInfo.getColumnName()).append("=?");
26 propertyValue.add(o);
27 }
28 }
29
30 if (wherePropertyValue.isEmpty()) {
31 throw new IllegalArgumentException("更新表 [" + tableName + "] 无法找到id, 请求数据:" + entity);
32 }
33
34 String sql = "update " + tableName + " set " + property.toString().substring(1) + " where " + where.toString().substring(5);
35 propertyValue.addAll(wherePropertyValue);
36 return this.getJdbcTemplate().update(sql, propertyValue.toArray());
37 } catch (Exception e) {
38 e.printStackTrace();
39 }
40 return 0;
41}
删除操作
相对 update 简单一点
1/**
2 * Delete int.
3 * <p>根据id 删除对应的数据</p>
4 *
5 * @param entity the entity
6 * @return the int
7 */
8@Override
9public int delete(Object entity) {
10 try {
11 String tableName = AnnotationParser.getTableName(entity);
12 StringBuilder where = new StringBuilder(" 1=1 ");
13 List<Object> whereValue = new ArrayList<>(5);
14 List<DBColumnInfo> dbColumnInfos = AnnotationParser.getAllDBColumnInfo(entity);
15 for (DBColumnInfo dbColumnInfo : dbColumnInfos) {
16 if (dbColumnInfo.isId()) {
17 Object o = BeanReflectionUtil.getFieldValue(entity, dbColumnInfo.getFieldName());
18 if (null != o) {
19 whereValue.add(o);
20 }
21 where.append(" and `").append(dbColumnInfo.getColumnName()).append("` = ? ");
22 }
23 }
24
25 if (whereValue.size() == 0) {
26 throw new IllegalStateException("delete " + tableName + " id 无对应值,不能删除");
27 }
28 String sql = "delete from " + tableName + " where " + where.toString();
29 return this.getJdbcTemplate().update(sql, whereValue);
30 } catch (Exception e) {
31 e.printStackTrace();
32 }
33 return 0;
34}
通过上面的示例,就可以简单的实现一个 ORM, 为了更好的使用,我们还需要提供自己写 SQL 的方案
这里直接使用基于 JDK 动态代理实现
在 Java Persistence API
中没有我们需要的 @Query 和 @Param 这里我们自定义一下这两个注解,同时为了让 Query 支持返回主键,在定义一个 ReturnGeneratedKey
注解
Query.java
1package com.zyndev.tool.fastsql.annotation;
2
3import java.lang.annotation.ElementType;
4import java.lang.annotation.Retention;
5import java.lang.annotation.RetentionPolicy;
6import java.lang.annotation.Target;
7
8/**
9 * 查询操作 sql
10 *
11 * @author 张瑀楠 zyndev@gmail.com
12 * @version 0.0.1
13 * @since 2017/12/22 17:26
14 */
15@Target({ElementType.METHOD})
16@Retention(RetentionPolicy.RUNTIME)
17public @interface Query {
18
19 /**
20 * sql 语句
21 */
22 String value();
23}
Param.java
1package com.zyndev.tool.fastsql.annotation;
2
3import java.lang.annotation.ElementType;
4import java.lang.annotation.Retention;
5import java.lang.annotation.RetentionPolicy;
6import java.lang.annotation.Target;
7
8/**
9 *
10 * @author 张瑀楠 zyndev@gmail.com
11 * @version 0.0.1
12 * @since 2017/12/22 17:29
13 */
14@Target( { ElementType.PARAMETER })
15@Retention(RetentionPolicy.RUNTIME)
16public @interface Param {
17
18 /**
19 * 指出这个值是 SQL 语句中哪个参数的值,使用命名参数
20 */
21 String value();
22}
ReturnGeneratedKey.java
1package com.zyndev.tool.fastsql.annotation;
2
3
4import java.lang.annotation.*;
5
6
7/**
8 * 返回主键
9 *
10 * @author 张瑀楠 zyndev@gmail.com
11 * @version 0.0.3
12 *
13 */
14@Target({ElementType.METHOD})
15@Retention(RetentionPolicy.RUNTIME)
16@Documented
17public @interface ReturnGeneratedKey {
18}
1CREATE TABLE `tb_user` (
2 `id` int(11) NOT NULL AUTO_INCREMENT,
3 `uid` varchar(40) DEFAULT NULL,
4 `account_name` varchar(40) DEFAULT NULL,
5 `nick_name` varchar(23) DEFAULT NULL,
6 `password` varchar(30) DEFAULT NULL,
7 `phone` varchar(16) DEFAULT NULL,
8 `register_time` timestamp NULL DEFAULT NULL,
9 `update_time` timestamp NULL DEFAULT NULL,
10 PRIMARY KEY (`id`)
11) ENGINE=InnoDB AUTO_INCREMENT=3 DEFAULT CHARSET=utf8
这里使用一个 User.java 作为例子:
User.java
1package com.zyndev.tool.fastsql.repository;
2
3import lombok.*;
4
5import java.io.Serializable;
6import java.util.Date;
7
8/**
9 * @author 张瑀楠 zyndev@gmail.com
10 * @version 1.0
11 * @date 2017-12-27 15:04:13
12 */
13@Data
14public class User implements Serializable {
15
16 private static final long serialVersionUID = 1L;
17
18 private Integer id;
19
20 private String uid;
21
22 private String accountName;
23
24 private String nickName;
25
26 private String password;
27
28 private String phone;
29
30 private Date registerTime;
31
32 private Date updateTime;
33
34}
1package com.zyndev.tool.fastsql.repository;
2
3import com.zyndev.tool.fastsql.annotation.Param;
4import com.zyndev.tool.fastsql.annotation.Query;
5import com.zyndev.tool.fastsql.annotation.ReturnGeneratedKey;
6import org.springframework.stereotype.Repository;
7
8import java.util.List;
9import java.util.Map;
10
11/**
12 * 这里应该有描述
13 *
14 * @version 1.0
15 * @author 张瑀楠 zyndev@gmail.com
16 * @date 2017 /12/22 18:13
17 */
18@Repository
19public interface UserRepository {
20
21 @Query("select count(*) from tb_user")
22 public Integer getCount();
23
24 @Query("delete from tb_user where id = ?1")
25 public Boolean deleteById(int id);
26
27 @Query("select count(*) from tb_user where password = ?1 ")
28 public int getCountByPassword(@Param("password") String password);
29
30 @Query("select uid from tb_user where password = ?1 ")
31 public String getUidByPassword(@Param("password") String password);
32
33 @Query("select * from tb_user where id = :id ")
34 public User getUserById(@Param("id") Integer id);
35
36 @Query("select * " +
37 " from tb_user " +
38 " where account_name = :accountName ")
39 public List<User> getUserByAccountName(@Param("accountName") String accountName);
40
41 @Query("insert into tb_user(id, account_name, password, uid, nick_name, register_time, update_time) " +
42 "values(:id, :user.accountName, :user.password, :user.uid, :user.nickName, :user.registerTime, :user.updateTime )")
43 public int saveUser(@Param("id") Integer id, @Param("user") User user);
44
45 @ReturnGeneratedKey
46 @Query("insert into tb_user(account_name, password, uid, nick_name, register_time, update_time) " +
47 "values(:user.accountName, :user.password, :user.uid, :user.nickName, :user.registerTime, :user.updateTime )")
48 public int saveUser(@Param("user") User user);
49}
在上面的我们已经完成一些准备工作,包括:
接下来,我们看看如何将这些整合在一起
大致流程:
@Query
, @Param
,@ReturnGeneratedKey
注解,并取得方法的返回值FacadeProxy.java
为 Repository 生成代理,当代理方法执行时,回调 invoke 方法,invoke 中逻辑写到StatementParser.java中,防止类功能过大
1package com.zyndev.tool.fastsql.core;
2
3import org.apache.commons.logging.Log;
4import org.apache.commons.logging.LogFactory;
5
6import java.lang.reflect.InvocationHandler;
7import java.lang.reflect.Method;
8import java.lang.reflect.Proxy;
9
10/**
11 * 生成代理
12 *
13 * @author 张瑀楠 zyndev@gmail.com
14 * @version 0.0.1
15 * @since 2017 /12/23 上午12:40
16 */
17@SuppressWarnings("unchecked")
18public class FacadeProxy implements InvocationHandler {
19
20 private final static Log logger = LogFactory.getLog(FacadeProxy.class);
21
22 @Override
23 public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
24 return StatementParser.invoke(proxy, method, args);
25 }
26
27 /**
28 * New mapper proxy t.
29 */
30 protected static <T> T newMapperProxy(Class<T> mapperInterface) {
31 logger.info(" 生成代理:" + mapperInterface.getName());
32 ClassLoader classLoader = mapperInterface.getClassLoader();
33 Class<?>[] interfaces = new Class[]{mapperInterface};
34 FacadeProxy proxy = new FacadeProxy();
35 return (T) Proxy.newProxyInstance(classLoader, interfaces, proxy);
36 }
37}
在这里使用了 BeanFactoryAware
,关于这部分内容会单独写一篇,这里不在详细说明
1package com.zyndev.tool.fastsql.core;
2
3import com.zyndev.tool.fastsql.util.ClassScanner;
4import com.zyndev.tool.fastsql.util.StringUtil;
5import org.springframework.beans.BeansException;
6import org.springframework.beans.factory.BeanFactory;
7import org.springframework.beans.factory.BeanFactoryAware;
8import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
9import org.springframework.beans.factory.support.BeanDefinitionRegistry;
10import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
11import org.springframework.core.type.AnnotationMetadata;
12import org.springframework.stereotype.Repository;
13
14import java.io.IOException;
15import java.util.Set;
16
17/**
18 * The type Fast sql repository registrar.
19 *
20 * @author 张瑀楠 zyndev@gmail.com
21 * @version 1.0
22 * @date 2018 /2/23 12:26
23 */
24public class FastSqlRepositoryRegistrar implements ImportBeanDefinitionRegistrar, BeanFactoryAware {
25
26 private ConfigurableListableBeanFactory beanFactory;
27
28 @Override
29 public void registerBeanDefinitions(AnnotationMetadata annotationMetadata, BeanDefinitionRegistry beanDefinitionRegistry) {
30 System.out.println("FastSqlRepositoryRegistrar registerBeanDefinitions ");
31 String basePackage = "com.sparrow";
32 ClassScanner classScanner = new ClassScanner();
33 Set<Class<?>> classSet = null;
34 try {
35 classSet = classScanner.getPackageAllClasses(basePackage, true);
36 } catch (IOException | ClassNotFoundException e) {
37 e.printStackTrace();
38 }
39 for (Class clazz : classSet) {
40 if (clazz.getAnnotation(Repository.class) != null) {
41 beanFactory.registerSingleton(StringUtil.firstCharToLowerCase(clazz.getSimpleName()), FacadeProxy.newMapperProxy(clazz));
42 }
43 }
44 }
45
46 @Override
47 public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
48 this.beanFactory = (ConfigurableListableBeanFactory) beanFactory;
49 }
50}
在前面生成动态的代理的时候,可以看到,所有的invoke逻辑由StatementParser.java处理,这也是一个重量级的方法
invoke执行流程说明:
invoke(Object proxy, Method method, Object[] args)
sql
语句,无法取到sql则抛异常@Param
注解,并将参数与对应的Param的名称关联:param1->arg0 password->arg1(?i)select([\s\S]*?)
@ReturnGeneratedKey
注解@ReturnGeneratedKey
则直接执行语句并返回对应的结果@ReturnGeneratedKey
并且是 insert 语句则返回生成的主键关于重写sql
1@Query("insert into tb_user(id,
2account_name,
3password,
4uid,
5nick_name,
6register_time,
7update_time) values(
8 :id,
9 :user.accountName,
10 :user.password,
11 :user.uid,
12 :user.nickName,
13 :user.registerTime,
14 :user.updateTime )")
15 public int saveUser(@Param("id") Integer id, @Param("user") User user);
首先获取sql
1insert into tb_user(id, account_name, password, uid, nick_name, register_time, update_time)
2values(:id, :user.accountName, :user.password, :user.uid, :user.nickName, :user.registerTime, :user.updateTime )
可以看出这并不是标准的 sql 也不是 jdbc 可以识别的sql,这里我们使用正则\?\d+(.[A-Za-z]+)?|:[A-Za-z0-9]+(.[A-Za-z]+)?
来提取出 ?1 ?2 :1 :2 :id :user.accountName 的特殊标志,并将其替换为 ?
替换过程中没替换一个 ?则记录对应的 ?所代表的数值
替换后的SQL为
1insert into tb_user(id, account_name, password, uid, nick_name, register_time, update_time)
2values(?, ?, ?, ?, ?, ?, ? )
这样的sql就可以被 jdbc 处理了 同时参数允许为:
1:id, :user.accountName, :user.password, :user.uid, :user.nickName, :user.registerTime, :user.updateTime
这里的 id 可以从参数中 id 直接获取,
:user.accountName
则需要从参数 :user
即 user 中通过反射获取,这样 SQL 的重写就完成了
返回结果集封装可以通过反射,可以直接看下面代码
StatementParser.java
1package com.zyndev.tool.fastsql.core;
2
3import com.sun.org.apache.bcel.internal.generic.IF_ACMPEQ;
4import com.zyndev.tool.fastsql.annotation.Param;
5import com.zyndev.tool.fastsql.annotation.Query;
6import com.zyndev.tool.fastsql.annotation.ReturnGeneratedKey;
7import com.zyndev.tool.fastsql.convert.BeanConvert;
8import com.zyndev.tool.fastsql.convert.ListConvert;
9import com.zyndev.tool.fastsql.convert.SetConvert;
10import com.zyndev.tool.fastsql.util.BeanReflectionUtil;
11import com.zyndev.tool.fastsql.util.StringUtil;
12import org.apache.commons.logging.Log;
13import org.apache.commons.logging.LogFactory;
14import org.springframework.jdbc.core.*;
15import org.springframework.jdbc.support.GeneratedKeyHolder;
16import org.springframework.jdbc.support.KeyHolder;
17import org.springframework.jdbc.support.rowset.SqlRowSet;
18import sun.reflect.generics.reflectiveObjects.NotImplementedException;
19
20import java.lang.reflect.Method;
21import java.lang.reflect.Parameter;
22import java.lang.reflect.ParameterizedType;
23import java.lang.reflect.Type;
24import java.sql.Connection;
25import java.sql.PreparedStatement;
26import java.sql.SQLException;
27import java.sql.Statement;
28import java.util.HashMap;
29import java.util.List;
30import java.util.Map;
31
32/**
33 * sql 语句解析
34 * <p>
35 * 暂时只能处理 select count(*) from tb_user 类似语句
36 *
37 * @author 张瑀楠 zyndev@gmail.com
38 * @version 0.0.1
39 * @since 2017 /12/23 下午12:11
40 */
41class StatementParser {
42
43 private final static Log logger = LogFactory.getLog(StatementParser.class);
44
45 private static PreparedStatementCreator getPreparedStatementCreator(final String sql, final Object[] args, final boolean returnKeys) {
46 PreparedStatementCreator creator = new PreparedStatementCreator() {
47
48 @Override
49 public PreparedStatement createPreparedStatement(Connection con) throws SQLException {
50 PreparedStatement ps = con.prepareStatement(sql);
51 if (returnKeys) {
52 ps = con.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS);
53 } else {
54 ps = con.prepareStatement(sql);
55 }
56
57 if (args != null) {
58 for (int i = 0; i < args.length; i++) {
59 Object arg = args[i];
60 if (arg instanceof SqlParameterValue) {
61 SqlParameterValue paramValue = (SqlParameterValue) arg;
62 StatementCreatorUtils.setParameterValue(ps, i + 1, paramValue,
63 paramValue.getValue());
64 } else {
65 StatementCreatorUtils.setParameterValue(ps, i + 1,
66 SqlTypeValue.TYPE_UNKNOWN, arg);
67 }
68 }
69 }
70 return ps;
71 }
72 };
73 return creator;
74 }
75
76 /**
77 * 此处对 Repository 中方法进行解析,解析成对应的sql 和 参数
78 * <p>
79 * sql 来自于 @Query 注解的 value
80 * 参数 来自方法的参数
81 * <p>
82 * 注意根据返回值的不同封装结果集
83 *
84 * @param proxy 执行对象
85 * @param method 执行方法
86 * @param args 参数
87 * @return object
88 */
89 static Object invoke(Object proxy, Method method, Object[] args) throws Exception {
90
91 JdbcTemplate jdbcTemplate = DataSourceHolder.getInstance().getJdbcTemplate();
92
93 boolean logDebug = logger.isDebugEnabled();
94
95 String methodReturnType = method.getReturnType().getName();
96 Query query = method.getAnnotation(Query.class);
97
98 if (null == query || StringUtil.isBlank(query.value())) {
99 logger.error(method.toGenericString() + " 无 query 注解或 SQL 为空");
100 throw new IllegalStateException(method.toGenericString() + " 无 query 注解或 SQL 为空");
101 }
102
103 String originSql = query.value().trim();
104
105 System.out.println("sql:" + query.value());
106 Map<String, Object> namedParamMap = new HashMap<>();
107 Parameter[] parameters = method.getParameters();
108 if (args != null && args.length > 0) {
109 for (int i = 0; i < args.length; ++i) {
110 Param param = parameters[i].getAnnotation(Param.class);
111 if (null != param) {
112 namedParamMap.put(param.value(), args[i]);
113 }
114 namedParamMap.put("?" + (i + 1), args[i]);
115 }
116 }
117
118 if (logDebug) {
119 logger.debug("执行 sql: " + originSql);
120 }
121
122 // 判断 sql 类型, 判断是否为 select 开头语句
123 boolean isQuery = originSql.trim().matches("(?i)select([\\s\\S]*?)");
124 Object[] params = null;
125 // rewrite sql
126 if (null != args && args.length > 0) {
127 List<String> results = StringUtil.matches(originSql, "\\?\\d+(\\.[A-Za-z]+)?|:[A-Za-z0-9]+(\\.[A-Za-z]+)?");
128 if (results.isEmpty()) {
129 params = args;
130 } else {
131 params = new Object[results.size()];
132 for (int i = 0; i < results.size(); ++i) {
133 if (results.get(i).charAt(0) == ':') {
134 originSql = originSql.replaceFirst(results.get(i), "?");
135 // 判断是否是 param.param 的格式
136 if (!results.get(i).contains(".")) {
137 params[i] = namedParamMap.get(results.get(i).substring(1));
138 } else {
139 String[] paramArgs = results.get(i).split("\\.");
140 Object param = namedParamMap.get(paramArgs[0].substring(1));
141 params[i] = BeanReflectionUtil.getFieldValue(param, paramArgs[1]);
142 }
143 continue;
144 }
145 int paramIndex = Integer.parseInt(results.get(i).substring(1));
146 originSql = originSql.replaceFirst("\\?" + paramIndex, "?");
147 params[i] = namedParamMap.get(results.get(i));
148 }
149 }
150 }
151
152
153 System.out.println("execute sql:" + originSql);
154 System.out.print("params : ");
155 if (null != params) {
156 for (Object o : params) {
157 System.out.print(o + ",\t");
158 }
159 }
160 System.out.println("\n");
161
162
163 /**
164 * 如果返回值是基本类型或者其包装类
165 */
166 System.out.println(methodReturnType);
167 if (isQuery) {
168 // 查询方法
169 if ("java.lang.Integer".equals(methodReturnType) || "int".equals(methodReturnType)) {
170 return jdbcTemplate.queryForObject(originSql, params, Integer.class);
171 } else if ("java.lang.String".equals(methodReturnType)) {
172 return jdbcTemplate.queryForObject(originSql, params, String.class);
173 } else if ("java.util.List".equals(methodReturnType) || "java.util.Set".equals(methodReturnType)) {
174 String typeName = null;
175 Type returnType = method.getGenericReturnType();
176 if (returnType instanceof ParameterizedType) {
177 Type[] types = ((ParameterizedType) returnType).getActualTypeArguments();
178 if (null == types || types.length > 1) {
179 throw new IllegalArgumentException("当返回值为 list 时,必须标明具体类型,且只有一个");
180 }
181 typeName = types[0].getTypeName();
182 }
183 Object obj = BeanReflectionUtil.newInstance(typeName);
184 SqlRowSet rowSet = jdbcTemplate.queryForRowSet(originSql, params);
185 if ("java.util.List".equals(methodReturnType)) {
186 return ListConvert.convert(rowSet, obj);
187 }
188 return SetConvert.convert(rowSet, obj);
189 } else if ("java.util.Map".equals(methodReturnType)) {
190 throw new NotImplementedException();
191 } else {
192 SqlRowSet rowSet = jdbcTemplate.queryForRowSet(originSql, params);
193 Object obj = BeanReflectionUtil.newInstance(methodReturnType);
194 return BeanConvert.convert(rowSet, obj);
195 }
196 } else {
197 // 非查询方法
198 // 判断是否是insert 语句
199 ReturnGeneratedKey returnGeneratedKeyAnnotation = method.getAnnotation(ReturnGeneratedKey.class);
200 if (returnGeneratedKeyAnnotation == null) {
201 int retVal = jdbcTemplate.update(originSql, params);
202 if ("java.lang.Integer".equals(methodReturnType) || "int".equals(methodReturnType)) {
203 return retVal;
204 } else if ("java.lang.Boolean".equals(methodReturnType)) {
205 return retVal > 0;
206 }
207 } else {
208 // 判断是否是 insert 语句
209 boolean isInsertSql = originSql.trim().matches("(?i)insert([\\s\\S]*?)");
210 if (isInsertSql) {
211 KeyHolder keyHolder = new GeneratedKeyHolder();
212 PreparedStatementCreator preparedStatementCreator = getPreparedStatementCreator(originSql, params, true);
213 jdbcTemplate.update(preparedStatementCreator, keyHolder);
214 if ("java.lang.Integer".equals(methodReturnType) || "int".equals(methodReturnType)) {
215 return keyHolder.getKey().intValue();
216 } else if ("java.lang.Long".equals(methodReturnType) || "long".equals(methodReturnType)) {
217 return keyHolder.getKey().longValue();
218 }
219 logger.error(method.toGenericString() + " 返回主键id应该为 int 或者 long 类型 ");
220 throw new IllegalArgumentException(method.toGenericString() + " 返回主键id应该为 int 或者 long 类型 ");
221 } else {
222 logger.error(method.toGenericString() + " 非 insert 语句 无法返回 GeneratedKey:sql语句为:" + originSql);
223 throw new IllegalStateException(method.toGenericString() + " 非 insert 语句 无法返回 GeneratedKey:sql语句为:" + originSql);
224 }
225 }
226 }
227 return null;
228 }
229}
由此一个简单 ORM 就实现了,其实实现 ORM 并不难,难的是细心处理各种可能的 Bug