Monosoul's Dev Blog A blog to write down dev-related stuff I face
How to customize dependency injection in Spring

How to customize dependency injection in Spring (Part 1)

I work with Spring Framework a lot. I can’t even remember a single project I had at work where I didn’t use it. There is one part of Spring that I use especially often – it’s DI (Dependency Injection). And as often as I use it, I face a situation where it turns out to be not smart enough for me. This article is a result of me trying to customize dependency injection in Spring to make it smarter as well as trying to get a deeper understanding of how does it work. I hope you’ll find it useful as well!

Before going forward I recommend watching Evgeny’s Borisov “Spring the Ripper” series: part 1 and part 2.

Introduction

Let’s imagine that we were asked to develop a service to provide fortunes and horoscopes. The service will have many components, but 2 of them stand out the most:

  • Globa (Pavel Globa is a famous Russian astrologer), which implements interface FortuneTeller and will be telling fortunes:
  • Gypsy, which implements interface HoroscopeTeller and will provide horoscopes:

The service will also have multiple endpoints (controllers) to get fortunes and horoscopes. We will also use an aspect to restrict access to the service’s endpoints that will be applied to the controllers’ methods. It will look like this:

RestrictionAspect.java
@Aspect @Component @Slf4j public class RestrictionAspect { private final Predicate<String> ipIsAllowed; public RestrictionAspect(@NonNull final Predicate<String> ipIsAllowed) { this.ipIsAllowed = ipIsAllowed; } @Before("execution(public * com.github.monosoul.fortuneteller.web.*.*(..))") public void checkAccess() { val ip = getRequestSourceIp(); log.debug("Source IP: {}", ip); if (!ipIsAllowed.test(ip)) { throw new AccessDeniedException(format("Access for IP [%s] is denied", ip)); } } private String getRequestSourceIp() { val requestAttributes = currentRequestAttributes(); Assert.state(requestAttributes instanceof ServletRequestAttributes, "RequestAttributes needs to be a ServletRequestAttributes"); val request = ((ServletRequestAttributes) requestAttributes).getRequest(); return request.getRemoteAddr(); } }
Code language: Java (java)

We will use an implementation of ipIsAllowed predicate to check if the source IP is allowed to access the endpoint or not. In general, it might be any other aspect as well. For example an aspect for authorization.

At that point we’ve developed the service and it works great. Now, let’s talk about how to test it.

How to test it?

How can we make sure that the aspect got applied correctly? There are multiple ways of doing that.

We can write a separate test for the aspect and for the controllers, without creating a Spring context (which creates a proxy applying the aspect to the controller, you can read more about it in the official documentation), but in that case we will not test that the aspect is indeed applied to the controllers and works the way it should;

We can write a test where we will get the whole application context up and running, but in that case:

  • it will take the test quite long time to start, since Spring will be creating all the beans required for the app to work;
  • we will have to prepare a valid test data that will let the whole bean call chain to work without throwing an NPE.

But we want to test only that the aspect is properly applied to the controllers! We don’t want to test the services used by the controllers, nor do we want to prepare the test data and waste our time on waiting for the tests to start. Hence, we will write a test that will only initialize a small portion of the context. I.e. the test context will have a real aspect bean and a real controller bean, while the rest of the context will be mocked.

How to create mock beans?

There are multiple ways to create mock beans in Spring. For better visibility, we will use one of the service’s controllers – PersonalizedHoroscopeTellController. It’s code looks like this:

PersonalizedHoroscopeTellController.java
@Slf4j @RestController @RequestMapping( value = "/horoscope", produces = APPLICATION_JSON_UTF8_VALUE ) public class PersonalizedHoroscopeTellController { private final HoroscopeTeller horoscopeTeller; private final Function<String, ZodiacSign> zodiacSignConverter; private final Function<String, String> nameNormalizer; public PersonalizedHoroscopeTellController( final HoroscopeTeller horoscopeTeller, final Function<String, ZodiacSign> zodiacSignConverter, final Function<String, String> nameNormalizer ) { this.horoscopeTeller = horoscopeTeller; this.zodiacSignConverter = zodiacSignConverter; this.nameNormalizer = nameNormalizer; } @GetMapping(value = "/tell/personal/{name}/{sign}") public PersonalizedHoroscope tell(@PathVariable final String name, @PathVariable final String sign) { log.info("Received name: {}; sign: {}", name, sign); return PersonalizedHoroscope.builder() .name( nameNormalizer.apply(name) ) .horoscope( horoscopeTeller.tell( zodiacSignConverter.apply(sign) ) ) .build(); } }
Code language: Java (java)

