Use .NET 6's ParallelForEachAsync(...)

This commit is contained in:
Tyrrrz 2021-12-13 21:02:11 +02:00
parent b8567d384f
commit 008bb2f591
6 changed files with 87 additions and 99 deletions

View file

@ -14,7 +14,6 @@ using DiscordChatExporter.Core.Exceptions;
using DiscordChatExporter.Core.Exporting; using DiscordChatExporter.Core.Exporting;
using DiscordChatExporter.Core.Exporting.Filtering; using DiscordChatExporter.Core.Exporting.Filtering;
using DiscordChatExporter.Core.Exporting.Partitioning; using DiscordChatExporter.Core.Exporting.Partitioning;
using DiscordChatExporter.Core.Utils.Extensions;
namespace DiscordChatExporter.Cli.Commands.Base; namespace DiscordChatExporter.Cli.Commands.Base;
@ -68,36 +67,47 @@ public abstract class ExportCommandBase : TokenCommandBase
await console.Output.WriteLineAsync($"Exporting {channels.Count} channel(s)..."); await console.Output.WriteLineAsync($"Exporting {channels.Count} channel(s)...");
await console.CreateProgressTicker().StartAsync(async progressContext => await console.CreateProgressTicker().StartAsync(async progressContext =>
{ {
await channels.ParallelForEachAsync(async channel => await Parallel.ForEachAsync(
{ channels,
try new ParallelOptions
{ {
await progressContext.StartTaskAsync($"{channel.Category.Name} / {channel.Name}", async progress => MaxDegreeOfParallelism = Math.Max(1, ParallelLimit),
CancellationToken = cancellationToken
},
async (channel, innerCancellationToken) =>
{
try
{ {
var guild = await Discord.GetGuildAsync(channel.GuildId, cancellationToken); await progressContext.StartTaskAsync(
$"{channel.Category.Name} / {channel.Name}",
async progress =>
{
var guild = await Discord.GetGuildAsync(channel.GuildId, innerCancellationToken);
var request = new ExportRequest( var request = new ExportRequest(
guild, guild,
channel, channel,
OutputPath, OutputPath,
ExportFormat, ExportFormat,
After, After,
Before, Before,
PartitionLimit, PartitionLimit,
MessageFilter, MessageFilter,
ShouldDownloadMedia, ShouldDownloadMedia,
ShouldReuseMedia, ShouldReuseMedia,
DateFormat DateFormat
);
await Exporter.ExportChannelAsync(request, progress, innerCancellationToken);
}
); );
}
await Exporter.ExportChannelAsync(request, progress, cancellationToken); catch (DiscordChatExporterException ex) when (!ex.IsFatal)
}); {
errors[channel] = ex.Message;
}
} }
catch (DiscordChatExporterException ex) when (!ex.IsFatal) );
{
errors[channel] = ex.Message;
}
}, Math.Max(ParallelLimit, 1), cancellationToken);
}); });
// Print result // Print result

View file

@ -48,10 +48,9 @@ public partial record ExportRequest
Snowflake? after = null, Snowflake? after = null,
Snowflake? before = null) Snowflake? before = null)
{ {
// Formats path // Formats path
outputPath = Regex.Replace(outputPath, "%.", m => outputPath = Regex.Replace(outputPath, "%.", m =>
PathEx.EscapePath(m.Value switch PathEx.EscapeFileName(m.Value switch
{ {
"%g" => guild.Id.ToString(), "%g" => guild.Id.ToString(),
"%G" => guild.Name, "%G" => guild.Name,
@ -118,9 +117,6 @@ public partial record ExportRequest
// File extension // File extension
buffer.Append($".{format.GetFileExtension()}"); buffer.Append($".{format.GetFileExtension()}");
// Replace invalid chars return PathEx.EscapeFileName(buffer.ToString());
PathEx.EscapePath(buffer);
return buffer.ToString();
} }
} }

View file

@ -104,6 +104,6 @@ internal partial class MediaDownloader
var fileNameWithoutExtension = Path.GetFileNameWithoutExtension(fileName); var fileNameWithoutExtension = Path.GetFileNameWithoutExtension(fileName);
var fileExtension = Path.GetExtension(fileName); var fileExtension = Path.GetExtension(fileName);
return PathEx.EscapePath(fileNameWithoutExtension.Truncate(42) + '-' + urlHash + fileExtension); return PathEx.EscapeFileName(fileNameWithoutExtension.Truncate(42) + '-' + urlHash + fileExtension);
} }
} }

View file

