近期由于需要改造 shiro 成无状态服务,对 shiro 一些问题有了更深入的理解。这里通过一个 springboot 项目与 shiro 整合后对 shiroFilter 的初始化与拦截流程进行源码角度的分析,加深理解。
@Bean(name = "shiroFilter")
public ShiroFilterFactoryBean shiroFilterFactoryBean(SecurityManager securityManager) {
ShiroFilterFactoryBean shiroFilterFactoryBean = new ShiroFilterFactoryBean();
//Shiro的核心安全接口,这个属性是必须的
shiroFilterFactoryBean.setSecurityManager(securityManager);
Map<String, Filter> filterMap = new LinkedHashMap<>();
filterMap.put("authc", new AjaxPermissionsAuthorizationFilter());
shiroFilterFactoryBean.setFilters(filterMap);
/*定义shiro过滤链 Map结构
* Map中key(xml中是指value值)的第一个'/'代表的路径是相对于HttpServletRequest.getContextPath()的值来的
* anon:它对应的过滤器里面是空的,什么都没做,这里.do和.jsp后面的*表示参数,比方说login.jsp?main这种
* authc:该过滤器下的页面必须验证后才能访问,它是Shiro内置的一个拦截器org.apache.shiro.web.filter.authc.FormAuthenticationFilter
*/
Map<String, String> filterChainDefinitionMap = new LinkedHashMap<>();
/* 过滤链定义,从上向下顺序执行,一般将 / ** 放在最为下边:这是一个坑呢,一不小心代码就不好使了;
authc:所有url都必须认证通过才可以访问; anon:所有url都都可以匿名访问 */
filterChainDefinitionMap.put("/", "anon");
filterChainDefinitionMap.put("/AdminLTE/**", "anon");
filterChainDefinitionMap.put("/bootstrap/**", "anon");
filterChainDefinitionMap.put("/cron/**", "anon");
filterChainDefinitionMap.put("/flat_ui/**", "anon");
filterChainDefinitionMap.put("/font-awesome/**", "anon");
filterChainDefinitionMap.put("/lonicons/**", "anon");
filterChainDefinitionMap.put("/iview/**", "anon");
filterChainDefinitionMap.put("/layer/**", "anon");
filterChainDefinitionMap.put("/libs/**", "anon");
filterChainDefinitionMap.put("/**/*.js", "anon");
filterChainDefinitionMap.put("/**/*.html", "anon");
filterChainDefinitionMap.put("/**/*.shtml", "anon");
filterChainDefinitionMap.put("/task/**", "anon");
filterChainDefinitionMap.put("/templates/**", "anon");
filterChainDefinitionMap.put("/login/auth", "anon");
filterChainDefinitionMap.put("/login/logout", "anon");
filterChainDefinitionMap.put("/error", "anon");
//swagger放行
filterChainDefinitionMap.put("/swagger-ui.html", "anon");
filterChainDefinitionMap.put("/swagger-resources", "anon");
filterChainDefinitionMap.put("/v2/api-docs", "anon");
filterChainDefinitionMap.put("/webjars/springfox-swagger-ui/**", "anon");
filterChainDefinitionMap.put("/configuration/security", "anon");
filterChainDefinitionMap.put("/configuration/ui", "anon");
filterChainDefinitionMap.put("/**", "authc");
shiroFilterFactoryBean.setFilterChainDefinitionMap(filterChainDefinitionMap);
return shiroFilterFactoryBean;
}
spring 在进行上下文初始化时会先进行 bean 和 filter 的注册操作,org.springframework.boot.web.servlet.ServletContextInitializerBeans#addAdaptableBeans:
protected void addAdaptableBeans(ListableBeanFactory beanFactory) {
MultipartConfigElement multipartConfig = getMultipartConfig(beanFactory);
//注册beanFactory
addAsRegistrationBean(beanFactory, Servlet.class,
new ServletRegistrationBeanAdapter(multipartConfig));
//进行filter初始化
addAsRegistrationBean(beanFactory, Filter.class,
new FilterRegistrationBeanAdapter());
for (Class<?> listenerType : ServletListenerRegistrationBean
.getSupportedTypes()) {
addAsRegistrationBean(beanFactory, EventListener.class,
(Class<EventListener>) listenerType,
new ServletListenerRegistrationBeanAdapter());
}
}
真正的初始化流程在 org.springframework.boot.web.servlet.ServletContextInitializerBeans#addAsRegistrationBean(org.springframework.beans.factory.ListableBeanFactory, java.lang.Class, org.springframework.boot.web.servlet.ServletContextInitializerBeans.RegistrationBeanAdapter):
protected <T> void addAsRegistrationBean(ListableBeanFactory beanFactory,
Class<T> type, RegistrationBeanAdapter<T> adapter) {
addAsRegistrationBean(beanFactory, type, type, adapter);
}
private <T, B extends T> void addAsRegistrationBean(ListableBeanFactory beanFactory,
Class<T> type, Class<B> beanType, RegistrationBeanAdapter<T> adapter) {
List<Map.Entry<String, B>> entries = getOrderedBeansOfType(beanFactory, beanType,
this.seen);
for (Entry<String, B> entry : entries) {
String beanName = entry.getKey();
B bean = entry.getValue();
if (this.seen.add(bean)) {
// One that we haven't already seen
RegistrationBean registration = adapter.createRegistrationBean(beanName,
bean, entries.size());
int order = getOrder(bean);
registration.setOrder(order);
this.initializers.add(type, registration);
if (logger.isTraceEnabled()) {
logger.trace(
"Created " + type.getSimpleName() + " initializer for bean '"
+ beanName + "'; order=" + order + ", resource="
+ getResourceDescription(beanName, beanFactory));
}
}
}
}
然后我们进入 org.springframework.boot.web.servlet.ServletContextInitializerBeans#getOrderedBeansOfType(org.springframework.beans.factory.ListableBeanFactory, java.lang.Class, java.util.Set)方法:
private <T> List<Entry<String, T>> getOrderedBeansOfType(
ListableBeanFactory beanFactory, Class<T> type, Set<?> excludes) {
String[] names = beanFactory.getBeanNamesForType(type, true, false);
Map<String, T> map = new LinkedHashMap<>();
for (String name : names) {
if (!excludes.contains(name) && !ScopedProxyUtils.isScopedTarget(name)) {
//name为shiroFilter type为Filter类型
//这里的beanFactory为DefaultListableBeanFactory
T bean = beanFactory.getBean(name, type);
if (!excludes.contains(bean)) {
map.put(name, bean);
}
}
}
List<Entry<String, T>> beans = new ArrayList<>();
beans.addAll(map.entrySet());
beans.sort((o1, o2) -> AnnotationAwareOrderComparator.INSTANCE
.compare(o1.getValue(), o2.getValue()));
return beans;
}
在这里会调用 DefaultListableBeanFactory 的 getBean 方法去获取 shiroFilter 实例。我们来看下 org.springframework.beans.factory.support.AbstractBeanFactory#getBean(java.lang.String, java.lang.Class)方法:
public <T> T getBean(String name, Class<T> requiredType) throws BeansException {
return this.doGetBean(name, requiredType, (Object[])null, false);
}
然后进入到 org.springframework.beans.factory.support.AbstractBeanFactory#doGetBean 方法:
protected <T> T doGetBean(String name, @Nullable Class<T> requiredType, @Nullable Object[] args, boolean typeCheckOnly) throws BeansException {
String beanName = this.transformedBeanName(name);
//在这里实例化ShiroFilterFactoryBean
Object sharedInstance = this.getSingleton(beanName);
Object bean;
if (sharedInstance != null && args == null) {
if (this.logger.isTraceEnabled()) {
if (this.isSingletonCurrentlyInCreation(beanName)) {
this.logger.trace("Returning eagerly cached instance of singleton bean '" + beanName + "' that is not fully initialized yet - a consequence of a circular reference");
} else {
this.logger.trace("Returning cached instance of singleton bean '" + beanName + "'");
}
}
//获取bean
bean = this.getObjectForBeanInstance(sharedInstance, name, beanName, (RootBeanDefinition)null);
} else {
......................
先实例化 ShiroFilterFactoryBean,然后再通过 getObjectForBeanInstance 获取 filter 实例。
protected Object getObjectForBeanInstance(Object beanInstance, String name, String beanName, @Nullable RootBeanDefinition mbd) {
String currentlyCreatedBean = (String)this.currentlyCreatedBean.get();
if (currentlyCreatedBean != null) {
this.registerDependentBean(beanName, currentlyCreatedBean);
}
return super.getObjectForBeanInstance(beanInstance, name, beanName, mbd);
}
到最后一层层调用到 org.apache.shiro.spring.web.ShiroFilterFactoryBean#getObject 方法,我们先来看一看下图中的方法调用栈:
我们接着来看 org.apache.shiro.spring.web.ShiroFilterFactoryBean#getObject 方法来创建 shiroFilter 实例:
public Object getObject() throws Exception {
if (instance == null) {
instance = createInstance();
}
return instance;
}
在这里主要进行 filter 的实例初始化操作,我们看一下 createInstance 方法:
protected AbstractShiroFilter createInstance() throws Exception {
log.debug("Creating Shiro Filter instance.");
SecurityManager securityManager = getSecurityManager();
if (securityManager == null) {
String msg = "SecurityManager property must be set.";
throw new BeanInitializationException(msg);
}
......................
FilterChainManager manager = createFilterChainManager();
PathMatchingFilterChainResolver chainResolver = new PathMatchingFilterChainResolver();
chainResolver.setFilterChainManager(manager);
............
return new SpringShiroFilter((WebSecurityManager) securityManager, chainResolver);
}
在这个方法里主要做了以下几个工作:
我们主要分析下 createFilterChainManager:
protected FilterChainManager createFilterChainManager() {
DefaultFilterChainManager manager = new DefaultFilterChainManager();
Map<String, Filter> defaultFilters = manager.getFilters();
//apply global settings if necessary:
for (Filter filter : defaultFilters.values()) {
applyGlobalPropertiesIfNecessary(filter);
}
//Apply the acquired and/or configured filters:
//这里获取到的是文章开头的时候配置的filter即AjaxPermissionsAuthorizationFilter
Map<String, Filter> filters = getFilters();
if (!CollectionUtils.isEmpty(filters)) {
for (Map.Entry<String, Filter> entry : filters.entrySet()) {
String name = entry.getKey();
Filter filter = entry.getValue();
applyGlobalPropertiesIfNecessary(filter);
if (filter instanceof Nameable) {
((Nameable) filter).setName(name);
}
//'init' argument is false, since Spring-configured filters should be initialized
//in Spring (i.e. 'init-method=blah') or implement InitializingBean:
manager.addFilter(name, filter, false);
}
}
//build up the chains:
Map<String, String> chains = getFilterChainDefinitionMap();
if (!CollectionUtils.isEmpty(chains)) {
for (Map.Entry<String, String> entry : chains.entrySet()) {
String url = entry.getKey();
String chainDefinition = entry.getValue();
manager.createChain(url, chainDefinition);
}
}
return manager;
}
这里主要涉及到三 h 步操作,创建 filterChainManager 然后加载配置的 filter 和加载 filterChainDefinition 的映射关系。
public DefaultFilterChainManager() {
this.filters = new LinkedHashMap<String, Filter>();
this.filterChains = new LinkedHashMap<String, NamedFilterList>();
addDefaultFilters(false);
}
我们看一下 addDefaultFilters 方法:
protected void addDefaultFilters(boolean init) {
for (DefaultFilter defaultFilter : DefaultFilter.values()) {
addFilter(defaultFilter.name(), defaultFilter.newInstance(), init, false);
}
}
而 defaultFilter 有哪些呢,我们可以看下:
public enum DefaultFilter {
anon(AnonymousFilter.class),
authc(FormAuthenticationFilter.class),
authcBasic(BasicHttpAuthenticationFilter.class),
logout(LogoutFilter.class),
noSessionCreation(NoSessionCreationFilter.class),
perms(PermissionsAuthorizationFilter.class),
port(PortFilter.class),
rest(HttpMethodPermissionFilter.class),
roles(RolesAuthorizationFilter.class),
ssl(SslFilter.class),
user(UserFilter.class);
也就是说会预先将 filters 中填充上不同权限的 filter,按照文章开头的配置将会用到其中的 anon 和 authc 两种。
filterMap.put("authc", new AjaxPermissionsAuthorizationFilter());
shiroFilterFactoryBean.setFilters(filterMap);
将这些 filter 通过 manager.addFilter 方法设置到 FilterChainManager 中去,会设置到 filters 列表中 authc 对应的那个 filter 中去,其中也限定了 AjaxPermissionsAuthorizationFilter 是 FormAuthenticationFilter 类型的。
public class DefaultFilterChainManager implements FilterChainManager {
private static transient final Logger log = LoggerFactory.getLogger(DefaultFilterChainManager.class);
private FilterConfig filterConfig;
private Map<String, Filter> filters; //pool of filters available for creating chains
private Map<String, NamedFilterList> filterChains; //key: chain name, value: chain
到这里初始化部分就结束了,我们接下来看具体的拦截流程。
关于 SpringShiroFilter 的类继承关系在上文已经提到过,在进入拦截流程之前我们先看一下 SpringShiroFilter 的结构:
private static final class SpringShiroFilter extends AbstractShiroFilter {
protected SpringShiroFilter(WebSecurityManager webSecurityManager, FilterChainResolver resolver) {
super();
if (webSecurityManager == null) {
throw new IllegalArgumentException("WebSecurityManager property cannot be null.");
}
setSecurityManager(webSecurityManager);
if (resolver != null) {
setFilterChainResolver(resolver);
}
}
}
SpringShiroFilter 继承自 AbstractShiroFilter,AbstractShiroFilter 继承自 OncePerRequestFilter,OncePerRequestFilter 继承自 NameableFilter...
而这些 filter 中,OncePerRequestFilter 是有具体的 doFilter 方法实现的,熟悉 web 编程的同学都知道,过滤器的工作机制是所有请求先经过 doFilter 方法,也就是说 shiroFilter 拦截请求的核心在于 OncePerRequestFilter 的 doFilter 方法:
public final void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
String alreadyFilteredAttributeName = getAlreadyFilteredAttributeName();
if ( request.getAttribute(alreadyFilteredAttributeName) != null ) {
log.trace("Filter '{}' already executed. Proceeding without invoking this filter.", getName());
filterChain.doFilter(request, response);
} else //noinspection deprecation
//如果没有启用会走这个分支,实际上默认是启用的
if (/* added in 1.2: */ !isEnabled(request, response) ||
/* retain backwards compatibility: */ shouldNotFilter(request) ) {
log.debug("Filter '{}' is not enabled for the current request. Proceeding without invoking this filter.",
getName());
filterChain.doFilter(request, response);
} else {
// Do invoke this filter...
log.trace("Filter '{}' not yet executed. Executing now.", getName());
request.setAttribute(alreadyFilteredAttributeName, Boolean.TRUE);
try {
doFilterInternal(request, response, filterChain);
} finally {
// Once the request has finished, we're done and we don't
// need to mark as 'already filtered' any more.
request.removeAttribute(alreadyFilteredAttributeName);
}
}
}
一般的请求都会进入最后一个分支,也就是执行 doFilterInternal 方法:
protected void doFilterInternal(ServletRequest servletRequest, ServletResponse servletResponse, final FilterChain chain)
throws ServletException, IOException {
Throwable t = null;
try {
final ServletRequest request = prepareServletRequest(servletRequest, servletResponse, chain);
final ServletResponse response = prepareServletResponse(request, servletResponse, chain);
final Subject subject = createSubject(request, response);
//noinspection unchecked
subject.execute(new Callable() {
public Object call() throws Exception {
updateSessionLastAccessTime(request, response);
executeChain(request, response, chain);
return null;
}
});
........................
接下来会进入 org.apache.shiro.web.servlet.AbstractShiroFilter#executeChain 方法:
protected void executeChain(ServletRequest request, ServletResponse response, FilterChain origChain)
throws IOException, ServletException {
FilterChain chain = getExecutionChain(request, response, origChain);
chain.doFilter(request, response);
}
再来看一下 getExecutionChain 方法:
protected FilterChain getExecutionChain(ServletRequest request, ServletResponse response, FilterChain origChain) {
FilterChain chain = origChain;
//获取最初初始化时传入的那个filterChainResolver,实际上是PathMatchingFilterChainResolver
FilterChainResolver resolver = getFilterChainResolver();
if (resolver == null) {
log.debug("No FilterChainResolver configured. Returning original FilterChain.");
return origChain;
}
//从resovler中取到filterChain
FilterChain resolved = resolver.getChain(request, response, origChain);
if (resolved != null) {
log.trace("Resolved a configured FilterChain for the current request.");
chain = resolved;
} else {
log.trace("No FilterChain configured for the current request. Using the default.");
}
return chain;
}
获取最初初始化时传入的那个 filterChainResolver,实际上是 PathMatchingFilterChainResolver,然后从 resolver 中获取到 FilterChain,这里我们主要看下这个 getChain 方法:
public FilterChain getChain(ServletRequest request, ServletResponse response, FilterChain originalChain) {
FilterChainManager filterChainManager = getFilterChainManager();
if (!filterChainManager.hasChains()) {
return null;
}
String requestURI = getPathWithinApplication(request);
for (String pathPattern : filterChainManager.getChainNames()) {
if (pathMatches(pathPattern, requestURI)) {
if (log.isTraceEnabled()) {
log.trace("Matched path pattern [" + pathPattern + "] for requestURI [" + requestURI + "]. " +
"Utilizing corresponding filter chain...");
}
return filterChainManager.proxy(originalChain, pathPattern);
}
}
return null;
}
这里最后返回的是利用 filterChainManager 的 proxy 方法创建的 FilterChain 的代理对象,也就是 org.apache.shiro.web.filter.mgt.DefaultFilterChainManager#proxy 方法:
public FilterChain proxy(FilterChain original, String chainName) {
//传入的chainName为对应的路径,然后通过getChain方法从filterChains中取到初始化时放入的与该路径匹配的FilterChain
NamedFilterList configured = getChain(chainName);
if (configured == null) {
String msg = "There is no configured chain under the name/key [" + chainName + "].";
throw new IllegalArgumentException(msg);
}
//做了一个代理操作
return configured.proxy(original);
}
这里传入的 chainName 为对应的路径,然后通过 getChain 方法从 filterChains 中取到初始化时放入的与该路径匹配的 FilterChain,并对返回的结果进行了代理,返回的是一个 ProxiedFilterChain 对象。
再回过头来看上面的 executeChain 方法:
protected void executeChain(ServletRequest request, ServletResponse response, FilterChain origChain)
throws IOException, ServletException {
FilterChain chain = getExecutionChain(request, response, origChain);
chain.doFilter(request, response);
}
此时我们需要关注的是 chain.doFilter,由于返回的 FilterChain 是一个 ProxiedFilterChain 对象,那么实际调用的就是 org.apache.shiro.web.servlet.ProxiedFilterChain#doFilter 方法:
public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException {
if (this.filters == null || this.filters.size() == this.index) {
//we've reached the end of the wrapped chain, so invoke the original one:
if (log.isTraceEnabled()) {
log.trace("Invoking original filter chain.");
}
this.orig.doFilter(request, response);
} else {
if (log.isTraceEnabled()) {
log.trace("Invoking wrapped filter at index [" + this.index + "]");
}
this.filters.get(this.index++).doFilter(request, response, this);
}
}
public class AnonymousFilter extends PathMatchingFilter {
@Override
protected boolean onPreHandle(ServletRequest request, ServletResponse response, Object mappedValue) {
// Always return true since we allow access to anyone
return true;
}
}
具体的类继承关系如下:
类继承关系为:
它的 doFilter 方法调用也是进入到 OncePerRequestFilter 的 doFilter 中,然后调用 AdviceFilter 的 doFilterInternal 方法,然后调用 PathMatchingFilter 的 preHandle 方法,匹配成功后进入 isFilterChainContinued 方法,然后进入 AccessControlFilter 的 onPreHandle 如下:
public boolean onPreHandle(ServletRequest request, ServletResponse response, Object mappedValue) throws Exception {
return isAccessAllowed(request, response, mappedValue) || onAccessDenied(request, response, mappedValue);
}
然后先调用 org.apache.shiro.web.filter.authc.AuthenticatingFilter#isAccessAllowed 方法:
@Override
protected boolean isAccessAllowed(ServletRequest request, ServletResponse response, Object mappedValue) {
return super.isAccessAllowed(request, response, mappedValue) ||
(!isLoginRequest(request, response) && isPermissive(mappedValue));
}
进而调用 org.apache.shiro.web.filter.authc.AuthenticationFilter#isAccessAllowed 方法:
protected boolean isAccessAllowed(ServletRequest request, ServletResponse response, Object mappedValue) {
Subject subject = getSubject(request, response);
return subject.isAuthenticated();
}
如果 subject.isAuthenticated()权限校验通不过,就返回 false,会进入 AccessControlFilter 的 onPreHandle 中调用 onAccessDenied 方法,而 onAccessDenied 方法是在 AjaxPermissionsAuthorizationFilter 中实现的。
到这里 shiroFilter 的初始化与拦截流程源码分析完毕,由于时间关系可能排版和流程并不友好,请各位慎品。