Define a Java config class with dependencies in every test

We can write a Java config in every test class, where we will define the aspect and controller beans as well as mock beans for every controller dependency. This is an imperative way of defining beans, since we will be telling Spring how we want every bean to be created explicitly.

In that case a test for the controller will look this way:

javaconfig/PersonalizedHoroscopeTellControllerTest.java
@SpringJUnitConfig public class PersonalizedHoroscopeTellControllerTest { private static final int LIMIT = 10; @Autowired private PersonalizedHoroscopeTellController controller; @Autowired private Predicate<String> ipIsAllowed; @Test void doNothingWhenAllowed() { when(ipIsAllowed.test(anyString())).thenReturn(true); controller.tell(randomAlphabetic(LIMIT), randomAlphabetic(LIMIT)); } @Test void throwExceptionWhenNotAllowed() { when(ipIsAllowed.test(anyString())).thenReturn(false); assertThatThrownBy(() -> controller.tell(randomAlphabetic(LIMIT), randomAlphabetic(LIMIT))) .isInstanceOf(AccessDeniedException.class); } @Configuration @Import(AspectConfiguration.class) @EnableAspectJAutoProxy public static class Config { @Bean public PersonalizedHoroscopeTellController personalizedHoroscopeTellController( final HoroscopeTeller horoscopeTeller, final Function<String, ZodiacSign> zodiacSignConverter, final Function<String, String> nameNormalizer ) { return new PersonalizedHoroscopeTellController(horoscopeTeller, zodiacSignConverter, nameNormalizer); } @Bean public HoroscopeTeller horoscopeTeller() { return mock(HoroscopeTeller.class); } @Bean public Function<String, ZodiacSign> zodiacSignConverter() { return mock(Function.class); } @Bean public Function<String, String> nameNormalizer() { return mock(Function.class); } } }
Code language: Java (java)

Such a test looks pretty huge. Also, we will have to write a Java config for every controller’s test. Even though it will look different for every test, it’s goal would be the same: define the controller bean and it’s dependencies’ mocks. So, eventually it would kind of be the same for every controller. Being a highly professional developer, I am lazy. Hence, I decided that this is not the way.

@MockBean annotation on every dependency field

@MockBean annotation was introduced in Spring Boot Test version 1.4.0. It’s pretty much like @Mock annotation from Mockito (and it actually uses it internally), but with a little difference: a mock created with @MockBean will be automatically put into the Spring’s context. This is a declarative way of defining mock beans, since the beans will be created implicitly.

In that case the test will look this way:

mockbean/PersonalizedHoroscopeTellControllerTest.java
@SpringJUnitConfig public class PersonalizedHoroscopeTellControllerTest { private static final int LIMIT = 10; @MockBean private HoroscopeTeller horoscopeTeller; @MockBean private Function<String, ZodiacSign> zodiacSignConverter; @MockBean private Function<String, String> nameNormalizer; @MockBean private Predicate<String> ipIsAllowed; @Autowired private PersonalizedHoroscopeTellController controller; @Test void doNothingWhenAllowed() { when(ipIsAllowed.test(anyString())).thenReturn(true); controller.tell(randomAlphabetic(LIMIT), randomAlphabetic(LIMIT)); } @Test void throwExceptionWhenNotAllowed() { when(ipIsAllowed.test(anyString())).thenReturn(false); assertThatThrownBy(() -> controller.tell(randomAlphabetic(LIMIT), randomAlphabetic(LIMIT))) .isInstanceOf(AccessDeniedException.class); } @Configuration @Import({PersonalizedHoroscopeTellController.class, RestrictionAspect.class, RequestContextHolderConfigurer.class}) @EnableAspectJAutoProxy public static class Config { } }
Code language: JavaScript (javascript)

