diff --git a/DiscordChatExporter.Cli.Tests/Infra/ExportWrapper.cs b/DiscordChatExporter.Cli.Tests/Infra/ExportWrapper.cs index 4168ac2..af7a8ed 100644 --- a/DiscordChatExporter.Cli.Tests/Infra/ExportWrapper.cs +++ b/DiscordChatExporter.Cli.Tests/Infra/ExportWrapper.cs @@ -1,9 +1,11 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Linq; using System.Reflection; using System.Text.Json; +using System.Threading; using System.Threading.Tasks; using AngleSharp.Dom; using AngleSharp.Html.Dom; @@ -18,6 +20,8 @@ namespace DiscordChatExporter.Cli.Tests.Infra; public static class ExportWrapper { + private static readonly ConcurrentDictionary Locks = new(); + private static readonly string DirPath = Path.Combine( Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location) ?? Directory.GetCurrentDirectory(), "ExportCache" @@ -38,22 +42,33 @@ public static class ExportWrapper private static async ValueTask ExportAsync(Snowflake channelId, ExportFormat format) { - var fileName = channelId.ToString() + '.' + format.GetFileExtension(); - var filePath = Path.Combine(DirPath, fileName); + // Lock separately for each channel and format + var channelLock = Locks.GetOrAdd($"{channelId}_{format}", _ => new SemaphoreSlim(1, 1)); + await channelLock.WaitAsync(); - // Perform export only if it hasn't been done before - if (!File.Exists(filePath)) + try { - await new ExportChannelsCommand + var fileName = channelId.ToString() + '.' + format.GetFileExtension(); + var filePath = Path.Combine(DirPath, fileName); + + // Perform export only if it hasn't been done before + if (!File.Exists(filePath)) { - Token = Secrets.DiscordToken, - ChannelIds = new[] { channelId }, - ExportFormat = format, - OutputPath = filePath - }.ExecuteAsync(new FakeConsole()); + await new ExportChannelsCommand + { + Token = Secrets.DiscordToken, + ChannelIds = new[] { channelId }, + ExportFormat = format, + OutputPath = filePath + }.ExecuteAsync(new FakeConsole()); + } + + return await File.ReadAllTextAsync(filePath); + } + finally + { + channelLock.Release(); } - - return await File.ReadAllTextAsync(filePath); } public static async ValueTask ExportAsHtmlAsync(Snowflake channelId) => Html.Parse(