Refactor DataService

In preparation for #177
pull/241/head
Alexey Golub 5 years ago
parent b830014a46
commit 457e14d0b6

@ -1,6 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Net;
using System.Net.Http; using System.Net.Http;
using System.Net.Http.Headers; using System.Net.Http.Headers;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -9,61 +10,56 @@ using DiscordChatExporter.Core.Services.Exceptions;
using DiscordChatExporter.Core.Services.Internal; using DiscordChatExporter.Core.Services.Internal;
using Newtonsoft.Json.Linq; using Newtonsoft.Json.Linq;
using Polly; using Polly;
using Tyrrrz.Extensions;
namespace DiscordChatExporter.Core.Services namespace DiscordChatExporter.Core.Services
{ {
public partial class DataService : IDisposable public partial class DataService : IDisposable
{ {
private readonly HttpClient _httpClient = new HttpClient(); private readonly HttpClient _httpClient = new HttpClient();
private readonly IAsyncPolicy<HttpResponseMessage> _httpPolicy;
private async Task<JToken> GetApiResponseAsync(AuthToken token, string resource, string endpoint, public DataService()
params string[] parameters)
{ {
// Create retry policy _httpClient.BaseAddress = new Uri("https://discordapp.com/api/v6");
var retryPolicy = Policy
.Handle<HttpErrorStatusCodeException>(e => (int) e.StatusCode >= 500) // Discord seems to always respond 429 on our first request with unreasonable wait time (10+ minutes).
.Or<HttpErrorStatusCodeException>(e => (int) e.StatusCode == 429) // For that reason the policy will start respecting their retry-after header only after Nth failed response.
.WaitAndRetryAsync(10, _ => TimeSpan.FromSeconds(3)); _httpPolicy = Policy
.HandleResult<HttpResponseMessage>(m => m.StatusCode == HttpStatusCode.TooManyRequests)
// Send request .OrResult(m => m.StatusCode >= HttpStatusCode.InternalServerError)
return await retryPolicy.ExecuteAsync(async () => .WaitAndRetryAsync(6,
(i, result, ctx) =>
{
if (i <= 3)
return TimeSpan.FromSeconds(2 * i);
if (i <= 5)
return TimeSpan.FromSeconds(5 * i);
return result.Result.Headers.RetryAfter.Delta ?? TimeSpan.FromSeconds(10 * i);
},
(response, timespan, retryCount, context) => Task.CompletedTask);
}
private async Task<JToken> GetApiResponseAsync(AuthToken token, string route)
{
using var response = await _httpPolicy.ExecuteAsync(async () =>
{ {
// Create request using var request = new HttpRequestMessage(HttpMethod.Get, route);
const string apiRoot = "https://discordapp.com/api/v6";
using var request = new HttpRequestMessage(HttpMethod.Get, $"{apiRoot}/{resource}/{endpoint}");
// Set authorization header
request.Headers.Authorization = token.Type == AuthTokenType.Bot request.Headers.Authorization = token.Type == AuthTokenType.Bot
? new AuthenticationHeaderValue("Bot", token.Value) ? new AuthenticationHeaderValue("Bot", token.Value)
: new AuthenticationHeaderValue(token.Value); : new AuthenticationHeaderValue(token.Value);
// Add parameters return await _httpClient.SendAsync(request);
foreach (var parameter in parameters) });
{
var key = parameter.SubstringUntil("=");
var value = parameter.SubstringAfter("=");
// Skip empty values
if (string.IsNullOrWhiteSpace(value))
continue;
request.RequestUri = request.RequestUri.SetQueryParameter(key, value);
}
// Get response
using var response = await _httpClient.SendAsync(request);
// Check status code
// We throw our own exception here because default one doesn't have status code
if (!response.IsSuccessStatusCode)
throw new HttpErrorStatusCodeException(response.StatusCode, response.ReasonPhrase);
// Get content // We throw our own exception here because default one doesn't have status code
var raw = await response.Content.ReadAsStringAsync(); if (!response.IsSuccessStatusCode)
throw new HttpErrorStatusCodeException(response.StatusCode, response.ReasonPhrase);
// Parse var jsonRaw = await response.Content.ReadAsStringAsync();
return JToken.Parse(raw); return JToken.Parse(jsonRaw);
});
} }
public async Task<Guild> GetGuildAsync(AuthToken token, string guildId) public async Task<Guild> GetGuildAsync(AuthToken token, string guildId)
@ -72,7 +68,7 @@ namespace DiscordChatExporter.Core.Services
if (guildId == Guild.DirectMessages.Id) if (guildId == Guild.DirectMessages.Id)
return Guild.DirectMessages; return Guild.DirectMessages;
var response = await GetApiResponseAsync(token, "guilds", guildId); var response = await GetApiResponseAsync(token, $"guilds/{guildId}");
var guild = ParseGuild(response); var guild = ParseGuild(response);
return guild; return guild;
@ -80,23 +76,40 @@ namespace DiscordChatExporter.Core.Services
public async Task<Channel> GetChannelAsync(AuthToken token, string channelId) public async Task<Channel> GetChannelAsync(AuthToken token, string channelId)
{ {
var response = await GetApiResponseAsync(token, "channels", channelId); var response = await GetApiResponseAsync(token, $"channels/{channelId}");
var channel = ParseChannel(response); var channel = ParseChannel(response);
return channel; return channel;
} }
public async Task<IReadOnlyList<Guild>> GetUserGuildsAsync(AuthToken token) public async IAsyncEnumerable<Guild> EnumerateUserGuildsAsync(AuthToken token)
{ {
var response = await GetApiResponseAsync(token, "users", "@me/guilds", "limit=100"); var afterId = "";
var guilds = response.Select(ParseGuild).ToArray();
while (true)
{
var route = "users/@me/guilds?limit=100";
if (!string.IsNullOrWhiteSpace(afterId))
route += $"&after={afterId}";
var response = await GetApiResponseAsync(token, route);
return guilds; if (!response.HasValues)
yield break;
foreach (var guild in response.Select(ParseGuild))
{
yield return guild;
afterId = guild.Id;
}
}
} }
public Task<IReadOnlyList<Guild>> GetUserGuildsAsync(AuthToken token) => EnumerateUserGuildsAsync(token).AggregateAsync();
public async Task<IReadOnlyList<Channel>> GetDirectMessageChannelsAsync(AuthToken token) public async Task<IReadOnlyList<Channel>> GetDirectMessageChannelsAsync(AuthToken token)
{ {
var response = await GetApiResponseAsync(token, "users", "@me/channels"); var response = await GetApiResponseAsync(token, "users/@me/channels");
var channels = response.Select(ParseChannel).ToArray(); var channels = response.Select(ParseChannel).ToArray();
return channels; return channels;
@ -104,7 +117,7 @@ namespace DiscordChatExporter.Core.Services
public async Task<IReadOnlyList<Channel>> GetGuildChannelsAsync(AuthToken token, string guildId) public async Task<IReadOnlyList<Channel>> GetGuildChannelsAsync(AuthToken token, string guildId)
{ {
var response = await GetApiResponseAsync(token, "guilds", $"{guildId}/channels"); var response = await GetApiResponseAsync(token, $"guilds/{guildId}/channels");
var channels = response.Select(ParseChannel).ToArray(); var channels = response.Select(ParseChannel).ToArray();
return channels; return channels;
@ -112,36 +125,44 @@ namespace DiscordChatExporter.Core.Services
public async Task<IReadOnlyList<Role>> GetGuildRolesAsync(AuthToken token, string guildId) public async Task<IReadOnlyList<Role>> GetGuildRolesAsync(AuthToken token, string guildId)
{ {
var response = await GetApiResponseAsync(token, "guilds", $"{guildId}/roles"); var response = await GetApiResponseAsync(token, $"guilds/{guildId}/roles");
var roles = response.Select(ParseRole).ToArray(); var roles = response.Select(ParseRole).ToArray();
return roles; return roles;
} }
public async Task<IReadOnlyList<Message>> GetChannelMessagesAsync(AuthToken token, string channelId, private async Task<Message> GetLastMessageAsync(AuthToken token, string channelId, DateTimeOffset? before = null)
DateTimeOffset? after = null, DateTimeOffset? before = null, IProgress<double>? progress = null)
{ {
var result = new List<Message>(); var route = $"channels/{channelId}/messages?limit=1";
if (before != null)
route += $"&before={before.Value.ToSnowflake()}";
var response = await GetApiResponseAsync(token, route);
return response.Select(ParseMessage).FirstOrDefault();
}
public async IAsyncEnumerable<Message> EnumerateMessagesAsync(AuthToken token, string channelId,
DateTimeOffset? after = null, DateTimeOffset? before = null, IProgress<double>? progress = null)
{
// Get the last message // Get the last message
var response = await GetApiResponseAsync(token, "channels", $"{channelId}/messages", var lastMessage = await GetLastMessageAsync(token, channelId, before);
"limit=1", $"before={before?.ToSnowflake()}");
var lastMessage = response.Select(ParseMessage).FirstOrDefault();
// If the last message doesn't exist or it's outside of range - return // If the last message doesn't exist or it's outside of range - return
if (lastMessage == null || lastMessage.Timestamp < after) if (lastMessage == null || lastMessage.Timestamp < after)
{ {
progress?.Report(1); progress?.Report(1);
return result; yield break;
} }
// Get other messages // Get other messages
var firstMessage = default(Message);
var offsetId = after?.ToSnowflake() ?? "0"; var offsetId = after?.ToSnowflake() ?? "0";
while (true) while (true)
{ {
// Get message batch // Get message batch
response = await GetApiResponseAsync(token, "channels", $"{channelId}/messages", var route = $"channels/{channelId}/messages?limit=100&after={offsetId}";
"limit=100", $"after={offsetId}"); var response = await GetApiResponseAsync(token, route);
// Parse // Parse
var messages = response var messages = response
@ -158,30 +179,36 @@ namespace DiscordChatExporter.Core.Services
.TakeWhile(m => m.Id != lastMessage.Id && m.Timestamp < lastMessage.Timestamp) .TakeWhile(m => m.Id != lastMessage.Id && m.Timestamp < lastMessage.Timestamp)
.ToArray(); .ToArray();
// Add to result // Yield messages
result.AddRange(messagesInRange); foreach (var message in messagesInRange)
{
// Set first message if it's not set
firstMessage ??= message;
// Report progress (based on the time range of parsed messages compared to total)
progress?.Report((message.Timestamp - firstMessage.Timestamp).TotalSeconds /
(lastMessage.Timestamp - firstMessage.Timestamp).TotalSeconds);
yield return message;
offsetId = message.Id;
}
// Break if messages were trimmed (which means the last message was encountered) // Break if messages were trimmed (which means the last message was encountered)
if (messagesInRange.Length != messages.Length) if (messagesInRange.Length != messages.Length)
break; break;
// Report progress (based on the time range of parsed messages compared to total)
progress?.Report((result.Last().Timestamp - result.First().Timestamp).TotalSeconds /
(lastMessage.Timestamp - result.First().Timestamp).TotalSeconds);
// Move offset
offsetId = result.Last().Id;
} }
// Add last message // Yield last message
result.Add(lastMessage); yield return lastMessage;
// Report progress // Report progress
progress?.Report(1); progress?.Report(1);
return result;
} }
public Task<IReadOnlyList<Message>> GetMessagesAsync(AuthToken token, string channelId,
DateTimeOffset? after = null, DateTimeOffset? before = null, IProgress<double>? progress = null) =>
EnumerateMessagesAsync(token, channelId, after, before, progress).AggregateAsync();
public async Task<Mentionables> GetMentionablesAsync(AuthToken token, string guildId, public async Task<Mentionables> GetMentionablesAsync(AuthToken token, string guildId,
IEnumerable<Message> messages) IEnumerable<Message> messages)
{ {
@ -214,7 +241,7 @@ namespace DiscordChatExporter.Core.Services
DateTimeOffset? after = null, DateTimeOffset? before = null, IProgress<double>? progress = null) DateTimeOffset? after = null, DateTimeOffset? before = null, IProgress<double>? progress = null)
{ {
// Get messages // Get messages
var messages = await GetChannelMessagesAsync(token, channel.Id, after, before, progress); var messages = await GetMessagesAsync(token, channel.Id, after, before, progress);
// Get mentionables // Get mentionables
var mentionables = await GetMentionablesAsync(token, guild.Id, messages); var mentionables = await GetMentionablesAsync(token, guild.Id, messages);
@ -234,19 +261,6 @@ namespace DiscordChatExporter.Core.Services
return await GetChatLogAsync(token, guild, channel, after, before, progress); return await GetChatLogAsync(token, guild, channel, after, before, progress);
} }
public async Task<ChatLog> GetChatLogAsync(AuthToken token, string channelId, public void Dispose() => _httpClient.Dispose();
DateTimeOffset? after = null, DateTimeOffset? before = null, IProgress<double>? progress = null)
{
// Get channel
var channel = await GetChannelAsync(token, channelId);
// Get the chat log
return await GetChatLogAsync(token, channel, after, before, progress);
}
public void Dispose()
{
_httpClient.Dispose();
}
} }
} }

@ -1,5 +1,7 @@
using System; using System;
using System.Collections.Generic;
using System.Drawing; using System.Drawing;
using System.Threading.Tasks;
namespace DiscordChatExporter.Core.Services.Internal namespace DiscordChatExporter.Core.Services.Internal
{ {
@ -14,5 +16,15 @@ namespace DiscordChatExporter.Core.Services.Internal
} }
public static Color ResetAlpha(this Color color) => Color.FromArgb(1, color); public static Color ResetAlpha(this Color color) => Color.FromArgb(1, color);
public static async Task<IReadOnlyList<T>> AggregateAsync<T>(this IAsyncEnumerable<T> asyncEnumerable)
{
var list = new List<T>();
await foreach (var i in asyncEnumerable)
list.Add(i);
return list;
}
} }
} }
Loading…
Cancel
Save