There is still Java config here, though it’s way more compact. Though, what I see as a disadvantage here – is that I had to define fields for the controller dependencies just to put an annotation on them, later they’re never used anyhow. Also, if for any reason, you’re using a version of Spring Boot Test lower than 1.4.0, you won’t be able to use that annotation.

With that being said, I came up with an idea of a different way to mock controller dependencies. I want it to work like that…

@Automocked annotation on a dependent component

I’d like to have an annotation called @Automocked, that I can put on a test class field having the controller, while mock beans for the controller’s dependencies would be created automatically and put into the context.

The test class will look this way then:

automocked/PersonalizedHoroscopeTellControllerTest.java
@SpringJUnitConfig @ContextConfiguration(classes = AspectConfiguration.class) @TestExecutionListeners(listeners = AutomockTestExecutionListener.class, mergeMode = MERGE_WITH_DEFAULTS) public class PersonalizedHoroscopeTellControllerTest { private static final int LIMIT = 10; @Automocked private PersonalizedHoroscopeTellController controller; @Autowired private Predicate<String> ipIsAllowed; @Test void doNothingWhenAllowed() { when(ipIsAllowed.test(anyString())).thenReturn(true); controller.tell(randomAlphabetic(LIMIT), randomAlphabetic(LIMIT)); } @Test void throwExceptionWhenNotAllowed() { when(ipIsAllowed.test(anyString())).thenReturn(false); assertThatThrownBy(() -> controller.tell(randomAlphabetic(LIMIT), randomAlphabetic(LIMIT))) .isInstanceOf(AccessDeniedException.class); } }
Code language: Java (java)

How does it work?

Let’s try to figure out how does it work and why do we need it.

TestExecutionListener

There is an interface in Spring Framework called TestExecutionListener. It provides an API to introduce your own logic to the test execution process at different stages of it, e.g.: during the test class instantiation, before or after invoking a test method, etc. It has a set of implementations “out of the box”. There are: DirtiesContextTestExecutionListener that will clean up the context if you’ll put a related annotation; DependencyInjectionTestExecutionListener – injects dependencies into test classes; etc. To apply our very own TestExecutionListener to a test, we need to put @TestExecutionListeners annotation on the test class and specify the implementation class.

Ordered

Also spring has an interface called Ordered. One can use it to specify that objects implementing it should be ordered somehow. For example if you have multiple implementations of the same interface and you want to inject them into a collection, the implementations will be ordered according to how they implement Ordered. When using this interface with TestExecutionListener implementations, it will be used to evaluate order in which they should apply.

Implementing TestExecutionListener

So, our very own Listener will implement 2 interfaces: TestExecutionListener and Ordered. We’ll call it AutomockTestExecutionListener and it will look like that:

