3
3
// The .NET Foundation licenses this file to you under the MIT license.
4
4
// See the LICENSE file in the project root for full license information.
5
5
6
- using System . Collections . Concurrent ;
7
- using System . Data ;
8
- using System . Diagnostics . CodeAnalysis ;
6
+ using System . Runtime . CompilerServices ;
9
7
using Microsoft . Extensions . DependencyInjection ;
10
8
11
9
namespace Splat . Microsoft . Extensions . DependencyInjection ;
@@ -17,7 +15,6 @@ namespace Splat.Microsoft.Extensions.DependencyInjection;
17
15
public class MicrosoftDependencyResolver : IDependencyResolver
18
16
{
19
17
private const string ImmutableExceptionMessage = "This container has already been built and cannot be modified." ;
20
- private static readonly Type _dictionaryType = typeof ( ContractDictionary < > ) ;
21
18
private readonly object _syncLock = new ( ) ;
22
19
private IServiceCollection ? _serviceCollection ;
23
20
private bool _isImmutable ;
@@ -91,29 +88,27 @@ public virtual IEnumerable<object> GetServices(Type? serviceType, string? contra
91
88
var isNull = serviceType is null ;
92
89
serviceType ??= typeof ( NullServiceType ) ;
93
90
94
- IEnumerable < object > services ;
91
+ IEnumerable < object > services = Enumerable . Empty < object > ( ) ;
95
92
96
93
if ( contract is null || string . IsNullOrWhiteSpace ( contract ) )
97
94
{
98
95
// this is to deal with CS8613 that GetServices returns IEnumerable<object?>?
99
96
services = ServiceProvider . GetServices ( serviceType )
100
97
. Where ( a => a is not null )
101
98
. Select ( a => a ! ) ;
102
-
103
- if ( isNull )
104
- {
105
- services = services
106
- . Cast < NullServiceType > ( )
107
- . Select ( nst => nst . Factory ( ) ! ) ;
108
- }
109
99
}
110
- else
100
+ else if ( ServiceProvider is IKeyedServiceProvider serviceProvider )
101
+ {
102
+ services = serviceProvider . GetKeyedServices ( serviceType , contract )
103
+ . Where ( a => a is not null )
104
+ . Select ( a => a ! ) ;
105
+ }
106
+
107
+ if ( isNull )
111
108
{
112
- var dic = GetContractDictionary ( serviceType , false ) ;
113
- services = dic ?
114
- . GetFactories ( contract )
115
- . Select ( f => f ( ) ! )
116
- ?? Array . Empty < object > ( ) ;
109
+ services = services
110
+ . Cast < NullServiceType > ( )
111
+ . Select ( nst => nst . Factory ( ) ! ) ;
117
112
}
118
113
119
114
return services ;
@@ -142,9 +137,10 @@ public virtual void Register(Func<object?> factory, Type? serviceType, string? c
142
137
}
143
138
else
144
139
{
145
- var dic = GetContractDictionary ( serviceType , true ) ;
146
-
147
- dic ? . AddFactory ( contract , factory ) ;
140
+ _serviceCollection ? . AddKeyedTransient ( serviceType , contract , ( _ , __ ) =>
141
+ isNull
142
+ ? new NullServiceType ( factory )
143
+ : factory ( ) ! ) ;
148
144
}
149
145
150
146
// required so that it gets rebuilt if not injected externally.
@@ -166,22 +162,18 @@ public virtual void UnregisterCurrent(Type? serviceType, string? contract = null
166
162
{
167
163
if ( contract is null || string . IsNullOrWhiteSpace ( contract ) )
168
164
{
169
- var sd = _serviceCollection ? . LastOrDefault ( s => s . ServiceType == serviceType ) ;
165
+ var sd = _serviceCollection ? . LastOrDefault ( s => ! s . IsKeyedService && s . ServiceType == serviceType ) ;
170
166
if ( sd is not null )
171
167
{
172
168
_serviceCollection ? . Remove ( sd ) ;
173
169
}
174
170
}
175
171
else
176
172
{
177
- var dic = GetContractDictionary ( serviceType , false ) ;
178
- if ( dic is not null )
173
+ var sd = _serviceCollection ? . LastOrDefault ( sd => MatchesKeyedContract ( serviceType , contract , sd ) ) ;
174
+ if ( sd is not null )
179
175
{
180
- dic . RemoveLastFactory ( contract ) ;
181
- if ( dic . IsEmpty )
182
- {
183
- RemoveContractService ( serviceType ) ;
184
- }
176
+ _serviceCollection ? . Remove ( sd ) ;
185
177
}
186
178
}
187
179
@@ -196,7 +188,7 @@ public virtual void UnregisterCurrent(Type? serviceType, string? contract = null
196
188
/// ignoring the <paramref name="serviceType"/> argument.
197
189
/// </summary>
198
190
/// <param name="serviceType">The service type to unregister.</param>
199
- /// <param name="contract">This parameter is ignored. Service will be removed from all contracts .</param>
191
+ /// <param name="contract">A optional value which will remove only an object registered with the same contract .</param>
200
192
public virtual void UnregisterAll ( Type ? serviceType , string ? contract = null )
201
193
{
202
194
if ( _isImmutable )
@@ -208,34 +200,28 @@ public virtual void UnregisterAll(Type? serviceType, string? contract = null)
208
200
209
201
lock ( _syncLock )
210
202
{
211
- switch ( contract )
203
+ if ( _serviceCollection is null )
204
+ {
205
+ // required so that it gets rebuilt if not injected externally.
206
+ _serviceProvider = null ;
207
+ return ;
208
+ }
209
+
210
+ IEnumerable < ServiceDescriptor > sds = Enumerable . Empty < ServiceDescriptor > ( ) ;
211
+
212
+ if ( contract is null || string . IsNullOrWhiteSpace ( contract ) )
212
213
{
213
- case null when _serviceCollection is not null :
214
- {
215
- var sds = _serviceCollection
216
- . Where ( s => s . ServiceType == serviceType )
217
- . ToList ( ) ;
218
-
219
- foreach ( var sd in sds )
220
- {
221
- _serviceCollection . Remove ( sd ) ;
222
- }
223
-
224
- break ;
225
- }
226
-
227
- case null :
228
- throw new ArgumentException ( "There must be a valid contract if there is no service collection." , nameof ( contract ) ) ;
229
- default :
230
- {
231
- var dic = GetContractDictionary ( serviceType , false ) ;
232
- if ( dic ? . TryRemoveContract ( contract ) == true && dic . IsEmpty )
233
- {
234
- RemoveContractService ( serviceType ) ;
235
- }
236
-
237
- break ;
238
- }
214
+ sds = _serviceCollection . Where ( s => ! s . IsKeyedService && s . ServiceType == serviceType ) ;
215
+ }
216
+ else
217
+ {
218
+ sds = _serviceCollection
219
+ . Where ( sd => MatchesKeyedContract ( serviceType , contract , sd ) ) ;
220
+ }
221
+
222
+ foreach ( var sd in sds . ToList ( ) )
223
+ {
224
+ _serviceCollection . Remove ( sd ) ;
239
225
}
240
226
241
227
// required so that it gets rebuilt if not injected externally.
@@ -255,16 +241,10 @@ public virtual bool HasRegistration(Type? serviceType, string? contract = null)
255
241
{
256
242
if ( contract is null || string . IsNullOrWhiteSpace ( contract ) )
257
243
{
258
- return _serviceCollection ? . Any ( sd => sd . ServiceType == serviceType ) == true ;
244
+ return _serviceCollection ? . Any ( sd => ! sd . IsKeyedService && sd . ServiceType == serviceType ) == true ;
259
245
}
260
246
261
- var dictionary = ( ContractDictionary ? ) _serviceCollection ? . FirstOrDefault ( sd => sd . ServiceType == GetDictionaryType ( serviceType ) ) ? . ImplementationInstance ;
262
-
263
- return dictionary switch
264
- {
265
- null => false ,
266
- _ => dictionary . GetFactories ( contract ) . Select ( f => f ( ) ) . Any ( )
267
- } ;
247
+ return _serviceCollection ? . Any ( sd => MatchesKeyedContract ( serviceType , contract , sd ) ) == true ;
268
248
}
269
249
270
250
if ( contract is null )
@@ -273,8 +253,12 @@ public virtual bool HasRegistration(Type? serviceType, string? contract = null)
273
253
return service is not null ;
274
254
}
275
255
276
- var dic = GetContractDictionary ( serviceType , false ) ;
277
- return dic ? . IsEmpty == false ;
256
+ if ( _serviceProvider is IKeyedServiceProvider keyedServiceProvider )
257
+ {
258
+ return keyedServiceProvider . GetKeyedService ( serviceType , contract ) is not null ;
259
+ }
260
+
261
+ return false ;
278
262
}
279
263
280
264
/// <inheritdoc />
@@ -292,103 +276,9 @@ protected virtual void Dispose(bool disposing)
292
276
{
293
277
}
294
278
295
- private static Type GetDictionaryType ( Type serviceType ) => _dictionaryType . MakeGenericType ( serviceType ) ;
296
-
297
- private void RemoveContractService ( Type serviceType )
298
- {
299
- var dicType = GetDictionaryType ( serviceType ) ;
300
- var sd = _serviceCollection ? . SingleOrDefault ( s => s . ServiceType == serviceType ) ;
301
-
302
- if ( sd is not null )
303
- {
304
- _serviceCollection ? . Remove ( sd ) ;
305
- }
306
- }
307
-
308
- [ SuppressMessage ( "Naming Rules" , "SA1300" , Justification = "Intentional" ) ]
309
- private ContractDictionary ? GetContractDictionary ( Type serviceType , bool createIfNotExists )
310
- {
311
- var dicType = GetDictionaryType ( serviceType ) ;
312
-
313
- if ( ServiceProvider is null )
314
- {
315
- throw new InvalidOperationException ( "The ServiceProvider is null." ) ;
316
- }
317
-
318
- if ( _isImmutable )
319
- {
320
- return ( ContractDictionary ? ) ServiceProvider . GetService ( dicType ) ;
321
- }
322
-
323
- var dic = getDictionary ( ) ;
324
- if ( createIfNotExists && dic is null )
325
- {
326
- lock ( _syncLock )
327
- {
328
- if ( createIfNotExists )
329
- {
330
- dic = ( ContractDictionary ? ) Activator . CreateInstance ( dicType ) ;
331
-
332
- if ( dic is not null )
333
- {
334
- _serviceCollection ? . AddSingleton ( dicType , dic ) ;
335
- }
336
- }
337
- }
338
- }
339
-
340
- return dic ;
341
-
342
- ContractDictionary ? getDictionary ( ) => _serviceCollection ?
343
- . Where ( sd => sd . ServiceType == dicType )
344
- . Select ( sd => sd . ImplementationInstance )
345
- . Cast < ContractDictionary > ( )
346
- . SingleOrDefault ( ) ;
347
- }
348
-
349
- private class ContractDictionary
350
- {
351
- private readonly ConcurrentDictionary < string , List < Func < object ? > > > _dictionary = new ( ) ;
352
-
353
- public bool IsEmpty => _dictionary . IsEmpty ;
354
-
355
- public bool TryRemoveContract ( string contract ) =>
356
- _dictionary . TryRemove ( contract , out var _ ) ;
357
-
358
- public Func < object ? > ? GetFactory ( string contract ) =>
359
- GetFactories ( contract )
360
- . LastOrDefault ( ) ;
361
-
362
- public IEnumerable < Func < object ? > > GetFactories ( string contract ) =>
363
- _dictionary . TryGetValue ( contract , out var collection )
364
- ? collection ?? Enumerable . Empty < Func < object ? > > ( )
365
- : Array . Empty < Func < object ? > > ( ) ;
366
-
367
- public void AddFactory ( string contract , Func < object ? > factory ) =>
368
- _dictionary . AddOrUpdate ( contract , _ => new ( ) { factory } , ( _ , list ) =>
369
- {
370
- ( list ??= [ ] ) . Add ( factory ) ;
371
- return list ;
372
- } ) ;
373
-
374
- public void RemoveLastFactory ( string contract ) =>
375
- _dictionary . AddOrUpdate ( contract , [ ] , ( _ , list ) =>
376
- {
377
- var lastIndex = list . Count - 1 ;
378
- if ( lastIndex > 0 )
379
- {
380
- list . RemoveAt ( lastIndex ) ;
381
- }
382
-
383
- // TODO if list empty remove contract entirely
384
- // need to find how to atomically update or remove
385
- // https://github.com/dotnet/corefx/issues/24246
386
- return list ;
387
- } ) ;
388
- }
389
-
390
- [ SuppressMessage ( "Design" , "CA1812: Unused class." , Justification = "Used in reflection." ) ]
391
- private sealed class ContractDictionary < T > : ContractDictionary
392
- {
393
- }
279
+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
280
+ private static bool MatchesKeyedContract ( Type ? serviceType , string contract , ServiceDescriptor sd ) =>
281
+ sd . ServiceType == serviceType
282
+ && sd is { IsKeyedService : true , ServiceKey : string serviceKey }
283
+ && serviceKey == contract ;
394
284
}
0 commit comments