diff --git a/src/Microsoft.ComponentDetection.Common/Telemetry/Records/BcdeExecutionTelemetryRecord.cs b/src/Microsoft.ComponentDetection.Common/Telemetry/Records/BcdeExecutionTelemetryRecord.cs index 8993e29e6..dc1da8df4 100644 --- a/src/Microsoft.ComponentDetection.Common/Telemetry/Records/BcdeExecutionTelemetryRecord.cs +++ b/src/Microsoft.ComponentDetection.Common/Telemetry/Records/BcdeExecutionTelemetryRecord.cs @@ -2,6 +2,9 @@ namespace Microsoft.ComponentDetection.Common.Telemetry.Records { + using System.Threading; + using System.Threading.Tasks; + public class BcdeExecutionTelemetryRecord : BaseDetectionTelemetryRecord { public override string RecordName => "BcdeExecution"; @@ -22,13 +25,16 @@ public class BcdeExecutionTelemetryRecord : BaseDetectionTelemetryRecord public string AgentOSDescription { get; set; } - public static TReturn Track(Func functionToTrack, bool terminalRecord = false) + public static async Task TrackAsync( + Func> functionToTrack, + bool terminalRecord = false, + CancellationToken cancellationToken = default) { using var record = new BcdeExecutionTelemetryRecord(); try { - return functionToTrack(record); + return await functionToTrack(record, cancellationToken); } catch (Exception ex) { diff --git a/src/Microsoft.ComponentDetection.Orchestrator/Orchestrator.cs b/src/Microsoft.ComponentDetection.Orchestrator/Orchestrator.cs index 06104883d..1b8fe88d4 100644 --- a/src/Microsoft.ComponentDetection.Orchestrator/Orchestrator.cs +++ b/src/Microsoft.ComponentDetection.Orchestrator/Orchestrator.cs @@ -27,7 +27,7 @@ public class Orchestrator { private static readonly bool IsLinux = RuntimeInformation.IsOSPlatform(OSPlatform.Linux); - public ScanResult Load(string[] args) + public async Task LoadAsync(string[] args, CancellationToken cancellationToken = default) { var argumentHelper = new ArgumentHelper { ArgumentSets = new[] { new BaseArguments() } }; BaseArguments baseArguments = null; @@ -65,19 +65,20 @@ public ScanResult Load(string[] args) var shouldFailureBeSuppressed = false; // Don't use the using pattern here so we can take care not to clobber the stack - var returnResult = BcdeExecutionTelemetryRecord.Track( - (record) => - { - var executionResult = this.HandleCommand(args, record); - if (executionResult.ResultCode == ProcessingResultCode.PartialSuccess) + var returnResult = await BcdeExecutionTelemetryRecord.TrackAsync( + async (record, ct) => { - shouldFailureBeSuppressed = true; - record.HiddenExitCode = (int)executionResult.ResultCode; - } + var executionResult = await this.HandleCommandAsync(args, record, ct); + if (executionResult.ResultCode == ProcessingResultCode.PartialSuccess) + { + shouldFailureBeSuppressed = true; + record.HiddenExitCode = (int)executionResult.ResultCode; + } - return executionResult; - }, - true); + return executionResult; + }, + true, + cancellationToken); // The order of these things is a little weird, but done this way mostly to prevent any of the logic inside if blocks from being duplicated if (shouldFailureBeSuppressed) @@ -126,35 +127,38 @@ private static void AddAssembliesWithType(Assembly assembly, ContainerConfigu [Import] private static IArgumentHelper ArgumentHelper { get; set; } - public ScanResult HandleCommand(string[] args, BcdeExecutionTelemetryRecord telemetryRecord) + public async Task HandleCommandAsync( + string[] args, + BcdeExecutionTelemetryRecord telemetryRecord, + CancellationToken cancellationToken = default) { var scanResult = new ScanResult() { ResultCode = ProcessingResultCode.Error, }; - var parsedArguments = ArgumentHelper.ParseArguments(args) - .WithParsed(argumentSet => - { - CommandLineArgumentsExporter.ArgumentsForDelayedInjection = argumentSet; + var parsedArguments = ArgumentHelper.ParseArguments(args); + await parsedArguments.WithParsedAsync(async argumentSet => + { + CommandLineArgumentsExporter.ArgumentsForDelayedInjection = argumentSet; - // Don't set production telemetry if we are running the build task in DevFabric. 0.36.0 is set in the task.json for the build task in development, but is calculated during deployment for production. - TelemetryConstants.CorrelationId = argumentSet.CorrelationId; - telemetryRecord.Command = this.GetVerb(argumentSet); + // Don't set production telemetry if we are running the build task in DevFabric. 0.36.0 is set in the task.json for the build task in development, but is calculated during deployment for production. + TelemetryConstants.CorrelationId = argumentSet.CorrelationId; + telemetryRecord.Command = this.GetVerb(argumentSet); - scanResult = this.SafelyExecute(telemetryRecord, () => - { - this.GenerateEnvironmentSpecificTelemetry(telemetryRecord); + scanResult = await this.SafelyExecuteAsync(telemetryRecord, async () => + { + await this.GenerateEnvironmentSpecificTelemetryAsync(telemetryRecord); - telemetryRecord.Arguments = JsonConvert.SerializeObject(argumentSet); - FileWritingService.Init(argumentSet.Output); - Logger.Init(argumentSet.Verbosity); - Logger.LogInfo($"Run correlation id: {TelemetryConstants.CorrelationId.ToString()}"); + telemetryRecord.Arguments = JsonConvert.SerializeObject(argumentSet); + FileWritingService.Init(argumentSet.Output); + Logger.Init(argumentSet.Verbosity); + Logger.LogInfo($"Run correlation id: {TelemetryConstants.CorrelationId.ToString()}"); - return this.Dispatch(argumentSet, CancellationToken.None).GetAwaiter().GetResult(); - }); - }) - .WithNotParsed(errors => + return await this.Dispatch(argumentSet, cancellationToken); + }); + }); + parsedArguments.WithNotParsed(errors => { if (errors.Any(e => e is HelpVerbRequestedError)) { @@ -174,7 +178,7 @@ public ScanResult HandleCommand(string[] args, BcdeExecutionTelemetryRecord tele return scanResult; } - private void GenerateEnvironmentSpecificTelemetry(BcdeExecutionTelemetryRecord telemetryRecord) + private async Task GenerateEnvironmentSpecificTelemetryAsync(BcdeExecutionTelemetryRecord telemetryRecord) { telemetryRecord.AgentOSDescription = RuntimeInformation.OSDescription; @@ -204,7 +208,7 @@ private void GenerateEnvironmentSpecificTelemetry(BcdeExecutionTelemetryRecord t throw new TimeoutException($"The execution did not complete in the alotted time ({taskTimeout} seconds) and has been terminated prior to completion"); } - agentOSMeaningfulDetails[LibSslDetailsKey] = getLibSslPackages.GetAwaiter().GetResult(); + agentOSMeaningfulDetails[LibSslDetailsKey] = await getLibSslPackages; } catch (Exception ex) { @@ -259,11 +263,11 @@ private async Task Dispatch(IScanArguments arguments, CancellationTo return scanResult; } - private ScanResult SafelyExecute(BcdeExecutionTelemetryRecord record, Func wrappedInvocation) + private async Task SafelyExecuteAsync(BcdeExecutionTelemetryRecord record, Func> wrappedInvocation) { try { - return wrappedInvocation(); + return await wrappedInvocation(); } catch (Exception ae) { diff --git a/src/Microsoft.ComponentDetection/Program.cs b/src/Microsoft.ComponentDetection/Program.cs index 232a0ece5..b8843a076 100644 --- a/src/Microsoft.ComponentDetection/Program.cs +++ b/src/Microsoft.ComponentDetection/Program.cs @@ -29,7 +29,7 @@ public static async Task Main(string[] args) var orchestrator = new Orchestrator.Orchestrator(); - var result = orchestrator.Load(args); + var result = await orchestrator.LoadAsync(args); var exitCode = (int)result.ResultCode; if (result.ResultCode == ProcessingResultCode.Error || result.ResultCode == ProcessingResultCode.InputError)