AutomockTestExecutionListener.java
@Slf4j public class AutomockTestExecutionListener implements TestExecutionListener, Ordered { @Override public int getOrder() { return 1900; } @Override public void prepareTestInstance(final TestContext testContext) { val beanFactory = ((DefaultListableBeanFactory) testContext.getApplicationContext().getAutowireCapableBeanFactory()); setByNameCandidateResolver(beanFactory); for (val field : testContext.getTestClass().getDeclaredFields()) { if (field.getAnnotation(Automocked.class) == null) { continue; } log.debug("Performing automocking for the field: {}", field.getName()); makeAccessible(field); setField( field, testContext.getTestInstance(), createBeanWithMocks(findConstructorToAutomock(field.getType()), beanFactory) ); } } private void setByNameCandidateResolver(final DefaultListableBeanFactory beanFactory) { if ((beanFactory.getAutowireCandidateResolver() instanceof AutomockedBeanByNameAutowireCandidateResolver)) { return; } beanFactory.setAutowireCandidateResolver( new AutomockedBeanByNameAutowireCandidateResolver(beanFactory.getAutowireCandidateResolver()) ); } private Constructor<?> findConstructorToAutomock(final Class<?> clazz) { log.debug("Looking for suitable constructor of {}", clazz.getCanonicalName()); Constructor<?> fallBackConstructor = clazz.getDeclaredConstructors()[0]; for (val constructor : clazz.getDeclaredConstructors()) { if (constructor.getParameterTypes().length > fallBackConstructor.getParameterTypes().length) { fallBackConstructor = constructor; } val autowired = getAnnotation(constructor, Autowired.class); if (autowired != null) { return constructor; } } return fallBackConstructor; } private <T> T createBeanWithMocks(final Constructor<T> constructor, final DefaultListableBeanFactory beanFactory) { createMocksForParameters(constructor, beanFactory); val clazz = constructor.getDeclaringClass(); val beanName = forClass(clazz).toString(); log.debug("Creating bean {}", beanName); if (!beanFactory.containsBean(beanName)) { val bean = beanFactory.createBean(clazz); beanFactory.registerSingleton(beanName, bean); } return beanFactory.getBean(beanName, clazz); } private <T> void createMocksForParameters(final Constructor<T> constructor, final DefaultListableBeanFactory beanFactory) { log.debug("{} is going to be used for auto mocking", constructor); val constructorArgsAmount = constructor.getParameterTypes().length; for (int i = 0; i < constructorArgsAmount; i++) { val parameterType = forConstructorParameter(constructor, i); val beanName = parameterType.toString(); if (!beanFactory.containsBean(beanName)) { beanFactory.registerSingleton( beanName, mock(parameterType.resolve(), withSettings().stubOnly()) ); } log.debug("Mocked {}", beanName); } } }
Code language: Java (java)

What does it do? First, let’s have a look at the method prepareTestInstance(). It finds all fields annotated with @Automocked:

