diff --git a/src/GitHub.Exports/Extensions/VSExtensions.cs b/src/GitHub.Exports/Extensions/VSExtensions.cs index 1fd3e7350a..b4d34e66a5 100644 --- a/src/GitHub.Exports/Extensions/VSExtensions.cs +++ b/src/GitHub.Exports/Extensions/VSExtensions.cs @@ -7,6 +7,8 @@ namespace GitHub.Extensions { public static class VSExtensions { + static IUIProvider cachedUIProvider = null; + public static T TryGetService(this IServiceProvider serviceProvider) where T : class { return serviceProvider.TryGetService(typeof(T)) as T; @@ -14,6 +16,9 @@ public static T TryGetService(this IServiceProvider serviceProvider) where T public static object TryGetService(this IServiceProvider serviceProvider, Type type) { + if (cachedUIProvider != null && type == typeof(IUIProvider)) + return cachedUIProvider; + var ui = serviceProvider as IUIProvider; if (ui != null) return ui.TryGetService(type); @@ -21,7 +26,7 @@ public static object TryGetService(this IServiceProvider serviceProvider, Type t { try { - return serviceProvider.GetService(type); + return GetServiceAndCache(serviceProvider, type, ref cachedUIProvider); } catch (Exception ex) { @@ -33,20 +38,42 @@ public static object TryGetService(this IServiceProvider serviceProvider, Type t public static T GetService(this IServiceProvider serviceProvider) { - return (T)serviceProvider.GetService(typeof(T)); + if (cachedUIProvider != null && typeof(T) == typeof(IUIProvider)) + return (T)cachedUIProvider; + + return (T)GetServiceAndCache(serviceProvider, typeof(T), ref cachedUIProvider); } public static T GetExportedValue(this IServiceProvider serviceProvider) { + if (cachedUIProvider != null && typeof(T) == typeof(IUIProvider)) + return (T)cachedUIProvider; + var ui = serviceProvider as IUIProvider; return ui != null ? ui.GetService() - : VisualStudio.Services.ComponentModel.DefaultExportProvider.GetExportedValue(); + : GetExportedValueAndCache(ref cachedUIProvider); } public static ITeamExplorerSection GetSection(this IServiceProvider serviceProvider, Guid section) { return serviceProvider?.GetService()?.GetSection(section); } + + static object GetServiceAndCache(IServiceProvider provider, Type type, ref CacheType cache) + { + var ret = provider.GetService(type); + if (type == typeof(CacheType)) + cache = (CacheType)ret; + return ret; + } + + static T GetExportedValueAndCache(ref CacheType cache) + { + var ret = VisualStudio.Services.ComponentModel.DefaultExportProvider.GetExportedValue(); + if (typeof(T) == typeof(CacheType)) + cache = (CacheType)(object)ret; + return ret; + } } }