Skip to content

Commit 8977918

Browse files
OleksandrTsvirkunChrisPulmandpvreony
authored
WIP Feature: Use Microsoft KeyedServiceProvider (#1075)
* Use KeyedServiceProvide instead of ContractDictionary for Splat.Microsoft.Extensions.DependencyInjection adapter * refactor service contract check * add another use of refactor, update xmldoc * add net8 as test framework * restore contract dictionary * remove contract dictionary, not actually needed --------- Co-authored-by: Chris Pulman <chris.pulman@yahoo.com> Co-authored-by: dpvreony <dpvreony@users.noreply.github.com>
1 parent 53c0b31 commit 8977918

File tree

2 files changed

+57
-167
lines changed

2 files changed

+57
-167
lines changed

src/Splat.Microsoft.Extensions.DependencyInjection.Tests/Splat.Microsoft.Extensions.DependencyInjection.Tests.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
4-
<TargetFrameworks>net6.0</TargetFrameworks>
4+
<TargetFrameworks>net6.0;net8.0</TargetFrameworks>
55
<NoWarn>$(NoWarn);1591;CA1707;SA1633;CA2000</NoWarn>
66
<IsPackable>false</IsPackable>
77
<Nullable>enable</Nullable>

src/Splat.Microsoft.Extensions.DependencyInjection/MicrosoftDependencyResolver.cs

Lines changed: 56 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
// The .NET Foundation licenses this file to you under the MIT license.
44
// See the LICENSE file in the project root for full license information.
55

6-
using System.Collections.Concurrent;
7-
using System.Data;
8-
using System.Diagnostics.CodeAnalysis;
6+
using System.Runtime.CompilerServices;
97
using Microsoft.Extensions.DependencyInjection;
108

119
namespace Splat.Microsoft.Extensions.DependencyInjection;
@@ -17,7 +15,6 @@ namespace Splat.Microsoft.Extensions.DependencyInjection;
1715
public class MicrosoftDependencyResolver : IDependencyResolver
1816
{
1917
private const string ImmutableExceptionMessage = "This container has already been built and cannot be modified.";
20-
private static readonly Type _dictionaryType = typeof(ContractDictionary<>);
2118
private readonly object _syncLock = new();
2219
private IServiceCollection? _serviceCollection;
2320
private bool _isImmutable;
@@ -91,29 +88,27 @@ public virtual IEnumerable<object> GetServices(Type? serviceType, string? contra
9188
var isNull = serviceType is null;
9289
serviceType ??= typeof(NullServiceType);
9390

94-
IEnumerable<object> services;
91+
IEnumerable<object> services = Enumerable.Empty<object>();
9592

9693
if (contract is null || string.IsNullOrWhiteSpace(contract))
9794
{
9895
// this is to deal with CS8613 that GetServices returns IEnumerable<object?>?
9996
services = ServiceProvider.GetServices(serviceType)
10097
.Where(a => a is not null)
10198
.Select(a => a!);
102-
103-
if (isNull)
104-
{
105-
services = services
106-
.Cast<NullServiceType>()
107-
.Select(nst => nst.Factory()!);
108-
}
10999
}
110-
else
100+
else if (ServiceProvider is IKeyedServiceProvider serviceProvider)
101+
{
102+
services = serviceProvider.GetKeyedServices(serviceType, contract)
103+
.Where(a => a is not null)
104+
.Select(a => a!);
105+
}
106+
107+
if (isNull)
111108
{
112-
var dic = GetContractDictionary(serviceType, false);
113-
services = dic?
114-
.GetFactories(contract)
115-
.Select(f => f()!)
116-
?? Array.Empty<object>();
109+
services = services
110+
.Cast<NullServiceType>()
111+
.Select(nst => nst.Factory()!);
117112
}
118113

119114
return services;
@@ -142,9 +137,10 @@ public virtual void Register(Func<object?> factory, Type? serviceType, string? c
142137
}
143138
else
144139
{
145-
var dic = GetContractDictionary(serviceType, true);
146-
147-
dic?.AddFactory(contract, factory);
140+
_serviceCollection?.AddKeyedTransient(serviceType, contract, (_, __) =>
141+
isNull
142+
? new NullServiceType(factory)
143+
: factory()!);
148144
}
149145

150146
// required so that it gets rebuilt if not injected externally.
@@ -166,22 +162,18 @@ public virtual void UnregisterCurrent(Type? serviceType, string? contract = null
166162
{
167163
if (contract is null || string.IsNullOrWhiteSpace(contract))
168164
{
169-
var sd = _serviceCollection?.LastOrDefault(s => s.ServiceType == serviceType);
165+
var sd = _serviceCollection?.LastOrDefault(s => !s.IsKeyedService && s.ServiceType == serviceType);
170166
if (sd is not null)
171167
{
172168
_serviceCollection?.Remove(sd);
173169
}
174170
}
175171
else
176172
{
177-
var dic = GetContractDictionary(serviceType, false);
178-
if (dic is not null)
173+
var sd = _serviceCollection?.LastOrDefault(sd => MatchesKeyedContract(serviceType, contract, sd));
174+
if (sd is not null)
179175
{
180-
dic.RemoveLastFactory(contract);
181-
if (dic.IsEmpty)
182-
{
183-
RemoveContractService(serviceType);
184-
}
176+
_serviceCollection?.Remove(sd);
185177
}
186178
}
187179

@@ -196,7 +188,7 @@ public virtual void UnregisterCurrent(Type? serviceType, string? contract = null
196188
/// ignoring the <paramref name="serviceType"/> argument.
197189
/// </summary>
198190
/// <param name="serviceType">The service type to unregister.</param>
199-
/// <param name="contract">This parameter is ignored. Service will be removed from all contracts.</param>
191+
/// <param name="contract">A optional value which will remove only an object registered with the same contract.</param>
200192
public virtual void UnregisterAll(Type? serviceType, string? contract = null)
201193
{
202194
if (_isImmutable)
@@ -208,34 +200,28 @@ public virtual void UnregisterAll(Type? serviceType, string? contract = null)
208200

209201
lock (_syncLock)
210202
{
211-
switch (contract)
203+
if (_serviceCollection is null)
204+
{
205+
// required so that it gets rebuilt if not injected externally.
206+
_serviceProvider = null;
207+
return;
208+
}
209+
210+
IEnumerable<ServiceDescriptor> sds = Enumerable.Empty<ServiceDescriptor>();
211+
212+
if (contract is null || string.IsNullOrWhiteSpace(contract))
212213
{
213-
case null when _serviceCollection is not null:
214-
{
215-
var sds = _serviceCollection
216-
.Where(s => s.ServiceType == serviceType)
217-
.ToList();
218-
219-
foreach (var sd in sds)
220-
{
221-
_serviceCollection.Remove(sd);
222-
}
223-
224-
break;
225-
}
226-
227-
case null:
228-
throw new ArgumentException("There must be a valid contract if there is no service collection.", nameof(contract));
229-
default:
230-
{
231-
var dic = GetContractDictionary(serviceType, false);
232-
if (dic?.TryRemoveContract(contract) == true && dic.IsEmpty)
233-
{
234-
RemoveContractService(serviceType);
235-
}
236-
237-
break;
238-
}
214+
sds = _serviceCollection.Where(s => !s.IsKeyedService && s.ServiceType == serviceType);
215+
}
216+
else
217+
{
218+
sds = _serviceCollection
219+
.Where(sd => MatchesKeyedContract(serviceType, contract, sd));
220+
}
221+
222+
foreach (var sd in sds.ToList())
223+
{
224+
_serviceCollection.Remove(sd);
239225
}
240226

241227
// required so that it gets rebuilt if not injected externally.
@@ -255,16 +241,10 @@ public virtual bool HasRegistration(Type? serviceType, string? contract = null)
255241
{
256242
if (contract is null || string.IsNullOrWhiteSpace(contract))
257243
{
258-
return _serviceCollection?.Any(sd => sd.ServiceType == serviceType) == true;
244+
return _serviceCollection?.Any(sd => !sd.IsKeyedService && sd.ServiceType == serviceType) == true;
259245
}
260246

261-
var dictionary = (ContractDictionary?)_serviceCollection?.FirstOrDefault(sd => sd.ServiceType == GetDictionaryType(serviceType))?.ImplementationInstance;
262-
263-
return dictionary switch
264-
{
265-
null => false,
266-
_ => dictionary.GetFactories(contract).Select(f => f()).Any()
267-
};
247+
return _serviceCollection?.Any(sd => MatchesKeyedContract(serviceType, contract, sd)) == true;
268248
}
269249

270250
if (contract is null)
@@ -273,8 +253,12 @@ public virtual bool HasRegistration(Type? serviceType, string? contract = null)
273253
return service is not null;
274254
}
275255

276-
var dic = GetContractDictionary(serviceType, false);
277-
return dic?.IsEmpty == false;
256+
if (_serviceProvider is IKeyedServiceProvider keyedServiceProvider)
257+
{
258+
return keyedServiceProvider.GetKeyedService(serviceType, contract) is not null;
259+
}
260+
261+
return false;
278262
}
279263

280264
/// <inheritdoc />
@@ -292,103 +276,9 @@ protected virtual void Dispose(bool disposing)
292276
{
293277
}
294278

295-
private static Type GetDictionaryType(Type serviceType) => _dictionaryType.MakeGenericType(serviceType);
296-
297-
private void RemoveContractService(Type serviceType)
298-
{
299-
var dicType = GetDictionaryType(serviceType);
300-
var sd = _serviceCollection?.SingleOrDefault(s => s.ServiceType == serviceType);
301-
302-
if (sd is not null)
303-
{
304-
_serviceCollection?.Remove(sd);
305-
}
306-
}
307-
308-
[SuppressMessage("Naming Rules", "SA1300", Justification = "Intentional")]
309-
private ContractDictionary? GetContractDictionary(Type serviceType, bool createIfNotExists)
310-
{
311-
var dicType = GetDictionaryType(serviceType);
312-
313-
if (ServiceProvider is null)
314-
{
315-
throw new InvalidOperationException("The ServiceProvider is null.");
316-
}
317-
318-
if (_isImmutable)
319-
{
320-
return (ContractDictionary?)ServiceProvider.GetService(dicType);
321-
}
322-
323-
var dic = getDictionary();
324-
if (createIfNotExists && dic is null)
325-
{
326-
lock (_syncLock)
327-
{
328-
if (createIfNotExists)
329-
{
330-
dic = (ContractDictionary?)Activator.CreateInstance(dicType);
331-
332-
if (dic is not null)
333-
{
334-
_serviceCollection?.AddSingleton(dicType, dic);
335-
}
336-
}
337-
}
338-
}
339-
340-
return dic;
341-
342-
ContractDictionary? getDictionary() => _serviceCollection?
343-
.Where(sd => sd.ServiceType == dicType)
344-
.Select(sd => sd.ImplementationInstance)
345-
.Cast<ContractDictionary>()
346-
.SingleOrDefault();
347-
}
348-
349-
private class ContractDictionary
350-
{
351-
private readonly ConcurrentDictionary<string, List<Func<object?>>> _dictionary = new();
352-
353-
public bool IsEmpty => _dictionary.IsEmpty;
354-
355-
public bool TryRemoveContract(string contract) =>
356-
_dictionary.TryRemove(contract, out var _);
357-
358-
public Func<object?>? GetFactory(string contract) =>
359-
GetFactories(contract)
360-
.LastOrDefault();
361-
362-
public IEnumerable<Func<object?>> GetFactories(string contract) =>
363-
_dictionary.TryGetValue(contract, out var collection)
364-
? collection ?? Enumerable.Empty<Func<object?>>()
365-
: Array.Empty<Func<object?>>();
366-
367-
public void AddFactory(string contract, Func<object?> factory) =>
368-
_dictionary.AddOrUpdate(contract, _ => new() { factory }, (_, list) =>
369-
{
370-
(list ??= []).Add(factory);
371-
return list;
372-
});
373-
374-
public void RemoveLastFactory(string contract) =>
375-
_dictionary.AddOrUpdate(contract, [], (_, list) =>
376-
{
377-
var lastIndex = list.Count - 1;
378-
if (lastIndex > 0)
379-
{
380-
list.RemoveAt(lastIndex);
381-
}
382-
383-
// TODO if list empty remove contract entirely
384-
// need to find how to atomically update or remove
385-
// https://github.com/dotnet/corefx/issues/24246
386-
return list;
387-
});
388-
}
389-
390-
[SuppressMessage("Design", "CA1812: Unused class.", Justification = "Used in reflection.")]
391-
private sealed class ContractDictionary<T> : ContractDictionary
392-
{
393-
}
279+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
280+
private static bool MatchesKeyedContract(Type? serviceType, string contract, ServiceDescriptor sd) =>
281+
sd.ServiceType == serviceType
282+
&& sd is { IsKeyedService: true, ServiceKey: string serviceKey }
283+
&& serviceKey == contract;
394284
}

0 commit comments

Comments
 (0)