for (val field : testContext.getTestClass().getDeclaredFields()) { if (field.getAnnotation(Automocked.class) == null) { continue; }
Code language: Java (java)

Next it makes accessible (writable) the fields it has found:

makeAccessible(field);
Code language: Java (java)

After that using method findConstructorToAutomock() it finds the most suitable constructor:

Constructor<?> fallBackConstructor = clazz.getDeclaredConstructors()[0]; for (val constructor : clazz.getDeclaredConstructors()) { if (constructor.getParameterTypes().length > fallBackConstructor.getParameterTypes().length) { fallBackConstructor = constructor; } val autowired = getAnnotation(constructor, Autowired.class); if (autowired != null) { return constructor; } } return fallBackConstructor;
Code language: Java (java)

“The most suitable” means a constructor that is either has @Autowired annotation on it, or has the most arguments in it.

The constructor that was found will be passed then to method createBeanWithMocks(), which in turn calls createMocksForParameters() that creates mocks for the arguments of the constructor and registers them into the context:

val constructorArgsAmount = constructor.getParameterTypes().length; for (int i = 0; i < constructorArgsAmount; i++) { val parameterType = forConstructorParameter(constructor, i); val beanName = parameterType.toString(); if (!beanFactory.containsBean(beanName)) { beanFactory.registerSingleton( beanName, mock(parameterType.resolve(), withSettings().stubOnly()) ); } }
Code language: Java (java)

It is important to notice here that the string representation of the argument’s type (with it’s generic parameters) will be used as the bean name here. So, for an argument of type packages.Function<String, String> the bean name would look like: packages.Function<java.lang.String, java.lang.String>. Keep that in mind, it is important and I’ll get back to it later.

After all the mocks are created and registered in the context, the listener gets back to the dependent class bean instantiation (in our case it’s the controller):

if (!beanFactory.containsBean(beanName)) { val bean = beanFactory.createBean(clazz); beanFactory.registerSingleton(beanName, bean); }
Code language: Java (java)

Another thing I should bring your attention to is that we used 1900 as the Order value. We need it this way because the listener should be called after the context will be cleaned up by DirtiesContextBeforeModesTestExecutionListener (order=1500), but before dependency injection by DependencyInjectionTestExecutionListener (order=2000). Because the Listener will create new beans.

AutowireCandidateResolver

AutowireCandidateResolver is used to check if a BeanDefinition satisfies the dependency description. It has a few implementations “out of the box”:

The “out of the box” implementations are basically a chain of inheritance, so each of them extends the other one. We will write a decorator instead, because it is more flexible.

The resolver works as follows:

  • Spring picks a dependency descriptor – DependencyDescriptor;
  • then it finds all the BeanDefinitions of a suitable type;
  • it iterates through the definitions calling isAutowireCandidate() method of the resolver;
  • depending on if the ben definition satisfies the dependency descriptor or not it returns true or false.

But why would we need to implement a resolver ourselves?

Let’s try to figure that out by looking at the controller:

public class PersonalizedHoroscopeTellController { private final HoroscopeTeller horoscopeTeller; private final Function<String, ZodiacSign> zodiacSignConverter; private final Function<String, String> nameNormalizer; public PersonalizedHoroscopeTellController( final HoroscopeTeller horoscopeTeller, final Function<String, ZodiacSign> zodiacSignConverter, final Function<String, String> nameNormalizer ) { this.horoscopeTeller = horoscopeTeller; this.zodiacSignConverter = zodiacSignConverter; this.nameNormalizer = nameNormalizer; }
Code language: Java (java)

As you can see, it has 2 dependencies of the same type – Function, but having different generic parameters. In one case it is String and ZodiacSign, while in the other it is String and String. The problem here is that Mockito can’t create a mock with generic types. In other words, if we will create mocks for these dependencies, Spring Framework won’t be able to inject them, because they don’t have generic parameters info. We will see an exception saying that there are more than 1 dependencies of type Function in the context. And this is the exact problem we will solve with our own resolver. If you remember, our Listener implementation will use dependency type with generics as a bean name, so the only thing we need to do now – is teach Spring to compare the dependency type and bean name.

AutomockedBeanByNameAutowireCandidateResolver

So, the resolver will do exactly what I described before, while isAutowireCandidate() method implementation will look like that:

AutowireCandidateResolver.isAutowireCandidate()
@Override public boolean isAutowireCandidate(BeanDefinitionHolder beanDefinitionHolder, DependencyDescriptor descriptor) { val dependencyType = descriptor.getResolvableType().resolve(); val dependencyTypeName = descriptor.getResolvableType().toString(); val candidateBeanDefinition = (AbstractBeanDefinition) beanDefinitionHolder.getBeanDefinition(); val candidateTypeName = beanDefinitionHolder.getBeanName(); if (candidateTypeName.equals(dependencyTypeName) && candidateBeanDefinition.getBeanClass() != null) { return true; } return candidateResolver.isAutowireCandidate(beanDefinitionHolder, descriptor); }
Code language: Java (java)

It gets dependency type from the dependency descriptor, bean name (having the string representation of the bean type) from the bean definition and then compares them. If they match – it returns true. If they didn’t – it delegates the call to the internal resolver. Pretty simple, isn’t it?

Summary

As with all things development, there are always more than one way of doing something. Same goes with dependency injection customization in Spring. This time I showed you how to do that in tests with mock beans. The options we’ve seen are:

  • Java config – it will be an imperative way, pretty verbose and with some boilerplate. Though, it’s the most straightforward and simple way;
  • @MockBean – it will be a declarative way, less verbose than Java config, but still with a bit of boilerplate in form of redundant fields with the annotation;
  • @Automocked + a custom resolver – the least verbose option for tests, though has a pretty limited ways of using it. Also requires to be written at least once. Though, can come in handy when you just want to make sure that Spring creates proxies in a way you expect it.

There is no best way to do it, but there is definitely at leas one way that can suit your use case. Choose wisely and enjoy coding!


The source code is available in my GitHub repo: https://github.com/monosoul/spring-di-customization

P.S.

This article is a translation of the same-themed article I wrote in Russian back in 2019. If you speak Russian, you might prefer the original version of it.

Like it? Share it!

Leave a comment

Your email address will not be published.