@ -1,8 +1,5 @@
using System; using System.Collections.Generic;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace DiscordChatExporter.Core.Utils.Extensions; namespace DiscordChatExporter.Core.Utils.Extensions;
@ -23,29 +20,4 @@ public static class AsyncExtensions
public static ValueTaskAwaiter<IReadOnlyList<T>> GetAwaiter<T>( public static ValueTaskAwaiter<IReadOnlyList<T>> GetAwaiter<T>(
this IAsyncEnumerable<T> asyncEnumerable) => this IAsyncEnumerable<T> asyncEnumerable) =>
asyncEnumerable.AggregateAsync().GetAwaiter(); asyncEnumerable.AggregateAsync().GetAwaiter();
public static async ValueTask ParallelForEachAsync<T>(
this IEnumerable<T> source,
Func<T, ValueTask> handleAsync,
int degreeOfParallelism,
CancellationToken cancellationToken = default)
{
using var semaphore = new SemaphoreSlim(degreeOfParallelism);
await Task.WhenAll(source.Select(async item =>
{
// ReSharper disable once AccessToDisposedClosure
await semaphore.WaitAsync(cancellationToken);
try
{
await handleAsync(item);
}
finally
{
// ReSharper disable once AccessToDisposedClosure
semaphore.Release();
}
}));
}
} }

View file

@ -1,17 +1,20 @@
using System.IO; using System.Collections.Generic;
using System.IO;
using System.Text; using System.Text;
namespace DiscordChatExporter.Core.Utils; namespace DiscordChatExporter.Core.Utils;
public static class PathEx public static class PathEx
{ {
public static StringBuilder EscapePath(StringBuilder pathBuffer) private static readonly HashSet<char> InvalidFileNameChars = new(Path.GetInvalidFileNameChars());
public static string EscapeFileName(string path)
{ {
foreach (var invalidChar in Path.GetInvalidFileNameChars()) var buffer = new StringBuilder(path.Length);
pathBuffer.Replace(invalidChar, '_');
return pathBuffer; foreach (var c in path)
buffer.Append(!InvalidFileNameChars.Contains(c) ? c : '_');
return buffer.ToString();
} }
public static string EscapePath(string path) => EscapePath(new StringBuilder(path)).ToString();
} }

View file

@ -210,39 +210,46 @@ public class RootViewModel : Screen
var operations = ProgressManager.CreateOperations(dialog.Channels!.Count); var operations = ProgressManager.CreateOperations(dialog.Channels!.Count);
var successfulExportCount = 0; var successfulExportCount = 0;
await dialog.Channels.Zip(operations).ParallelForEachAsync(async tuple => await Parallel.ForEachAsync(
{ dialog.Channels.Zip(operations),
var (channel, operation) = tuple; new ParallelOptions
try
{ {
var request = new ExportRequest( MaxDegreeOfParallelism = Math.Max(1, _settingsService.ParallelLimit)
dialog.Guild!, },
channel!, async (tuple, cancellationToken) =>
dialog.OutputPath!,
dialog.SelectedFormat,
dialog.After?.Pipe(Snowflake.FromDate),
dialog.Before?.Pipe(Snowflake.FromDate),
dialog.PartitionLimit,
dialog.MessageFilter,
dialog.ShouldDownloadMedia,
_settingsService.ShouldReuseMedia,
_settingsService.DateFormat
);
await exporter.ExportChannelAsync(request, operation);
Interlocked.Increment(ref successfulExportCount);
}
catch (DiscordChatExporterException ex) when (!ex.IsFatal)
{ {
Notifications.Enqueue(ex.Message.TrimEnd('.')); var (channel, operation) = tuple;
try
{
var request = new ExportRequest(
dialog.Guild!,
channel,
dialog.OutputPath!,
dialog.SelectedFormat,
dialog.After?.Pipe(Snowflake.FromDate),
dialog.Before?.Pipe(Snowflake.FromDate),
dialog.PartitionLimit,
dialog.MessageFilter,
dialog.ShouldDownloadMedia,
_settingsService.ShouldReuseMedia,
_settingsService.DateFormat
);
await exporter.ExportChannelAsync(request, operation, cancellationToken);
Interlocked.Increment(ref successfulExportCount);
}
catch (DiscordChatExporterException ex) when (!ex.IsFatal)
{
Notifications.Enqueue(ex.Message.TrimEnd('.'));
}
finally
{
operation.Dispose();
}
} }
finally );
{
operation.Dispose();
}
}, Math.Max(1, _settingsService.ParallelLimit));
// Notify of overall completion // Notify of overall completion
if (successfulExportCount > 0) if (successfulExportCount > 0)