Use dictionaries for storing context members, channels, and roles

pull/1003/head
Tyrrrz 2 years ago
parent a7f4fe0643
commit 2a81abb1a6

@ -5,7 +5,6 @@ using System.Threading;
using System.Threading.Tasks;
using DiscordChatExporter.Core.Discord;
using DiscordChatExporter.Core.Discord.Data;
using DiscordChatExporter.Core.Discord.Data.Common;
using DiscordChatExporter.Core.Exceptions;
using DiscordChatExporter.Core.Utils.Extensions;
using Gress;
@ -24,11 +23,16 @@ public class ChannelExporter
CancellationToken cancellationToken = default)
{
// Build context
var contextMembers = new HashSet<Member>(IdBasedEqualityComparer.Instance);
var contextChannels = await _discord.GetGuildChannelsAsync(request.Guild.Id, cancellationToken);
var contextRoles = await _discord.GetGuildRolesAsync(request.Guild.Id, cancellationToken);
var contextMembers = new Dictionary<Snowflake, Member>();
var contextChannels = (await _discord.GetGuildChannelsAsync(request.Guild.Id, cancellationToken))
.ToDictionary(c => c.Id);
var contextRoles = (await _discord.GetGuildRolesAsync(request.Guild.Id, cancellationToken))
.ToDictionary(r => r.Id);
var context = new ExportContext(
_discord,
request,
contextMembers,
contextChannels,
@ -38,9 +42,6 @@ public class ChannelExporter
// Export messages
await using var messageExporter = new MessageExporter(context);
var exportedAnything = false;
var encounteredUsers = new HashSet<User>(IdBasedEqualityComparer.Instance);
await foreach (var message in _discord.GetMessagesAsync(
request.Channel.Id,
request.After,
@ -48,16 +49,10 @@ public class ChannelExporter
progress,
cancellationToken))
{
cancellationToken.ThrowIfCancellationRequested();
// Skip messages that fail to pass the supplied filter
if (!request.MessageFilter.IsMatch(message))
continue;
// Resolve members for referenced users
foreach (var referencedUser in message.MentionedUsers.Prepend(message.Author))
{
if (!encounteredUsers.Add(referencedUser))
if (contextMembers.ContainsKey(referencedUser.Id))
continue;
var member = await _discord.GetGuildMemberAsync(
@ -66,16 +61,16 @@ public class ChannelExporter
cancellationToken
);
contextMembers.Add(member);
contextMembers[member.Id] = member;
}
// Export message
await messageExporter.ExportMessageAsync(message, cancellationToken);
exportedAnything = true;
// Export the message
if (request.MessageFilter.IsMatch(message))
await messageExporter.ExportMessageAsync(message, cancellationToken);
}
// Throw if no messages were exported
if (!exportedAnything)
if (messageExporter.MessagesExported <= 0)
throw DiscordChatExporterException.ChannelIsEmpty();
}
}

@ -12,31 +12,17 @@ using DiscordChatExporter.Core.Utils.Extensions;
namespace DiscordChatExporter.Core.Exporting;
internal class ExportContext
internal record ExportContext(
DiscordClient Discord,
ExportRequest Request,
IReadOnlyDictionary<Snowflake, Member> Members,
IReadOnlyDictionary<Snowflake, Channel> Channels,
IReadOnlyDictionary<Snowflake, Role> Roles)
{
private readonly ExportAssetDownloader _assetDownloader;
public ExportRequest Request { get; }
public IReadOnlyCollection<Member> Members { get; }
public IReadOnlyCollection<Channel> Channels { get; }
public IReadOnlyCollection<Role> Roles { get; }
public ExportContext(
ExportRequest request,
IReadOnlyCollection<Member> members,
IReadOnlyCollection<Channel> channels,
IReadOnlyCollection<Role> roles)
{
Request = request;
Members = members;
Channels = channels;
Roles = roles;
_assetDownloader = new ExportAssetDownloader(request.OutputAssetsDirPath, request.ShouldReuseAssets);
}
private readonly ExportAssetDownloader _assetDownloader = new(
Request.OutputAssetsDirPath,
Request.ShouldReuseAssets
);
public string FormatDate(DateTimeOffset instant) => Request.DateFormat switch
{
@ -45,18 +31,23 @@ internal class ExportContext
var format => instant.ToLocalString(format)
};
public Member? TryGetMember(Snowflake id) => Members.FirstOrDefault(m => m.Id == id);
public Member? TryGetMember(Snowflake id) => Members.GetValueOrDefault(id);
public Channel? TryGetChannel(Snowflake id) => Channels.FirstOrDefault(c => c.Id == id);
public Channel? TryGetChannel(Snowflake id) => Channels.GetValueOrDefault(id);
public Role? TryGetRole(Snowflake id) => Roles.FirstOrDefault(r => r.Id == id);
public Role? TryGetRole(Snowflake id) => Roles.GetValueOrDefault(id);
public Color? TryGetUserColor(Snowflake id)
{
var member = TryGetMember(id);
var roles = member?.RoleIds.Join(Roles, i => i, r => r.Id, (_, role) => role);
return roles?
var memberRoles = member?
.RoleIds
.Select(TryGetRole)
.WhereNotNull()
.ToArray();
return memberRoles?
.Where(r => r.Color is not null)
.OrderByDescending(r => r.Position)
.Select(r => r.Color)

@ -13,6 +13,8 @@ internal partial class MessageExporter : IAsyncDisposable
private int _partitionIndex;
private MessageWriter? _writer;
public long MessagesExported { get; private set; }
public MessageExporter(ExportContext context)
{
_context = context;
@ -62,6 +64,7 @@ internal partial class MessageExporter : IAsyncDisposable
{
var writer = await GetWriterAsync(cancellationToken);
await writer.WriteMessageAsync(message, cancellationToken);
MessagesExported++;
}
public async ValueTask DisposeAsync() => await ResetWriterAsync();

@ -15,4 +15,13 @@ public static class CollectionExtensions
foreach (var o in source)
yield return (o, i++);
}
public static IEnumerable<T> WhereNotNull<T>(this IEnumerable<T?> source) where T : class
{
foreach (var o in source)
{
if (o is not null)
yield return o;
}
}
}
Loading…
Cancel
Save